**GROUP 1 : ASSIGNMENT 2 : VAE**

In [3]:
import torch
import time
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
from torchvision.transforms import InterpolationMode
import os
import zipfile
import random
from PIL import Image
# from torch.utils.data import Dataset, DataLoader
# from scipy.linalg import sqrtm
from torchvision import transforms, models
# import pandas as pd
# import json
from torch.optim import lr_scheduler
import torch.nn.functional as F


In [4]:
# Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
# Dont forget to use the GPU using .to(device)!
# Use anaconda3/bin/python 3.12.4 base environment while using FIST server

cuda


In [None]:
from google.colab import drive
drive.mount('/content/drive')

**Organising Data for Butterfly and Animal dataset and also DataAugmentation**

In [None]:
# Butterfly dataset
# train has 6499 images from 75 classes
# test has 2786 images

# Convert the image to a tensor and normalize it between 0 and 1
transform = transforms.Compose([transforms.Resize((128, 128)),transforms.ToTensor()])
# Load CSV containing filenames and labels
df = pd.read_csv(r"C:\ADRL data\Butterfly dataset\Training_set.csv")
N = len(df)
# Create a mapping of distinct labels to indices
class_names = sorted(df['label'].unique())
class_to_idx = {class_name: idx for idx, class_name in enumerate(class_names)}

# Save the class_to_idx dictionary to a JSON file
with open(r"C:\ADRL data\Butterfly dataset\class_to_idx.json", 'w') as f:
    json.dump(class_to_idx, f)
with open(r"C:\ADRL data\Butterfly dataset\class_to_idx.json", 'r') as f:
    loaded_class_to_idx = json.load(f)
print(loaded_class_to_idx)

# Initialize a tensor to store the images
# Shape: (N, 3 (channels), 128 (H), 128 (W))
image_tensor = torch.zeros((N, 3, 128, 128))
# Initialize a tensor to store class indices
# Shape: (N,)
class_indices = torch.zeros(N, dtype=torch.long)  # Store class indices

# Process each image and store it in the appropriate class index
for i, row in df.iterrows():
    print(i)
    img_path = os.path.join(r"C:\ADRL data\Butterfly dataset\train", row['filename'])  # Assuming images are in the 'train' folder
    label = row['label']

    # Load the image
    img = Image.open(img_path).convert('RGB')  # Ensure image is in RGB mode
    # (3,224,224)

    # Apply transformations
    img_tensor = transform(img)
    # Get class index
    class_idx = class_to_idx[label]

    # Store the image tensor and class index
    image_tensor[i] = img_tensor
    class_indices[i] = class_idx

torch.save(image_tensor, r"C:\ADRL data\Butterfly dataset\butterfly_training_images.pth")
torch.save(class_indices, r"C:\ADRL data\Butterfly dataset\butterfly_training_classindices.pth")
# Now, image_tensor contains the (N, 3, 128, 128) tensor.
print(f'Successfully created tensor with shape: {image_tensor.shape}')


In [None]:
# Loading the dataset in [0,1] range
animal_images = torch.load("/home/sahapthank/models/training_images.pt")
butterfly_images = torch.load("/home/sahapthank/models/butterfly_training_images.pth" , weights_only = True)
butterfly_labels = torch.load("/home/sahapthank/models/butterfly_training_classindices.pth")
anime_images = torch.load("/home/sahapthank/models/anime_images.pth")
animal_images_augmented = torch.load("/home/sahapthank/models/animal_images_augmented.pth")
butterfly_images_augmented = torch.load("/home/sahapthank/models/butterfly_images_augmented.pth")


In [5]:
# DATA AUGMENTATION FOR BUTTERFLY
horizontal_flip = transforms.RandomHorizontalFlip(p=1)    # Randomly flip the image horizontally
rotation = transforms.RandomRotation(degrees=20)    # Random rotation with expanding size
color_jitter = transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1)  # Random color adjustments

tensor_shape = (6499,3,128,128)
flipped_images = torch.zeros(tensor_shape)
rotated_images = torch.zeros(tensor_shape)
color_jittered_images = torch.zeros(tensor_shape)

for i in range(6499):
    flipped_images[i] = horizontal_flip(butterfly_images[i])
    rotated_images[i] = rotation(butterfly_images[i])
    color_jittered_images[i]= color_jitter(butterfly_images[i])

butterfly_images_augmented = torch.stack([butterfly_images,flipped_images,rotated_images,color_jittered_images])

In [None]:
# DATA AUGMENTATION FOR ANIMAL
horizontal_flip = transforms.RandomHorizontalFlip(p=1)    # Randomly flip the image horizontally
rotation = transforms.RandomRotation(degrees=20)    # Random rotation with expanding size
color_jitter = transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1)  # Random color adjustments

tensor_shape = (90,60,3,128,128)
flipped_images = torch.zeros(tensor_shape)
rotated_images = torch.zeros(tensor_shape)
color_jittered_images = torch.zeros(tensor_shape)

for i in range(90):
    for j in range(60):
        flipped_images[i][j] = horizontal_flip(animal_images[i][j])
        rotated_images[i][j] = rotation(animal_images[i][j])
        color_jittered_images[i][j] = color_jitter(animal_images[i][j])

animal_images_augmented = torch.stack([animal_images,flipped_images,rotated_images,color_jittered_images])

**[Q1] Vanilla VAE implementation**

**Some implementation Observations**
* In the ENCODER we use {1 NN} that use the features extracted from CNN layers to predict mean and log(variance). Since this can be any real number using ReLU as final activation saturates the learning after 50 steps and variance learnt becomes 0 collapsing all images!
* Therefore to cover the entire range space of (-inf,inf) there is no last activation function and nn.Identity is used!
* Since VAE is trained using Pixel-based MSE_loss it is better to keep the images in [0,1] as opposed to [-1,1] while using GAN. Therefore Tanh is replaced by Sigmoid in the last layer of DECODER
* Without batch normalisation even prior to training just after initialisation the DECODER gives images which are just "black" meaning most probably "nan". Therefore batch-normalisation is crucial similar to Generator of GAN in all layers except the last one
* KL divergence becoming close to zero very quickly and saturating..so we define kl_weight to manually control weightage given to it and focus first on reconstruction loss only









In [None]:
# In a vanilla VAE, the encoder does following:
# 1) Input: Image of shape (bs,3,128,128)
# 2) Outputs 2 things for a fixed (x_i) when z is d-dimensional
# Mean vector (μ) of shape (bs,d) of {q_phi(z|x_i)}
# Log variance vector (log(σ²)) of shape (bs,d) since co-variance is assumed to be diagonal [dimensions are indep]
# This is done to prevent the model from learning negative values of variance

class Vanilla_VAE_Encoder(nn.Module):
    def __init__(self, architecture):
        # architecture = [cnn_architecture, nn_architecute]
        # nn_architecture = [nn_layers, nn_batch_norm , nn_activation]
        # cnn_architecture = [conv_params, cnn_batchnorm, cnn_activation_fn]
        # conv_params = [[in_channels, output_channels, kernel_size, stride, padding], ...]
        # nn/cnn_batchnorm = [False, True, ...]
        # nn/cnn_activation = [nn.ReLU(), nn.LeakyReLU(), ...]

        super(Vanilla_VAE_Encoder, self).__init__()
        assert len(architecture[0][0]) == len(architecture[0][1]) == len(architecture[0][2])
        assert len(architecture[1][0]) == len(architecture[1][1]) == len(architecture[1][2])
        self.conv_params = architecture[0][0]
        self.cnn_batchnorm = architecture[0][1]
        self.cnn_activation = architecture[0][2]
        self.nn_layers = architecture[1][0]
        self.nn_batchnorm = architecture[1][1]
        self.nn_activation = architecture[1][2]

        layers_0 = []
        for j,i in enumerate(self.conv_params):
            (layers_0).append(nn.Conv2d(in_channels=i[0],out_channels=i[1],kernel_size=i[2],stride=i[3],padding=i[4]))
            if (self.cnn_batchnorm)[j]:
                (layers_0).append(nn.BatchNorm2d(i[1]))
            (layers_0).append(self.cnn_activation[j])

        layers_1 = []
        for i in range(len(self.nn_layers)):
            in_dim, out_dim = self.nn_layers[i]
            (layers_1).append(nn.Linear(in_dim, out_dim))
            if self.nn_batchnorm[i]:
                (layers_1).append(nn.BatchNorm1d(out_dim))
            (layers_1).append(self.nn_activation[i])

        # Stack convolutional layers in a Sequential block
        # Create a single FFNN for mu and sigma o/p is "2d" dimension
        self.cnn = nn.Sequential(*(layers_0))
        self.nn = nn.Sequential(*(layers_1))

    def forward(self, x):
        features = self.cnn(x)
        # Flatten the tensor for the fully connected layers
        features = torch.flatten(features, start_dim=1)
        c = self.nn(features) #(bs,2d)
        d = int(c.size(1)/2)
        mu = c[:, :d]  #(bs,d)
        logvar = c[:, d:] #(bs,d)
        return mu,logvar

    # We will use this to prevent calculating gradients wrt parameters of Encoder
    def set_requires_grad(self, requires_grad):
        for param in self.parameters():
            param.requires_grad = requires_grad


