# Exercise: VAE

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms, utils as vutils
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
import numpy as np
import os
from tqdm import tqdm
import random

In [None]:
# Configs
class Config:
    dataset = 'fashion'
    data_dir = './data'
    out_dir = './outputs'
    batch_size = 128 # Quantas amostras são processadas por iteração de treinamento.
    epochs = 10 # Quantas vezes o modelo vê todo o conjunto de treinamento
    lr = 2e-3 # Learning Rate
    latent_dim = 2 # Dimensão do espaço latente
    beta = 1.0 # Peso do termo de divergência KL na loss
    valid_split = 0.1 # Porcentagem do dataset de treino usada para validação
    seed = 42
    num_workers = 2 # Número de processos paralelos para carregar dados (DataLoader do PyTorch).
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    compare_ae = False # Se True, o script também treina um Autoencoder (AE) clássico como baseline para comparação

os.makedirs(Config.out_dir, exist_ok=True)

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

set_seed(Config.seed)

print(f"Using device: {Config.device}")

In [None]:
# Model Definitions
class Encoder(nn.Module):
    def __init__(self, latent_dim):
        super().__init__()
        self.conv = nn.Sequential(
          nn.Conv2d(1, 32, 3, 2, 1),
          nn.ReLU(True),
          nn.Conv2d(32, 64, 3, 2, 1),
          nn.ReLU(True),
          nn.Conv2d(64, 128, 3, 2, 1),
          nn.ReLU(True)
        )
        self.fc_mu = nn.Linear(128*4*4, latent_dim)
        self.fc_logvar = nn.Linear(128*4*4, latent_dim)

    def forward(self, x):
        h = torch.flatten(self.conv(x), 1)
        mu, logvar = self.fc_mu(h), self.fc_logvar(h)
        return mu, logvar

class Decoder(nn.Module):
    def __init__(self, latent_dim):
        super().__init__()
        self.fc = nn.Linear(latent_dim, 128*4*4)
        self.deconv = nn.Sequential(
          nn.ConvTranspose2d(128, 64, 4, 2, 1),
          nn.ReLU(True),
          nn.ConvTranspose2d(64, 32, 4, 2, 2, 1),
          nn.ReLU(True),
          nn.ConvTranspose2d(32, 16, 4, 2, 1),
          nn.ReLU(True),
          nn.Conv2d(16, 1, 3, 1, 1)
        )

    def forward(self, z):
        h = self.fc(z).view(-1, 128, 4, 4)
        x_logits = self.deconv(h)
        return x_logits[:, :, 1:29, 1:29]

class VAE(nn.Module):
    def __init__(self, latent_dim):
        super().__init__()
        self.encoder = Encoder(latent_dim)
        self.decoder = Decoder(latent_dim)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def forward(self, x):
        mu, logvar = self.encoder(x)
        z = self.reparameterize(mu, logvar)
        x_logits = self.decoder(z)
        return x_logits, mu, logvar

def elbo_bce_loss(x_logits, x, mu, logvar, beta=1.0):
    recon = F.binary_cross_entropy_with_logits(x_logits, x, reduction='sum')
    kld = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return recon + beta * kld, recon, kld

In [None]:
# Data Preparation
transform = transforms.ToTensor()
if Config.dataset == 'fashion':
    full_train = datasets.FashionMNIST(Config.data_dir, train=True, download=True, transform=transform)
    test = datasets.FashionMNIST(Config.data_dir, train=False, download=True, transform=transform)
else:
    full_train = datasets.MNIST(Config.data_dir, train=True, download=True, transform=transform)
    test = datasets.MNIST(Config.data_dir, train=False, download=True, transform=transform)

val_size = int(Config.valid_split * len(full_train))
train_size = len(full_train) - val_size
train_ds, val_ds = random_split(full_train, [train_size, val_size])

train_loader = DataLoader(train_ds, batch_size=Config.batch_size, shuffle=True)
val_loader = DataLoader(val_ds, batch_size=Config.batch_size)
test_loader = DataLoader(test, batch_size=Config.batch_size)

In [None]:
# Training Loop
model = VAE(Config.latent_dim).to(Config.device)
optimizer = torch.optim.Adam(model.parameters(), lr=Config.lr)

for epoch in range(1, Config.epochs + 1):
    model.train()
    total_loss = total_recon = total_kld = 0
    for x, _ in tqdm(train_loader, desc=f'Epoch {epoch}'):
        x = x.to(Config.device)
        x_logits, mu, logvar = model(x)
        loss, recon, kld = elbo_bce_loss(x_logits, x, mu, logvar, Config.beta)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        total_recon += recon.item()
        total_kld += kld.item()

print(f'Epoch {epoch}: ELBO={total_loss/len(train_ds):.4f} Recon={total_recon/len(train_ds):.4f} KLD={total_kld/len(train_ds):.4f}')

In [None]:
# Visualization
model.eval()
with torch.no_grad():
    all_z = []
    all_labels = []
    for x, labels in val_loader:
        x = x.to(Config.device)
        _, mu, logvar = model(x)
        all_z.append(mu.cpu().numpy())
        all_labels.append(labels.cpu().numpy())

    all_z = np.concatenate(all_z, axis=0)
    all_labels = np.concatenate(all_labels, axis=0)

    # Latent space (if <=3 dim)
    z = all_z
    if Config.latent_dim > 2:
        z = PCA(2).fit_transform(z)

    # Define class names for FashionMNIST
    class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
                   'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']

    plt.figure(figsize=(10, 10))
    scatter = plt.scatter(z[:,0], z[:,1], s=5, c=all_labels, cmap='tab10', alpha=0.6)
    plt.title('Latent Space')
    cbar = plt.colorbar(scatter, ticks=range(len(class_names)))
    cbar.set_ticklabels(class_names)
    plt.show()

    # Display some reconstructions from the first batch
    x, labels = next(iter(val_loader))
    x = x.to(Config.device)
    x_logits, mu, logvar = model(x)
    recon = torch.sigmoid(x_logits).cpu()
    grid = torch.cat([x.cpu(), recon])
    vutils.save_image(grid, os.path.join(Config.out_dir, 'vae_recon.png'), nrow=8)


    # Random samples
    z_rand = torch.randn(64, Config.latent_dim).to(Config.device)
    x_gen = torch.sigmoid(model.decoder(z_rand)).cpu()
    vutils.save_image(x_gen, os.path.join(Config.out_dir, 'vae_samples.png'), nrow=8)