# 2.1 Train a VAE

## TODO

> Le reparametrization trick est-il necessaire?

> Comment faire du Binary Cross Entropy loss sur un output pas dans [0,1] ??? (sigmoid ajoutée..)

> Implementer ELBO loss (KL?)

> Evaluer le valid set avec ce ELBO


In [1]:
from torchvision.datasets import utils
import torch.utils.data as data_utils
import torch
import os
import numpy as np
from torch import nn
from torch.nn.modules import upsampling
from torch.functional import F
from torch.optim import Adam
from torch.autograd import Variable

In [2]:
def get_data_loader(dataset_location, batch_size):
    def lines_to_np_array(lines): return np.array([[int(i) for i in line.split()] for line in lines])
    splitdata = []
    for splitname in ["train", "valid", "test"]:
        filename = "binarized_mnist_%s.amat" % splitname
        filepath = os.path.join(dataset_location, filename)
        with open(filepath) as f: lines = f.readlines()
        x = lines_to_np_array(lines).astype('float32')
        x = x.reshape(x.shape[0], 1, 28, 28)
        dataset = data_utils.TensorDataset(torch.from_numpy(x))
        dataset_loader = data_utils.DataLoader(x, batch_size=batch_size, shuffle=splitname == "train")
        splitdata.append(dataset_loader)
    return splitdata

train, valid, test = get_data_loader("binMNIST/", 64)

In [3]:
import matplotlib.pyplot as plt
for x in train: plt.imshow(x[0, 0]); break;

## Implementing VAE

In [12]:
class VAE(nn.Module):
    def __init__(self):
        super(VAE, self).__init__()
        self.e1 =  nn.Sequential(
            nn.Conv2d(1, 32, kernel_size = (3, 3) ), # 26 x 26
            nn.ELU(),
            
            nn.AvgPool2d(kernel_size=2, stride=2), # 13 x 13
            nn.Conv2d(32, 64, kernel_size=(3, 3)), # 11 x 11
            nn.ELU(),
            
            nn.AvgPool2d(kernel_size=2, stride=2),  # 5 x 5
            nn.Conv2d(64, 256, kernel_size=(5, 5)), # 1 x 1
            nn.ELU()
        )
        self.e2 = nn.Linear(in_features=256, out_features= 200)
        
        
        self.d1 = nn.Linear(in_features=100, out_features=256) # 256 x 1
        
        self.d2 = nn.Sequential(
            nn.ELU(),
            
            nn.Conv2d(256, 64, kernel_size=(5, 5), padding=(4, 4)), # 1 > 5
            nn.ELU(),
            
            nn.UpsamplingBilinear2d(scale_factor=2), # 5 > 10
            nn.Conv2d(64, 32, kernel_size=(3, 3), padding=(2, 2)), # 10 > 12
            nn.ELU(),
            
            nn.UpsamplingBilinear2d(scale_factor=2), # 12 > 24
            nn.Conv2d(32, 16, kernel_size=(3, 3), padding=(2, 2)), # 24 > 26
            nn.ELU(),
            
            nn.Conv2d(16, 1, kernel_size=(3, 3), padding=(2, 2)), # 26 x 26 > 28 x 28
            
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.e1(x)
        x = x.view(-1,256)
        x = self.e2(x)
        
        mean, logvar = torch.split(x, 100, dim = 1) # mean: [minibatch, 100]
        x = torch.normal(mean=mean, std=logvar) # minibatch x 100
        
        x = self.d1(x)
        x = x.view(-1,256,1,1)
        x = self.d2(x)
        
        return x

In [13]:
model = VAE()
model(x[0:20]).size()



torch.Size([20, 1, 28, 28])

## Training the model

In [15]:
num_epochs = 20
learning_rate = 3e-4

model = VAE().cuda()
optimizer = torch.optim.Adam( model.parameters(), lr=learning_rate)
criterion = nn.BCELoss()

batch_size = 50


for epoch in range(num_epochs):
    for data in train:
        
        data = Variable(data).cuda()
        # ===================forward=====================
        output = model(data)
        loss = criterion(output, data)
        # ===================backward====================
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
            
    # ===================loss========================
    print('epoch [{}/{}], loss:{:.4f}'.format(epoch + 1, num_epochs, loss.item()) )



epoch [1/20], loss:0.2125
epoch [2/20], loss:0.1959
epoch [3/20], loss:0.1862
epoch [4/20], loss:0.2006
epoch [5/20], loss:0.1553
epoch [6/20], loss:0.1954
epoch [7/20], loss:0.1751
epoch [8/20], loss:0.1781
epoch [9/20], loss:0.1718
epoch [10/20], loss:0.1609
epoch [11/20], loss:0.1581
epoch [12/20], loss:0.2076
epoch [13/20], loss:0.1785
epoch [14/20], loss:0.1658
epoch [15/20], loss:0.1786
epoch [16/20], loss:0.1610
epoch [17/20], loss:0.1639
epoch [18/20], loss:0.1566
epoch [19/20], loss:0.1605
epoch [20/20], loss:0.1645


In [8]:
def ELBO(x, x_out):
    criterion = nn.BCELoss()
    l1 = criterion(x, x_out)
    KL = VAE_B2 * K.mean(1 + z_log_sigma_sq - K.square(z_mean) - K.exp(z_log_sigma_sq), axis=None)
    return l1 - KL

torch.Size([16, 1, 28, 28])