### An example of variational autoencoder on MNIST dataset

In [2]:
import torch 
import torchvision
import torch.nn as nn
import numpy as np
import torchvision.transforms as transforms
import torch.nn.functional as F
from torchvision.utils import save_image
import os

'''move the computations to the GPU if cuda is available, otherwise the computations will be run on CPU'''
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

'''defining model parameters'''
image_size = 784
h_dim = 400
z_dim = 20
num_epochs = 15
batch_size = 128
learning_rate = 1e-3

'''download the training and test set'''
train_dataset = torchvision.datasets.MNIST(root='data/', 
                                           train=True, 
                                           transform=transforms.ToTensor(),  
                                           download=True)

test_dataset = torchvision.datasets.MNIST(root='data/', 
                                          train=False, 
                                          transform=transforms.ToTensor())

'''use dataloader to shuffle and batch the data'''
train_loader = torch.utils.data.DataLoader(dataset = train_dataset, batch_size = batch_size, shuffle = True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=True)

'''define the model'''
class vae(nn.Module):
    def __init__(self, image_size = 784, h_dim = 400, z_dim = 20):
        super(vae, self).__init__()
        self.layer_1 = nn.Linear(image_size, h_dim)
        self.layer_2_1 = nn.Linear(h_dim, z_dim)
        self.layer_2_2 = nn.Linear(h_dim, z_dim)
        self.layer_3 = nn.Linear(z_dim, h_dim)
        self.layer_4 = nn.Linear(h_dim, image_size)
        self.relu = nn.ReLU()
    
    def encoder(self, x):
        out = self.layer_1(x)
        out = self.relu(out)
        mu = self.layer_2_1(out)
        std = self.layer_2_2(out)
        return mu, std

    def decoder(self, z):
        out = self.layer_3(z)
        out = self.relu(out)
        out = self.layer_4(out)
        out = torch.sigmoid(out)
        return out
    
    def forward(self, x):
        mu, std = self.encoder(x)
        z = self.parameterize(mu, std)
        out = self.decoder(z)
        return mu, std, out
    
    def parameterize(self, mu, std):
        #Returns a tensor with the same size as input that is filled with random numbers 
        #from a normal distribution with mean 0 and variance 1
        sample = torch.randn_like(std)
        out = mu + (std*sample)
        return out

'''instantiate the model'''
model = vae().to(device)
#criterion = kl_divergence + loss_of_generator
'''Adam optimizer is used as the optimization function. We optimized all the model parameters, with a given learning rate.'''
optimizer = torch.optim.Adam(model.parameters(), lr = learning_rate)

'''training'''
for epoch in range(num_epochs):
    for i, (img, label) in enumerate(train_loader):
        img = img.reshape(-1, image_size).to(device)
        mu, std, out = model(img)
        
        gen_loss = F.binary_cross_entropy(out, img, size_average=False)
        
        kl_div = - 0.5 * torch.sum(1 + std - mu.pow(2) - std.exp())
        
        loss = gen_loss + kl_div
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if i%100 == 0:
            print('epoch [{}/{}], step [{}/{}], loss {:.4f}'.format(epoch, num_epochs, i, len(train_loader), loss))

epoch [0/15], step [0/469], loss 69694.5703
epoch [0/15], step [100/469], loss 21035.7832
epoch [0/15], step [200/469], loss 16209.1182
epoch [0/15], step [300/469], loss 13802.0586
epoch [0/15], step [400/469], loss 12742.9785
epoch [1/15], step [0/469], loss 11931.6426
epoch [1/15], step [100/469], loss 11973.4199
epoch [1/15], step [200/469], loss 11213.8047
epoch [1/15], step [300/469], loss 10665.0469
epoch [1/15], step [400/469], loss 11209.6924
epoch [2/15], step [0/469], loss 10739.9209
epoch [2/15], step [100/469], loss 10400.5293
epoch [2/15], step [200/469], loss 10643.6660
epoch [2/15], step [300/469], loss 10382.5850
epoch [2/15], step [400/469], loss 10576.7363
epoch [3/15], step [0/469], loss 10327.3721
epoch [3/15], step [100/469], loss 10249.4600
epoch [3/15], step [200/469], loss 10094.6299
epoch [3/15], step [300/469], loss 10378.8857
epoch [3/15], step [400/469], loss 9671.6836
epoch [4/15], step [0/469], loss 10149.3115
epoch [4/15], step [100/469], loss 10028.2197

In [6]:
from torchvision.utils import save_image
import os
from PIL import Image

sample_dir = "data/"

with torch.no_grad():
    z = torch.randn(batch_size, z_dim).to(device)
    out = model.decoder(z).view(-1, 1, 28, 28)
    save_image(out, os.path.join(sample_dir, 'sampled-{}.png'.format(epoch+1)))
    img = Image.open(os.path.join(sample_dir, 'sampled-{}.png'.format(epoch+1)))
    img.show()