# Homework 10 - Self-supervised Learning

Self-supervised (SSL) learning is a type of machine learning algorithms that use unlabeled data to learn representations of the data. 
It is different from supervised learning in that it does not require labeled data - instead, the algorithm learns by exploring the data and trying to uncover underlying patterns. 
This allows for models to learn more complex representations of the data and make better predictions.

### [SSL methods]((https://arxiv.org/pdf/2006.08218.pdf)):
- Reconstruct from a corrupted (or partial) version
  - [Denoising Autoencoders](https://www.cs.toronto.edu/~larocheh/publications/icml-2008-denoising-autoencoders.pdf), [Masked Autoencoders](https://arxiv.org/abs/2111.06377)
  - [In-painting](https://arxiv.org/abs/1604.07379)
- Visual common sense tasks
  - [Relative patch prediction](https://arxiv.org/abs/1505.05192)
  - [Jigsaw puzzles](https://arxiv.org/abs/1603.09246)
  - [Rotation prediction](https://arxiv.org/abs/1803.07728)
- Contrastive Learning
  - [Simple Framework for Contrastive Learning of Visual Representations (SimCLR)](https://arxiv.org/abs/2002.05709)
  - [Momentum Contrast (MoCo)](https://arxiv.org/abs/1911.05722)
  - [Bootstrap Your Own Latent (BYOL)](https://arxiv.org/abs/2006.07733)

The details of the method can be found by clicking on it, and implementations can be found in [OpenMixup](https://github.com/Westlake-AI/openmixup).

In this notebook, we will start form implementing an simple Autoencoder:

## Autoencoder

 Autoencoders (AEs) attempt to find a compressed representation for some dataset. The model does this by encoding the data into a smaller number of values, a compressed vector. During training, the vector is passed through a decoder which attempts to reconstruct the original data. The loss is then the mean squared error (MSE) between the original data $x_{ori}$ and reconstruction $x_{rec}$.

 $$\ell_{MSE} = \sum_{i=1}^{D}(x_{ori}-x_{rec})^2$$

![A autoencoder. ](./imgs/AE.png)

The compressed space is also named Latent Space.

In [None]:
%matplotlib inline

import time

import numpy as np
import matplotlib.pyplot as plt
import torch
from torch import nn
from torch import optim
import torch.nn.functional as F
from torch.autograd import Variable
from torchvision import datasets, transforms
from utils import test_network, view_recon

In [2]:
# Get the data 
transform = transforms.Compose([transforms.ToTensor(),
                            transforms.Normalize((0.5), (0.5)),
                            ])
trainset = datasets.MNIST('./data/', download=True, train=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)

testset = datasets.MNIST('./data/', train=False, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False)

## Network Building

We will be constructing a network for an autoencoder by creating a class that is derived from the `nn.Module`. 

The approach I took was to split the forward pass into two parts which are the encoder and decoder. 
This is done so that after training we will mainly use the encoded vector. 
In the `forward` method, the connections between the two will be established.

```python
class Network(nn.Module):
    def __init__(self):
        super().__init__()
        
        # Define the layers and operations here

    def forward(self, x):

        x = self.encoder(x)
        logits = self.decoder(x)

        return logits

    def encoder(self, x):
        # Forward pass through the encoder
        return x

    def decoder(self, x):
        # Forward pass through the decoder
        return x
```

### Encoder
To create the encoder, we'll combine nn.Conv2d with dropout from `nn.Dropout2d`.
There is an option to downsample using max-pooling but many developers are opting to use strides instead. 
When inferring, you should turn off dropout and other optional layers that can be disabled like batch normalization with `net.eval()`. 
To reverse it, use `net.train()`.

### Decoder
With the decoder, you'll need to upsample the layers. This can be done with transposed convolutions (`nn.ConvTranspose2d`) or by nearest neighbor upsampling (`nn.UpsamplingNearest2d`). With transposed convolutions, you define the kernel size and strides like normal, but when you call the module, you can set an output size. 

In [3]:
class Autoencoder(nn.Module):
    def __init__(self, drop_prob=0.5):
        super().__init__()
        # Tip: Encoder
        pass
        
        # Tip: Decoder
        pass
    
    def forward(self, x):
        x = self.encode(x)
        logits = self.decode(x)
        
        return logits

    def encode(self, x):
        # Tip: encoding x to the latent z
        pass
        
    def decode(self, x):
        # Tip: decoding the latent z to x'
        pass

## Training

Now we'll train the network. We're using the MSE for the loss, so we'll use `criterion = nn.MSELoss()`. As before, we update the weights by doing a forward pass through the network, then calculate the loss, get the gradients with `loss.backward()`, then make an update step with `optimizer.step()` (with an Adam optimizer, optim.Adam).

In [None]:
net = Autoencoder(drop_prob=0.1)

criterion = nn.MSELoss()
optimizer = optim.Adam(net.parameters(), lr=0.001)
epochs = 10
print_every = 200  # Number of training steps to print losses
show_every = 500  # Number of training steps to show reconstructions
cuda = True        # Train on GPU or not

if cuda:
    net.cuda()

running_loss = 0
for e in range(epochs):
    start = time.time() # Start timing
    for i, (images, _) in enumerate(trainloader, 0): # Grab the images and labels
        
        inputs = Variable(images)
        targets = Variable(images) # self-supervised

        if cuda:
            inputs, targets = inputs.cuda(), targets.cuda()
        
        optimizer.zero_grad() # Reset current gradients
        
        output = net.forward(inputs)
        loss = criterion(output, targets) # Calculate the loss
        loss.backward()
        optimizer.step() # Update network weights
        
        running_loss += loss.item() # Accumulate the loss over all batches
        
        if i % print_every == 0:
            net.eval()
            stop = time.time()
            # Test accuracy
            val_loss = 0
            for ii, (images, labels) in enumerate(testloader):
            
                inputs = Variable(images)
                targets = Variable(images)
                    
                if cuda:
                    inputs, targets = inputs.cuda(), targets.cuda()
                
                output = net.forward(inputs)
                val_loss += criterion(output, targets).item()
                
            print("Epoch: {}/{}..".format(e+1, epochs),
                "Loss: {:.4f}..".format(running_loss/print_every),
                "Test loss: {:.4f}..".format(val_loss/(ii+1)),
                "{:.4f} s/batch".format((stop - start)/print_every)
                )
            
            running_loss = 0
            start = time.time()
            net.train()
            
        if i % show_every == 0:
            net.cpu()
            net.eval()
            img = images[3]
            with torch.no_grad():
                x = Variable(img.resize_(1, *img.size()))
            recon = net(x)
            
            view_recon(img, recon)
            plt.show()
            if cuda:
                net.cuda()
            net.train()