## Import Required Libraries

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import os

## Device Configuration

In [None]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
BATCH_SIZE = 128
EPOCHS = 50
LR = 1e-3
LATENT_DIM = 2

BASE_DIR = "vae_comparison"
NO_KL_DIR = f"{BASE_DIR}/no_kl"
KL_DIR = f"{BASE_DIR}/with_kl"

In [None]:
for d in [NO_KL_DIR, KL_DIR]:
    os.makedirs(f"{d}/recon", exist_ok=True)
    os.makedirs(f"{d}/samples", exist_ok=True)
    os.makedirs(f"{d}/models", exist_ok=True)

## Dataset Loading and Preprocessing

In [None]:
transform = transforms.ToTensor()

train_data = datasets.FashionMNIST(
    root="./data", train=True, download=True, transform=transform
)
test_data = datasets.FashionMNIST(
    root="./data", train=False, download=True, transform=transform
)

train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(test_data, batch_size=BATCH_SIZE, shuffle=False)

## Variational Autoencoder Architecture

In [None]:
class VAE(nn.Module):
    def __init__(self, latent_dim):
        super().__init__()

        self.encoder = nn.Sequential(
            nn.Linear(784, 512),
            nn.ReLU(),
            nn.Linear(512, 256),
            nn.ReLU()
        )

        self.mu = nn.Linear(256, latent_dim)
        self.logvar = nn.Linear(256, latent_dim)

        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 512),
            nn.ReLU(),
            nn.Linear(512, 784),
            nn.Sigmoid()
        )

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def forward(self, x):
        x = x.view(-1, 784)
        h = self.encoder(x)
        mu = self.mu(h)
        logvar = self.logvar(h)
        z = self.reparameterize(mu, logvar)
        recon = self.decoder(z)
        return recon, mu, logvar

## Loss Function without KL

In [None]:
def loss_no_kl(recon, x):
    x = x.view(-1, 784)
    return nn.functional.binary_cross_entropy(recon, x, reduction="sum")

## Loss Function with KL

In [None]:
def loss_with_kl(recon, x, mu, logvar):
    x = x.view(-1, 784)
    recon_loss = nn.functional.binary_cross_entropy(recon, x, reduction="sum")
    kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return recon_loss + kl

In [None]:
def save_recon(model, loader, epoch, path):
    model.eval()
    x, _ = next(iter(loader))
    x = x.to(DEVICE)

    with torch.no_grad():
        recon, _, _ = model(x)

    plt.figure(figsize=(8,4))
    for i in range(8):
        plt.subplot(2,8,i+1)
        plt.imshow(x[i][0].cpu(), cmap="gray")
        plt.axis("off")

        plt.subplot(2,8,i+9)
        plt.imshow(recon[i].view(28,28).cpu(), cmap="gray")
        plt.axis("off")

    plt.savefig(f"{path}/recon/epoch_{epoch}.png")
    plt.close()

In [None]:
def save_samples(model, epoch, path):
    model.eval()
    with torch.no_grad():
        z = torch.randn(16, LATENT_DIM).to(DEVICE)
        samples = model.decoder(z)

    samples = samples.view(-1,1,28,28).cpu()

    plt.figure(figsize=(4,4))
    for i in range(16):
        plt.subplot(4,4,i+1)
        plt.imshow(samples[i][0], cmap="gray")
        plt.axis("off")

    plt.savefig(f"{path}/samples/epoch_{epoch}.png")
    plt.close()

## Training Configuration

In [None]:
def train(model, optimizer, use_kl, save_path):
    for epoch in range(1, EPOCHS+1):
        model.train()
        total_loss = 0

        for x, _ in train_loader:
            x = x.to(DEVICE)
            optimizer.zero_grad()
            recon, mu, logvar = model(x)

            if use_kl:
                loss = loss_with_kl(recon, x, mu, logvar)
            else:
                loss = loss_no_kl(recon, x)

            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        print(f"Epoch {epoch} | Loss: {total_loss/len(train_loader.dataset):.4f}")

        save_recon(model, test_loader, epoch, save_path)
        save_samples(model, epoch, save_path)

In [None]:
model_no_kl = VAE(LATENT_DIM).to(DEVICE)
model_kl = VAE(LATENT_DIM).to(DEVICE)

opt_no_kl = optim.Adam(model_no_kl.parameters(), lr=LR)
opt_kl = optim.Adam(model_kl.parameters(), lr=LR)

## Training Without KL Divergence

In [None]:
print("\nTraining WITHOUT KL Divergence")
train(model_no_kl, opt_no_kl, False, NO_KL_DIR)

## Training With KL Divergence

In [None]:
print("\nTraining WITH KL Divergence")
train(model_kl, opt_kl, True, KL_DIR)

## Saving the models

In [None]:
torch.save(model_no_kl.state_dict(), f"{NO_KL_DIR}/models/model.pth")
torch.save(model_kl.state_dict(), f"{KL_DIR}/models/model.pth")

## Latent Space Visualization Without KL Divergence

In [None]:
def plot_and_save_latent_no_kl(model, loader, save_path):
    model.eval()
    zs, labels = [], []

    with torch.no_grad():
        for x, y in loader:
            x = x.to(DEVICE)
            _, mu, _ = model(x)
            zs.append(mu.cpu())
            labels.append(y)

    zs = torch.cat(zs)
    labels = torch.cat(labels)

    plt.figure(figsize=(6,6))
    plt.scatter(zs[:, 0], zs[:, 1], c=labels, cmap="tab10", s=5)
    plt.colorbar()
    plt.title("Latent Space WITHOUT KL Divergence")

    plt.savefig(f"{save_path}/latent_space_without_kl.png", dpi=300)
    plt.show()
    plt.close()

## Latent Space Visualization With KL Divergence

In [None]:
def plot_and_save_latent_with_kl(model, loader, save_path):
    model.eval()
    zs, labels = [], []

    with torch.no_grad():
        for x, y in loader:
            x = x.to(DEVICE)
            _, mu, _ = model(x)
            zs.append(mu.cpu())
            labels.append(y)

    zs = torch.cat(zs)
    labels = torch.cat(labels)

    plt.figure(figsize=(6,6))
    plt.scatter(zs[:, 0], zs[:, 1], c=labels, cmap="tab10", s=5)
    plt.colorbar()
    plt.title("Latent Space WITH KL Divergence")

    plt.savefig(f"{save_path}/latent_space_with_kl.png", dpi=300)
    plt.show()
    plt.close()

In [None]:
plot_and_save_latent_no_kl(model_no_kl, test_loader, NO_KL_DIR)
plot_and_save_latent_with_kl(model_kl, test_loader, KL_DIR)

## Download the Zipped folder into yoour local

In [None]:
import shutil

zip_name = "vae_results"
folder_to_zip = "vae_comparison"

shutil.make_archive(zip_name, 'zip', folder_to_zip)

print("ZIP file created:", zip_name + ".zip")

## Conclusion
KL divergence is essential for transforming an autoencoder into a generative model by enforcing a structured latent space.