In [3]:
!pip install lightning

Collecting lightning
  Using cached lightning-2.4.0-py3-none-any.whl.metadata (38 kB)
Collecting torchmetrics<3.0,>=0.7.0 (from lightning)
  Using cached torchmetrics-1.4.1-py3-none-any.whl.metadata (20 kB)
Collecting pytorch-lightning (from lightning)
  Using cached pytorch_lightning-2.4.0-py3-none-any.whl.metadata (21 kB)
Collecting nvidia-cudnn-cu12==8.9.2.26 (from torch<4.0,>=2.1.0->lightning)
  Using cached nvidia_cudnn_cu12-8.9.2.26-py3-none-manylinux1_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cusolver-cu12==11.4.5.107 (from torch<4.0,>=2.1.0->lightning)
  Using cached nvidia_cusolver_cu12-11.4.5.107-py3-none-manylinux1_x86_64.whl.metadata (1.6 kB)
Using cached lightning-2.4.0-py3-none-any.whl (810 kB)
Using cached nvidia_cudnn_cu12-8.9.2.26-py3-none-manylinux1_x86_64.whl (731.7 MB)
Using cached nvidia_cusolver_cu12-11.4.5.107-py3-none-manylinux1_x86_64.whl (124.2 MB)
Using cached torchmetrics-1.4.1-py3-none-any.whl (866 kB)
Using cached pytorch_lightning-2.4.0-py3-none-any.

In [4]:
import albumentations as A
from albumentations.pytorch import ToTensorV2


import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages
import numpy as np
import scipy.ndimage as ndi
import os
import pandas as pd
import cv2
import random

from sklearn.model_selection import train_test_split
import shutil
from PIL import Image

#imports for pyroch lighting
import lightning as L
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.loggers import TensorBoardLogger
from lightning.pytorch import Trainer
from lightning.pytorch.callbacks import Callback

from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from tqdm.notebook import tqdm
from torch import nn, optim
from torch.utils.data import DataLoader, random_split, Dataset
from torchvision import datasets, transforms
import torch.nn.functional as F
import torch

from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


Implementation of CNN-based Variatonal Autoencoder Model with Pytorch Lighting

In [5]:
#encoder implementation
#can use 4th conv layer in encoder and 1st transpose_conv layer in decoder for higher compression. Currently not used in forward functions

class Encoder(nn.Module):
    def __init__(self, latent_dim=64):
        super(Encoder, self).__init__()

        self.latent_dim = latent_dim #larger image, higher latent_dim might be better

        self.conv1 = nn.Conv2d(1, 16, kernel_size=3, padding=1, stride=2)  # 512x512 -> 256x256
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1, stride=2)  # 256 -> 128
        self.conv3 = nn.Conv2d(32, 64, kernel_size=3, padding=1, stride=2)  # 128 -> 64
        self.conv4 = nn.Conv2d(64, 128, kernel_size=3, padding=1, stride=2)  # 64 -> 32
        self.relu = nn.ReLU()

        #mean and log of var latent vectors. They need to be linear layers
        #linear layer dimensions: channel x final height x final width, chosen latent dimension

        self.fc_mu = nn.Linear(64 * 64 * 64, self.latent_dim)  # Adjust according to input size
        self.fc_log_var = nn.Linear(64 * 64 * 64, self.latent_dim)  # Adjust according to input size

    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.relu(self.conv2(x))
        x = self.relu(self.conv3(x))
        # x = self.relu(self.conv4(x)) #uncomment this for higher compression. Do the same to corresponding conv_transpose in decoder
        x = x.view(x.size(0), -1)  # flatten feature maps before passing into linear layers

        mu = self.fc_mu(x)
        log_var = self.fc_log_var(x) #using log of variance instead of regular variance for numerical stability and better convergence
        return mu, log_var

