Prepared by Tymoteusz Kwieciński and Vladimir Zaigrajew

Adapted from materials from HHU Deep Learning, SS2022/23, 21.04.2023, Prof. Dr. Markus Kollmann


# Exercise Week 3 - Self-supervised learning with rotation prediction


## Introduction to rotation prediction

As we learned from the Week 2 lecture, self-supervised learning (SSL) is a powerful approach to learn representations from unlabeled data. Traditional methods such as PCA, t-SNE, or UMAP are often limited in their ability to capture complex relationships in high-dimensional data. SSL methods via deep learning, on the other hand, can learn rich representations by leveraging the structure of the data itself. After training, these representations can be used for various downstream tasks such as classification, segmentation, or object detection.

On today's task you will try to train such model on **rotation prediction** which provides a simple, yet effective way to learn rich representations from unlabeled image data. The basic idea behind rotation prediction is that the network is trained to predict the orientation of a given image after it has been rotated by a certain angle (e.g., 0°, 90°, 180°, or 270°).

By doing so, the network is forced to learn features that are invariant to rotation, which can be very useful for downstream tasks such as object recognition or image classification.

Rotation prediction is also a relatively simple task that can be applied to large amounts of unlabeled data, which makes it a great candidate for introducing self-supervised learning in practice.

### Exercise description

In this exercise, we will verify how the simple SSL framework trained for rotation prediction can be applied to training a model for a downstream task. We will use the German Traffic Sign Benchmark (GTSRB), which consists of more than 50,000 images with 43 classes. This task will focus on the classification of traffic signs, which is a common problem in computer vision and has many practical applications, such as autonomous driving or advanced driver assistance systems.

In our case, we have a fully labeled dataset, but let us assume that we have our dataset only partly labeled. So firstly, we will split the dataset into two parts. The first one - the larger one will be used for the SSL framework. The second one, which will consist of a limited number of labels, will be used for the classification. In this way, we will mimic the situation where we have access only to a limited amount of labeled data and act as if the larger part of the dataset is unlabeled. One of the benefits of SSL is that it can help improve the performance of models trained on limited labeled data by leveraging the information contained in the unlabeled data, so let's see if this is the case.

On the first dataset, we will train a ResNet18 network that will predict the artificial rotation of the image. Then, we will modify the classification head of this network for a different task - label classification and finetune it on the smaller dataset. The second model trained will be on the smaller dataset only. In the end, we will visually compare the features of each of the models with the features of raw images.

