In [None]:
# !pip install lightly
# !pip install scikit-dimension

In [2]:
import random

import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision
from torchvision import datasets
import torchvision.transforms as T

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 
SEED=42

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

print(DEVICE)

cpu


In [3]:
# utils and metrics
from skdim.id import MLE # Maximum-Likelihood ID esimation

def model_param_count(model):
    return sum([np.prod(p.shape) for p in model.parameters()])

class MeanSquareDistancesStat:
    def __init__(self, ds_size, batch_size, model, n_epochs, hidden_dim):
        """
        ds_size: size of dataset
        batch_size: size of batch
        model: encoder that is expected to return representations of size (batch_size, hidden_dim)
        hidden_dim: dimension of representations
        """
        self.last_epoch_repr = torch.zeros(ds_size, hidden_dim)
        self.mdp_hist = torch.zeros((n_epochs, ds_size))
        self.encoder = model
        self.batch_size = batch_size
        
        self.n_samples = ds_size
        self.n_epochs = n_epochs
        
        self.cur_batch = 0
        self.cur_epoch = 0

    @torch.no_grad()
    def mdp(self, x_prev, x_next):
        return torch.mean((x_next - x_prev) ** 2, dim=1)

    def push(self, x_batch):
        if self.cur_epoch == self.n_epochs: raise ValueError('Statistics is already collected')
        
        l, r = self.cur_batch * self.batch_size, (self.cur_batch + 1) * self.batch_size
        
        x_repr_cur = self.encoder(x_batch)
        x_repr_prev = self.last_epoch_repr[l:r]

        self.mdp_hist[self.cur_epoch, l:r] = self.mdp(x_repr_prev, x_repr_cur).cpu()
        
        self.last_epoch_repr[l:r] = x_repr_cur
        self.cur_batch += 1

        if self.cur_batch * self.batch_size >= self.n_samples:
            self.cur_batch = 0
            self.cur_epoch += 1

class IDStats:
    def __init__(self, ds_size:int, mle_kwargs):
        self.ds_size = ds_size
        self.id_hist = []

        self.mle_est = MLE(**mle_kwargs)

    def push(self, x_enc):
        assert x_enc.shape[0] == self.ds_size
        self.id_hist.append(self.mle_est.fit_transform_pw(x_enc))
        
        # QUESTION: should we:
        # - fit to train => predict on test
        # - fit on test => predict on test
        # - fit on train => predict on train



## Self-Supervised Contrastive Models

## Barlow Twins
https://arxiv.org/abs/2103.03230

In [None]:
from lightly.loss import BarlowTwinsLoss
from lightly.models.modules import BarlowTwinsProjectionHead
from lightly.transforms.byol_transform import (
    BYOLView1Transform,
    BYOLView2Transform,
    MultiViewTransform
)
from lightly.transforms.utils import IMAGENET_NORMALIZE

class BarlowTwins(nn.Module):
    def __init__(self, backbone):
        super().__init__()
        self.backbone = backbone
        self.projection_head = BarlowTwinsProjectionHead(512, 2048, 2048)

    def forward(self, x):
        x = self.backbone(x).flatten(dim=1)
        z = self.projection_head(x)
        return z

In [13]:
CHS = 3
IMG_SIZE = (32, 32)
BS = 256

class BYOLTransformWrapped(MultiViewTransform):
        """Appends BYOL transform output with not augmented images"""
        def __init__(self, view_1_transform, view_2_transform):
                view_1_transform = view_1_transform or BYOLView1Transform()
                view_2_transform = view_2_transform or BYOLView2Transform()
                transforms = [
                        T.Compose([T.ToTensor(), T.Normalize(mean=IMAGENET_NORMALIZE["mean"], std=IMAGENET_NORMALIZE["std"])]),
                        view_1_transform, 
                        view_2_transform
                ]
                super().__init__(transforms=transforms)


transform = BYOLTransformWrapped( # note: this thing works only with 3
        view_1_transform=BYOLView1Transform(input_size=IMG_SIZE[0], gaussian_blur=0.0),
        view_2_transform=BYOLView2Transform(input_size=IMG_SIZE[0], gaussian_blur=0.0),
)

ds = datasets.CIFAR10(root='./data', transform=transform, download=True)
loader = DataLoader(ds, batch_size=BS, shuffle=True, drop_last=True, num_workers=4)

Files already downloaded and verified


In [6]:
len(ds), next(iter(loader))[0][0].shape

(50000, torch.Size([64, 3, 32, 32]))

In [11]:
backbone = nn.Sequential(
    nn.Flatten(),
    nn.Linear(CHS*np.prod(IMG_SIZE), 512),
    nn.ELU(),
    nn.Linear(512, 512),
    nn.ELU()
)

print(model_param_count(backbone))

1836032


In [14]:
NUM_EPOCHS=10

model = BarlowTwins(backbone).to(DEVICE)
criterion = BarlowTwinsLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.06)

msd_tracker = MeanSquareDistancesStat(len(ds), BS, model, NUM_EPOCHS, hidden_dim=2048)

for epoch in range(10):
    total_loss = 0
    for x, _ in tqdm(loader):
        x_orig, x0, x1 = x
        x_orig = x_orig.to(DEVICE)
        x0 = x0.to(DEVICE)
        x1 = x1.to(DEVICE)

        z0 = model(x0)
        z1 = model(x1)

        loss = criterion(z0, z1)
        total_loss += loss.detach()
        loss.backward()

        msd_tracker.push(x_orig)
        optimizer.step()
        optimizer.zero_grad()
    
    avg_loss = total_loss / len(loader)
    print(f"epoch: {epoch:>02}, loss: {avg_loss:.5f}")

 64%|██████▍   | 125/195 [04:45<02:34,  2.21s/it]