# Self-supervised Learning
The objective of this lab project is to go further in the understanding of Self-Supervised Learning (SSL). By the end of the notebook, you will
- Train models using different pretext tasks: colorizing, inpainting, masking reconstruction.
- Fine-tune the models with the downstream task of interst.
- Compare the performance of the different backbones obtained from the different downstream tasks.

## Library imports

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Dataset
import numpy as np
import matplotlib.pyplot as plt

## Pretext and Downstream Tasks

We will train three different models using three different pretext tasks. All three models will be trained on the SVHN dataset. The three models are the following:
- A Colorization Neural Network
- An Inpainting Neural Network
- A Masked Autoencoder

We will build a common architecture so that all three models have as similar architectures as possible. The common architecture will consist of an encoder and a decoder. Once the models pre-trained on their respective pretext tasks, we will use the pre-trained encoders to evaluate the learnt representations on two new downstream task: image classification on MNIST and on SVHN. To that end, we will perform a linear evaluation protocol, by freezing the weights of the pre-trained encoders, and training a linear classifier on the learnt representations. 

## Data Preparation

In [None]:
batch_size = 128

# SVHN Dataset (Train and Test)
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((32, 32)),
])

mnist_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((32, 32)),
])

# Use SVHN train dataset for pre-training
svhn_train = datasets.SVHN(root='../data', split='train', download=True, transform=transform)
svhn_train_loader = DataLoader(svhn_train, batch_size=128, shuffle=True)

# Use SVHN test dataset for monitoring
svhn_test = datasets.SVHN(root='../data', split='test', download=True, transform=transform)
svhn_test_loader = DataLoader(svhn_test, batch_size=128, shuffle=False)

""" # Use SVHN extra dataset for pre-training
svhn_extra = datasets.SVHN(root='../data', split='extra', download=True, transform=transform)
extra_loader = DataLoader(svhn_extra, batch_size=64, shuffle=True) """

# MNIST Dataset (Fine-tuning and classification)
mnist_train = datasets.MNIST(root='../data', train=True, download=True, transform=mnist_transform)
mnist_test = datasets.MNIST(root='../data', train=False, download=True, transform=mnist_transform)
mnist_train_loader = DataLoader(mnist_train, batch_size=64, shuffle=True)
mnist_test_loader = DataLoader(mnist_test, batch_size=64, shuffle=False)

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

## Shared architecture
We will create a shared architecture with the following layers:
- The encoder (in sequential order):
    - A convolution with 64 filters, kernel size 4, stride 2, padding 1
    - A ReLU activation
    - A convolution with 128 filters, kernel size 4, stride2, padding 1
    - A ReLU activation
    - A convolution with `latent_dim`, kernel size 4, stride 2, padding 1
    - A ReLU activation
The encoder should take the number of channels of the input `in_channels` and the hidden dimension `latent_dim` as arguments.
- The decoder (in secuential ) order:
    - A Transpose convolution with 128 filters, kernel size 4, padding 1
    - A ReLU activation
    - A convolution with 16428 filters, kernel size 4, stride2, padding 1
    - A ReLU activation
    - A convolution with `out_channels` filters, kernel size 4, stride 2, padding 1
    - A ReLU activation
The encoder should take the number of channels of the output `out_channels` and the hidden dimension `latent_dim` as arguments.

In [None]:
# TODO: Define the encoder architecture
class Encoder(nn.Module):


# TODO: Define the decoder architecture
class Decoder(nn.Module):

In [4]:
# %load solutions/encoder_decoder.py

## Trainer function

In [5]:
def train_ssl_model(model, 
                    train_loader, 
                    test_loader, 
                    criterion,
                    optimizer,
                    device=device,
                    epochs=5):

    for epoch in range(epochs):
        model.to(device)
        model.train()
        total_train_loss = 0
        for images, _ in train_loader:
            images = images.to(device)
            optimizer.zero_grad()
            output, _ = model(images)
            loss = criterion(output, images)
            loss.backward()
            optimizer.step()
            total_train_loss += loss.item()

        avg_train_loss = total_train_loss / len(train_loader)

        model.eval()
        total_val_loss = 0
        with torch.no_grad():
            for images, _ in test_loader:
                images = images.to(device)
                output, _ = model(images)
                val_loss = criterion(output, images)
                total_val_loss += val_loss.item()

        avg_val_loss = total_val_loss / len(test_loader)

        print(f"Epoch {epoch+1}/{epochs}, Train Loss: {avg_train_loss:.4f}, Avg Val Loss: {avg_val_loss:.4f}")
    
    return model.encoder

## Pretext 1: Colorization

**Questions.** 
1. What are the values of `in_channels` and `out_channels` for the Colorization model?
2. Given the structure of the `train_ssl_model` function above, where should the conversion from the original images to the grayscale ones happen?
3. What loss should we use for the training?
4. What optimizer would you choose for the training?
5. Are there other hyper-parameters that we need to choose?

