In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
from torch.optim.lr_scheduler import StepLR
from torchvision.utils import save_image
from torchvision.utils import make_grid
from torchmetrics.image.fid import FrechetInceptionDistance
import numpy as np
import matplotlib.pyplot as plt
import os
from PIL import Image
import random
from tqdm import tqdm
from sklearn.metrics import confusion_matrix
from sklearn.manifold import TSNE
import seaborn as sns
import wandb

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device Used: {device}")

Device Used: cuda


In [7]:
class AFHQDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        """
        root_dir: path of the parent directory that contains images.
        transforms: augmentations applied to the images (can be none or more).
        """
        self.root_dir = root_dir
        self.transform = transform
        self.image_paths = []
        self.labels = []
        self.class_mapping = {}
        
        extensions = (".jpg", ".jpeg", ".png")
        # go through all sub-directories
        for label, category in enumerate(sorted(os.listdir(root_dir))):
            full_path = os.path.join(root_dir, category)
            if os.path.exists(full_path):
                self.class_mapping[label] = category
                for img_name in os.listdir(full_path):
                    if img_name.endswith(extensions):
                        self.image_paths.append(os.path.join(full_path, img_name))
                        self.labels.append(label)

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert('RGB')
        label = self.labels[idx]
        
        if self.transform:
            image = self.transform(image)
        
        return image, label

In [8]:
# Dataset Hyperparameters
img_size = 64
batch_size = 64

# dataset paths
train_dir = '/home/user/javeda1/stargan-v2/data/afhq/train'
val_dir = '/home/user/javeda1/stargan-v2/data/afhq/val'

