Import Modules

In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Subset, TensorDataset
from torchvision import datasets, transforms
from torch.optim import Adam
import torchvision.transforms as T
from tqdm import tqdm
import matplotlib.pyplot as plt
import numpy as np
from sklearn.model_selection import train_test_split
from preprocessing import ImageDataset

Hyperperameters

In [6]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

kwargs = {'num_workers': 4, 'pin_memory': True} 

batch_size = 100

x_dim  = 784
hidden_dim = 400
latent_dim = 200

lr = 1e-3

epochs = 10

# Model parameters
input_channels = 3
latent_dim = 20

Define Model

In [54]:
class Encoder(nn.Module):
    def __init__(self, input_channels, latent_dim):
        super(Encoder, self).__init__()
        self.conv1 = nn.Conv2d(input_channels, 32, kernel_size=4, stride=2, padding=1)  # (32, 128, 128)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1)  # (64, 64, 64)
        self.conv3 = nn.Conv2d(64, 32, kernel_size=4, stride=2, padding=1)  # (32, 32, 32)
        self.flatten = nn.Flatten()
        
        # Output mean vector
        self.fc_mean = nn.Linear(32 ** 3, latent_dim)
        
        # Output covariance matrix (as a flattened vector)
        self.fc_cov = nn.Linear(32 ** 3, latent_dim * latent_dim)

        self.relu = nn.LeakyReLU(0.2)

    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.relu(self.conv2(x))
        x = self.relu(self.conv3(x))
        x = self.flatten(x)
        
        mean = self.fc_mean(x)
        cov_flat = self.fc_cov(x)
        cov_matrix = cov_flat.view(-1, mean.size(1), mean.size(1))  # Reshape to square covariance matrix

        cov_matrix = 0.5 * (cov_matrix + cov_matrix.transpose(-1, -2))  # Ensure symmetry
        cov_matrix = cov_matrix + 1e-4 * torch.eye(latent_dim).to(cov_matrix.device)  # Ensure positive definiteness
        
        # Ensure the covariance matrix is positive-definite
        cov_matrix = torch.clamp(cov_matrix, min=1e-4)  # Prevent it from becoming too small
        
        return mean, cov_matrix


class Decoder(nn.Module):
    def __init__(self, latent_dim, output_channels):
        super(Decoder, self).__init__()
        self.fc = nn.Linear(latent_dim, 32 ** 3)
        self.deconv1 = nn.ConvTranspose2d(32, 64, kernel_size=4, stride=2, padding=1) # (64, 64, 64)
        self.deconv2 = nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1)  # (32, 128, 128)
        self.deconv3 = nn.ConvTranspose2d(32, output_channels, kernel_size=4, stride=2, padding=1)  # (3, 256, 256)

    def forward(self, z):
        x = F.relu(self.fc(z))
        x = x.view(-1, 32, 32, 32)  # Reshape to spatial dimensions
        x = F.relu(self.deconv1(x))
        x = torch.sigmoid(self.deconv2(x))
        x = torch.sigmoid(self.deconv3(x))
        return x

class VAE(nn.Module):
    def __init__(self, input_channels, latent_dim):
        super(VAE, self).__init__()
        self.encoder = Encoder(input_channels, latent_dim)
        self.decoder = Decoder(latent_dim, input_channels)

    def reparameterize(self, mean, cov_matrix):
        batch_size, latent_dim = mean.shape
    
        # Ensure positive-definiteness and numerical stability
        cov_matrix = cov_matrix + 1e-4 * torch.eye(latent_dim).to(cov_matrix.device)  # Ensure positive definiteness
        cov_matrix = torch.clamp(cov_matrix, min=1e-4)  # Prevent it from becoming too small
    
        # Use Cholesky decomposition to obtain L such that L * L^T = cov_matrix
        L = torch.linalg.cholesky(cov_matrix)
    
        # Sample from standard normal
        epsilon = torch.randn(batch_size, latent_dim).to(cov_matrix.device)
    
        # Reparameterization trick: z = mean + L * epsilon
        return mean + torch.matmul(L, epsilon.unsqueeze(-1)).squeeze(-1)

    def forward(self, x):
        mean, cov_matrix = self.encoder(x)
        z = self.reparameterize(mean, cov_matrix)
        reconstructed = self.decoder(z)
        return reconstructed, mean, cov_matrix


