In [None]:
!pip install lightning

Collecting lightning
  Using cached lightning-2.4.0-py3-none-any.whl.metadata (38 kB)
Collecting torchmetrics<3.0,>=0.7.0 (from lightning)
  Using cached torchmetrics-1.4.1-py3-none-any.whl.metadata (20 kB)
Collecting pytorch-lightning (from lightning)
  Using cached pytorch_lightning-2.4.0-py3-none-any.whl.metadata (21 kB)
Collecting nvidia-cudnn-cu12==8.9.2.26 (from torch<4.0,>=2.1.0->lightning)
  Using cached nvidia_cudnn_cu12-8.9.2.26-py3-none-manylinux1_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cusolver-cu12==11.4.5.107 (from torch<4.0,>=2.1.0->lightning)
  Using cached nvidia_cusolver_cu12-11.4.5.107-py3-none-manylinux1_x86_64.whl.metadata (1.6 kB)
Using cached lightning-2.4.0-py3-none-any.whl (810 kB)
Using cached nvidia_cudnn_cu12-8.9.2.26-py3-none-manylinux1_x86_64.whl (731.7 MB)
Using cached nvidia_cusolver_cu12-11.4.5.107-py3-none-manylinux1_x86_64.whl (124.2 MB)
Using cached torchmetrics-1.4.1-py3-none-any.whl (866 kB)
Using cached pytorch_lightning-2.4.0-py3-none-any.

In [None]:
import albumentations as A
from albumentations.pytorch import ToTensorV2


import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages
import numpy as np
import scipy.ndimage as ndi
import os
import pandas as pd
import cv2
import random

from sklearn.model_selection import train_test_split
import shutil
from PIL import Image

#imports for pyroch lighting
import lightning as L
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.loggers import TensorBoardLogger
from lightning.pytorch import Trainer
from lightning.pytorch.callbacks import Callback

from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from tqdm.notebook import tqdm
from torch import nn, optim
from torch.utils.data import DataLoader, random_split, Dataset
from torchvision import datasets, transforms
import torch.nn.functional as F
import torch

from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


Implementation of CNN-based Variatonal Autoencoder Model with Pytorch Lighting

In [None]:
#encoder implementation
#can use 4th conv layer in encoder and 1st transpose_conv layer in decoder for higher compression. Currently not used in forward functions

class Encoder(nn.Module):
    def __init__(self, latent_dim=128):  # Adjust latent_dim as needed
        super(Encoder, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=2, padding=1)  # 256 -> 128
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1)  # 128 -> 64
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1)  # 64 -> 32
        self.conv4 = nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1)  # 32 -> 16

        self.fc_mu = nn.Linear(256 * 16 * 16, latent_dim)  # Latent mean
        self.fc_log_var = nn.Linear(256 * 16 * 16, latent_dim)  # Latent log-variance

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = F.relu(self.conv4(x))
        x = x.view(x.size(0), -1)  # Flatten for linear layer
        mu = self.fc_mu(x)
        log_var = self.fc_log_var(x)
        return mu, log_var

#VAE Decoder Implementation
class Decoder(nn.Module):
    def __init__(self, latent_dim=128):
        super(Decoder, self).__init__()
        self.fc = nn.Linear(latent_dim, 256 * 16 * 16) #flattened, so need to match the size

        self.deconv1 = nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, output_padding=1)  # 16 -> 32
        self.deconv2 = nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1)  # 32 -> 64
        self.deconv3 = nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=1)  # 64 -> 128
        self.deconv4 = nn.ConvTranspose2d(32, 1, kernel_size=3, stride=2, padding=1, output_padding=1)  # 128 -> 256

    def forward(self, z):
        x = self.fc(z)
        x = x.view(x.size(0), 256, 16, 16) #reshape back to prev conv layer's output shape
        x = F.relu(self.deconv1(x))
        x = F.relu(self.deconv2(x))
        x = F.relu(self.deconv3(x))
        x = torch.sigmoid(self.deconv4(x)) #Sigmoid for pixel values in [0, 1]
        return x

In [None]:
#reparam step for VAEs:
def reparameterize(mu, log_var):
    std = torch.exp(0.5 * log_var)
    epsilon = torch.randn_like(std)
    return mu + std * epsilon

In [None]:
#VAE Class
class VAE(L.LightningModule):
    def __init__(self, encoder, decoder):
        super(VAE, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.test_loss_list = []

    def forward(self, x):
        mu, log_var = self.encoder(x)
        z = reparameterize(mu, log_var)
        reconstructed_img = self.decoder(z)
        return reconstructed_img, mu, log_var #after one encoder-decoder pass, recon_img used to calculate recon_loss. (mu, log_var) for kl_divergence

    def _get_vae_loss(self, batch):
        x = batch
        x_hat, mu, log_var = self(x) #get from forward pass

        #recon loss
        recon_loss = F.mse_loss(x_hat, x, reduction='sum')

        #KL divergence loss
        kl_loss = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())

        return recon_loss + kl_loss

    def training_step(self, batch, batch_idx):
        loss = self._get_vae_loss(batch) #minimize on combined recon, kl loss
        self.log("train_loss", loss, logger=True, prog_bar=True, on_epoch=True)
        return loss

    def validation_step(self, batch, batch_idx):
        loss = self._get_vae_loss(batch)
        self.log("val_loss", loss, logger=True, prog_bar=True, on_epoch=True)
        return loss

    def test_step(self, batch, batch_idx):
        loss = self._get_vae_loss(batch)
        self.test_loss_list.append(loss.item())
        self.log("test_loss", loss, logger=True, prog_bar=True, on_epoch=True)
        return loss

    def on_test_epoch_end(self): #return the losses of test samples. Accssed via class-object.test_loss_list
        return {"test_loss_list": self.test_loss_list}

    def configure_optimizers(self):
        #add l2 reg term to reduce overfitting (weight_decay)
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3, weight_decay=1e-5)
        #lr scheduler
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.2, patience=10, min_lr=5e-5)
        return {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "val_loss"}


In [None]:
#VAE Initializations
latent_dim = 128 #latent dimension for the latent space of VAE. Larger might be better for larger images

encoder = Encoder(latent_dim=latent_dim)
decoder = Decoder(latent_dim=latent_dim)

vae = VAE(encoder, decoder)

Dataset implementation

Testing VAE with Mnist and Mnist fashion datset to see if model is working correctly. Fahsion dataset is anomly data. If working correctly, the Fashion Dataset should have higher loss than MNIST digits