# Generating Handwritten Digits with VAEs in PyTorch

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from torchvision import datasets, transforms
import torchvision.transforms as transforms
from torch.utils.data import Dataset, random_split, DataLoader
from torchvision.utils import save_image

from torchsummary import summary

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import os
import math
from PIL import Image
from PIL import Image
from IPython.display import display
import glob

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [3]:
x_data = datasets.MNIST(root='./mnist_data/', train=True, transform=transforms.ToTensor(), download=True)
y_data = datasets.MNIST(root='./mnist_data/', train=False, transform=transforms.ToTensor(), download=False)

train_loader = torch.utils.data.DataLoader(dataset=x_data, batch_size=100, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=y_data, batch_size=100, shuffle=False)

In [4]:
class VAE(nn.Module):
    def __init__(self, latent_size=4):
        super(VAE, self).__init__()
        
        self.latent_size = latent_size
        
        self.l1 = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=4, stride=2, padding=1)
        self.l2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2, padding=1)        
        
        self.l21 = nn.Linear(32*2*7*7, self.latent_size)
        self.l22 = nn.Linear(32*2*7*7, self.latent_size)
        
        self.f = nn.Linear(self.latent_size, 32*2*7*7)
        
        self.l3 = nn.ConvTranspose2d(in_channels=64, out_channels=32, kernel_size=4, stride=2, padding=1)
        self.l4 = nn.ConvTranspose2d(in_channels=32, out_channels=1, kernel_size=4, stride=2, padding=1)
        
    def encoder(self, x_in):
        h = F.relu(self.l1(x_in))
        h = F.relu(self.l2(h))
        
        h = h.view(h.size(0), -1)
        
        return self.l21(h), self.l22(h)
    
    def decoder(self, z):
        z = self.f(z)
        z = z.view(z.size(0), 32*2, 7, 7)
        
        z = F.relu(self.l3(z))
        z = torch.sigmoid(self.l4(z))
        
        return z
    
    def sampling(self, mu, log_var):
        std = torch.exp(0.5*log_var)
        eps = torch.randn_like(std)
        return torch.add(eps.mul(std), mu)
    
    def forward(self, x_in):
        mu, log_var = self.encoder(x_in)
        z = self.sampling(mu, log_var)
        return self.decoder(z), mu, log_var

In [5]:
vae = VAE()
    
vae.to(device)

VAE(
  (l1): Conv2d(1, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  (l2): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  (l21): Linear(in_features=3136, out_features=4, bias=True)
  (l22): Linear(in_features=3136, out_features=4, bias=True)
  (f): Linear(in_features=4, out_features=3136, bias=True)
  (l3): ConvTranspose2d(64, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  (l4): ConvTranspose2d(32, 1, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
)

In [6]:
summary(vae, (1, 28, 28))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 32, 14, 14]             544
            Conv2d-2             [-1, 64, 7, 7]          32,832
            Linear-3                    [-1, 4]          12,548
            Linear-4                    [-1, 4]          12,548
            Linear-5                 [-1, 3136]          15,680
   ConvTranspose2d-6           [-1, 32, 14, 14]          32,800
   ConvTranspose2d-7            [-1, 1, 28, 28]             513
Total params: 107,465
Trainable params: 107,465
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.15
Params size (MB): 0.41
Estimated Total Size (MB): 0.56
----------------------------------------------------------------


In [7]:
optimizer = optim.Adam(vae.parameters())

def loss_function(recon_x, x, mu, log_var):
    #print(recon_x.size())
    #print(x.size())
    BCE = F.binary_cross_entropy(recon_x, x, reduction='sum')
    KLD = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp()) # KL Divergence from MIT 6.S191
    return (BCE + KLD)

In [8]:
def train(epoch):
    vae.train()
    train_loss = 0
    for batch_idx, (data, _) in enumerate(train_loader):
        data = data.to(device)
        optimizer.zero_grad()
        
        r_batch, mu, log_var = vae(data)

        loss = loss_function(r_batch, data, mu, log_var)

        loss.backward()
        train_loss += loss.item()
        optimizer.step()
        
    print('Epoch: {} Train mean loss: {:.8f}'.format(epoch, train_loss / len(train_loader.dataset)), end=" ")

In [9]:
def test(epoch):
    vae.eval()
    test_loss=0
    with torch.no_grad():
        for data, _ in test_loader:
            data = data.to(device)
            #print(data.size())
            r, mu, log_var = vae(data)
            #print(r.size())
            #print(data.size())
            test_loss += loss_function(r, data, mu, log_var).item()
            
    test_loss /= len(test_loader.dataset)
    print('Epoch: {} Test mean loss: {:.8f}'.format(epoch, test_loss))

In [10]:
n_epoches = 10

for epoch in range(1, n_epoches+1):
    train(epoch)
    test(epoch)

Epoch: 1 Train mean loss: 172.63472850 Epoch: 1 Test mean loss: 148.49528750
Epoch: 2 Train mean loss: 144.33558420 Epoch: 2 Test mean loss: 141.12898662
Epoch: 3 Train mean loss: 139.76057521 Epoch: 3 Test mean loss: 138.40836152
Epoch: 4 Train mean loss: 137.40822023 Epoch: 4 Test mean loss: 136.47317852
Epoch: 5 Train mean loss: 136.01164126 Epoch: 5 Test mean loss: 135.31066396
Epoch: 6 Train mean loss: 134.97169360 Epoch: 6 Test mean loss: 134.51169756
Epoch: 7 Train mean loss: 134.14543091 Epoch: 7 Test mean loss: 134.00177539
Epoch: 8 Train mean loss: 133.52039969 Epoch: 8 Test mean loss: 133.63631113
Epoch: 9 Train mean loss: 132.97341929 Epoch: 9 Test mean loss: 133.19777441
Epoch: 10 Train mean loss: 132.54359199 Epoch: 10 Test mean loss: 132.77700938


In [11]:
with torch.no_grad():
    z = torch.randn(1, 4).to(device)
    for i in range(100):
        z = torch.add(z, 0.05)
        
        sample = vae.decoder(z).to(device)
        save_image(sample.view(1, 28, 28), './samplesCVAE/sample' + str(i) + '.png')