# Variational Auto-Encoder Neural Network (2 Hidden Layers)

In [1]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms
from torchvision.utils import save_image

In [2]:
# Define hyperparameters
image_size = 784
hidden_dim_1 = 400
hidden_dim_2 = 200
latent_dim = 20
batch_size = 128
epochs = 10

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# MNIST dataset
train_dataset = torchvision.datasets.MNIST(root='../../data',
                                           train=True,
                                           transform=transforms.ToTensor(),
                                           download=True)

train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                           batch_size=batch_size, 
                                           shuffle=True)

test_dataset = torchvision.datasets.MNIST(root='../../data',
                                          train=False,
                                          transform=transforms.ToTensor())

test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                          batch_size=batch_size, 
                                          shuffle=True)

# Create directory to save the reconstructed and sampled images (if directory not present)
sample_dir = 'results'
if not os.path.exists(sample_dir):
    os.makedirs(sample_dir)

![vae](https://user-images.githubusercontent.com/30661597/78418103-a2047200-766b-11ea-8205-c7e5712715f4.png)

In [3]:
# VAE model
class VAE(nn.Module):
    def __init__(self):
        super(VAE, self).__init__()
        
        self.fc1 = nn.Linear(image_size, hidden_dim_1)
        self.fc2 = nn.Linear(hidden_dim_1,hidden_dim_2)
        self.fc3_mean = nn.Linear(hidden_dim_2, latent_dim)      # mean 
        self.fc3_logvar = nn.Linear(hidden_dim_2, latent_dim)    # log(variance)
        self.fc4 = nn.Linear(latent_dim, hidden_dim_2)
        self.fc5 = nn.Linear(hidden_dim_2, hidden_dim_1)
        self.fc6 = nn.Linear(hidden_dim_1, image_size)
        
    def encode(self, x):
        hidden1 = F.relu(self.fc1(x))
        hidden2 = F.relu(self.fc2(hidden1))
        mu = self.fc3_mean(hidden2)
        log_var = self.fc3_logvar(hidden2)
        return mu, log_var
    
    def reparameterize(self, mu, logvar):
        std_dev = torch.exp(logvar/2)
        epsilon = torch.randn_like(std_dev) # Normal distribution with standard deviation 'std_dev'
        return mu + epsilon * std_dev
    
    def decode(self, z):
        hidden3 = F.relu(self.fc4(z))
        hidden4 = F.relu(self.fc5(hidden3))
        out = torch.sigmoid(self.fc6(hidden4))  # range of values for MNIST images is 0 to 1.
        return out
    
    def forward(self, x):
        # x: (batch_size, 1, 28,28) --> (batch_size, 784)
        mu, logvar = self.encode(x.view(-1, image_size))   # flatten x
        z = self.reparameterize(mu, logvar)
        reconstructed = self.decode(z)
        return reconstructed, mu, logvar

# Define model and optimizer
model = VAE().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

#### $\boxed{Loss = -E[\log P(X | z)]+D_{K L}[N(\mu(X), \Sigma(X)) \| N(0,1)]}$ 

#### where $D_{KL}[N(\mu(X), \Sigma(X)) \| N(0,1)]$ is the KL Divergence.

#### $D_{K L}[N(\mu(X), \Sigma(X)) \| N(0,1)]=\frac{1}{2} \sum_{k}\left(\exp (\Sigma(X))+\mu^{2}(X)-1-\Sigma(X)\right)$

In [4]:
# Define Loss
def loss_function(reconstructed_image, original_image, mu, logvar):
    # reconstructed image is flattened, so we need to flatten the original image
    # also reduction = 'sum' because we are summing up terms for D_KL.
    bce = F.binary_cross_entropy(reconstructed_image, original_image.view(-1, 784), reduction = 'sum') 
    #############################################################################################
    # logvar, exp: (batch_size, 20)
    # kld = torch.sum(0.5 * torch.sum(logvar.exp() + mu.pow(2) - 1 - logvar, 1))  # (batch_size)
    # kld_sum = torch.sum(kld)
    #############################################################################################
    kld = 0.5 * torch.sum(logvar.exp() + mu.pow(2) - 1 - logvar)
    return bce + kld
    

# Train function
def train(epoch):
    model.train()
    train_loss = 0
    for i, (images, _) in enumerate(train_loader):
        images = images.to(device)
        reconstructed, mu, logvar = model(images)
        loss = loss_function(reconstructed, images, mu, logvar) # loss for all the batches
        optimizer.zero_grad()
        loss.backward()
        train_loss += loss.item()   
        optimizer.step()
        
        if i % 100 == 0:
            # we need to calculate loss.item()/len(images) = average loss per batch
            print("Train Epoch {} [Batch {}/{}]\tLoss: {:.3f}".format(epoch, i, len(train_loader), loss.item()/len(images)))
            
    print('=====> Epoch {}, Average Loss: {:.3f}'.format(epoch, train_loss/len(train_loader.dataset)))


# Test function
def test(epoch):
    model.eval()
    test_loss = 0
    with torch.no_grad():
        for batch_idx, (images, _) in enumerate(test_loader):
            images = images.to(device)
            reconstructed, mu, logvar = model(images)
            test_loss += loss_function(reconstructed, images, mu, logvar).item()
            if batch_idx == 0:
                comparison = torch.cat([images[:5], reconstructed.view(batch_size, 1, 28, 28)[:5]])
                save_image(comparison.cpu(), 'results/reconstruction_' + str(epoch) + '.png', nrow = 5)

    print('=====> Average Test Loss: {:.3f}'.format(test_loss/len(test_loader.dataset)))

In [5]:
# Main function
for epoch in range(1, epochs + 1):
    train(epoch)
    test(epoch)
    with torch.no_grad():
        # Get rid of the encoder and sample z from the gaussian ditribution and feed it to the decoder to generate samples
        sample = torch.randn(64,20).to(device)
        generated = model.decode(sample).cpu()
        save_image(generated.view(64,1,28,28), 'results/sample_' + str(epoch) + '.png')

Train Epoch 1 [Batch 0/469]	Loss: 544.412
Train Epoch 1 [Batch 100/469]	Loss: 196.123
Train Epoch 1 [Batch 200/469]	Loss: 182.577
Train Epoch 1 [Batch 300/469]	Loss: 161.037
Train Epoch 1 [Batch 400/469]	Loss: 146.293
=====> Epoch 1, Average Loss: 179.996
=====> Average Test Loss: 140.696
Train Epoch 2 [Batch 0/469]	Loss: 135.242
Train Epoch 2 [Batch 100/469]	Loss: 135.302
Train Epoch 2 [Batch 200/469]	Loss: 132.059
Train Epoch 2 [Batch 300/469]	Loss: 132.334
Train Epoch 2 [Batch 400/469]	Loss: 125.177
=====> Epoch 2, Average Loss: 130.468
=====> Average Test Loss: 121.539
Train Epoch 3 [Batch 0/469]	Loss: 123.840
Train Epoch 3 [Batch 100/469]	Loss: 129.982
Train Epoch 3 [Batch 200/469]	Loss: 114.330
Train Epoch 3 [Batch 300/469]	Loss: 120.390
Train Epoch 3 [Batch 400/469]	Loss: 109.349
=====> Epoch 3, Average Loss: 118.587
=====> Average Test Loss: 114.332
Train Epoch 4 [Batch 0/469]	Loss: 109.887
Train Epoch 4 [Batch 100/469]	Loss: 111.533
Train Epoch 4 [Batch 200/469]	Loss: 110.340