In [None]:
# The decoders input is z = (bs,d) and outputs image y similar to Generator of DCGAN
# While inference we only need the Decoder and can pass any random z to get an image
# While trying to optimize using MSE loss it is implicitly assumed that the decoder outputs the mean image
# Note that VAE work on PIXEL-BASED LOSSES so better in range [0,1]

class Vanilla_VAE_Decoder_TransposeConv(nn.Module):
    def __init__(self, architecture):
        # The architecture is almost same as DCGAN
        # Suppose we use transpose_convolutions then we first use 1 Linear layer to convert to higher dim
        # Then we apply series of transpose convolutions!

        super(Vanilla_VAE_Decoder_TransposeConv, self).__init__()
        assert len(architecture[0]) == len(architecture[1]) == len(architecture[2])
        self.transpose_conv_params = architecture[0]
        self.use_batchnorm = architecture[1]
        self.activation_fn = architecture[2]
        layers = []
        # Starting with input latent vector (e.g., size d = 100)
        for j, i in enumerate(self.transpose_conv_params):
            layers.append(nn.ConvTranspose2d(in_channels=i[0], out_channels=i[1], kernel_size=i[2], stride=i[3], padding=i[4]))
            if self.use_batchnorm[j]:
                layers.append(nn.BatchNorm2d(i[1]))
            layers.append(self.activation_fn[j])
        # nn.Sigmoid() is used as the final activation function to [0,1]
        self.model = nn.Sequential(*layers)

    # Sampling images from the generator G
    def forward(self, x):
        return self.model(x)

    # We will use this to prevent calculating gradients wrt parameters of Decoder
    def set_requires_grad(self, requires_grad):
        for param in self.parameters():
            param.requires_grad = requires_grad



In [None]:
# Reparameterisation trick
# This acts as g_phi(e,x_i) as in the paper
# g is a differentiable transformation
def reparam_Gaussian(mu,sigma,L):
    # General method is using Cholensky Decomposition
    # But here lets use diagonal matrix to simplify things
    # Ensure mu,sigma are of shape (bs,d)
    bs = mu.size(0)
    d = mu.size(1)
    e = torch.distributions.MultivariateNormal(torch.zeros((d,)), torch.eye(d)).sample([bs, L]).to(device)
    # e is of shape (bs,L,d)
    # To broadcast we make it mu,sigma (bs,1,d)
    # (sigma * e ) is done elementwise
    mu = mu.unsqueeze(1)
    sigma = sigma.unsqueeze(1)
    z = (mu + (sigma * e))
    return z #(bs,L,d)

# Sample from a Gaussian for inference
def sample_gaussian(mean, covariance_matrix , n_samples):
    # mean should be of size ([d])
    # variance is the covariance matrix of size ([d,d]) must be +ve semidefinite
    z = torch.distributions.MultivariateNormal(mean,covariance_matrix)
    return z.sample([n_samples])
    # returns of size ([n_samples,d])

# Sample from Real_images [BUTTERFLY]
def sample_real_butterfly(n_samples):
    i_indices = torch.randint(0, 6499, (n_samples,)).tolist()
    sampled_images = []
    for i in i_indices:
        sampled_images.append(butterfly_images[i])
    sampled_images_tensor = torch.stack(sampled_images)
    return sampled_images_tensor

# Sample from Real_images [ANIMAL]
def sample_real_animal(n_samples):
    i_indices = torch.randint(0, 90, (n_samples,)).tolist() # chooses class
    j_indices = torch.randint(0, 60, (n_samples,)).tolist() # chooses image in class
    sampled_images = []
    for i, j in zip(i_indices, j_indices):
        sampled_images.append(animal_images[i, j])
    sampled_images_tensor = torch.stack(sampled_images)
    return sampled_images_tensor

# Sample from Real_images [ANIMAL] including augmented
def sample_real_augmented_animal(n_samples):
    i_indices = torch.randint(0, 90, (n_samples,)).tolist() # chooses class
    j_indices = torch.randint(0, 60, (n_samples,)).tolist() # chooses image in class
    k_indices = torch.randint(0, 4, (n_samples,)).tolist() # chooses augmentation type
    sampled_images = []
    for i, j , k in zip(i_indices, j_indices , k_indices):
        sampled_images.append(animal_images_augmented[k, i, j])
    sampled_images_tensor = torch.stack(sampled_images)
    return sampled_images_tensor

# Sample from Real_images [BUTTERFLY] including augmented
def sample_real_augmented_butterfly(n_samples):
    i_indices = torch.randint(0, 6499, (n_samples,)).tolist() # chooses image
    k_indices = torch.randint(0, 4, (n_samples,)).tolist() # chooses augmentation type
    sampled_images = []
    for i, k in zip(i_indices, k_indices):
        sampled_images.append(butterfly_images_augmented[k, i])
    sampled_images_tensor = torch.stack(sampled_images)
    return sampled_images_tensor

# Sample from Real_images [ANIME]
def sample_real_anime(n_samples):
    i_indices = torch.randint(0, 21551, (n_samples,)).tolist()
    sampled_images = []
    for i in i_indices:
        sampled_images.append(anime_images[i])
    sampled_images_tensor = torch.stack(sampled_images)
    return sampled_images_tensor


In [None]:
# Formula for CNN convolutions
# output_size = ({input_size - kernel_size + 2*padding}/stride) + 1
# Formula for Transpose Conv
# Assume input size is [d_1,1,1] gets converted to [3,128,128]
# output_size = {(input_size - 1) * stride} - (2*padding) + kernel_size}

# TRIAL 1
# d_1 = 128
# conv_params_1 = [[3,16,4,2,1],[16,16,4,2,1],[16,32,4,2,1],[32,64,4,2,1]]
# cnn_batchnorm_1 = [False]+ [True for _ in range(3)]
# cnn_activation_1 = [nn.LeakyReLU(0.2, inplace=False) for _ in range(4)]
# nn_layers_1 = [[4096,1024],[1024,2*d_1]]
# nn_batchnorm_1 = [True for _ in range(1)] + [False]
# nn_activation_1 = [nn.ReLU() for _ in range(1)] + [nn.Identity()]
# cnn_arch_1 = [conv_params_1, cnn_batchnorm_1, cnn_activation_1]
# nn_arch_1 = [nn_layers_1, nn_batchnorm_1, nn_activation_1]
# E_arch_1 = [cnn_arch_1, nn_arch_1]
# transpose_conv_params_1 = [[d_1,512,4,1,0],[512,256,4,2,1],[256,128,4,2,1],[128,64,4,2,1],[64,32,4,2,1],[32,3,4,2,1]]
# use_batchnorm_1 = [True for _ in range(5)] + [False]
# activation_fn_1 = [nn.ReLU(inplace = False) for _ in range(5)] + [nn.Sigmoid()]
# D_arch_1 = [transpose_conv_params_1, use_batchnorm_1, activation_fn_1]

# TRIAL 2
# TRIAL 2 [17,481,424 3,838,371]
# d_1 = 128
# conv_params_1 = [[3,16,4,2,1],[16,32,4,2,1],[32,64,4,2,1],[64,128,4,2,1]]
# cnn_batchnorm_1 = [True for _ in range(4)]
# cnn_activation_1 = [nn.LeakyReLU(0.2, inplace=False) for _ in range(4)]
# nn_layers_1 = [[8192,2048],[2048,2*d_1]]
# nn_batchnorm_1 = [True for _ in range(1)] + [False]
# nn_activation_1 = [nn.ReLU() for _ in range(1)] + [nn.Identity()]
# cnn_arch_1 = [conv_params_1, cnn_batchnorm_1, cnn_activation_1]
# nn_arch_1 = [nn_layers_1, nn_batchnorm_1, nn_activation_1]
# E_arch_1 = [cnn_arch_1, nn_arch_1]
# transpose_conv_params_1 = [[d_1,512,4,1,0],[512,256,4,2,1],[256,128,4,2,1],[128,64,4,2,1],[64,32,4,2,1],[32,3,4,2,1]]
# use_batchnorm_1 = [True for _ in range(5)] + [False]
# activation_fn_1 = [nn.ReLU(inplace = False) for _ in range(5)] + [nn.Sigmoid()]
# D_arch_1 = [transpose_conv_params_1, use_batchnorm_1, activation_fn_1]

