In [None]:
import torch
import random
from torch.utils.data import Dataset
from torchvision import transforms
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import transforms
import torch.distributions as dist
import torch.nn.functional as F
from torchvision.datasets import ImageFolder
import torch.optim as optim
from tqdm import tqdm
import matplotlib.pyplot as plt
import torchvision.utils as vutils
import numpy as np
%matplotlib inline

# Question 1: GMM

In [None]:
class GMM:
    def __init__(self, num_components, num_iterations, tolerance, device='cpu'):
        self.num_components = num_components
        self.num_iterations = num_iterations
        self.tolerance = tolerance
        self.device = device
        self.means = None
        self.covariances = None
        self.mixing_coefficients = None
        self.likelihoods = []

    def initialize_parameters(self, data):
        num_samples, num_features = data.shape
        self.means = data[torch.randint(0, num_samples, (self.num_components,))].to(self.device)
        self.covariances = torch.stack([torch.eye(num_features) for _ in range(self.num_components)]).to(self.device)
        self.mixing_coefficients = (torch.ones(self.num_components) / self.num_components).to(self.device)

    def E_Step(self, data):
        responsibilities = torch.zeros(data.shape[0], self.num_components).to(self.device)
        for k in range(self.num_components):
            mvn = dist.MultivariateNormal(self.means[k], self.covariances[k])
            responsibilities[:, k] = self.mixing_coefficients[k] * mvn.log_prob(data)
        return F.softmax(responsibilities, dim=1)

    def M_Step(self, data, responsibilities):
        self.mixing_coefficients = responsibilities.mean(dim=0)
        weighted_data = data.unsqueeze(1) * responsibilities.unsqueeze(2)
        self.means = weighted_data.sum(dim=0) / responsibilities.sum(dim=0).unsqueeze(1)
        for k in range(self.num_components):
            centered_data = data - self.means[k]
            self.covariances[k] = torch.matmul(centered_data.t(), centered_data * responsibilities[:, k].unsqueeze(1)) / responsibilities[:, k].sum()
            self.covariances[k] += 1e-6 * torch.eye(data.shape[1])

    def compute_likelihood(self, data):
        likelihoods = torch.zeros(data.shape[0], self.num_components).to(self.device)
        for k in range(self.num_components):
            mvn = dist.MultivariateNormal(self.means[k], self.covariances[k])
            likelihoods[:, k] = self.mixing_coefficients[k] * mvn.log_prob(data)
        return likelihoods.sum(dim=1).mean()

    def train(self, data):
        self.initialize_parameters(data)
        for iteration in range(self.num_iterations):
            responsibilities = self.E_Step(data)
            self.M_Step(data, responsibilities)
            reg_term = 1e-6
            for k in range(self.num_components):
                self.covariances[k].add_(torch.eye(self.covariances[k].size(0)) * reg_term)
            log_likelihood = self.compute_likelihood(data)
            self.likelihoods.append(log_likelihood.item())
            if iteration > 0 and abs(self.likelihoods[-1] - self.likelihoods[-2]) < self.tolerance:
                break

    def generate_samples(self, num_samples=100):
        samples = torch.zeros(num_samples, self.means.shape[1]).to(self.device)
        for i in range(num_samples):
            component = np.random.choice(self.num_components, p=self.mixing_coefficients.detach().cpu().numpy())
            mvn = dist.MultivariateNormal(self.means[component], self.covariances[component])
            samples[i] = mvn.sample()
        return samples

    def plot_likelihood_curve(self):
        plt.plot(range(len(self.likelihoods)), self.likelihoods)
        plt.xlabel("Iteration")
        plt.ylabel("Log Likelihood")
        plt.title("Likelihood Curve")
        plt.show()
        
    def generate_samples(self, num_samples=100):
        num_components, num_features = self.means.shape
        generated_samples = torch.zeros(num_samples, num_features)
        
        for _ in range(num_samples):
            component = torch.randint(0, num_components, (1,))
            mvn = dist.MultivariateNormal(self.means[component], self.covariances[component])
            sample = mvn.sample()
            generated_samples[_] = sample
        
        return generated_samples

    def visualize_samples(self, num_samples=100):
        generated_samples = self.generate_samples(num_samples)
        
        grid_size = (10, 10)
        fig, axarr = plt.subplots(*grid_size, figsize=(10, 10))
        
        for i in range(grid_size[0]):
            for j in range(grid_size[1]):
                sample = generated_samples[i * grid_size[1] + j].view(3, 16, 16).permute(1,2,0).numpy()
                sample = np.clip(sample, 0, 1)  # Clip the values of the sample to be in the range [0,1]
                axarr[i, j].imshow(sample)
                axarr[i, j].axis('off')

        plt.suptitle(f"Generated Samples ({self.num_components} Components)")
        plt.show()
        

    def calculate_entropy(self, labels):
        # Calculate entropy of labels using PyTorch
        unique_labels, label_counts = torch.unique(labels, return_counts=True)
        probabilities = label_counts.float() / len(labels)
        entropy_value = -torch.sum(probabilities * torch.log2(probabilities))
        return entropy_value

    def calculate_mutual_information(self, class_labels, cluster_assignments):
        # Convert class_labels and cluster_assignments to 'Float'
        class_labels = class_labels.float()
        cluster_assignments = cluster_assignments.float()

        # Calculate mutual information between class labels and cluster assignments using PyTorch
        max_bins = int((class_labels.max() + 1) * (cluster_assignments.max() + 1))

        joint_distribution = torch.histc(
            class_labels * (cluster_assignments.max() + 1) + cluster_assignments,
            bins=max_bins,
            min=0,
            max=max_bins
        )
        joint_distribution = joint_distribution.view(int(class_labels.max() + 1), int(cluster_assignments.max() + 1))
        p_class = joint_distribution.sum(dim=1) / len(class_labels)
        p_cluster = joint_distribution.sum(dim=0) / len(class_labels)
        p_joint = joint_distribution / len(class_labels)

        mutual_info = torch.sum(p_joint * torch.log2(p_joint / (torch.outer(p_class, p_cluster) + 1e-12) + 1e-12))
        return mutual_info



    def calculate_nmi(self, class_labels, cluster_assignments):
        # Calculate mutual information between class labels and cluster assignments
        mutual_info = self.calculate_mutual_information(class_labels, cluster_assignments)

        # Calculate entropy of class labels and cluster assignments
        entropy_class = self.calculate_entropy(class_labels)
        entropy_cluster = self.calculate_entropy(cluster_assignments)

        # Calculate NMI using the formula
        nmi = 2 * mutual_info / (entropy_class + entropy_cluster)
        return nmi




