# CIFAR10 Variational Auto-Encoder

## Setup

In [None]:
import torch 
import torch.nn as nn 
import torch.nn.functional as F

## The model

In [None]:
from operator import itemgetter

import torch.optim as optim
import pytorch_lightning as pl

from variational_autoencoder import VariationalAutoEncoder, _tensor_size_3_t

?VariationalAutoEncoder

### Model Parameters

In [None]:
# reconstruction loss scaling factor
R_SCALING_FACTOR = 1000
LEARNING_RATE = 0.0005

MODEL_PARAMS = {
    'r_scaling_factor': R_SCALING_FACTOR,
    'learning_rate': LEARNING_RATE,
    'enc_in_channels': [3, 32, 64],
    'enc_out_channels': [32, 64, 128],
    'enc_kernel_sizes': [3, 3, 3],
    'enc_strides': [1, 2, 1],
    'enc_paddings': [1, 1, 1],
    'dec_in_channels': [128, 64, 32],
    'dec_out_channels': [64, 32, 3],
    'dec_kernel_sizes': [3, 3, 3],
    'dec_strides': [1, 2, 1],
    'dec_paddings': [1, 1, 1],
    'dec_output_paddings': [0, 1, 0],
    'latent_dim': 4,
    'use_batchnorm': True,
    'use_dropout': True
}

In [None]:
class LitVAE(pl.LightningModule):
    def __init__(self,
                 r_scaling_factor: float = 1000,
                 learning_rate: float = 0.0005,
                 **kwargs) -> None:
        """
        Parameters
        ----------
        - `r_scaling_factor: float`:
            scaling factor for the reconstruction loss
        - `learning_rate: float`:
            learning rate for the optimizer
        - `**kwargs`:
            arguments to pass to the variational autoencoder constructor
        """
        super(LitVAE, self).__init__()
        
        self.r_scaling_factor = r_scaling_factor
        self.learning_rate = learning_rate 

        self.vae = VariationalAutoEncoder(**kwargs)

    def forward(self, x) -> _tensor_size_3_t: 
        return self.vae(x)

    def training_step(self, batch, batch_idx):
        loss = self.shared_step(batch)
        
        self.log("train_loss_step", loss)

        return {"loss": loss}

    def training_epoch_end(self, outputs) -> None:
        # add computation graph
        if(self.current_epoch == 0):
            sample_input = torch.randn((1, 3, 32, 32))
            sample_model = LitVAE(**MODEL_PARAMS)
            
            self.logger.experiment.add_graph(sample_model, sample_input)
            
        epoch_loss = self.average_metric(outputs, "loss")
        self.logger.experiment.add_scalar("train_loss_epoch", epoch_loss, self.current_epoch)

    def validation_step(self, batch, batch_idx):
        loss = self.shared_step(batch)

        self.log("valid_loss_step", loss)

        return {"loss": loss}

    def validation_epoch_end(self, outputs) -> None:
        epoch_loss = self.average_metric(outputs, "loss")
        self.logger.experiment.add_scalar("valid_loss_epoch", epoch_loss, self.current_epoch)

    def test_step(self, batch, batch_idx):
        loss = self.shared_step(batch)

        self.log("test_loss_step", loss)

        return {"loss": loss}

    def test_epoch_end(self, outputs) -> None:
        epoch_loss = self.average_metric(outputs, "loss")
        self.logger.experiment.add_scalar("test_loss_epoch", epoch_loss, self.current_epoch)

    def configure_optimizers(self):
        return optim.Adam(self.parameters(), lr=self.learning_rate)
        
    def shared_step(self, batch) -> torch.TensorType: 
        # images are both samples and targets thus original 
        # labels from the dataset are not required
        true_images, _ = batch

        # perform a forward pass through the VAE 
        # mean and log_variance are used to calculate the KL Divergence loss 
        # decoder_output represents the generated images 
        mean, log_variance, generated_images = self(true_images)

        loss = self.calculate_loss(mean, log_variance, generated_images, true_images)

        return loss 

    def calculate_loss(self, mean, log_variance, predictions, targets): 
        #reconstruction loss
        r_loss = F.mse_loss(predictions, targets)
        # KL-Loss
        kl_loss = self.kl_loss(mean, log_variance)

        return r_loss * self.r_scaling_factor + kl_loss

    def kl_loss(self, mean, log_variance): 
        loss = -0.5 * torch.sum(1 + log_variance - torch.square(mean) - torch.exp(log_variance))

        return loss 

    def average_metric(self, metrics, metric_name):
        avg_metric = torch.stack([x[metric_name] for x in metrics]).mean()
        return avg_metric