In [9]:
transform = transforms.Compose([
    
        # transforms.GaussianBlur(3, sigma=(0.1, 2.0)),
        transforms.Resize((img_size, img_size)), # image is downsampled to 64x64
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

In [10]:
# Load the train and val dataset
train_dataset = AFHQDataset(root_dir=train_dir, transform=transform)
val_dataset = AFHQDataset(root_dir=val_dir, transform=transform)

# DataLoaders for train and val sets
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4)

train_size = len(train_loader.dataset)
val_size = len(val_loader.dataset)

print(f"Train dataset size: {train_size}")
print(f"Validation dataset size: {val_size}")

Train dataset size: 14630
Validation dataset size: 1500


In [11]:
def vae_loss_function(recon_x, x, mu, log_var, kl_weight=1):
    """
    Calculates the VAE loss as a combination of 
    reconstruction loss and KL divergence, 
    scaled by a weight.
    """
    recon_loss = F.mse_loss(recon_x, x, reduction='sum')
    kl_divergence = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
    return recon_loss + kl_divergence * kl_weight, recon_loss, kl_divergence

def train_vae(model, train_loader, optimizer, kl_weight, device):
    """
    Trains the Variational Autoencoder (VAE) for one epoch on the given training data loader.
    """
    model.train()
    running_loss = 0.0
    running_recon_loss = 0.0
    running_kl_loss = 0.0
    
    for inputs, _ in train_loader:  #labels not used
        inputs = inputs.to(device)
        optimizer.zero_grad()
        recon_x, mu, log_var = model(inputs)
        loss, recon_loss, kl_loss = vae_loss_function(recon_x, inputs, mu, log_var, kl_weight)
        loss.backward()
        optimizer.step()
        
        # Collect all losses
        running_loss += loss.item()
        running_recon_loss += recon_loss.item()
        running_kl_loss += kl_loss.item()
        
    dataset_size = len(train_loader.dataset)
    return {
        'total_loss': running_loss / dataset_size,
        'recon_loss': running_recon_loss / dataset_size,
        'kl_loss': running_kl_loss / dataset_size
    }

def evaluate_vae(model, val_loader, kl_weight, device):
    """
    Evaluates the Variational Autoencoder (VAE) on the validation dataset after each epoch.
    """
    model.eval()
    running_loss = 0.0
    running_recon_loss = 0.0
    running_kl_loss = 0.0
    
    with torch.no_grad():
        for inputs, _ in val_loader:
            inputs = inputs.to(device)
            recon_x, mu, log_var = model(inputs)
            loss, recon_loss, kl_loss = vae_loss_function(recon_x, inputs, mu, log_var, kl_weight)
            
            # Collect all losses
            running_loss += loss.item()
            running_recon_loss += recon_loss.item()
            running_kl_loss += kl_loss.item()
            
    dataset_size = len(val_loader.dataset)
    return {
        'total_loss': running_loss / dataset_size,
        'recon_loss': running_recon_loss / dataset_size,
        'kl_loss': running_kl_loss / dataset_size
    }

def run_vae_training(
    model, train_loader, val_loader, device, num_epochs, learning_rate=0.001, 
    project="vae-training", name="vae_run", kl_weight=0.1, step_size=30, gamma=0.1):

    """Train and evaluate the model for a given number of epochs with W&B logging"""

    print(f"Training Name: {name}")
    print(f"Total num. of Epochs: {num_epochs}")
    print(f"Learning Rate: {learning_rate}")
    print(f"KL Weight used for Loss function: {kl_weight}\n")
    
    # Sample a batch for visualization ( used here to make it same for every epoch)
    inputs, _ = next(iter(train_loader))
    inputs = inputs.to(device)
    
    # Initialize W&B logging
    wandb.init(project=project, name=name, 
               config={
                   "learning_rate": learning_rate,
                   "num_epochs": num_epochs,
                   "step_size": step_size,
                   "gamma": gamma,
                   "kl_weight": kl_weight
               })
    
    optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-5) # Optimizer
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma)  # Learning rate scheduler
    
    for epoch in tqdm(range(num_epochs)):
        # Train for one epoch
        train_metrics = train_vae(model, train_loader, optimizer, kl_weight, device)
        # Evaluate after each epoch
        val_metrics = evaluate_vae(model, val_loader, kl_weight, device)
        
        current_lr = scheduler.get_last_lr()[0]
        scheduler.step()

        # Log images to W&B
        with torch.no_grad():
            # # Sample a small batch for visualization
            # inputs, _ = next(iter(train_loader))
            # inputs = inputs[:32].to(device)
            recon_x, _, _ = model(inputs)

            # Normalize and convert to image format
            recon_x = recon_x.view(-1, *inputs.shape[1:])
            recon_grid = make_grid(recon_x.cpu().detach() * 0.5 + 0.5, normalize=True, pad_value=1, padding=10)
            original_grid = make_grid(inputs.cpu().detach() * 0.5 + 0.5, normalize=True, pad_value=1, padding=10)
    
        # Log all data to W&B
        wandb.log({
            "epoch": epoch + 1,
            "train/total_loss": train_metrics['total_loss'],
            "train/recon_loss": train_metrics['recon_loss'],
            "train/kl_loss": train_metrics['kl_loss'],
            "val/total_loss": val_metrics['total_loss'],
            "val/recon_loss": val_metrics['recon_loss'],
            "val/kl_loss": val_metrics['kl_loss'],
            "learning_rate": current_lr,
            "original_images": wandb.Image(original_grid),
            "reconstructed_images": wandb.Image(recon_grid),
        })
        
        # # Print stats after each epoch
        print(f"\nEpoch {epoch+1}/{num_epochs}")
        print(f"Learning Rate: {current_lr:.6f}")
        print(f"Train - Total: {train_metrics['total_loss']:.4f}, "
              f"Recon: {train_metrics['recon_loss']:.4f}, "
              f"KL: {train_metrics['kl_loss']:.4f}")
        print(f"Eval  - Total: {val_metrics['total_loss']:.4f}, "
              f"Recon: {val_metrics['recon_loss']:.4f}, "
              f"KL: {val_metrics['kl_loss']:.4f}")
    
    # End W&B run
    wandb.finish()
    
    return model

