In [10]:
import torch 
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import torchvision.transforms as transforms

In [11]:
# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Transform: Normalize images to [-1,1]
transform = transforms.Compose([
    transforms.ToTensor(),
    #transforms.Normalize((0.5,), (0.5,)),
    #transforms.ConvertImageDtype(torch.float)
])

# Load Fashion-MNIST dataset
train_dataset = torchvision.datasets.FashionMNIST(root="./data", train=True, transform=transform)
train_loader  = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True, drop_last=True)

In [12]:
# Define VAE Model
class VAE(nn.Module):
    def __init__(self):
        super(VAE, self).__init__()

        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, stride=2, padding=1),  # -> [B, 32, 14, 14]
            nn.ReLU(),
            nn.MaxPool2d(2, 2),  # -> [B, 32, 7, 7]
            nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),  # -> [B, 64, 4, 4]
            nn.ReLU(),
            nn.MaxPool2d(2, 2)  # -> [B, 64, 2, 2]
        )

        self.fc1  = nn.Linear(64 * 2 * 2, 400)  # match encoder output
        self.fc21 = nn.Linear(400, 20)  # mu
        self.fc22 = nn.Linear(400, 20)  # logvar

        # Latent to decoder
        self.fc3 = nn.Linear(20, 400)
        self.fc4 = nn.Linear(400, 64 * 7 * 7)

        # Decoder (transposed convs)
        self.deconv1 = nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=1)
        self.deconv2 = nn.ConvTranspose2d(32, 1, kernel_size=3, stride=2, padding=1, output_padding=1)

    def encode(self, x):
        x = self.encoder(x)  # [B, 64, 2, 2]
        x = x.view(x.size(0), -1)  # flatten to [B, 256]
        h = F.relu(self.fc1(x))   # [B, 400]
        return self.fc21(h), self.fc22(h)  # mu, logvar

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std
    
    def decode(self, z):
        x = F.relu(self.fc3(z))             # [B, 400]
        x = F.relu(self.fc4(x))             # [B, 3136]
        x = x.view(-1, 64, 7, 7)            # reshape to [B, 64, 7, 7]
        x = F.relu(self.deconv1(x))         # -> [B, 32, 14, 14]
        x = torch.sigmoid(self.deconv2(x))  # -> [B, 1, 28, 28]
        return x

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        reconstructed = self.decode(z)
        return reconstructed, mu, logvar


'''
# Loss function: Reconstruction loss + KL Divergence loss
def loss_function(recon_x, x, mu, logvar):
    BCE = nn.F.binary_cross_entropy(recon_x.view(x.size(0), -1), x.view(x.size(0), -1), reduction='sum')
    KL = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return BCE + KL
'''

def loss_function(recon_x, x, mu, logvar):
    recon_x_flat = recon_x.view(x.size(0), -1)
    x_flat = x.view(x.size(0), -1)

    MSE = F.mse_loss(recon_x_flat, x_flat, reduction='mean')
    KL = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())/x.size(0)
    return (MSE + KL) #/ x.size(0)

In [None]:
# Initialize model, loss, optimizer
model = VAE().to(device)   # default to model.train() if model.eval() is not called before training
optimizer = optim.Adam(model.parameters(), lr=0.0001)
# Training loop
num_epochs = 30
for epoch in range(num_epochs):
    total_loss = 0
    model.train()
    for images, _ in train_loader:
        images = images.to(device)
        
        recon_images, mu, logvar = model(images)
        loss = loss_function(recon_images, images, mu, logvar)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
    
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {total_loss:.4f}")

Epoch [1/30], Loss: 1.5245
Epoch [2/30], Loss: 1.2912
Epoch [3/30], Loss: 1.2820
Epoch [4/30], Loss: 1.2791
Epoch [5/30], Loss: 1.2782
Epoch [6/30], Loss: 1.2775
Epoch [7/30], Loss: 1.2769
Epoch [8/30], Loss: 1.2765
Epoch [9/30], Loss: 1.2764
Epoch [10/30], Loss: 1.2762
Epoch [11/30], Loss: 1.2759
Epoch [12/30], Loss: 1.2756


In [None]:
# Function to visualize reconstructed images
def visualize_reconstruction(model, data_loader):
    model.eval()
    with torch.no_grad():
        images, _ = next(iter(data_loader))
        images = images.to(device)
        recon_images, _, _ = model(images)

        images = images.cpu().numpy()
        recon_images = recon_images.cpu().numpy()

        fig, axes = plt.subplots(2, 10, figsize=(10, 2))
        for i in range(10):
            axes[0, i].imshow(images[i][0] * 0.5 + 0.5, cmap='gray')  # Original
            axes[0, i].axis('off')
            axes[1, i].imshow(recon_images[i][0] * 0.5 + 0.5, cmap='gray')  # Reconstructed
            axes[1, i].axis('off')

        axes[0, 0].set_title("Original")
        axes[1, 0].set_title("Reconstructed")
        plt.show()

# Run visualization
visualize_reconstruction(model, train_loader)