### building vae using PyTorch from scratch
* using MNIST dataset

In [1]:
import torch
from torch import nn, optim
import torch.nn.functional as F

In [2]:
import torchvision.datasets as datasets
from tqdm import tqdm # progress bar
from torchvision import transforms
from torchvision.utils import save_image
from torch.utils.data import DataLoader

In [3]:
class VariationalAutoEncoder(nn.Module): # inherits from nn.Module
    def __init__(self, input_dim, h_dim=200, z_dim=20):
        super().__init__() # initializes all internal PyTorch machinery in nn.Module
        # define basic layers
        # encoder
        self.img_2hid = nn.Linear(input_dim, h_dim)
        self.hid_2mu = nn.Linear(h_dim, z_dim)
        self.hid_2sigma = nn.Linear(h_dim, z_dim)

        # decoder
        self.z_2hid = nn.Linear(z_dim, h_dim)
        self.h_2img = nn.Linear(h_dim, input_dim)

        self.relu = nn.ReLU()

    # encoder and decoder are helper functions to make the code modular and clean
    def encode(self,x):
        # q_phi(z|x)

        h = self.relu(self.img_2hid(x))
        mu, sigma = self.hid_2mu(h), self.hid_2sigma(h) # mu and sigma come from linear layer, no need for relu
        return mu, sigma
        
    def decode(self,z):
        # p_theta(x|z)

        h = self.relu(self.z_2hid(z))
        return torch.sigmoid(self.h_2img(h)) # to bring scale back to 0-1

    # combine everything in forward function
    # defines the full model logic
    # by default being called when calling the whole class
    def forward(self,x):
        mu, sigma = self.encode(x)
        epsilon = torch.randn_like(sigma) # sample from N(0,1) with same shape as sigma
        z_reparametrized = mu + sigma * epsilon
        x_reconstructed = self.decode(z_reparametrized)
        return x_reconstructed, mu, sigma # need all 3 to calculate loss function

In [4]:
# test the model 
if __name__ == "__main__":
    x = torch.randn(1,28*28)
    vae = VariationalAutoEncoder(input_dim=784)
    x_reconstructed, mu, sigma = vae(x)
    print(x_reconstructed.shape)
    print(mu.shape)
    print(sigma.shape)

torch.Size([1, 784])
torch.Size([1, 20])
torch.Size([1, 20])


In [5]:
# configuration
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
INPUT_DIM = 784
H_DIM = 200
Z_DIM = 20

NUM_EPOCHS = 10
BATCH_SIZE = 32
LR_RATE = 3e-4 # Karpathy constant


In [6]:
# download MNIST dataset to local drive
dataset = datasets.MNIST(root="dataset/", train=True, transform=transforms.ToTensor(), download=True)

In [7]:
train_loader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE, shuffle=True)

In [8]:
model = VariationalAutoEncoder(INPUT_DIM, H_DIM, Z_DIM).to(DEVICE)
optimizer = optim.Adam(model.parameters(), lr=LR_RATE)
loss_fn = nn.BCELoss(reduction="sum")

In [9]:
# start training
for epoch in range(NUM_EPOCHS):
    loop = tqdm(enumerate(train_loader)) # to get progress bar
    for i, (x, _) in loop:
        # forward pass
        x = x.to(DEVICE).view(x.shape[0], INPUT_DIM) # reshape input dimension (can also use reshape)
        x_reconstructed, mu, sigma = model(x)

        # compute loss
        reconstruction_loss = loss_fn(x_reconstructed, x)
        kl_div = -torch.sum(1 + torch.log(sigma.pow(2)) - mu.pow(2) - sigma.pow(2))

        # back prop
        loss = reconstruction_loss + kl_div
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        loop.set_postfix(loss=loss.item())

1875it [00:05, 345.71it/s, loss=5.24e+3]
1875it [00:05, 348.38it/s, loss=4.96e+3]
1875it [00:05, 344.50it/s, loss=4.84e+3]
1875it [00:05, 335.69it/s, loss=4.29e+3]
1875it [00:05, 351.17it/s, loss=4.42e+3]
1875it [00:06, 307.87it/s, loss=4.27e+3]
1875it [00:05, 335.41it/s, loss=3.9e+3] 
1875it [00:05, 343.22it/s, loss=4.14e+3]
1875it [00:05, 338.09it/s, loss=4.32e+3]
1875it [00:05, 338.45it/s, loss=3.94e+3]


In [53]:
# how to do inference
def inference(digit, num_examples=1):
    images = []
    idx = 0
    for x, y in dataset:
        if y == idx:
            images.append(x)
            idx += 1
        if idx == 10:
            break

    encodings_digit = []
    for d in range(10):
        with torch.no_grad():
            mu, sigma = model.encode(images[d].view(1, 784))
        encodings_digit.append((mu, sigma))

    # Create the output folder if it doesn't exist
    output_dir = "generated_digits"
    os.makedirs(output_dir, exist_ok=True)
    
    mu, sigma = encodings_digit[digit]
    for example in range(num_examples):
        epsilon = torch.randn_like(sigma)
        z = mu + sigma * epsilon
        out = model.decode(z)
        out = out.view(-1, 1, 28, 28) # PyTorch tensor method to reshape tensor (-1 is the placeholder to auto infer the dim)
        save_image(out, os.path.join(output_dir, f"generated_{digit}_ex{example}.png"))
        

In [54]:
for idx in range(10):
    inference(idx, num_examples=5)