# Code challenge: self-supervised learning and embeddings generation
To address this challenge, we will follow the SimCLR approach of Chen et al. (2020)...

## Import libraries and modules:

In [137]:
import math
import torch # load pytorch for machine learning
import torch.nn as nn # load module for neural networks
import torchvision.transforms as tvtran # load module for transformations to transform / augment data 
import torchvision.datasets as tvdat # load module for handling datasets
import torchvision.models as tvmod
from torchvision.models import resnet18, ResNet18_Weights

## Dataset loading
We begin by loading the dataset of pet images from the root folder:

In [114]:
x_raw = tvdat.ImageFolder('../CodeChallenge/data/', transform=None) 

## Dataset augmentation
Next we specify a set of transformations to aid contrastive learning, based on recommendations in the SimCLR paper [Chen et al. (2020)](https://arxiv.org/pdf/2002.05709):

In [115]:
transforms = tvtran.Compose([
    # spatial transformations
    tvtran.RandomResizedCrop(size=224), # random crop, resized to 224x224
    tvtran.RandomHorizontalFlip(p=0.5), # flip horizontally 50% of the time
    # colour distortion
    tvtran.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.2), # illustrative parameters
    tvtran.RandomGrayscale(p=0.2), # convert to greyscale with probability 20%
    tvtran.GaussianBlur(kernel_size=3, sigma=(0.1, 2.0)), # illustrative parameters
    # convert to tensor
    tvtran.ToTensor()
])

The SimCLR approach requires two augmented versions of each image. Define a class for generating a customised dataset, using the base class `torch.utils.data.Dataset`:

In [118]:
class SimCLR_custom_dat(torch.utils.data.Dataset):
    def __init__(self, raw_dataset, transform):
        self.raw_dataset = raw_dataset
        self.transform = transform
    def __len__(self):
        return len(self.raw_dataset)
    def __getitem__(self, idx):
        img, label = self.raw_dataset[idx] # image and label from ImageFolder dataset
        # apply two stochastic augmentations to the same image
        x1 = self.transform(img)
        x2 = self.transform(img)
        return x1, x2, label

Now, we can generate our customised dataset of augmented (and challenging) images, providing positive pairs for contrastive learning:

In [119]:
x_simclr = SimCLR_custom_dat(x_raw, transforms)

## Data loader
Create a data loader for inputting images into the model:

In [138]:
batch_size = 32 # this is arbitrary; may be tweaked later

train_loader = torch.utils.data.DataLoader(x_simclr, batch_size=batch_size, shuffle=True, drop_last=True) # drop_last=True drops incomplete batches

## Base encoder
The framework specified in Chen et al. (2020) consists of a base encoder framework to create representations of the positive pairs (denoted `f` in the paper, see Fig. 2) and a projection head (denoted `g`). 

For the base encoder, we will use ResNet18 together with a set of pre-trained weights. This model is selected as a relatively small network, to show proof-of-concept without requiring large compute resources. 

In [126]:
f = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1) # load ResNet18 as the base encoder, together with pre-trained weights

ResNet18 has 1000 classes of animals in its output layer, but we are interested in just 3 classes (chinchilla, rabbit, hamster). Therefore, we truncate the base model to remove the final layer. The output of this truncated model will be that of the penultimate layer (average pooling), a 512-dimensional feature vector for each image: 

In [127]:
f_trunc = nn.Sequential(*list(f.children())[:-1]) # remove final classification layer

## Projection head
The projection head `g` is a small neural network that maps the representations generated by `f_trunc` to the space where contrastive lost is applied. 

Following Chen et al. (2020), we define `g` to be a multi-layer perceptron with one hidden layer. 

In [130]:
g = nn.Sequential(
    nn.Linear(512, 512), # hidden layer
    nn.ReLU(), # activation function
    nn.Linear(512, 128) # output layer
)

## Combined model
Next, we use the pytorch base class `torch.nn.Module` to define our custom neural network architecture, which will combine the encoder `f_trunc` with the projection head `g`:

In [132]:
class CustomModel(nn.Module):
    def __init__(self, ff, gg):
        super(CustomModel, self).__init__()
        self.encoder = ff
        self.projection_head = gg
    def forward(self, x):
        h = self.encoder(x)           # representation of x generated by f_trunc; shape: [batch, 512, 1, 1] for ResNet18
        h = h.view(h.size(0), -1)     # flatten to shape: [batch, 512]
        z = self.projection_head(h)         # apply projection head, to project to 128-dim; z will be used in the constrative loss calculation
        return z

Create an instance of our model:

In [133]:
model = CustomModel(f_trunc, g)

## Training parameters
Here, we specify the device on which the training will occur (CPU will be used here for demonstration), and also the optimizer to be used in training:

In [136]:
# move model to CPU
model = model.to('cpu')

# optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) # use Adam optimizer with low learning rate (lr)

## Self-supervised training