In [23]:
# TODO: Define the colorization model
class ColorizationModel(nn.Module):

In [6]:
# %load solutions/colorization_model.py

### Colorization Training

In [None]:
# TODO: Pre-train the encoder using SVHN dataset
colorization_model = # TODO: Instantiate the colorization model
colorization_encoder = # TODO: Pre-train the encoder using SVHN dataset

In [None]:
# %load solutions/pretrain_colorization.py

### Visualize Colorization

In [None]:
import random

# Visualize colorization on random test images

def visualize_reconstructions(model, data_loader, device, num_images=5):
    # Set the model to evaluation mode
    model.eval()

    # Convert the DataLoader to a list to randomly sample images
    dataset = list(data_loader.dataset)

    # Randomly select `num_images` images from the dataset
    random_indices = random.sample(range(len(dataset)), num_images)
    random_images = [dataset[i][0] for i in random_indices]  # Extract only the images, ignoring labels

    # Stack the images into a batch
    images = torch.stack(random_images)

    # Move images to the specified device
    images = images.to(device)
    
    # Run the grayscale images through the colorization model
    with torch.no_grad():
        reconstructed_images, perturbed_images = model(images)
    
    # Move images back to CPU for visualization
    images = images.cpu()
    reconstructed_images = reconstructed_images.cpu()
    perturbed_images = perturbed_images.cpu()
    
    # Plot the grayscale, ground truth, and colorized images
    fig, axes = plt.subplots(num_images, 3, figsize=(10, num_images * 4))
    for i in range(num_images):
        # Grayscale input
        axes[i, 0].imshow(perturbed_images[i].permute(1, 2, 0).squeeze(), cmap='gray')
        axes[i, 0].set_title("Grayscale Input")
        axes[i, 0].axis('off')
        
        # Ground truth (original RGB image)
        axes[i, 1].imshow(images[i].permute(1, 2, 0))
        axes[i, 1].set_title("Ground Truth (RGB)")
        axes[i, 1].axis('off')
        
        # Colorized output from the model
        axes[i, 2].imshow(reconstructed_images[i].permute(1, 2, 0))
        axes[i, 2].set_title("Colorized Output")
        axes[i, 2].axis('off')
    
    plt.tight_layout()
    plt.show()

# Visualize colorization on random test images
visualize_reconstructions(colorization_model, svhn_test_loader, device)

## Pretext 2: Inpainting

**Questions.** 
1. What are the values of `in_channels` and `out_channels` for the Inpainting model?
2. Given the structure of the `train_ssl_model` function above, where should the masking of the original images happen?
3. What loss should we use for the training?
4. What optimizer would you choose for the training?
5. Are there other hyper-parameters that we need to choose?

In [24]:
#TODO: Define the inpainting model by adding an apply_mask method inside the class
class InpaintingModel(nn.Module):
    # TODO: Implement the __init__ method

    # TODO: Implement the forward method

    # TODO: Implement the apply_mask method
    def apply_mask(self, x):
        masked_x = x.clone()

        for i in range(masked_x.size(0)): # Loop over the batch size
            ul_x = # TODO: Randomly sample the x coordinate of the upper left corner
            ul_y = # TODO: Randomly sample the y coordinate of the upper left corner
            # TODO: Apply the mask to the image

        return masked_x

In [9]:
# %load solutions/inpainting_model.py

### Inpainting Training

In [None]:
# TODO: Pre-train the encoder using SVHN dataset
inpainting_model = # TODO: Instantiate the colorization model
inpainting_encoder = # TODO: Pre-train the encoder using SVHN dataset

In [None]:
#%load solutions/pretrain_inpainting.py

### Visualize inpainting

In [None]:
visualize_reconstructions(inpainting_model, svhn_test_loader, device=device, num_images=5)

## Pretext 3: Masked Autoencoder

**Questions.** 
1. What are the values of `in_channels` and `out_channels` for the Masked Autoencoder model?
2. Given the structure of the `train_ssl_model` function above, where should the masking of the original images happen?
3. What loss should we use for the training?
4. What optimizer would you choose for the training?
5. Are there other hyper-parameters that we need to choose?

In [149]:
#TODO: Define the MAE model by adding an apply_mask method inside the class
class MaskedAutoencoderModel(nn.Module):
    # TODO: Implement the __init__ method

    # TODO: Implement the forward method
    
    # TODO: Implement the apply_mask method
    def apply_mask(self, x):
        x_masked = x.clone()
        mask = # TODO create a random mask with the right average number of pixels masked
        x_masked[mask] = 0
        return x_masked

In [12]:
# %load solutions/mae_model.py

### MAE Training