In [15]:
class VAEEncoder(nn.Module):
    def __init__(self, latent_dim):
        super(VAEEncoder, self).__init__()
        
        # Initial convolution
        self.conv_initial = nn.Conv2d(3, 64, 3, stride=1, padding=1)
        
        # Downsampling blocks
        self.conv1 = nn.Conv2d(64, 128, 4, stride=2, padding=1)   
        self.conv2 = nn.Conv2d(128, 256, 4, stride=2, padding=1)  
        self.conv3 = nn.Conv2d(256, 512, 4, stride=2, padding=1)  
        self.conv4 = nn.Conv2d(512, 1024, 4, stride=2, padding=1)
        
        # Batch normalization
        self.bn1 = nn.BatchNorm2d(128)
        self.bn2 = nn.BatchNorm2d(256)
        self.bn3 = nn.BatchNorm2d(512)
        self.bn4 = nn.BatchNorm2d(1024)
        
        # Bottleneck
        self.fc_mu = nn.Linear(1024 * 4 * 4, latent_dim)
        self.fc_var = nn.Linear(1024 * 4 * 4, latent_dim)
        
        # Dropout for regularization
        self.dropout = nn.Dropout(0.2)
        
    def forward(self, x):
        # Initial convolution
        x = F.leaky_relu(self.conv_initial(x), 0.2)
        
        # Downsampling path
        x = F.leaky_relu(self.bn1(self.conv1(x)), 0.2)
        x = F.leaky_relu(self.bn2(self.conv2(x)), 0.2)
        
        # Save the feature map for the skip connection
        skip_connection = F.leaky_relu(self.bn3(self.conv3(x)), 0.2)
        
        x = F.leaky_relu(self.bn4(self.conv4(skip_connection)), 0.2)
        
        # Flatten and apply dropout
        x = self.dropout(x.view(x.size(0), -1))
        
        # Generate latent parameters
        mu = self.fc_mu(x)
        log_var = self.fc_var(x)
        
        return mu, log_var, skip_connection


class VAEDecoder(nn.Module):
    def __init__(self, latent_dim):
        super(VAEDecoder, self).__init__()
        
        # Initial fully connected layer
        self.fc = nn.Linear(latent_dim, 1024 * 4 * 4)
        
        # Upsampling blocks
        self.conv1 = nn.ConvTranspose2d(1024, 512, 4, stride=2, padding=1) 
        self.conv2 = nn.ConvTranspose2d(1024, 256, 4, stride=2, padding=1)  
        self.conv3 = nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1)  
        self.conv4 = nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1)
        
        # Batch normalization
        self.bn1 = nn.BatchNorm2d(512)
        self.bn2 = nn.BatchNorm2d(256)
        self.bn3 = nn.BatchNorm2d(128)
        self.bn4 = nn.BatchNorm2d(64)
        
        # Final convolution for output
        self.conv_final = nn.Conv2d(64, 3, 3, stride=1, padding=1)
        
        # Dropout for regularization
        self.dropout = nn.Dropout(0.2)
        
    def forward(self, z, skip_connection):
        # Reshape from latent space
        x = F.relu(self.fc(z))
        x = x.view(x.size(0), 1024, 4, 4)
        
        # Upsampling path
        x = F.relu(self.bn1(self.conv1(x)))
        
        # # Apply skip connection
        x = torch.cat([x, skip_connection], dim=1)  # Concatenate along channel dimension
        # Apply skip connection if provided
        # if skip_connection is not None:
        #     x = torch.cat([x, skip_connection], dim=1)  # Concatenate along channel dimension
        
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.relu(self.bn3(self.conv3(x)))
        x = F.relu(self.bn4(self.conv4(x)))
        
        # Final convolution with tanh activation
        x = torch.tanh(self.conv_final(x))
        
        return x

