Read the article Barlow Twins - Self-Supervised Learning via Redundancy Reduction: https://arxiv.org/abs/2103.03230

model architecture: https://www.researchgate.net/figure/Schematic-representation-of-Barlow-twinsZbontar-et-al-2021_fig1_362858330

CIFAR-10 Dataset: https://www.cs.toronto.edu/~kriz/cifar.html 

In [1]:
import random
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models, transforms, datasets

In [None]:
train_set = datasets.CIFAR10(root='./data', train=True, download=True)
val_set = datasets.CIFAR10(root='./data', train=False)
print(f"Total training examples: {len(train_set)}")
print(f"Total validation examples: {len(val_set)}")

plt.figure(figsize=(10, 10))
for n in range(25):
    ax = plt.subplot(5, 5, n + 1)
    plt.imshow(np.asarray(train_set[n][0]).astype("int"))
    plt.axis("off")
plt.show()

In [4]:
def get_transform():
    '''
    returns a transform that randomly crops, flips, jitters color or drops color from the input
    '''
    return transforms.Compose([
                transforms.RandomResizedCrop(32, scale=[0.75, 1.0], 
                                            interpolation=Image.BICUBIC),
                transforms.RandomHorizontalFlip(p=0.5),
                transforms.RandomApply(
                    [transforms.ColorJitter(brightness=0.4, contrast=0.4,
                                            saturation=0.2, hue=0.1)],
                    p=0.9
                ),
                transforms.RandomGrayscale(p=0.3),
            ])

In [None]:
t1 = get_transform()
t2 = get_transform()

plt.figure(figsize=(8, 4))
for n in range(8):
    im_t1 = t1(train_set[n][0])
    plt.subplot(2, 4, n + 1)
    plt.imshow(np.asarray(im_t1).astype("int"))
    plt.axis("off")
plt.show()

plt.figure(figsize=(8, 4))
for n in range(8):
    im_t2 = t2(train_set[n][0])
    plt.subplot(2, 4, n + 1)
    plt.imshow(np.asarray(im_t2).astype("int"))
    plt.axis("off")
plt.show()

In [6]:
def off_diagonal(x):
    '''
    returns a flattened view of the off-diagonal elements of a square matrix x
    '''
    n, m = x.shape
    assert n == m
    
    def flatten(t):
        return t.reshape(t.shape[0], -1)

    return x.flatten()[:-1].view(n - 1, n + 1)[:, 1:].flatten()

In [7]:
def barlow_loss(z1, z2, bn, lambd):
    '''
    return the barlow twins loss function for a pair of features. Makes use of the off_diagonal function.
    
    :param z1: first input feature
    :param z2: second input feature
    :param bn: nn.BatchNorm1d layer applied to z1 and z2
    :param lambd: trade-off hyper-parameter lambda
    '''
    # empirical cross-correlation matrix
    c = torch.mm(bn(z1).T, bn(z2))
    c.div_(z1.shape[0])

    on_diag = torch.diagonal(c).add_(-1).pow_(2).sum()
    off_diag = off_diagonal(c).pow_(2).sum()
    return (on_diag + lambd * off_diag)

In [8]:
class Projector(nn.Module):
    '''
    2-layer neural network (512 -> 256), (256 -> 128), ReLU non-linearity
    '''
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(512, 256)
        self.fc2 = nn.Linear(256, 128)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [9]:
class BarlowTwins(nn.Module):
    '''
    Full Barlow Twins model with encoder, projector and loss
    '''
    def __init__(self, encoder, projector, lambd):
        '''
        :param encoder: encoder network
        :param projector: projector network
        :param lambd: tradeoff function (hyper-parameter)
        '''
        super().__init__()
        self.encoder = encoder
        self.projector = projector
        self.lambd = lambd     

        # normalization layer for the representations z1 and z2
        self.bn = nn.BatchNorm1d(128, affine=False)

    def forward(self, y1, y2):
        z1 = self.encoder(y1)
        z2 = self.encoder(y2)
        z1 = self.projector(z1)
        z2 = self.projector(z2)

        return barlow_loss(z1, z2, self.bn, self.lambd)

In [10]:
cifar_train_mean = [125.30691805, 122.95039414, 113.86538318]
cifar_train_std = [62.99321928, 62.08870764, 66.70489964]

class Transform:
    def __init__(self, t1, t2):
        '''
        :param t1: Transforms to be applied to first input
        :param t2: Transforms to be applied to second input
        '''
        self.t1 = transforms.Compose([
                t1,
                transforms.ToTensor(),
                transforms.Normalize(mean=cifar_train_mean, std=cifar_train_std)
            ])
        self.t2 = transforms.Compose([
                t2,
                transforms.ToTensor(),
                transforms.Normalize(mean=cifar_train_mean, std=cifar_train_std)
            ])

    def __call__(self, x):
        y1 = self.t1(x)
        y2 = self.t2(x)
        return y1, y2

In [11]:
# Hyper-parameters
EPOCHS = 10
LR = 0.001
BATCH = 256
LAMBDA = 5e-3

# Initialize encoder, projector and full model
encoder = models.resnet18(pretrained=False)
encoder.fc = nn.Identity() # removes the 1000-dimensional classification layer
projector = Projector()
twins = BarlowTwins(encoder, projector, LAMBDA).cuda()

