In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch.nn as nn
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import DataLoader, Dataset, random_split
from torchvision import transforms
from torchvision.datasets import STL10
import torch

In [None]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,)*3, (0.5,)*3)
])

# Load all 105,000 training images (ignore labels)
full_dataset = STL10(root='./data', split='train+unlabeled', download=True, transform=transform)

train_percent = 0.9

train_size = int(train_percent * len(full_dataset))
test_size  = len(full_dataset) - train_size

train_set, test_set = random_split(full_dataset, [train_size, test_size])



# Part 1
## Section 1

In [None]:
class VAEEncoder(nn.Module):
    def __init__(self, z_dim):
        super().__init__()
        self.z_dim = z_dim
        self.conv1 = nn.Conv2d(3, 32, 4, 2, 1)  #(in_channels, out_channels, kernel_size, stride=1, padding=0, ...)
        self.conv2 = nn.Conv2d(32, 64, 4, 2, 1)
        self.conv3 = nn.Conv2d(64, 128, 4, 2, 1)



        self.fc_mu = nn.Linear(128 * 12 * 12, z_dim)
        self.fc_logvar = nn.Linear(128 * 12 * 12, z_dim) 
    
    def reparametrization(self, mu, logvar):
        assert mu.size() == logvar.size()
        epsilon = torch.randn_like(mu)
        #TODO ask professor if it's ok anyway, supposed to be better for gradient

        return mu + torch.exp(0.5 * logvar) * epsilon
        #logvar used instead of directly var to avoid negative values

    def forward(self, inputs):
        x = torch.relu(self.conv1(inputs))
        x = torch.relu(self.conv2(x))
        x = torch.relu(self.conv3(x))
        x = x.view(x.size(0), -1) # for the fully connected layer
        
        mu = self.fc_mu(x.view(x.size(0), -1))
        logvar = self.fc_logvar(x.view(x.size(0), -1))
        z = self.reparametrization(mu, logvar)
        return z, mu, logvar # mu, logvar for loss and tracking
    
    def save_params(self, path):
        torch.save(self.state_dict(), path)
    def load_params(self, path):
        self.load_state_dict(torch.load(path))

In [None]:
class VAEDecoder(nn.Module):
    def __init__(self, z_dim , **kwargs):
        super().__init__(**kwargs)
        self.z_dim = z_dim
        self.fc = nn.Linear(z_dim, 128 * 12 * 12)
        self.deconv1 = nn.ConvTranspose2d(128, 64, 4, 2, 1)
        self.deconv2 = nn.ConvTranspose2d(64, 32, 4, 2, 1)
        self.deconv3 = nn.ConvTranspose2d(32, 3, 4, 2, 1)


    def forward(self, z_values):
        x = torch.relu(self.fc(z_values))
        x = x.view(x.size(0), 128, 12, 12) # reshape for the conv layers

        x = torch.relu(self.deconv1(x))
        x = torch.relu(self.deconv2(x))
        x = torch.sigmoid(self.deconv3(x)) # sigmoid to have output between 0 and 1
        return x
    
    def save_params(self, path):
        torch.save(self.state_dict(), path)
    def load_params(self, path):
        self.load_state_dict(torch.load(path))            

In [None]:
class VAE(nn.Module):
    def __init__(self, z_dim):
        super().__init__()
        self.encoder = VAEEncoder(z_dim=z_dim)
        self.decoder = VAEDecoder(z_dim=z_dim)
    
    def forward(self, inputs):
        z, mu, logvar = self.encoder(inputs)
        reconstructed = self.decoder(z)
        return reconstructed, mu, logvar
    def save_params(self, path):
        torch.save(self.state_dict(), path)
    def load_params(self, path):
        self.load_state_dict(torch.load(path))    
    def generate(self, z_values):
        return self.decoder(z_values)    
    def encode(self, img):
        return self.encoder(img) # returns tensors z, mu, logvar
    def compute_loss(self, x, reconstructed, mu, logvar):
        # Reconstruction loss
        recon_loss = nn.functional.mse_loss(reconstructed, x, reduction='sum')
        #TODO Check 
        # KL Divergence
        kl_div = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
        div = x.size(0)  # batch size
        
        return (recon_loss + kl_div)/div, recon_loss/div, kl_div/div

## Section 2 
### Training 

In [None]:
#Training 
from torch.amp import autocast #trying to speed it up a bit
from torch.amp import GradScaler
import time

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')



n_epochs = 50
batch_size = 256

z_dim = 20
T_max = n_epochs  # Number of epochs for a full cosine cycle
eta_min = 1e-5  # Minimum learning rate

vae = VAE(z_dim=z_dim).to(device)
optimizer = torch.optim.Adam(vae.parameters(), lr=1e-3)

scheduler = CosineAnnealingLR(optimizer, T_max=T_max, eta_min=eta_min) 

train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers = 4)
test_loader  = DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers = 4)


Train_losses = {"total": [], "reconstruction": [], "kl_divergence": []}
Test_losses = {"total": [], "reconstruction": [], "kl_divergence": []}

scaler = GradScaler('cuda')

