<table style="background-color:#FFFFFF">   
  <tr>     
  <td><img src="https://upload.wikimedia.org/wikipedia/commons/9/95/Logo_EPFL_2019.svg" width="150x"/>
  </td>     
  <td>
  <h1> <b>CS-461: Foundation Models and Generative AI</b> </h1>
  Prof. Charlotte Bunne  
  </td>   
  </tr>
</table>

# ðŸ“š  Exercise Session (Coding Part) - 2

In this exercise, you will implement, explore and compare three self-supervised training frameworks.

1. **SimCLR (2020)** ([Link to paper](https://arxiv.org/pdf/2002.05709))

2. **BYOL - Bootstrap your own latent (2020)** ([Link to paper](https://arxiv.org/pdf/2006.07733))

3. **Barlow Twins (2021)** ([Link to paper](https://arxiv.org/pdf/2103.03230))

Each of these has introduced important contributions shaping current state-of-the-art self-supervised learning frameworks such as DinoV2.

First, we import the following packages:

In [None]:
import copy
import torch
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import ImageGrid
import numpy as np

import torch.nn as nn
import torch.nn.functional as F
import math
import os

from torchvision.datasets import CIFAR10
from torchvision.transforms import v2 as T
from tqdm import tqdm
from torch.utils.data import DataLoader, Dataset
device = 'cuda' if torch.cuda.is_available() else 'cpu'

from sklearn.neighbors import KNeighborsClassifier
from sklearn.linear_model import LogisticRegression
from sklearn import metrics as skm
from torchvision.models import resnet18

## 1. Datasets & Transformations

In the following, we will work with the **CIFAR10 dataset**, consisting of 60,000 32x32 color images in 10 classes, with 6,000 images per class.\
There are 50,000 training images and 10,000 test images.

Let's download/load it and define a default transformation turning a PIL Image into a `torch.tensor`


In [None]:
default_transform = T.Compose([
    T.ToTensor(),
])

train = CIFAR10('./data', train=True, download=True, transform=default_transform)
test = CIFAR10('./data', train=False, download=True, transform=default_transform)

Let's have a look at the a few examples. 

#### Task: Visualization
Complete the following code cell to visualize the first 9 images in a 3x3 grid. 

In [None]:
fig = plt.figure(figsize=(6,6))
grid = ImageGrid(fig, 111, nrows_ncols=(3,3), axes_pad=0.2)

for ax, img in zip(grid, train):
    ax.imshow(img[0].permute(1,2,0))
    ax.axis('off')

## 2. Evaluation Pipeline

One challenge in self-supervised training is that the training losses typically do not directly correspond to the model's downstream performance on an actual task of interest, e.g., image classification.

Therefore, we next implement a simple evaluation pipeline, which we will later use to monitor our training progress.

#### Task: Evaluation Tools
Complete the following function `extract_features_and_labels` which encodes all samples provided by the dataloader. If `normalize` is `True` it should additionally L2-normalize each feature vector.

**Caution:** The `forward` pass of our models will later return either a batch of feature vectors of shape `(B,D)` or a `list/tuple`, with the list of features vectors being the first element. Make sure to account for this.

In [None]:
def extract_features_and_labels(model, dataloader, normalize=False):
    """
    Extract features and labels from a dataloader using the given model.
    model: an encoder model taking as input a batch of images (batch_size, channels, height, width) and outputing either a batch of feature vectors (batch_size, feature_dim) or a list/tuple in which the first element is the batch of feature vectors (batch_size, feature_dim)
    dataloader: a PyTorch dataloader providing batches of (images, labels)
    returns: features (num_samples, feature_dim), labels (num_samples,)
    """
    features = []
    labels = []

    for batch in tqdm(dataloader, disable=True):
        x, y = batch
        x = x.to(device)
        with torch.no_grad():
            feat = model(x)
            if type(feat) is tuple or type(feat) is list:
                repr = feat[0]
                feat = repr
        features.append(feat.cpu())
        labels.append(y)

    features = torch.cat(features, dim=0)
    labels = torch.cat(labels, dim=0)

    if normalize:
        features = F.normalize(features, dim=1)

    return features, labels

Given a classification task at hand, two established routines to evaluate the quality of representations are **Neigheast Neighbour Probes** and **Linear Probes**, in which either a KNN classifier or a logistic regression model is trained and evaluated.

Implement such probes in the following functions `run_knn_probe` and `run_linear_probe`. These functions should return the achieved test accuracy. Feel free to use tools from `sklearn`.

In [None]:
def run_knn_probe(train_features, train_labels, test_features, test_labels):
    """
    Runs a k-NN probe on the given features and labels.
    train_features: (num_train_samples, feature_dim)
    train_labels: (num_train_samples,)
    test_features: (num_test_samples, feature_dim)
    test_labels: (num_test_samples,)
    returns: accuracy (float)
    """
    knn = KNeighborsClassifier(n_neighbors=5, n_jobs=-1)
    knn.fit(train_features, train_labels)
    test_preds = knn.predict(test_features)
    accuracy = skm.accuracy_score(test_labels, test_preds)
    return accuracy

def run_linear_probe(train_features, train_labels, test_features, test_labels):
    """
    Runs a linear probe on the given features and labels.
    train_features: (num_train_samples, feature_dim)
    train_labels: (num_train_samples,)
    test_features: (num_test_samples, feature_dim)
    test_labels: (num_test_samples,)
    returns: accuracy (float)
    """
    # TODO
    logreg = LogisticRegression(max_iter=1000, n_jobs=-1)
    logreg.fit(train_features, train_labels)
    test_preds = logreg.predict(test_features)
    accuracy = skm.accuracy_score(test_labels, test_preds)
    return accuracy

Let's test our evaluation pipeline on a randomly initalized ResNet18 model.

In [None]:
model = resnet18(weights=None)
model.fc = nn.Identity()
model = model.to(device)

In [None]:
train_dataloader = DataLoader(train, batch_size=256, shuffle=True, num_workers=10)
test_dataloader = DataLoader(test, batch_size=256, shuffle=False, num_workers=10)

train_features, train_labels = extract_features_and_labels(model, train_dataloader)
test_features, test_labels = extract_features_and_labels(model, test_dataloader)

knn_accuracy = run_knn_probe(train_features.numpy(), train_labels.numpy(), test_features.numpy(), test_labels.numpy())
linear_accuracy = run_linear_probe(train_features.numpy(), train_labels.numpy(), test_features.numpy(), test_labels.numpy())

print(f'k-NN accuracy: {knn_accuracy*100:.2f}%')
print(f'Linear probe accuracy: {linear_accuracy*100:.2f}%')

Interestingly, we observe that even the representations of a randomly initalized ResNet18 achieve already accuracies far beyond random performance. Keep this observation in mind for later!

## 3. SimCLR

We will start again by implementing **SimCLR** (Simple Framework for Contrastive Learning of Visual Representations), a self-supervised learning method for training deep visual representations without labeled data.\
If you have already sucessfully implemented and trained the SimCLR model in the Exercise Session 1 (Task 3), feel free to skip the implementation part of this section and load your model from last week for the evaluation.

SimCLR is based on the idea of contrastive learning, i.e., the key idea to learn representations by maximizing agreement between differently augmented views of the same image (termed positive pairs) while minimizing agreement between pairs of views of different images (termed negative pairs).  

**Main components:**
1. **Data Augmentation:** Generate two correlated views of the same image (e.g., random crop, color distortion, Gaussian blur).  
2. **Encoder Network:** A deep neural network (commonly ResNet) extracts feature representations from each view.  
3. **Projection Head:** A small MLP maps the representations into a latent space where contrastive loss is applied.  
4. **Contrastive Loss (NT-Xent):** Encourages representations of augmented views of the same image to be close, while pushing apart those of different images.  

#### Augmentations

A key contribution of the SimCLR paper was its systematic study and ablation of image augmentations for contrastive learning. Based on these findings, the SimCLR framework applies the following transformations: random cropping and resizing, horizontal flips, color distortions, grayscale conversion, and Gaussian blur.

We implement these transformations in the following.

In [None]:
class SimCLRTransform:

    def __init__(self, size=32, s=0.5, blur_p=0.5):
        color_jitter = T.ColorJitter(0.8*s, 0.8*s, 0.8*s, 0.2*s)
        k = 3 if size <= 32 else 5
        base = [
            T.RandomResizedCrop(size=size, scale=(0.2, 1.0)),
            T.RandomHorizontalFlip(),
            T.RandomApply([color_jitter], p=0.8),
            T.RandomGrayscale(p=0.2),
            T.RandomApply([T.GaussianBlur(kernel_size=k, sigma=(0.1, 2.0))], p=blur_p),
            T.ToTensor()
        ]
        self.train_transform = T.Compose(base)

    def __call__(self, x):
        return self.train_transform(x), self.train_transform(x)

def simclr_collate_fn(batch):
    xs1, xs2, ys = [], [], []
    for (x1, x2), y in batch:
        xs1.append(x1)
        xs2.append(x2)
        ys.append(y)
    return torch.stack(xs1), torch.stack(xs2), torch.tensor(ys)

#### Task: SimCLR Model Architecture

In SimCLR, the model architecture consists of two stacked components:  
1. An encoder $f(\cdot)$, e.g., a CNN, that yields the image representation.  
2. A projector head $g(\cdot)$, e.g., an MLP, that yields projections of the image representations. These are used only during training.  

Implement the `forward` pass of the architecture below. Follow the specifications in the docstring. We already provide example network layers (tested for convergence), but feel free to explore alternative (potentially better) configurations!  

**Caution:** Make sure to L2-normalize the projections before returning them.


In [None]:
class SimCLRModel(nn.Module):

    def __init__(self, proj_dim=128, hidden=2048):
        super().__init__()
        enc = resnet18(weights=None)
        enc.conv1 = nn.Conv2d(3, 64, 3, 1, 1, bias=False)
        enc.maxpool = nn.Identity()
        enc.fc = nn.Identity()
        self.encoder = enc

        self.projector = nn.Sequential(
            nn.Linear(512, hidden),
            nn.ReLU(inplace=True),
            nn.Linear(hidden, proj_dim) # (bs, proj_dim)
        )

    def normalize(self, x, eps=1e-8):
        """
        Normalizes a batch of feature vectors.
        """
        return x / (x.norm(dim=-1, keepdim=True) + eps)

    def forward(self, x):
        """
        x: (batch_size, channels, height, width) tensor of images
        returns: repr (batch_size, feature_dim), proj (batch_size, proj_dim)
        """
        # TODO
        repr = self.encoder(x)
        proj = self.normalize(self.projector(repr))
        return repr, proj

#### Task: NT-Xent Loss

As mentioned, in SimCLR training we maximize agreement between different views of the same image, while minimizing agreement between views of distinct images. For this, we use the normalized temperature-scaled cross entropy (NT-Xent) loss.

Let $\{x_i\}_{i=1}^{N}$ be a batch of image views, with $x_{2k-1}$ and $x_{2k}$ being views of the same original image.  
For each $i = 1, \dots, N$, we obtain a projected representation $z_i$ as
$$
z_i = g(f(x_i)).
$$

For each positive pair $(i,j)$ and temperature $\tau$, the NT-Xent loss is computed as
$$
\ell_{i,j} = - \log \frac{\exp\big(\mathrm{sim}(\mathbf{z}_i, \mathbf{z}_j)/\tau\big)}
{\sum_{k=1}^{2N} \mathbf{1}_{[k \neq i]} \exp\big(\mathrm{sim}(\mathbf{z}_i, \mathbf{z}_k)/\tau\big)}
$$
and for a complete batch as 
$$
\mathcal{L} = \frac{1}{N} \sum_{k=1}^{N} \ell_{2k-1,2k} + \ell_{2k,2k-1}.
$$

Implement this loss in the following function `nt_xent`.  
It takes as input two tensors of projected representations, `z1` and `z2`, each corresponding to different augmentations of the same batch of images; i.e., for each index `i`, the vectors `z1[i]` and `z2[i]` form a positive pair.


In [None]:
def nt_xent(z1, z2, tau=0.5):
    """
    Computes NT-Xent loss.
    z1: (batch_size, feature_dim) tensor of normalized projection vectors
    z2: (batch_size, feature_dim) tensor of normalized projection vectors
    returns: loss (scalar)
    """
    # TODO
    B, d = z1.shape
    z = torch.cat([z1, z2], dim=0)              # (2B, d)
    sim = (z @ z.t()) / tau                     # (2B, 2B)
    mask = torch.eye(2*B, dtype=torch.bool, device=z.device)
    sim.masked_fill_(mask, -1e9)
    targets = torch.arange(B, device=z.device)
    targets = torch.cat([targets + B, targets], dim=0)
    return F.cross_entropy(sim, targets)

Let us initialize the datasets and dataloaders. For evaluation, we prepare extra dataloaders without the SimCLR augmentations.


In [None]:
simclr_transform = SimCLRTransform(size=32)

train_ds = CIFAR10(root="./data", train=True, download=True, transform=simclr_transform)
test_ds  = CIFAR10(root="./data", train=False, download=True, transform=simclr_transform)

train_loader = DataLoader(train_ds, batch_size=256, num_workers=10, pin_memory=True, collate_fn=simclr_collate_fn)
test_loader  = DataLoader(test_ds,  batch_size=256, num_workers=10, pin_memory=True, collate_fn=simclr_collate_fn)

train_ds_noaugment = CIFAR10(root="./data", train=True, download=True, transform=default_transform)
test_ds_noaugment  = CIFAR10(root="./data", train=False, download=True, transform=default_transform)

train_loader_noaugment = DataLoader(train_ds_noaugment, batch_size=256, shuffle=False, num_workers=10, pin_memory=True)
test_loader_noaugment  = DataLoader(test_ds_noaugment,  batch_size=256, shuffle=False, num_workers=10, pin_memory=True)

#### Task: SimCLR Training

Now that all components are ready, it's time to put everything together.  

Complete the training pipeline below and train your SimCLR model.


In [None]:
simclr_model = SimCLRModel(proj_dim=128).to(device)

total_epochs = 50
warmup_epochs = 10

def lr_lambda(epoch):
    if epoch < warmup_epochs:
        return (epoch + 1) / float(warmup_epochs)
    t = (epoch - warmup_epochs) / float(total_epochs - warmup_epochs)
    return 0.0 + 0.5 * (1 - 0.0) * (1 + math.cos(math.pi * t))

optimizer = torch.optim.AdamW(simclr_model.parameters(), lr=0.6, weight_decay=0.0)
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

all_accuracies = []
all_losses = []
for epoch in range(total_epochs):
    simclr_model.train()
    avg_loss = 0.
    for x1, x2, y in tqdm(train_loader):
        # TODO
        x1, x2 = x1.to(device), x2.to(device)

        _, proj1 = simclr_model(x1)
        _, proj2 = simclr_model(x2)
        loss = nt_xent(proj1, proj2, tau=0.5)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        avg_loss += loss.item()

    scheduler.step()
    simclr_model.eval()
    train_features, train_labels = extract_features_and_labels(simclr_model, train_loader_noaugment, normalize=True)
    test_features, test_labels = extract_features_and_labels(simclr_model, test_loader_noaugment, normalize=True)
    acc = run_knn_probe(train_features, train_labels, test_features, test_labels)
    avg_loss = avg_loss / len(train_loader)
    print(f"Epoch {epoch+1}, Average Loss: {avg_loss:.4f}, kNN Accuracy: {acc:.4f}")
    all_accuracies.append(acc)
    all_losses.append(avg_loss)

os.makedirs('./checkpoints', exist_ok=True)
torch.save(simclr_model.state_dict(), f'./checkpoints/simclr_cifar10.pth')
print('Model saved to ./checkpoints/simclr_cifar10.pth')

Let's visualize the training progress by plotting the loss and accuracy curves!

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(12, 5))
axes[0].plot(range(1, total_epochs+1), all_losses)
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Average NT-Xent Loss')

axes[1].plot(range(1, total_epochs+1), all_accuracies)
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('k-NN Accuracy (%)')
axes[1].set_ylim(0, 1)

plt.show()

Let's also run a quick linear probe for the final results!

In [None]:
train_features, train_labels = extract_features_and_labels(simclr_model, train_loader_noaugment, normalize=True)
test_features, test_labels = extract_features_and_labels(simclr_model, test_loader_noaugment, normalize=True)

acc = run_linear_probe(train_features, train_labels, test_features, test_labels)
print(f'Final Linear Probe Accuracy: {acc:.4f}%')
print(f'Final k-NN Accuracy: {all_accuracies[-1]:.4f}%')

#### Task: Shortcomings
What are potential shortcomings and practical disadvantages of the SimCLR framework?

In [None]:
# TODO

**Solution:**
1. **Large Batch Size Requirement** To get enough negative samples, it typically requires very large batch sizes (e.g., 4096+), leading to high GPU demands.
2. **Sensitive to Augmentations** The quality of learned representations depends heavily on the choice of data augmentations (e.g., color distortion, cropping). Take a look at the original paper for more information.
3. **Inefficient Negative Sampling** SimCLR treats all negatives equally, even if some are more informative (hard negatives) than others.

## 4. BYOL

Next, we will implement **BYOL (Bootstrap Your Own Latent)**, a self-supervised representation learning framework that takes a different approach from contrastive methods.  

Recall our initial observation: even a randomly initialized CNN can produce representations that achieve non-trivial k-NN accuracy. BYOL builds on this insight.  

The key idea is that a given encoder, $f_{\text{target}}$â€”even if random or mediocreâ€”can serve as a teacher to train a stronger encoder, $f_{\text{online}}$. The training objective is to make $f_{\text{online}}$ imitate the outputs of $f_{\text{target}}$ when both are fed *different augmentations of the same input image*. This process is known as *bootstrapping*.  

Although BYOL still aligns representations of different views of the same image, unlike SimCLR, it is **not strictly contrastive**: it avoids the use of negative examples altogether.


We use the same set of augmentations as in SimCLR.

In [None]:
class BYOLTransform:

    def __init__(self, size=32, s=0.5, blur_p=0.5):
        color_jitter = T.ColorJitter(0.8*s, 0.8*s, 0.8*s, 0.2*s)
        k = 3 if size <= 32 else 5
        base = [
            T.RandomResizedCrop(size=size, scale=(0.2, 1.0)),
            T.RandomHorizontalFlip(),
            T.RandomApply([color_jitter], p=0.8),
            T.RandomGrayscale(p=0.2),
            T.RandomApply([T.GaussianBlur(kernel_size=k, sigma=(0.1, 2.0))], p=blur_p),
            T.ToTensor()
        ]
        self.train_transform = T.Compose(base)

    def __call__(self, x):
        return self.train_transform(x), self.train_transform(x)


def byol_collate_fn(batch):
    xs1, xs2, ys = [], [], []
    for (x1, x2), y in batch:
        xs1.append(x1)
        xs2.append(x2)
        ys.append(y)
    return torch.stack(xs1), torch.stack(xs2), torch.tensor(ys)

#### Task: BYOL Architecture

The BYOL architecture builds on the SimCLR architecture but introduces a third component on top of the encoder $f(\cdot)$ and the projector head $g(\cdot)$: the predictor head $q(\cdot)$.  

Implement the `forward` pass of the provided `BYOLModel`. It should output three tensors:  
1. The representations $f(x)$,  
2. Their normalized projections $\mathrm{normalize}(g(f(x)))$, and  
3. The normalized predictions $\mathrm{normalize}(q(g(f(x))))$.  


In [None]:
class BYOLModel(nn.Module):

    def __init__(self, proj_dim=128, hidden_dim=2048):
        super().__init__()
        enc = resnet18(weights=None)
        enc.conv1 = nn.Conv2d(3, 64, 3, 1, 1, bias=False)
        enc.maxpool = nn.Identity()
        enc.fc = nn.Identity()
        self.encoder = enc

        self.projector = nn.Sequential(
            nn.Linear(512, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, proj_dim)
        )

        self.predictor = nn.Sequential(
            nn.Linear(proj_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, proj_dim)
        )

    def forward(self, x):
        """
        x: (batch_size, channels, height, width) tensor of images
        returns: repr (batch_size, feature_dim), proj (batch_size, proj_dim), pred (batch_size, proj_dim)
        """
        # TODO
        repr = self.encoder(x)
        proj = self.projector(repr)
        pred = self.predictor(proj)

        proj = F.normalize(proj, dim=-1, eps=1e-8)
        pred = F.normalize(pred, dim=-1, eps=1e-8)

        return repr, proj, pred

Let's initalize our datasets and dataloaders.

In [None]:
byol_transform = BYOLTransform(size=32)

train_ds = CIFAR10(root="./data", train=True, download=True, transform=byol_transform)
test_ds  = CIFAR10(root="./data", train=False, download=True, transform=byol_transform)

train_loader = DataLoader(train_ds, batch_size=256, num_workers=10, pin_memory=True, collate_fn=byol_collate_fn)
test_loader  = DataLoader(test_ds,  batch_size=256, num_workers=10, pin_memory=True, collate_fn=byol_collate_fn)

train_ds_noaugment = CIFAR10(root="./data", train=True, download=True, transform=default_transform)
test_ds_noaugment  = CIFAR10(root="./data", train=False, download=True, transform=default_transform)

train_loader_noaugment = DataLoader(train_ds_noaugment, batch_size=256, shuffle=False, num_workers=10, pin_memory=True)
test_loader_noaugment  = DataLoader(test_ds_noaugment,  batch_size=256, shuffle=False, num_workers=10, pin_memory=True)

Let's initialize our models. In BYOL, we train two networks:  
1. The **online network**, and  
2. The **target network**.  

The online network is trained via gradient descent to imitate the target network.  

We initialize both networks with the same weights, but the parameters of the target network are frozen.


In [None]:
online_network = BYOLModel(proj_dim=128).to(device)
target_network = BYOLModel(proj_dim=128).to(device)

target_network.load_state_dict(copy.deepcopy(online_network.state_dict()))
for param in target_network.parameters():
    param.requires_grad = False

#### Task: EMA Update

Given an initially random target network, after some iterations the potential to learn from it via augmentation-invariant imitation will become saturated. Therefore, we need to update the target network as training progresses.  

In BYOL, this is achieved by updating the weights of the target network as an exponential moving average (EMA) of the online network after every optimization step. For each weight $w$ and update rate $\beta$, the update is given by:  
$$
w_{\mathrm{target}} = \beta \, w_{\mathrm{target}} + (1 - \beta) \, w_{\mathrm{online}}
$$

Implement this update rule in the `update` function of the following helper object.

In [None]:
class EMA():

    def __init__(self, beta=0.99):
        self.beta = beta
    
    def update(self, online, target):
        """
        Update target network parameters (inplace) using exponential moving average of online network parameters.
        online: online network (nn.Module)
        target: target network (nn.Module)
        """
        for online_param, target_param in zip(online.parameters(), target.parameters()):
            target_param.data = self.beta * target_param.data + (1 - self.beta) * online_param.data

#### Task: BYOL Training

In the BYOL training loss, the **normalized predictions** of the online network are compared against the **normalized projections** of the target network. See the illustration below. 

![BYOL Diagram](https://raw.githubusercontent.com/lucidrains/byol-pytorch/master/diagram.png)

For an image $x$ and transformations $t, t'$, the loss $\mathcal{L}_{BYOL}(x,t,t')$ computes as 
$$
\mathcal{L}_{BYOL}(x,t,t') 
= \left\| \frac{q_{\theta}(z^{(1)})}{\|q_{\theta}(z^{(1)})\|_2} - \frac{\text{sg}(z^{(2)})}{\|\text{sg}(z^{(2)})\|_2} \right\|_2^2
= 2 - 2 \cdot \frac{\langle q_{\theta}(z^{(1)}), \; \text{sg}(z^{(2)}) \rangle}{\| q_{\theta}(z^{(1)}) \|_2 \, \| \text{sg}(z^{(2)}) \|_2}
$$
where $z^{(1)}$ is the projection of $t(x)$ under the student network and $z^{(2)}$ is the projection of $t'(x)$ under the teacher network.

The total loss for a batch $\{x_i\}_{i=1}^{N}$ reads
$$
\mathcal{L}_{BYOL} = \frac{1}{N}\sum_{i=1}^{N} \mathcal{L}_{BYOL}(x,t,t') + \mathcal{L}_{BYOL}(x,t',t). 
$$

Complete and run the provided training below accordingly.\
*Hint: Don't be surprised if with the provided configuration the model takes a bit longer to converge than SimCLR.*

In [None]:
total_epochs = 50

ema = EMA(beta=0.99)

optimizer = torch.optim.AdamW(online_network.parameters(), lr=0.5, weight_decay=1e-6)


all_accuracies = []
all_losses = []
for epoch in range(total_epochs):
    avg_loss = 0.
    online_network.train()
    for x1, x2, y in tqdm(train_loader):
        # TODO
        x1, x2 = x1.to(device), x2.to(device)

        _, _, pred_online = online_network(torch.concat([x1, x2], dim=0))

        with torch.no_grad():
            _, proj_target, _ = target_network(torch.concat([x2, x1], dim=0))
        
        byol_loss = 2 - 2 * (pred_online * proj_target.detach()).sum(dim=-1)
        byol_loss = byol_loss.mean()

        optimizer.zero_grad()
        byol_loss.backward()
        avg_loss += byol_loss.item()
        optimizer.step()

        ema.update(online_network, target_network)

    online_network.eval()
    train_features, train_labels = extract_features_and_labels(online_network, train_loader_noaugment, normalize=True)
    test_features, test_labels = extract_features_and_labels(online_network, test_loader_noaugment, normalize=True)
    acc = run_knn_probe(train_features, train_labels, test_features, test_labels)
    avg_loss = avg_loss / len(train_loader)
    print(f"Epoch {epoch+1}, Average Loss: {avg_loss:.4f}, kNN Accuracy: {acc}")
    all_accuracies.append(acc)
    all_losses.append(avg_loss)

os.makedirs('./checkpoints', exist_ok=True)
torch.save(online_network.state_dict(), f'./checkpoints/byol_cifar10.pth')
print('Model saved to ./checkpoints/byol_cifar10.pth')

Let's plot the loss and accuracy curves!

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(12, 5))
axes[0].plot(range(1, total_epochs+1), all_losses)
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Average BYOL Loss')

axes[1].plot(range(1, total_epochs+1), all_accuracies)
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('k-NN Accuracy (%)')
axes[1].set_ylim(0, 1)

plt.show()

Let's also run a quick linear probe for the final results!

In [None]:
train_features, train_labels = extract_features_and_labels(online_network, train_loader_noaugment, normalize=True)
test_features, test_labels = extract_features_and_labels(online_network, test_loader_noaugment, normalize=True)

acc = run_linear_probe(train_features, train_labels, test_features, test_labels)
print(f'Final Linear Probe Accuracy: {acc:.4f}%')
print(f'Final k-NN Accuracy: {all_accuracies[-1]:.4f}%')

#### Task: On the Role of the Predicor Head

The asymmetry in the BYOL lossâ€”i.e., that the prediction from one view is matched to the projection of the other viewâ€”is an important feature of the BYOL architecture.  

**Discussion:** What could be the role of this asymmetry?  

*Hint:* Consider what would happen without the predictor. What would be a local minimum of the loss in that case?


In [None]:
# TODO 

**Solution:** The predictor helps to avoid collapse,  where the model learns to predict a constant representation. This is otherwise a common issue in self-supervised learning based on alignment with out negative examples.

## 5. Barlow Twins

Finally, let us implement **Barlow Twins**. Similar to BYOL, it is a self-supervised learning method that aligns representations of different views of the same images without using negative examples.  

However, it uses only a single network.


#### Augmentations and Architecture

We provide the transformation as well as a model architecture which is very similiar to that of SimCLR.

In [None]:
class BarlowTwinsTransform:

    def __init__(self, size=32, s=0.5, blur_p=0.5):
        color_jitter = T.ColorJitter(0.8*s, 0.8*s, 0.8*s, 0.2*s)
        k = 3 if size <= 32 else 5
        base = [
            T.RandomResizedCrop(size=size, scale=(0.2, 1.0)),
            T.RandomHorizontalFlip(),
            T.RandomApply([color_jitter], p=0.8),
            T.RandomGrayscale(p=0.2),
            T.RandomApply([T.GaussianBlur(kernel_size=k, sigma=(0.1, 2.0))], p=blur_p),
            T.ToTensor()
        ]
        self.train_transform = T.Compose(base)

    def __call__(self, x):
        return self.train_transform(x), self.train_transform(x)


def barlowtwins_collate_fn(batch):
    xs1, xs2, ys = [], [], []
    for (x1, x2), y in batch:
        xs1.append(x1)
        xs2.append(x2)
        ys.append(y)
    return torch.stack(xs1), torch.stack(xs2), torch.tensor(ys)

In [None]:
class BarlowTwinsModel(nn.Module):

    def __init__(self, proj_dim=128, hidden_dim=2048):
        super().__init__()
        enc = resnet18(weights=None)
        enc.conv1 = nn.Conv2d(3, 64, 3, 1, 1, bias=False)
        enc.maxpool = nn.Identity()
        enc.fc = nn.Identity()
        self.encoder = enc

        self.projector = nn.Sequential(
            nn.Linear(512, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, proj_dim)
        )

    def forward(self, x):
        repr = self.encoder(x)
        proj = self.projector(repr)

        return repr, proj

#### Task: Barlow Twins Loss

Let $z_1, z_2 \in \mathbb{R}^{N\times d}$ be the projected representations of the same images $x$ under different augmentations. The Barlow Twins loss is defined as follows:

Let $\bar{z}_1, \bar{z}_2$ be the batch-normalized projected representations (i.e., zero mean and unit standard deviation along the batch dimension), and let $\mathcal{C}$ be the resulting normalized cross-correlation matrix:  
$$
\mathcal{C} = \frac{\bar{z}_1^T \bar{z}_2}{N}
$$

The Barlow Twins loss is given by:  
$$
\mathcal{L}_{BT} = \sum_{i=1}^{d} \left( 1 - \mathcal{C}_{ii} \right)^2 \;+\; \alpha \sum_{i=1}^{d} \sum_{j \neq i} \mathcal{C}_{ij}^2
$$
where $\alpha$ is a trade-off factor.  

**Questions:**  
1. Hypothesize the role of the two terms in the loss function.  
2. How does this loss avoid collapse to a trivial representation?


In [None]:
# TODO 

**Solution**

**Invariance:** $\sum_i (1 - \mathcal{C}_{ii})^2$ â€” aligns embeddings of the same image under different augmentations.  
**Redundancy reduction:** $\lambda \sum_{i} \sum_{j \neq i} \mathcal{C}_{ij}^2$ â€” decorrelates feature dimensions to reduce redundancy.

Collapse is avoided by batch normalizing the projected representations.

Complete the loss function below accordingly.

In [None]:
def off_diagonal(x):
    bs, d = x.shape
    return x.flatten()[:-1].view(bs - 1, bs + 1)[:, 1:].flatten()

def barlow_loss(z1, z2, alpha, eps=1e-5):
    """
    z1: (batch_size, feature_dim) tensor of projection vectors
    z2: (batch_size, feature_dim) tensor of projection vectors
    alpha: redundancy reduction strength (float)
    returns: loss (scalar)
    """
    # TODO 
    B, d = z1.shape
    z1 = (z1 - z1.mean(dim=0, keepdim=True)) / torch.sqrt(z1.var(dim=0, keepdim=True) + eps) # (B, d)
    z2 = (z2 - z2.mean(dim=0, keepdim=True)) / torch.sqrt(z2.var(dim=0, keepdim=True) + eps) #

    cross_corr = z1.T @ z2 / B

    invariance_term = torch.diagonal(cross_corr).add_(-1).pow_(2).sum()
    redundancy_term = off_diagonal(cross_corr).pow_(2).sum()

    loss = invariance_term + alpha * redundancy_term

    return loss

Let's initialize datasets and model.

In [None]:
barlowtwins_transform = BarlowTwinsTransform(size=32)

train_ds = CIFAR10(root="./data", train=True, download=True, transform=barlowtwins_transform)
test_ds  = CIFAR10(root="./data", train=False, download=True, transform=barlowtwins_transform)

train_loader = DataLoader(train_ds, batch_size=256, num_workers=10, pin_memory=True, collate_fn=barlowtwins_collate_fn)
test_loader  = DataLoader(test_ds,  batch_size=256, num_workers=10, pin_memory=True, collate_fn=barlowtwins_collate_fn)

train_ds_noaugment = CIFAR10(root="./data", train=True, download=True, transform=default_transform)
test_ds_noaugment  = CIFAR10(root="./data", train=False, download=True, transform=default_transform)

train_loader_noaugment = DataLoader(train_ds_noaugment, batch_size=256, shuffle=False, num_workers=2, pin_memory=True)
test_loader_noaugment  = DataLoader(test_ds_noaugment,  batch_size=256, shuffle=False, num_workers=2, pin_memory=True)

In [None]:
barlow_model = BarlowTwinsModel(proj_dim=128).to(device)

#### Task: Barlow Twins Training

Complete the training pipeline below. We suggest $\alpha = 0.005$.

In [None]:
total_epochs = 50

optimizer = torch.optim.AdamW(barlow_model.parameters(), lr=0.5, weight_decay=1e-6)

all_accuracies = []
all_losses = []
for epoch in range(total_epochs):
    barlow_model.train()
    avg_loss = 0.
    for x1, x2, y in tqdm(train_loader):
        # TODO
        x1, x2 = x1.to(device), x2.to(device)

        _, proj1 = barlow_model(x1)
        _, proj2 = barlow_model(x2)

        loss = barlow_loss(proj1, proj2, alpha=0.005)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        avg_loss += loss.item()

    barlow_model.eval()

    train_features, train_labels = extract_features_and_labels(barlow_model, train_loader_noaugment, normalize=True)
    test_features, test_labels = extract_features_and_labels(barlow_model, test_loader_noaugment, normalize=True)
    acc = run_knn_probe(train_features, train_labels, test_features, test_labels)
    avg_loss = avg_loss / len(train_loader)
    print(f'Epoch {epoch+1}, average loss: {avg_loss/len(train_loader):.4f}, kNN accuracy: {acc:.4f}')

    all_accuracies.append(acc)
    all_losses.append(avg_loss)

os.makedirs('./checkpoints', exist_ok=True)
torch.save(barlow_model.state_dict(), f'./checkpoints/barlowtwins_cifar10.pth')
print('Model saved to ./checkpoints/barlowtwins_cifar10.pth')

Let's plot the loss and accuracy curves!

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(12, 5))
axes[0].plot(range(1, total_epochs+1), all_losses)
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Average Barlow Twins Loss')

axes[1].plot(range(1, total_epochs+1), all_accuracies)
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('k-NN Accuracy (%)')
axes[1].set_ylim(0, 1)

plt.show()

Let's evaluate the final linear probe accuracy!

In [None]:
train_features, train_labels = extract_features_and_labels(barlow_model, train_loader_noaugment, normalize=True)
test_features, test_labels = extract_features_and_labels(barlow_model, test_loader_noaugment, normalize=True)

acc = run_linear_probe(train_features, train_labels, test_features, test_labels)
print(f'Final Linear Probe Accuracy: {acc:.4f}%')
print(f'Final k-NN Accuracy: {all_accuracies[-1]:.4f}%')