# Import required libraries

In [1]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms
from torchvision.utils import save_image
from matplotlib import pyplot as plt
%matplotlib inline

#Device configuration
device=torch.device('cpu')
if torch.cuda.is_available():
    device=torch.device('cuda')

# Hyper-parameters

In [2]:
# Hyper-parameters
image_size = 784
h_dim = 400
z_dim = 20
num_epochs = 20
batch_size = 128
learning_rate = 1e-3

# Load dataset

In [3]:
#MNIST dataset
dataset = torchvision.datasets.MNIST(root='/data', train=False, transform=transforms.ToTensor())

#Data loader
data_loader = torch.utils.data.DataLoader(dataset=dataset, batch_size=batch_size,  shuffle=True)

# Variational Autoencoder model

In [4]:
class VAE(nn.Module):
    def __init__(self, image_size=784, h_dim=400, z_dim=20):
        super(VAE, self).__init__()
        self.fc1 = nn.Linear(image_size, h_dim)
        self.fc2 = nn.Linear(h_dim, z_dim)
        self.fc3 = nn.Linear(h_dim, z_dim)
        self.fc4 = nn.Linear(z_dim, h_dim)
        self.fc5 = nn.Linear(h_dim, image_size)
        
    def encode(self, x):
        h = F.relu(self.fc1(x))
        return self.fc2(h), self.fc3(h)
    
    def reparameterize(self, mu, log_var):
        std = torch.exp(log_var/2)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        h = F.relu(self.fc4(z))
        return F.sigmoid(self.fc5(h))
    
    def forward(self, x):
        mu, log_var = self.encode(x)
        z = self.reparameterize(mu, log_var)
        x_reconst = self.decode(z)
        return x_reconst, mu, log_var

In [5]:
model = VAE().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

# Create a directory if not exists

In [6]:
sample_dir = 'generated'
if not os.path.exists(sample_dir):
    os.makedirs(sample_dir)

# Training the model

In [7]:
for epoch in range(num_epochs):
    for i, (x, _) in enumerate(data_loader):
        # Forward pass
        x = x.to(device).view(-1, image_size)
        x_reconst, mu, log_var = model(x)
        
        #Compute reconstruction loss and KL divergence
        reconst_loss = F.binary_cross_entropy(x_reconst, x, size_average=False)
        kl_div = - 0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
        
        #Backprop and optimize
        loss = reconst_loss + kl_div
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if (i+1) % 10 == 0:
            print ("Epoch[{}/{}], Step [{}/{}], Reconstruction Loss: {:.4f}, KL Divergence: {:.4f}" 
                   .format(epoch+1, num_epochs, i+1, len(data_loader), reconst_loss.item(), kl_div.item()))
            
            
        with torch.no_grad():
            # Save the sampled images
            z = torch.randn(batch_size, z_dim).to(device)
            out = model.decode(z).view(-1, 1, 28, 28)
            save_image(out, os.path.join(sample_dir, 'sampled-{}.png'.format(epoch+1)))

            # Save the reconstructed images
            out, _, _ = model(x)
            x_concat = torch.cat([x.view(-1, 1, 28, 28), out.view(-1, 1, 28, 28)], dim=3)
            save_image(x_concat, os.path.join(sample_dir, 'reconstruction-{}.png'.format(epoch+1)))



Epoch[1/20], Step [10/79], Reconstruction Loss: 35030.8203, KL Divergence: 3198.4971
Epoch[1/20], Step [20/79], Reconstruction Loss: 29060.1992, KL Divergence: 1023.0688
Epoch[1/20], Step [30/79], Reconstruction Loss: 28055.2266, KL Divergence: 1231.1628
Epoch[1/20], Step [40/79], Reconstruction Loss: 27444.1230, KL Divergence: 692.6370
Epoch[1/20], Step [50/79], Reconstruction Loss: 26166.9492, KL Divergence: 659.7141
Epoch[1/20], Step [60/79], Reconstruction Loss: 25527.2930, KL Divergence: 781.9853
Epoch[1/20], Step [70/79], Reconstruction Loss: 24336.6777, KL Divergence: 1015.4268
Epoch[2/20], Step [10/79], Reconstruction Loss: 22391.8047, KL Divergence: 1194.1521
Epoch[2/20], Step [20/79], Reconstruction Loss: 22487.2285, KL Divergence: 1246.2677
Epoch[2/20], Step [30/79], Reconstruction Loss: 21691.9707, KL Divergence: 1392.3047
Epoch[2/20], Step [40/79], Reconstruction Loss: 20417.1562, KL Divergence: 1471.3330
Epoch[2/20], Step [50/79], Reconstruction Loss: 20049.6680, KL Diver

Epoch[14/20], Step [70/79], Reconstruction Loss: 11148.6445, KL Divergence: 2912.7747
Epoch[15/20], Step [10/79], Reconstruction Loss: 10723.8281, KL Divergence: 2995.0884
Epoch[15/20], Step [20/79], Reconstruction Loss: 11851.5518, KL Divergence: 2959.7539
Epoch[15/20], Step [30/79], Reconstruction Loss: 11223.9355, KL Divergence: 3135.8267
Epoch[15/20], Step [40/79], Reconstruction Loss: 10999.1504, KL Divergence: 2961.6819
Epoch[15/20], Step [50/79], Reconstruction Loss: 11471.6621, KL Divergence: 3142.6206
Epoch[15/20], Step [60/79], Reconstruction Loss: 11029.0059, KL Divergence: 2997.8335
Epoch[15/20], Step [70/79], Reconstruction Loss: 11279.1943, KL Divergence: 3055.9851
Epoch[16/20], Step [10/79], Reconstruction Loss: 11344.7568, KL Divergence: 2986.1458
Epoch[16/20], Step [20/79], Reconstruction Loss: 11140.2637, KL Divergence: 3057.3330
Epoch[16/20], Step [30/79], Reconstruction Loss: 11032.2070, KL Divergence: 3197.2778
Epoch[16/20], Step [40/79], Reconstruction Loss: 11376