# Dataset and optimizer
dataset = datasets.CIFAR10(root='./data', train=True, transform=Transform(t1, t2))
loader = torch.utils.data.DataLoader(dataset,
                                        batch_size=BATCH,
                                        num_workers=4, # For some students, this cell took >1h to run. But when they used num_workers=0 it worked
                                        shuffle=True)
optimizer = torch.optim.Adam(twins.parameters(), lr=LR)

# Training loop
for epoch in range(EPOCHS):
    for batch_idx, ((x1,x2), _) in enumerate(loader):
        loss = twins(x1.cuda(), x2.cuda())
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    print(f"Epoch: {epoch + 1}, Loss: {float(loss)}")

  f"The parameter '{pretrained_param}' is deprecated since 0.13 and will be removed in 0.15, "
  cpuset_checked))


Epoch: 1, Loss: 7.610368251800537
Epoch: 2, Loss: 8.644915580749512
Epoch: 3, Loss: 5.734394073486328
Epoch: 4, Loss: 5.0028839111328125
Epoch: 5, Loss: 5.47348690032959
Epoch: 6, Loss: 5.095256328582764
Epoch: 7, Loss: 5.4128570556640625
Epoch: 8, Loss: 5.202373027801514
Epoch: 9, Loss: 5.410131454467773
Epoch: 10, Loss: 4.758655548095703


In [13]:
NUM_SAMPLES = 1000

# 1000 random (image, label) pairs from train set
train_indices = random.sample(range(len(train_set)), k=NUM_SAMPLES)
train_subset = train_set.data[train_indices]
train_subset_labels = np.array(train_set.targets)[train_indices]

# 1000 random (image, label) pairs from validation set
val_indices = random.sample(range(len(val_set)), k=NUM_SAMPLES)
val_subset = val_set.data[val_indices]
val_subset_labels = np.array(val_set.targets)[val_indices]

In [14]:
# We calculate this for each one step of validation. We take one validation sample and look into all used training samples to find out
# K number of training samples with lowest distance (L1) to validation sample
# We count how many of those K samples belonge to each class (here we have 10 classes) and return the class with highest count (where majority of K fell)


def predict_knn(sample, train_data, train_labels, k):
    '''
    returns the predicted label for a specific validation sample
    
    :param sample: single example from validation set
    :param train_data: full training set as a single array
    :param train_labels: full set of training labels and a single array
    :param k: number of nearest neighbors used for k-NN voting
    '''
    data = train_data.reshape(NUM_SAMPLES, -1)
    label_count = np.zeros(10)            # because dataset used has 10 classes
    dist = np.sum(np.abs(sample.flatten() - data), axis=1)
    idx = np.argpartition(dist,k)         # partitions and sorts dist from small to large valued k chunks 
    min_ind = idx[:k]                     # we take only first k chunk, as it containes lowest values (smalest distances between one validation and all test examples)
    for x in min_ind:
        label_count[train_labels[x]] +=1
    return np.argmax(label_count)

In [15]:
# Dataloaders for extracting self-supervised features
test_transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize(mean=cifar_train_mean, std=cifar_train_std)
            ])

train_set = datasets.CIFAR10(root='./data', train=True, transform=test_transform)
val_set = datasets.CIFAR10(root='./data', train=False, transform=test_transform)

train_subset_torch = torch.utils.data.Subset(train_set, train_indices)
val_subset_torch = torch.utils.data.Subset(val_set, val_indices)

train_loader = torch.utils.data.DataLoader(train_subset_torch,
                                        batch_size=NUM_SAMPLES,
                                        shuffle=False)
val_loader = torch.utils.data.DataLoader(val_subset_torch,
                                        batch_size=NUM_SAMPLES,
                                        shuffle=False)

# Extract features with the trained encoder
# We use a single batch of size 1000
for batch in train_loader:
    train_features = encoder(batch[0].cuda()).data.cpu().numpy()

for batch in val_loader:
    val_features = encoder(batch[0].cuda()).data.cpu().numpy()

In [16]:
predictions_7 = []
predictions_13 = []
predictions_19 = []
for sample in val_features:
    predictions_7.append(predict_knn(sample, train_features, train_subset_labels, k=7))
    predictions_13.append(predict_knn(sample, train_features, train_subset_labels, k=13))
    predictions_19.append(predict_knn(sample, train_features, train_subset_labels, k=19))

In [17]:
matches_7 = (np.array(predictions_7) == val_subset_labels)
accuracy_7 = np.sum(matches_7)/NUM_SAMPLES * 100
print(f"k-NN accuracy (k=7): {accuracy_7}%")

matches_13 = (np.array(predictions_13) == val_subset_labels)
accuracy_13 = np.sum(matches_13)/NUM_SAMPLES * 100
print(f"k-NN accuracy (k=13): {accuracy_13}%")

matches_19 = (np.array(predictions_19) == val_subset_labels)
accuracy_19 = np.sum(matches_19)/NUM_SAMPLES * 100
print(f"k-NN accuracy (k=19): {accuracy_19}%")

k-NN accuracy (k=7): 31.8%
k-NN accuracy (k=13): 31.5%
k-NN accuracy (k=19): 32.300000000000004%