for epoch in range(n_epochs):
    vae.train()
    train_loss = 0
    train_recon_loss = 0
    train_kl_div = 0
    div_train = 0
    t0 = time.time()

    for batch_idx, (data, _) in enumerate(train_loader):

        data = data.to(device)
        optimizer.zero_grad()
        
        with autocast('cuda'): #dynamically changes precision to use FP32 or FP16/BF16, (trying to make this go faster, brum brum)  
          reconstructed, mu, logvar = vae(data)
          loss, recon_loss, kl_div = vae.compute_loss(data, reconstructed, mu, logvar)
          

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update() #

        train_mult = data.size(0) # last batch can be smaller, we do a weighted average
        div_train += train_mult

        train_loss += loss.item()*train_mult
        train_recon_loss += recon_loss.item() * train_mult 
        train_kl_div += kl_div.item() * train_mult
        

    

    scheduler.step()  

    
    Train_losses["total"].append(train_loss/ div_train)
    Train_losses["reconstruction"].append(train_recon_loss/div_train)
    Train_losses["kl_divergence"].append(train_kl_div/div_train)    
    
    vae.eval()
    with torch.no_grad():
        test_loss = 0
        test_recon_loss = 0
        test_kl_div = 0
        div_test = 0

        for batch_idx, (data, _) in enumerate(test_loader):

            data = data.to(device)
            reconstructed, mu, logvar = vae(data)
            loss_test, recon_loss_test, kl_div_test = vae.compute_loss(data, reconstructed, mu, logvar)
            
            test_mult = data.size(0) # last batch can be smaller, we do a weighted average
            div_test += test_mult

            test_loss += loss_test.item() * test_mult
            test_recon_loss += recon_loss_test.item() * test_mult   
            test_kl_div += kl_div_test.item() * test_mult

        Test_losses["total"].append(test_loss/ div_test)
        Test_losses["reconstruction"].append(test_recon_loss/div_test)
        Test_losses["kl_divergence"].append(test_kl_div/div_test)    
        t = (time.time()-t0) / 60
    print(f'Epoch {epoch+1}, Train Loss: {train_loss / len(train_loader.dataset)}, Test loss : {test_loss / len(test_loader.dataset)}, time taken: {t:.2f} minutes')


### Display losses 

In [None]:
# Display train and test losses 
plt.figure(figsize=(12,5))

plt.subplot(1, 2, 1)
plt.plot(Train_losses["total"], label='Total Loss', color='tab:blue')
plt.plot(Train_losses["reconstruction"], label='Reconstruction Loss', color='tab:orange')
plt.plot(Train_losses["kl_divergence"], label='KL Divergence', color='tab:green')
plt.title('Training Losses')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.grid(True, linestyle='--', alpha=0.5)

# --- Test Losses ---
plt.subplot(1, 2, 2)
plt.plot(Test_losses["total"], label='Total Loss', color='tab:blue')
plt.plot(Test_losses["reconstruction"], label='Reconstruction Loss', color='tab:orange')
plt.plot(Test_losses["kl_divergence"], label='KL Divergence', color='tab:green')
plt.title('Test Losses')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.grid(True, linestyle='--', alpha=0.5)

plt.tight_layout()
plt.show()



### Generate 10 images 

In [None]:
n = 10
z_values = torch.randn(n, z_dim).to(device) 
generated_images = vae.generate(z_values)
# Plot generated images
plt.figure(figsize=(10, 4))
for i in range(n):  
    ax = plt.subplot(2, 5, i + 1)
    plt.imshow(generated_images[i].cpu().permute(1, 2, 0).detach().numpy() * 0.5 + 0.5)  # unnormalize
    plt.axis("off")

# Part 2


In [None]:
#Pick two pictures at random from the set and encode them
import gc
import random

del full_dataset
gc.collect()

dataset = STL10(root='./data', split='train', download=True, transform=transform)


labels = dataset.labels # Make sure labels exist

classes = list(set(labels)) # Get unique classes
class1, class2 = random.sample(classes, 2) # Randomly pick two *different* classes
 
# Get indices for each class
indices_class1 = [i for i, y in enumerate(labels) if y == class1]
indices_class2 = [i for i, y in enumerate(labels) if y == class2]

# Randomly pick one image from each class
idx1 = random.choice(indices_class1)
idx2 = random.choice(indices_class2)

img1, label1 = dataset[idx1]
img2, label2 = dataset[idx2]
 
z1, _, _ = vae.encode(img1.unsqueeze(0).to(device)) # unsqueeze to add batch dimension
z2, _, _ = vae.encode(img2.unsqueeze(0).to(device)) # unsqueeze to add batch dimension

#zinterpolated = (1-lambda)z1 + lambda z2 where lambda in [0, 1]
lambdas = [0, 0.25,0.5,0.75,1]
z_values = [(1-lambda_) * z1 + lambda_ * z2 for lambda_ in lambdas]

In [None]:
# Decode interpolated values  and display images
decoded_batch = vae.generate(torch.cat(z_values, dim=0)) #cat cause the decoder expects a batch

fig, axs = plt.subplots(1, len(decoded_batch), figsize=(12, 3))
for i, img in enumerate(decoded_batch):
    img = img.permute(1, 2, 0) * 0.5 + 0.5  # denormalize
    axs[i].imshow(img.numpy())
    axs[i].set_title(f"λ={lambdas[i]}")
    axs[i].axis('off')
plt.show()