class ConvVAE(nn.Module):
    def __init__(self, latent_dim):
        super(ConvVAE, self).__init__()
        self.encoder = VAEEncoder(latent_dim)
        self.decoder = VAEDecoder(latent_dim)
        
    def reparameterize(self, mu, log_var):
        if self.training:
            std = torch.exp(0.5 * log_var)
            eps = torch.randn_like(std)
            return mu + eps * std
        return mu
        
    def forward(self, x):
        # Encoder with skip connection
        mu, log_var, skip_connection = self.encoder(x)
        z = self.reparameterize(mu, log_var)
        
        # Decoder with skip connection
        recon_x = self.decoder(z, skip_connection)
        return recon_x, mu, log_var


In [None]:
# Model training parameters
learning_rate=0.0001
step_size=10
gamma=0.5

kl_weight=0.01

num_epochs=40

latent_dim = 512 # define latent dimension

name=f"run_kl_wgt_{str(kl_weight)}_ep_{num_epochs}_latent_dim_{latent_dim}_skip_cn"
project="assignment-5-v2"


# Load model
model = ConvVAE(latent_dim=latent_dim).to(device)
# model

model = run_vae_training(
    model, train_loader, val_loader, device, 
    num_epochs=num_epochs, learning_rate=learning_rate,
    step_size=step_size, gamma=gamma,
    kl_weight=kl_weight,
    name=name, project=project
)

save_path = os.path.join(saved_model_folder, name)
torch.save(model, save_path)
print(f"Model saved at: {save_path}")

Training Name: run_kl_wgt_0.01_ep_40_latent_dim_512_skip_cn
Total num. of Epochs: 40
Learning Rate: 0.0001
KL Weight used for Loss function: 0.01



  2%|█▍                                                      | 1/40 [00:11<07:38, 11.75s/it]


Epoch 1/40
Learning Rate: 0.000100
Train - Total: 365.4440, Recon: 363.4789, KL: 196.5119
Eval  - Total: 126.2105, Recon: 125.7657, KL: 44.4807


  5%|██▊                                                     | 2/40 [00:26<08:41, 13.73s/it]


Epoch 2/40
Learning Rate: 0.000100
Train - Total: 112.9381, Recon: 112.4148, KL: 52.3299
Eval  - Total: 79.9486, Recon: 79.6469, KL: 30.1692


  8%|████▏                                                   | 3/40 [00:36<07:20, 11.90s/it]


Epoch 3/40
Learning Rate: 0.000100
Train - Total: 83.8239, Recon: 83.4087, KL: 41.5205
Eval  - Total: 83.0544, Recon: 82.7761, KL: 27.8362


 10%|█████▌                                                  | 4/40 [00:51<07:51, 13.09s/it]


Epoch 4/40
Learning Rate: 0.000100
Train - Total: 68.7571, Recon: 68.4081, KL: 34.8940
Eval  - Total: 69.1650, Recon: 68.9845, KL: 18.0593


 12%|███████                                                 | 5/40 [01:01<06:55, 11.87s/it]


Epoch 5/40
Learning Rate: 0.000100
Train - Total: 60.6038, Recon: 60.2789, KL: 32.4947
Eval  - Total: 55.8070, Recon: 55.6468, KL: 16.0229


 15%|████████▍                                               | 6/40 [01:16<07:28, 13.20s/it]


Epoch 6/40
Learning Rate: 0.000100
Train - Total: 54.0693, Recon: 53.7859, KL: 28.3386
Eval  - Total: 50.0897, Recon: 49.9214, KL: 16.8365


 18%|█████████▊                                              | 7/40 [01:29<07:09, 13.01s/it]


Epoch 7/40
Learning Rate: 0.000100
Train - Total: 49.2945, Recon: 49.0250, KL: 26.9478
Eval  - Total: 52.4603, Recon: 52.3126, KL: 14.7680


In [3]:
model = load_model(save_path)
max_samples = min(2000, len(val_loader.dataset))
fid_score = compute_fid_score(model, val_loader, device, max_samples=max_samples)
print(f"FID Score for model {name}: {fid_score}")

NameError: name 'load_model' is not defined

In [4]:
data_iter = iter(val_loader)
images, _ = next(data_iter)
visualize_reconstructions(model, images, device)

NameError: name 'val_loader' is not defined