HHU Deep Learning, SS2022/23, 21.04.2023, Prof. Dr. Markus Kollmann

Lecturers and Tutoring is done by Tim Kaiser, Nikolas Adaloglou and Felix Michels.

# Assignment 03 - Rotation prediction as a pretext task


## Contents

1. Preparation and imports
2. Preparing the data
3. Load and modify Resnet18
4. Launch Training
5. Visualizing the best model
6. Validation accuracy
7. Compute features
8. Linear evaluation: Probing
9. TSNE visualization

## Introduction to rotation prediction


Rotation prediction provides a simple, yet effective way to learn rich representations from unlabeled image data. The basic idea behind rotation prediction is that the network is trained to predict the orientation of a given image after it has been rotated by a certain angle (e.g., 0°, 90°, 180°, or 270°). 

By doing so, the network is forced to learn features that are invariant to rotation, which can be very useful for downstream tasks such as object recognition or image classification.

Rotation prediction is also a relatively simple task that can be applied to large amounts of unlabeled data, which makes it a good choice for pretraining neural networks. Therefore, it has become a popular choice for pretraining in many computer vision tasks.

In this exercise, we will train a ResNet18 on the task of rotation prediction.

Related paper: [Unsupervised Representation Learning by predicting Image Rotations](https://arxiv.org/pdf/1803.07728.pdf)

# Part I. Preparation and imports

In [None]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim
import torch.utils.data as data
import random
import matplotlib.pyplot as plt
from torchvision import transforms as T
from tqdm import tqdm

from utils import *

# Part II. Preparing the data

**Exercise:** Implement the dataset class below. It should return random rotations of images from the STL10 dataset. Instead of returning a tuple of `(img, label)` the dataset should now return `(img, rotation_class_id)`.

In [None]:
class STL10Rot(Dataset):
    def __init__(self, split="unlabeled", transform=None):
        super(STL10Rot, self).__init__()
        ### START CODE HERE ### (≈ 3 lines)
       
        ### END CODE HERE ###

    def __len__(self):
        return len(self.dataset)
    
    def rand_rotate(self, img):
        ### START CODE HERE ### (≈ 3 lines)
       
        ### END CODE HERE ###
    
    def __getitem__(self, idx):
        ### START CODE HERE ### (≈ 3 lines)
        
        ### END CODE HERE ###

def load_data(batch_size=128, train_split="unlabeled", test_split="test", rotation=True):
    # Returns a train and validation dataloader for STL10 dataset
    transf = T.Compose([T.ToTensor()])
    if rotation:
        train_ds = STL10Rot(split=train_split, transform=transf)
        val_ds = STL10Rot(split=test_split, transform=transf)
    else:
        train_ds = torchvision.datasets.STL10(root='../data', split=train_split, transform=transf, download=True)
        val_ds = torchvision.datasets.STL10(root='../data', split=test_split, transform=transf, download=True)
    train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
    val_dl = DataLoader(val_ds, batch_size=batch_size, shuffle=False)
    return train_dl, val_dl

def test_load_data():
    train_dl, val_dl = load_data(batch_size=4)
    for i, (x, y) in enumerate(train_dl):
        print(x.shape, y)
        plt.figure(figsize=(6, 6))
        imshow(x[0,...])
        break
    for i, (x, y) in enumerate(val_dl):
        print(x.shape, y.shape)
        plt.figure(figsize=(6, 6))
        imshow(x[0,...])
        break

test_load_data()

# Part III. Load and modify Resnet18

**Exercise:** Load and mofidy the Resnet18 architecture, so that you can use it with our data and to predict rotation angles. 

Which layers do you need to modify?

In [None]:
def load_resnet():
    ### START CODE HERE ### (≈ 3 lines)
   
    ### END CODE HERE ###
    return model

def test_load_resnet():
    model = load_resnet()
    train_dl , _ = load_data(batch_size=4)
    x, target = next(iter(train_dl))
    y = model(x)
    loss = F.cross_entropy(y, target)
    print(y.shape, target.shape, loss)

test_load_resnet()

# Part IV. Launch training!

**Exercise:** Choose the model and hyperparameters and launch the training, using the `pretrain` function from `utils.py`. It was imported at the top of the notebook. The final training should not take more than 1h (~50 mins with our setup for 0.041 train loss and 0.082 validatin loss).

In [None]:
### START CODE HERE ### (≈ 6 lines)

### END CODE HERE ###

# Part V. Visualizing the best model

**Exercise:** Load the best model using the `load_model` function from `utils.py`. Use the validation split to visualize the rotated images and their predicted angles vs the true ones. Visualize 16 predictions in a 4x4 grid.

In [None]:
### START CODE HERE ### (≈ 15 lines)

### END CODE HERE ###

# Part VI. Validation accuracy

**Exercise:** What percentage of image rotations is the model able to predict correctly on the validation split?

In [None]:
### START CODE HERE ### (≈ 6 line of code)

### END CODE HERE ###

### Expected result

```
Val Accuracy 97.99 || Loss Val 0.070
```

# Part VII. Compute features



**Exercise:** Use the `get_features` function from `utils.py` to generate a set features from your trained encoder. The features should be the output of the last layer **before** the full-connected one at the end. 

- Hint: Spend some time first on figuring out how to extract them in an elegant way. Loading the data and running `get_features` takes 3 of the 4 lines of code below. 
- Use rotation=False in load_data() to load the images with the labels.

In [None]:
### START CODE HERE ### (≈ 4 line of code)

### END CODE HERE ###
torch.save(train_feats, "train_feats.pth")
torch.save(val_feats, "val_feats.pth")
torch.save(train_labels, "train_labels.pth")
torch.save(val_labels, "val_labels.pth")

# Part VIII. Linear evaluation: Probing

**Exercise:** Evaluate the quality of the representations with linear probing, by using the `linear_eval` function from `utils.py`. Use the saved features and train a linear classifier on top on the **training split** of STL10.

- Hint: Remember that every torch.nn class is technically a "model" with all the functionality you need. You don't need to implement your own class here for the classifier. 
- - Use rotation=False in load_data() to load the images with the labels.

In [None]:
### START CODE HERE ### (≈ 10 line of code)

### END CODE HERE ###

### Expected results

```
Ep 49/50: Accuracy : Train:64.06 	 Val:60.33 || Loss: Train 1.154 	 Val 1.191: 100%|██████████| 50/50 [00:04<00:00, 11.00it/s]
```

# Part IX. TSNE visualization

**Exercise:** Visualize the computed features with TSNE.

In [None]:
from sklearn.manifold import TSNE

def tsne_plot_embeddings(features, labels, class_names):
    plt.figure(figsize=(12, 12))
    ### START CODE HERE ### (≈ 6 line of code)
    
    ### END CODE HERE ###
    plt.legend(class_names, fontsize=18, loc='center left', bbox_to_anchor=(1.05, 0.5))
    plt.title('TSNE plot STL10 learned features from rotation prediction', fontsize=18)
    plt.gca().axes.get_yaxis().set_visible(False)
    plt.gca().axes.get_xaxis().set_visible(False)
    plt.savefig("tsne_plot_embeddings_solution.png")
    plt.show()
    
val_feats, val_labels = torch.load("val_feats.pth"), torch.load("val_labels.pth")
class_names = torchvision.datasets.STL10(root='../data').classes
tsne_plot_embeddings(val_feats, val_labels, class_names)

### Expected result

<img src="tsne_plot_embeddings_solution.png" alt= “” width="800" height="800">

# Conclusion and Bonus reads

Are the features visually seperable in the 2D space? Why is that?

That's the end of this exercise. If you reached this point, congratulations!

If you are interested to delve into this topic further, here are some links:

- [Self-supervised learning and computer vision](https://www.fast.ai/posts/2020-01-13-self_supervised.html)
- [Self-Supervised Representation Learning: Introduction, Advances and Challenges](https://arxiv.org/abs/2110.09327)