Related paper: [Unsupervised Representation Learning by predicting Image Rotations](https://arxiv.org/pdf/1803.07728.pdf)

## Part I. Preparation and imports
Firstly, prepare an environment by installing all the required libraries. If you've done the previous exercises, I asked about the reproducibility of results and to achieve it, I mentioned that packages we use should be pinned to specific versions. This time, I want you to do it, so provide the cell with `pip install` commands for packages with specific versions that you used in compiling your code.
```bash
pip install vladimir_is_awesome==1.0.0
``` 

In [None]:
# %pip install your_packages

In [None]:
import os
import time
import random
import numpy as np
from tqdm import tqdm
from copy import deepcopy
from sklearn.model_selection import train_test_split

import torch
import torch.nn.functional as F

import torchvision
from torchvision import transforms as T

import matplotlib.pyplot as plt

Now, the most important part for the reproducibility of results - the random seed. I hope that you already understand why we need to do it. If not, you can read this [reddit post](https://www.reddit.com/r/learnpython/comments/s678b0/explain_randomseed_like_im_five/) or just search the internet. The code below will set the random seed for all libraries that we will use in this exercise.

In [None]:
# seed everything for reproducibility
SEED = 42

def seed_everything(seed: int=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


seed_everything(SEED)

# Part II. Preparing the data

Now, we need to prepare the datasets. In this task we will mimic the situation where we have a lot of unlabeled data and we need to label it manually. After weeks, we label some proportion of the data. As we learned, SSL methods enable learning from unlabeled data, so we will use the larger part of the dataset that is still unlabeled for training the SSL model. The second part of the dataset will be used for the classification task. We will check if the model trained on the unlabeled data and fine-tuned on labeled data will perform better than the one trained on the labeled data only.

We will use the GTSRB dataset, which is a large dataset of traffic sign images. The dataset contains 43 classes of traffic signs, with a total of 50,000 images. The images are in color and we will resize them to 224x224 pixels. The dataset is split into training and test sets, with 39,209 images in the training set and 12,630 images in the test set. We will use **40%** of the training set for the SSL task and **60%** for the classification task. The test set will be used for the evaluation of the models.

To implement this, we will first download the dataset, preprocess the images, and then create the SSL and classification datasets. We will use the `torchvision` library to download and preprocess the dataset. The `torchvision` library provides a convenient way to download and preprocess datasets, and it also provides a number of useful utilities for working with images.

If you want to experiment more, we encourage you to try out different datasets, such as STL10 (the resolution of this dataset is higher and your visualizations may be more exciting, but you would need to modify the dataset splits).

In [None]:
# define the transform to apply to the data - resize to 224x224 and convert to tensor.
transform = T.Compose([T.Resize((224, 224)), T.ToTensor()])

# Load the datasets
path_where_data_is_stored = '../data' # change this to the path where you want to store the data

### START CODE HERE ### (≈ 2 lines) you need to load train and test `GTSRB` dataset from PyTorch official page.
# Provide the path where you want to store the data (`path_where_data_is_stored`) and set `download=True` to download the dataset if it is not already present.
# Also, set the `transform` to the one defined above.
### END CODE HERE ###

# Lets now split the dataset into a SSL dataset and a classification dataset
SSL_SIZE = 0.6 # percentage of the dataset to use for training
targets = np.array([y for _, y in full_dataset])
SSL_indices, classification_indices = train_test_split(
    np.arange(len(targets)),
    test_size=1-SSL_SIZE,
    random_state=SEED,
    stratify=targets
)
SSL_dataset = torch.utils.data.Subset(full_dataset, SSL_indices)
classification_dataset = torch.utils.data.Subset(full_dataset, classification_indices)

print(f"Train full dataset size: {len(full_dataset)}")
print(f"Train SSL dataset size: {len(SSL_dataset)}")
print(f"Train classification dataset size: {len(classification_dataset)}")
print(f"Test dataset size: {len(dataset_test)}")

In [None]:
map_idx_to_class = [
    "Speed limit (20km/h)",
    "Speed limit (30km/h)",
    "Speed limit (50km/h)",
    "Speed limit (60km/h)",
    "Speed limit (70km/h)",
    "Speed limit (80km/h)",
    "End of speed limit (80km/h)",
    "Speed limit (100km/h)",
    "Speed limit (120km/h)",
    "No passing",
    "No passing for vehicles over 3.5 metric tons",
    "Right-of-way at the next intersection",
    "Priority road",
    "Yield",
    "Stop",
    "No vehicles",
    "Vehicles over 3.5 metric tons prohibited",
    "No entry",
    "General caution",
    "Dangerous curve to the left",
    "Dangerous curve to the right",
    "Double curve",
    "Bumpy road",
    "Slippery road",
    "Road narrows on the right",
    "Road work",
    "Traffic signals",
    "Pedestrians",
    "Children crossing",
    "Bicycles crossing",
    "Beware of ice/snow",
    "Wild animals crossing",
    "End of all speed and passing limits",
    "Turn right ahead",
    "Turn left ahead",
    "Ahead only",
    "Go straight or right",
    "Go straight or left",
    "Keep right",
    "Keep left",
    "Roundabout mandatory",
    "End of no passing",
    "End of no passing by vehicles over 3.5 metric tons"
]
angles = [0, 90, 180, 270]

In [None]:
image, target = next(iter(SSL_dataset))
image.shape, target, len(map_idx_to_class), len(angles)

At this point we have the original `GTSRB` dataset containing images and labels. We divided the training set into two parts: `classification_dataset` and `SSL_dataset`. The first one will be used for the classification task, while the second one will be used for the SSL task. In both cases, we will use the test set `dataset_test` for validation. For the classification task, the original `GTSRB` is already prepared, but for the rotation task we need to prepare the dataset with random rotations. We prepared the class `SSLRot` which you need to finish. The class should load the images, discard the label, and randomly rotate the image and record the rotation. The class should return a tuple of `(img, rotation_class_id)`, where `rotation_class_id` is an integer from 0 to 3, representing the rotation angle (0°, 90°, 180°, or 270°).

In [None]:
# Rotation Dataset
class SSLRot(torch.utils.data.Dataset):
    def __init__(self, dataset: torch.utils.data.Dataset, angles: list[int]):
        super(SSLRot, self).__init__()
        self.original_dataset = dataset
        self.angles = angles

    def __len__(self):
        return len(self.original_dataset)
    
    def rand_rotate(self, img: torch.Tensor) -> tuple[torch.Tensor, int]:
        """
        Randomly rotates the image by 0, 90, 180, or 270 degrees.

        Args:
            img (torch.Tensor): Input image tensor of shape (C, H, W).

        Returns:
            tuple: Rotated image tensor and the corresponding rotation label (0, 1, 2, or 3).
        """
        ### START CODE HERE ###
        
        
        
        ### END CODE HERE ###

    def __getitem__(self, idx):
        ### START CODE HERE ###
        # Get the data from the original dataset and ignore the label (second element)
        
        ### END CODE HERE ###
        return rotated_img, torch.tensor(rot_label, dtype=torch.long)

Let's visualize what the dataset looks like. If you get all 4 images rotated by 0 degrees, just rerun the cell. The random seed is set, but the rotation is not deterministic. If you always get the same result, something is wrong with your code.

In [None]:
rotation_dataset_train = SSLRot(SSL_dataset, angles)
rotation_dataset_test = SSLRot(dataset_test, angles)

fig, ax = plt.subplots(2, 2, figsize=(6, 4))
img, rot = rotation_dataset_train[0]
ax[0, 0].imshow(img.permute(1, 2, 0))
ax[0, 0].set_title(f"Image rotation: {angles[rot]} degrees")

img, rot = rotation_dataset_train[1]
ax[0, 1].imshow(img.permute(1, 2, 0))
ax[0, 1].set_title(f"Image rotation: {angles[rot]} degrees")

img, rot = rotation_dataset_test[2]
ax[1, 0].imshow(img.permute(1, 2, 0))
ax[1, 0].set_title(f"Image rotation: {angles[rot]} degrees")

img, rot = rotation_dataset_test[3]
ax[1, 1].imshow(img.permute(1, 2, 0))
ax[1, 1].set_title(f"Image rotation: {angles[rot]} degrees")
plt.tight_layout()
plt.show()

# Part III. Load and modify Resnet18

Now, we need to get the model. In the previous exercise we used the `ResNet18` model from `torchvision`. This time we will use the same model, but we will modify it to predict the rotation of the image and later to predict classes from the original dataset. The `ResNet18` model is a convolutional neural network (CNN) that is widely used for image classification tasks. It consists of 18 layers and is known for its ability to learn rich representations from images.

**Exercise:** Load and modify the Resnet18 architecture, so that you can use it with our data and to predict rotation angles.

In [None]:
def load_resnet_rotation(number_of_classes=4):
    ### START CODE HERE ### (≈ 2 lines)
    # Load ResNet18 without pre-trained weights from Pytorch page
    # Modify the final fully connected layer to output 4 classes (rotation labels).
    # If you don't know how to modify the last layer, in the previous exercise tutorial there was a part where authors modified the last layer of ResNet18 to output 2 classes.
    ### END CODE HERE ###
    return model

def test_load_resnet_rotation():
    model = load_resnet_rotation()
    x, y = rotation_dataset_train[0]
    x = x.unsqueeze(0)  # Add a batch dimension
    y = y.unsqueeze(0)  # Add a batch dimension
    
    with torch.no_grad():
        pred_y = model(x)
        
    loss = F.cross_entropy(pred_y, y)
    pred_y_class = torch.argmax(pred_y, dim=1)
    print(f"Input shape: {x.shape}, Model output: {pred_y.shape}, Model predicted {pred_y_class}, Ground truth: {y}, Loss: {loss.item()}")

test_load_resnet_rotation()

You should see the input size like `torch.Size([1, 3, 224, 224])` and the output size like `torch.Size([1, 4])`. The first one is the input size of the image and the second one is the output size of the model. The output size should be equal to the number of classes in the dataset. In our case, we have 4 classes (0°, 90°, 180°, or 270°). The model prediction probably has the highest probability for not the correct class. This is because the model is not trained yet.

# Part IV. Launch training!

Let's now train the model. We will use the `SSLRot` dataset for training. The model should be trained for `n` epochs. We provide the training and validation script; your task will be to fill the code in the highest abstract class for training `train` to run the training.

We do not expect the model to achieve perfect accuracy.

We encourage you to save the checkpoints of models with different parameters, and later decide to use the best one.

In [None]:
def train_one_epoch(model: torch.nn.Module, optimizer: torch.optim.Optimizer, train_loader: torch.utils.data.DataLoader, criterion: torch.nn.Module, device=torch.device):
    """Train the model for one epoch.
    Args:
        model (torch.nn.Module): The model to train.
        optimizer (torch.optim.Optimizer): The optimizer to use.
        train_loader (torch.utils.data.DataLoader): The training data loader.
        criterion (torch.nn.Module): The loss function.
        device (torch.device): The device to use for training (CPU or GPU).
    Returns:
        tuple: The average loss and accuracy for the epoch.
    """
    model.train()
    total_loss = 0.0
    correct = 0
    total = 0
    
    for inputs, targets in tqdm(train_loader, desc="Training"):
        inputs, targets = inputs.to(device), targets.to(device)
        
        # Zero the parameter gradients
        optimizer.zero_grad()
        
        # Forward + backward + optimize
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
        
        # Statistics
        total_loss += loss.item() * inputs.size(0)
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
    
    train_loss = total_loss / total
    train_acc = 100.0 * correct / total
    
    return train_loss, train_acc

def validate(model: torch.nn.Module, val_loader: torch.utils.data.DataLoader, device=torch.device) -> float:
    """Validate the model.
    Args:
        model (nn.Module): The model to validate.
        val_loader (torch.utils.data.DataLoader): The validation data loader.
        device (torch.device): The device to use for validation (CPU or GPU).
    
    Returns:
        float: The average accuracy for the validation set.    
    """
    model.eval()
    correct = 0
    total = 0
    
    with torch.no_grad():
        for inputs, targets in tqdm(val_loader, desc="Validating"):
            inputs, targets = inputs.to(device), targets.to(device)
            
            # Forward
            outputs = model(inputs)
            
            # Statistics
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
    
    val_acc = 100.0 * correct / total
    
    return val_acc

def visualize_predictions(model: torch.nn.Module, dataset: torch.utils.data.Dataset, device: torch.device, class_names: list[str], num_images: int=5):
    """Visualize predictions of the model on a subset of the dataset.
    Args:
        model (torch.nn.Module): The model to use for predictions.
        dataset (torch.utils.data.Dataset): The dataset to visualize.
        device (torch.device): The device to use for predictions (CPU or GPU).
        class_names (list[str]): The list of class names.
        num_images (int): The number of images to visualize.
        
    """
    model.eval()
    indices = random.sample(range(len(dataset)), num_images)
    
    _, axes = plt.subplots(1, num_images, figsize=(15, 5))
    
    with torch.no_grad():
        for i, idx in enumerate(indices):
            img, label = dataset[idx]
            img = img.unsqueeze(0).to(device)
            output = model(img)
            pred_label = output.argmax(dim=1).item()
            
            axes[i].imshow(img.squeeze(0).permute(1, 2, 0).cpu())
            axes[i].set_title(f"Pred: {class_names[pred_label]}\nTrue: {class_names[label]}")
            axes[i].axis('off')
    
    plt.show()

Now your part!

In [None]:
def train(model: torch.nn.Module, train_loader: torch.utils.data.DataLoader, val_loader: torch.utils.data.DataLoader, optimizer: torch.optim.Optimizer, criterion: torch.nn.Module, num_epochs: int=10, device=torch.device) -> tuple[list[float], list[float]]:
    """Train the model.
    
    Args:
        model (torch.nn.Module): The model to train.
        train_loader (torch.utils.data.DataLoader): The training data loader.
        val_loader (torch.utils.data.DataLoader): The validation data loader.
        optimizer (torch.optim.Optimizer): The optimizer to use.
        criterion (torch.nn.Module): The loss function.
        num_epochs (int): The number of epochs to train for.
        device (torch.device): The device to use for training (CPU or GPU).
    
    Returns:
        tuple: A tuple containing the training losses and validation accuracies for each epoch.
    """
    train_accs = []
    val_accs = []
    ### START CODE HERE ###





    ### END CODE HERE ###
    return train_accs, val_accs

Now, we have everything we need to train our model (at least the main meat). Below you have parameters that you can modify for the training. You can change the number of epochs, batch size, learning rate, and other parameters. You can also use different optimizers, such as Adam or SGD. You can also use different learning rate schedulers, such as StepLR or CosineAnnealingLR. We encourage you to experiment with different parameters and see how they affect the training process. 

Increasing the batch_size and number of epochs can lead to faster training, but be mindful of the memory and time usage. Adjust these parameters based on your hardware capabilities and monitor the training performance. That's why we advise to use Google Colab with GPU.

**I ENCOURAGE YOU TO PLAY WITH THESE PARAMETERS FOR EACH MODEL**

In [None]:
### PARAMETERS ###
BATCH_SIZE = 64
NUM_EPOCHS = 4
LEARNING_RATE = 0.001
NUM_WORKERS = 0
NUMBER_OF_CLASSES = 4
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Let's now train the model. The placeholder for the code can be ignored if you don't want to try various hyperparameters.

In [None]:
# Get the data loaders
train_dl = torch.utils.data.DataLoader(rotation_dataset_train, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)
val_dl = torch.utils.data.DataLoader(rotation_dataset_test, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)

# Load the model
ssl_model = load_resnet_rotation(NUMBER_OF_CLASSES)
ssl_model = ssl_model.to(device)

# Define the optimizer
optimizer = torch.optim.AdamW(ssl_model.parameters(), lr=LEARNING_RATE)

# Define the loss function
criterion = torch.nn.CrossEntropyLoss()

# Train the model
train_accs, val_accs = train(ssl_model, train_dl, val_dl, optimizer, criterion, num_epochs=NUM_EPOCHS, device=device)

### START CODE HERE ### (≈ 1 lines)
# Save the model and the training history
### END CODE HERE ###

In [None]:
# We encarage you to try different hyperparameters and see how the model performs.
# Based on various experiments, you can choose the best model and load it here for the evaluation.

### START CODE HERE ### (≈ 1 lines)
# Load the model and the training history
### END CODE HERE ###

Let's now view the results of our computations by plotting accuracy on train and validation sets.

In [None]:
plt.plot(train_accs, label='train accuracy')
plt.plot(val_accs, label='val accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.title('Training and Validation Accuracy on Rotation Dataset')
plt.legend()
plt.show()

And finally let's view the results by visualizing the images and what the model predicted for them.

In [None]:
visualize_predictions(ssl_model, rotation_dataset_test, num_images=5, device=device, class_names=angles)

At the bottom you will have space to answer what you saw, so don't do it here.

# Part V. Train the empty model on classification task

Now, we will train the original "empty" model on the classification task using the dataset `classification_dataset`. This model has random weights and we will train it from scratch. We will use the same training script as before, but this time we will use the `classification_dataset` dataset for training. In this part we want to check if the amount of labeled data is enough to train the model from scratch.

In [None]:
### PARAMETERS ###
BATCH_SIZE = 64
NUM_EPOCHS = 6
LEARNING_RATE = 0.001
NUM_WORKERS = 0
NUMBER_OF_CLASSES = 43
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

In [None]:
# Get the data loaders
train_dl = torch.utils.data.DataLoader(classification_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)
val_dl = torch.utils.data.DataLoader(dataset_test, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)

# Load the model
empty_model = load_resnet_rotation(NUMBER_OF_CLASSES)
empty_model = empty_model.to(device)

# Define the optimizer
optimizer = torch.optim.AdamW(empty_model.parameters(), lr=LEARNING_RATE)

# Define the loss function
criterion = torch.nn.CrossEntropyLoss()

# Train the model
train_accs, val_accs = train(empty_model, train_dl, val_dl, optimizer, criterion, num_epochs=NUM_EPOCHS, device=device)

In [None]:
plt.plot(train_accs, label='train accuracy')
plt.plot(val_accs, label='val accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.title('Training and Validation Accuracy on GTSRB Dataset')
plt.legend()
plt.show()

In [None]:
visualize_predictions(empty_model, dataset_test, num_images=5, device=device, class_names=map_idx_to_class)

# Part VI. Train classifier head on top of the SSL model

Okay, we trained our model from scratch, we got the metrics etc., now our question is: can we improve our classifier's performance? Normally we would do hyperparameter tuning to enhance performance or go and label more data. But we have a model that was trained on unlabeled data, so let's use it as our feature extractor. In this section we will train a classifier head on top of the SSL model and see if we can improve the performance or get any other benefits from using the SSL model.

Your task is to modify the last layer of the SSL model to predict the classes of the dataset (it is now set to 4 classes). In this section we will use the SSL backbone as the feature extractor and train only a classifier head. To do it, you need to freeze the weights of the SSL model and then change the classification head from 4 classes to 43 classes.

In [None]:
### PARAMETERS ###
BATCH_SIZE = 64
NUM_EPOCHS = 20
LEARNING_RATE = 0.001
NUM_WORKERS = 0
NUMBER_OF_CLASSES = 43
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

In [None]:
# Get the data loaders
train_dl = torch.utils.data.DataLoader(classification_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)
val_dl = torch.utils.data.DataLoader(dataset_test, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)

# Freeze the model
model = deepcopy(ssl_model)

### CODE HERE ### (≈ 3 lines) freeze the whole model and substitute the last layer with a new one that has the number of classes equal to `NUMBER_OF_CLASSES`.



### END CODE HERE ###
model = model.to(device)

# Define the optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)

# Define the loss function
criterion = torch.nn.CrossEntropyLoss()

# Train the model
train_accs, val_accs = train(model, train_dl, val_dl, optimizer, criterion, num_epochs=NUM_EPOCHS, device=device)

In [None]:
plt.plot(train_accs, label='train accuracy')
plt.plot(val_accs, label='val accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.title('Training and Validation Accuracy on GTSRB Dataset with trained classifier on SSL model')
plt.legend()
plt.show()

In [None]:
visualize_predictions(model, dataset_test, num_images=5, device=device, class_names=map_idx_to_class)

# Part VII. Finetune the SSL model for the classification task

Finally, to squeeze the last drop of performance from our SSL model, we will not freeze the weights of the SSL model during training. This should give us better performance on the classification task. So do the same as previously but remove the part where you freeze the model weights.

In [None]:
### PARAMETERS ###
BATCH_SIZE = 64
NUM_EPOCHS = 4
LEARNING_RATE = 0.0001
NUM_WORKERS = 0
NUMBER_OF_CLASSES = 43
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

In [None]:
# Get the data loaders
train_dl = torch.utils.data.DataLoader(classification_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)
val_dl = torch.utils.data.DataLoader(dataset_test, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)

# Freeze the model
model_finetune = deepcopy(ssl_model)

### CODE HERE ### (≈ 1 line) substitute the last layer with a new one that has the number of classes equal to `NUMBER_OF_CLASSES`.

### END CODE HERE ###
model_finetune = model_finetune.to(device)

# Define the optimizer
optimizer = torch.optim.AdamW(model_finetune.parameters(), lr=LEARNING_RATE)

# Define the loss function
criterion = torch.nn.CrossEntropyLoss()

# Train the model
train_accs, val_accs = train(model_finetune, train_dl, val_dl, optimizer, criterion, num_epochs=NUM_EPOCHS, device=device)

In [None]:
plt.plot(train_accs, label='train accuracy')
plt.plot(val_accs, label='val accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.title('Training and Validation Accuracy on GTSRB Dataset fully fine-tuned')
plt.legend()
plt.show()

In [None]:
visualize_predictions(model_finetune, dataset_test, num_images=5, device=device, class_names=map_idx_to_class)

# Conclusion and Bonus reads

After your training, I want you to answer the following questions:
1. Did the model trained on the SSL dataset perform better than the one trained only on the classification dataset?
2. Which training strategy performed better? The one where you trained the classifier head on top of the SSL model or the one where you finetuned the SSL model?
3. Besides the accuracy metric, what other benefits did you get from using the SSL model?
4. What are your conclusions? Did you expect this? If not, why?
5. Was SSL useful in this case?

Now that you have trained the model and answered the questions, I want you to rerun the whole notebook but now change `SSL_SIZE` from 60% to 80% and answer the same questions again. The additional question is: What happened when the number of labeled data decreased?

After you've answered both sets of questions, I want you to write your responses here in this cell with a clear separation between the results for different `SSL_SIZE` values. You will send me the Google Colab link with your notebook and I will check the answers and the code.

If you are interested to delve into this topic further, here are some links:

- [Self-supervised learning and computer vision](https://www.fast.ai/posts/2020-01-13-self_supervised.html)
- [Self-Supervised Representation Learning: Introduction, Advances and Challenges](https://arxiv.org/abs/2110.09327)

As an additional part, below is the code that, if you run it, will show you how the representation before the classifier head looks when mapped to 2D space via T-SNE. We use `1000` samples from the test set and plot them on a scatter plot for the three models that we used in this notebook: SSL, Scratch model, and Finetuned SSL. We also color each point by the class of the image from the classification dataset. I hope when you look at this plot you can think about which representation would be easier to classify and why. We will explore this question further in next week's exercise.

In [None]:
# As the addition part of the exercise, you can view the model representations of the images with the help of `TSNE`.

from sklearn.manifold import TSNE
from sklearn.preprocessing import StandardScaler

def get_features(model, batch_x):
    features = []
    
    def hook_fn(module, input, output):
        features.append(input[0].detach().cpu().numpy())
    
    hook = model.fc.register_forward_hook(hook_fn)
    model(batch_x)
    hook.remove()
    
    return features[0].reshape(features[0].shape[0], -1).squeeze()

ssl_model_features = []
empty_model_features = []
finetuned_model_features = []
labels = []

ssl_model.eval()
empty_model.eval()
model_finetune.eval()

num_samples = 1000
sample_indices = np.random.choice(len(dataset_test), min(num_samples, len(dataset_test)), replace=False)


for idx in tqdm(sample_indices, desc="Extracting features"):
    input, target = dataset_test[idx]
    input = input.unsqueeze(0).to(device)
    
    with torch.no_grad():
        ssl_model_features.append(get_features(ssl_model, input))
        empty_model_features.append(get_features(empty_model, input))
        finetuned_model_features.append(get_features(model_finetune, input))
    labels.append(target)
    
ssl_model_features = np.array(ssl_model_features)
empty_model_features = np.array(empty_model_features)
finetuned_model_features = np.array(finetuned_model_features)
labels = np.array(labels)

# Standardize the features
scaler = StandardScaler()
ssl_model_features = scaler.fit_transform(ssl_model_features)
empty_model_features = scaler.fit_transform(empty_model_features)
finetuned_model_features = scaler.fit_transform(finetuned_model_features)

# Apply TSNE
tsne = TSNE(n_components=2, random_state=SEED)
ssl_model_features_2d = tsne.fit_transform(ssl_model_features)
empty_model_features_2d = tsne.fit_transform(empty_model_features)
finetuned_model_features_2d = tsne.fit_transform(finetuned_model_features)

In [None]:
from matplotlib.lines import Line2D

# Create 2x2 grid - 3 plots for visualizations and 1 for the legend
fig, axs = plt.subplots(2, 2, figsize=(20, 16))
plt.subplots_adjust(hspace=0.3)

# Get unique classes for legend
unique_labels = np.unique(labels)
num_classes = len(unique_labels)

# Create a color map
cmap = plt.cm.get_cmap('viridis', num_classes)
colors = [cmap(i) for i in range(num_classes)]

# Plot 1: SSL Model
ax1 = axs[0, 0]
for i, label in enumerate(unique_labels):
    mask = labels == label
    ax1.scatter(
        ssl_model_features_2d[mask, 0], 
        ssl_model_features_2d[mask, 1],
        color=colors[i], 
        s=20, 
        alpha=0.7
    )
ax1.set_title('SSL Model Features (Rotation Pre-training)', fontsize=16)
ax1.set_xlabel('t-SNE Dimension 1', fontsize=12)
ax1.set_ylabel('t-SNE Dimension 2', fontsize=12)

# Plot 2: Model Trained from Scratch
ax2 = axs[0, 1]
for i, label in enumerate(unique_labels):
    mask = labels == label
    ax2.scatter(
        empty_model_features_2d[mask, 0], 
        empty_model_features_2d[mask, 1],
        color=colors[i], 
        s=20, 
        alpha=0.7
    )
ax2.set_title('Model Trained from Scratch', fontsize=16)
ax2.set_xlabel('t-SNE Dimension 1', fontsize=12)
ax2.set_ylabel('t-SNE Dimension 2', fontsize=12)

# Plot 3: Fine-tuned SSL Model
ax3 = axs[1, 0]
for i, label in enumerate(unique_labels):
    mask = labels == label
    ax3.scatter(
        finetuned_model_features_2d[mask, 0], 
        finetuned_model_features_2d[mask, 1],
        color=colors[i], 
        s=20, 
        alpha=0.7
    )
ax3.set_title('Fine-tuned SSL Model', fontsize=16)
ax3.set_xlabel('t-SNE Dimension 1', fontsize=12)
ax3.set_ylabel('t-SNE Dimension 2', fontsize=12)

# Create legend in the 4th subplot
ax4 = axs[1, 1]
ax4.axis('off')  # Turn off axis

# Create legend elements
legend_elements = []
for i, label in enumerate(unique_labels):
    # Truncate long class names
    class_name = map_idx_to_class[label]
    if len(class_name) > 25:
        class_name = class_name[:25] + "..."
    
    legend_elements.append(
        Line2D([0], [0], marker='o', color='w', markerfacecolor=colors[i], 
               markersize=10, label=f"{label}: {class_name}")
    )

# Create the legend with multiple columns
ax4.legend(handles=legend_elements, loc='center', ncol=2, fontsize=10)
ax4.set_title('Traffic Sign Classes', fontsize=16)

plt.suptitle('Comparing Feature Representations Using t-SNE Visualization', fontsize=20)
plt.tight_layout(rect=[0, 0, 1, 0.97])
plt.show()