In [None]:
# Define the path to your dataset and other hyperparameters
data_dir_train = 'afhq/train/' 
data_dir_val = 'afhq/val/'
shuffle_dataset = True  
num_workers = 4  


transform = transforms.Compose([
    transforms.Resize((16,16)), 
    transforms.ToTensor(),  
])


train_dataset = ImageFolder(root='afhq/train/', transform=transform)
#val_dataset = ImageFolder(root='afhq/val/', transform=transform)

train_dataloader = DataLoader(
    train_dataset,
    batch_size=32,
    shuffle=shuffle_dataset,
    num_workers=num_workers
)

data = []
for images, _ in train_dataloader:
    flattened_images = images.view(images.size(0), -1)  # Flatten each image
    data.append(flattened_images)

data = torch.cat(data, dim=0)

num_mixture = [2, 3, 4, 8, 10, 12, 14, 16]
nmi=[]

for num_components in num_mixture:
    print(f"Training GMM with {num_components} components...")

    # Create GMM model and train
    gmm_model = GMM(num_components=num_components, num_iterations=40, tolerance=1e-4)
    gmm_model.train(data)

    # Plot likelihood curve
    gmm_model.plot_likelihood_curve()

    # Generate and visualize samples:
    gmm_model.visualize_samples(100)

    cluster_assignments = gmm_model.E_Step(data).argmax(dim=1).cpu()
    true_labels = torch.tensor(train_dataset.targets) 

    #Calculate NMI
    nmi_score = gmm_model.calculate_nmi(true_labels, cluster_assignments)
    nmi.append(nmi_score)
    print(f"NMI for {num_components} components: {nmi_score}")


In [None]:
plt.plot(num_mixture,nmi)

* A GMM class is defined, which has methods for initializing model parameters, performing the Expectation-Maximization (E-Step and M-Step) for training, computing the likelihood of data given the model, training the model, generating samples, and visualizing the samples.

* The code is set up to train the GMM on a dataset of images. It loads and preprocesses images using the transforms.Compose function, and creates a DataLoader for the training dataset. The images are flattened to create feature vectors for GMM training.

* The GMM is trained with different numbers of mixture components (num_mixture) ranging from 2 to 16. For each number of components, the GMM is trained, the likelihood curve is plotted, and samples are generated and visualized.

 * The code calculates the Normalized Mutual Information (NMI) between the true class labels and the cluster assignments obtained from the GMM. NMI is used to evaluate the quality of clustering.

* The NMI scores for different numbers of mixture components are stored in the nmi list for further analysis. And likelihood curves and generated samples shown above.

* NMI scores are very less due to size of image 16 x 16(due to computation limitation), due to the small size it can't able to detect object clearly and classifying most of images in one cluster.

* NMI scores for the mixture components [2, 3, 4, 8, 10, 12, 14, 16] are [0.00010338137508369982, 0.00023964836145751178, 0.00034696151851676404,0.0014466078719124198, 0.0012399664847180247, 8.245532080763951e-05,0.00047308814828284085, 8.962965512182564e-05]

* It was observed that increasing size of the image like 32 x 32, there is chances of cov-matrix is no more semi positive definite.

# Question 2: Vanilla VAE

In [None]:
class VAE(nn.Module):
    def __init__(self, input_channels, latent_dim):
        super(VAE, self).__init__()
        
        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(input_channels, 4, kernel_size=4, stride=2, padding=1),  
            nn.ReLU(),
            nn.BatchNorm2d(4),
            nn.Conv2d(4, 8, kernel_size=4, stride=2, padding=1),  
            nn.ReLU(),
            nn.BatchNorm2d(8),
            nn.Conv2d(8, 16, kernel_size=4, stride=2, padding=1), 
            nn.ReLU(),
            nn.BatchNorm2d(16),
            nn.Conv2d(16, 32, kernel_size=4, stride=2, padding=1), 
            nn.ReLU(),
            nn.BatchNorm2d(32),
            nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1), 
            nn.ReLU(),
            nn.BatchNorm2d(64),
        )  # output shape(Batch Size, 64, 4, 4)
        
        # Separate linear layers for mean and log-variance
        self.fc_mu = nn.Linear(64 * 4 * 4, latent_dim)
        self.fc_log_var = nn.Linear(64 * 4 * 4, latent_dim)
        
        #self.z_back = torch.nn.Linear(latent_dim, 64 * 4 *4)
        
        # Decoder
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, 64 * 4 * 4),
            nn.ReLU(),
            nn.BatchNorm1d(64 * 4 * 4),
            nn.Unflatten(1, (64, 4, 4)),
            nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),  
            nn.ReLU(),
            nn.BatchNorm2d(32),
            nn.ConvTranspose2d(32, 16, kernel_size=4, stride=2, padding=1), 
            nn.ReLU(),
            nn.BatchNorm2d(16),
            nn.ConvTranspose2d(16, 8, kernel_size=4, stride=2, padding=1),  
            nn.ReLU(),
            nn.BatchNorm2d(8),
            nn.ConvTranspose2d(8, 4, kernel_size=4, stride=2, padding=1),  
            nn.ReLU(),
            nn.BatchNorm2d(4),            
            nn.ConvTranspose2d(4, input_channels, kernel_size=4, stride=2, padding=1),  
            nn.Sigmoid()  # Sigmoid activation for pixel values between 0 and 1
        )

    def reparameterize(self, mu, log_var):
        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)
        return mu + eps * std
    
    def forward(self, x):
        # Encode
        enc_output = self.encoder(x)
#        print('Encoder', enc_output.shape)
        enc_output = enc_output.view(enc_output.size(0), -1)  # Flatten the output
#        print('Encoder', enc_output.shape)
        mu = self.fc_mu(enc_output)
        log_var = self.fc_log_var(enc_output)
#         print('mu',mu.shape)
#         print('logvar',log_var.shape)
        # Reparameterization
        z = self.reparameterize(mu, log_var)
         
#        print(z.shape)
        # Decode
        dec_output = self.decoder(z)  # Reshape z for convolutional decoder
        return dec_output, mu, log_var


In [None]:
class VAELoss(nn.Module):
    def __init__(self):
        super(VAELoss, self).__init__()

    def forward(self, x, x_recon, mu, logvar, latent_dim, image_size):
        # Calculate the likelihood loss (reconstruction loss) using MSE
        likelihood_loss = nn.functional.mse_loss(x_recon, x)

        # Calculate the KL divergence loss using torch.distributions.kl_divergence
        # Define the prior distribution as a standard normal (mean=0, std=1)
        prior = torch.distributions.Normal(0, 1)
        
        # Define the variational posterior distribution based on mu and logvar
        posterior = torch.distributions.Normal(mu, logvar.exp().sqrt())
        
        # Calculate the KL divergence between the posterior and the prior
        kl_loss = torch.mean(torch.distributions.kl.kl_divergence(posterior, prior))
        kl_loss= kl_loss * (latent_dim / (3 * image_size * image_size))

        # Total loss is the sum of likelihood loss and KL loss
        total_loss = likelihood_loss + kl_loss

        return total_loss, likelihood_loss, kl_loss


