In [1]:
# Load some mnist data
from torchvision import transforms, datasets
from torch.utils.data import DataLoader

batch_size = 256
learning_rate = 1e-3
epochs = 100

# -----------------------
# 2. Data Loading
# -----------------------
transform = transforms.ToTensor()
train_dataset = datasets.MNIST(root="./data", train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root="./data", train=False, download=True, transform=transform)
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)

In [2]:
import torch
class Encoder(torch.nn.Module):
    def __init__(self, input_size=784, hidden_size=400, latent_size=20):
        super().__init__()

        self.fc1 = torch.nn.Linear(input_size, hidden_size)

        # This is the mu/mean of the latent Gaussian
        self.fc_mu = torch.nn.Linear(hidden_size,latent_size)

        # This is the logvariance of the latent Gaussian, used instead of standard deviation for numerical stability
        self.fc_logvar = torch.nn.Linear(hidden_size,latent_size)
        self.relu = torch.nn.ReLU()

    def forward(self, x):
        # We apply relu to output so that we do not get negative values, also introduces non-linearity
        hidden = self.relu(self.fc1(x))
        
        mu = self.fc_mu(hidden)
        logvar = self.fc_logvar(hidden)
        return mu, logvar

class Decoder(torch.nn.Module):
    def __init__(self, latent_size=20, hidden_size=400, output_size=784):
        super().__init__()

        self.fc1 = torch.nn.Linear(latent_size, hidden_size)
        self.fc2 = torch.nn.Linear(hidden_size, output_size)
        self.relu = torch.nn.ReLU()
        self.sigmoid = torch.nn.Sigmoid() # Sigmoid activation function to output values between 0 and 1

    def forward(self, z):
        """
        Forward pass of decoder: takes latent variable z and returns reconstructed output

        Returns:
            Image: reconstructed image in [0,1] interval
        """

        hidden = self.relu(self.fc1(z))
        output = self.sigmoid(self.fc2(hidden))

        return output

class VAE(torch.nn.Module):
    def __init__(self, input_size=784, hidden_size=400, latent_size=20):
        super().__init__()

        self.encoder = Encoder(input_size, hidden_size, latent_size)
        self.decoder = Decoder(latent_size, hidden_size, input_size)

    def reparameterize(self,mu,logvar):
        """
        Reparameterization trick: z = mu + std*eps; eps ~ N(0, I)
        """
        assert mu.size() == logvar.size()
        std = torch.exp(0.5 * logvar) # standard deviation
        eps = torch.randn_like(std) # random noise mean 0, variance 1 dimension same as std which is latent size
        return mu + eps * std

    def forward(self,x):
        mu,logvar =  self.encoder(x)
        z = self.reparameterize(mu,logvar)
        x_reconstructed = self.decoder(z)

        return x_reconstructed, mu, logvar
    
    def loss(self, x, x_reconstructed, mu, logvar):
        """
        Computes the VAE loss function
        """
        # Reconstruction loss
        BCE = torch.nn.functional.binary_cross_entropy(x_reconstructed, x, reduction='sum')

        var = logvar.exp()

        # KL divergence loss for Gaussian prior and posterior
        KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - var) # KL divergence between q(z|x) and p(z)

        return BCE + KLD

In [None]:
# Training
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = VAE().to(device)

model.train() # set to training mode
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

for epoch in range(epochs):
    train_loss = 0
    for i, (images, _) in enumerate(train_loader):
        images = images.view(-1,784).to(device) # gives us a shape of (batch_size, 784)
        optimizer.zero_grad()

        recon, mu, logvar = model(images)
        loss = model.loss(images,recon,mu,logvar)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
    avg_loss = train_loss / len(train_loader.dataset)
    print(f"Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.2f}")

In [None]:
import matplotlib.pyplot as plt

# -----------------------
# 6. Testing
# -----------------------
model.eval()
test_loss = 0
with torch.no_grad():
    for i, (data, _) in enumerate(test_loader):
        data = data.view(-1, 784).to(device)
        recon, mu, logvar = model(data)
        loss = model.loss(recon, data, mu, logvar)
        test_loss += loss.item()

test_loss /= len(test_loader.dataset)
print(f"Test Loss: {test_loss:.4f}")

# -----------------------
# 7. Visualization
# -----------------------
# Reconstruct some test images
data_iter = iter(test_loader)
images, _ = next(data_iter)
images = images.to(device).view(-1, 784)

with torch.no_grad():
    recon, _, _ = model(images)

images = images.view(-1, 28, 28).cpu().numpy()
recon = recon.view(-1, 28, 28).cpu().numpy()

fig, axes = plt.subplots(2, 8, figsize=(12, 3))
for i in range(8):
    axes[0, i].imshow(images[i], cmap='gray')
    axes[0, i].axis('off')
    axes[1, i].imshow(recon[i], cmap='gray')
    axes[1, i].axis('off')
axes[0, 0].set_title("Original")
axes[1, 0].set_title("Reconstructed")
plt.show()

# Sample from the latent space
with torch.no_grad():
    z = torch.randn(16, 20).to(device)
    sample = model.decoder(z)
    sample = sample.view(-1, 28, 28).cpu().numpy()

fig, axes = plt.subplots(2, 8, figsize=(12, 3))
for i in range(16):
    r = i // 8
    c = i % 8
    axes[r, c].imshow(sample[i], cmap='gray')
    axes[r, c].axis('off')
plt.suptitle("Randomly Sampled Digits")
plt.show()

In [6]:
# Save weights of model
torch.save(model.state_dict(), "vae.pth")