# Bootstrap Your Own Latent (BYOL)

In this session we are going to implement Bootstrap Your Own Latent paper (https://arxiv.org/abs/2006.07733).

It uses a MoCo-style training (with asymmetric SiameseNet) but with a L2 loss penalty (it is not a contrastive-base method).

In [1]:
import os
import pandas as pd

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

from torchvision.io import read_image

import torchvision
import torchvision.models as models
import torchvision.transforms as transforms

from PIL import Image
import copy

In [7]:
# loss fn
def loss_fn(x, y):
    x = F.normalize(x, dim=-1, p=2)
    y = F.normalize(y, dim=-1, p=2)
    return 2 - 2 * (x * y).sum(dim=-1) # cosine similarity


class EMA():
    # exponential moving average
    def __init__(self, beta):
        super().__init__()
        self.beta = beta

    def update_average(self, old, new):
        if old is None:
            return new
        return old * self.beta + (1 - self.beta) * new

def update_moving_average(ema_updater, ma_model, current_model):
    for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
        old_weight, up_weight = ma_params.data, current_params.data
        ma_params.data = ema_updater.update_average(old_weight, up_weight)

# MLP class for projector and predictor

def MLP(dim, projection_size, hidden_size=4096, sync_batchnorm=None):
    return nn.Sequential(
        nn.Linear(dim, hidden_size),
        nn.BatchNorm1d(hidden_size),
        nn.ReLU(inplace=True),
        nn.Linear(hidden_size, projection_size)
    )


class BYOL(nn.Module):
    def __init__(self, backbone, moving_average_decay = 0.99):
        super().__init__()

        self.target_ema_updater = EMA(moving_average_decay)

        self.online_net = backbone
        self.online_net.fc = nn.Identity()
        self.online_projector = MLP(512, 512, 4096)

        self.target_net = None

    def _get_target_encoder(self):
        if self.target_net is None:
            target_net = copy.deepcopy(self.online_net) # copia solo decoder, non il projector
            for param in target_net.parameters():
                param.requires_grad = False
            self.target_net = target_net
        else:
            target_net = self.target_net
        return target_net

    def update_moving_average(self):
        update_moving_average(self.target_ema_updater, self.target_net, self.online_net)

    def forward(self, x1, x2):

        images = torch.cat((x1, x2), dim = 0)

        online_projections = self.online_projector(self.online_net(images))
        online_pred_one, online_pred_two = online_projections.chunk(2, dim = 0) # split tensor

        with torch.no_grad():
            target_net = self._get_target_encoder()

            target_projections = target_net(images) # labels, no SGD -> evito representation collapse
            target_projections = target_projections.detach()
            target_proj_one, target_proj_two = target_projections.chunk(2, dim = 0) # split tensor

        loss_one = loss_fn(online_pred_one, target_proj_two.detach())
        loss_two = loss_fn(online_pred_two, target_proj_one.detach())

        loss = loss_one + loss_two
        return loss.mean()

## Exercise 0

Study the above code.
- Where is the EMA updates?
- Why it computes both loss_one and loss_two values?

In [None]:
# Answer in report:
# sempre per efficienza, si calcolano come se fossero 2 batch di dati (liste "incrociate")
# Non posso farlo con contrastive loss, calcolerei la stessa cosa

## Exercise 1

Write the training loop for moco-style training as used in BYOL.
Use the Dataset which creates the two augmented views for each image and the Siamese Network from the past lab session [1](https://colab.research.google.com/drive/1NJwAFbRiD4MdwWf__6P2Lm0xYk_DNdVu?usp=sharing) and [2](https://colab.research.google.com/drive/1AMkh0q8L5nJScx7v6cMWoK336zqOqDY6?usp=sharing).

In [3]:
class SiameseNetASIM(nn.Module):
    def __init__(self, backbone, backbone2):
        super().__init__()
        self.encoder = backbone
        self.encoder2 = backbone2

    def forward(self, x1, x2):
        features = self.encoder(x1)
        features2 = self.encoder2(x2)
        return torch.cat((features, features2), 0)

In [9]:
class CustomImageDataset(Dataset):
    def __init__(self, data, targets, transform=None, target_transform=None):
        self.imgs = data
        self.targets = targets
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.targets)

    def __getitem__(self, idx):
        img_base = self.imgs[idx]
        if isinstance(img_base, str):
          img_base = read_image(img_base)
        label = self.targets[idx]
        if self.transform:
            img1 = self.transform(img_base)
            img2 = self.transform(img_base)
        else:
            img1 = img_base
            img2 = img_base
        if self.target_transform:
            label = self.target_transform(label)
        return img1, img2, label


data = torchvision.datasets.CIFAR10(root='./data', train=True, download=True)
size = 32
s=1
color_jitter = transforms.ColorJitter(0.8 * s, 0.8 * s, 0.8 * s, 0.2 * s)
transform = transforms.Compose([transforms.ToTensor(),
                                  transforms.RandomResizedCrop(size=size),
                                  transforms.RandomHorizontalFlip(),
                                  transforms.RandomApply([color_jitter], p=0.8),
                                  transforms.RandomGrayscale(p=0.2),
                                  transforms.GaussianBlur(kernel_size=int(0.1 * size))])

Files already downloaded and verified


In [10]:
trainset = CustomImageDataset(data.data, data.targets, transform=transform)
dataloader = DataLoader(trainset, batch_size=64, shuffle=True)

backbone = models.resnet18()
byol = BYOL(backbone)

# prendo solo i parametri da aggiornare con SGD, così risparmio tempo e memoria
online_net_params = {'params': byol.online_net.parameters()}
online_projector_params = {'params': byol.online_projector.parameters()}
optimizer = optim.SGD([online_net_params, online_projector_params], lr=0.001, momentum=0.9, weight_decay=1e-04)


for idx, data in enumerate(dataloader):
    view1, view2, _ = data

    optimizer.zero_grad()
    iter_loss = byol(view1, view2)
    print(f"Iter {idx} loss: {iter_loss}")
    iter_loss.backward() # gradienti solo dell'online net

    optimizer.step() # aggiornamento solo sell'online net
    byol.update_moving_average() # aggiorno target net

    if idx == 3:
        break

Iter 0 loss: 4.021104335784912
Iter 1 loss: 3.9903438091278076
Iter 2 loss: 3.8605329990386963
Iter 3 loss: 3.7224299907684326