#VAE Decoder Implementation
class Decoder(nn.Module):
    def __init__(self, latent_dim=64): #latent_dim need to match that of encoder
        super(Decoder, self).__init__()

        self.latent_dim = latent_dim

        self.fc = nn.Linear(self.latent_dim, 64 * 64 * 64)  #Dimensions need to match encoder's nn.linear input, reversed

        self.deconv1 = nn.ConvTranspose2d(128, 64, kernel_size=3, padding=1, stride=2, output_padding=1) #reverse if encoder conv layers
        self.deconv2 = nn.ConvTranspose2d(64, 32, kernel_size=3, padding=1, stride=2, output_padding=1)
        self.deconv3 = nn.ConvTranspose2d(32, 16, kernel_size=3, padding=1, stride=2, output_padding=1)
        self.deconv4 = nn.ConvTranspose2d(16, 1, kernel_size=3, padding=1, stride=2, output_padding=1)
        self.sigmoid = nn.Sigmoid()
        self.relu = nn.ReLU()

    def forward(self, z):
        x = self.fc(z)
        x = x.view(x.size(0), 64, 64, 64)  # Reshape back to the conv layer's output shape
        # x = self.relu(self.deconv1(x)) #if using 1st layer, uncomment
        x = self.relu(self.deconv2(x))
        x = self.relu(self.deconv3(x))
        x = self.sigmoid(self.deconv4(x))
        return x

In [6]:
#reparam step for VAEs:
def reparameterize(mu, log_var):
    std = torch.exp(0.5 * log_var)
    epsilon = torch.randn_like(std)
    return mu + std * epsilon

In [7]:
#VAE Class
class VAE(L.LightningModule):
    def __init__(self, encoder, decoder):
        super(VAE, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.test_loss_list = []

    def forward(self, x):
        mu, log_var = self.encoder(x)
        z = reparameterize(mu, log_var)
        reconstructed_img = self.decoder(z)
        return reconstructed_img, mu, log_var #after one encoder-decoder pass, recon_img used to calculate recon_loss. (mu, log_var) for kl_divergence

    def _get_vae_loss(self, batch):
        x = batch
        x_hat, mu, log_var = self(x) #get from forward pass

        #recon loss
        recon_loss = F.mse_loss(x_hat, x, reduction='sum')

        #KL divergence loss
        kl_loss = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())

        return recon_loss + kl_loss

    def training_step(self, batch, batch_idx):
        loss = self._get_vae_loss(batch) #minimize on combined recon, kl loss
        self.log("train_loss", loss, logger=True, prog_bar=True, on_epoch=True)
        return loss

    def validation_step(self, batch, batch_idx):
        loss = self._get_vae_loss(batch)
        self.log("val_loss", loss, logger=True, prog_bar=True, on_epoch=True)
        return loss

    def test_step(self, batch, batch_idx):
        loss = self._get_vae_loss(batch)
        self.test_loss_list.append(loss.item())
        self.log("test_loss", loss, logger=True, prog_bar=True, on_epoch=True)
        return loss

    def on_test_epoch_end(self): #return the losses of test samples. Accssed via class-object.test_loss_list
        return {"test_loss_list": self.test_loss_list}

    def configure_optimizers(self):
        #add l2 reg term to reduce overfitting (weight_decay)
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3, weight_decay=1e-5)
        #lr scheduler
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.2, patience=10, min_lr=5e-5)
        return {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "val_loss"}


In [9]:
#VAE Initializations
latent_dim = 64 #latent dimension for the latent space of VAE. Larger might be better for larger images

encoder = Encoder(latent_dim=latent_dim)
decoder = Decoder(latent_dim=latent_dim)

vae = VAE(encoder, decoder)

Dataset implementation

Testing VAE with Mnist and Mnist fashion datset to see if model is working correctly. Fahsion dataset is anomly data. If working correctly, the Fashion Dataset should have higher loss than MNIST digits

