# BYOL in PyTorch

This is written by Ying Liu (yili@di.ku.dk) who is the student in the Department of Computer Science at Københavns Universitet. And the following code are based on the given code from the original paper in JAX.

## Algorithm

<div style="display:flex; justify-content: center; align-items: center; background: white">
    <img src="https://imgur.com/rBZRkOf.png" style="width:45%; margin-right:1rem" />
    <img src="https://imgur.com/uudkfAk.png" style="width:45%;" />
</div>

There are three important components, they are **encoder** $f$ (ResNet-50), **projector** $g$ and **predictor** $q$, respectively. 

- Online Network: encoder + projector + predictor, with parameters $\theta$;

- Target Network: encoder + projector, with parameters $\xi$.

During the training, the online network is updated by the **loss function** while the target network is updated slowly by the **exponential moving average** (EMA) of the online network.

In [10]:
import torch
import random
from torch import nn
import torch.nn.functional as F
import torchvision.transforms as T
from torchvision.models import resnet50
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor

### Hyperparameters Dict

In [None]:
HPS = dict(
    max_steps=int(1000. * 1281167 / 4096), # 1000 epochs
    batch_size=4096,
    mlp_hidden_size=4096,
    projection_size=256,
    base_target_ema=4e-3,
    optimizer_config=dict(
        optimizer_name='lars',
        beta=0.9,
        trust_coef=1e-3,
        weight_decay=1.5e-6,
        # As in SimCLR and official implementation of LARS, we exclude bias
        # and batchnorm weight from the Lars adaption and weightdecay.
        exclude_bias_from_adaption=True),
    learnng_rate_schedule=dict(
        # The learning rate is linearly increase up to 
        # its base value * batchsize / 256 after warmup_steps
        # global steps and then anneal with a cosine schedule.
        base_learning_rate=0.2,
        warmup_steps=int(10. * 1281167 / 4096),
        anneal_schedule='cosine'),
    batchnorm_kwargs=dict(
        decay_rate=0.9,
        eps=1e-5),
    seed=1337,
    )

### Loss Function

Define the loss function in advance.



In [None]:
def loss_fn(x, y):
    x = F.normalize(x, dim=-1, p=2)
    y = F.normalize(y, dim=-1, p=2)
    return 2 - 2 * (x * y).sum(dim=-1)

### Random Augmentation

As the same as in SimCLR. The random augmentation is used to generate two views of the same image. The random augmentation is composed of the following operations:

A **random** patch of the image is selected and **resized** to 224x224 with a **random horizontal flip**, followed by a **color distortion**, consisting of a **random sequence of brightness, contrast, saturation, hue adjustments**, and an optional grayscale conversion. Finally, **Gaussian blur and solarization** are applied to the patches.

In [None]:
class RandomApply(nn.Module):
    def __init__(self, fn, p):
        super().__init__()
        self.fn = fn
        self.p = p
        
    def forward(self, x):
        if random.random() > self.p:
            return x
        return self.fn(x)

In [None]:
# DEFAULT_AUG
DEFAULT_AUG = torch.nn.Sequential(
            RandomApply(
                T.ColorJitter(0.8, 0.8, 0.8, 0.2),
                p = 0.3
            ),
            T.RandomGrayscale(p=0.2),
            T.RandomHorizontalFlip(),
            RandomApply(
                T.GaussianBlur((3, 3), (1.0, 2.0)),
                p = 0.2
            ),
            T.RandomResizedCrop((image_size, image_size)),
            T.Normalize(
                mean=torch.tensor([0.485, 0.456, 0.406]),
                std=torch.tensor([0.229, 0.224, 0.225])),
        )

### Model Implementation

Below are the cells for defining the **encoder**, **projector**, **predictor**, and the **loss function**.

### Encoder

The encoder is a ResNetV1_50x1 model with the last fully connected layer removed.

In [12]:
class BYOLEncoder(nn.Module):
    def __init__(self):
        super(BYOLEncoder, self).__init__()
        self.resnet = resnet50(pretrained=False)
        self.encoder = nn.Sequential(*list(self.resnet.children())[:-1])

    def forward(self, x):
        return self.encoder(x)

In [None]:
def MLP(dim, hidden_size, projection_size):
    return nn.Sequential(
        nn.Linear(dim, hidden_size),
        nn.BatchNorm1d(hidden_size, **HPS['batchnorm_kwargs']),
        nn.ReLU(inplace=True),
        nn.Linear(hidden_size, projection_size)
    )

### Model Composition

In [None]:
class network(nn.Module):
    def __init__(self, encoder, projector):
        super(network, self).__init__()
        self.encoder = encoder
        self.projector = projector
        self.predictor = projector

    def forward(self, x):
        pass

In [2]:
# Get cpu, gpu or mps device for training.
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Using {device} device")

Using mps device


In [None]:
# count the number of files within the directory
import os
import os.path
import glob
import numpy as np

# Get the number of files in the directory
def get_num_files(directory):
    if not os.path.exists(directory):
        return 0
    return sum([len(files) for r, d, files in os.walk(directory)])