## BYOL

An end-to-end demonstration of BYOL in action, using STL10 dataset :
We use Kornia for implementing the transformations, a great Python library with fully differentiable computer vision operations and use PyTorch Lightning which is fantastic library for deep learning projects/research written in PyTorch, which includes conveniences like multi-GPU training, experiment logging, model checkpointing, and mixed-precision training. 
Following that, we'll require an Encoder module. The Encoder is in charge of extracting features from the base model and projecting them into a latent space with a lower dimension. We'll build a wrapper class to implement it, which will allow us to use BYOL with any model.  There are two primary elements - 
Feature Extractor: collects the outputs from one of the last model layers.
Projector: a linear layer, which projects outputs down lower dimensions.
BYOL has two Encoder networks that are identical. The first is trained as normal, with each training batch updating its weights. A running average of the first Encoder's weights is used to update the second (referred to as the "target" network). During training, a raw training batch is supplied to the target network, and a transformed version of the same batch is delivered to the other encoder. For its respective data, each network develops a low-dimensional, latent representation. Then, using a multi-layer perceptron, we try to anticipate the output of the target network. The similarity between this prediction and the output of the target network is maximised by BYOL.

The loss function which we have used here is called contrastive loss.

In [1]:
# Install dependencies.  Note that pytorch and torchvision are pre-installed 
# in standard Colab instances, so no need to worry about those.
!pip install -q kornia pytorch_lightning

