<a href="https://colab.research.google.com/github/apester/IME/blob/main/Lab11_VAE.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms, utils
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Hyperparameters
batch_size = 128
epochs = 10
learning_rate = 1e-3
latent_dim = 20  # Dimensionality of the latent space
input_dim = 28 * 28  # MNIST images are 28x28

# Define the Variational Autoencoder (VAE) model
class VAE(nn.Module):
    def __init__(self, input_dim, latent_dim):
        super(VAE, self).__init__()
        # Encoder: Fully connected layers to parameterize mean and log variance
        self.fc1 = nn.Linear(input_dim, 400)
        self.fc21 = nn.Linear(400, latent_dim)  # For mean
        self.fc22 = nn.Linear(400, latent_dim)  # For log variance

        # Decoder: Fully connected layers to reconstruct the input
        self.fc3 = nn.Linear(latent_dim, 400)
        self.fc4 = nn.Linear(400, input_dim)

    def encode(self, x):
        h1 = F.relu(self.fc1(x))
        mu = self.fc21(h1)
        logvar = self.fc22(h1)
        return mu, logvar

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)  # Standard deviation
        eps = torch.randn_like(std)    # Sample epsilon from normal distribution
        return mu + eps * std          # Reparameterization trick

    def decode(self, z):
        h3 = F.relu(self.fc3(z))
        return torch.sigmoid(self.fc4(h3))  # Sigmoid to bound output between 0 and 1

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        recon_x = self.decode(z)
        return recon_x, mu, logvar

# Loss function: Reconstruction loss + KL divergence
def loss_function(recon_x, x, mu, logvar):
    # Reconstruction loss (binary cross entropy)
    BCE = F.binary_cross_entropy(recon_x, x, reduction='sum')
    # KL Divergence loss
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return BCE + KLD

# Load MNIST dataset
transform = transforms.ToTensor()  # Transforms images to tensor
train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST('./data', train=False, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# Initialize model, optimizer
model = VAE(input_dim=input_dim, latent_dim=latent_dim).to(device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# Training function
def train(epoch):
    model.train()
    train_loss = 0
    for batch_idx, (data, _) in enumerate(train_loader):
        data = data.view(-1, input_dim).to(device)
        optimizer.zero_grad()
        recon_batch, mu, logvar = model(data)
        loss = loss_function(recon_batch, data, mu, logvar)
        loss.backward()
        train_loss += loss.item()
        optimizer.step()

        if batch_idx % 100 == 0:
            print(f"Epoch {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)}]  Loss: {loss.item()/len(data):.4f}")
    print(f"====> Epoch {epoch} Average loss: {train_loss/len(train_loader.dataset):.4f}")

# Testing/Validation function
def test(epoch):
    model.eval()
    test_loss = 0
    with torch.no_grad():
        for data, _ in test_loader:
            data = data.view(-1, input_dim).to(device)
            recon, mu, logvar = model(data)
            test_loss += loss_function(recon, data, mu, logvar).item()
    test_loss /= len(test_loader.dataset)
    print(f"====> Test set loss: {test_loss:.4f}")

# Training loop
for epoch in range(1, epochs + 1):
    train(epoch)
    test(epoch)

# Visualize reconstructions from the test set
def visualize_reconstructions():
    model.eval()
    with torch.no_grad():
        data, _ = next(iter(test_loader))
        data = data.view(-1, input_dim).to(device)
        recon, _, _ = model(data)
        comparison = torch.cat([data.view(-1, 1, 28, 28)[:8],
                                recon.view(-1, 1, 28, 28)[:8]])
        # Save the image grid
        utils.save_image(comparison.cpu(), 'reconstructions.png', nrow=8)
        print("Reconstructed images saved as reconstructions.png")

visualize_reconstructions()

# Generate new samples from the latent space
def generate_samples(num_samples=16):
    model.eval()
    with torch.no_grad():
        # Sample from standard normal distribution in latent space
        z = torch.randn(num_samples, latent_dim).to(device)
        samples = model.decode(z).cpu()
        samples = samples.view(-1, 1, 28, 28)
        utils.save_image(samples, 'generated_samples.png', nrow=4)
        print("Generated sample images saved as generated_samples.png")

generate_samples()

# Optionally: Display generated images using matplotlib
def display_image(filename):
    img = plt.imread(filename)
    plt.figure(figsize=(8, 8))
    plt.imshow(img)
    plt.axis('off')
    plt.show()

# Uncomment below lines to display images inline (if running in Jupyter Notebook)
# display_image('reconstructions.png')
# display_image('generated_samples.png')