In [None]:
import os
from pathlib import Path

import lightning as L
import matplotlib.pyplot as plt
import numpy as np
import rootutils
import torch
import torch.nn as nn
import torch.nn.functional as F
from mpl_toolkits.axes_grid1 import ImageGrid

root_p = rootutils.setup_root(search_from=os.getcwd(), indicator=".project-root")

from src.data.mnist_datamodule import MNISTDataModule
from src.models.vae_components.vanilla_vae import VAE
from src.models.vae_module import VAEModule

%load_ext autoreload
%autoreload 2

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dm = MNISTDataModule(num_workers=0, transform="default")
model = VAE(input_dim=784, hidden_dim=400, latent_dim=200)

# pl_module = VAEModule(model, lr=0.001)
# Implement training routine myself
dm.setup("fit")
model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [None]:
# Test dataloader
dataiter = iter(dm.train_dataloader())
batch = next(dataiter)

num_samples = 25
sample_images = [batch[0][i, 0] for i in range(num_samples)]

fig = plt.figure(figsize=(5, 5))
grid = ImageGrid(fig, 111, nrows_ncols=(5, 5), axes_pad=0.1)

for ax, im in zip(grid, sample_images):
    ax.imshow(im, cmap="gray")
    ax.axis("off")

plt.show()

In [None]:
# Define training routine


def loss_function(x, x_hat, mean, log_var):
    # Loss function from https://github.com/AntixK/PyTorch-VAE/blob/master/models/vanilla_vae.py
    recons_loss = nn.functional.mse_loss(x_hat, x)
    kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mean**2 - log_var.exp(), dim=1), dim=0)

    return recons_loss + kld_loss


def train(model, optimizer, epochs, device):
    model.train()
    for epoch in range(epochs):
        overall_loss = 0
        for batch_idx, (x, _) in enumerate(dm.train_dataloader()):
            N, C, H, W = x.size()
            x = x.reshape(N, -1)
            x = x.to(device)

            optimizer.zero_grad()

            x_hat, mean, log_var = model(x)
            loss = loss_function(x, x_hat, mean, log_var)

            overall_loss += loss.item()

            loss.backward()
            optimizer.step()

        print("\tEpoch", epoch + 1, "\tAverage Loss: ", overall_loss / (batch_idx * N))
    return overall_loss


train(model, optimizer, epochs=10, device=device)

In [None]:
# Load model from Lightning checkpoint
pl_module = VAEModule.load_from_checkpoint(
    root_p / "logs" / "vanilla_vae" / "runs" / "2023-12-20_11-01-43" / "checkpoints" / "last.ckpt",
    model=model,
)
model = pl_module.model
model.eval()


def generate_digit(mean, var):
    z_sample = torch.tensor([[mean, var]], dtype=torch.float).to(device)
    x_decoded = model.decode(z_sample)
    digit = x_decoded.detach().cpu().reshape(28, 28)  # reshape vector to 2d array
    plt.imshow(digit, cmap="gray")
    plt.axis("off")
    plt.show()


generate_digit(0.0, 1.0), generate_digit(1.0, 0.0)

In [None]:
def plot_latent_space(model, scale=1.0, n=25, digit_size=28, figsize=15):
    # display a n*n 2D manifold of digits
    figure = np.zeros((digit_size * n, digit_size * n))

    # construct a grid
    grid_x = np.linspace(-scale, scale, n)
    grid_y = np.linspace(-scale, scale, n)[::-1]

    for i, yi in enumerate(grid_y):
        for j, xi in enumerate(grid_x):
            z_sample = torch.tensor([[xi, yi]], dtype=torch.float).to(device)
            x_decoded = model.decode(z_sample)
            digit = x_decoded[0].detach().cpu().reshape(digit_size, digit_size)
            figure[
                i * digit_size : (i + 1) * digit_size,
                j * digit_size : (j + 1) * digit_size,
            ] = digit

    plt.figure(figsize=(figsize, figsize))
    plt.title("VAE Latent Space Visualization")
    start_range = digit_size // 2
    end_range = n * digit_size + start_range
    pixel_range = np.arange(start_range, end_range, digit_size)
    sample_range_x = np.round(grid_x, 1)
    sample_range_y = np.round(grid_y, 1)
    plt.xticks(pixel_range, sample_range_x)
    plt.yticks(pixel_range, sample_range_y)
    plt.xlabel("mean, z [0]")
    plt.ylabel("var, z [1]")
    plt.imshow(figure, cmap="Greys_r")
    plt.show()


plot_latent_space(model)