# Variational Auto Encoder

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [3]:
class VAE(nn.Module):
    def __init__(self, latent_dim=20):
        super(VAE, self).__init__()
        # Encoder: conv layers or linear layers ending in hidden size
        self.fc1 = nn.Linear(784, 400)
        self.fc_mu = nn.Linear(400, latent_dim)
        self.fc_logvar = nn.Linear(400, latent_dim)
        # Decoder: linear layers from latent to output
        self.fc_dec1 = nn.Linear(latent_dim, 400)
        self.fc_dec2 = nn.Linear(400, 784)

    def encode(self, x):
        h = F.relu(self.fc1(x))
        mu = self.fc_mu(h)
        logvar = self.fc_logvar(h)
        return mu, logvar

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)      # standard deviation
        eps = torch.randn_like(std)        # sample epsilon ~ N(0,1)
        return mu + eps * std              # z ~ N(mu, sigma^2)

    def decode(self, z):
        h = F.relu(self.fc_dec1(z))
        x_recon = torch.sigmoid(self.fc_dec2(h))  # for Bernoulli output
        return x_recon

    def forward(self, x):
        mu, logvar = self.encode(x.view(-1, 784))
        z = self.reparameterize(mu, logvar)
        x_recon = self.decode(z)
        return x_recon, mu, logvar

In [None]:
model = VAE(latent_dim=20)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

for batch in data_loader:    # iterate minibatches
    x = batch.to(device)     # x shape: [batch, 1, 28, 28]
    optimizer.zero_grad()
    x_recon, mu, logvar = model(x)
    # Reconstruction loss (Bernoulli cross-entropy per pixel)
    recon_loss = F.binary_cross_entropy(x_recon, x.view(-1,784), reduction='sum')
    # KL divergence loss for q(z|x) vs p(z)=N(0,1)
    kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    loss = recon_loss + kl_loss
    loss.backward()
    optimizer.step()
