In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
from torchsummary import summary

In [None]:
#transform = transforms.ToTensor()

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5), (0.5))
])

mnist_data = datasets.MNIST(root='./data', train=True, download=True, transform=transform)

data_loader = torch.utils.data.DataLoader(dataset=mnist_data,
                                          batch_size=64,
                                          shuffle=True)


In [None]:
import torch
from torch import nn


class VariationalAutoEncoder(nn.Module):
    def __init__(self, input_dim, h_dim=200, z_dim=20):
        super().__init__()
        # encoder
        self.img_2hid = nn.Linear(input_dim, h_dim)
        self.hid_2mu = nn.Linear(h_dim, z_dim)
        self.hid_2sigma = nn.Linear(h_dim, z_dim)

        # decoder
        self.z_2hid = nn.Linear(z_dim, h_dim)
        self.hid_2img = nn.Linear(h_dim, input_dim)

        self.relu = nn.ReLU()

    def encode(self, x):
        h = self.relu(self.img_2hid(x))
        mu, sigma = self.hid_2mu(h), self.hid_2sigma(h)
        return mu, sigma

    def decode(self, z):
        h = self.relu(self.z_2hid(z))
        return torch.sigmoid(self.hid_2img(h))

    def forward(self, x):
        mu, sigma = self.encode(x)
        epsilon = torch.randn_like(sigma)
        z_new = mu + sigma*epsilon
        x_reconstructed = self.decode(z_new)
        return x_reconstructed, mu, sigma


if __name__ == "__main__":
    x = torch.randn(4, 28*28)
    vae = VariationalAutoEncoder(input_dim=784)
    x_reconstructed, mu, sigma = vae(x)
    print(x_reconstructed.shape)
    print(mu.shape)
    print(sigma.shape)
    print(summary(vae, input_size=(784,)))




In [None]:
def plot_reconstructed_images(model, data_loader, device, num_images=8):
    model.eval()
    with torch.no_grad():
        for x, _ in data_loader:
            x = x.to(device).view(-1, INPUT_DIM)
            x_reconstructed, _, _ = model(x)
            break

    fig, axes = plt.subplots(2, num_images, figsize=(num_images*2, 4))
    for i in range(num_images):
        axes[0, i].imshow(x[i].cpu().reshape(28,28), cmap='gray')
        axes[0, i].axis('off')
        axes[0, i].set_title("Original")
        
        axes[1, i].imshow(x_reconstructed[i].cpu().reshape(28,28), cmap='gray')
        axes[1, i].axis('off')
        axes[1, i].set_title("Reconstructed")

    plt.tight_layout()
    plt.show()




In [None]:
def generate_random_images(model, device, num_images=10, z_dim=2):
    model.eval()
    with torch.no_grad():
        z = torch.randn(num_images, z_dim).to(device)
        generated_images = model.decode(z).cpu()

    fig, axes = plt.subplots(1, num_images, figsize=(num_images * 2, 2))
    for i in range(num_images):
        axes[i].imshow(generated_images[i].reshape(28, 28), cmap='gray')
        axes[i].axis('off')
        axes[i].set_title(f'Sample {i+1}')
    plt.show()




In [None]:
import matplotlib.pyplot as plt
import numpy as np

def visualize_latent_space(model, data_loader, device, num_batches=100):
    model.eval()
    all_mu, all_labels = [], []
    
    with torch.no_grad():
        for i, (x, labels) in enumerate(data_loader):
            if i >= num_batches:
                break
            x = x.to(device).view(x.shape[0], -1) # flaten data into single vector per image
            mu, sigma = model.encode(x)
            all_mu.append(mu.cpu().numpy())
            all_labels.append(labels.numpy())

    all_mu = np.concatenate(all_mu)
    all_labels = np.concatenate(all_labels)

    plt.figure(figsize=(8, 6))
    # x dimension y dimension
    scatter = plt.scatter(all_mu[:, 0], all_mu[:, 1], c=all_labels, cmap='tab10', alpha=0.7)
    plt.colorbar(scatter, ticks=range(10))
    plt.xlabel('Latent dimension 1')
    plt.ylabel('Latent dimension 2')
    plt.title('Latent Space Visualization')
    plt.grid()
    plt.show()




In [None]:
import torch
import torchvision.datasets as datasets  # Standard datasets
from tqdm import tqdm
from torch import nn, optim
from torchvision import transforms
from torchvision.utils import save_image
from torch.utils.data import DataLoader

# Configuration
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
INPUT_DIM = 784
H_DIM = 200
Z_DIM = 2
NUM_EPOCHS = 10
BATCH_SIZE = 32
LR_RATE = 3e-4  # Karpathy constant

# Dataset Loading
dataset = datasets.MNIST(root="dataset/", train=True, transform=transforms.ToTensor(), download=True)
train_loader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE, shuffle=True)
model = VariationalAutoEncoder(INPUT_DIM, H_DIM, Z_DIM).to(DEVICE)
optimizer = optim.Adam(model.parameters(), lr=LR_RATE)
loss_fn = nn.BCELoss(reduction="sum")

for epoch in range(NUM_EPOCHS):
    loop = tqdm(train_loader, desc=f'Epoch [{epoch+1}/{NUM_EPOCHS}]')
    for x, _ in loop:
        x = x.to(DEVICE).view(x.shape[0], INPUT_DIM)
        x_reconstructed, mu, sigma = model(x)
        reconstruction_loss = loss_fn(x_reconstructed, x)
        kl_div = -torch.sum(1 + torch.log(sigma.pow(2)) - mu.pow(2) - sigma.pow(2))

        loss = reconstruction_loss + kl_div
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        loop.set_postfix(loss=f'{loss.item()/len(x):.4f}')
    
    plot_reconstructed_images(model, train_loader, DEVICE)
    visualize_latent_space(model, train_loader, DEVICE)
    generate_random_images(model, DEVICE, num_images=10, z_dim=2)
