# BYOL

In this notebook we are going to implement [BYOL: Bootstrap Your Own Latent](https://arxiv.org/pdf/2006.07733.pdf) and compare the results of a classification task before and after pretraining the model with BYOL.

### Data Augmentations

In [1]:
import random
from typing import Callable, Tuple
import torch
import torchvision
from torch import nn, Tensor
from torchvision import transforms as T
from torch.nn import functional as F


class RandomApply(nn.Module):
    def __init__(self, fn: Callable, p: float):
        super().__init__()
        self.fn = fn
        self.p = p

    def forward(self, x: Tensor) -> Tensor:
        # your code

        if random.random() > self.p:
          return x
        else:
          return self.fn(x)



def default_augmentation(image_size: Tuple[int, int] = (224, 224)) -> nn.Module:
    """
        1. resize images to 'image_size'
        2. RandomApply color jitter
        3. RandomApply grayscale
        4. RandomApply horizon flip
        5. RandomApply gaussian blur with kernel_size(3, 3), sigma=(1.5, 1.5)
        6. RandomApply ResizedCrop to 'image_size'
        7. Normalize
        choosing hyperparameters that are not mentioned is up to you
    """
    return nn.Sequential(
        # your code
        T.Resize(size=image_size),
        RandomApply(T.ColorJitter(0.8, 0.8, 0.8, 0.2), p=0.8),
        T.RandomGrayscale(p=0.2),
        T.RandomHorizontalFlip(p=0.5),
        RandomApply(T.GaussianBlur((3, 3), (1.5, 1.5)), p=0.1),
        T.RandomResizedCrop(size=image_size),
        T.Normalize(
            mean=torch.tensor([0.485, 0.456, 0.406]),
            std=torch.tensor([0.229, 0.224, 0.225]),
        )
    )

# Model
We will use ResNet18 as our representation model.

In [2]:
def get_encoder_model():
    resnet = torchvision.models.resnet18()
    # remove last fully-connected layer
    # your code
    resnet = torch.nn.Sequential(*list(resnet.children())[:-1])

    return resnet

### Loss Function
We need to use NormalizedMSELoss as our loss function.
$$NormalizedMSELoss(v_1, v_2) = \Vert \bar{v_1} - \bar{v_2}\Vert_2^2 = 2 - 2.\frac{\langle v_1, v_2 \rangle}{\Vert v_1\Vert_2 \Vert v_2\Vert_2}$$

In [3]:
class NormalizedMSELoss(nn.Module):
    def __init__(self) -> None:
        super(NormalizedMSELoss,self).__init__()

    def forward(self, view1: Tensor, view2: Tensor) -> Tensor:
        # your code
        v1 = F.normalize(view1, dim=-1)
        v2 = F.normalize(view2, dim=-1)
        return 2 - 2 * (v1 * v2).sum(dim=-1)

### MLP
Here you will implement a simple MLP class with one hidden layer with BatchNorm and ReLU activation, and a linear output layer. This class will be used for both the projections and the prediction networks.

In [4]:
class MLP(nn.Module):
    def __init__(self, input_dim: int=4096, projection_dim: int = 256, hidden_dim: int = 4096) -> None:
        super(MLP,self).__init__()

        # your code
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_dim, projection_dim)
        )

    def forward(self, x: Tensor) -> Tensor:
        # your code
        return self.net(x)


### Encoder + Projector Network
This is the network structure that is shared between online and target networks. It consists of our encoder model, followed by a projection MLP.

In [5]:
class EncoderProjecter(nn.Module):
    def __init__(self,
                 encoder: nn.Module,
                 hidden_dim: int = 4096,
                 projection_out_dim: int = 256
                 ) -> None:
        super(EncoderProjecter, self).__init__()

        # your code
        self.encoder = encoder
        self.projector = MLP(512)

    def forward(self, x: Tensor) -> Tensor:
        # your code
        x = self.encoder(x).squeeze()
        return self.projector(x)


## BYOL

In [6]:
import copy

