In [12]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Normal, Bernoulli, Categorical
import time
from torch.utils.data import DataLoader
import torchvision
from torchvision.transforms import transforms

In [3]:
class Encoder(nn.Module):
    def __init__(self, num_layers, input_size, hidden_size, latent_size):
        super(Encoder, self).__init__()
        self.input_size = input_size
        self.num_layers = num_layers
        self.hidden_size = hidden_size
        self.latent_size = latent_size

        self.layers = []
        self.layers.append(nn.Linear(self.input_size, self.hidden_size))
        self.layers.append(nn.ReLU())

        for _ in range(self.num_layers - 1):
            self.layers.append(nn.Linear(self.hidden_size, self.hidden_size))
            self.layers.append(nn.ReLU())

        self.layers.append(nn.Linear(self.hidden_size, 2*self.latent_size))
        self.net = nn.Sequential(*self.layers)

    def forward(self, X):
        mu, log_sigma2 = torch.split(self.net(X), self.latent_size, dim=-1)
        return mu, log_sigma2

In [4]:
class Decoder(nn.Module):
    def __init__(self, num_layers, latent_size, hidden_size, output_size):
        super(Decoder, self).__init__()
        self.num_layers = num_layers
        self.latent_size = latent_size
        self.hidden_size = hidden_size
        self.output_size = output_size

        self.layers = []
        self.layers.append(nn.Linear(self.latent_size, self.hidden_size))
        self.layers.append(nn.ReLU())

        for _ in range(self.num_layers - 1):
            self.layers.append(nn.Linear(self.hidden_size, self.hidden_size))
            self.layers.append(nn.ReLU())

        self.layers.append(nn.Linear(self.hidden_size, self.output_size))
        self.net = nn.Sequential(*self.layers)

    def forward(self, X):
        y = self.net(X)
        return y

In [5]:
class VariationalAutoencoder(nn.Module):
    def __init__(self, input_size, hidden_size, latent_size, num_layers):
        super().__init__()

        self.encoder = Encoder(
            num_layers=num_layers,
            input_size=input_size,
            hidden_size=hidden_size,
            latent_size=latent_size
        )

        self.decoder = Decoder(
            num_layers=num_layers,
            latent_size=latent_size,
            hidden_size=hidden_size,
            output_size=input_size
        )

        self._init_params()

    def forward(self, x):
        """
        Forward pass returns everything needed to compute ELBO.
        No randomness is created here except through reparameterization input.
        """

        # Encode
        mu, log_sigma2 = self.encoder(x)

        # Reparameterization (SGVB estimator)
        epsilon = torch.randn_like(mu)
        sigma = torch.exp(0.5 * log_sigma2)
        z = mu + sigma * epsilon

        # Decode (logits for Bernoulli likelihood)
        logits = self.decoder(z)

        return logits, mu, log_sigma2

    def _init_params(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, mean=0.0, std=0.01)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)


In [6]:
def negative_elbo(logits, labels, mu, log_sigma2):
    reconstruction = nn.BCEWithLogitsLoss(reduction='sum')
    reconstruction_term = reconstruction(logits, labels)
    sigma2 = torch.exp(log_sigma2)
    kl_divergence = 0.5 * torch.sum(torch.square(mu) + sigma2 - log_sigma2 - 1)
    return kl_divergence - reconstruction_term

In [7]:
def train_epoch(model, dataloader, criterion, optimizer, device, update_freq=10):
    model.train()

    total_loss = 0.0

    num_batches = 0

    for batch_idx, (features, _) in enumerate(dataloader):
        features = features.to(device)
        labels = features

        optimizer.zero_grad()
        
        features = features.view(features.size(0), -1)  # Flatten
        labels = labels.view(labels.size(0), -1)        # flatten labels
        logits, mu, log_sigma2 = model(features)

        loss = criterion(logits, labels, mu, log_sigma2)

        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        num_batches += 1

        if(batch_idx+1)%update_freq == 0:
            print(f'   Batch [{batch_idx+1}/{len(dataloader)}] - 'f'Loss: {loss.item():.4f}')

    avg_loss = total_loss/num_batches
    return avg_loss

In [8]:
def evaluate(model, dataloader, criterion, device, use_mean=False):
    """
    Evaluates average negative ELBO.
    
    Args:
        use_mean:
            If True  -> z = mu (deterministic reconstruction)
            If False -> z sampled via reparameterization (L=1 Monte Carlo)
    """
    model.eval()
    total_loss = 0.0
    num_batches = 0

    all_reconstructions = []
    all_originals = []

    with torch.no_grad():
        for features, _ in dataloader:
            features = features.to(device)

            # Flatten
            features = features.view(features.size(0), -1)

            # Forward pass
            logits, mu, log_sigma2 = model(features)

            # Loss = negative ELBO
            loss = criterion(logits, features, mu, log_sigma2)

            total_loss += loss.item()
            num_batches += 1

            # Convert logits â†’ probabilities for visualization
            probs = torch.sigmoid(logits)

            all_reconstructions.append(probs.cpu())
            all_originals.append(features.cpu())

    avg_loss = total_loss / num_batches

    all_reconstructions = torch.cat(all_reconstructions, dim=0)
    all_originals = torch.cat(all_originals, dim=0)

    return avg_loss, all_reconstructions, all_originals

In [9]:
def train_vae(num_epochs, model, train_dataloader, test_dataloader, criterion, optimizer, device, update_freq):
    train_losses = []
    test_losses = []
    
    start_time = time.time()
    
    for epoch in range(num_epochs):
        epoch_start_time = time.time()
        print(f"\nEpoch [{epoch+1}/{num_epochs}]")
        print("-" * 50)
        
        #Train
        print("INFO - Training...")
        train_loss = train_epoch(model, train_dataloader, criterion, optimizer, device, update_freq=update_freq)
        
        #Eval
        print("INFO - Evaluating...")
        test_loss, _, _ = evaluate(model, test_dataloader, criterion, device)
        
        # Store metrics
        train_losses.append(train_loss)
        test_losses.append(test_loss)
        
        epoch_time = time.time() - epoch_start_time
    
    total_time = time.time() - start_time
    print("\n" + "=" * 80)
    print("âœ… Training completed!")
    print(f"ðŸ•’ Total training time: {total_time:.2f}s ({total_time/60:.1f} minutes)")
    
    return train_losses, test_losses

In [10]:
def download_MNIST():
    # Define data transforms for preprocessing
    transform = transforms.Compose([
        transforms.Pad(2),
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))  # Mean and Std for MNIST
    ])
    
    # Load training dataset
    train_dataset = torchvision.datasets.MNIST(
        root='../data/train',
        train=True,
        download=True,
        transform=transform
    )
    
    # Load test dataset
    test_dataset = torchvision.datasets.MNIST(
        root='../data/test',
        train=False,
        download=True,
        transform=transform
    )
    
    return train_dataset, test_dataset

In [11]:
def create_DataLoaders(train_dataset, test_dataset, batch_size, shuffle_train, num_workers):
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=shuffle_train,
        num_workers=num_workers,
        pin_memory=True
    )
    
    test_loader = DataLoader(
        test_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=2,
        pin_memory=True
    )
    
    return train_loader, test_loader