# Lab11 - Self-supervised learning

In todays lab scenario we will focus on part of semi-supervised learning called self-supervised learning which focuses on learning representation from unlabeled data.

Particularly we will implement RotNet - a neural network that tries to learn meaningful embedding of the image by solving a "pretext" task of rotation prediction.

You can read more about the idea here: https://arxiv.org/abs/2012.01985v2

After learning the representation we will use it in semi-supervised scenario and comapre with an algorithm using only the supervised part of the dataset.

---

We will implement our models for [Fashion-MNIST](https://github.com/zalandoresearch/fashion-mnist) dataset, so lets start by downloading the data.

In [None]:
import torchvision
import torch

from torchvision import transforms

transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize((0.5,), (0.5,))])

fashion_mnist_data_train_raw = torchvision.datasets.FashionMNIST('./data', train=True, download=True, transform=transform)

fashion_mnist_data_test = torchvision.datasets.FashionMNIST('./data', train=False, download=True, transform=transform)

To properly perform the experiment we will split our training data to simulate that we do not have labels for majority of them.

In [None]:
from torch.utils.data.dataset import Dataset

train_size = int(0.02 * len(fashion_mnist_data_train_raw))
train_indices = torch.randperm(len(fashion_mnist_data_train_raw))[:train_size]
fashion_mnist_data_train = torch.utils.data.Subset(fashion_mnist_data_train_raw, train_indices)

class PoolDataset(Dataset):
  def __init__(self, dataset_to_hide_labels):
    self.dataset = dataset_to_hide_labels

  def __getitem__(self, index):
    X, y = self.dataset[index]
    return X

  def __len__(self):
    return len(self.dataset)

fashion_mnist_data_pool = PoolDataset(fashion_mnist_data_train_raw)

train_data_loader = torch.utils.data.DataLoader(fashion_mnist_data_train,
                                          batch_size=64,
                                          shuffle=True,
                                          num_workers=2)

test_data_loader = torch.utils.data.DataLoader(fashion_mnist_data_test,
                                          batch_size=64,
                                          shuffle=False,
                                          num_workers=2)

Now we will define a simple neural network that will be our model for representation learning. We will reuse some chunks of code from lab7.

In [None]:
import torch.nn as nn

model = nn.Sequential(
          nn.Conv2d(1,32, 3),
          nn.ReLU(),
          nn.MaxPool2d(2),
          nn.Dropout(),
          nn.Conv2d(32,64, 3),
          nn.ReLU(),
          nn.MaxPool2d(2),
          nn.Dropout(),
          nn.Conv2d(64, 32, 3),
          nn.ReLU(),
          nn.MaxPool2d(2),
          nn.Flatten(),
          nn.Dropout(),
          nn.Linear(32, 10)
        )
model = model.float()

optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
loss_fn = nn.CrossEntropyLoss()

We can define below functions for training and evaluation with created NN.

In [None]:
from sklearn.metrics import balanced_accuracy_score

def train_loop(dataloader, model, loss_fn, optimizer, num_epochs=1):
    size = len(dataloader.dataset)
    for epoch in range(num_epochs):
      for batch, (X, y) in enumerate(dataloader):
          pred = model(X)
          loss = loss_fn(pred, y)

          optimizer.zero_grad()
          loss.backward()
          optimizer.step()

          if batch % 100 == 0:
              loss, current = loss.item(), batch * len(X)
              print(f"Epoch {epoch}, loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")


def test_loop(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    test_loss, labels_estimated, correct_labels = 0, [], []

    with torch.no_grad():
        for X, y in dataloader:
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct_labels.extend(y.numpy())
            labels_estimated.extend((pred.argmax(1)).numpy())

    test_loss /= num_batches
    print(f"Test Error: \n BAC: {balanced_accuracy_score(correct_labels, labels_estimated)} \n Loss {test_loss}")

1. Train a classifier on the training set and evaluate its performance.

2. Train a RotNet on the pool data.

We can model RotNet as classifier with categorical output predicting the following categories:
- Image wasn't rotated
- Image was rotated by 90 degrees
- Image was rotated by 180 degrees
- Image was rotated by 270 degrees

Artificially create a label for each sample from the data pool during the training process.

3. Fine tune the RotNet network to the downstream task - fashion categories classification. Evaluate obtained network on the test set.

4.* Try to add another head to the self-supervised network that learns another "pretext" task.