In [1]:
import torch
import torch.nn as nn
import torch.optim as optim

import tqdm
import pickle
import matplotlib.pyplot as plt
import numpy as np
from sklearn.model_selection import train_test_split

from sklearn.datasets import make_moons
from torchvision import datasets, transforms

## Define basic Variational Autoencoder and neessary routines

In [2]:
# variational autoencoder skeleton

class VAE(nn.Module):
    def _init__(self, input_size, num_hidden_layers, hidden_layer_size, bottleneck_size):
        super(VAE, self).__init__()

        # define architecture of the vae
        self.input_size = input_size
        self.num_hidden_layers = num_hidden_layers
        self.hidden_layer_size = hidden_layer_size
        self.bottleneck_size = bottleneck_size
        
        # construct the encoder, reparameterise, and decoder
        self.encoder = self.construct_encoder()
        self.reparameterise = self.reparameterise()
        self.decoder = self.construct_decoder()

        #define loss function and optimizer
        self.mse_loss_function = nn.MSELoss(reduction='mean')
        self.kl_loss_function = nn.KLDivLoss(reduction='mean')

        self.mse_weight = 30
        self.loss = self.loss_function()

        self.optimizer = optim.Adam(self.parameters(), lr=0.001)

    def loss_function(self, x, x_hat, mu, log_var):
        # loss function
        mse_loss = self.mse_loss_function(x, x_hat)
        kl_loss = self.kl_loss_function(mu, log_var)
        return (self.mse_weight/(self.mse_weight + 1)) * mse_loss + (1//(self.mse_weight + 1)) * kl_loss

    def construct_encoder(self):
        # construct the encoder
        encoder = nn.Sequential()

        encoder.add_module('input', nn.Linear(self.input_size, self.hidden_layer_size))
        encoder.add_module('relu_input', nn.ReLU())

        for i in range(self.num_hidden_layers - 1):
            encoder.add_module('hidden_' + str(i), nn.Linear(self.hidden_layer_size, self.hidden_layer_size))
            encoder.add_module('relu_' + str(i), nn.ReLU())

        # split output into mu and var components
        encoder.add_module('mu', nn.Linear(self.hidden_layer_size, self.bottleneck_size))
        encoder.add_module('log_var', nn.Linear(self.hidden_layer_size, self.bottleneck_size))

        return encoder
    
    def reparameterise(self, mu, log_var):
        # reparameterise
        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)
        return mu + eps * std
    
    def construct_decoder(self):
        # construct the decoder
        decoder = nn.Sequential()

        decoder.add_module('input', nn.Linear(self.bottleneck_size, self.hidden_layer_size))
        decoder.add_module('relu_input', nn.ReLU())

        for i in range(self.num_hidden_layers - 1):
            decoder.add_module('hidden_' + str(i), nn.Linear(self.hidden_layer_size, self.hidden_layer_size))
            decoder.add_module('relu_' + str(i), nn.ReLU())

        decoder.add_module('output', nn.Linear(self.hidden_layer_size, self.input_size))

        return decoder
    
    def forward(self, x):
        # forward pass
        x = self.encoder(x)
        mu, log_var = torch.chunk(x, 2, dim=-1)
        z = self.reparameterise(mu, log_var)
        x = self.decoder(z)
        return x 

In [4]:
# define training loop

def train_vae(model, X_train, X_val, n_epochs=10, batchsize=32, verbose=True):
    loss_history = {
        'train': [],
        'val': []
    }

    if verbose:
        pbar = tqdm.tqdm(range(n_epochs))
    else:
        pbar = range(n_epochs)

    for epoch in pbar:
        # training
        model.train()
        train_loss = 0
        for i in tqdm.tqdm(range(0, X_train.shape[0], batchsize)):
            batch = X_train[i:i+batchsize]
            model.optimizer.zero_grad()
            output = model(batch)
            loss = model.loss_function(output, batch)
            loss.backward()
            model.optimizer.step()
            train_loss += loss.item()
        train_loss /= X_train.shape[0]
        loss_history['train'].append(train_loss)

        # validation
        model.eval()
        val_loss = 0
        with torch.no_grad():
            for i in tqdm.tqdm(range(0, X_val.shape[0], batchsize)):
                batch = X_val[i:i+batchsize]
                output = model(batch)
                loss = model.loss_function(output, batch)
                val_loss += loss.item()
            val_loss /= X_val.shape[0]
            loss_history['val'].append(val_loss)

        #print loss
        if verbose:
            pbar.set_description('Epoch: {}/{}, train loss: {:.4f}, val loss: {:.4f}'.format(epoch+1, n_epochs, train_loss, val_loss))
        
        return loss_history