In [None]:
import torch
from torch import optim
from torch.utils.data import DataLoader, Subset
from torchvision import datasets, transforms
import numpy as np
import matplotlib.pyplot as plt
from models.vanilla_vae import VanillaVAE
from scipy import stats

In [None]:
# Get the minst dataset loader
def get_mnist_loaders(batch_size=64):
    """
    Train on digits 0-8 only (normal). Test on all digits (0-9).
    """
    transform = transforms.Compose([
        transforms.Resize(64),
        transforms.ToTensor(),                     
        transforms.Lambda(lambda x: x.repeat(3, 1, 1)),  # (3, H, W) 
    ])

    train_full = datasets.MNIST(
        root="./data", train=True, download=True, transform=transform
    )
    test_full = datasets.MNIST(
        root="./data", train=False, download=True, transform=transform
    )

    # Train only on digits 0-8 (normal data)
    normal_indices = [i for i, (_, y) in enumerate(train_full) if y != 9]
    train_normal = Subset(train_full, normal_indices)

    train_loader = DataLoader(train_normal, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_full, batch_size=batch_size, shuffle=False)

    return train_loader, test_loader

# Gradient Descent Trainning
def train_mnist_vae(latent_dim=128,batch_size=64, num_epochs=20, lr=0.005, device=None):
    
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    train_loader, test_loader = get_mnist_loaders(batch_size=batch_size)

    model = VanillaVAE(in_channels=3, latent_dim=latent_dim).to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=0.0)
    scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.95)

    for epoch in range(num_epochs):
        model.train()
        total_loss = 0.0
        total_recon = 0.0
        total_kld = 0.0
        num_batches = len(train_loader)

        for batch_idx, (x, _) in enumerate(train_loader):
            x = x.to(device)
            
            recons, x_in, mu, log_var = model(x)
            M_N = x.size(0) / len(train_loader.dataset) 
            loss_dict = model.loss_function(
                recons, x_in, mu, log_var, M_N=M_N
            )
            loss = loss_dict["loss"]

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item() * x.size(0)
            total_recon += loss_dict["Reconstruction_Loss"].item() * x.size(0)
            total_kld += loss_dict["KLD"].item() * x.size(0)

        avg_loss = total_loss / len(train_loader.dataset)
        avg_recon = total_recon / len(train_loader.dataset)
        avg_kld = total_kld / len(train_loader.dataset)

        scheduler.step()

    return model, test_loader, device


In [None]:
# Running the Model (The model is trained on Google Colab)
model, test_loader, device = train_mnist_vae(latent_dim=128,batch_size=64, num_epochs=20, lr=0.005, device=None)
torch.save(model.state_dict(), "mnist_vae_vanilla.pth")

In [None]:
# Load the model
model = VanillaVAE(in_channels=3, latent_dim=128).to(device)
model.load_state_dict(torch.load("mnist_vae_vanilla.pth"))
model.eval()

# Get test data
transform = transforms.Compose([transforms.Resize(64),transforms.ToTensor(),transforms.Lambda(lambda x: x.repeat(3, 1, 1))])
test_full = datasets.MNIST(root="./data", train=False, download=True, transform=transform)
test_loader = DataLoader(test_full, batch_size=64, shuffle=False)

# Compute errors
def compute_reconstruction_errors(model, data_loader, device):
    model.eval()
    all_errors = []
    all_labels = []
    with torch.no_grad():
        for x, y in data_loader:
            x = x.to(device)
            recons, x_in, mu, log_var = model(x)
            mse = torch.mean((recons - x_in) ** 2, dim=(1, 2, 3)) # L2 difference of actual and generated
            all_errors.append(mse.cpu().numpy())
            all_labels.append(y.numpy())
    errors = np.concatenate(all_errors)
    labels = np.concatenate(all_labels)
    return errors, labels

# Plot histogram
def plot_anomaly_histogram(errors, labels):
    normal_mask = labels != 9
    anomaly_mask = labels == 9
    normal_errors = errors[normal_mask]
    anomaly_errors = errors[anomaly_mask]

    fig, ax = plt.subplots(figsize=(10, 6))

    # Histograms density
    ax.hist(normal_errors, bins=50, alpha=0.4, label="Digits 0-8 (normal)",
            density=True, color='steelblue')
    ax.hist(anomaly_errors, bins=50, alpha=0.4, label="Digit 9 (anomaly)",
            density=True, color='coral')

    ax.set_xlabel("Reconstruction MSE", fontsize=12)
    ax.set_ylabel("Density", fontsize=12)
    ax.legend(fontsize=10)
    ax.set_title("MNIST VAE Reconstruction Errors", fontsize=14)
    ax.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.savefig('anomaly_histogram.png', dpi=150, bbox_inches='tight')
    plt.show()

