In [1]:
import matplotlib.pyplot as plt  # plotting library
import numpy as np
import pandas as pd
import random
import io

import torch
import tensorflow as tf
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader, random_split
from torch.utils.tensorboard import SummaryWriter
from torch import nn
import torch.nn.functional as F
import torch.optim as optim
import einops
from einops.layers.torch import Rearrange


2023-01-06 12:44:25.595178: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-01-06 12:44:26.593856: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /home/advait/miniconda3/envs/rohan/lib/python3.9/site-packages/nvidia/cublas/lib/:/home/advait/miniconda3/envs/rohan/lib/python3.9/site-packages/nvidia/cublas/lib/::/home/advait/miniconda3/envs/rohan/lib/
2023-01-06 12:44:26.593961: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: c

In [2]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print(f"Selected device: {device}")

writer = SummaryWriter()

data_dir = "dataset"
train_dataset = torchvision.datasets.MNIST(data_dir, train=True, download=True)
test_dataset = torchvision.datasets.MNIST(data_dir, train=False, download=True)

train_transform = test_transform = transforms.ToTensor()
train_dataset.transform = train_transform
test_dataset.transform = test_transform

m = len(train_dataset)

train_data, val_data = random_split(train_dataset, [int(m - m * 0.2), int(m * 0.2)])
batch_size = 256

train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size)
valid_loader = torch.utils.data.DataLoader(val_data, batch_size=batch_size)
test_loader = torch.utils.data.DataLoader(
    test_dataset, batch_size=batch_size, shuffle=True
)


Selected device: cuda


In [3]:
class Encoder(nn.Module):
    def __init__(self, latent_dims):
        super(Encoder, self).__init__()
        self.main_block = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=5, stride=2, padding=0),  # 1x28x28 -> 32x12x12
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=5, stride=2, padding=0),  # 32x12x12 -> 64x4x4
            nn.ReLU(),
            Rearrange("b c h w -> b (c h w)"),
            nn.Linear(64 * 4 * 4, 64 * 4 * 4),
            nn.ReLU(),
            nn.Linear(64 * 4 * 4, 64 * 4 * 4),
            nn.ReLU(),
        )
        self.mu = nn.Sequential(
            nn.LazyLinear(latent_dims),
        )
        self.log_var = nn.Sequential(
            nn.LazyLinear(latent_dims),
            nn.ReLU(),
        )

    def sample(self, mu, log_var):
        std = log_var.mul(0.5).exp_()
        esp = torch.randn(*std.size()).to(device)
        z = mu + std * esp
        return z

    def forward(self, x):
        x = x.to(device)
        x = self.main_block(x)
        mu = self.mu(x)
        log_var = self.log_var(x)
        z = self.sample(mu, log_var)
        return z, mu, log_var


class Decoder(nn.Module):
    def __init__(self, latent_dims):
        super(Decoder, self).__init__()
        self.main_block = nn.Sequential(
            nn.Linear(latent_dims, 64 * 4 * 4),
            nn.ReLU(),
            nn.Linear(64 * 4 * 4, 64 * 4 * 4),
            nn.ReLU(),
            Rearrange("b (c h w) -> b c h w", c=64, h=4, w=4),
            nn.ConvTranspose2d(
                64, 32, 5, stride=2, output_padding=1
            ),  # 64x4x4 -> 32x12x12
            nn.ReLU(),
            nn.ConvTranspose2d(
                32, 1, 5, stride=2, output_padding=1
            ),  # 32x12x12 -> 1x28x28
            nn.Sigmoid(),
        )

    def forward(self, x):
        x = x.to(device)
        return self.main_block(x)


class VariationalAutoencoder(nn.Module):
    def __init__(self, latent_dims):
        super(VariationalAutoencoder, self).__init__()
        self.encoder = Encoder(latent_dims)
        self.decoder = Decoder(latent_dims)

    def forward(self, x):
        x = x.to(device)
        z, mu, log_var = self.encoder(x)
        return self.decoder(z), z, mu, log_var


In [4]:
# Initialize model

torch.manual_seed(0)

latent_dims = 10
lr = 1e-3

vae = VariationalAutoencoder(latent_dims=latent_dims)
optim = torch.optim.Adam(vae.parameters(), lr=lr, weight_decay=1e-5)

vae.to(device)