In [None]:
# Set hyperparameters
latent_dim = 32
learning_rate = 0.001
batch_size = 128
num_epochs = 30
input_dim = 3
image_size = 128
num_samples_z = [1,3,5,7]

In [None]:
# Define the path to your dataset and other hyperparameters
data_dir_train = 'afhq/train/' 
data_dir_val = 'afhq/val/'
shuffle_dataset = True  
num_workers = 4  


transform = transforms.Compose([
    transforms.Resize((image_size, image_size)), 
    transforms.ToTensor(),  
])


train_dataset = ImageFolder(root='afhq/train/', transform=transform)
val_dataset = ImageFolder(root='afhq/val/', transform=transform)

train_dataloader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=shuffle_dataset,
    num_workers=num_workers
)

val_dataloader = DataLoader(
    val_dataset,
    batch_size=100,
    shuffle=shuffle_dataset,
    num_workers=num_workers
)


In [None]:
vae = VAE(input_dim, latent_dim)  
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 
optimizer = optim.Adam(vae.parameters(), lr=learning_rate)
loss_fn = VAELoss()
vae.to(device)

In [None]:
def plot(total_losses, likelihood_losses, kl_losses, num):
    plt.figure(figsize=(12, 6))

    # Training losses
    plt.plot(range(1, len(total_losses) + 1), total_losses, label='Train Total Loss', linestyle='--', color='blue')
    plt.plot(range(1, len(likelihood_losses) + 1), likelihood_losses, label='Train Reconstruction Loss', linestyle='--', color='green')
    plt.plot(range(1, len(kl_losses) + 1), kl_losses, label='Train KL Divergence Loss', linestyle='--', color='red')

    # Validation losses
    plt.plot(range(1, len(val_total_losses) + 1), val_total_losses, label='Validation Total Loss', color='blue')
    plt.plot(range(1, len(val_likelihood_losses) + 1), val_likelihood_losses, label='Validation Reconstruction Loss', color='green')
    plt.plot(range(1, len(val_kl_losses) + 1), val_kl_losses, label='Validation KL Divergence Loss', color='red')

    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.title(f'Loss Curves (Training & Validation)  for num_sample_z = {num}')
    plt.grid(True, which='both', linestyle='--', linewidth=0.5)
    plt.tight_layout()
    plt.show()


In [None]:
# Function to generate a grid of original and reconstructed images
def plot_original_reconstructed(vae, dataloader, device, num, num_samples=100):
    vae.eval()  # Set the model to evaluation mode
    with torch.no_grad():
        # Sample random images from the validation dataset
        data_iter = iter(dataloader)
        images, _ = next(data_iter)
        images = images.to(device)
        
        # Reconstruct the sampled images
        reconstructed_images, _, _ = vae(images)
        
        # Create two separate grids for original and reconstructed images
        original_grid = vutils.make_grid(images[:num_samples], nrow=10, padding=2, normalize=True)
        reconstructed_grid = vutils.make_grid(reconstructed_images[:num_samples], nrow=10, padding=2, normalize=True)
        
        # Combine original and reconstructed grids side by side
        combined_grid = torch.cat([original_grid, reconstructed_grid], dim=2)
        
        # Plot the combined grid of images
        plt.figure(figsize=(20, 10))
        plt.imshow(combined_grid.permute(1, 2, 0).cpu().numpy())
        plt.axis('off')
        plt.title(f'Original (Left) vs. Reconstructed (Right) Images for num_sample_z = {num}')
        plt.show()



# Function to generate random images
def generate_images(vae, device, num, num_samples=100):
    vae.eval()  # Set the model to evaluation mode
    with torch.no_grad():
        # Generate random samples from the latent space
        latent_samples = torch.randn(num_samples, latent_dim).to(device)
        
        # Pass the samples through the decoder to generate images
        generated_images = vae.decoder(latent_samples)
        
        # Create a grid of generated images
        generated_grid = vutils.make_grid(generated_images, nrow=10, padding=2, normalize=True)
        
        # Plot the grid of generated images
        plt.figure(figsize=(10, 5))
        plt.imshow(generated_grid.permute(1, 2, 0).cpu().numpy())
        plt.axis('off')
        plt.title(f'Generated Images for num_sample_z = {num}')
        plt.show()



