# Generating Handwritten Digits with VAEs in PyTorch

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from torchvision import datasets, transforms
import torchvision.transforms as transforms
from torch.utils.data import Dataset, random_split, DataLoader
from torchvision.utils import save_image

from torchsummary import summary

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import os
import math
from PIL import Image
from PIL import Image
from IPython.display import display
import glob

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [3]:
x_data = datasets.MNIST(root='./mnist_data/', train=True, transform=transforms.ToTensor(), download=True)
y_data = datasets.MNIST(root='./mnist_data/', train=False, transform=transforms.ToTensor(), download=False)

train_loader = torch.utils.data.DataLoader(dataset=x_data, batch_size=100, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=y_data, batch_size=100, shuffle=False)

In [4]:
class VAE(nn.Module):
    def __init__(self, x, h1, h2, z):
        super(VAE, self).__init__()
        
        self.l1 = nn.Linear(x, h1)
        self.l2 = nn.Linear(h1, h2)
        self.l31 = nn.Linear(h2, z)
        self.l32 = nn.Linear(h2, z)
        
        self.l4 = nn.Linear(z, h2)
        self.l5 = nn.Linear(h2, h1)
        self.l6 = nn.Linear(h1, x)
    
    def encoder(self, x_in):
        h = F.relu(self.l1(x_in))
        h = F.relu(self.l2(h))
        return self.l31(h), self.l32(h)
    
    def decoder(self, z):
        h = F.relu(self.l4(z))
        h = F.relu(self.l5(h))
        return torch.sigmoid(self.l6(h))
    
    def sampling(self, mu, log_var):
        std = torch.exp(0.5*log_var)
        eps = torch.randn_like(std)
        return torch.add(eps.mul(std), mu)
    
    def forward(self, x_in):
        mu, log_var = self.encoder(x_in.view(-1, 784))
        z = self.sampling(mu, log_var)
        return self.decoder(z), mu, log_var

In [5]:
vae = VAE(x=784, h1= 512, h2=256, z=2)
    
vae.to(device)

VAE(
  (l1): Linear(in_features=784, out_features=512, bias=True)
  (l2): Linear(in_features=512, out_features=256, bias=True)
  (l31): Linear(in_features=256, out_features=2, bias=True)
  (l32): Linear(in_features=256, out_features=2, bias=True)
  (l4): Linear(in_features=2, out_features=256, bias=True)
  (l5): Linear(in_features=256, out_features=512, bias=True)
  (l6): Linear(in_features=512, out_features=784, bias=True)
)

In [6]:
summary(vae, (1, 28, 28))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Linear-1                  [-1, 512]         401,920
            Linear-2                  [-1, 256]         131,328
            Linear-3                    [-1, 2]             514
            Linear-4                    [-1, 2]             514
            Linear-5                  [-1, 256]             768
            Linear-6                  [-1, 512]         131,584
            Linear-7                  [-1, 784]         402,192
Total params: 1,068,820
Trainable params: 1,068,820
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.02
Params size (MB): 4.08
Estimated Total Size (MB): 4.10
----------------------------------------------------------------


In [7]:
optimizer = optim.Adam(vae.parameters())

def loss_function(recon_x, x, mu, log_var):
    BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), reduction='sum')
    KLD = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp()) # KL Divergence from MIT 6.S191
    return BCE + KLD

In [8]:
def train(epoch):
    vae.train()
    train_loss = 0
    for batch_idx, (data, _) in enumerate(train_loader):
        data = data.to(device)
        
        optimizer.zero_grad()
        
        r_batch, mu, log_var = vae(data)
        loss = loss_function(r_batch, data, mu, log_var)
        
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
        
    print('Epoch: {} Train mean loss: {:.8f}'.format(epoch, train_loss / len(train_loader.dataset)), end=" ")

In [9]:
def test(epoch):
    vae.eval()
    test_loss=0
    with torch.no_grad():
        for data, _ in test_loader:
            data = data.to(device)
            r, mu, log_var = vae(data)
            test_loss += loss_function(r, data, mu, log_var).item()
        
    test_loss /= len(test_loader.dataset)
    print('Epoch: {} Test mean loss: {:.8f}'.format(epoch, test_loss))

In [10]:
n_epoches = 50

for epoch in range(1, n_epoches+1):
    train(epoch)
    test(epoch)

Epoch: 1 Train mean loss: 178.41738145 Epoch: 1 Test mean loss: 162.53292871
Epoch: 2 Train mean loss: 157.97955275 Epoch: 2 Test mean loss: 154.79316367
Epoch: 3 Train mean loss: 152.61461733 Epoch: 3 Test mean loss: 151.00139668
Epoch: 4 Train mean loss: 149.27334487 Epoch: 4 Test mean loss: 148.64444863
Epoch: 5 Train mean loss: 147.17804229 Epoch: 5 Test mean loss: 146.64882686
Epoch: 6 Train mean loss: 145.73971899 Epoch: 6 Test mean loss: 145.84128945
Epoch: 7 Train mean loss: 144.52863384 Epoch: 7 Test mean loss: 144.99415273
Epoch: 8 Train mean loss: 143.56036440 Epoch: 8 Test mean loss: 144.05461689
Epoch: 9 Train mean loss: 142.85794180 Epoch: 9 Test mean loss: 143.67806162
Epoch: 10 Train mean loss: 142.33500907 Epoch: 10 Test mean loss: 143.24529170
Epoch: 11 Train mean loss: 141.67030042 Epoch: 11 Test mean loss: 142.41908887
Epoch: 12 Train mean loss: 141.20681271 Epoch: 12 Test mean loss: 142.28842197
Epoch: 13 Train mean loss: 140.65099388 Epoch: 13 Test mean loss: 142.

In [11]:
with torch.no_grad():
    z = torch.randn(1, 2).to(device)
    for i in range(100):
        z = torch.add(z, 0.05)
        
        sample = vae.decoder(z).to(device)
        save_image(sample.view(1, 28, 28), './samples/sample' + str(i) + '.png')