## Implement AutoEncoder from Scratch (PyTorch Lightning)

An ```autoencoder``` is a type of ```artificial neural network``` used to learn efficient data codings in an ```unsupervised manner```. The aim of an autoencoder is to learn a representation (encoding) for a set of data, typically for dimensionality reduction, by training the network to ignore signal ```“noise”```.

In [2]:
import torch
from torch import nn, optim
from torch.utils.data import random_split, DataLoader
from torch.nn import functional as F
from torchvision import transforms

import pytorch_lightning as pl
from torchvision.datasets import MNIST

## 1. Setting Up ```LightningDataModule```

In [7]:
class MNISTDataModule(pl.LightningDataModule):
    def __init__(self, data_dir='./', batch_size=64, num_workers=1):
        super().__init__()
        
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.num_workers = num_workers

        # We hardcode dataset specific stuff here.
        self.num_classes = 10
        self.dims = (1, 28, 28)
        self.transform = transforms.Compose([
            transforms.ToTensor()
        ])
    
    def prepare_data(self):
        # download
        MNIST(self.data_dir, train=True, download=True)
        MNIST(self.data_dir, train=False, download=True)
    
    def setup(self, stage=None):
        # Assign train/val datasets for use in dataloaders
        if stage == 'fit' or stage is None:
            mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)
            self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])

        # Assign test dataset for use in dataloader(s)
        if stage == 'test' or stage is None:
            self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)
            
    def train_dataloader(self):
        return DataLoader(self.mnist_train, batch_size=self.batch_size, num_workers=self.num_workers, pin_memory=True)

    def val_dataloader(self):
        return DataLoader(self.mnist_val, batch_size=self.batch_size, num_workers=self.num_workers)

    def test_dataloader(self):
        return DataLoader(self.mnist_test, batch_size=self.batch_size, num_workers=self.num_workers)

## 2. Create the Model Architecture

### Basic AutoEncoder

We generate a ```lower-dimensional``` representation of an image, that can be ```decoded``` to reconstruct the original image

The difference in different architectures of Auto Encoders is:
- Method of creating a ```lower-dimensional representation```
- Method of ```reconstruction```.

<img src="https://miro.medium.com/max/573/1*IougT0pP1prt_kISvPLkoA.png"/>

## Basic Architecture

### Q1. Method of Creating lower-dimensional representation:
1. Flatten the image i.e, if the image is of size 100X100 it is flattened to the shape of 10,000X1.
2. Send it to a Dense Layer which takes the flattened shape to the size of the compressed representation

### Q2. Decoding Method:

- Q1. How are they creating lower-dimensional representations?
- Q2. How are they reconstructing the images back?

<b>For Encoding a batch of images</b>

In [17]:
class AutoEncoder(pl.LightningModule):
    def __init__(self, input_shape, representation_size):
        super().__init__()

        self.save_hyperparameters() # Saves the hyperparams -- input_shape, representation_size

        self.input_shape = input_shape
        self.representation_size = representation_size

        # Calculate the flattened size
        flattened_size = 1
        for x in self.input_shape:
            flattened_size *= x

        self.flattened_size = flattened_size

        # Initialise the Dense Layers
        self.input_to_representation = nn.Linear(self.flattened_size, self.representation_size)
        self.representation_to_output = nn.Linear(self.representation_size, self.flattened_size)


    def forward(self, image_batch):
        ## ENCODING
        # image_batch: [batch_size, ...] -- Other dimensions are the input_shape
        flattened = image_batch.view(-1, self.flattened_size)
        # flattened: [batch_size, flattened_size]
        representation = F.relu(self.input_to_representation(flattened))
        # representation: [batch_size, representation_size]

        ## DECODING
        flat_reconstructed = F.relu(self.representation_to_output(representation))
        # flat_reconstructed: [batch_size, flattened_size]
        reconstructed = flat_reconstructed.view(-1, *self.input_shape)
        # reconstructed is same shape as image_batch

        return reconstructed
    
    def training_step(self, batch, batch_idx):
        batch_images = batch[0]
        # Get the reconstructed images
        reconstructed_images = self.forward(batch_images)
        # Calculate loss
        batch_loss = F.mse_loss(reconstructed_images, batch_images)

        # store the result
        result = pl.TrainResult(minimize=batch_loss)
        result.batch_loss = batch_loss
        result.log('train_loss', batch_loss, prog_bar=True)

        return result
    
    def validation_step(self, batch, batch_idx):
        batch_images = batch[0]
        # Get the reconstructed images
        reconstructed_images = self.forward(batch_images)
        # Calculate loss
        batch_loss = F.mse_loss(reconstructed_images, batch_images)

        # store the result
        result = pl.TrainResult(checkpoint_on=batch_loss)
        result.batch_loss = batch_loss

        return result

    def test_step(self, batch, batch_idx):
        batch_images = batch[0]
        # Get the reconstructed images
        reconstructed_images = self.forward(batch_images)
        # Calculate loss
        batch_loss = F.mse_loss(reconstructed_images, batch_images)

        # store the result
        result = pl.TrainResult(checkpoint_on=batch_loss)
        result.batch_loss = batch_loss

        return result   

    def validation_end(self, outputs):
        # Take mean of all batch losses
        avg_loss = outputs.batch_loss.mean()
        result = pl.TrainResult(checkpoint_on=avg_loss)
        result.log('val_loss', avg_loss, prog_bar=True)
        return result

    def test_epoch_end(self, outputs):
        # Take mean of all batch losses
        avg_loss = outputs.batch_loss.mean()
        result = pl.TrainResult()
        result.log('test_loss', avg_loss, prog_bar=True)
        return result
    
    def configure_optimizers(self):
        return optim.Adam(self.parameters())

## Train Model

In [18]:
mnist_dm = MNISTDataModule()
model = AutoEncoder(input_shape=mnist_dm.size(), representation_size=128)

# We use 16-bit precision for lesser memory usage.
# progress_bar_refresh_rate=5, to avoid Colab from crashing

trainer = pl.Trainer(gpus=1, max_epochs=5, precision=16, progress_bar_refresh_rate=5)
trainer.fit(model, mnist_dm)

GPU available: True, used: True
TPU available: None, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Using native 16bit precision.

  | Name                     | Type   | Params
----------------------------------------------------
0 | input_to_representation  | Linear | 100 K 
1 | representation_to_output | Linear | 101 K 
----------------------------------------------------
201 K     Trainable params
0         Non-trainable params
201 K     Total params


Validation sanity check: 0it [00:00, ?it/s]

AttributeError: module 'pytorch_lightning' has no attribute 'TrainResult'

## Visualize

In [19]:
import matplotlib.pyplot as plt
from PIL import Image
%matplotlib inline

trans = transforms.ToPILImage()

In [20]:
model.eval()
for batch in mnist_dm.val_dataloader():
    original_imgs = batch[0]
    outputs = model(original_imgs)
    for i in range(len(outputs)):
        plt.figure()
        plt.imshow(trans(outputs[i]).convert("RGB"))
        plt.figure()
        plt.imshow(trans(original_imgs[i]).convert("RGB"))
        if i==3:
            break
    break

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking arugment for argument mat1 in method wrapper_addmm)

In [21]:
%load_ext tensorboard
%tensorboard --logdir ./lightning_logs/