# Show reconstructions 
def show_reconstructions(model, data_loader, device, num_examples=8):
    model.eval()

    # Take a single batch
    x_batch, y_batch = next(iter(data_loader))
    x_batch = x_batch.to(device)
    y_batch = y_batch.to(device)

    # Get reconstructions
    with torch.no_grad():
        recons = model.generate(x_batch)

    def select_indices(mask, k):
        return mask.nonzero(as_tuple=True)[0][:k]

    # Normal digits: 0â€“8, anomalies: 9
    normal_idx = select_indices(y_batch != 9, num_examples)
    anomaly_idx = select_indices(y_batch == 9, num_examples)

    def to_image(t):
        img = t.detach().cpu().permute(1, 2, 0)  
        img = (img + 1.0) / 2.0                  
        img = torch.clamp(img, 0, 1)
        return img.numpy()

    def plot_pairs(indices, title):
        if len(indices) == 0:
            return

        n = len(indices)
        fig, axes = plt.subplots(2, n, figsize=(2 * n, 4))

        # If n == 1, axes will not be 2D by default
        if n == 1:
            axes = axes.reshape(2, 1)

        for col, idx in enumerate(indices):
            orig_img = to_image(x_batch[idx])
            rec_img = to_image(recons[idx])

            axes[0, col].imshow(orig_img, cmap="gray")
            axes[0, col].axis("off")

            axes[1, col].imshow(rec_img, cmap="gray")
            axes[1, col].axis("off")

        axes[0, 0].set_title(f"{title} - original")
        axes[1, 0].set_title(f"{title} - reconstruction")

        fig.tight_layout()

        filename = (
            title.lower()
                 .replace(" ", "_")
                 .replace("(", "")
                 .replace(")", "")
                 .replace("-", "")
            + ".png"
        )
        fig.savefig(filename, dpi=150, bbox_inches="tight")
        plt.show()

    plot_pairs(normal_idx, "Normal (0-8)")
    plot_pairs(anomaly_idx, "Anomaly (9)")

def analyze_latent_dimensions(model, data_loader, device):
    model.eval()
    all_mu = []
    all_logvar = []

    with torch.no_grad():
        for x, _ in data_loader:
            x = x.to(device)
            mu, log_var = model.encode(x)
            all_mu.append(mu.cpu())
            all_logvar.append(log_var.cpu())

    all_mu = torch.cat(all_mu)
    all_logvar = torch.cat(all_logvar)

    # Calculate KL divergence per dimension
    kl_per_dim = -0.5 * (1 + all_logvar - all_mu**2 - all_logvar.exp())
    kl_per_dim = kl_per_dim.mean(dim=0)

    collapsed_dims = (kl_per_dim < 0.01).sum().item()

    print(f"Number of collapsed dimensions (KL < 0.01): {collapsed_dims}/{len(kl_per_dim)}")

def show_interpolation_artifacts(model, data_loader, device, num_steps=10):
    """Show VAE disadvantage: Unrealistic interpolations"""
    model.eval()
    x_batch, y_batch = next(iter(data_loader))

    # Get two different digits
    idx1 = (y_batch == 0).nonzero(as_tuple=True)[0][0]
    idx2 = (y_batch == 8).nonzero(as_tuple=True)[0][0]

    x1 = x_batch[idx1:idx1+1].to(device)
    x2 = x_batch[idx2:idx2+1].to(device)

    with torch.no_grad():
        mu1, _ = model.encode(x1)
        mu2, _ = model.encode(x2)

        # Interpolate
        alphas = torch.linspace(0, 1, num_steps)
        interpolated = []
        for alpha in alphas:
            z = (1 - alpha) * mu1 + alpha * mu2
            img = model.decode(z)
            interpolated.append(img)

        interpolated = torch.cat(interpolated)

    fig, axes = plt.subplots(1, num_steps, figsize=(num_steps*1.5, 2))
    for i in range(num_steps):
        img = interpolated[i].cpu().permute(1, 2, 0)
        img = (img + 1.0) / 2.0
        img = torch.clamp(img, 0, 1)
        axes[i].imshow(img, cmap='gray')
        axes[i].axis('off')
        axes[i].set_title(f'{alphas[i]:.1f}', fontsize=8)

    plt.suptitle('Interpolations\n',
                 fontsize=12)
    plt.tight_layout()
    plt.savefig('vae_interpolation_artifacts.png', dpi=150, bbox_inches='tight')
    plt.show()



errors, labels = compute_reconstruction_errors(model, test_loader, device)
plot_anomaly_histogram(errors, labels)

show_reconstructions(model, test_loader, device)

analyze_latent_dimensions(model, test_loader, device)

show_interpolation_artifacts(model, test_loader, device)