Prepared by Vladimir Zaigrajew and Tymoteusz Kwieciński

# Exercise Week 4 - Evaluation of Self-supervised Learning Methods
In the previous exercise we learned how to train self-supervised models on the ration task on the dataset with traffic signs. We later used the trained models to extract features and train a classifier on top of them. We compared the performance of the classifier trained on the features extracted from the self-supervised model, fine-tuned self-supervised model on downstream task (classification), and the classifier trained from scratch on the same dataset. In this exercise we will focus on evaluation part of training self-supervised models. In previous exercise you done the downstream task evaluation on the dataset with traffic signs. In this exercise we will focus on less compute intensive evaluation using visualization techniques and alternative to learning linear probing (linear classifier) from previous exercise.

In this exercise, we will focus on **4** different models:
- ResNet-18 trained with SSL task (Rotation Prediction - same as in the previous exercise)
- ResNet-18 trained from the scratch (Supervised)
- ResNet-18 pertained from Pytorch team on ImageNet (Supervised)
- PCA model (classical unsupervised learning method)


For training a few of our models we will use STL10 dataset for self-supervised training and CIFAR-100 dataset for supervised training. The STL10 dataset consists of 10 classes with 105000 images in set where some of the images are not labeled (don't care because we will be using it for self-supervised training). The dataset have also the labeled test set which consist of 8000 images. The CIFAR-100 dataset consists of 60,000 32x32 color images in 100 classes, with 600 images per class. The dataset is divided into 50,000 training images and 10,000 test images.

After we train our models we will evaluate learned representations on following datasets:
- CIFAR-10: Subset of CIFAR-100 with 10 classes (6000 images in total)
- GTSRB: Dataset with 43 classes of traffic signs (600 images per class) which you may already know from the previous exercise
- FGVCAircraft: Dataset with 100 classes of aircrafts (100 images per class)

This exercise will be less compute intensive as your task will be to analyse the results from the evaluation of the models we have on various task and writing a short report on the results.

So let's start!

Part I. Prepare the environment
Firstly, prepare an environment by installing all the required libraries. To know which package to install you need to investigate cells with the import statements. I require to paste the code below with the command to install the libraries with **specific versions** for example:
```python
%pip install numpy==1.21.0 pandas==1.3.0 matplotlib==3.4.2 torch==1.9.0 torchvision==0.10.0
```

In [None]:
# This is a placeholder for package installation.

Now that the package is installed, you can run the following command to check if it works as packages imported in the cell below are needed for the task.

In [None]:
import os
import random
import numpy as np
from tqdm import tqdm
from sklearn.neighbors import KNeighborsClassifier
from sklearn.manifold import TSNE
from sklearn.preprocessing import StandardScaler

import torch
import torchvision
from torchvision import transforms as T

import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from matplotlib.colorbar import ColorbarBase
from matplotlib.colors import Normalize

Finally in previous exercise I asked you about how results can be reproducible specifically in the context seeding the random number generator. In the code below we set all the random number generators to the same seed. This is important for reproducibility of the results. You can change the seed to any number you want, but make sure to set it to the same number in all cells where you need to set the random number generator. People who still don't know how it works should read the blogs about it as this is basic knowledge for ML practitioners.

In [None]:
# seed everything for reproducibility
SEED = 42

def seed_everything(seed: int=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


seed_everything(SEED)

## Part II. Prepare the data

Now, we need to prepare the data. We need to get the CIFAR-100 and STL10 dataset for training and the other datasets for evaluation. We will use the torchvision library to download and prepare the datasets. I recommend to check each of the dataset how they are structured, the domain of the data, the classes and the sizes of the data as this may help you when you will be analyzing the results.



In [None]:
# define the transform to apply to the data
# the training transform includes random cropping and resizing. Random cropping is used to
# augment the data and make the model more robust to variations in the input data.
# The size 128 was chosen to be the most compatible with all datasets (it is also good to consider
# the size of the original images when writing the report).
transform_train = T.Compose([
    T.RandomResizedCrop(128),
    T.ToTensor(),
])
# The test transform includes resizing and center cropping to ensure that the input data is
# the same size as the training data (128x128).
test_transform = T.Compose([
    T.Resize(134),
    T.CenterCrop(128),
    T.ToTensor(),
])

dataset_path = './data'

# load the STL10 dataset
train_dataset_ssl = torchvision.datasets.STL10(root=dataset_path, split='train+unlabeled', download=True, transform=transform_train)
test_dataset_ssl = torchvision.datasets.STL10(root=dataset_path, split='test', download=True, transform=test_transform)

# load the CIFAR-10 dataset
train_dataset = torchvision.datasets.CIFAR100(root=dataset_path, train=True, download=True, transform=transform_train)
test_dataset = torchvision.datasets.CIFAR100(root=dataset_path, train=False, download=True, transform=transform_train)

# load the CIFAR-10 test dataset
test_dataset_cifar10 = torchvision.datasets.CIFAR10(root=dataset_path, train=False, download=True, transform=test_transform)

# load the GTSRB test dataset
test_dataset_gtsrb = torchvision.datasets.GTSRB(root=dataset_path, download=True, transform=test_transform)

# load the FGVCAircraft test dataset
test_dataset_fgv_aircraft = torchvision.datasets.FGVCAircraft(root=dataset_path, split='test', download=True, transform=test_transform)

Now we need to create the dataset class for rotation prediction task based on STL10 dataset. In previous exercise you implemented the dataset class for the rotation prediction task. Now I have already done this so you can use the code below.

The dataset class is based on the STL10 dataset and it will create a dataset with images rotated by 0, 90, 180, and 270 degrees. The labels will be the rotation angle (0, 1, 2, or 3). The dataset class will also apply the same transformations as in the previous exercise. You can check the code below to see how it works.

In [None]:
# Rotation Dataset
class SSLRot(torch.utils.data.Dataset):
    def __init__(self, dataset: torch.utils.data.Dataset, angles: list[int]):
        """
        Initialize the rotation dataset.
        
        Args:
            dataset (torch.utils.data.Dataset): The original dataset to apply rotations to.
            angles (list[int]): List of rotation angles in degrees (e.g. [0, 90, 180, 270]).
        """
        super(SSLRot, self).__init__()
        self.original_dataset = dataset
        self.angles = angles

    def __len__(self) -> int:
        """
        Return the length of the dataset.
        
        Returns:
            int: The number of samples in the dataset.
        """
        return len(self.original_dataset)
    
    def rand_rotate(self, img: torch.Tensor) -> tuple[torch.Tensor, int]:
        """
        Randomly rotates the image by 0, 90, 180, or 270 degrees.

        Args:
            img (torch.Tensor): Input image tensor of shape (C, H, W).

        Returns:
            tuple: Rotated image tensor and the corresponding rotation label (0, 1, 2, or 3).
        """
        rot_label = random.randint(0, 3)
        rotated_img = T.functional.rotate(img, self.angles[rot_label])
        return rotated_img, rot_label
        
    def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]:
        """
        Retrieves the rotated image and its corresponding rotation label.
        
        Args:
            idx (int): Index of the sample to retrieve.
        
        Returns:
            tuple: Rotated image tensor and the rotation label as a tensor.
        """
        img, _ = self.original_dataset[idx]
        rotated_img, rot_label = self.rand_rotate(img)
        return rotated_img, torch.tensor(rot_label, dtype=torch.long)

In [None]:
# Define the rotation angles
angles = [0, 90, 180, 270] 

# Create the rotation datasets for training and testing
rotation_dataset_train = SSLRot(train_dataset_ssl, angles)
rotation_dataset_test = SSLRot(test_dataset_ssl, angles)

# Print the lengths of the datasets
len(rotation_dataset_train), len(rotation_dataset_test)

# Part III. Load the models

Now, we need to get the model. In the previous exercise we used the `ResNet18` model from `torchvision`. This time we will use the same model. Your task is to load the models and modify them if needed. As described earlier we will use the following models:
- ResNet-18 trained with SSL task - so we need to load the empty ResNet-18 model (no weights) and modify the `fc` layer to have 4 outputs (for the rotation prediction task)
- ResNet-18 trained from scratch - so we need to load the empty ResNet-18 model (no weights) and modify the `fc` layer to have 10 outputs (for the CIFAR-10 classification task)
- ResNet-18 pretrained from Pytorch team on `IMAGENET1K_V1` - so we need to load the ResNet-18 model with the weights from ImageNet and **modify nothing**.

In [None]:
# This is a placeholder for model loading please name the models:
# ssl_model
# scratch_model
# model_finetuned (the pretrained one)

# Part IV. Prepare the training

In the part below you don't need to do anything as the code is already prepared for you. The code below will execute the necessary training and validation steps needed to train both SSL and from scratch models.

In [None]:
def train_one_epoch(model: torch.nn.Module, optimizer: torch.optim.Optimizer, train_loader: torch.utils.data.DataLoader, criterion: torch.nn.Module, device=torch.device) -> tuple[float, float]:
    """Train the model for one epoch.
    Args:
        model (torch.nn.Module): The model to train.
        optimizer (torch.optim.Optimizer): The optimizer to use.
        train_loader (torch.utils.data.DataLoader): The training data loader.
        criterion (torch.nn.Module): The loss function.
        device (torch.device): The device to use for training (CPU or GPU).
    Returns:
        tuple: The average loss and accuracy for the epoch.
    """
    model.train()
    total_loss = 0.0
    correct = 0
    total = 0
    
    for inputs, targets in tqdm(train_loader, desc="Training"):
        inputs, targets = inputs.to(device), targets.to(device)
        
        # Zero the parameter gradients
        optimizer.zero_grad()
        
        # Forward + backward + optimize
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
        
        # Statistics
        total_loss += loss.item() * inputs.size(0)
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
    
    train_loss = total_loss / total
    train_acc = 100.0 * correct / total
    
    return train_loss, train_acc

def validate(model: torch.nn.Module, val_loader: torch.utils.data.DataLoader, device=torch.device) -> float:
    """Validate the model.
    Args:
        model (nn.Module): The model to validate.
        val_loader (torch.utils.data.DataLoader): The validation data loader.
        device (torch.device): The device to use for validation (CPU or GPU).
    
    Returns:
        float: The average accuracy for the validation set.    
    """
    model.eval()
    correct = 0
    total = 0
    
    with torch.no_grad():
        for inputs, targets in tqdm(val_loader, desc="Validating"):
            inputs, targets = inputs.to(device), targets.to(device)
            
            # Forward
            outputs = model(inputs)
            
            # Statistics
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
    
    val_acc = 100.0 * correct / total
    
    return val_acc


def train(model: torch.nn.Module, train_loader: torch.utils.data.DataLoader, val_loader: torch.utils.data.DataLoader, optimizer: torch.optim.Optimizer, criterion: torch.nn.Module, num_epochs: int=10, device=torch.device) -> tuple[list[float], list[float], list[dict]]:
    """Train the model.
    
    Args:
        model (torch.nn.Module): The model to train.
        train_loader (torch.utils.data.DataLoader): The training data loader.
        val_loader (torch.utils.data.DataLoader): The validation data loader.
        optimizer (torch.optim.Optimizer): The optimizer to use.
        criterion (torch.nn.Module): The loss function.
        num_epochs (int): The number of epochs to train for.
        device (torch.device): The device to use for training (CPU or GPU).
    
    Returns:
        tuple: A tuple containing the training accuracy, validation accuracy, and model state dictionaries.
    """
    train_accs = []
    val_accs = []
    state_dicts = []
    for epoch in range(num_epochs):
        print(f"Epoch {epoch+1}/{num_epochs}")
        
        # Train
        train_loss, train_acc = train_one_epoch(model, optimizer, train_loader, criterion, device)
        train_accs.append(train_acc)
        
        # Save the model state
        state_dicts.append(model.state_dict())
        
        # Validate
        val_acc = validate(model, val_loader, device)
        val_accs.append(val_acc)
        
        print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%, Val Acc: {val_acc:.2f}%")
    return train_accs, val_accs, state_dicts


def visualize_predictions(model: torch.nn.Module, dataset: torch.utils.data.Dataset, device: torch.device, class_names: list[str], num_images: int=5):
    """Visualize predictions of the model on a subset of the dataset.
    Args:
        model (torch.nn.Module): The model to use for predictions.
        dataset (torch.utils.data.Dataset): The dataset to visualize.
        device (torch.device): The device to use for predictions (CPU or GPU).
        class_names (list[str]): The list of class names.
        num_images (int): The number of images to visualize.
        
    """
    model.eval()
    indices = random.sample(range(len(dataset)), num_images)
    
    _, axes = plt.subplots(1, num_images, figsize=(15, 5))
    
    with torch.no_grad():
        for i, idx in enumerate(indices):
            img, label = dataset[idx]
            img = img.unsqueeze(0).to(device)
            output = model(img)
            pred_label = output.argmax(dim=1).item()
            
            axes[i].imshow(img.squeeze(0).permute(1, 2, 0).cpu())
            axes[i].set_title(f"Pred: {class_names[pred_label]}\nTrue: {class_names[label]}")
            axes[i].axis('off')
    
    plt.show()

# Part V. Train SSL and Supervised models

In this part we will train and evaluate both models using the rotation dataset (STL10) and CIFAR-10 datasets. As I promised earlier, this part can be resource intensive, so I already done the training and on our slack group you can find the weights for trained models. You **can** train the model by yourself, but you can also just load the weights I provided. I will not tell you which cell is loading the weights you need to figure it out based on the code below :)

