# Practicals - Pytorch: Autoencoder

In [None]:
# common imports
import numpy as np
import os
import random
import math

%matplotlib inline
import matplotlib as mpl
import matplotlib.pyplot as plt

# Path variables
PROJECT_ROOT_DIR = "."
MODEL_PATH = os.path.join(PROJECT_ROOT_DIR, "models")
DATA_PATH = os.path.join(PROJECT_ROOT_DIR, "data")
logging_dir = os.path.join(PROJECT_ROOT_DIR, "my_logs_4")
os.makedirs(MODEL_PATH, exist_ok=True)
os.makedirs(DATA_PATH, exist_ok=True)

In [None]:
# torch imports
import torch, torch.nn as nn
import torch.utils.data as data
import torch.nn.functional as F
import torch.optim as optim

# Torchvision
import torchvision
from torchvision.datasets import CIFAR10
from torchvision import transforms


%load_ext tensorboard

In [None]:
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        
set_seed(42)

In [None]:
# Ensure that all operations are deterministic on GPU (if used) for reproducibility
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [None]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

# Autoencoders - Introduction

In this tutorial, we will take a closer look at autoencoders (AE). 

An Autoencoder consists of: 
- an encoder, which encodes an image $x$ to a smaller-dimensional so-called "latent vector" $z$. 
- a decoder, which takes this latent vector $z$ and recreates the image $x$. 

Why? The encoding $z$ of $x$ contains all the information necessary to reconstruct $x$ (i.e. everything important), so it has to learn the features of the image!

Advantage of Autoencoder: It is an unsupervised method! You can use it even if you don't have labels. 

i.e.: if you want to use a CNN to learn features, but don't have labels to train it on, use an Autoencoder.

### Dataset 

We again work with the CIFAR10 dataset. In CIFAR10, each image has 3 color channels and is 32x32 pixels large. 

In contrast to the previous CIFAR10 tutorial, we do not normalize the data explicitly with a mean of 0 and std of 1, but roughly estimate it scaling the data between -1 and 1. This is because limiting the range will make our task of predicting/reconstructing images easier.

In the following cell: 
- define a transform that turns the input to a tensor and then normalizes it, where we assume the mean and standard deviation are 0.5 in all channels
- load the CIFAR10 training set with this transform and split in train and val set [45000, 5000] like before 
- load the test set with the same transform
- define `train_loader`, `val_loader`, and `test_loader` with batch size = 256

In [None]:
# Let's draw some of the training data
examples = enumerate(test_loader)
batch_idx, (example_data, example_targets) = next(examples)

fig = plt.figure()
for i in range(6):
  plt.subplot(2,3,i+1)
  plt.tight_layout()
  plt.imshow(example_data[i].permute(1, 2, 0))
  plt.title("Ground Truth: {}".format(example_targets[i]))
  plt.xticks([])
  plt.yticks([])
plt.show()

### Building the AE

The autoencoder consists of an **encoder** that maps the input $x$ to a lower-dimensional feature vector $z$, and a **decoder** that reconstructs the input $\hat{x}$ from $z$. We train the model by comparing $x$ to $\hat{x}$ and optimizing the parameters to minimize the MSE between these.

#### First Step: The Encoder

Note: we do not apply BatchNormlization here. This is because we want the encoding of each image to be independent of all other images. Otherwise, we might introduce correlations into the encoding or decoding that we do not want to have. So: best practice = No BatchNormalization in Autoencoders! Instead, if you want to Normalize, use e.g. LayerNormalization. Here, the model is so small we don't need normalization. 

In the following cell, complete the code that defined the AE encoder class:
- self.net consists of the following modules types:
    - module type 1: 3x3 convolution with stride 2 and padding = 1 and + activation function
    - module type 2: 3x3 same convolution with stride 1 + activation function
- in the following order: 
    - type 1(c_hid out channels)
    - type 2(2*c_hid out channels)
    - type 1(2* c_hid out channels) 
    - type 2(2* c_hid out channels) 
    - type 1(2* c_hid out channels)
    - flatten
    - Linear with output-dimension=latent_dim
  
- Define the paramter initialization `_init_params()` as follows (you can e.g. take a look at how it was done in the GoogleNet class in the last practicals...):
for the modules in `self.net`, use Xavier Normal (= Glorot) initialization for the Linear Layer and Kaiming Normal (=He) initialization for the Conv2d layers

- forward is just net applied to the input