class BYOL(nn.Module):
    def __init__(self,
                 model: nn.Module,
                 hidden_dim: int = 4096,
                 projection_out_dim: int = 256,
                 target_decay: float = 0.99

                ) -> None:
        super(BYOL, self).__init__()
        # your code
        self.online_network = EncoderProjecter(model, hidden_dim, projection_out_dim)  # encoder + projector
        self.online_predictor = MLP(256)
        self.target_network = copy.deepcopy(self.online_network)
        for p in self.target_network.parameters():
          p.requires_grad = False
        self.target_network.eval()
        self.loss_function = NormalizedMSELoss()
        self.target_decay = target_decay
    @torch.no_grad()
    def soft_update_target_network(self) -> None:
        # your code
         for online_params, target_params in zip(self.online_network.parameters(), self.target_network.parameters()):
            target_params.data = (self.target_decay * target_params.data) + ((1 - self.target_decay) * online_params.data)




    def forward(self, view1, view2) -> Tuple[Tensor]:
        # return online projection and target projection of view
        # your code
        with torch.no_grad():
          target_projection1 = self.target_network(view1)
          target_projection2 = self.target_network(view2)

        online_pred1 = self.online_predictor(self.online_network(view1))
        online_pred2 = self.online_predictor(self.online_network(view2))

        loss1 = self.loss_function(online_pred1, target_projection2)
        loss2 = self.loss_function(online_pred2, target_projection1)

        #self.soft_update_target_network()

        return torch.mean(loss1 + loss2)
    '''def loss(self, online_pred1, target_projection1, online_pred2, target_projection2):
        # compute loss once for (online_prediction1, target_projection2) and once for (online_prediction2, target_projection1).
        # then return the mean.
        # your code



        return (loss1 + loss2) / 2'''

# STL10 Datasets

We need 3 separate datasets from STL10 for this experiment:
1. `"train"` -- Contains only labeled training images. Used for supervised training.
2. `"train+unlabeled"` -- Contains training images, plus a large number of unlabelled images.  Used for self-supervised learning with BYOL.
3. `"test"` -- Labeled test images.  We use it both as a validation set, and for computing the final model accuracy.

In [7]:
from torchvision.datasets import STL10
from torchvision.transforms import ToTensor


TRAIN_DATASET = STL10(root="data", split="train", download=True, transform=ToTensor())
TRAIN_UNLABELED_DATASET = STL10(root="data", split="train+unlabeled", download=True, transform=ToTensor())
TEST_DATASET = STL10(root="data", split="test", download=True, transform=ToTensor())

Downloading http://ai.stanford.edu/~acoates/stl10/stl10_binary.tar.gz to data/stl10_binary.tar.gz


100%|██████████| 2640397119/2640397119 [02:38<00:00, 16626326.09it/s]


Extracting data/stl10_binary.tar.gz to data
Files already downloaded and verified
Files already downloaded and verified


Create dataloaders:

In [8]:
# your code
from torch.utils.data import DataLoader
device = DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
train_loader = DataLoader(
    TRAIN_DATASET,
    batch_size=256,
    shuffle=True,

)
val_loader = DataLoader(
    TEST_DATASET,
    batch_size=256,
)
train_un_loader = DataLoader(
    TRAIN_UNLABELED_DATASET,
    batch_size=256,
    shuffle=True,

)

# Supervised Training without BYOL

First create a classifier model by simply adding a linear layer on top of the encoder model. Then train the model using the labeled training set. Performance should be pretty good already.

In [9]:
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm


class Classifier(nn.Module):
    def __init__(self, encoder, num_classes):
        super(Classifier, self).__init__()
        self.encoder = encoder
        self.linear = nn.Linear(512, num_classes, device='cuda')

    def forward(self, x):
        x = self.encoder(x).squeeze()
        x = self.linear(x)
        return x

# Get the encoder model
encoder = get_encoder_model().to(device)
num_classes = 10
classifier = Classifier(encoder, num_classes)

criterion = nn.CrossEntropyLoss()
LR = 2e-4
optimizer = optim.Adam(classifier.parameters(), lr=LR)
num_epochs = 5