In [None]:
for num in num_samples_z:
    # Initialize lists to store loss values for plotting
    total_losses = []
    likelihood_losses = []
    kl_losses = []

    early_stopping_patience = 5
    epochs_without_improvement = 0
    best_val_loss = float('inf')

    # Lists to store validation losses
    val_total_losses = []
    val_likelihood_losses = []
    val_kl_losses = []

    # Training loop with tqdm
    for epoch in range(num_epochs):
        total_loss_epoch = 0.0
        likelihood_loss_epoch = 0.0
        kl_loss_epoch = 0.0

        # Use tqdm to create a progress bar for the training loop
        progress_bar = tqdm(enumerate(train_dataloader), total=len(train_dataloader))

        for batch_idx, batch in progress_bar:
            optimizer.zero_grad()
            x,_ = batch


            batch_likelihood_loss = 0
            batch_kl_loss = 0
            _, mu, logvar = vae(x)

            for _ in range(num):
    #             x_recon, mu, logvar = vae(x)
                z = vae.reparameterize(mu, logvar)
                x_recon = vae.decoder(z)
                _, likelihood_loss, kl_loss = loss_fn(x, x_recon, mu, logvar, latent_dim, image_size)
                batch_likelihood_loss += likelihood_loss
                batch_kl_loss += kl_loss        

            batch_likelihood_loss /= num
            batch_kl_loss /= num
            total_loss = batch_likelihood_loss + batch_kl_loss

            total_loss.backward()
            torch.nn.utils.clip_grad_norm_(vae.parameters(), max_norm=1)

            optimizer.step()

            # Update loss values for the current epoch
            total_loss_epoch += total_loss.item()
            likelihood_loss_epoch += likelihood_loss.item()
            kl_loss_epoch += kl_loss.item()

            # Update the progress bar description
            progress_bar.set_description(f'Epoch [{epoch + 1}/{num_epochs}]')

        # Calculate and store the average loss for the epoch
        avg_total_loss = total_loss_epoch / len(train_dataloader)
        avg_likelihood_loss = likelihood_loss_epoch / len(train_dataloader)
        avg_kl_loss = kl_loss_epoch / len(train_dataloader)

        total_losses.append(avg_total_loss)
        likelihood_losses.append(avg_likelihood_loss)
        kl_losses.append(avg_kl_loss)

        print(f'Total Loss: {avg_total_loss}, Likelihood Loss: {avg_likelihood_loss}, KL Loss: {avg_kl_loss}')
         # Validation loop
        vae.eval()
        with torch.no_grad():
            val_loss_epoch = 0.0
            val_likelihood_loss_epoch = 0.0
            val_kl_loss_epoch = 0.0

            for batch in val_dataloader:
                x, _ = batch
                _, mu, logvar = vae(x)

                val_batch_likelihood_loss = 0
                val_batch_kl_loss = 0

                for _ in range(num):
                    z = vae.reparameterize(mu, logvar)
                    x_recon = vae.decoder(z)
                    _, likelihood_loss, kl_loss = loss_fn(x, x_recon, mu, logvar, latent_dim, image_size)

                    val_batch_likelihood_loss += likelihood_loss
                    val_batch_kl_loss += kl_loss

                val_batch_likelihood_loss /= num
                val_batch_kl_loss /= num
                val_total_loss = val_batch_likelihood_loss + val_batch_kl_loss

                val_loss_epoch += val_total_loss.item()
                val_likelihood_loss_epoch += val_batch_likelihood_loss.item()
                val_kl_loss_epoch += val_batch_kl_loss.item()

            avg_val_total_loss = val_loss_epoch / len(val_dataloader)
            avg_val_likelihood_loss = val_likelihood_loss_epoch / len(val_dataloader)
            avg_val_kl_loss = val_kl_loss_epoch / len(val_dataloader)

            val_total_losses.append(avg_val_total_loss)
            val_likelihood_losses.append(avg_val_likelihood_loss)
            val_kl_losses.append(avg_val_kl_loss)

            print(f'Validation - Total Loss: {avg_val_total_loss}, Likelihood Loss: {avg_val_likelihood_loss}, KL Loss: {avg_val_kl_loss}')

            # Early stopping
            if avg_val_total_loss < best_val_loss:
                best_val_loss = avg_val_total_loss
                epochs_without_improvement = 0
                torch.save(vae.state_dict(), 'best_model_weights.pth') # save the model weights
            else:
                epochs_without_improvement += 1

            if epochs_without_improvement >= early_stopping_patience:
                print(f'Early stopping after {early_stopping_patience} epochs without improvement.')
                break

        vae.train()  # Set the model back to training mode at the end of the validation loop

        
    #Plottingof Total loss, Likelihood loss and KL loss
    plot(total_losses, likelihood_losses, kl_losses, num)
    # Plot original and reconstructed images
    plot_original_reconstructed(vae, val_dataloader, device, num)
    # Generate and plot new images
    generate_images(vae, device, num)



* VAE Architecture and Loss Function: The VAE architecture is defined in the VAE class. It consists of an encoder and a decoder.The VAELoss class defines the custom loss function for training the VAE. It includes a reconstruction loss (mean squared error) and a KL divergence loss to regularize the latent space.

* Hyperparameters: Key hyperparameters such as latent_dim, learning_rate, batch_size, num_epochs, input_dim, and image_size are set. num_samples_z is a list containing the different numbers of samples drawn from the latent space (z) that will be experimented with.

* Data Preparation: The code prepares the dataset for training the VAE using PyTorch's data loading utilities, such as ImageFolder for loading image data and creating data loaders.

* Training Loop: The code runs a training loop for each value of num_samples_z. Within the loop, the VAE is trained for a specified number of epochs. During each training epoch, the code computes the total loss, likelihood loss, and KL divergence loss for both the training and validation datasets. The training loop also includes early stopping, which monitors the validation loss and stops training if no improvement is observed after a certain number of epochs.

* Plotting and Visualization: After training, the code generates plots to visualize the training and validation losses for each experiment. It also generates and plots original and reconstructed images to assess the quality of the VAE's reconstructions. Finally, it generates new images from the VAE and plots them to see how well the model can generate novel data.

* We are carefully adjusting the model's architecture and training process to mitigate the risk of gradients becoming NaN values during training. This includes strategies to ensure numerical stability, such as gradient clipping, proper weight initialization, learning rate scheduling, and regularization techniques. These measures help us prevent numerical instability and maintain a stable training process for the VAE.


### CNN