In the evaluation phase of our training we will plot the accuracy of the model on the train and validation datasets. We will also plot example images with their predicted labels to visually assess the model's performance.

In [None]:
### PARAMETERS ###
BATCH_SIZE = 256
NUM_EPOCHS = 20
LEARNING_RATE = 0.001
NUM_WORKERS = 0
NUMBER_OF_CLASSES = 4
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

In [None]:
# Get the data loaders
train_dl = torch.utils.data.DataLoader(rotation_dataset_train, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)
val_dl = torch.utils.data.DataLoader(rotation_dataset_test, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)

# Prepare the model
ssl_model = ssl_model.to(device)

# Define the optimizer
optimizer = torch.optim.AdamW(ssl_model.parameters(), lr=LEARNING_RATE)

# Define the loss function
criterion = torch.nn.CrossEntropyLoss()

# Train the model
train_accs, val_accs, state_dicts = train(ssl_model, train_dl, val_dl, optimizer, criterion, num_epochs=NUM_EPOCHS, device=device)


In [None]:
plt.plot(train_accs, label='train accuracy')
plt.plot(val_accs, label='val accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.title('Training and Validation Accuracy on Rotation Dataset')
plt.legend()
plt.show()

In [None]:
# Save the model
arg_max = np.argmax(val_accs)
print(f"Best model at epoch {arg_max+1} with accuracy {val_accs[arg_max]:.2f}%")
ssl_model.load_state_dict(state_dicts[arg_max])
torch.save(ssl_model.state_dict(), os.path.join(dataset_path, 'ssl_model.pth'))

# Load the model
# ssl_model.load_state_dict(torch.load(os.path.join(dataset_path, 'ssl_model.pth')))
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Check the performance of the model on the test set
# train_dl = torch.utils.data.DataLoader(rotation_dataset_train, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)
# val_dl = torch.utils.data.DataLoader(rotation_dataset_test, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)

# train_acc = validate(ssl_model, train_dl, device)
# print(f"Train Accuracy: {train_acc:.2f}%")
# val_acc = validate(ssl_model, val_dl, device)
# print(f"Validation Accuracy: {val_acc:.2f}%")

In [None]:
visualize_predictions(ssl_model, rotation_dataset_test, num_images=5, device=device, class_names=angles)

Okay we trained our SSL model let's now train our supervised model on CIFAR-10. The drill is the same as with the previous model you can either train the model by yourself or load the weights I provided. The code below is already prepared for you so you don't need to do anything. You can just run the code and check the results.

In [None]:
### PARAMETERS ###
BATCH_SIZE = 256
NUM_EPOCHS = 15
LEARNING_RATE = 0.001
NUM_WORKERS = 0
NUMBER_OF_CLASSES = 43
device = torch.device("mps")    
print(f"Using device: {device}") 

In [None]:
# Get the data loaders
train_dl = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)
val_dl = torch.utils.data.DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)