for epoch in tqdm(range(num_epochs)):
    classifier.train()
    for inputs, labels in train_loader:
        optimizer.zero_grad()
        outputs = classifier(inputs.to(device))
        loss = criterion(outputs, labels.to(device))

        loss.backward()
        optimizer.step()


    classifier.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in val_loader:
            outputs = classifier(inputs.to(device))
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels.to(device)).sum().item()

    print(f'Epoch {epoch + 1}/{num_epochs}, Accuracy: {100 * correct / total:.2f}%')


 20%|██        | 1/5 [00:15<01:01, 15.32s/it]

Epoch 1/5, Accuracy: 10.01%


 40%|████      | 2/5 [00:24<00:34, 11.50s/it]

Epoch 2/5, Accuracy: 15.36%


 60%|██████    | 3/5 [00:32<00:19,  9.98s/it]

Epoch 3/5, Accuracy: 40.21%


 80%|████████  | 4/5 [00:41<00:09,  9.54s/it]

Epoch 4/5, Accuracy: 45.24%


100%|██████████| 5/5 [00:50<00:00, 10.05s/it]

Epoch 5/5, Accuracy: 45.30%





به علت کمبود وقت ۵ ایپاک آموزش داده شد. همچنین با ۱۵ ایپاک در حد ۴ درصد تفاوت دارد

# Self-Supervised Training with BYOL

Now perform the self-supervised training. This is the most computationally intensive part of the script.

In [10]:
import torch.optim as optim
from tqdm import tqdm

encoder2 = get_encoder_model()
byol = BYOL(encoder2, 4096, 256, 0.99).to(device)
Path = './content/'
LR = 2e-4
optimizer2 = optim.Adam(byol.parameters(), lr=LR)
aug = default_augmentation((96,96))

num_epochs = 20
for epoch in tqdm(range(num_epochs)):
    total_loss = 0
    for inputs, labels in train_un_loader:
        inputs = inputs.to(device)
        with torch.no_grad():
          v1 ,  v2 = aug(inputs), aug(inputs)

        optimizer2.zero_grad()

        loss = byol(v1, v2)
        loss.backward()

        optimizer2.step()
        byol.soft_update_target_network()
        total_loss += loss.item()
    print(f'Epoch {epoch + 1}/{num_epochs}, Loss: {total_loss}')


  5%|▌         | 1/20 [04:31<1:26:06, 271.90s/it]

Epoch 1/20, Loss: 13.009098197100684


 10%|█         | 2/20 [09:03<1:21:28, 271.61s/it]

Epoch 2/20, Loss: 0.11659725764911855


 15%|█▌        | 3/20 [13:33<1:16:49, 271.16s/it]

Epoch 3/20, Loss: 0.015903343723039143


 20%|██        | 4/20 [18:03<1:12:07, 270.46s/it]

Epoch 4/20, Loss: 0.00648481408279622


 25%|██▌       | 5/20 [22:32<1:07:31, 270.07s/it]

Epoch 5/20, Loss: 0.004222342373395804


 30%|███       | 6/20 [27:02<1:02:59, 269.95s/it]

Epoch 6/20, Loss: 0.0030689763834743644


 35%|███▌      | 7/20 [31:33<58:36, 270.46s/it]  

Epoch 7/20, Loss: 0.0022416956726374337


 40%|████      | 8/20 [36:04<54:05, 270.42s/it]

Epoch 8/20, Loss: 0.00190783981270215


 45%|████▌     | 9/20 [40:34<49:35, 270.50s/it]

Epoch 9/20, Loss: 0.001542078610327735


 50%|█████     | 10/20 [45:05<45:04, 270.48s/it]

Epoch 10/20, Loss: 0.0015028218740553712


 55%|█████▌    | 11/20 [49:34<40:29, 269.98s/it]

Epoch 11/20, Loss: 0.00132805211478626


 60%|██████    | 12/20 [54:04<35:59, 269.98s/it]

Epoch 12/20, Loss: 0.001988662034364097


 65%|██████▌   | 13/20 [58:35<31:32, 270.29s/it]