In [None]:
# Define the CNN-based classifier model using nn.Sequential
class CNNClassifier(nn.Module):
    def __init__(self, num_classes):
        super(CNNClassifier, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(16, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        self.classifier = nn.Sequential(
            nn.Linear(64 * 16 * 16, 128),  
            nn.ReLU(),
            nn.Linear(128, num_classes)
        )

    def forward(self, x):
        x = self.features(x)
        #print(x.shape)
        x = x.view(-1, 64 * 16 * 16)  
        x = self.classifier(x)
        return x


num_classes = len(np.unique(train_dataset.targets))   
# Initialize the classifier model and optimizer
classifier = CNNClassifier(num_classes)
optimizer = optim.Adam(classifier.parameters(), lr=learning_rate)

# Define the loss function 
criterion = nn.CrossEntropyLoss()

train_losses = []
val_losses = []
val_accuracies = []

best_val_loss = float('inf')  
patience = 5  
counter = 0  

for epoch in range(num_epochs):
    classifier.train()
    total_train_loss = 0.0
    for batch_data, batch_labels in tqdm(train_dataloader, desc=f'Epoch {epoch + 1}/{num_epochs}'):
        optimizer.zero_grad()
        batch_data = batch_data.to(device)
        batch_labels = batch_labels.to(device)
        outputs = classifier(batch_data)
        loss = criterion(outputs, batch_labels)
        loss.backward()
        optimizer.step()
        total_train_loss += loss.item()

    # Calculate average training loss for this epoch
    avg_train_loss = total_train_loss / len(train_dataloader)
    train_losses.append(avg_train_loss)
    
    # Validation loop
    classifier.eval()
    total_val_loss = 0.0
    correct_predictions = 0
    total_predictions = 0

    with torch.no_grad():
        for batch_data, batch_labels in val_dataloader:
            batch_data = batch_data.to(device)
            batch_labels = batch_labels.to(device)
            outputs = classifier(batch_data)
            loss = criterion(outputs, batch_labels)
            total_val_loss += loss.item()
            
            # Calculate validation accuracy
            _, predicted = torch.max(outputs.data, 1)
            total_predictions += batch_labels.size(0)
            correct_predictions += (predicted == batch_labels).sum().item()
    
    # Calculate average validation loss for this epoch
    avg_val_loss = total_val_loss / len(val_dataloader)
    val_losses.append(avg_val_loss)
    
    # Calculate validation accuracy
    val_accuracy = (correct_predictions / total_predictions) * 100.0
    val_accuracies.append(val_accuracy)

    # Print training and validation losses and accuracy for this epoch
    print(f'Epoch [{epoch + 1}/{num_epochs}] Train Loss: {avg_train_loss:.4f}, Validation Loss: {avg_val_loss:.4f}, Validation Accuracy: {val_accuracy:.2f}%')

    # Implement early stopping
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        counter = 0
        # Save the best model's state
        best_model_state = classifier.state_dict()
    else:
        counter += 1
        if counter >= patience:
            print(f'Early stopping at epoch {epoch + 1} due to no improvement in validation loss.')
            break

# Restore the best model's state
classifier.load_state_dict(best_model_state)

# After training, you can access the final validation accuracy using val_accuracies[-1]
final_val_accuracy = val_accuracies[-1]
print(f'Final Validation Accuracy: {final_val_accuracy:.2f}%')


#### CNN-based Classifier:

#### Advantages:
* Achieves a high final validation accuracy (96.20%).
* Effective in extracting image features directly from raw data.
#### Considerations:
* Larger model size and computation requirements.
* May require a larger dataset for optimal performance.

### Posterior Inference

In [None]:
def get_latent_vector(vae, dataloader, device):
    vae.eval()  
    latent_vectors = []
    labels = []
    
    with torch.no_grad():
        for images, labels_batch in dataloader:
            images = images.to(device)
            _, mu, log_var = vae(images)  # Get mu and log_var
            z = vae.reparameterize(mu, log_var)  # Sample latent vectors
            latent_vectors.append(z.cpu().numpy())
            labels.append(labels_batch.numpy())
    
    # Concatenate the latent vectors and labels
    latent_vectors = np.concatenate(latent_vectors, axis=0)
    labels = np.concatenate(labels, axis=0)
    
    return latent_vectors, labels

train_latent_vectors, train_labels = get_latent_vector(vae, train_dataloader, device)
val_latent_vectors, val_labels = get_latent_vector(vae, val_dataloader, device)


In [None]:
train_latent_vectors.shape, train_labels.shape, val_latent_vectors.shape,  val_labels.shape

### MLP classifier

In [None]:
class MLPClassifier(nn.Module):
    def __init__(self, input_size, num_classes):
        super(MLPClassifier, self).__init__()
        
        self.mlp_classifier = nn.Sequential(
            nn.Linear(input_size, 256),  
            nn.ReLU(),
            nn.BatchNorm1d(256),  
            nn.Linear(256, 128),  
            nn.ReLU(),
            nn.BatchNorm1d(128),  
            nn.Linear(128, 64), 
            nn.ReLU(),
            nn.BatchNorm1d(64),  
            nn.Linear(64, 32), 
            nn.ReLU(),
            nn.BatchNorm1d(32),  
            nn.Linear(32, 16), 
            nn.ReLU(),
            nn.BatchNorm1d(16),  
            nn.Linear(16, num_classes)
        )

    def forward(self, x):
        x = self.mlp_classifier(x)
        return x

    
num_classes = len(np.unique(train_dataset.targets))   
mlp_cls = MLPClassifier(latent_dim, num_classes)
mlp_cls.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(mlp_cls.parameters(), lr=learning_rate)
train_losses = []
val_losses = []
val_accuracies = []

# Define your early stopping criteria
patience = 5
best_val_loss = float('inf')
counter = 0

# Initialize tqdm progress bar
train_dataloader_length = len(train_labels)
progress_bar = tqdm(range(num_epochs), desc="Training Progress", dynamic_ncols=True)

# Training Loop
for epoch in progress_bar:
    mlp_cls.train()
    total_train_loss = 0.0

    # Convert numpy arrays to tensors
    train_dataset1 = torch.utils.data.TensorDataset(torch.tensor(train_latent_vectors, dtype=torch.float32), torch.tensor(train_labels, dtype=torch.long))
    train_loader1 = torch.utils.data.DataLoader(train_dataset1, batch_size=batch_size, shuffle=True)

    for inputs, labels in train_loader1:
        inputs = inputs.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()
        outputs = mlp_cls(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        total_train_loss += loss.item()

    # Calculate average training loss for this epoch
    average_train_loss = total_train_loss / len(train_loader1)

    # Validation
    mlp_cls.eval()
    total_val_loss = 0.0
    correct = 0
    total = 0

    with torch.no_grad():
        val_dataset1 = torch.utils.data.TensorDataset(torch.tensor(val_latent_vectors, dtype=torch.float32), torch.tensor(val_labels, dtype=torch.long))
        val_loader1 = torch.utils.data.DataLoader(val_dataset1, batch_size=batch_size, shuffle=False)

        for inputs, labels in val_loader1:
            inputs = inputs.to(device)
            labels = labels.to(device)
            val_outputs = mlp_cls(inputs)
            val_loss = criterion(val_outputs, labels)
            total_val_loss += val_loss.item()
            _, predicted = torch.max(val_outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    # Calculate average validation loss for this epoch
    average_val_loss = total_val_loss / len(val_loader1)

    # Calculate validation accuracy
    val_accuracy = 100 * correct / total

    # Print training and validation statistics using tqdm
    progress_bar.set_description(f"Epoch {epoch + 1}/{num_epochs}, Train Loss: {average_train_loss:.4f}, Val Loss: {average_val_loss:.4f}, Val Accuracy: {val_accuracy:.2f}%")

    # Check for early stopping
    if average_val_loss < best_val_loss:
        best_val_loss = average_val_loss
        counter = 0
    else:
        counter += 1

    if counter >= patience:
        print("Early stopping triggered.")
        break

# Calculate final accuracy on validation data
final_val_accuracy = val_accuracy

print(f"Final Validation Accuracy: {final_val_accuracy:.2f}%")


#### MLP-based Classifier with Latent Vectors:

#### Advantages:
* Achieves a respectable validation accuracy (81.20%).
* Smaller model size, computationally efficient.
* Effective transfer learning from VAE latent vectors.
#### Considerations:
* Lower accuracy compared to the CNN-based classifier.

In [None]:
torch.save(vae.state_dict(), 'q2_vanila_vae.pth')

### Question 3: Beta VAE

In [None]:
beta = [.01, .1, .5, .8, .95, 1, 1.05, 1.5, 5, 10, 100, 500 ]
num_samples_z = 1

num_epochs = 10
for i in beta:
    vae = VAE(input_dim, latent_dim)  
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 
    optimizer = optim.Adam(vae.parameters(), lr=learning_rate)
    loss_fn = VAELoss()
    vae.to(device)
    
    # Initialize lists to store loss values for plotting
    total_losses = []
    likelihood_losses = []
    kl_losses = []

    # Training loop with tqdm
    for epoch in range(num_epochs):
        total_loss_epoch = 0.0
        likelihood_loss_epoch = 0.0
        kl_loss_epoch = 0.0

        # Use tqdm to create a progress bar for the training loop
        progress_bar = tqdm(enumerate(train_dataloader), total=len(train_dataloader))

        for batch_idx, batch in progress_bar:
            optimizer.zero_grad()
            x,_ = batch
        
        
            batch_likelihood_loss = 0
            batch_kl_loss = 0
        
            for _ in range(num_samples_z):
                x_recon, mu, logvar = vae(x)
                _, likelihood_loss, kl_loss = loss_fn(x, x_recon, mu, logvar, latent_dim, image_size)
                batch_likelihood_loss += likelihood_loss
                batch_kl_loss += kl_loss        
        
            batch_likelihood_loss /= num_samples_z
            batch_kl_loss /= num_samples_z
            total_loss = batch_likelihood_loss + i * batch_kl_loss
        
            total_loss.backward()
            torch.nn.utils.clip_grad_norm_(vae.parameters(), max_norm=1)

            optimizer.step()

            # Update loss values for the current epoch
            total_loss_epoch += total_loss.item()
            likelihood_loss_epoch += likelihood_loss.item()
            kl_loss_epoch += kl_loss.item()

            # Update the progress bar description
            progress_bar.set_description(f'Epoch [{epoch + 1}/{num_epochs}]')

        # Calculate and store the average loss for the epoch
        avg_total_loss = total_loss_epoch / len(train_dataloader)
        avg_likelihood_loss = likelihood_loss_epoch / len(train_dataloader)
        avg_kl_loss = kl_loss_epoch / len(train_dataloader)

        total_losses.append(avg_total_loss)
        likelihood_losses.append(avg_likelihood_loss)
        kl_losses.append(avg_kl_loss)
    
        print(f'Total Loss: {avg_total_loss}, Likelihood Loss: {avg_likelihood_loss}, KL Loss: {avg_kl_loss}')
    

   
    # Function to generate a grid of original and reconstructed images
    def plot_original_reconstructed(vae, dataloader, device, num_samples=100):
        vae.eval()  # Set the model to evaluation mode
        with torch.no_grad():
            # Sample random images from the validation dataset
            data_iter = iter(dataloader)
            images, _ = next(data_iter)
            images = images.to(device)
        
            # Reconstruct the sampled images
            reconstructed_images, _, _ = vae(images)
        
            # Create two separate grids for original and reconstructed images
            original_grid = vutils.make_grid(images[:num_samples], nrow=10, padding=2, normalize=True)
            reconstructed_grid = vutils.make_grid(reconstructed_images[:num_samples], nrow=10, padding=2, normalize=True)
        
            # Combine original and reconstructed grids side by side
            combined_grid = torch.cat([original_grid, reconstructed_grid], dim=2)
        
            # Plot the combined grid of images
            plt.figure(figsize=(20, 10))
            plt.imshow(combined_grid.permute(1, 2, 0).cpu().numpy())
            plt.axis('off')
            plt.title(f'Original (Left) vs. Reconstructed (Right) Images for beta =  {i}')
            plt.show()

    # Plot original and reconstructed images
    plot_original_reconstructed(vae, val_dataloader, device)

    # Function to generate random images
    def generate_images(vae, device, num_samples=100):
        vae.eval()  # Set the model to evaluation mode
        with torch.no_grad():
            # Generate random samples from the latent space
            latent_samples = torch.randn(num_samples, latent_dim).to(device)
        
            # Pass the samples through the decoder to generate images
            generated_images = vae.decoder(latent_samples)
        
            # Create a grid of generated images
            generated_grid = vutils.make_grid(generated_images, nrow=10, padding=2, normalize=True)
        
            # Plot the grid of generated images
            plt.figure(figsize=(10, 5))
            plt.imshow(generated_grid.permute(1, 2, 0).cpu().numpy())
            plt.axis('off')
            plt.title(f'Generated Images for beta = {i}')
            plt.show()

    # Generate and plot new images
    generate_images(vae, device)
    filename = f'q2_vanilla_vae {i}.pth'
    torch.save(vae.state_dict(), filename)


#### Effect of Beta on Training:

* As "beta" increases, the weight of the KL divergence term in the VAE loss becomes more significant relative to the reconstruction loss. This encourages the VAE to have a more structured latent space, which may lead to better-organized representations of data.

#### Impact on Reconstruction:

* Lower values of "beta" (e.g., 0.01, 0.1) result in less emphasis on the KL divergence term. As a result, the VAE may prioritize accurate image reconstruction over enforcing a structured latent space. This can lead to reconstructions that are closer to the original data, but the latent space may not be well-organized.

* Higher values of "beta" (e.g., 5, 10, 100, 500) give more weight to the KL divergence term. This encourages the VAE to have a more structured latent space but can lead to reconstructions that may deviate from the original data as the model seeks to impose a more defined structure on the latent space.

#### Effect on Image Generation:

* Low "beta" values (e.g., 0.01, 0.1) may produce generated images that closely resemble training data but exhibit less variety.

* Higher "beta" values (e.g., 5, 10, 100, 500) may result in generated images with more diversity, but they might not be as faithful to the original training data. The model prioritizes exploring the structured latent space, which can lead to novel but less realistic image samples.

In summary, adjusting the "beta" hyperparameter in a VAE allows you to control the trade-off between reconstruction quality and the structure of the latent space. Lower "beta" values prioritize reconstruction fidelity, while higher "beta" values encourage a more organized latent space and diverse image generation.

#### Posterior Inference for Beta = 0.8

In [None]:
vae = VAE(input_dim, latent_dim)  
vae.load_state_dict(torch.load('q2_vanilla_vae 0.8.pth'))
vae.eval()
num_interpolations = 10
num_pairs = 10
pair_indices = torch.randint(0, len(val_dataset), (num_pairs, 2))
for pair_index in pair_indices:
    # Choose two images from the dataset
    image1 = val_dataset[pair_index[0]]
    image2 = val_dataset[pair_index[1]]

    # Encode the images to get their latent representations
    with torch.no_grad():
        _, mu1, log_var1 = vae(image1[0].unsqueeze(0))
        _, mu2, log_var2 = vae(image2[0].unsqueeze(0))
        latent1 = vae.reparameterize(mu1, log_var1)
        latent2 = vae.reparameterize(mu2, log_var2)

    # Interpolate between the latent representations
    interpolated_latents = []
    for alpha in torch.linspace(0, 1, num_interpolations):
        interpolated_latent = alpha * latent1 + (1 - alpha) * latent2
        interpolated_latents.append(interpolated_latent)

    # Generate images from interpolated latents
    with torch.no_grad():
        interpolated_images = vae.decoder(torch.stack(interpolated_latents).view(10,32))

    # Plot the interpolated images
    plt.figure(figsize=(15, 1.5))
    for i, img in enumerate(interpolated_images):
        plt.subplot(1, num_interpolations, i + 1)
        plt.imshow(img.permute(1, 2, 0).cpu().numpy())  # Convert to numpy and permute the dimensions
        plt.axis('off')


    plt.suptitle(f'Interpolation between Pair {pair_index[0]} and Pair {pair_index[1]}')
    plt.show()

* The code provides a visual representation of the VAE's latent space by interpolating between the latent representations of pairs of images and generating images at various points along the interpolation path. 
* This demonstrates the VAE's ability to capture and manipulate image features in a structured and continuous latent space. 
* The results show that the VAE can smoothly transition between images, indicating its effectiveness in encoding and decoding visual information.

### Question 4: Vanila Autoencoder

In [None]:
class Vanila_AE(nn.Module):
    def __init__(self, input_channels, latent_dim):
        super(Vanila_AE, self).__init__()
        
        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(input_channels, 4, kernel_size=4, stride=2, padding=1),  # Convolutional layer 1
            nn.ReLU(),
            nn.BatchNorm2d(4),
            nn.Conv2d(4, 8, kernel_size=4, stride=2, padding=1),  # Convolutional layer 2
            nn.ReLU(),
            nn.BatchNorm2d(8),
            nn.Conv2d(8, 16, kernel_size=4, stride=2, padding=1),  # Convolutional layer 2
            nn.ReLU(),
            nn.BatchNorm2d(16),
            nn.Conv2d(16, 32, kernel_size=4, stride=2, padding=1),  # Convolutional layer 2
            nn.ReLU(),
            nn.BatchNorm2d(32),
            nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1),  # Convolutional layer 2
            nn.ReLU(),
            nn.BatchNorm2d(64),
            
        )  # output shape(Batch Size,latent_dim)
        
        
        self.fc_z = nn.Linear(64 * 4 * 4, latent_dim)       
        
        # Decoder
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, 64 * 4 * 4),
            nn.ReLU(),
            nn.BatchNorm1d(64 * 4 * 4),
            nn.Unflatten(1, (64, 4, 4)),
            nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),  # Transposed convolutional layer 1
            nn.ReLU(),
            nn.BatchNorm2d(32),
            nn.ConvTranspose2d(32, 16, kernel_size=4, stride=2, padding=1),  # Transposed convolutional layer 1
            nn.ReLU(),
            nn.BatchNorm2d(16),
            nn.ConvTranspose2d(16, 8, kernel_size=4, stride=2, padding=1),  # Transposed convolutional layer 1
            nn.ReLU(),
            nn.BatchNorm2d(8),
            nn.ConvTranspose2d(8, 4, kernel_size=4, stride=2, padding=1),  # Transposed convolutional layer 1
            nn.ReLU(),
            nn.BatchNorm2d(4),            
            nn.ConvTranspose2d(4, input_channels, kernel_size=4, stride=2, padding=1),  # Transposed convolutional layer 2
            nn.Sigmoid()  # Sigmoid activation for pixel values between 0 and 1
        )
        



    def forward(self, x):
        # Encode
