Adapted from: https://github.com/Jackson-Kang/Pytorch-VAE-tutorial/blob/master/01_Variational_AutoEncoder.ipynb

In [71]:
import torch
import torch.nn as nn

import numpy as np

from tqdm import tqdm
from torchvision.utils import save_image, make_grid

In [72]:
# Model hyperparameters

dataset_path = '~/datasets'

cuda = True
if torch.cuda.is_available():
    DEVICE= "cuda" 
else:
    DEVICE= "cpu"


batch_size = 100

x_dim  = 784
hidden_dim = 400
latent_dim = 200

lr = 1e-3

epochs = 30

In [73]:
from torchvision.datasets import MNIST
import torchvision.transforms as transforms
from torch.utils.data import DataLoader


mnist_transform = transforms.Compose([
        transforms.ToTensor(),
])

kwargs = {'num_workers': 1, 'pin_memory': True} 

train_dataset = MNIST(dataset_path, transform=mnist_transform, train=True, download=True)
test_dataset  = MNIST(dataset_path, transform=mnist_transform, train=False, download=True)

train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, **kwargs)
test_loader  = DataLoader(dataset=test_dataset,  batch_size=batch_size, shuffle=False, **kwargs)

In [74]:
class Encoder(nn.Module):
    
    def __init__(self, input_dim, hidden_dim, latent_dim):
        super(Encoder, self).__init__()

        self.FC_L1 = nn.Linear(input_dim, hidden_dim)
        self.FC_L2 = nn.Linear(hidden_dim, hidden_dim)
        self.FC_mean  = nn.Linear(hidden_dim, latent_dim)
        self.FC_var   = nn.Linear (hidden_dim, latent_dim)
        
        self.LeakyReLU = nn.LeakyReLU(0.2)
        
        self.training = True
        
    def forward(self, x):
        h_       = self.LeakyReLU(self.FC_L1(x))
        h_       = self.LeakyReLU(self.FC_L2(h_))
        mean     = self.FC_mean(h_)
        log_var  = self.FC_var(h_)
                                                 
        return mean, log_var

In [75]:
class Decoder(nn.Module):
    def __init__(self, latent_dim, hidden_dim, output_dim):
        super(Decoder, self).__init__()
        self.FC_h1 = nn.Linear(latent_dim, hidden_dim)
        self.FC_h2 = nn.Linear(hidden_dim, hidden_dim)
        self.FC_output = nn.Linear(hidden_dim, output_dim)
        
        self.LeakyReLU = nn.LeakyReLU(0.2)
        
    def forward(self, x):
        h     = self.LeakyReLU(self.FC_h1(x))
        h     = self.LeakyReLU(self.FC_h2(h))
        
        x_hat = torch.sigmoid(self.FC_output(h))
        return x_hat
        

In [76]:
class Model(nn.Module):
    def __init__(self, Encoder, Decoder):
        super(Model, self).__init__()
        self.Encoder = Encoder
        self.Decoder = Decoder
        
    def reparameterization(self, mean, var):
        epsilon = torch.randn_like(var).to(DEVICE)      # sampling epsilon        
        z = mean + var*epsilon                          # reparameterization trick
        return z
        
                
    def forward(self, x):
        mean, log_var = self.Encoder(x)
        z = self.reparameterization(mean, torch.exp(0.5 * log_var)) # takes exponential function (log var -> var)
        x_hat = self.Decoder(z)
        
        return x_hat, mean, log_var

In [77]:
encoder = Encoder(input_dim=x_dim, hidden_dim=hidden_dim, latent_dim=latent_dim)
decoder = Decoder(latent_dim=latent_dim, hidden_dim = hidden_dim, output_dim = x_dim)

model = Model(Encoder=encoder, Decoder=decoder).to(DEVICE)

In [78]:
from torch.optim import Adam

BCE_loss = nn.BCELoss() #the Binary Cross Entropy between the target and the input probabilities

def loss_function(x, x_hat, mean, log_var):
    reconstruction_loss = nn.functional.binary_cross_entropy(x_hat, x, reduction='sum')
    KLD = - 0.5 * torch.sum(1+ log_var - mean.pow(2) - log_var.exp())

    return reconstruction_loss + KLD


optimizer = Adam(model.parameters(), lr=lr)

In [79]:
print("Training VAE...")
model.train()

for epoch in range(epochs):
    overall_loss = 0
    for batch_idx, (x, _) in enumerate(train_loader):
        x = x.view(batch_size, x_dim)
        x = x

        optimizer.zero_grad()

        x_hat, mean, z_var = model(x)
        loss = loss_function(x, x_hat, mean, z_var)
        
        overall_loss += loss.item()
        
        loss.backward()
        optimizer.step()
        
    print("\tEpoch", epoch + 1, "complete!", "\tAverage Loss: ", overall_loss / (batch_idx*batch_size))
    
print("Training Completed.")

Training VAE...
	Epoch 1 complete! 	Average Loss:  173.78094217902233
	Epoch 2 complete! 	Average Loss:  128.34900460728818
	Epoch 3 complete! 	Average Loss:  117.53541030297892
	Epoch 4 complete! 	Average Loss:  113.38513693069177
	Epoch 5 complete! 	Average Loss:  110.14743812930405
	Epoch 6 complete! 	Average Loss:  108.29960833159433
	Epoch 7 complete! 	Average Loss:  106.95295415862375
	Epoch 8 complete! 	Average Loss:  105.98925810595784
	Epoch 9 complete! 	Average Loss:  105.19710383190734
	Epoch 10 complete! 	Average Loss:  104.57479912810413
	Epoch 11 complete! 	Average Loss:  104.09212481414336
	Epoch 12 complete! 	Average Loss:  103.69582861083576
	Epoch 13 complete! 	Average Loss:  103.35352366248435
	Epoch 14 complete! 	Average Loss:  102.98651556304779
	Epoch 15 complete! 	Average Loss:  102.71266934148059
	Epoch 16 complete! 	Average Loss:  102.42430217354445
	Epoch 17 complete! 	Average Loss:  102.19759897681031
	Epoch 18 complete! 	Average Loss:  101.95291426465985
	Ep