Epoch 13/20, Loss: 0.002682100795027509


 70%|███████   | 14/20 [1:03:04<27:00, 270.10s/it]

Epoch 14/20, Loss: 0.0019195920788206422


 75%|███████▌  | 15/20 [1:07:33<22:28, 269.72s/it]

Epoch 15/20, Loss: 0.00380253968796751


 80%|████████  | 16/20 [1:12:04<17:59, 269.96s/it]

Epoch 16/20, Loss: 0.002904784959355311


 85%|████████▌ | 17/20 [1:16:34<13:30, 270.15s/it]

Epoch 17/20, Loss: 0.00850325170904398


 90%|█████████ | 18/20 [1:21:03<08:59, 269.80s/it]

Epoch 18/20, Loss: 0.001862303167627033


 95%|█████████▌| 19/20 [1:25:35<04:30, 270.25s/it]

Epoch 19/20, Loss: 0.0006754969244866516


100%|██████████| 20/20 [1:30:05<00:00, 270.27s/it]

Epoch 20/20, Loss: 0.003070410899908893





# Supervised Training Again

Extract the encoder network's state dictionary from BYOL, and load it into our ResNet18 model before starting training.  Then run supervised training, and watch the accuracy improve from last time!

In [11]:
encoder_ = torch.nn.Sequential(*list(byol.online_network.children())[:-1])


num_classes = 10
classifier2 = Classifier(encoder_, num_classes)

criterion = nn.CrossEntropyLoss()
LR = 2e-4
optimizer = optim.Adam(classifier2.parameters(), lr=LR)
num_epochs = 15


for epoch in tqdm(range(num_epochs)):
    classifier2.train()
    for inputs, labels in train_loader:

        optimizer.zero_grad()

        outputs = classifier2(inputs.to(device))
        loss = criterion(outputs, labels.to(device))
        loss.backward()

        optimizer.step()

    classifier2.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in val_loader:
            outputs = classifier2(inputs.to(device))
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels.to(device)).sum().item()

    print(f'Epoch {epoch + 1}/{num_epochs}, Accuracy: {100 * correct / total:.2f}%')


  7%|▋         | 1/15 [00:09<02:07,  9.07s/it]

Epoch 1/15, Accuracy: 14.25%


 13%|█▎        | 2/15 [00:17<01:52,  8.66s/it]

Epoch 2/15, Accuracy: 36.89%


 20%|██        | 3/15 [00:26<01:45,  8.81s/it]

Epoch 3/15, Accuracy: 43.48%


 27%|██▋       | 4/15 [00:35<01:39,  9.00s/it]

Epoch 4/15, Accuracy: 45.41%


 33%|███▎      | 5/15 [00:44<01:30,  9.01s/it]

Epoch 5/15, Accuracy: 46.10%


 40%|████      | 6/15 [00:53<01:19,  8.81s/it]

Epoch 6/15, Accuracy: 51.42%


 47%|████▋     | 7/15 [01:02<01:10,  8.86s/it]

Epoch 7/15, Accuracy: 47.67%


 53%|█████▎    | 8/15 [01:11<01:03,  9.02s/it]

Epoch 8/15, Accuracy: 51.66%


 60%|██████    | 9/15 [01:20<00:53,  8.96s/it]

Epoch 9/15, Accuracy: 53.24%


 67%|██████▋   | 10/15 [01:29<00:44,  8.88s/it]

Epoch 10/15, Accuracy: 53.96%


 73%|███████▎  | 11/15 [01:38<00:35,  8.91s/it]

Epoch 11/15, Accuracy: 54.26%


 80%|████████  | 12/15 [01:47<00:26,  8.98s/it]

Epoch 12/15, Accuracy: 55.10%


 87%|████████▋ | 13/15 [01:55<00:17,  8.82s/it]

Epoch 13/15, Accuracy: 55.39%


 93%|█████████▎| 14/15 [02:04<00:08,  8.87s/it]

Epoch 14/15, Accuracy: 57.01%


100%|██████████| 15/15 [02:13<00:00,  8.91s/it]

Epoch 15/15, Accuracy: 57.16%