#         print(x.shape)
        enc_output = self.encoder(x)
#         print('Encoder', enc_output.shape)
        enc_output = enc_output.view(enc_output.size(0), -1)  # Flatten the output
#        print('Encoder', enc_output.shape)
        z = self.fc_z(enc_output)

#         print(z.shape)
        # Decode
        dec_output = self.decoder(z)  # Reshape z for convolutional decoder
        return z, dec_output


In [None]:
ae=Vanila_AE(input_dim, latent_dim)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 
optimizer = optim.Adam(ae.parameters(), lr=learning_rate)
ae.to(device)

In [None]:
# Initialize lists to store loss values for plotting
total_losses = []
likelihood_losses = []
kl_losses = []
num_epochs = 10
# Training loop with tqdm
for epoch in range(num_epochs):
    total_loss_epoch = 0.0
#     likelihood_loss_epoch = 0.0
#     kl_loss_epoch = 0.0

    # Use tqdm to create a progress bar for the training loop
    progress_bar = tqdm(enumerate(train_dataloader), total=len(train_dataloader))

    for batch_idx, batch in progress_bar:
        optimizer.zero_grad()
        x,_ = batch
        
        
        batch_likelihood_loss = 0
        batch_kl_loss = 0
        

        latent, x_recon = ae(x)
        total_loss = nn.functional.mse_loss(x_recon, x)
        
        total_loss.backward()
        torch.nn.utils.clip_grad_norm_(ae.parameters(), max_norm=1)

        optimizer.step()

        # Update loss values for the current epoch
        total_loss_epoch += total_loss.item()

        # Update the progress bar description
        progress_bar.set_description(f'Epoch [{epoch + 1}/{num_epochs}]')

    # Calculate and store the average loss for the epoch
    avg_total_loss = total_loss_epoch / len(train_dataloader)

    total_losses.append(avg_total_loss)

    print(f'Total Loss: {avg_total_loss}')#, Likelihood Loss: {avg_likelihood_loss}, KL Loss: {avg_kl_loss}')
    



