# Session 1: Variational Autoencoder (VAE)

* This week, the task is to build a Variational Autoencoder in **PyTorch**. 

* You may approach to build your model any way you wish but I encourage you to make it so that you can change the architecture with ease - perhaps with the use of classes...? 

* The loss function may be implemented using the built-in KL divergence function but to get the most out of this session try to build it from scratch.



---



### Import Dependencies

In [None]:
import torch 
import torch.nn as nn
import torch.nn.functional as F



---



### Task 1: Define a class for the VAE model

In [None]:
# Class derived from the PyTorch module class
class VAE(nn.Module):
    def __init__(self, x_dim, h_dim1, h_dim2, z_dim):
        super(VAE, self).__init__()
        
        # define your variables here
        # encoder part
        self.fc1 = nn.Linear(x_dim, h_dim1)
        self.fc2 = nn.Linear(h_dim1, h_dim2)
        self.mu = nn.Linear(h_dim2, z_dim)
        self.log_var = nn.Linear(h_dim2, z_dim)
        # decoder part
        self.fc4 = nn.Linear(z_dim, h_dim2)
        self.fc5 = nn.Linear(h_dim2, h_dim1)
        self.fc6 = nn.Linear(h_dim1, x_dim)
        
    def encoder(self, x):
        h = F.relu(self.fc1(x))
        h = F.relu(self.fc2(h))
        return self.mu(h), self.log_var(h) 
    
    def sampling(self, mu, log_var):
        std = torch.exp(0.5*log_var)
        eps = torch.randn_like(std)
        return eps.mul(std).add_(mu) # return z sample
        
    def decoder(self, z):
        h = F.relu(self.fc4(z))
        h = F.relu(self.fc5(h))
        return F.sigmoid(self.fc6(h)) 
    
    def forward(self, x):
        mu, log_var = self.encoder(x.view(-1, 784))
        z = self.sampling(mu, log_var)
        return self.decoder(z), mu, log_var



---



### Task 2: Build the loss function

Remember that the loss function is a **combination** of 
* Cross Entropy (assessing how accurate the reconstruction is) and 
* KL divergence (measuring distance between Gaussian distribution and True distribution).

In [None]:
def loss_function(recon_x, x, mu, log_var):
    ce = torch.nn.functional.binary_cross_entropy(recon_x, x.view(-1, 784), reduction='sum') # cross entropy loss
    kl = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp()) # kl divergence loss
    loss = ce + kl
    return loss



---



### Task 3: Define data, train and test functions

Now we are almost ready to build a VAE, load in some data and train!


First we'll start by defining some functions that we can reuse later. Here, we'll start with MNIST, but once you get the hang of it, I highly recommend trying out other datasets and formats such as:


1.   Fashion MNIST, Flickr: Also avialable from torchvision
2.   MIMIC ECG data: https://physionet.org/content/mimicdb/1.0.0/
3.   If you are feeling adventurous, why not try stripping data from Wikipedia on a particular topic and generating similar trends/groups automatically in latent space!



In [None]:
from torchvision import datasets, transforms

# Function to load MNIST
# You may decide to extend this to include a text string as a parameter and gather 
# the appropriate dataset
def get_dataset(batch_size=10):
  
  # MNIST Dataset
  train_dataset = datasets.MNIST(root='./mnist_data/', train=True, transform=transforms.ToTensor(), download=True)
  test_dataset = datasets.MNIST(root='./mnist_data/', train=False, transform=transforms.ToTensor(), download=False)

  # Data Loader (Input Pipeline)
  train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
  test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)

  return train_loader, test_loader

Notice that we can define transformations here to augment our data automatically using built-in PyTorch functions. I recommend using them when moving away from MNIST.

Define training function...
This has been derived from a tutorial on PyTorch which gives a definition of each line:
https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html#sphx-glr-beginner-blitz-cifar10-tutorial-py

In [None]:
def train(model, epoch, train_loader, optimizer, device='cpu'):
    model.to(device)
    model.train()   # tells PyTorch to update weights
    train_loss = 0
    
    for batch_idx, (data, _) in enumerate(train_loader):
        data = data.to(device)    
        optimizer.zero_grad()
        
        # Call your model and loss functions here
        recon_x, mu, log_var = model(data)
        loss = loss_function(recon_x, data, mu, log_var)
        
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
        
        if batch_idx % 100 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item() / len(data)))
            
    print('====> Epoch: {} Average loss: {:.4f}'.format(epoch, train_loss / len(train_loader.dataset)))

In [None]:
def test(model, test_loader, device='cpu'):
    model.eval()
    test_loss= 0
    
    with torch.no_grad():
        for data, _ in test_loader:
            data = data.to(device)
            recon, mu, log_var = model(data)
            
            # sum up batch loss
            test_loss += loss_function(recon, data, mu, log_var).item()
        
    test_loss /= len(test_loader.dataset)
    print('====> Test set loss: {:.4f}'.format(test_loss))

Let's try out all the functions we have created above!

In [None]:
vae_model = VAE(x_dim=784, h_dim1= 512, h_dim2=256, z_dim=2)
optimizer = torch.optim.Adam(vae_model.parameters(), lr=0.001)
device_type = 'cuda' # choose between 'cpu' and 'cuda' based on your accelerator availability
num_epochs = 50

train_data, test_data = get_dataset()

for epoch in range(1, num_epochs+1):
    train(vae_model, epoch, train_data, optimizer, device=device_type)
    test(vae_model, test_data, device=device_type)



====> Epoch: 1 Average loss: 161.5130
====> Test set loss: 153.7311
====> Epoch: 2 Average loss: 150.0359
====> Test set loss: 148.2008
====> Epoch: 3 Average loss: 147.1011
====> Test set loss: 146.9544
====> Epoch: 4 Average loss: 145.5331
====> Test set loss: 145.2405
====> Epoch: 5 Average loss: 144.7718
====> Test set loss: 144.2242
====> Epoch: 6 Average loss: 143.8425
====> Test set loss: 143.2982
====> Epoch: 7 Average loss: 143.2632
====> Test set loss: 143.8365
====> Epoch: 8 Average loss: 142.9037
====> Test set loss: 143.5956
====> Epoch: 9 Average loss: 142.3711
====> Test set loss: 143.4587
====> Epoch: 10 Average loss: 142.0456
====> Test set loss: 142.6769
====> Epoch: 11 Average loss: 141.7074
====> Test set loss: 141.8176
====> Epoch: 12 Average loss: 141.8221
====> Test set loss: 143.8194
====> Epoch: 13 Average loss: 141.4961
====> Test set loss: 142.6167
====> Epoch: 14 Average loss: 141.1036
====> Test set loss: 143.0168
====> Epoch: 15 Average loss: 141.0619
====



---



We are now done training and testing a VAE through PyTorch.

Now go build some visualization function for testing your model further...