In [14]:
###Mnist dataset###
class MNISTNoLabels(Dataset):
    def __init__(self, root, train=True, transform=None, mnistType=None):
        if mnistType == 'mnist':
            self.mnist = datasets.MNIST(root=root, train=train, download=True, transform=transform)
        elif mnistType == 'fashion':
            self.mnist = datasets.FashionMNIST(root=root, train=train, download=True, transform=transform)

    def __len__(self):
        return len(self.mnist)

    def __getitem__(self, idx):
        image, _ = self.mnist[idx]  # Discard the label, only return the image
        return image

# Use this custom dataset
transformMnist = transforms.Compose([
    transforms.Resize((512, 512)),
    transforms.ToTensor(), # Convert the image to a PyTorch tensor
])

transformMnistFashion = transforms.Compose([
    transforms.Resize((512, 512)),
    transforms.ToTensor(),
])

### MNIST ###
mnist_no_labels = MNISTNoLabels(root='mnist_data', train=True, transform=transformMnist, mnistType='mnist')

train_size = int(0.8 * len(mnist_no_labels))  # 80% for training
val_size = len(mnist_no_labels) - train_size  # 20% for validation
train_dataset_mnist, val_dataset_mnist = random_split(mnist_no_labels, [train_size, val_size])

train_loader_mnist = DataLoader(train_dataset_mnist, batch_size=64, shuffle=True)
val_loader_mnist = DataLoader(val_dataset_mnist, batch_size=64, shuffle=False)

test_dataset_mnist = MNISTNoLabels(root='mnist_data', train=False, transform=transformMnist, mnistType='mnist')
test_loader_mnist = DataLoader(test_dataset_mnist, batch_size=1, shuffle=False)

### Fashion MNIST ###
fashion_no_labels = MNISTNoLabels(root='fashion_mnist_data', train=True, transform=transformMnist, mnistType='fashion')

train_size = int(0.8 * len(fashion_no_labels))  # 80% for training
val_size = len(fashion_no_labels) - train_size  # 20% for validation
train_dataset_fashion, val_dataset_fashion = random_split(fashion_no_labels, [train_size, val_size])

train_loader_fashion = DataLoader(train_dataset_fashion, batch_size=64, shuffle=True)
val_loader_fashion = DataLoader(val_dataset_fashion, batch_size=64, shuffle=False)

test_dataset_fashion = MNISTNoLabels(root='fashion_mnist_data', train=False, transform=transformMnistFashion, mnistType='fashion')
test_loader_fashion = DataLoader(test_dataset_fashion, batch_size=1, shuffle=False)

In [15]:
#training on MNIST data (good)
trainerMnist = Trainer(accelerator="auto", max_epochs=10)
trainerMnist.fit(vae, train_loader_mnist, val_loader_mnist)

INFO: GPU available: True (cuda), used: True
INFO:lightning.pytorch.utilities.rank_zero:GPU available: True (cuda), used: True
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: HPU available: False, using: 0 HPUs
INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:lightning.pytorch.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO: 
  | Name    | Type    | Params | Mode
-------------------------------------------
0 | encoder | Encoder | 33.7 M | eval
1 | decoder | Decoder | 17.1 M | eval
-------------------------------------------
50.8 M    Trainable params
0         Non-trainable params
50.8 M    Total params
203.152   Total estimated model params size (MB)
0         Modules in train mode
16        Modules in eval mode
INFO:lightning.pytorch.callbacks.model_summary:
  | Name    | Type    | Params | Mode
--

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

INFO: `Trainer.fit` stopped: `max_epochs=10` reached.
INFO:lightning.pytorch.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=10` reached.


In [16]:
#testing on good (mnist) data
trainerMnist.test(vae, test_loader_mnist)

INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:lightning.pytorch.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing: |          | 0/? [00:00<?, ?it/s]

[{'test_loss': 513.9217529296875}]

In [17]:
#testing on bad (mnist fashion) data
trainerMnist.test(vae, test_loader_fashion)

INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:lightning.pytorch.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing: |          | 0/? [00:00<?, ?it/s]

[{'test_loss': 21137.888671875}]