# Variational Autoencoders(VAEs) from Scratch
We will considers few points starting our implementation of the VAEs
1. Predict $log(variance)$ instead of variance. It's done for stability during training process like $\mu$ for mean and $\log(\sigma^2)$ for variance which are positive. we can achieve same by using RELU but it will not have well defined gradient transiants. Variance have smaller value $\sigma \in [0, 1]$ which can have a wider range value after the log transform $log(\sigma) \in [-\inf, +\inf]$ i.e $log : [0,1] \to [-\inf, +\inf]$.
2. Reparameterization Tricks : We needed to sample data from $q(z|x)$ butif we directly sample from it's introduces stochastic elements but backpropagation relies on the deterministic elements. Thus we needed to do randomness for sampling but also preserve the differentiability of operations involves. The random sampling $z \approx N(0, 1)$ and then scale and translate to $z*\sigma + \mu$. Now backpropagation continuously flows through deterministic  nodes $\mu$ and $\sigma$.

In [None]:
# import modules
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms

# Define hyperparameters
batch_size = 100
latent_dim = 2
epochs = 100

# Load MNIST dataset
transform = transforms.ToTensor()
train_dataset = datasets.MNIST('~/.pytorch/MNIST_data/', download=True, train=True, transform=transform)
test_dataset = datasets.MNIST('~/.pytorch/MNIST_data/', download=True, train=False, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [None]:
# Define VAE model
class VAE(nn.Module):
    def __init__(self, input_dim, latent_dim):
        super(VAE, self).__init__()
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, 256),
            nn.ReLU(),
            nn.Linear(256, latent_dim * 2)
        )
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, 256),
            nn.ReLU(),
            nn.Linear(256, input_dim),
            nn.Sigmoid()
        )

    def encode(self, x):
        z_mean, z_log_var = self.encoder(x).chunk(2, dim=1)
        return z_mean, z_log_var

    def reparameterize(self, z_mean, z_log_var):
        std = torch.exp(0.5 * z_log_var)
        eps = torch.randn_like(std)
        return z_mean + eps * std

    def decode(self, z):
        return self.decoder(z)

    def forward(self, x):
        z_mean, z_log_var = self.encode(x)
        z = self.reparameterize(z_mean, z_log_var)
        x_recon = self.decode(z)
        return x_recon, z_mean, z_log_var


In [None]:
# Initialize VAE and optimizer
vae = VAE(input_dim=784, latent_dim=latent_dim)
optimizer = optim.Adam(vae.parameters(), lr=0.001)

# Train VAE
for epoch in range(epochs):
    for x, _ in train_loader:
        x = x.view(-1, 784)
        x_recon, z_mean, z_log_var = vae(x)
        reconstruction_loss = nn.MSELoss()(x_recon, x)
        kl_divergence_loss = 0.5 * torch.sum(torch.exp(z_log_var) + z_mean ** 2 - 1 - z_log_var)
        loss = reconstruction_loss + kl_divergence_loss
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f'Epoch {epoch+1}, Loss: {loss.item()}')


In [None]:
# Test VAE
with torch.no_grad():
    for x, _ in test_loader:
        x = x.view(-1, 784)
        x_recon, _, _ = vae(x)
        print(torch.sum((x_recon - x) ** 2))