In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributions as dist
from torchvision import datasets, transforms, utils
import numpy as np
import matplotlib.pyplot as plt

# Implementing classic VAE

We start by creating the VAE class 

In [2]:
class Encoder(nn.Module):
    def __init__(self, input_dim, hidden_dim1, hidden_dim2, latent_dim):
        super().__init__()
        
        self.input_dim = input_dim
        self.hidden_dim1 = hidden_dim1
        self.hidden_dim2 = hidden_dim2
        self.latent_dim = latent_dim
        
        # Define the layers
        
        self.fc1 = nn.Linear(input_dim, hidden_dim1)
        self.fc2 = nn.Linear(hidden_dim1, hidden_dim2)
        self.fc3_mean = nn.Linear(hidden_dim2, latent_dim)
        self.fc3_var = nn.Linear(hidden_dim2, latent_dim)
        
    def forward(self, x):
        out = F.elu(self.fc1(x))
        out = F.elu(self.fc2(out))
        mean = self.fc3_mean(out)
        var = self.fc3_var(out) # = diagonal elements of the covariance matrix (mean-field assumption on the latent variables so it can be expressed as a vector)
        
        return mean, var

In [3]:
class Decoder(nn.Module):
    def __init__(self, input_dim, hidden_dim1, hidden_dim2, latent_dim):
        super().__init__()
        
        self.input_dim = input_dim
        self.hidden_dim1 = hidden_dim1
        self.hidden_dim2 = hidden_dim2
        self.latent_dim = latent_dim
        
        # Define the layers
        
        self.fc1 = nn.Linear(latent_dim, hidden_dim2)
        self.fc2 = nn.Linear(hidden_dim2, hidden_dim1)
        self.fc3 = nn.Linear(hidden_dim1, input_dim)
        
    def forward(self, z):
        out = F.elu(self.fc1(z))
        out = F.elu(self.fc2(out))
        x = torch.sigmoid(self.fc3(out))
        
        return x

In [4]:
class VAE(nn.Module):
    def __init__(self, encoder, decoder):
        super().__init__()
        
        self.encoder = encoder
        self.decoder = decoder
        
    def reparameterization(self, mean, var):
        std = torch.exp(var) # To ensure we cannot have negative values in the variance
        sample = torch.randn_like(std)
        
        return mean + std * sample
    
    def loss(self, recon_x, x, mean, var):
        reconstruction_loss = nn.functional.binary_cross_entropy(recon_x, x, reduction='sum') # We use cross entropy as pixels in binary MNIST follows a Bernoulli distribution 
        kl_divergence = - 0.5 * torch.sum(1 + var - mean.pow(2) - var.exp())

        return reconstruction_loss + kl_divergence
        
    def forward(self, x):
        mean, var = self.encoder(x)
        z = self.reparameterization(mean, var)
        recon_x = self.decoder(z)
        
        return recon_x, mean, var

We import MNIST and transform it into binary MNIST

In [6]:
batch_size = 100

# To transform MNIST into binary MNIST
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Lambda(lambda x: torch.where(x >= 0.5, torch.tensor(1, dtype=torch.float32), torch.tensor(0, dtype=torch.float32))),
])

train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
n_class = len(train_dataset.classes) # Useful for the multi_encoder VAE
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size)

We train the VAE 

In [2]:
def train(vae_model, input_dim, loader, optimizer, n_epochs=20):
    """
    Train the VAE

    Params:
    - vae_model: An instance of VAE
    - input_dim: The dimension of the input
    - loader: The loader on which we will train the model
    - optimizer: The optimizer used to train
    - n_epochs: The number of training epochs
    """
    vae_model.train()

    for epoch in range(n_epochs):
        overall_loss = 0

        for batch_idx, (x, _) in enumerate(loader):
            x = x.view(batch_size, input_dim)
            
            optimizer.zero_grad()
            recon_x, mean, var = vae_model(x)
            loss = vae_model.loss(recon_x, x, mean, var)
            overall_loss += loss
            
            loss.backward()
            optimizer.step()

        print("\tEpoch", epoch + 1, "\tAverage Loss: ", overall_loss/(batch_idx*batch_size))
        
    return overall_loss

In [None]:
train_classic_vae = False # Change to true if you want to train the VAE instead of taking the trained one

# Define the dimension of the layers
input_dim = 28 * 28
hidden_dim1 = hidden_dim2 = 200
latent_dim = 50

# Create the classic VAE
encoder = Encoder(input_dim, hidden_dim1, hidden_dim2, latent_dim)
decoder = Decoder(input_dim, hidden_dim1, hidden_dim2, latent_dim)
classic_vae = VAE(encoder, decoder)

# Training (if train == True) otherwise just load the already trained model
optimizer = torch.optim.Adam(list(classic_vae.encoder.parameters()) + list(classic_vae.decoder.parameters()), lr=0.001)

if train_classic_vae:
    train_classic(classic_vae, input_dim, train_loader, optimizer)
    torch.save(classic_vae.state_dict(), 'save/classic_vae.pth')
else:
    classic_vae.load_state_dict(torch.load('save/classic_vae.pth'))

for param in classic_vae.encoder.parameters():
    param.requires_grad = False 
for param in classic_vae.decoder.parameters():
    param.requires_grad = False 

classic_vae.eval()