[K     |████████████████████████████████| 493 kB 5.4 MB/s 
[K     |████████████████████████████████| 582 kB 20.0 MB/s 
[K     |████████████████████████████████| 408 kB 44.7 MB/s 
[K     |████████████████████████████████| 136 kB 42.3 MB/s 
[K     |████████████████████████████████| 596 kB 40.6 MB/s 
[K     |████████████████████████████████| 1.1 MB 30.4 MB/s 
[K     |████████████████████████████████| 94 kB 421 kB/s 
[K     |████████████████████████████████| 144 kB 36.3 MB/s 
[K     |████████████████████████████████| 271 kB 34.5 MB/s 
[?25h

## BYOL

An end-to-end demonstration of BYOL in action, using STL10 dataset :
We use Kornia for implementing the transformations, a great Python library with fully differentiable computer vision operations and use PyTorch Lightning which is fantastic library for deep learning projects/research written in PyTorch, which includes conveniences like multi-GPU training, experiment logging, model checkpointing, and mixed-precision training. 
Following that, we'll require an Encoder module. The Encoder is in charge of extracting features from the base model and projecting them into a latent space with a lower dimension. We'll build a wrapper class to implement it, which will allow us to use BYOL with any model.  There are two primary elements - 
Feature Extractor: collects the outputs from one of the last model layers.
Projector: a linear layer, which projects outputs down lower dimensions.
BYOL has two Encoder networks that are identical. The first is trained as normal, with each training batch updating its weights. A running average of the first Encoder's weights is used to update the second (referred to as the "target" network). During training, a raw training batch is supplied to the target network, and a transformed version of the same batch is delivered to the other encoder. For its respective data, each network develops a low-dimensional, latent representation. Then, using a multi-layer perceptron, we try to anticipate the output of the target network. The similarity between this prediction and the output of the target network is maximised by BYOL.

The loss function which we have used here is called contrastive loss.

### Data Augmentations

In [2]:
import random
from typing import Callable, Tuple

from kornia import augmentation as aug
from kornia import filters
from kornia.geometry import transform as tf
import torch
from torch import nn, Tensor


class RandomApply(nn.Module):
    def __init__(self, fn: Callable, p: float):
        super().__init__()
        self.fn = fn
        self.p = p

    def forward(self, x: Tensor) -> Tensor:
        return x if random.random() > self.p else self.fn(x)


def default_augmentation(image_size: Tuple[int, int] = (224, 224)) -> nn.Module:
    return nn.Sequential(
        tf.Resize(size=image_size),
        RandomApply(aug.ColorJitter(0.8, 0.8, 0.8, 0.2), p=0.8),
        aug.RandomGrayscale(p=0.2),
        aug.RandomHorizontalFlip(),
        RandomApply(filters.GaussianBlur2d((3, 3), (1.5, 1.5)), p=0.1),
        aug.RandomResizedCrop(size=image_size),
        aug.Normalize(
            mean=torch.tensor([0.485, 0.456, 0.406]),
            std=torch.tensor([0.229, 0.224, 0.225]),
        ),
    )

### Encoder wrapper for features identification

In [3]:
from typing import Union


def mlp(dim: int, projection_size: int = 256, hidden_size: int = 4096) -> nn.Module:
    return nn.Sequential(
        nn.Linear(dim, hidden_size),
        nn.BatchNorm1d(hidden_size),
        nn.ReLU(inplace=True),
        nn.Linear(hidden_size, projection_size),
    )


class EncoderWrapper(nn.Module):
    def __init__(
        self,
        model: nn.Module,
        projection_size: int = 256,
        hidden_size: int = 4096,
        layer: Union[str, int] = -2,
    ):
        super().__init__()
        self.model = model
        self.projection_size = projection_size
        self.hidden_size = hidden_size
        self.layer = layer

        self._projector = None
        self._projector_dim = None
        self._encoded = torch.empty(0)
        self._register_hook()

    @property
    def projector(self):
        if self._projector is None:
            self._projector = mlp(
                self._projector_dim, self.projection_size, self.hidden_size
            )
        return self._projector

    def _hook(self, _, __, output):
        output = output.flatten(start_dim=1)
        if self._projector_dim is None:
            self._projector_dim = output.shape[-1]
        self._encoded = self.projector(output)

    def _register_hook(self):
        if isinstance(self.layer, str):
            layer = dict([*self.model.named_modules()])[self.layer]
        else:
            layer = list(self.model.children())[self.layer]

        layer.register_forward_hook(self._hook)

    def forward(self, x: Tensor) -> Tensor:
        _ = self.model(x)
        return self._encoded

## BYOL module

In [4]:
from copy import deepcopy
from itertools import chain
from typing import Dict, List

import pytorch_lightning as pl
from torch import optim
import torch.nn.functional as f


def normalized_mse(x: Tensor, y: Tensor) -> Tensor:
    x = f.normalize(x, dim=-1)
    y = f.normalize(y, dim=-1)
    return 2 - 2 * (x * y).sum(dim=-1)


class BYOL(pl.LightningModule):
    def __init__(
        self,
        model: nn.Module,
        image_size: Tuple[int, int] = (128, 128),
        hidden_layer: Union[str, int] = -2,
        projection_size: int = 256,
        hidden_size: int = 4096,
        augment_fn: Callable = None,
        beta: float = 0.999,
    ):
        super().__init__()
        self.augment = default_augmentation(image_size) if augment_fn is None else augment_fn
        self.beta = beta
        self.encoder = EncoderWrapper(
            model, projection_size, hidden_size, layer=hidden_layer
        )
        self.predictor = nn.Linear(projection_size, projection_size, hidden_size)
        
        self._target = None

        self.encoder(torch.zeros(2, 3, *image_size))

    def forward(self, x: Tensor) -> Tensor:
        return self.predictor(self.encoder(x))

    @property
    def target(self):
        if self._target is None:
            self._target = deepcopy(self.encoder)
        return self._target

    def update_target(self):
        for p, pt in zip(self.encoder.parameters(), self.target.parameters()):
            pt.data = self.beta * pt.data + (1 - self.beta) * p.data

    # --- Methods required for PyTorch Lightning only! ---

    def configure_optimizers(self):
        optimizer = getattr(optim, self.hparams.get("optimizer", "Adam"))
        lr = self.hparams.get("lr", 1e-4)
        weight_decay = self.hparams.get("weight_decay", 1e-6)
        return optimizer(self.parameters(), lr=lr, weight_decay=weight_decay)

    def training_step(self, batch, *_) -> Dict[str, Union[Tensor, Dict]]:
        x = batch[0]
        with torch.no_grad():
            x1, x2 = self.augment(x), self.augment(x)

        pred1, pred2 = self.forward(x1), self.forward(x2)
        with torch.no_grad():
            targ1, targ2 = self.target(x1), self.target(x2)
        loss = torch.mean(normalized_mse(pred1, targ2) + normalized_mse(pred2, targ1))

        self.log("train_loss", loss.item())
        self.update_target()

        return {"loss": loss}

    @torch.no_grad()
    def validation_step(self, batch, *_) -> Dict[str, Union[Tensor, Dict]]:
        x = batch[0]
        x1, x2 = self.augment(x), self.augment(x)
        pred1, pred2 = self.forward(x1), self.forward(x2)
        targ1, targ2 = self.target(x1), self.target(x2)
        loss = torch.mean(normalized_mse(pred1, targ2) + normalized_mse(pred2, targ1))

        return {"loss": loss}

    @torch.no_grad()
    def validation_epoch_end(self, outputs: List[Dict]) -> Dict:
        val_loss = sum(x["loss"] for x in outputs) / len(outputs)
        self.log("val_loss", val_loss.item())

In [5]:
class SupervisedLightningModule(pl.LightningModule):
    def __init__(self, model: nn.Module, **hparams):
        super().__init__()
        self.model = model

    def forward(self, x: Tensor) -> Tensor:
        return self.model(x)

    def configure_optimizers(self):
        optimizer = getattr(optim, self.hparams.get("optimizer", "Adam"))
        lr = self.hparams.get("lr", 1e-4)
        weight_decay = self.hparams.get("weight_decay", 1e-6)
        return optimizer(self.parameters(), lr=lr, weight_decay=weight_decay)

    def training_step(self, batch, *_) -> Dict[str, Union[Tensor, Dict]]:
        x, y = batch
        loss = f.cross_entropy(self.forward(x), y)
        self.log("train_loss", loss.item())
        return {"loss": loss}

    @torch.no_grad()
    def validation_step(self, batch, *_) -> Dict[str, Union[Tensor, Dict]]:
        x, y = batch
        loss = f.cross_entropy(self.forward(x), y)
        return {"loss": loss}

    @torch.no_grad()
    def validation_epoch_end(self, outputs: List[Dict]) -> Dict:
        val_loss = sum(x["loss"] for x in outputs) / len(outputs)
        self.log("val_loss", val_loss.item())

### STL10 Datasets

We need 3 separate datasets from STL10 for this experiment:
1. `"train"` -- Contains only labeled training images. Used for supervised training.
2. `"train+unlabeled"` -- Contains training images, plus a large number of unlabelled images.  Used for self-supervised learning with BYOL.
3. `"test"` -- Labeled test images.  We use it both as a validation set, and for computing the final model accuracy.

In [6]:
from torchvision.datasets import STL10
from torchvision.transforms import ToTensor


TRAIN_DATASET = STL10(root="data", split="train", download=True, transform=ToTensor())
TRAIN_UNLABELED_DATASET = STL10(
    root="data", split="train+unlabeled", download=True, transform=ToTensor()
)
TEST_DATASET = STL10(root="data", split="test", download=True, transform=ToTensor())

Downloading http://ai.stanford.edu/~acoates/stl10/stl10_binary.tar.gz to data/stl10_binary.tar.gz


  0%|          | 0/2640397119 [00:00<?, ?it/s]

Extracting data/stl10_binary.tar.gz to data
Files already downloaded and verified
Files already downloaded and verified


## Training

In [7]:
from os import cpu_count

from torch.utils.data import DataLoader
from torchvision.models import resnet18


model = resnet18(pretrained=True)
supervised = SupervisedLightningModule(model)
trainer = pl.Trainer(max_epochs=25, gpus=-1, weights_summary=None)
train_loader = DataLoader(
    TRAIN_DATASET,
    batch_size=128,
    shuffle=True,
    drop_last=True,
)
val_loader = DataLoader(
    TEST_DATASET,
    batch_size=128,
)
trainer.fit(supervised, train_loader, val_loader)

Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth


  0%|          | 0.00/44.7M [00:00<?, ?B/s]

  "Setting `Trainer(weights_summary=None)` is deprecated in v1.5 and will be removed"
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Missing logger folder: /content/lightning_logs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Sanity Checking: 0it [00:00, ?it/s]



Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

## Test accuracy of resnet18

In [27]:
def accuracy(pred: Tensor, labels: Tensor) -> float:
    return (pred.argmax(dim=-1) == labels).float().mean().item()


model.cuda()
acc = sum([accuracy(model(x.cuda()), y.cuda()) for x, y in val_loader]) / len(val_loader)
print(f"Accuracy: {acc:.3f}")

Accuracy: 0.847


## Train accuracy of resnet18

In [9]:
acc = sum([accuracy(model(x.cuda()), y.cuda()) for x, y in train_loader]) / len(train_loader)
print(f"Accuracy: {acc:.3f}")

Accuracy: 1.000


In [10]:
model = resnet18(pretrained=True)
byol = BYOL(model, image_size=(96, 96))
trainer = pl.Trainer(
    max_epochs=50, 
    gpus=-1,
    accumulate_grad_batches=2048 // 128,
    weights_summary=None,
)
train_loader = DataLoader(
    TRAIN_DATASET,
    batch_size=128,
    shuffle=True,
    drop_last=True,
)
trainer.fit(byol, train_loader, val_loader)

  "Setting `Trainer(weights_summary=None)` is deprecated in v1.5 and will be removed"
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Sanity Checking: 0it [00:00, ?it/s]



Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

In [11]:

# and load the state dictionary into the new model.
#
# This ensures that we remove all hooks from the previous model,
# which are automatically implemented by BYOL.
state_dict = model.state_dict()
model = resnet18()
model.load_state_dict(state_dict)

supervised = SupervisedLightningModule(model)
trainer = pl.Trainer(
    max_epochs=25, 
    gpus=-1,
    weights_summary=None,
)
train_loader = DataLoader(
    TRAIN_DATASET,
    batch_size=128,
    shuffle=True,
    drop_last=True,
)
trainer.fit(supervised, train_loader, val_loader)

  "Setting `Trainer(weights_summary=None)` is deprecated in v1.5 and will be removed"
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Sanity Checking: 0it [00:00, ?it/s]



Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

## TRAIN ACCURACY

In [28]:

model.cuda()
acc = sum([accuracy(model(x.cuda()), y.cuda()) for x, y in train_loader]) / len(train_loader)
print(f"Accuracy: {acc:.3f}")

Accuracy: 0.996


## TEST ACCURACY

In [14]:

model.cuda()
acc = sum([accuracy(model(x.cuda()), y.cuda()) for x, y in val_loader]) / len(val_loader)
print(f"Accuracy: {acc:.3f}")

Accuracy: 0.847


In [None]:
def accuracy(pred: Tensor, labels: Tensor) -> float:
    return (pred.argmax(dim=-1))
acc = list([accuracy(model(x.cuda()), y.cuda()) for x, y in val_loader])


In [16]:
import tensorflow as tf

In [17]:
from re import I
from sklearn.metrics import f1_score
from sklearn.metrics import recall_score
tf.compat.v1.disable_eager_execution()
labels=[y.cuda() for x,y in val_loader]
# recall = recall_score(labels.numpy(), accuracy(), labels=[1,2], average='micro')

temp=[]
for i in labels:
  temp=temp+list(i.cpu().numpy())

acc_temp=[]
for i in acc:
  acc_temp=acc_temp+list(i.cpu().numpy())

## Performance metrics

In [18]:
recall = recall_score(temp,acc_temp, average='macro')
print('Recall: %.3f' % recall)

Recall: 0.847


In [19]:
res=f1_score(temp, acc_temp, average='macro')
print('F1score: %.4f' % res)

F1score: 0.8475


In [25]:
recall = recall_score(temp,acc_temp,average="micro")
print('Recall: %.3f' % recall)

Recall: 0.847


In [21]:
res=f1_score(temp, acc_temp, average='micro')
print('F1score: %.4f' % res)

F1score: 0.8472


## Confusion matrix

In [22]:
from sklearn.metrics import confusion_matrix
confusion_matrix(temp,acc_temp)

array([[735,   9,   8,   3,   3,   6,   6,   1,  16,  13],
       [  5, 697,   0,  25,   7,  24,   0,  40,   1,   1],
       [ 10,   1, 718,   4,   1,   0,   3,   1,   2,  60],
       [  1,  25,   0, 567,  48, 107,   7,  41,   1,   3],
       [  1,  30,   1,  49, 644,  31,  26,  17,   0,   1],
       [  2,  18,   0,  59,  23, 579,  51,  68,   0,   0],
       [  4,   3,   0,  10,  21,  69, 670,  18,   3,   2],
       [  1,   9,   1,  44,  18,  29,   5, 691,   1,   1],
       [ 26,   1,   1,   1,   0,   0,   0,   0, 752,  19],
       [ 16,   1,  30,   0,   0,   0,   2,   0,  26, 725]])