In [1]:
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import time
import math
import torch
import torch.nn as nn
import torch.utils 
from torch.nn import Sigmoid
from torch.optim import Adam
from torchvision.datasets import MNIST
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [3]:
#Loading in the data
dataset_path = '~/datasets'
batch_size = 100

mnist_transform = transforms.Compose([
        transforms.ToTensor(),])

kwargs = {'num_workers': 1, 'pin_memory': True} 


train_dataset = MNIST(dataset_path, transform=mnist_transform, train=True, download=True)
test_dataset = MNIST(dataset_path, transform=mnist_transform, train=False, download=True)

train_data = torch.utils.data.Subset(train_dataset, range(0,50000))
val_data = torch.utils.data.Subset(train_dataset, range(50000, 60000))

train_loader = DataLoader(dataset=train_data, batch_size=batch_size, shuffle=True, **kwargs)
vali_loader = DataLoader(dataset=val_data, batch_size=batch_size, shuffle=True, **kwargs)
test_loader  = DataLoader(dataset=test_dataset,  batch_size=batch_size, shuffle=True,  **kwargs)

In [None]:

class Reshape(nn.Module):
    def __init__(self, *args):
        super(Reshape, self).__init__()
        self.shape = args

    def forward(self, x):
        return x.view(self.shape)
    
class Trim(nn.Module):
    def __init__(self, *args):
        super().__init__()
    def forward(self,x):
        return x[:, :, :28, :28]



In [None]:
class VAE(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=(3,3), stride=(1, 1), padding=0),
            nn.BatchNorm2d(32),
            nn.ReLU(True),
            nn.Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=0),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=(3, 3), stride=(3, 3),  padding=0),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1),  padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Flatten()
        )
        
        self.z_mean = torch.nn.Linear(8*8*64, 10)
        self.z_var =nn.Sequential(
            torch.nn.Linear(8*8*64, 10),
            torch.nn.Softplus())


        
        self.decoder = nn.Sequential(
            torch.nn.Linear(10, 8*8*64),
            Reshape(-1, 64, 8, 8),
            nn.ConvTranspose2d(64, 64, kernel_size=(3, 3), stride = (1, 1),  padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),            
            nn.ConvTranspose2d(64, 64, kernel_size=(3, 3), stride = (3, 3),  padding=0),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, kernel_size=(3, 3), stride = (1, 1),  padding=0),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            # 2 output channels, one for mu one for sigma
            # alternatively one for beta one for alpha
            nn.ConvTranspose2d(32, 1, kernel_size=(3, 3), stride = (1, 1), padding=0),
            # this for gauss
            # nn.Sigmoid()
            # this for beta
            Trim(),
            nn.Softplus()
        )
    
    def reparametization(self, z_mu, z_var):
        #batch_mu_z + torch.sqrt(batch_var_z) * torch.randn(batch_var_z.shape, device=device)
        z = z_mu + torch.sqrt(z_var)*torch.randn(z_var.shape,device=device)
        #eps = torch.randn_like(torch.exp(z_log_var),device=device)
        #z = z_mu + torch.exp(z_log_var) * eps

        return z
  
    
    def forward(self, x):
        x = self.encoder(x)
        z_mean, z_var = self.z_mean(x), self.z_var(x)
        encoded = self.reparametization(z_mean, z_var)
        decoded = self.decoder(encoded)
        return decoded, z_mean, z_var

In [None]:
vae = VAE().to(device)

def loss_function(x, x_reconstr, mu, log_sigma):
    reconstr_loss = nn.functional.mse_loss(x_reconstr, x, reduction='sum')
    kl_loss = 0.5 * torch.sum(mu.pow(2) + (2*log_sigma).exp() - 2*log_sigma - 1)
    total_loss = reconstr_loss + kl_loss
    return total_loss, reconstr_loss, kl_loss

optimizer = Adam(vae.parameters(), lr=0.005)

In [None]:
epochs = 50

print("Start training VAE...")
start_time = time.time()
vae.train()

train_ELBO = []
validation_ELBO = []

for epoch in range(epochs):
    
    overall_loss = 0
    overall_reconstr_loss = 0
    overall_kl_loss = 0
    for batch_idx, (x, _) in enumerate(train_loader):

        optimizer.zero_grad()
        x = x.to(device)
        x_reconstr, mu, log_sigma = vae(x)
        loss, reconstr_loss, kl_loss = loss_function(x, x_reconstr, mu, log_sigma)
        overall_loss += loss.item()
        overall_reconstr_loss += reconstr_loss.item()
        overall_kl_loss += kl_loss.item()
        
        loss.backward()
        optimizer.step()
        
        
        
    n_datapoints = batch_idx * batch_size
    train_ELBO.append(overall_loss/n_datapoints)
    
    with torch.no_grad():
        validation_loss = 0
        validation_reconstr_loss = 0
        validation_kl_loss = 0
        for batch_idx, (x, y) in enumerate(vali_loader):
            x = x.to(device)
            y = y.to(device)
            x_reconstr, mu, log_sigma = vae(x)
            loss, reconstr_loss, kl_loss = loss_function(x, x_reconstr, mu, log_sigma)
            
            validation_loss += loss.item()
            validation_reconstr_loss += reconstr_loss.item()
            validation_kl_loss += kl_loss.item()
            
        n_datapoints = batch_idx * batch_size
        validation_ELBO.append(validation_loss/n_datapoints)
            
    if (np.absolute(train_ELBO[epoch] - train_ELBO[epoch-1]) <= 0.05) and (epoch != 0):
        print(train_ELBO)
        break
        
    print("\tEpoch", epoch + 1, "\tAverage Loss: ", overall_loss / n_datapoints, "\tReconstruction Loss:", overall_reconstr_loss / n_datapoints, "\tKL Loss:", overall_kl_loss / n_datapoints)
    
print("Training complete!")
print(start_time - time.time())

In [None]:
#Make the ELBO plots
fig, (ax1, ax2) = plt.subplots(1,2)
ax1.plot(train_ELBO)
ax2.plot(validation_ELBO);
ax1.set_xlabel('nr epochs')
ax2.set_xlabel('nr epochs')
ax1.set_ylabel('loss');



In [None]:
#Making reconstructions based on the test set
vae.eval()

x_original_list = []
y_list = []
x_reconstr_list = []
with torch.no_grad():
    for batch_idx, (x, y) in enumerate(tqdm(test_loader)):
        x = x.view(batch_size,1,28,28)
        x = x.to(device)
        
        x_reconstr, mu, log_sigma= vae(x)
        x_original_list.append(x)
        y_list.append(y)
        x_reconstr_list.append(x_reconstr)

In [None]:
#Generating images
with torch.no_grad():
    noise = torch.randn(batch_size, 10).to(device)
    generated_images = vae.decoder(noise)

In [None]:
from datetime import datetime

#Making columns of images 
def show_images(x, ncols=16):
    x = x.view(batch_size, 28, 28)
    fig, ax = plt.subplots(1, ncols, figsize=(40, 2))
    
    for idx in range(ncols):
        ax[idx].imshow(x[idx].cpu().numpy(), cmap="Greys_r")
        ax[idx].axis('off')
    time =  datetime.now().strftime('%H%M%S')
    fig.savefig(f'figure: {time}')

In [None]:
batch_idx = 0
show_images(x_original_list[batch_idx])
show_images(x_reconstr_list[batch_idx])
show_images(generated_images)