In [None]:
# TODO: Pre-train the encoder using SVHN dataset
mae_model = # TODO: Instantiate the colorization model
mae_encoder = # TODO: Pre-train the encoder using SVHN dataset

In [None]:
# %load solutions/pretrain_mae.py

### MAE Visualization

In [None]:
visualize_reconstructions(mae_model, svhn_test_loader, device=device, num_images=5)

## Downstream Task 1: Classification on Mnist dataset

In [17]:
class Classifier(nn.Module):
    def __init__(self, encoder, num_classes=10):
        super(Classifier, self).__init__()
        self.encoder = encoder
        self.fc = nn.Linear(128 * 2 * 2, num_classes)

    def forward(self, x):
        z = self.encoder(x)
        z = z.view(z.size(0), -1)
        return self.fc(z)

# Fine-tuning loop for classification
def fine_tune_classification(encoder, train_loader, test_loader, epochs=5, encoder_in_channels=3):
    model = Classifier(encoder).to(device)
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    criterion = nn.CrossEntropyLoss()

    # Freeze the encoder's weights
    for param in model.encoder.parameters():
        param.requires_grad = False

    for epoch in range(epochs):
        model.train()
        total_loss = 0
        correct = 0
        total = 0
        for images, labels in train_loader:
            if encoder_in_channels == 3:
                images = torch.cat((images, images, images), dim=1)
            images = images.to(device)
            labels = labels.to(device)

            optimizer.zero_grad()
            output = model(images)
            loss = criterion(output, labels)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
            # Compute accuracy
            _, predicted = torch.max(output, 1)
            correct += (predicted == labels).sum().item()
            total += labels.size(0)

        print(f"Epoch {epoch+1}/{epochs}, Loss: {total_loss/len(train_loader)}, Accuracy: {100 * correct / total:.2f}%")

    # Evaluate on test set
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in test_loader:
            if encoder_in_channels == 3:
                images = torch.cat((images, images, images), dim=1)
            images = images.to(device)
            labels = labels.to(device)
            output = model(images)
            _, predicted = torch.max(output, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    print(f'Test Accuracy: {100 * correct / total:.2f}%')

In [None]:
# 1. Fine-tune Colorization model
fine_tune_classification(colorization_encoder, mnist_train_loader, mnist_test_loader, encoder_in_channels=1)

# 2. Fine-tune Inpainting model
fine_tune_classification(inpainting_encoder, mnist_train_loader, mnist_test_loader)

# 3. Fine-tune Masked Autoencoder model
fine_tune_classification(mae_encoder, mnist_train_loader, mnist_test_loader)

## Downstream Task 2: Classification on SVHN

In [19]:
class Classifier(nn.Module):
    def __init__(self, encoder, num_classes=10):
        super(Classifier, self).__init__()
        self.encoder = encoder
        self.fc = nn.Linear(128 * 2 * 2, num_classes)

    def forward(self, x):
        z = self.encoder(x)
        z = z.view(z.size(0), -1)
        return self.fc(z)

# Fine-tuning loop for classification
def fine_tune_svhn(encoder, train_loader, test_loader, epochs=5, encoder_in_channels=3):
    model = Classifier(encoder).to(device)
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    criterion = nn.CrossEntropyLoss()

    # Freeze the encoder's weights
    for param in model.encoder.parameters():
        param.requires_grad = False

    for epoch in range(epochs):
        model.train()
        total_loss = 0
        correct = 0
        total = 0
        for images, labels in train_loader:
            if encoder_in_channels == 1:
                images = transforms.Grayscale()(images)  # Convert RGB to Grayscale
            images = images.to(device)
            labels = labels.to(device)

            optimizer.zero_grad()
            output = model(images)
            loss = criterion(output, labels)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
            # Compute accuracy
            _, predicted = torch.max(output, 1)
            correct += (predicted == labels).sum().item()
            total += labels.size(0)

        print(f"Epoch {epoch+1}/{epochs}, Loss: {total_loss/len(train_loader)}, Accuracy: {100 * correct / total:.2f}%")

    # Evaluate on test set
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in test_loader:
            if encoder_in_channels == 1:
                images = transforms.Grayscale()(images)  # Convert RGB to Grayscale
            images = images.to(device)
            labels = labels.to(device)
            output = model(images)
            _, predicted = torch.max(output, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    print(f'Test Accuracy: {100 * correct / total:.2f}%')

In [None]:
# 1. Fine-tune Colorization model
fine_tune_svhn(colorization_encoder, svhn_train_loader, svhn_test_loader, encoder_in_channels=1)

# 2. Fine-tune Inpainting model
fine_tune_svhn(inpainting_encoder, svhn_train_loader, svhn_test_loader)

# 3. Fine-tune Masked Autoencoder model
fine_tune_svhn(mae_encoder, svhn_train_loader, svhn_test_loader)