# Load the model
scratch_model = scratch_model.to(device)

# Define the optimizer
optimizer = torch.optim.AdamW(scratch_model.parameters(), lr=LEARNING_RATE)

# Define the loss function
criterion = torch.nn.CrossEntropyLoss()

# Train the model
train_accs, val_accs, state_dict = train(scratch_model, train_dl, val_dl, optimizer, criterion, num_epochs=NUM_EPOCHS, device=device)

In [None]:
plt.plot(train_accs, label='train accuracy')
plt.plot(val_accs, label='val accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.title('Training and Validation Accuracy on CIFAR-10 Dataset')
plt.legend()
plt.show()

In [None]:
# Save the model
arg_max = np.argmax(val_accs)
print(f"Best model at epoch {arg_max+1} with accuracy {val_accs[arg_max]:.2f}%")
scratch_model.load_state_dict(state_dict[arg_max])
torch.save(scratch_model.state_dict(), os.path.join(dataset_path, 'scratch_model.pth'))

# Load the model
# scratch_model.load_state_dict(torch.load(os.path.join(dataset_path, 'scratch_model.pth')))
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Check the performance of the model on the test set
# train_acc = validate(scratch_model, train_dl, device)
# print(f"Train Accuracy: {train_acc:.2f}%")
# val_acc = validate(scratch_model, val_dl, device)
# print(f"Validation Accuracy: {val_acc:.2f}%")

In [None]:
visualize_predictions(scratch_model, test_dataset, num_images=5, device=device, class_names=train_dataset.classes)

# Part VI. Prepare the feature extraction from ResNet models and visualization methods

The code below which is already prepared for you implement the feature extraction from the ResNet models. So using the function `extract_features` you can extract the features from the provided ResNet style model and the dataset.

The second function `plot_tsne` is used to plot the t-SNE visualization of the features. The t-SNE is a technique for dimensionality reduction that is particularly well-suited for visualizing high-dimensional data. It works by converting the similarities between data points into joint probabilities and then minimizing the Kullback-Leibler divergence between the joint probabilities and the corresponding probabilities in the lower-dimensional space.

We will provide to `plot_tsne` features extracted by the `extract_features` function from all our models. The function will convert features to 2D space and plot them using the `matplotlib` library. We also provide the labels for the data points to color them according to their class. 

In [None]:
def extract_features(model: torch.nn.Module, batch_x: torch.Tensor) -> torch.Tensor:
    """
    Extract features from the model.
    
    Args:
        model (torch.nn.Module): The model to use for feature extraction.
        batch_x (torch.Tensor): The input batch of images.
        
    Returns:
        torch.Tensor: The extracted features.
    """
    features = []
    
    def hook_fn(module, input, output):
        features.append(input[0].detach().cpu().numpy())
    
    hook = model.fc.register_forward_hook(hook_fn)
    model(batch_x)
    hook.remove()
    features = features[0]
    return features.reshape(features.shape[0], -1)

def plot_tsne(
    scratch_features: np.ndarray,
    ssl_features: np.ndarray,
    model_finetuned_features: np.ndarray,
    pca_features: np.ndarray,
    labels: np.ndarray
) -> None:
    """
    Visualize the features using t-SNE with a color gradient bar at the bottom.

    Args:
        scratch_features (np.ndarray): Features extracted from the model trained from scratch.
        ssl_features (np.ndarray): Features extracted from the model trained with SSL.
        model_finetuned_features (np.ndarray): Features extracted from the pretrained model by Pytorch team.
        pca_features (np.ndarray): Features extracted from PCA.
        labels (np.ndarray): Class labels for the features.
    """
    # Get TSNE
    scaler = StandardScaler()
    scaler_scratch_features = scaler.fit_transform(scratch_features)
    scaler_ssl_features = scaler.fit_transform(ssl_features)
    scaler_model_finetuned_features = scaler.fit_transform(model_finetuned_features)
    scaler_pca_features = scaler.fit_transform(pca_features)
    
    # Apply TSNE
    tsne = TSNE(n_components=2, random_state=SEED)
    tsne_scratch_features = tsne.fit_transform(scaler_scratch_features)
    tsne_ssl_features = tsne.fit_transform(scaler_ssl_features)
    tsne_model_finetuned_features = tsne.fit_transform(scaler_model_finetuned_features)
    tsne_pca_features = tsne.fit_transform(scaler_pca_features)
    
    # Create a figure with GridSpec for better control
    fig = plt.figure(figsize=(20, 18))
    gs = gridspec.GridSpec(3, 2, height_ratios=[1, 1, 0.15])
    
    # Get unique classes for colormap
    unique_labels = np.unique(labels)
    num_classes = len(unique_labels)
    
    # Create a color map
    cmap = plt.cm.get_cmap('viridis', num_classes)
    colors = [cmap(i) for i in range(num_classes)]
    
    # Plot 1: Model Trained from Scratch
    ax1 = fig.add_subplot(gs[0, 0])
    for i, label in enumerate(unique_labels):
        mask = labels == label
        ax1.scatter(
            tsne_scratch_features[mask, 0], 
            tsne_scratch_features[mask, 1],
            color=colors[i], 
            s=20, 
            alpha=0.7,
            label=str(label)  # Convert label to string for plotting
        )
    ax1.set_title('Model Trained from Scratch', fontsize=16)
    ax1.set_xlabel('t-SNE Dimension 1', fontsize=12)
    ax1.set_ylabel('t-SNE Dimension 2', fontsize=12)
    
    # Plot 2: Model Trained with SSL
    ax2 = fig.add_subplot(gs[0, 1])
    for i, label in enumerate(unique_labels):
        mask = labels == label
        ax2.scatter(
            tsne_ssl_features[mask, 0], 
            tsne_ssl_features[mask, 1],
            color=colors[i], 
            s=20, 
            alpha=0.7
        )
    ax2.set_title('Model Trained with SSL', fontsize=16)
    ax2.set_xlabel('t-SNE Dimension 1', fontsize=12)
    ax2.set_ylabel('t-SNE Dimension 2', fontsize=12)
    
    # Plot 3: Fine-tuned SSL Model
    ax3 = fig.add_subplot(gs[1, 0])
    for i, label in enumerate(unique_labels):
        mask = labels == label
        ax3.scatter(
            tsne_model_finetuned_features[mask, 0], 
            tsne_model_finetuned_features[mask, 1],
            color=colors[i], 
            s=20, 
            alpha=0.7
        )
    ax3.set_title('Fine-tuned SSL Model', fontsize=16)
    ax3.set_xlabel('t-SNE Dimension 1', fontsize=12)
    ax3.set_ylabel('t-SNE Dimension 2', fontsize=12)
    
    # Plot 4: PCA
    ax4 = fig.add_subplot(gs[1, 1])
    for i, label in enumerate(unique_labels):
        mask = labels == label
        ax4.scatter(
            tsne_pca_features[mask, 0], 
            tsne_pca_features[mask, 1],
            color=colors[i], 
            s=20, 
            alpha=0.7
        )
    ax4.set_title('PCA', fontsize=16)
    ax4.set_xlabel('t-SNE Dimension 1', fontsize=12)
    ax4.set_ylabel('t-SNE Dimension 2', fontsize=12)
    
    # Create a horizontal colorbar at the bottom
    cbar_ax = fig.add_subplot(gs[2, :])
    norm = Normalize(vmin=0, vmax=num_classes-1)
    cb = ColorbarBase(
        cbar_ax, 
        cmap=cmap,
        norm=norm,
        orientation='horizontal'
    )
    
    cb.set_label(f'Classes (Total: {num_classes})', fontsize=14)
    
    # Add ticks for each class on the colorbar
    if num_classes <= 20:
        # For a smaller number of classes, show all ticks
        tick_positions = np.arange(num_classes)
        tick_labels = [str(label) for label in unique_labels]
    else:
        # For many classes, show every nth class
        n = max(1, num_classes // 20)
        tick_positions = np.arange(0, num_classes, n)
        tick_labels = [str(unique_labels[i]) for i in range(0, num_classes, n)]
    
    cb.set_ticks(tick_positions)
    cb.set_ticklabels(tick_labels)
    
    # Calculate frequency of each class
    class_counts = {}
    for label in labels:
        if label in class_counts:
            class_counts[label] += 1
        else:
            class_counts[label] = 1
    
    # Add frequency information above the colorbar
    if num_classes <= 40:
        for i, label in enumerate(unique_labels):
            if i % (max(1, num_classes // 20)) == 0:
                freq = class_counts.get(label, 0)
                cbar_ax.text(
                    i/(num_classes-1), 
                    1.1, 
                    f"{freq}", 
                    horizontalalignment='center',
                    fontsize=8
                )
    
    # Set the main title
    plt.suptitle(f'Comparing Feature Representations Using t-SNE Visualization\n(Total Classes: {num_classes})', fontsize=20)
    plt.tight_layout(rect=[0, 0, 1, 0.95])
    plt.show()

On the lecture I mentioned that there is a newer method for visualization called UMAP. Your task is to similar how it is done in `plot_tsne` function to implement the UMAP visualization. You can use the `umap-learn` library for this task. You can find the documentation here: https://umap-learn.readthedocs.io/en/latest/. The UMAP is a more recent method for dimensionality reduction that is often faster and produces better results than t-SNE. It is based on the concept of topological data analysis and uses a different approach to minimize the divergence between the high-dimensional and low-dimensional representations.

In [None]:
def plot_umap(
    scratch_features: np.ndarray,
    ssl_features: np.ndarray,
    model_finetuned_features: np.ndarray,
    pca_features: np.ndarray,
    labels: np.ndarray
) -> None:
    """
    Visualize the features using UMAP. This function first takes the features provided to the function and applies
    UMAP to reduce the dimensionality of the features to 2D. It then creates a 2x2 grid of subplots, where each
    subplot corresponds to a different set of features. The first subplot shows the features extracted from the model
    trained from scratch, the second subplot shows the features extracted from the model trained with SSL, the third
    subplot shows the features extracted from the fine-tuned SSL model, and the fourth subplot shows the PCA features.
    Each subplot contains a scatter plot where each point is colored according to its class label. The legend is
    shown in the fourth subplot.
    
    Args:
        scratch_features (np.ndarray): Features extracted from the model trained from scratch.
        ssl_features (np.ndarray): Features extracted from the model trained with SSL.
        model_finetuned_features (np.ndarray): Features extracted from the pretrained model by Pytorch team.
        pca_features (np.ndarray): Features extracted from PCA.
        labels (np.ndarray): Class labels for the features.
    """
    # CODE FOR UMAP VISUALIZATION
    pass

# Part VII. Train the PCA model

In this part we will train the PCA model. The PCA is a classical unsupervised learning method that is used for dimensionality reduction. I will not get into the mathematical details of PCA as you should already know how it works from the previous courses. In the cell below you need to train the PCA code on the CIFAR-10 dataset (but if you downloaded STL10 dataset I recommend using that instead, but the test set). The thing I require is to have PCA working with **flattened images** and the number of components set to the same as in dimension in embeddings we will be using from Resnet models `scratch_model.fc.in_features`. If PCA trains too long you can use the subset of the dataset (I used random sample of 5000 images). You also need to ensure that the PCA is standardized before fitting (Use `Pipeline` from `sklearn` library to combine the transformation steps wit PCA).

In [None]:
# Placeholder for the PCA model you need to create and train

# Visualize the explained variance ratio of the PCA components
plt.plot(np.cumsum(pca.named_steps['pca'].explained_variance_ratio_))
plt.xlabel('Number of Components')
plt.ylabel('Cumulative Explained Variance')
plt.title('PCA Explained Variance')
plt.grid()
plt.show()

# Part VIII. Evaluate the models on CIFAR-10 dataset

In this part we will evaluate the models on the CIFAR-10 dataset. The CIFAR-10 dataset is a subset of CIFAR-100 dataset with lesser number of classes. In this task we will extract features from our 4 models and try to evaluate them based on our dataset. 

Our evaluation will include the visualization of the features using t-SNE and UMAP methods. For performance evaluation as an alternative to linear probing we will use the KNN classifier. The KNN classifier is a simple and effective method for classification that works by finding the k nearest neighbors of a data point in the feature space and assigning the class label based.

The cell below will extract features from our models using the CIFAR-10 dataset.

In [None]:
# Extraction Hyperparameters
BATCH_SIZE = 64
NUM_WORKERS = 0

# Move models to eval state and to the same device
scratch_model.eval()
ssl_model.eval()
model_finetuned.eval()

model_finetuned = model_finetuned.to(device)

# Lists to store features
scratch_cifar10 = []
ssl_cifar10 = []
model_finetuned_cifar10 = []
pca_cifar10 = []
cifar10_labels = []

# Create a DataLoader for the test dataset (it is more efficient to use a DataLoader)
dl = torch.utils.data.DataLoader(test_dataset_cifar10, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)
for batch_x, labels in tqdm(dl, desc="Extracting features from CIFAR-10 test set"):
    batch_x = batch_x.to(device)
    cifar10_labels.extend(label for label in labels)
    
    # Extract features
    scratch_features = extract_features(scratch_model, batch_x)
    ssl_features = extract_features(ssl_model, batch_x)
    model_finetuned_features = extract_features(model_finetuned, batch_x)
    pca_features = pca.transform(batch_x.cpu().numpy().reshape(batch_x.shape[0], -1))
    
    # Fill the lists
    scratch_cifar10.append(scratch_features)
    ssl_cifar10.append(ssl_features)
    model_finetuned_cifar10.append(model_finetuned_features)
    pca_cifar10.append(pca_features)

# Concatenate the features from all batches
scratch_cifar10 = np.concatenate(scratch_cifar10, axis=0)
ssl_cifar10 = np.concatenate(ssl_cifar10, axis=0)
model_finetuned_cifar10 = np.concatenate(model_finetuned_cifar10, axis=0)
pca_cifar10 = np.concatenate(pca_cifar10, axis=0)
scratch_cifar10.shape, ssl_cifar10.shape, model_finetuned_cifar10.shape, pca_cifar10.shape

Let's now visualize our features with TSNE.

In [None]:
plot_tsne(
    scratch_cifar10, 
    ssl_cifar10, 
    model_finetuned_cifar10, 
    pca_cifar10, 
    cifar10_labels
)

Now let's visualize our features with UMAP.

In [None]:
plot_umap(
    scratch_cifar10,
    ssl_cifar10,
    model_finetuned_cifar10, 
    pca_cifar10, 
    cifar10_labels
)

Finally, let's evaluate our models using KNN classifier. The KNN classifier is a simple and effective method for classification that works by finding the k nearest neighbors of a data point in the feature space and assigning the class label based on the majority class of those neighbors. Your task is to implement this evaluation using the `KNeighborsClassifier` from the `sklearn` library. Each features extracted from the model should be used as input to the separate KNN classifier. The KNN classifier should be trained on the extracted features and labels used in the visualization. For the number of neighbors you can use 5. For the evaluation part of KNN you can use build in function `score` from the `KNeighborsClassifier` class and print the accuracy of the model on the same training features. So shortly you need to implement the following steps:
1. Create the KNN classifier with 5 neighbors
2. Train the KNN classifier on the extracted features and labels
3. Evaluate the KNN classifier on the same training features
4. Print the accuracy of the KNN classifier

And do it 4 times for each feature we extracted from the model.

In [None]:
# Placeholder for the KNN classifier

# Part IX. Evaluate the models on GTSRB dataset
In this part we will evaluate the models on the GTSRB dataset. The GTSRB dataset is a dataset with 43 classes of traffic signs. The task is the same as in the previous part. We will extract features from our models and try to evaluate them with visualizations and KNN classifier. You need to write the whole code for this evaluation but as you may guess, you just need to copy the code from the previous section and change the datasets from CIFAR-10 to GTSRB.

In [None]:
# Your time to shine

## Part X. Evaluate the models on FGVCAircraft dataset
You know the drill, copy the code, change the dataset and run the code. The FGVCAircraft dataset is a dataset with 100 classes of aircrafts. The task is the same as in the previous part. We will extract features from our models and try to evaluate them with visualizations and KNN classifier.

In [None]:
# And one last time

Now the main part of our exercise so the report.

In this part you need to write a short report (really short don't overdo it) on the results of the evaluation. You can use the following questions to guide you in writing the report:
- What can be read from the t-SNE and UMAP visualizations?
- What are the advantages and disadvantages of using t-SNE and UMAP for visualization?
- What are the advantages and disadvantages of using KNN classifier for evaluation?
- What are the results of the evaluation on the CIFAR-100 dataset?
- What are the results of the evaluation on the GTSRB dataset?
- What are the results of the evaluation on the FGVCAircraft dataset?
- Why the results are different for each dataset? (The hint is to think about the domain of the data and the domain that model was trained on and the task it solved)
- Why do you think the best performing model for each dataset is better than the others? Answer this question for each dataset separately.

**I want you to know that this notebook is not about showing you that SSL is awesome but to demonstrate how we can evaluate models features and how we can use them for different tasks. I want you to understand that showing only the results is often not enough and you need to know how to interpret them in the context of the current task and in future to improve the results. That's why I am more focused about your short report not the results of the notebook.**

You can send me the notebook with the markdown cell with the report via slack. Now just send me the notebook I don't need you google colab link. I will check the notebook and send you the feedback. If you have any questions or problems with the notebook, feel free to ask me on slack.