# TRIAL 3
d_1 = 128
conv_params_1 = [[3,16,4,2,1],[16,32,4,2,1],[32,64,4,2,1],[64,128,4,2,1],[128,256,4,2,1],[256,512,4,1,0]]
cnn_batchnorm_1 = [True for _ in range(6)]
cnn_activation_1 = [nn.LeakyReLU(0.2, inplace=False) for _ in range(6)]
nn_layers_1 = [[512,512],[512,2*d_1]]
nn_batchnorm_1 = [True for _ in range(1)] + [False]
nn_activation_1 = [nn.ReLU() for _ in range(1)] + [nn.Identity()]
cnn_arch_1 = [conv_params_1, cnn_batchnorm_1, cnn_activation_1]
nn_arch_1 = [nn_layers_1, nn_batchnorm_1, nn_activation_1]
E_arch_1 = [cnn_arch_1, nn_arch_1]
transpose_conv_params_1 = [[d_1,512,4,1,0],[512,256,4,2,1],[256,128,4,2,1],[128,64,4,2,1],[64,32,4,2,1],[32,3,4,2,1]]
use_batchnorm_1 = [True for _ in range(5)] + [False]
activation_fn_1 = [nn.ReLU(inplace = False) for _ in range(5)] + [nn.Sigmoid()]
D_arch_1 = [transpose_conv_params_1, use_batchnorm_1, activation_fn_1]



In [None]:
# Vanilla VAE is based on Auto-Encoding Variational Bayes paper
# Approximates the log-likelihood of the data = KL(P_r,P_theta)
# EM fails when we use NN to model q(z|x) due to intractability so we use Variational method
# A = E_{q_phi(z|x)}(log p_theta(x|z)) is the reconstruction-term
# B = D_KL(q_phi(z|x) || p(z)) is the prior-matching term
# C = D_KL(q_phi(z|x) || p_theta(z|x))
# We get l_theta(x) := log(p_theta(x)) = A - B + C = ELBO + C = F_theta(q_phi) + C
# C is intractable so (A-B) is the one optimised to get tighter bounds for l_theta(x)
# For ease of computations q_phi(z|x) and p_theta(z) are assumed to be Gaussians!
# p_theta(z) is actually fixed to be N(0,I) but can be modified to make it learnable
# q_phi(z|x) is N(z,mu_phi(x),sigma_squared_phi(x)I)

# Initialize model, loss function, and optimizer
# l is for #samples used while reparameterising
lr_VAE, m, d = 1e-4, 64, 128
E_arch = E_arch_1
D_arch = D_arch_1
E = Vanilla_VAE_Encoder(E_arch).to(device)
D = Vanilla_VAE_Decoder_TransposeConv(D_arch).to(device)
elbo = []
A_loss = []
B_loss = []
vae_params = list(E.parameters()) + list(D.parameters())
vae_optimizer = optim.Adam(vae_params, lr=lr_VAE, betas=(0.8, 0.98), eps=1e-8, weight_decay=0)
# vae_optimizer = optim.RMSprop(vae_params, lr=lr_VAE, alpha=0.99, eps=1e-8, weight_decay=0)
# vae_optimizer = optim.SGD(vae_params, lr=lr_VAE, momentum=0.9, weight_decay=0)
# scheduler = lr_scheduler.ReduceLROnPlateau(vae_optimizer, mode='min', factor=0.7, patience=5, verbose=True)


In [None]:
# For manual changing during training
lr_VAE, m, d = 1e-4, 64, 128
vae_optimizer = optim.Adam(vae_params, lr=lr_VAE, betas=(0.8, 0.98), eps=1e-8, weight_decay=0)
# vae_optimizer = optim.RMSprop(vae_params, lr=lr_VAE, alpha=0.99, eps=1e-8, weight_decay=0)
# vae_optimizer = optim.SGD(vae_params, lr=lr_VAE, momentum=0.9, weight_decay=0)


In [None]:
# We use the "2nd Loss function mentioned in the paper := L_b"
# Since B is KL btw 2 gaussians and is known analytically [like regularisation]
# A is estimated using Monte-Carlo
# Vanilla VAE training algorithm
num_epochs = 50
mini_batch_epochs = 101
kl_weight = 1e-5 * (1)
L = 3
mse_loss = nn.MSELoss()
# Training loop
for epoch in range(num_epochs):
    for _ in range(mini_batch_epochs):
        D.set_requires_grad(True)
        E.set_requires_grad(True)
        vae_optimizer.zero_grad()

        x = sample_real_butterfly(m).to(device)
        mu_x, log_sigma_squared_x = E(x)
        sigma_squared_x = torch.exp(log_sigma_squared_x)
        sigma_x = torch.sqrt(sigma_squared_x)

        # First of all the prior term B can be computed
        # Each x_i we get a scalar
        # We do a sum over dimensions d
        # So final shape of B is (bs,1)
        B = (kl_weight) * (0.5) * torch.sum(((-1) - (log_sigma_squared_x) + (mu_x.pow(2)) + (sigma_squared_x)), dim=1, keepdim=True)

        # We need to sample "L-many z_j" for each "x_i"
        # Using the z we need to get y = (bs,L,3,128,128) as images!
        z = (reparam_Gaussian(mu_x, sigma_x, L)).to(device)
        D_output = []
        for j in range(L):
            # Select the jth sample from z
            z_j = (z[:, j, :]).unsqueeze(-1).unsqueeze(-1)  # Shape (bs,d,1,1)
            D_images = D(z_j)  # Shape (bs, 3, 128, 128)
            D_output.append(D_images)
        y = torch.stack(D_output, dim=1).to(device) #(bs,L,3,128,128)

        # Finally the reconstruction loss term A has to be computed
        # MSE pixel-wise broadcasts and averages over {bs,L,pixels}
        x_r = x.unsqueeze(1).repeat(1, L, 1, 1, 1).to(device)
        A = mse_loss(y,x_r)

        # We need to maximize the ELBO
        loss_VAE = ((B.mean()) + A)
        loss_VAE.backward()
        vae_optimizer.step()
        vae_optimizer.zero_grad()

    # .item() detaches tensor automatically
    B1 = (B.mean().item()/(kl_weight))
    elbo.append(-(B1 + A.item()))
    A_loss.append(A.item())
    B_loss.append(B1)


In [None]:
# Create a figure with three subplots (1 row, 3 columns)
fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(18, 6))
# Plot A_loss in the first subplot
A_steps = range(len(A_loss))
ax1.plot(A_steps, A_loss, label='A_loss', color='blue')
ax1.set_xlabel('Epochs')
ax1.set_ylabel('A')
ax1.set_title('Reconstruction_Loss')
ax1.legend()
# Plot B_loss in the second subplot
B_steps = range(len(B_loss))
ax2.plot(B_steps, B_loss, label='B_loss', color='orange')
ax2.set_xlabel('Epochs')
ax2.set_ylabel('B')
ax2.set_title('KL_Divg')
ax2.legend()
# Plot ELBO in the third subplot
vae_steps = range(len(elbo))
ax3.plot(vae_steps, elbo, label='ELBO', color='green')
ax3.set_xlabel('Epochs')
ax3.set_ylabel('ELBO')
ax3.set_title('ELBO')
ax3.legend()
# Display the plots
plt.tight_layout()
plt.show()


In [None]:
# Inference
d = 128
s = 100
with torch.no_grad():
    z = (sample_gaussian(torch.zeros((d,)),torch.eye(d),s)).unsqueeze(-1).unsqueeze(-1).to(device)
    y = (D(z))
    images = y.detach().cpu().numpy()
    # Create a figure for the grid of images
    fig, axes = plt.subplots(nrows=10, ncols=10, figsize=(15, 15))
    # Loop through the 100 images and display them in the grid
    for i, ax in enumerate(axes.flat):
        img = images[i].transpose(1, 2, 0)
        ax.imshow(img)
        ax.axis('off')
    plt.tight_layout()
    plt.show()


In [None]:
# Plotting reconstructed images when fed with training images
s = 100
with torch.no_grad():
    x = sample_real_butterfly(s).to(device)
    mu_x, log_sigma_squared_x = E(x)
    sigma_squared_x = torch.exp(log_sigma_squared_x)
    sigma_x = torch.sqrt(sigma_squared_x)
    z = (reparam_Gaussian(mu_x, sigma_x, 1)).to(device)
    z = z.squeeze(1).unsqueeze(-1).unsqueeze(-1)
    y = D(z)