In [None]:
class Encoder(nn.Module):
    
    def __init__(self, 
                 num_input_channels : int, 
                 base_channel_size : int, 
                 latent_dim : int, 
                 act_fn : object = nn.GELU):
        """
        Inputs: 
            - num_input_channels : Number of input channels of the image. For CIFAR, this parameter is 3
            - base_channel_size : Number of channels we use in the first convolutional layers. Deeper layers might use a duplicate of it.
            - latent_dim : Dimensionality of latent representation z
            - act_fn : Activation function used throughout the encoder network
        """
        super().__init__()
        c_hid = base_channel_size
        self.net = nn.Sequential(
            #TODO
        )

        self._init_params()

    def _init_params(self):
        #TODO
        
    def forward(self, x):
        return self.net(x)

**Question:** As always, write down dimensions and numbers of parameters for the Encoder above in the table below!
`
**Answer:** 

Input: 3x32x32

|Layer|Layer-Output-Dimension|Weights|Biases|
|:---|:---|:---|:---|
|Conv2d 1|x|x|x|
|Conv2d 2|x|x|x|
|Conv2d 3|x|x|x|
|Conv2d 4|x|x|x|
|Conv2d 5|x|x|x|
|Flatten|x|x|x|
|Linear|x|x|x|


#### Second Step: The Decoder

The decoder is a mirrored, flipped version of the encoder. The only difference is that we replace strided convolutions (type 1 above) by transposed convolutions (i.e. deconvolutions) `nn.ConvTranspose2d` with stride=2 to upscale the features in a "symmetric" way. 

Construct the Decoder `self.net` as almost the mirror of the Encoder `self.net`: 
- the decoder self.net consists of the following modules types:
    - module type 1: 3x3 same **transpose convolution** with stride 2, padding = 1 and and `output_padding=1`  + activation function 
    - module type 2: 3x3 same **convolution** with stride 1 + activation function
- in the following order: 
    - type 1(2*c_hid out channels)
    - type 2(2*c_hid out channels)
    - type 1(c_hid out channels) 
    - type 2(c_hid out channels) 
    - type 1(num_input_channels out channels)
    - Tanh

Construct the Decoder `self.linear` as the "reverse" Linear Layer (i.e. input and output dimension reversed) of the Encoder `self.linear` + `act_fn()`


Define the forward function as: 
- linear, 
- then reshape to batch x channel x HxB, where HxB=4x4, 
- then net. 

In [None]:
class Decoder(nn.Module):
    
    def __init__(self, 
                 num_input_channels : int, 
                 base_channel_size : int, 
                 latent_dim : int, 
                 act_fn : object = nn.GELU):
        """
        Inputs: 
            - num_input_channels : Number of channels of the image to reconstruct. For CIFAR, this parameter is 3
            - base_channel_size : Number of channels we use in the last convolutional layers. Early layers might use a duplicate of it.
            - latent_dim : Dimensionality of latent representation z
            - act_fn : Activation function used throughout the decoder network
        """
        super().__init__()
        c_hid = base_channel_size
        self.linear = nn.Sequential(
            #TODO
        )
        self.net = nn.Sequential(
            #TODO
        )

        self._init_params()

    def _init_params(self):
        #TODO

    
    def forward(self, x):
        #TODO

**Question:** Write down dimensions and numbers of parameters for the Decoder above in the table below!

**Answer:** 

|Layer|Layer-Output-Dimension|Weights|Biases|
|:---|:---|:---|:---|
|Linear|x|x|x|
|TransposeConv2d 1|x|x|x|
|Conv2d 1|x|x|x|
|TransposeConv2d 2|x|x|x|
|Conv2d 2|x|x|x|
|TransposeConv2d 3|x|x|x|



#### Putting it all together

Finally, combine the encoder and decoder together into the autoencoder architecture: 
Define the forward function below.

In [None]:
class Autoencoder(nn.Module):
    
    def __init__(self, 
                 base_channel_size: int, 
                 latent_dim: int, 
                 encoder_class : object = Encoder,
                 decoder_class : object = Decoder,
                 num_input_channels: int = 3, 
                 width: int = 32, 
                 height: int = 32):
        super().__init__()
        # Creating encoder and decoder
        self.encoder = encoder_class(num_input_channels, base_channel_size, latent_dim)
        self.decoder = decoder_class(num_input_channels, base_channel_size, latent_dim)
        # Example input array needed for visualizing the graph of the network
        self.example_input_array = torch.zeros(2, num_input_channels, width, height)
        
    def forward(self, x):
        """
        The forward function takes in an image and returns the reconstructed image
        """
        #TODO

    

## Training the Autoencoder

### Loss Function

For the loss function, we use the mean squared error (MSE). 
Recall: MSE computes the total squared error for each instance (an instance is an image here!) and then computes the mean only over the batch size. 

This is why we can't simply use the mse_loss function on the entire batch of size [Batch-size, height, width, channels], because it would compute the mean squared error of each pixel, i.e. over all dimensions in [Batch-size, height, width, channels], not just over batch-size. 


The following function `reconstruction_loss` fixes that: 
`mse_loss` with `reduction="none"` computes the absolute sqared error (without the mean) over all 4 tensor dimensions.
Then you sum up the absolute errors for each image (dim 1,2,3) and only compute the mean over batch size (dim 0). 


In [None]:
def reconstruction_loss(x, x_hat):
        loss = F.mse_loss(x, x_hat, reduction="none")
        loss = loss.sum(dim=[1,2,3]).mean(dim=[0])
        return loss

### Training and Eval Functions

We use a modified version of the well-known training and eval functions: 
- this time, without accuracy, since we want to reconstruct the images, not classify them
- with the loss_module replaced by the custom function `reconstruction_loss`.  
- also, this time, we don't need the CIFAR10 labels as labels, but since we want to reconstruct the images, we need input = label!

If you want to do training in a more "sleek" way in the future with pre-implemented callbacks for TensorFlow, you could read into [Pytorch Lightning](https://www.pytorchlightning.ai/index.html). 

In [None]:
# Import tensorboard logger from PyTorch
from torch.utils.tensorboard import SummaryWriter

def train_model_with_logger(model, optimizer, scheduler, data_loader, val_loader, num_epochs=50, logging_dir=logging_dir):
    # Create TensorBoard logger
    writer = SummaryWriter(logging_dir)
    model_plotted = False


    val_scores = []
    # Training loop
    for epoch in range(num_epochs):

        # Set model to train mode
        model.train()
        
        epoch_loss = 0.0
        
        for data_inputs, _ in data_loader:
            
            ## Step 1: Move input data to device (only strictly necessary if we use GPU)
            data_inputs = data_inputs.to(device)

            # For the very first batch, we visualize the computation graph in TensorBoard
            if not model_plotted:
                writer.add_graph(model, data_inputs)
                model_plotted = True

            ## Step 2: Run the model on the input data
            preds = model(data_inputs)

            ## Step 3: Calculate the loss 
            loss = reconstruction_loss(preds, data_inputs)
            
            ## Step 4: Perform backpropagation
            # Before calculating the gradients, we need to ensure that they are all zero.
            # The gradients would not be overwritten, but actually added to the existing ones.
            optimizer.zero_grad()
            # Perform backpropagation
            loss.backward()

            ## Step 5: Update the parameters
            optimizer.step()

            ## Step 6: Take the running average of loss and update true
            epoch_loss += loss.item()

        
        # Validation at the end of training: 
        val_loss = eval_model(model, val_loader)

        ## Scheduler Step: Do one LR scheduler step
        scheduler.step(val_loss)

        # Add average loss to TensorBoard
        epoch_loss /= len(data_loader)
        writer.add_scalar('training_loss',
                          epoch_loss,
                          global_step = epoch + 1)
        
        
        # Add validation loss to TensorBoard
        writer.add_scalar('validation_loss',
                          val_loss,
                          global_step = epoch + 1)
        
        
        # Produce the output
        print(f'''Epoch: {epoch} Training loss: {epoch_loss:.2f} ''')

    writer.close()

In [None]:
def eval_model(model, data_loader):
    model.eval() # Set model to eval mode
    loss = 0.0

    with torch.no_grad(): # Deactivate gradients for the following code
        for data_inputs, _ in data_loader:

            # Determine prediction of model on dev set
            data_inputs = data_inputs.to(device)
            preds = model(data_inputs)

            # Compute the batch's loss
            loss += reconstruction_loss(preds, data_inputs).item()

    loss /= len(data_loader)
    return loss

### Training the model 

For latent dimension = 128 create an instance `model_128` of the Autoencoder, push the model to the device and train it for 20 epochs. 

Configure the training of each model by defining: 
- the optimizer to be Adam with lr=1e-3
- the scheduler to be ReduceLROnPlateau with factor 0.2, patience=5, min_lr=5e-5

Train the model. 

Evaluate the model on the test set

In [None]:
eval_model(model_128, val_loader)

## Generating new images with an Autoencoder?

Recall: An Autoencoder consists of
- an encoder, which encodes an image $x$ to a smaller-dimensional so-called "latent vector" $z$. 
- a decoder, which takes this latent vector $z$ and recreates the image $x$. 

Idea: Can you train an autoencoder and use the decoder to generate random images (similarly to Midjourney) from a random latent vector $z$, i.e. a random vector of dimension = `latent_dim`? Let's find out below:

- we create a random `latent_vector` from the 128-dimensional latent space
- with `torch.no_grad()` (since we don't need gradients): apply the decoder of `model_128` to these `latent_vectors` and push them to the cpu. 

In [None]:
latent_vector = torch.randn(1, 128, device=device)

with torch.no_grad():
    img = model_128.decoder(latent_vector)
    img = img.squeeze().cpu()

plt.figure(figsize=(8,5))
plt.imshow(img.permute(1, 2, 0))
plt.axis('off')
plt.show()

As we can see, the generated images are not realistic images. As the autoencoder was allowed to structure the latent space in whichever way it suits the reconstruction best, there is no incentive to map every possible latent vector to realistic images. However, if you structure the latent space to learn the underlying probability distribution of the images, then you could do exactly what we tried to do above! This is a variant of the autoencoder called a **"Variational Autoencoder"**. It uses a statistical model on top of the Autoencoder, and is able to generate new samples.

## Finding visually similar images

One application of autoencoders is to build an image-based search engine to retrieve visually similar images. This can be done by representing all images as their latent dimensionality, and find the closest $K$ images in this domain. The first step to such a search engine is to encode all images into $z$. In the following, we will use the training set as a search corpus, and the test set as queries to the system.

Write a function `embed_imgs(model, data_loader)` that encodes all images in data_loader using our model_128, and return both image and encoding as lists `img_list` and `embed_list`: 
- create empty lists `img_list` and `embed_list`
- we don't train, so put the model in eval-mode 
- cycle through batches from the data_loader as usual
- with torch.no_grad(): push images to device and compute the embedding of these images with model.encoder
- add the images to img_list and the encodings to the embed_list 
- return torch.cat(list, dim_0) for both lists

In [None]:
def embed_imgs(model, data_loader):
    # Encode all images in the data_laoder using model, and return both images and encodings
    img_list, embed_list = [], []
    model.eval()
    for imgs, _ in data_loader:
        with torch.no_grad():
            z = model.encoder(imgs.to(device))
        img_list.append(imgs)
        embed_list.append(z)
        print(img_list, "Embed", embed_list)
    return (torch.cat(img_list, dim=0), torch.cat(embed_list, dim=0))



Apply it to both train_loader and test_loader to generate train_img_embeds and test_img_embeds

In [None]:
train_img_embeds = embed_imgs(model_128, train_loader)
test_img_embeds = embed_imgs(model_128, test_loader)

We now have images and encodings in lists for both train and test data. Below is a function that finds the $K$ closest images called `find_similar_images(query_img, query_z, key_embeds, K=8)`: 
- it computes the distance between the query-z and the key_embeds in the batch with torch.cdist(p=2) - this distance function computes the distance with respect to $p$-norm (for p=2 it's just ordinary Euclidean distance) for an entire batch. 
- then it sorts the distances with torch.sort and plot the K closest images. 

In [None]:
def find_similar_images(query_img, query_z, key_embeds, K=8):
    # Find closest K images. We use the euclidean distance here but other like cosine distance can also be used.
    dist = torch.cdist(query_z[None,:], key_embeds[1], p=2)
    dist = dist.squeeze(dim=0)
    dist, indices = torch.sort(dist)
    # Plot K closest images
    imgs_to_display = torch.cat([query_img[None], key_embeds[0][indices[:K]]], dim=0)
    grid = torchvision.utils.make_grid(imgs_to_display, nrow=K+1, normalize=True, range=(-1,1))
    grid = grid.permute(1, 2, 0)
    plt.figure(figsize=(12,3))
    plt.imshow(grid)
    plt.axis('off')
    plt.show()

Apply this function to the first 8 test_img_embeds-images and encodings with key_embeds= train_img_embeds. 

In [None]:
# Plot the closest images for the first N test images as example
for i in range(8):
    find_similar_images(test_img_embeds[0][i], test_img_embeds[1][i], key_embeds=train_img_embeds)

Outlook: If you want to see how important initialization is for NNs, play around with the initializations or do not initialize at all and see what happens.