In [None]:
plt.figure(figsize=(10, 5))
plt.plot(range(1, num_epochs + 1), total_losses, label='Total Loss')

In [None]:
filename = f'q4_vanilla_ae.pth'
torch.save(ae.state_dict(), filename)

In [None]:
# Function to generate a grid of original and reconstructed images
def plot_original_reconstructed(ae, dataloader, device, num_samples=100):
    ae.eval()  # Set the model to evaluation mode
    with torch.no_grad():
        # Sample random images from the validation dataset
        data_iter = iter(dataloader)
        images, _ = next(data_iter)
        images = images.to(device)
        
        # Reconstruct the sampled images
        _, reconstructed_images = ae(images)
        
        # Create two separate grids for original and reconstructed images
        original_grid = vutils.make_grid(images[:num_samples], nrow=10, padding=2, normalize=True)
        reconstructed_grid = vutils.make_grid(reconstructed_images[:num_samples], nrow=10, padding=2, normalize=True)
        
        # Combine original and reconstructed grids side by side
        combined_grid = torch.cat([original_grid, reconstructed_grid], dim=2)
        
        # Plot the combined grid of images
        plt.figure(figsize=(20, 10))
        plt.imshow(combined_grid.permute(1, 2, 0).cpu().numpy())
        plt.axis('off')
        plt.title('Original (Left) vs. Reconstructed (Right) Images')
        plt.show()

# Plot original and reconstructed images
plot_original_reconstructed(ae, val_dataloader, device)

