In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt
import numpy as np
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import gc  # Add garbage collector

# Clear CUDA memory
if torch.cuda.is_available():
    torch.cuda.empty_cache()
    
def print_gpu_memory():
    if torch.cuda.is_available():
        print(f"GPU memory allocated: {torch.cuda.memory_allocated()/1e9:.2f} GB")
        print(f"GPU memory cached: {torch.cuda.memory_reserved()/1e9:.2f} GB")


In [2]:
transform = transforms.ToTensor()
train_dataset = datasets.MNIST(root = './data' , train = True , download = True , transform = transform)
train_loader = DataLoader(train_dataset , batch_size = 64 , shuffle = True)


In [3]:
class BetaVAE(nn.Module):
    def __init__(self, latent_dim = 20):
        super(BetaVAE, self).__init__()
        self.latent_dim = latent_dim

        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
            nn.ReLU()
        )

        self.flatten = nn.Flatten()
        self.fc_mu = nn.Linear(128 * 4 * 4, latent_dim)
        self.fc_logvar = nn.Linear(128 * 4 * 4, latent_dim)

        # Decoder
        self.fc_decode = nn.Linear(latent_dim, 128 * 4 * 4)
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(32, 1, kernel_size=4, stride=2, padding=1),
            nn.Sigmoid()
        )

    def encode(self, x):
        h = self.encoder(x)
        h_flat = self.flatten(h)
        mu = self.fc_mu(h_flat)
        logvar = self.fc_logvar(h_flat)
        return mu, logvar
    
    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std
    
    def decode(self, z):
        h = self.fc_decode(z)
        h = h.view(-1, 128, 4, 4)
        x_hat = self.decoder(h)
        return x_hat[:, :, :28, :28]
    
    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        x_hat = self.decode(z)
        return x_hat, mu, logvar


In [4]:
def beta_vae_loss(x , x_hat ,mu , logvar , beta = 4):
    recon_loss = F.binary_cross_entropy(x_hat , x, reduction ='sum')
    kl_div = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return recon_loss + beta * kl_div, recon_loss, kl_div


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = BetaVAE(latent_dim=20).to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
num_epochs = 10
beta = 4.0
losses, recons, kls = [], [], []

for epoch in range(1, num_epochs + 1):
    model.train()
    total_loss, total_recon, total_kl = 0, 0, 0

    for batch in train_loader:
        x, _ = batch
        x = x.to(device)
        x_hat, mu, logvar = model(x)
        loss, recon, kl = beta_vae_loss(x, x_hat, mu, logvar, beta)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        total_recon += recon.item()
        total_kl += kl.item()
    
    losses.append(total_loss / len(train_loader.dataset))
    recons.append(total_recon / len(train_loader.dataset))
    kls.append(total_kl / len(train_loader.dataset))

    print(f"Epoch {epoch}, Loss: {losses[-1]:.2f}, Recon: {recons[-1]:.2f}, KL: {kls[-1]:.2f}")




Epoch 1, Loss: 177.78, Recon: 152.42, KL: 6.34
Epoch 2, Loss: 152.24, Recon: 116.11, KL: 9.03
Epoch 3, Loss: 149.43, Recon: 111.84, KL: 9.40
Epoch 4, Loss: 147.94, Recon: 109.79, KL: 9.54
Epoch 5, Loss: 146.94, Recon: 108.37, KL: 9.64
Epoch 6, Loss: 146.15, Recon: 107.35, KL: 9.70
Epoch 7, Loss: 145.72, Recon: 106.60, KL: 9.78
Epoch 8, Loss: 145.12, Recon: 105.89, KL: 9.81
Epoch 9, Loss: 144.77, Recon: 105.48, KL: 9.82
Epoch 10, Loss: 144.46, Recon: 105.05, KL: 9.85


: 

In [None]:
plt.figure(figsize=(10,5))
plt.plot(losses, label='Total Loss')
plt.plot(recons, label='Reconstruction Loss')
plt.plot(kls, label='KL Divergence')
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend()
plt.title("Beta-VAE Training Loss")
plt.grid(True)
plt.show()
     

In [None]:
def generate_images(model , num_images = 16):
    model.eval()
    with torch.no_grad():
        z = torch.randn(num_images , model.latent_dim).to(device)
        samples = model.decode(z)
    return samples.cpu()

samples = generate_images(model)
# Plot generated digits
plt.figure(figsize=(4, 4))
for i in range(16):
    plt.subplot(4, 4, i + 1)
    plt.imshow(samples[i][0], cmap='gray')
    plt.axis('off')
plt.suptitle("Generated Digits from Beta-VAE")
plt.tight_layout()
plt.show()
     