def loss_function(reconstructed, original, mean, cov_matrix):
    # Reconstruction loss
    reconstruction_loss = F.mse_loss(reconstructed, original, reduction='sum')
    
    # KL Divergence for Multivariate Gaussian
    batch_size, latent_dim, _ = cov_matrix.size()
    
    cov_trace = torch.diagonal(cov_matrix, dim1=-2, dim2=-1).sum(dim=-1)  # Tr(cov)
    cov_det = torch.linalg.det(cov_matrix) + 1e-6  # Prevent log of zero
    
    # Avoid log(0) or negative determinant
    kl_divergence = 0.5 * (torch.sum(mean ** 2, dim=-1) + cov_trace - latent_dim - torch.log(cov_det))
    kl_divergence = kl_divergence.sum()

    return reconstruction_loss + kl_divergence


Training Loop

In [56]:
image_dataset = ImageDataset("map_images_original/labels.csv", "map_images_original/")

# Split dataset into training and testing sets
indices = list(range(len(image_dataset)))
train_labels, test_labels = train_test_split(indices, test_size=0.2, random_state=42)  # 80% train, 20% test
train_dataset = Subset(image_dataset, train_labels)
test_dataset = Subset(image_dataset, test_labels)

# Create a new DataLoader
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True, **kwargs)

# Define model
vae = VAE(input_channels, latent_dim).to(DEVICE)

# Optimizer
optimizer = Adam(vae.parameters(), lr=lr)

# Training loop
vae.train()
for epoch in range(epochs):
    total_loss = 0
    for batch_idx, (images, _) in enumerate(train_loader):
        images = images.to(torch.float32)
        images = images.to(DEVICE)

        optimizer.zero_grad()
        reconstructed, mean, cov_matrix = vae(images)
        
        # Compute loss
        loss = loss_function(reconstructed, images, mean, cov_matrix)
        
        loss.backward()
        
        # Gradient clipping to avoid exploding gradients
        torch.nn.utils.clip_grad_norm_(vae.parameters(), max_norm=1.0)
        
        optimizer.step()
        total_loss += loss.item()
        
    print(f"Epoch {epoch + 1}, Loss: {total_loss / (batch_idx * batch_size)}")
# Saving the model and optimizer
torch.save({
    'model_state_dict': vae.state_dict(),
    'optimizer_state_dict': optimizer.state_dict()
}, "MVVAE_checkpoint.pth")

_LinAlgError: linalg.cholesky: (Batch element 0): The factorization could not be completed because the input is not positive-definite (the leading minor of order 3 is not positive-definite).

Test Loop

In [None]:
test_loader = DataLoader(dataset=test_dataset,  batch_size=batch_size, shuffle=False,  **kwargs)

# Reloading the model and optimizer
checkpoint = torch.load("MVVAE_checkpoint.pth")
vae = VAE(input_channels=input_channels, latent_dim=latent_dim)  # Reinitialize model
vae.load_state_dict(checkpoint['model_state_dict'])
vae.to(DEVICE)
optimizer = torch.optim.Adam(vae.parameters(), lr=lr)  # Reinitialize optimizer
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

vae.eval()
for batch_idx, (test_images, _) in enumerate(test_loader):
    test_images = test_images.to(DEVICE)
    reconstructed, mean, log_var = vae(test_images)
    loss = loss_function(reconstructed, test_images, mean, log_var)
print(loss.item() / batch_size)

In [None]:
def show_image(original_batch, reconstructed_batch, index):
    # Extract the specific image and move to CPU
    original = original_batch[index].detach().cpu()
    reconstructed = reconstructed_batch[index].detach().cpu()

    transform = T.Compose([
    T.ToPILImage()
    ])
    
    
    original = transform(original)
    reconstructed = transform(reconstructed)
    
        
    # Create the plot
    fig, axes = plt.subplots(1, 2, figsize=(6, 3))
    
    # Original image
    axes[0].imshow(original)
    axes[0].set_title("Original")
    axes[0].axis('off')
    
    # Reconstructed image
    axes[1].imshow(reconstructed)
    axes[1].set_title("Generated")
    axes[1].axis('off')
    
    # Display the plot
    plt.tight_layout()
    plt.show()


show_image(test_images, reconstructed, index=15)