# Assuming both sets have the same number of images
num_images = min(x.size(0), y.size(0))
images_0 = x.detach().cpu().numpy()
images_1 = y.detach().cpu().numpy()
# Create a figure for the grid of images
fig, axes = plt.subplots(nrows=num_images//10, ncols=20, figsize=(30, 15))

# Loop through the images and display them side by side
for i in range(num_images):
    # Show the first set of images
    img_0 = images_0[i].transpose(1, 2, 0)
    axes[i // 10, (i % 10) * 2].imshow(img_0)
    axes[i // 10, (i % 10) * 2].axis('off')
    # Show the second set of images
    img_1 = images_1[i].transpose(1, 2, 0)
    axes[i // 10, (i % 10) * 2 + 1].imshow(img_1)
    axes[i // 10, (i % 10) * 2 + 1].axis('off')

plt.tight_layout()
plt.show()


**Some other techniques were tried**
1. KL cost annealing https://arxiv.org/pdf/1511.06349 Section 3.1
2. Cyclical KL Annealing Schedule https://aclanthology.org/N19-1021.pdf





**[Q2] CNN-based Classifier**

In [None]:
class Classifier(nn.Module):
    def __init__(self, architecture):
        # architecture = [cnn_architecture, nn_architecute]
        # nn_architecture = [nn_layers, nn_batch_norm , nn_activation]
        # cnn_architecture = [conv_params, cnn_batchnorm, cnn_activation_fn]
        # conv_params = [[in_channels, output_channels, kernel_size, stride, padding], ...]
        # nn/cnn_batchnorm = [False, True, ...]
        # nn/cnn_activation = [nn.ReLU(), nn.LeakyReLU(), ...]

        super(Classifier, self).__init__()
        assert len(architecture[0][0]) == len(architecture[0][1]) == len(architecture[0][2])
        assert len(architecture[1][0]) == len(architecture[1][1]) == len(architecture[1][2])
        self.conv_params = architecture[0][0]
        self.cnn_batchnorm = architecture[0][1]
        self.cnn_activation = architecture[0][2]
        self.nn_layers = architecture[1][0]
        self.nn_batchnorm = architecture[1][1]
        self.nn_activation = architecture[1][2]

        layers_0 = []
        for j,i in enumerate(self.conv_params):
            (layers_0).append(nn.Conv2d(in_channels=i[0],out_channels=i[1],kernel_size=i[2],stride=i[3],padding=i[4]))
            if (self.cnn_batchnorm)[j]:
                (layers_0).append(nn.BatchNorm2d(i[1]))
            (layers_0).append(self.cnn_activation[j])

        layers_1 = []
        for i in range(len(self.nn_layers)):
            in_dim, out_dim = self.nn_layers[i]
            (layers_1).append(nn.Linear(in_dim, out_dim))
            if self.nn_batchnorm[i]:
                (layers_1).append(nn.BatchNorm1d(out_dim))
            (layers_1).append(self.nn_activation[i])

        # Stack convolutional layers in a Sequential block
        # Create 1 FFNN
        self.cnn = nn.Sequential(*(layers_0))
        self.nn = nn.Sequential(*(layers_1))

    def forward(self, x):
        features = self.cnn(x)
        features = torch.flatten(features, start_dim=1)
        logits = self.nn(features)
        return logits

    # We will use this to prevent calculating gradients wrt parameters of Classifier
    def set_requires_grad(self, requires_grad):
        for param in self.parameters():
            param.requires_grad = requires_grad


In [None]:
# Formula for CNN convolutions
# output_size = ({input_size - kernel_size + 2*padding}/stride) + 1

# TRIAL 1
conv_params_3 = [[3,16,4,2,1],[16,16,4,2,1],[16,32,4,2,1],[32,32,4,2,1]]
# CONVERT from [3,128,128] = 49,152 to [32,8,8] = 2048
cnn_batchnorm_3 = [True for _ in range(4)]
cnn_activation_3 = [nn.LeakyReLU(0.2, inplace=False) for _ in range(4)]
classifier_cnn = [conv_params_3, cnn_batchnorm_3, cnn_activation_3]
nn_layers_3 = [[2048,75]]
nn_batchnorm_3 = [True for _ in range(0)] + [False]
nn_activation_3 = [nn.ReLU() for _ in range(0)] + [nn.Identity()]
# We will use soft-max while doing the CE loss
classifier_nn = [nn_layers_3, nn_batchnorm_3, nn_activation_3]
classifier_arch = [classifier_cnn , classifier_nn]


In [None]:
Classifier_check = Classifier(classifier_arch)
r1 = torch.randn((7,3,128,128))
r2 = Classifier_check(r1)
print(r2.size())
TP1 = sum(p.numel() for p in Classifier_check.parameters())
print(TP1)


In [None]:
# Initialize model, loss function, and optimizer
lr_CNN, m = 1e-4, 64
C = Classifier(classifier_arch).to(device)
CE_loss = []
C_params = list(C.parameters())
C_optimizer = optim.Adam(C_params, lr=lr_CNN, betas=(0.8, 0.98), eps=1e-8, weight_decay=0)

# Since butterfly images test set has no labels we create a random split over training
all_indices = set(range(6499))
train_split = set(random.sample(range(6499), 5000))
test_split = list(all_indices - train_split)
train_split = list(train_split)


In [None]:
# CNN training algorithm
num_epochs = 10
mini_batch_epochs = 77
criterion = nn.CrossEntropyLoss()
# Training loop
for epoch in range(num_epochs):
    for _ in range(mini_batch_epochs):
        C.set_requires_grad(True)
        C_optimizer.zero_grad()

        j_indices = torch.randint(0,5000, (m,)).tolist()
        i_indices = [train_split[t] for t in j_indices]
        sampled_images = []
        Labels = []
        for i in i_indices:
            sampled_images.append(butterfly_images[i])
            Labels.append(butterfly_labels[i])

        x = torch.stack(sampled_images).to(device)
        Labels = torch.tensor(Labels).to(device)
        y = C(x)
        loss = criterion(y,Labels)
        loss.backward()
        C_optimizer.step()
        C_optimizer.zero_grad()

    CE_loss.append(loss.item())


In [None]:
# Plot the losses
plt.plot(CE_loss, label='Cross-Entropy Loss')
# Adding titles and labels
plt.title('Cross-Entropy Loss Over Time')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
# Show the plot
plt.show()


In [None]:
# Determine the final classification accuracy on TRAINING SET
batch_size = 64
Correct = 0

# We iterate through in batches
# Due to BN layers bs must be > 1
for i in range(0, 5000, batch_size):
    with torch.no_grad():
        batch_indices = train_split[i:i+batch_size]
        # Create batches of butterfly_images and labels
        batch_images = butterfly_images[batch_indices].to(device)
        batch_labels = butterfly_labels[batch_indices].to(device)
        # Classification using the images
        y = C(batch_images)
        # Get predicted classes
        predictions = torch.argmax(y, dim=1)
        # Check how many predictions match the ground truth labels
        Correct += (predictions == batch_labels).sum().item()

# Final accuracy calculation
accuracy = Correct * 100 / 5000
print("Final Accuracy of CNN on TRAIN:", accuracy)


In [None]:
# Determine the final classification accuracy on TEST SET
batch_size = 64
Correct = 0

# We iterate through in batches
# Due to BN layers bs must be > 1
for i in range(0, 1499, batch_size):
    with torch.no_grad():
        batch_indices = test_split[i:i+batch_size]
        # Create batches of butterfly_images and labels
        batch_images = butterfly_images[batch_indices].to(device)
        batch_labels = butterfly_labels[batch_indices].to(device)
        # Classification using the images
        y = C(batch_images)
        # Get predicted classes
        predictions = torch.argmax(y, dim=1)
        # Check how many predictions match the ground truth labels
        Correct += (predictions == batch_labels).sum().item()

# Final accuracy calculation
accuracy = Correct * 100 / 1499
print("Final Accuracy of CNN on TEST:", accuracy)


**Accuracy of Classifier directly on the Butterfly Images**

* After 25 epochs: Final Accuracy of CNN on TRAIN: 97.82%
* After 25 epochs: Final Accuracy of CNN on TEST: 56.43%
* #CNN parameters = 183,403







**[Q3] Posterior Inference and MLP classification using Latents**

In [None]:
class Classifier_MLP(nn.Module):
    def __init__(self, architecture):
        # architecture = nn_architecture
        # nn_architecture = [nn_layers, nn_batch_norm , nn_activation]
        # nn_batchnorm = [False, True, ...]
        # nn_activation = [nn.ReLU(), nn.LeakyReLU(), ...]

        super(Classifier_MLP, self).__init__()
        assert len(architecture[0]) == len(architecture[1]) == len(architecture[2])
        self.nn_layers = architecture[0]
        self.nn_batchnorm = architecture[1]
        self.nn_activation = architecture[2]

        layers_1 = []
        for i in range(len(self.nn_layers)):
            in_dim, out_dim = self.nn_layers[i]
            (layers_1).append(nn.Linear(in_dim, out_dim))
            if self.nn_batchnorm[i]:
                (layers_1).append(nn.BatchNorm1d(out_dim))
            (layers_1).append(self.nn_activation[i])

        # Create 1 FFNN
        self.nn = nn.Sequential(*(layers_1))

    def forward(self, x):
        return (self.nn(x))

    # We will use this to prevent calculating gradients wrt parameters of Classifier
    def set_requires_grad(self, requires_grad):
        for param in self.parameters():
            param.requires_grad = requires_grad


In [None]:
d_latent = 128

#TRIAL_ 1 [43,211]
nn_layers_4 = [[d_latent,d_latent],[d_latent,d_latent],[d_latent,75]]
nn_batchnorm_4 = [True for _ in range(2)] + [False]
nn_activation_4 = [nn.ReLU() for _ in range(2)] + [nn.Identity()]
# We will use soft-max while doing the CE loss
classifier_arch = [nn_layers_4, nn_batchnorm_4, nn_activation_4]


In [None]:
Classifier_check = Classifier_MLP(classifier_arch)
r1 = torch.randn((7,d_latent))
r2 = Classifier_check(r1)
print(r2.size())
TP1 = sum(p.numel() for p in Classifier_check.parameters())
print(TP1)


In [None]:
# Initialize model, loss function, and optimizer
lr_NN, m = 1e-4, 64
C = Classifier_MLP(classifier_arch).to(device)
CE_loss = []
C_params = list(C.parameters())
C_optimizer = optim.Adam(C_params, lr=lr_NN, betas=(0.8, 0.98), eps=1e-8, weight_decay=0)
E_c = torch.load("/home/sahapthank/saha_adrl/E_1f_New_400.pth").to(device)

# Since butterfly images test set has no labels we create a random split over training
all_indices = set(range(6499))
train_split = set(random.sample(range(6499), 5000))
test_split = list(all_indices - train_split)
train_split = list(train_split)


In [None]:
# We train NN on z which are latents of the Encoder
num_epochs = 10
mini_batch_epochs = 77
criterion = nn.CrossEntropyLoss()
# Training loop
for epoch in range(num_epochs):
    for _ in range(mini_batch_epochs):
        E_c.set_requires_grad(False)
        C.set_requires_grad(True)
        C_optimizer.zero_grad()

        j_indices = torch.randint(0,5000, (m,)).tolist()
        i_indices = [train_split[t] for t in j_indices]
        sampled_images = []
        Labels = []
        for i in i_indices:
            sampled_images.append(butterfly_images[i])
            Labels.append(butterfly_labels[i])

        x = torch.stack(sampled_images).to(device)
        Labels = torch.tensor(Labels).to(device)
        mu_x, log_sigma_squared_x = E_c(x)
        sigma_squared_x = torch.exp(log_sigma_squared_x)
        sigma_x = torch.sqrt(sigma_squared_x)
        z = (reparam_Gaussian(mu_x, sigma_x, 1)).squeeze(1).to(device)
        y = C(z)

        loss = criterion(y,Labels)
        loss.backward()
        C_optimizer.step()
        C_optimizer.zero_grad()

    CE_loss.append(loss.item())


In [None]:
# Plot the losses
plt.plot(CE_loss, label='Cross-Entropy Loss')
# Adding titles and labels
plt.title('Cross-Entropy Loss Over Time')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
# Show the plot
plt.show()


In [None]:
# Determine the final classification accuracy on TRAINING SET
batch_size = 64
Correct = 0

# We iterate through in batches
# Due to BN layers bs must be > 1
for i in range(0, 5000, batch_size):
    with torch.no_grad():
        # Create batches of butterfly_images and labels
        batch_indices = train_split[i:i+batch_size]
        batch_images = butterfly_images[batch_indices].to(device)
        batch_labels = butterfly_labels[batch_indices].to(device)

        # Forward pass through Encoder
        mu_x, log_sigma_squared_x = E_c(batch_images)
        sigma_squared_x = torch.exp(log_sigma_squared_x)
        sigma_x = torch.sqrt(sigma_squared_x)
        z = (reparam_Gaussian(mu_x, sigma_x, 1)).squeeze(1).to(device)
        y = C(z)

        # Get predicted classes
        predictions = torch.argmax(y, dim=1)
        # Check how many predictions match the ground truth labels
        Correct += (predictions == batch_labels).sum().item()

# Final accuracy calculation
accuracy = Correct * 100 / 5000
print("Final Accuracy of MLP based on VAE Latents on TRAIN:", accuracy)


In [None]:
# Determine the final classification accuracy on TEST SET
batch_size = 64
Correct = 0

# We iterate through in batches
# Due to BN layers batch_size must be > 1
for i in range(0, 1499, batch_size):
    with torch.no_grad():
        batch_indices = test_split[i:i+batch_size]
        # Create batches of butterfly_images and labels
        batch_images = butterfly_images[batch_indices].to(device)
        batch_labels = butterfly_labels[batch_indices].to(device)

        # Forward pass through Encoder
        mu_x, log_sigma_squared_x = E_c(batch_images)
        sigma_squared_x = torch.exp(log_sigma_squared_x)
        sigma_x = torch.sqrt(sigma_squared_x)
        z = (reparam_Gaussian(mu_x, sigma_x, 1)).squeeze(1).to(device)
        y = C(z)

        # Get predicted classes
        predictions = torch.argmax(y, dim=1)
        # Check how many predictions match the ground truth labels
        Correct += (predictions == batch_labels).sum().item()

# Final accuracy calculation
accuracy = Correct * 100 / 1499
print("Final Accuracy of MLP based on VAE Latents on TEST:", accuracy)


**Accuracy of a simple MLP [3-layers] on VAE Latents**

The VAE latents were not that good. But even with 1/4 parameters of the CNN , accuracy is reasonable highlighting the ability of VAE to compress information.
* Final Accuracy of MLP based on VAE Latents on TRAIN: 62.48%
* Final Accuracy of MLP based on VAE Latents on TEST: 32.68%
* #MLP parameters  = 43,211






**[Q7 and Q8] VQ-VAE Implementation**

**Some implementation Observations**
* In VQ-VAE training we vaey the architecture and the beta value.
* Also as expected uniform sampling for generation does not give any good images! Fitting GMM improved it but still image quality is bad [this might be due to fitting using only diagonal covariances]. Fitting more gaussian components on the entire latents increased recognisable images but it seems the entire space is not mapped.
* This could be due to using lot of embedding vectors and small encoer having only 40k parameters. So we change K,D and experiment.








In [None]:
# VQ-VAE is the SoTa lets see if we can do it properly!
# Learning useful representations without supervision remains a key challenge in ML
# (VQ-VAE) differs from VAEs in two key ways:
# 1) encoder network outputs discrete, rather than continuous, codes; and the prior
# 2) is learnt rather than static.
# Using the VQ method allows the
# model to circumvent issues of “posterior collapse” — where the latents are ignored
# when they are paired with a powerful autoregressive decoder — typically observed in the VAE framework.

# For speech discrete encodings k are 1D and each z_q corresponds to e_q in R^d
# For images (d1,d2) 2D encoding where each (vector along pixel) gets mapped to an embedding vector
# There exist K embedding vectors!
# Basically a CNN is used to convert to z_e(x) = (D,d1,d2) and each {z_e(x)}i,j as a D-dim vectors
# We use nearest neighbours and map it to e_i in R^D!
# We now get z_q(x) = (D,d1,d2) which using a Transpose_CNN we get final image

# Note that there is no real gradient defined for discretisation using nearest neighbour
# however we approximate the gradient similar to the straight-through estimator
# and just copy gradients from decoder input z_q(x) to encoder output z_e(x)
# Since the output representation of the encoder and the input to the decoder share the same D dimensional space,
# the gradients contain useful information for how the encoder has to change its output

class VQ_VAE_Encoder_0(nn.Module):
    def __init__(self, architecture):
        # Has only Convolutions, architecture = [cnn_architecture]
        # cnn_architecture = [conv_params, cnn_batchnorm, cnn_activation_fn]
        # conv_params = [[in_channels, output_channels, kernel_size, stride, padding], ...]
        # cnn_batchnorm = [False, True, ...]
        # cnn_activation = [nn.ReLU(), nn.LeakyReLU(), ...]

        super(VQ_VAE_Encoder_0, self).__init__()
        assert len(architecture[0]) == len(architecture[1]) == len(architecture[2])
        self.conv_params = architecture[0]
        self.cnn_batchnorm = architecture[1]
        self.cnn_activation = architecture[2]

        layers_0 = []
        for j,i in enumerate(self.conv_params):
            (layers_0).append(nn.Conv2d(in_channels=i[0],out_channels=i[1],kernel_size=i[2],stride=i[3],padding=i[4]))
            if (self.cnn_batchnorm)[j]:
                (layers_0).append(nn.BatchNorm2d(i[1]))
            (layers_0).append(self.cnn_activation[j])

        # Stack convolutional layers in a Sequential block
        self.cnn = nn.Sequential(*(layers_0))

    def forward(self, x):
        z_e = self.cnn(x)
        return z_e

    # We will use this to prevent calculating gradients wrt parameters of Encoder
    def set_requires_grad(self, requires_grad):
        for param in self.parameters():
            param.requires_grad = requires_grad

class VQ_VAE_Decoder_0(nn.Module):
    def __init__(self, architecture):
        # The architecture is almost same as DCGAN using transpose_convolutions only

        super(VQ_VAE_Decoder_0, self).__init__()
        assert len(architecture[0]) == len(architecture[1]) == len(architecture[2])
        self.transpose_conv_params = architecture[0]
        self.use_batchnorm = architecture[1]
        self.activation_fn = architecture[2]

        layers = []
        # Starting with input latent vector of 2D image!
        for j, i in enumerate(self.transpose_conv_params):
            layers.append(nn.ConvTranspose2d(in_channels=i[0], out_channels=i[1], kernel_size=i[2], stride=i[3], padding=i[4]))
            if self.use_batchnorm[j]:
                layers.append(nn.BatchNorm2d(i[1]))
            layers.append(self.activation_fn[j])
        self.model = nn.Sequential(*layers)

    # Sampling images from Decoder
    def forward(self, x):
        return self.model(x)

    # We will use this to prevent calculating gradients wrt parameters of Decoder
    def set_requires_grad(self, requires_grad):
        for param in self.parameters():
            param.requires_grad = requires_grad


In [None]:
# Since we assume a uniform prior for z, the KL term that usually appears in the ELBO is constant
# w.r.t. the encoder parameters and can thus be ignored for training
# Note that in VAE KL is only used for encoder training!
# VQVAE has 3 loss terms for the reconstruction
# A = logp(x|z_q(x)) := reconstruction {Encoder and Decoder}
# B = l2 ((sg(z_e(x)),e)) := VQ Objective {Embedding Vectors only}
# C =  beta * l2 ((z_e(x),sg(e))) := CommitmentLoss {Encoder only}

# When we use N = k1*k2 dicrete latents (In paper they use (32,32) for ImageNet)
# The loss terms are averaged for each latent (1/N factor)
# One term for each latent arises in {B,C}
# Whilst training the VQ-VAE, the prior is kept constant and uniform.
# Each region of the image (represented by a latent code) has an equal probability
# of being encoded by any of the embedding vectors from the codebook.

class VectorQuantizer(nn.Module):
    def __init__(self, num_embeddings, embedding_dim, commitment_loss_beta):
        super(VectorQuantizer, self).__init__()
        # Codebook containing K embedding vectors
        # Each vector is of dim D
        # D would be the final #output channels in CNN
        # Underscore is used for inplace operation
        # Initialize with unform weights from [-1/K,1/K]
        self.K = num_embeddings
        self.D = embedding_dim
        self.embedding = nn.Embedding(self.K, self.D) #(K,D)
        self.embedding.weight.data.uniform_(-1/self.K, 1/self.K)
        self.beta = commitment_loss_beta

    def forward(self, inputs):
        # convert inputs z_e(x) from (bs,c,h,w) to (bs,h,w,c)
        inputs = inputs.permute(0, 2, 3, 1).contiguous()
        input_shape = inputs.shape
        # Flatten input to (bs*h*w,c=D)
        flat_input = inputs.view(-1, self.D)
        # Calculate distances of {each pixel_vector with each embedding vector}
        # Final shape is (bs*h*w = N,K)
        distances = (torch.sum(flat_input**2, dim=1, keepdim=True)
                    + (torch.sum((self.embedding.weight)**2, dim=1))
                    - (2 * torch.matmul(flat_input, self.embedding.weight.t())))

        # Encoding
        # Using one-hot encodings allows gradients to flow correctly during backpropagation.
        # If we directly assign embeddings based on indices
        # the computation graph might not accurately reflect the necessary operations for gradient calculation
        # especially when dealing with operations like straight-through estimators
        encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1) #(N,1)
        encodings = torch.zeros(encoding_indices.shape[0], self.K, device=inputs.device)
        # Keep shape same but put ones in the appropriate positions and 0-elsewhere
        encodings.scatter_(1,encoding_indices,1) #(N,K)

        # Quantize to (N,D) and unflatten to (bs,h,w,c)
        quantized = torch.matmul(encodings, self.embedding.weight).view(input_shape)

        # Calculate the Loss associated
        # Averaging over {batch and EmbeddingVectors is done by default}
        # MSE btw {z_e(x),sg(e)}
        commitment_loss = F.mse_loss(quantized.detach(),inputs)
        # MSE btw {sg(z_e(x)),e}
        vq_objective = F.mse_loss(quantized,inputs.detach())
        loss_B_C = (vq_objective) + (self.beta * commitment_loss)

        # Trick for straight-through estimation
        quantized = inputs + (quantized - inputs).detach()

        # convert quantized back to (bs,c,h,w) and return (Loss,z_q(x)),B_loss,C_loss)
        return (loss_B_C, quantized.permute(0, 3, 1, 2).contiguous(),commitment_loss.detach(),vq_objective.detach())

    # We will use this to prevent calculating gradients wrt parameters of VQ
    def set_requires_grad(self, requires_grad):
        for param in self.parameters():
            param.requires_grad = requires_grad


In [None]:
# Formula for CNN convolutions
# output_size = ({input_size - kernel_size + 2*padding}/stride) + 1
# Formula for Transpose Conv
# output_size = {(input_size - 1) * stride} - (2*padding) + kernel_size}

# # TRIAL 1 [E = 43,488 V = 32,768 D = 853,315]
# conv_params_2 = [[3,16,4,2,1],[16,16,1,1,0],[16,32,4,2,1],[32,32,1,1,0],[32,64,4,2,1]]
# # CONVERT TO [64,16,16] so {D = 64}
# cnn_batchnorm_2 = [True for _ in range(5)]
# cnn_activation_2 = [nn.LeakyReLU(0.2, inplace=False) for _ in range(5)]
# vq_E_arch_2 = [conv_params_2, cnn_batchnorm_2, cnn_activation_2]
# transpose_conv_params_2 = [[64,256,4,2,1],[256,128,4,2,1],[128,32,4,2,1],[32,3,1,1,0]]
# use_batchnorm_2 = [True for _ in range(3)] + [False]
# activation_fn_2 = [nn.ReLU(inplace = False) for _ in range(3)] + [nn.Sigmoid()]
# vq_D_arch_2 = [transpose_conv_params_2, use_batchnorm_2, activation_fn_2]

# TRIAL 2 [E = 47,776 V = 32,768 D = 937,859]
# Somebody please add this using different architectures to try...it is boring!
conv_params_2 = [[3,16,4,2,1],[16,16,1,1,0],[16,32,4,2,1],[32,32,1,1,0],[32,64,4,2,1],[64,64,1,1,0]]
# CONVERT TO [64,16,16] so D = 64
cnn_batchnorm_2 = [True for _ in range(6)]
cnn_activation_2 = [nn.LeakyReLU(0.2, inplace=False) for _ in range(6)]
vq_E_arch_2 = [conv_params_2, cnn_batchnorm_2, cnn_activation_2]
transpose_conv_params_2 = [[64,256,4,2,1],[256,128,4,2,1],[128,128,1,1,0],[128,64,4,2,1],[64,32,1,1,0],[32,3,1,1,0]]
use_batchnorm_2 = [True for _ in range(5)] + [False]
activation_fn_2 = [nn.ReLU(inplace = False) for _ in range(5)] + [nn.Sigmoid()]
vq_D_arch_2 = [transpose_conv_params_2, use_batchnorm_2, activation_fn_2]


In [None]:
# Rough checking if code so far behaves as expected
E_check = VQ_VAE_Encoder_0(vq_E_arch_2)
D_check = VQ_VAE_Decoder_0(vq_D_arch_2)
V_check = VectorQuantizer(512,64,2)
r1 = torch.randn((7,3,128,128))
r2 = E_check(r1)
print(r2.size())
r4 = torch.randn((7,64,16,16))
r5 = D_check(r4)
print(r5.size())
print(V_check(r4)[1].size())
TP1 = sum(p.numel() for p in E_check.parameters())
TP2 = sum(p.numel() for p in V_check.parameters())
TP3 = sum(p.numel() for p in D_check.parameters())
print(TP1,TP2,TP3)


In [None]:
# Initialize model, loss function, and optimizer
lr_vq_VAE, m = 1e-4, 64
beta = 2
E_arch = vq_E_arch_2
D_arch = vq_D_arch_2
E = VQ_VAE_Encoder_0(E_arch).to(device)
D = VQ_VAE_Decoder_0(D_arch).to(device)
vq = VectorQuantizer(512,64,beta).to(device)
A_loss = []
B_loss = []
C_loss = []
vq_loss = []
vae_params = list(E.parameters()) + list(D.parameters()) + list(vq.parameters())
vae_optimizer = optim.Adam(vae_params, lr=lr_vq_VAE, betas=(0.8, 0.98), eps=1e-8, weight_decay=0)
# vae_optimizer = optim.RMSprop(vae_params, lr=lr_VAE, alpha=0.99, eps=1e-8, weight_decay=0)
# vae_optimizer = optim.SGD(vae_params, lr=lr_VAE, momentum=0.9, weight_decay=0)


In [None]:
# For manual changing during training
lr_vq_VAE, m = 1e-4, 64
vq.beta = 2
vae_optimizer = optim.Adam(vae_params, lr=lr_vq_VAE, betas=(0.8, 0.98), eps=1e-8, weight_decay=0)
# vae_optimizer = optim.RMSprop(vae_params, lr=lr_VAE, alpha=0.99, eps=1e-8, weight_decay=0)
# vae_optimizer = optim.SGD(vae_params, lr=lr_VAE, momentum=0.9, weight_decay=0)


In [None]:
# VQ VAE training algorithm
num_epochs = 900
mini_batch_epochs = 100
# Training loop
for epoch in range(num_epochs):
    for _ in range(mini_batch_epochs):
        # start = time.time()
        D.set_requires_grad(True)
        E.set_requires_grad(True)
        vq.set_requires_grad(True)
        vae_optimizer.zero_grad()

        x = sample_real_butterfly(m).to(device)
        z_e = E(x)
        loss_B_C , z_q , C, B = vq(z_e)
        y = D(z_q)
        reconstruction_loss = (F.mse_loss(x,y))
        # Averaging over {batch+dimension} done by default

        loss_vq_VAE = (loss_B_C + reconstruction_loss)
        loss_vq_VAE.backward()
        vae_optimizer.step()
        vae_optimizer.zero_grad()

    A_loss.append(reconstruction_loss.detach().cpu())
    B_loss.append(B.cpu())
    C_loss.append(C.cpu())
    vq_loss.append(loss_vq_VAE.detach().cpu())


In [None]:
# Create a figure with three subplots (1 row, 3 columns)
fig, (ax1, ax2, ax3, ax4) = plt.subplots(1, 4, figsize=(18, 6))

# Plot A_loss in the first subplot
A_steps = range(len(A_loss))
ax1.plot(A_steps, A_loss, label='A_loss', color='blue')
ax1.set_xlabel('Epochs')
ax1.set_ylabel('A')
ax1.set_title('Reconstruction_Loss')
ax1.legend()
# Plot B_loss in the second subplot
B_steps = range(len(B_loss))
ax2.plot(B_steps, B_loss, label='B_loss', color='orange')
ax2.set_xlabel('Epochs')
ax2.set_ylabel('B')
ax2.set_title('VQ_Objective')
ax2.legend()
# Plot C_loss in the third subplot
vae_steps = range(len(C_loss))
ax3.plot(vae_steps, C_loss, label='C_loss', color='green')
ax3.set_xlabel('Epochs')
ax3.set_ylabel('C')
ax3.set_title('Commitment_Loss')
ax3.legend()
# Plot vq_loss in the fourth subplot
vae_steps = range(len(vq_loss))
ax4.plot(vae_steps, vq_loss, label='vq_loss', color='black')
ax4.set_xlabel('Epochs')
ax4.set_ylabel('Overall')
ax4.set_title('vq_Loss')
ax4.legend()

# Display the plots
plt.tight_layout()
plt.show()


In [None]:
# Plotting reconstructed images when fed with training images
s = 100
with torch.no_grad():
    x = sample_real_butterfly(s).to(device)
    z_e = E(x)
    z_q = vq(z_e)[1]
    y = D(z_q)

# Assuming both sets have the same number of images
num_images = min(x.size(0), y.size(0))
images_0 = x.detach().cpu().numpy()
images_1 = y.detach().cpu().numpy()
# Create a figure for the grid of images
fig, axes = plt.subplots(nrows=num_images//10, ncols=20, figsize=(30, 15))

# Loop through the images and display them side by side
for i in range(num_images):
    # Show the first set of images
    img_0 = images_0[i].transpose(1, 2, 0)
    axes[i // 10, (i % 10) * 2].imshow(img_0)
    axes[i // 10, (i % 10) * 2].axis('off')
    # Show the second set of images
    img_1 = images_1[i].transpose(1, 2, 0)
    axes[i // 10, (i % 10) * 2 + 1].imshow(img_1)
    axes[i // 10, (i % 10) * 2 + 1].axis('off')

plt.tight_layout()
plt.show()


In [None]:
# Inference M1 using random [assuming uniform prior]
s = 100
K = 512
with torch.no_grad():
    shape = (s,16,16,64)
    random_indices = torch.randint(0,K,(s*16*16,1)).to(device) #(N,1)
    # Now we need to map to embedding vectors
    # We use the previous logic of matmul
    random_encodings = torch.zeros(random_indices.shape[0],K,device=device)
    random_encodings.scatter_(1,random_indices,1) #(N,K)
    random_quantized = torch.matmul(random_encodings, vq.embedding.weight).view(shape)
    random_quantized = random_quantized.permute(0, 3, 1, 2).contiguous()
    # print(random_quantized.shape)
    generated_images = D(random_quantized)

    images = generated_images.detach().cpu().numpy()
    # Create a figure for the grid of images
    fig, axes = plt.subplots(nrows=10, ncols=10, figsize=(15, 15))
    # Loop through the 100 images and display them in the grid
    for i, ax in enumerate(axes.flat):
        img = images[i].transpose(1, 2, 0)
        ax.imshow(img)
        ax.axis('off')
    plt.tight_layout()
    plt.show()


In [None]:
# Inference M2 using GMM [Fit GMM on entire latent space of real_images]
# We treat latent vectors coming from GMM to be in space of dimension [64*16*16]
# E = torch.load("A2_weights/E_2a_100.pth").to(device)
# D = torch.load("A2_weights/D_2a_100.pth").to(device)
# vq = torch.load("A2_weights/vq_2a_100.pth").to(device)
all_latent_vectors = torch.zeros(6499,64,16,16)
batch_size = 64
for i in range(0, 6499, batch_size):
    with torch.no_grad():
        x = butterfly_images[i:min(i+batch_size,6499)].to(device)
        z_e = E(x)
        z_q = vq(z_e)[1]
        # Calculate the end index for slicing
        end_index = min(i + batch_size, 6499)
        # Insert the batch of latents into the correct slice of T
        all_latent_vectors[i:end_index] = z_q
all_latent_vectors = all_latent_vectors.view(-1,64*16*16)


In [None]:
from sklearn.mixture import GaussianMixture
# Fit the GMM and generate images
gmm = GaussianMixture(n_components=256, covariance_type='diag',max_iter=100,verbose=2, verbose_interval=1)
gmm.fit(all_latent_vectors.cpu().numpy())
# Access convergence results
print("Converged:", gmm.converged_)
print("Number of iterations:", gmm.n_iter_)
print("Log-Likelihood:", gmm.score(all_latent_vectors.cpu().numpy()))


In [None]:
s = 100
random_quantized, _ = gmm.sample(s)
random_quantized = (torch.from_numpy(random_quantized.reshape(-1, 64, 16, 16)).float()).to(device)
with torch.no_grad():
    generated_images = D(random_quantized)
    images = generated_images.detach().cpu().numpy()
    # Create a figure for the grid of images
    fig, axes = plt.subplots(nrows=10, ncols=10, figsize=(15, 15))
    # Loop through the 100 images and display them in the grid
    for i, ax in enumerate(axes.flat):
        img = images[i].transpose(1, 2, 0)
        ax.imshow(img)
        ax.axis('off')
    plt.tight_layout()
    plt.show()


**Classifying the images using the learned latents**

In [None]:
# A [CNN + NN] for classification
class Classifier_0(nn.Module):
    def __init__(self, architecture):
        # architecture = [cnn_architecture, nn_architecute]
        # nn_architecture = [nn_layers, nn_batch_norm , nn_activation]
        # cnn_architecture = [conv_params, cnn_batchnorm, cnn_activation_fn]
        # conv_params = [[in_channels, output_channels, kernel_size, stride, padding], ...]
        # nn/cnn_batchnorm = [False, True, ...]
        # nn/cnn_activation = [nn.ReLU(), nn.LeakyReLU(), ...]

        super(Classifier_0, self).__init__()
        assert len(architecture[0][0]) == len(architecture[0][1]) == len(architecture[0][2])
        assert len(architecture[1][0]) == len(architecture[1][1]) == len(architecture[1][2])
        self.conv_params = architecture[0][0]
        self.cnn_batchnorm = architecture[0][1]
        self.cnn_activation = architecture[0][2]
        self.nn_layers = architecture[1][0]
        self.nn_batchnorm = architecture[1][1]
        self.nn_activation = architecture[1][2]

        layers_0 = []
        for j,i in enumerate(self.conv_params):
            (layers_0).append(nn.Conv2d(in_channels=i[0],out_channels=i[1],kernel_size=i[2],stride=i[3],padding=i[4]))
            if (self.cnn_batchnorm)[j]:
                (layers_0).append(nn.BatchNorm2d(i[1]))
            (layers_0).append(self.cnn_activation[j])

        layers_1 = []
        for i in range(len(self.nn_layers)):
            in_dim, out_dim = self.nn_layers[i]
            (layers_1).append(nn.Linear(in_dim, out_dim))
            if self.nn_batchnorm[i]:
                (layers_1).append(nn.BatchNorm1d(out_dim))
            (layers_1).append(self.nn_activation[i])

        # Stack convolutional layers in a Sequential block
        # Create 1 FFNN
        self.cnn = nn.Sequential(*(layers_0))
        self.nn = nn.Sequential(*(layers_1))

    def forward(self, x):
        features = self.cnn(x)
        features = torch.flatten(features, start_dim=1)
        logits = self.nn(features)
        return logits

    # We will use this to prevent calculating gradients wrt parameters of Classifier_0
    def set_requires_grad(self, requires_grad):
        for param in self.parameters():
            param.requires_grad = requires_grad


In [None]:
# Formula for CNN convolutions
# output_size = ({input_size - kernel_size + 2*padding}/stride) + 1

# TRIAL 1
conv_params_4 = [[64,64,1,1,0],[64,64,4,2,1],[64,128,4,2,1]]
# CONVERT from [64,16,16] = 16,384 to [128,4,4] = 2048
cnn_batchnorm_4 = [True for _ in range(3)]
cnn_activation_4 = [nn.LeakyReLU(0.2, inplace=False) for _ in range(3)]
classifier_0_cnn = [conv_params_4, cnn_batchnorm_4, cnn_activation_4]
nn_layers_4 = [[2048,75]]
nn_batchnorm_4 = [True for _ in range(0)] + [False]
nn_activation_4 = [nn.ReLU() for _ in range(0)] + [nn.Identity()]
# We will use soft-max while doing the CE loss
classifier_0_nn = [nn_layers_4, nn_batchnorm_4, nn_activation_4]
classifier_0_arch = [classifier_0_cnn , classifier_0_nn]


In [None]:
Classifier_check = Classifier_0(classifier_0_arch)
r1 = torch.randn((7,64,16,16))
r2 = Classifier_check(r1)
print(r2.size())
TP1 = sum(p.numel() for p in Classifier_check.parameters())
print(TP1)


In [None]:
# Initialize model, loss function, and optimizer
lr_CNN, m = 1e-4, 64
C = Classifier_0(classifier_0_arch).to(device)
CE_loss = []
C_params = list(C.parameters())
C_optimizer = optim.Adam(C_params, lr=lr_CNN, betas=(0.8, 0.98), eps=1e-8, weight_decay=0)
E_c = torch.load("A2_weights/E_2b_700.pth").to(device)
D_c = torch.load("A2_weights/D_2b_700.pth").to(device)
vq_c = torch.load("A2_weights/vq_2b_700.pth").to(device)

# Since butterfly images test set has no labels we create a random split over training
all_indices = set(range(6499))
train_split = set(random.sample(range(6499), 5000))
test_split = list(all_indices - train_split)
train_split = list(train_split)


In [None]:
# We train CNN on z_q(x) of shape (64,16,16)/(1,16,16)
num_epochs = 10
mini_batch_epochs = 77
criterion = nn.CrossEntropyLoss()
# Training loop
for epoch in range(num_epochs):
    for _ in range(mini_batch_epochs):
        D_c.set_requires_grad(False)
        E_c.set_requires_grad(False)
        vq_c.set_requires_grad(False)
        C.set_requires_grad(True)
        C_optimizer.zero_grad()

        j_indices = torch.randint(0,5000, (m,)).tolist()
        i_indices = [train_split[t] for t in j_indices]
        sampled_images = []
        Labels = []
        for i in i_indices:
            sampled_images.append(butterfly_images[i])
            Labels.append(butterfly_labels[i])

        x = torch.stack(sampled_images).to(device)
        Labels = torch.tensor(Labels).to(device)
        z_e = E_c(x)
        z_q = vq_c(z_e)[1]
        y = C(z_q)

        loss = criterion(y,Labels)
        loss.backward()
        C_optimizer.step()
        C_optimizer.zero_grad()

    CE_loss.append(loss.item())


In [None]:
# Plot the losses
plt.plot(CE_loss, label='Cross-Entropy Loss')
# Adding titles and labels
plt.title('Cross-Entropy Loss Over Time')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
# Show the plot
plt.show()


In [None]:
# Determine the final classification accuracy on TRAINING SET
batch_size = 64
Correct = 0

# We iterate through in batches
# Due to BN layers bs must be > 1
for i in range(0, 5000, batch_size):
    with torch.no_grad():
        # Create batches of butterfly_images and labels
        batch_indices = train_split[i:i+batch_size]
        batch_images = butterfly_images[batch_indices].to(device)
        batch_labels = butterfly_labels[batch_indices].to(device)
        # Forward pass through the VQ-VAE encoder and quantizer
        z_e = E_c(batch_images)
        z_q = vq_c(z_e)[1]
        # Classification using the quantized latents
        y = C(z_q)
        # Get predicted classes
        predictions = torch.argmax(y, dim=1)
        # Check how many predictions match the ground truth labels
        Correct += (predictions == batch_labels).sum().item()

# Final accuracy calculation
accuracy = Correct * 100 / 5000
print("Final Accuracy of CNN based on VQ-VAE Latents:", accuracy)


In [None]:
# Determine the final classification accuracy on TEST SET
batch_size = 64
Correct = 0

# We iterate through in batches
# Due to BN layers batch_size must be > 1
for i in range(0, 1499, batch_size):
    with torch.no_grad():
        batch_indices = test_split[i:i+batch_size]
        # Create batches of butterfly_images and labels
        batch_images = butterfly_images[batch_indices].to(device)
        batch_labels = butterfly_labels[batch_indices].to(device)
        z_e = E_c(batch_images)
        z_q = vq_c(z_e)[1]
        # Classification using the quantized latents
        y = C(z_q)
        # Get predicted classes
        predictions = torch.argmax(y, dim=1)
        # Check how many predictions match the ground truth labels
        Correct += (predictions == batch_labels).sum().item()

# Final accuracy calculation
accuracy = Correct * 100 / 1499
print("Final Accuracy of CNN on TEST:", accuracy)


**Classification Accuracy of CNN trained on VQ-VAE Latents [64 x 16 x 16 as i/p]**

* After 25 epochs: Final Accuracy of CNN(VQ-VAE Latents) on TRAIN : 99.86
* After 25 epochs: Final Accuracy of CNN(VQ-VAE Latents) on TEST : 51.63
* #CNN Parameters : 355,147