# Function to generate random images
def generate_images(ae, device, num_samples=100):
    ae.eval()  # Set the model to evaluation mode
    with torch.no_grad():
        # Generate random samples from the latent space
        latent_samples = torch.randn(num_samples, latent_dim).to(device)
        
        # Pass the samples through the decoder to generate images
        generated_images = ae.decoder(latent_samples)
        
        # Create a grid of generated images
        generated_grid = vutils.make_grid(generated_images, nrow=10, padding=2, normalize=True)
        
        # Plot the grid of generated images
        plt.figure(figsize=(10, 5))
        plt.imshow(generated_grid.permute(1, 2, 0).cpu().numpy())
        plt.axis('off')
        plt.title('Generated Images')
        plt.show()

# Generate and plot new images
generate_images(ae, device)


In [None]:
def get_latent_vector_ae(ae, dataloader, device):
    ae.eval()  
    latent_vectors = []
    labels = []
    
    with torch.no_grad():
        for images, labels_batch in dataloader:
            images = images.to(device)
            z,_ = ae(images)  # Get latent vectors
#             z = vae.reparameterize(mu, log_var)  # Sample 
            latent_vectors.append(z.cpu().numpy())
            labels.append(labels_batch.numpy())
    
    # Concatenate the latent vectors and labels
    latent_vectors = np.concatenate(latent_vectors, axis=0)
    labels = np.concatenate(labels, axis=0)
    
    return latent_vectors, labels

train_latent_vectors, train_labels = get_latent_vector_ae(ae, train_dataloader, device)
val_latent_vectors, val_labels = get_latent_vector_ae(ae, val_dataloader, device)




In [None]:
class MLPClassifier(nn.Module):
    def __init__(self, input_size, num_classes):
        super(MLPClassifier, self).__init__()
        
        self.mlp_classifier = nn.Sequential(
            nn.Linear(input_size, 256),  
            nn.ReLU(),
            nn.BatchNorm1d(256),  
            nn.Linear(256, 128),  
            nn.ReLU(),
            nn.BatchNorm1d(128),  
            nn.Linear(128, 64), 
            nn.ReLU(),
            nn.BatchNorm1d(64),  
            nn.Linear(64, 32), 
            nn.ReLU(),
            nn.BatchNorm1d(32),  
            nn.Linear(32, 16), 
            nn.ReLU(),
            nn.BatchNorm1d(16),  
            nn.Linear(16, num_classes)
        )

    def forward(self, x):
        x = self.mlp_classifier(x)
        return x

    
num_classes = len(np.unique(train_dataset.targets))   
mlp_cls = MLPClassifier(latent_dim, num_classes)
mlp_cls.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(mlp_cls.parameters(), lr=learning_rate)
train_losses = []
val_losses = []
val_accuracies = []

# Define your early stopping criteria
patience = 5
best_val_loss = float('inf')
counter = 0

# Initialize tqdm progress bar
train_dataloader_length = len(train_labels)
progress_bar = tqdm(range(num_epochs), desc="Training Progress", dynamic_ncols=True)

# Training Loop
for epoch in progress_bar:
    mlp_cls.train()
    total_train_loss = 0.0

    # Convert numpy arrays to tensors
    train_dataset1 = torch.utils.data.TensorDataset(torch.tensor(train_latent_vectors, dtype=torch.float32), torch.tensor(train_labels, dtype=torch.long))
    train_loader1 = torch.utils.data.DataLoader(train_dataset1, batch_size=batch_size, shuffle=True)

    for inputs, labels in train_loader1:
        inputs = inputs.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()
        outputs = mlp_cls(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        total_train_loss += loss.item()

    # Calculate average training loss for this epoch
    average_train_loss = total_train_loss / len(train_loader1)

    # Validation
    mlp_cls.eval()
    total_val_loss = 0.0
    correct = 0
    total = 0

    with torch.no_grad():
        val_dataset1 = torch.utils.data.TensorDataset(torch.tensor(val_latent_vectors, dtype=torch.float32), torch.tensor(val_labels, dtype=torch.long))
        val_loader1 = torch.utils.data.DataLoader(val_dataset1, batch_size=batch_size, shuffle=False)

        for inputs, labels in val_loader1:
            inputs = inputs.to(device)
            labels = labels.to(device)
            val_outputs = mlp_cls(inputs)
            val_loss = criterion(val_outputs, labels)
            total_val_loss += val_loss.item()
            _, predicted = torch.max(val_outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    # Calculate average validation loss for this epoch
    average_val_loss = total_val_loss / len(val_loader1)

    # Calculate validation accuracy
    val_accuracy = 100 * correct / total

    # Print training and validation statistics using tqdm
    progress_bar.set_description(f"Epoch {epoch + 1}/{num_epochs}, Train Loss: {average_train_loss:.4f}, Val Loss: {average_val_loss:.4f}, Val Accuracy: {val_accuracy:.2f}%")

    # Check for early stopping
    if average_val_loss < best_val_loss:
        best_val_loss = average_val_loss
        counter = 0
    else:
        counter += 1

    if counter >= patience:
        print("Early stopping triggered.")
        break

# Calculate final accuracy on validation data
final_val_accuracy = val_accuracy

print(f"Final Validation Accuracy: {final_val_accuracy:.2f}%")


#### Comparing accuracy with VAE 
* Accuracy with AE = 81.00% and VAE = 81.20%
* This similarity could be attributed to the quality of the features extracted by the encoder part of both AE and VAE.
* It's possible that the latent spaces learned by both models capture similar underlying patterns in the data.

In [None]:
train_latent_vectors = torch.from_numpy(train_latent_vectors).float().to(device)
val_latent_vectors = torch.from_numpy(val_latent_vectors).float().to(device)

In [None]:
gmm_on_latents = GMM(num_components=10, num_iterations=40, tolerance=1e-4)
gmm_on_latents.train(train_latent_vectors)
# Plot likelihood curve
gmm_on_latents.plot_likelihood_curve()

In [None]:
num_samples=100
gmm_on_latents.generate_samples(num_samples).shape

In [None]:
new_latents = gmm_on_latents.generate_samples(100)
with torch.no_grad():
    reconstructed_images = ae.decoder(new_latents)

reconstructed_grid = vutils.make_grid(reconstructed_images[:num_samples], nrow=10, padding=2, normalize=True)

plt.figure(figsize=(10, 5))
plt.imshow(reconstructed_grid.permute(1, 2, 0).cpu().numpy())  # Using reconstructed_grid instead of reconstructed_images
plt.axis('off')
plt.title('GMM parameter for AE')
plt.show()


* It is difficult to compare the results with GMM because image obtained from is of size 16 x 16 and image obtained from AE is 128 x 128.