## The Data

In [None]:
import os 

from torch.utils.data import DataLoader, random_split
from torchvision import transforms
from torchvision.datasets import CIFAR10 

In [None]:
DATA_DIR = os.path.join(os.getcwd(), "../.data")

In [None]:
class CIFAR10DataModule(pl.LightningDataModule):
    """
    Implements the data loading functionality.
    """
    def __init__(self, root_dir: str = DATA_DIR, batch_size: int = 32, num_workers: int = 0, pin_memory: bool = False) -> None:
        super().__init__()

        self.root_dir = root_dir
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.pin_memory = pin_memory

    def prepare_data(self):
        # download training and test data
        CIFAR10(root=self.root_dir, train=True, download=True)
        # settting train to false downloads the test data>
        CIFAR10(root=self.root_dir, train=False, download=True)

    def setup(self, stage=None):
        # normalize and transform the images into tensors
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
        ])

        # load training and test data
        cifar10_train = CIFAR10(self.root_dir, train=True, transform=transform)
        self.cifar10_test = CIFAR10(self.root_dir,
                                    train=False,
                                    transform=transform)

        # split training data into training and validation sets
        # training set will have 45_000 images and validation set 5_000
        self.cifar10_train, self.cifar10_val = random_split(
            cifar10_train, [45_000, 5_000])

    def train_dataloader(self):
        return DataLoader(self.cifar10_train, batch_size=self.batch_size, num_workers=self.num_workers, pin_memory=self.pin_memory)

    def val_dataloader(self):
        return DataLoader(self.cifar10_val, batch_size=self.batch_size, num_workers=self.num_workers, pin_memory=self.pin_memory)

    def test_dataloader(self):
        return DataLoader(self.cifar10_test, batch_size=self.batch_size, num_workers=self.num_workers, pin_memory=self.pin_memory)

## The Engineering

### Training

#### Setup Model-Checkpointing

In [None]:
from pytorch_lightning.callbacks import ModelCheckpoint

# monitor validation loss
checkpoint_callback = ModelCheckpoint(monitor="valid_loss_step", verbose=True)

In [None]:
cifar10_dm = CIFAR10DataModule(num_workers=4)
model = LitVAE(**MODEL_PARAMS)

trainer = pl.Trainer(max_epochs=1, callbacks=[checkpoint_callback])


In [None]:
trainer.fit(model, cifar10_dm)

#### Visualize Training

In [None]:
%load_ext tensorboard
%tensorboard --logdir lightning_logs/

### Testing

#### Load model checkpoint

In [None]:
checkpoint_callback.best_model_score

In [None]:
trainer.test()

## Prediction

#### The Model

In [None]:
# load model from checkpoint
model = LitVAE.load_from_checkpoint(checkpoint_callback.best_model_path, **MODEL_PARAMS)

### The Data

In [None]:
# download data
cifar10_dm.prepare_data()
# load data 
cifar10_dm.setup()

In [None]:
NUM_IMAGES = 10
true_images = next(iter(cifar10_dm.test_dataloader()))[0][:NUM_IMAGES]

true_images.size()

### The Prediction

In [None]:

mean, log_variance, reconstructed_images = model(true_images)

reconstructed_images.size()

### Visualizing the predictions

In [None]:
import matplotlib.pyplot as plt 

In [None]:
fig = plt.figure(figsize=(15, 3))
fig.subplots_adjust(hspace=0.4, wspace=0.4)

for i in range(NUM_IMAGES):
    # unnormalize the image
    img = true_images[i].squeeze() * 0.5 + 0.5
    img = transforms.ToPILImage()(img)
    ax = fig.add_subplot(2, NUM_IMAGES, i+1)
    ax.axis('off')
    ax.imshow(img)
    
for i in range(NUM_IMAGES):
    # unnormalize the image
    img = reconstructed_images[i].squeeze() * 0.5 + 0.5
    img = transforms.ToPILImage()(img)
    ax = fig.add_subplot(2, NUM_IMAGES, i+NUM_IMAGES+1)
    ax.axis('off')
    ax.imshow(img)