# <center>VAE-baseline for CIFAR10 </center>

## Load Data

In [1]:
!pip3 install --upgrade torch torchvision

Collecting torch
  Using cached https://files.pythonhosted.org/packages/69/43/380514bd9663f1bf708abeb359b8b48d3fabb1c8e95bb3427a980a064c57/torch-0.4.0-cp36-cp36m-manylinux1_x86_64.whl
tcmalloc: large alloc 1073750016 bytes == 0x5b0fa000 @  0x7fa6645781c4 0x46d6a4 0x5fcbcc 0x4c494d 0x54f3c4 0x553aaf 0x54e4c8 0x54f4f6 0x553aaf 0x54efc1 0x54f24d 0x553aaf 0x54efc1 0x54f24d 0x553aaf 0x54efc1 0x54f24d 0x551ee0 0x54e4c8 0x54f4f6 0x553aaf 0x54efc1 0x54f24d 0x551ee0 0x54efc1 0x54f24d 0x551ee0 0x54e4c8 0x54f4f6 0x553aaf 0x54e4c8
Requirement already up-to-date: torchvision in /usr/local/lib/python3.6/dist-packages (0.2.1)
Requirement not upgraded as not directly required: numpy in /usr/local/lib/python3.6/dist-packages (from torchvision) (1.14.3)
Requirement not upgraded as not directly required: six in /usr/local/lib/python3.6/dist-packages (from torchvision) (1.11.0)
Requirement not upgraded as not directly required: pillow>=4.1.1 in /usr/local/lib/python3.6/dist-packages (from torchvision) (5.

In [0]:
import torch
import torch.nn as nn
from torch import optim
import torch.nn.functional as F
from torch.autograd import Variable

import torchvision
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
from IPython.display import Image
from google.colab import files


In [3]:
#Set random seed 
torch.manual_seed(512)

<torch._C.Generator at 0x7fbb60cd40b0>

In [4]:
#Get the CIFAR10 train images 
cifar = datasets.CIFAR10('./data/cifar/', train = True, download = True)

Files already downloaded and verified


In [0]:
# Organize training data in batches, 
# normalize them to have values between [-1, 1] (?)

train_images = torch.utils.data.DataLoader ( datasets.CIFAR10('./data/cifar/', train = True, download=False,
                               transform=transforms.Compose([
                               #transforms.Resize(64), 
                               #transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                               transforms.ToTensor(),])) , 
                               batch_size = 128, shuffle = True)

## Model

We will use the arquitecture suggested by [Radford et al](https://arxiv.org/abs/1511.06434) for both the encoder and decoder. With convolutional layers in the encoder and fractionally-strided  convolutions  in  the  decoder.   In  each convolutional layer in the encoder we double the number of filters present in the previous layer and use a convolutional stride of 2.  In each convolutional layer in the decoder we use a fractional stride of 2 and halve the number of filters on each layer.

In [0]:
class VAE(nn.Module):

    def __init__(self, image_size ,  hidden_dim , encoding_dim):
        
        super(VAE, self).__init__()
        
        self.encoding_dim = encoding_dim
        self.image_size = image_size
        self.hidden_dim = hidden_dim 
        
        # Decoder - Fractional strided convolutional layers
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(256, 128, 4, 1, 0, bias = False),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            nn.ConvTranspose2d(128, 64, 4, 2, 1, bias = False),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.ConvTranspose2d(64, 32, 4, 2, 1, bias = False),
            nn.BatchNorm2d(32),
            nn.ReLU(True),
            nn.ConvTranspose2d(32, 3, 4, 2, 1, bias = False),
            nn.Sigmoid() # nn.Tanh()  
        )
        
        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 32, 4, 2, 1, bias = False),
            nn.LeakyReLU(0.2, inplace = True),
            nn.Conv2d(32, 64, 4, 2, 1, bias = False),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2, inplace = True),
            nn.Conv2d(64, 128, 4, 2, 1, bias = False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace = True),
            nn.Conv2d(128, 256, 4, 2, 0, bias = False),
            nn.Sigmoid()
        )
        
        # Fully-connected layers
        self.fc1 = nn.Linear(256, self.hidden_dim)
        self.fc21 = nn.Linear(self.hidden_dim, self.encoding_dim)
        self.fc22 = nn.Linear(self.hidden_dim, self.encoding_dim)
        self.fc3 = nn.Linear(self.encoding_dim, self.hidden_dim)
        self.fc4 = nn.Linear(self.hidden_dim, 256)
    
    def decode(self, z):
        h3 = F.relu(self.fc3(z))
        h4 = F.sigmoid(self.fc4(h3))
        return self.decoder( h4.view(z.size(0),-1,1,1) ) 

        
    def forward(self, x):
        
        # Encode 
        encoded = F.relu(self.fc1( self.encoder(x).view(x.size(0), -1) ) )
        
        #Obtain mu and logvar
        mu = self.fc21( encoded )
        logvar = self.fc22 ( encoded )
        
        #Reparametrization trick
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        z = eps.mul(std).add_(mu)
        
        # Decode 
        decoded = self.decode(z)

        # return decoded, mu, logvar
        return decoded, mu , logvar

    
# Reconstruction + KL divergence losses summed over all elements and batch
def loss_function(recon_x, x, mu, logvar):
    BCE = F.binary_cross_entropy(recon_x, x, size_average=False)
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return BCE + KLD

In [0]:
#Define model
model = VAE(32, 128, 32).cuda()
optimizer = optim.Adam(model.parameters(), lr=1e-2, amsgrad=True)

#Train model
def train(epoch):
    train_loss = 0
    for batch_idx, (data, _) in enumerate(train_images):
        data = Variable(data).cuda()
        optimizer.zero_grad()
        recon_batch, mu, logvar = model(data)
        loss = loss_function(recon_batch, data, mu, logvar)
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
        if batch_idx % 50 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_images.dataset),
                100. * batch_idx / len(train_images),
                loss.item() / len(data)))

    print('====> Epoch: {} Average loss: {:.4f}'.format(
          epoch, train_loss / len(train_images.dataset)))

In [0]:
num_epochs = 100
for epoch in range(1, num_epochs):
    train(epoch)
    with torch.no_grad():
        sample = torch.randn(64, 32).cuda()
        sample = model.decode(sample)
        torchvision.utils.save_image(sample.view(64, 3, 32, 32),'./sample_' + str(epoch) + '.png')
        files.download('./sample_' + str(epoch) + '.png')
        torch.save(model.cpu().state_dict(), "./vae_checkpoint_epoch_"+str(epoch)+".pth")
        files.download("./vae_checkpoint_epoch_"+str(epoch)+".pth")
        model.cuda()