In [2]:
import importlib
import torch
import torchvision

from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor
import time

from learning_loop import BATCH_SIZE

In [2]:
stl = datasets.STL10(root="data", split="train", download=True, transform=ToTensor())

Downloading http://ai.stanford.edu/~acoates/stl10/stl10_binary.tar.gz to data/stl10_binary.tar.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting data/stl10_binary.tar.gz to data


In [230]:
def off_diagonal(x):
    # return a flattened view of the off-diagonal elements of a square matrix
    n, m = x.shape
    assert n == m
    return x.flatten()[:-1].view(n - 1, n + 1)[:, 1:].flatten()

class BarlowTwins(nn.Module):
    def __init__(self, batch_size):
        super().__init__()
        self.lam = 0.0051
        self.batch_size = batch_size

        # So this is where the res net is.  Cool.
        self.backbone = torchvision.models.resnet50(zero_init_residual=True)
        self.backbone.fc = nn.Identity()

        sizes = [2048, 1024]

        layers = []
        for i in range(len(sizes) - 2):
            layers.append(nn.Linear(sizes[i], sizes[i + 1], bias=False))
            layers.append(nn.BatchNorm1d(sizes[i + 1]))
            layers.append(nn.ReLU(inplace=True))
        layers.append(nn.Linear(sizes[-2], sizes[-1], bias=False))
        self.projector = nn.Sequential(*layers)

        # normalization layer for the representations z1 and z2
        self.bn = nn.BatchNorm1d(sizes[-1], affine=False)

    def scale_model(self, alpha):
        state_dict = self.state_dict()

        for name, param in state_dict.items():
            # Don't update if this is not a weight.
            if not "weight" in name:
                continue

            # Transform the parameter as required.
            transformed_param = param * alpha

            # Update the parameter.
            param.copy_(transformed_param)
            
    def forward_reps(self, y1):
        return self.bn(self.projector(self.backbone(y1)))
    
    def cov_eig(self, y1):
        reps = self.forward_reps(y1)
        cov = (reps.T @ reps) / self.batch_size
        e_vals, _ = torch.eig(cov)
        e_vals = list(map(lambda x : x[0].item(), e_vals))
        e_vals.sort(reverse=True)

        return e_vals
        
    def forward(self, y1):
        reps = self.forward_reps(y1)

        # empirical cross-correlation matrix
        c = reps.T @ reps

        # sum the cross-correlation matrix between all gpus
        c.div_(self.batch_size)

        on_diag = torch.diagonal(c).add_(-1).pow_(2).sum()
        off_diag = off_diagonal(c).pow_(2).sum()
        loss = on_diag + self.lam * off_diag

        return loss

In [226]:
stl_loader = DataLoader(stl, batch_size=BATCH_SIZE)
data, cls = next(iter(stl_loader))

In [243]:
model = BarlowTwins(BATCH_SIZE)

In [244]:
model.cov_eig(data)[0]

193.686767578125

In [245]:
model.scale_model(0.04)
model.cov_eig(data)[0]

3.767479574889876e-05