VariationalAutoencoder(
  (encoder): Encoder(
    (main_block): Sequential(
      (0): Conv2d(1, 32, kernel_size=(5, 5), stride=(2, 2))
      (1): ReLU()
      (2): Conv2d(32, 64, kernel_size=(5, 5), stride=(2, 2))
      (3): ReLU()
      (4): Rearrange('b c h w -> b (c h w)')
      (5): Linear(in_features=1024, out_features=1024, bias=True)
      (6): ReLU()
      (7): Linear(in_features=1024, out_features=1024, bias=True)
      (8): ReLU()
    )
    (mu): Sequential(
      (0): LazyLinear(in_features=0, out_features=10, bias=True)
    )
    (log_var): Sequential(
      (0): LazyLinear(in_features=0, out_features=10, bias=True)
      (1): ReLU()
    )
  )
  (decoder): Decoder(
    (main_block): Sequential(
      (0): Linear(in_features=10, out_features=1024, bias=True)
      (1): ReLU()
      (2): Linear(in_features=1024, out_features=1024, bias=True)
      (3): ReLU()
      (4): Rearrange('b (c h w) -> b c h w', c=64, h=4, w=4)
      (5): ConvTranspose2d(64, 32, kernel_size=(5, 5), s

In [5]:
def loss_function(x_hat, x, mu, log_var):
    bce = F.binary_cross_entropy(x_hat, x, reduction="sum")
    kl = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
    return bce + kl


def train_epoch(vae, device, dataloader, optimizer):
    # Set train mode for both the encoder and the decoder
    vae.train()
    train_loss = 0.0
    # Iterate the dataloader (we do not need the label values, this is unsupervised learning)
    for x, _ in dataloader:
        # Move tensor to the proper device
        x = x.to(device)
        x_hat, _, mu, log_var = vae(x)
        loss = loss_function(x_hat, x, mu, log_var)

        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        train_loss += loss.item()

    return train_loss / len(dataloader.dataset)


def test_epoch(vae, device, dataloader):
    # Set evaluation mode for encoder and decoder
    vae.eval()
    val_loss = 0.0
    with torch.no_grad():  # No need to track the gradients
        for x, _ in dataloader:
            # Move tensor to the proper device
            x = x.to(device)
            x_hat, _, mu, log_var = vae(x)
            loss = loss_function(x_hat, x, mu, log_var)
            val_loss += loss.item()

    return val_loss / len(dataloader.dataset)


def plot_ae_outputs(vae, n=10):
    plt.figure(figsize=(16, 4.5))
    targets = test_dataset.targets.numpy()
    t_idx = {i: np.where(targets == i)[0][0] for i in range(n)}
    for i in range(n):
        ax = plt.subplot(2, n, i + 1)
        img = test_dataset[t_idx[i]][0].unsqueeze(0).to(device)
        vae.eval()
        with torch.no_grad():
            x_hat, _, _, _ = vae(img)
        plt.imshow(img.cpu().squeeze().numpy(), cmap="gist_gray")
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)
        if i == n // 2:
            ax.set_title("Original images")
        ax = plt.subplot(2, n, i + 1 + n)
        plt.imshow(x_hat.cpu().squeeze().numpy(), cmap="gist_gray")
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)
        if i == n // 2:
            ax.set_title("Reconstructed images")
    return plt.gcf()


In [8]:
max_epochs = 500
val_losses = []

print("Beginning training")
writer.add_hparams({
    "latent_dims": "int",
    "lr": "float",
    "batch_size": "int",
}, {
    "latent_dims": latent_dims,
    "lr": lr,
    "batch_size": batch_size,
})

fig = plot_ae_outputs(vae, n=10)
writer.add_figure("VAE Output", fig, 0)

for epoch in range(max_epochs):
    train_loss = train_epoch(vae, device, train_loader, optim)
    val_loss = test_epoch(vae, device, valid_loader)
    val_losses.append(val_loss)

    writer.add_scalar("Loss/train", train_loss, epoch)
    writer.add_scalar("Loss/val", val_loss, epoch)
    fig = plot_ae_outputs(vae, n=10)
    writer.add_figure("VAE Output", fig, epoch)

    writer.flush()

    print(
        "\n EPOCH {}/{} \t train loss {:.3f} \t val loss {:.3f}".format(
            epoch + 1, max_epochs, train_loss, val_loss
        )
    )
    if len(val_losses) >= 5 and val_losses[-5] < val_loss:
        print("Validation loss is no longer shrinking. Quitting early.")
        break

writer.flush()
writer.close()

Beginning training

 EPOCH 1/500 	 train loss 138.552 	 val loss 138.983

 EPOCH 2/500 	 train loss 138.409 	 val loss 138.573

 EPOCH 3/500 	 train loss 138.329 	 val loss 138.751

 EPOCH 4/500 	 train loss 138.300 	 val loss 138.829

 EPOCH 5/500 	 train loss 138.181 	 val loss 138.697

 EPOCH 6/500 	 train loss 138.048 	 val loss 138.284

 EPOCH 7/500 	 train loss 138.019 	 val loss 138.537

 EPOCH 8/500 	 train loss 138.001 	 val loss 138.458

 EPOCH 9/500 	 train loss 137.796 	 val loss 138.672

 EPOCH 10/500 	 train loss 137.862 	 val loss 138.502
Validation loss is no longer shrinking. Quitting early.


In [None]:
# Random samples from the latent space

def show_image(img):
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))

vae.eval()
with torch.no_grad():
    # sample latent vectors from the normal distribution
    latent = torch.randn(100, latent_dims, device=device)

    # reconstruct images from the latent vectors
    img_recon = vae.decoder(latent)
    img_recon = img_recon.cpu()

    fig, ax = plt.subplots(figsize=(20, 8.5))
    show_image(torchvision.utils.make_grid(img_recon.data, 10, 5))
    plt.show()


In [None]:
# Average z and std of a bunch of zeroes

zeroes = [
    train_dataset.data[i] for i, label in enumerate(train_dataset.targets) if label == 0
]
zeroes = zeroes[:100]
zeroes = einops.rearrange(torch.stack(zeroes).type(torch.float32), "b h w -> b 1 h w")
zeroes.to(device)
with torch.no_grad():
    vae.eval()
    _, mu, log_var = vae.encoder(zeroes)

    # Average mu and log_var for zeroes in our dataset
    mu_avg = einops.reduce(mu, "b mu -> mu", "mean").cpu()
    log_var_avg = einops.reduce(log_var, "b l -> l", "mean").cpu()
    std_avg = log_var.cpu().mul(0.5).exp_()

    # Generate a bunch of zeroes from latent sampling
    esp = torch.randn(100, latent_dims, dtype=torch.float32)
    z = (esp.mul(std_avg).add(esp)).to(device)

    img_recon = vae.decoder(z)
    img_recon = img_recon.cpu()

    fig, ax = plt.subplots(figsize=(20, 8.5))
    show_image(torchvision.utils.make_grid(img_recon.data, 10, 5))
    plt.show()
