In [1]:
import torch
from torch import nn

INPUT_DIM = 784
H_DIM = 200
Z_DIM = 1

class VariationalAutoEncoder(nn.Module):
    def __init__(self, input_dim, h_dim=H_DIM, z_dim=Z_DIM):
        super().__init__()
        
        # encoder
        self.img_2hid = nn.Linear(input_dim, h_dim)
        self.hid_2mu = nn.Linear(h_dim, z_dim)
        self.hid_2sigma = nn.Linear(h_dim, z_dim)

        # decoder
        self.z_2hid = nn.Linear(z_dim, h_dim)
        self.hid_2img = nn.Linear(h_dim, input_dim)

        self.relu = nn.ReLU()

    def encode(self, x):
        h = self.relu(self.img_2hid(x))
        mu, sigma = self.hid_2mu(h), self.hid_2sigma(h)
        return mu, sigma

    def decode(self, z):
        h = self.relu(self.z_2hid(z))
        return torch.sigmoid(self.hid_2img(h))

    def forward(self, x):
        mu, sigma = self.encode(x)
        # reparametrization trick
        epsilon = torch.randn_like(sigma)
        z_new = mu + sigma*epsilon
        x_reconstructed = self.decode(z_new)
        return x_reconstructed, mu, sigma


In [4]:
import torch
import torchvision.datasets as datasets  # Standard datasets
from tqdm import tqdm
from torch import nn, optim
from torchvision import transforms
from torchvision.utils import save_image
from torch.utils.data import DataLoader
from matplotlib import pyplot as plt

# Set random seed for reproducibility
seed = 42
torch.manual_seed(seed)

# Configuration
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
NUM_EPOCHS = 30
BATCH_SIZE = 32
LR_RATE = 3e-4  

# Dataset Loading
dataset = datasets.MNIST(root="dataset/", train=True, transform=transforms.ToTensor(), download=True)
train_loader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE, shuffle=True)
model = VariationalAutoEncoder(INPUT_DIM, H_DIM, Z_DIM).to(DEVICE)
optimizer = optim.Adam(model.parameters(), lr=LR_RATE)
loss_fn = nn.BCELoss(reduction="sum")

In [5]:
# Training
for epoch in range(NUM_EPOCHS):
    
    loop = tqdm(enumerate(train_loader)) # for printing purpose

    for i, (x, _) in loop:

        # forward pass
        x = x.to(DEVICE).view(x.shape[0], INPUT_DIM)
        x_reconstructed, mu, sigma = model(x)

        # compute loss
        reconstruction_loss = loss_fn(x_reconstructed, x)
        kl_div = -torch.sum(1 + torch.log(sigma.pow(2)) - mu.pow(2) - sigma.pow(2))

        # backprop
        loss = reconstruction_loss + kl_div
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        loop.set_postfix(loss=loss.item())

1875it [00:08, 224.47it/s, loss=6.3e+3] 
1875it [00:08, 227.57it/s, loss=5.67e+3]
1875it [00:08, 224.24it/s, loss=5.72e+3]
1875it [00:08, 228.02it/s, loss=5.76e+3]
1875it [00:08, 218.82it/s, loss=6.1e+3] 
1875it [00:08, 222.83it/s, loss=5.59e+3]
1875it [00:08, 222.01it/s, loss=5.45e+3]
1875it [00:08, 219.43it/s, loss=5.64e+3]
1875it [00:08, 220.44it/s, loss=5.62e+3]
1875it [00:08, 224.00it/s, loss=5.92e+3]
1875it [00:08, 224.18it/s, loss=5.69e+3]
1875it [00:08, 226.19it/s, loss=5.48e+3]
1875it [00:08, 228.29it/s, loss=6.28e+3]
1875it [00:08, 225.52it/s, loss=6.32e+3]
1875it [00:08, 224.90it/s, loss=6.05e+3]
1875it [00:08, 226.28it/s, loss=5.93e+3]
1875it [00:08, 223.69it/s, loss=5.95e+3]
1875it [00:08, 224.02it/s, loss=5.93e+3]
1875it [00:08, 223.45it/s, loss=6.02e+3]
1875it [00:08, 227.28it/s, loss=5.99e+3]
1875it [00:08, 227.00it/s, loss=5.56e+3]
1875it [00:08, 215.86it/s, loss=5.52e+3]
1875it [00:08, 223.13it/s, loss=6.4e+3] 
1875it [00:08, 221.33it/s, loss=5.43e+3]
1875it [00:08, 2

In [6]:
# Save the model's parameters
torch.save(model.state_dict(), f'output/MNIST/model_{Z_DIM}.pth')