# Code for the training process of the Variational AutoEncoder network

## 1.Getting acess to the dataset on Google Drive



In [75]:
from google.colab import drive
drive.mount('/content/gdrive')

Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount("/content/gdrive", force_remount=True).


In [0]:
from pydrive.auth import GoogleAuth
from pydrive.drive import GoogleDrive
from google.colab import auth
from oauth2client.client import GoogleCredentials

auth.authenticate_user()
gauth = GoogleAuth()
gauth.credentials = GoogleCredentials.get_application_default()
drive = GoogleDrive(gauth)

## 2.Installing Pyro and calling the modules

In [0]:
!pip3 install -q http://download.pytorch.org/whl/{accelerator}/torch-0.4.0-{platform}-linux_x86_64.whl
!pip3 install torchvision
!pip3 install pyro-ppl

[31mERROR: torch-0.4.0-{platform}-linux_x86_64.whl is not a valid wheel filename.[0m
Collecting pyro-ppl
[?25l  Downloading https://files.pythonhosted.org/packages/92/7a/4dc4d39d6db1aae0825a2a2ab60178fc4afb92efd9669be02715d3a16734/pyro_ppl-1.2.1-py3-none-any.whl (486kB)
[K     |████████████████████████████████| 491kB 2.8MB/s 
[?25hCollecting tqdm>=4.36
[?25l  Downloading https://files.pythonhosted.org/packages/cd/80/5bb262050dd2f30f8819626b7c92339708fe2ed7bd5554c8193b4487b367/tqdm-4.42.1-py2.py3-none-any.whl (59kB)
[K     |████████████████████████████████| 61kB 6.2MB/s 
Collecting pyro-api>=0.1.1
  Downloading https://files.pythonhosted.org/packages/c2/bc/6cdbd1929e32fff62a33592633c2cc0393c7f7739131ccc9c9c4e28ac8dd/pyro_api-0.1.1-py3-none-any.whl
Installing collected packages: tqdm, pyro-api, pyro-ppl
  Found existing installation: tqdm 4.28.1
    Uninstalling tqdm-4.28.1:
      Successfully uninstalled tqdm-4.28.1
Successfully installed pyro-api-0.1.1 pyro-ppl-1.2.1 tqdm-4.42.1

In [0]:
def setup_data_loader(images, classes, batch_size = 128, use_CUDA = False):
    '''
    Function that receives two arrays, an array of the data images and an array of it latents values
    and generate a DataLoader for train and test data.
    
    Input:
    :images: array of size [*, 64, 64] with images data
    :classes: array of size [*, 6] with latent variables of the images
    
    Output:
    :train_loader: torch data loader with train data (images and latents)
    :test_loader: torch data loader with test data (images and latents)
    '''
    index = np.random.permutation(images.shape[0])
    images = images[index].astype(np.float32)
    classes = classes[index].astype(np.float32)
    train_df = torch.utils.data.TensorDataset(torch.from_numpy(images[100:].reshape(-1, 4096)), torch.from_numpy(classes[100:]))
    test_df = torch.utils.data.TensorDataset(torch.from_numpy(images[:100].reshape(-1, 4096)), torch.from_numpy(classes[:100]))
    kwargs = {'num_workers': 1, 'pin_memory': use_CUDA}
    train_loader = torch.utils.data.DataLoader(train_df, batch_size, shuffle = False, **kwargs)
    test_loader = torch.utils.data.DataLoader(test_df, batch_size, shuffle = False, **kwargs)
    return train_loader, test_loader


In [0]:
import torch
import pyro
import numpy as np
from torch.nn import Module
from torch.nn.functional import one_hot
from networks import Decoder, Encoder
from torch import tensor
from pyro.distributions import OneHotCategorical, Normal, Bernoulli, Uniform
import matplotlib.pyplot as plt

class VAE(Module):
    '''
    Class that define the posterior distribution q(z|x) as the model 
    with the decoder and the prior distribution q(x|z) as the guide 
    using the encoder.
    
    Inputs:  
    :pimg_dim: dimension of image vector
    :label_dim: dimension of label vector
    :latent_dim: dimension of Z space, output
    '''
    def __init__(self, img_dim = 4096, label_dim = 114, latent_dim = 200, use_CUDA = False):
        super(VAE, self).__init__()
        #creating networks
        self.encoder = Encoder(img_dim, label_dim, latent_dim)
        self.decoder = Decoder(img_dim, label_dim, latent_dim)
        self.img_dim = img_dim
        self.label_dim = label_dim
        self.latent_dim = latent_dim
        if use_CUDA:
            self.cuda()
        self.use_CUDA = use_CUDA
  
    def label_variable(self, label):
        options = {'device': label.device, 'dtype': label.dtype}

        shape = pyro.sample("label_shape", 
            OneHotCategorical(torch.ones(label.shape[0], 3, **options) / (3.0)), 
            obs = one_hot((label[:, 0] - 1.).to(torch.int64), 3).to(torch.float32))
        scale = pyro.sample("label_scale",
            Uniform(torch.ones(label.shape[0], 1, **options)*0.5, torch.ones(label.shape[0], 1, **options)*1.0001).to_event(1),
            obs = label[:, 1].reshape(label.shape[0], 1))
        orien = pyro.sample("label_orien",
            Uniform(torch.zeros(label.shape[0], 1, **options), torch.ones(label.shape[0], 1, **options)*2*np.pi).to_event(1),
            obs = label[:, 2].reshape(label.shape[0], 1))
        posX = pyro.sample("label_posX",
            Uniform(torch.zeros(label.shape[0], 1, **options), torch.ones(label.shape[0], 1, **options)*1.0001).to_event(1),
            obs = label[:, 3].reshape(label.shape[0], 1))
        posY = pyro.sample("label_posY",
            Uniform(torch.zeros(label.shape[0], 1, **options), torch.ones(label.shape[0], 1, **options)*1.0001).to_event(1),
            obs = label[:, 4].reshape(label.shape[0], 1))

        new_label = torch.cat([shape, scale, orien, posX, posY], -1).to(dtype = torch.float32, device = label.device)
        return new_label

    def model(self, img, label):
        pyro.module("decoder", self.decoder)
        options = {'device': img.device, 'dtype': img.dtype}
        with pyro.plate("data", img.shape[0]):
            z_mean = torch.zeros(img.shape[0], self.latent_dim, **options)
            z_variance = torch.ones(img.shape[0], self.latent_dim, **options)
            z_sample = pyro.sample("latent", Normal(z_mean, z_variance).to_event(1))
            image = self.decoder.forward(z_sample, self.label_variable(label))
            pyro.sample("obs", Bernoulli(image).to_event(1), obs = img)


    def guide(self, img, label):
        pyro.module("encoder", self.encoder)
        with pyro.plate("data", img.shape[0]):
            z_mean, z_variance = self.encoder.forward(img, self.label_variable(label))
            pyro.sample("latent", Normal(z_mean, z_variance).to_event(1))

In [79]:
dataset_zip = np.load('/content/gdrive/My Drive/autoencoder/dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz', allow_pickle = True, encoding = 'bytes')
print('Keys in the dataset:', dataset_zip.files)
imgs = dataset_zip['imgs']
latents_values = dataset_zip['latents_values']
latents_classes = dataset_zip['latents_classes']
latents_sizes = dataset_zip['metadata'][()][b'latents_sizes']
latents_names = dataset_zip['metadata'][()][b'latents_names']

Keys in the dataset: ['metadata', 'imgs', 'latents_classes', 'latents_values']


## 3.The training process
In the following training process, we use the Standart variational inference from Pyro. In each testing iteration the state of the model (parameters values of the network) are saved at the Google Drive with the number of epochs in the file name.

In [0]:
pyro.enable_validation(True)
pyro.clear_param_store()
use_CUDA = False
train_loader, test_loader = setup_data_loader(imgs[:600], latents_values[:, 1:6][:600], use_CUDA = use_CUDA)
vae = VAE(label_dim= 7, use_CUDA = use_CUDA)

In [120]:
#optimizer
optimizer = pyro.optim.Adam({"lr" : 1.0e-3})

#inference algorithm
elbo = pyro.infer.Trace_ELBO()
svi = pyro.infer.SVI(vae.model, vae.guide, optimizer, elbo)

train_elbo = []
test_elbo = []
num_epochs = 70
test_freq = 8
for epoch in range(num_epochs):
  epoch_loss = 0.
  for (img, label) in train_loader:
    if use_CUDA:
      img = img.cuda()
      label = label.to(img.device)
    epoch_loss += svi.step(img, label)
  total_epoch_loss_train = epoch_loss/len(train_loader)
  train_elbo.append(total_epoch_loss_train)
  print("epoch: " + str(epoch) + " average training loss: " + str(epoch_loss))

  if epoch % test_freq == 0:
    test_loss = 0
    for (img, label) in test_loader:
      if use_CUDA:
        img = img.cuda()
        label = label.to(img.device)
      test_loss += svi.evaluate_loss(img, label)
    total_epoch_loss_test  = epoch_loss/len(test_loader)
    test_elbo.append(total_epoch_loss_test)
    torch.save(vae.state_dict(), '/content/gdrive/My Drive/trained_movel_epoch_'+ str(epoch) + '.save')

{'label_posX', 'label_posY', 'label_shape', 'label_scale', 'label_orien'}
  guide_vars - aux_vars - model_vars))


epoch: 0 average training loss: 827067.0625
epoch: 1 average training loss: 246384.095703125
epoch: 2 average training loss: 189836.947265625
epoch: 3 average training loss: 174836.4140625
epoch: 4 average training loss: 166770.875
epoch: 5 average training loss: 162137.564453125
epoch: 6 average training loss: 158450.2734375
epoch: 7 average training loss: 156589.669921875
epoch: 8 average training loss: 155016.4453125
epoch: 9 average training loss: 153880.787109375
epoch: 10 average training loss: 153052.609375
epoch: 11 average training loss: 152473.029296875
epoch: 12 average training loss: 151970.373046875
epoch: 13 average training loss: 151336.587890625
epoch: 14 average training loss: 150737.7734375
epoch: 15 average training loss: 150046.619140625
epoch: 16 average training loss: 149186.673828125
epoch: 17 average training loss: 148301.677734375
epoch: 18 average training loss: 146738.9296875
epoch: 19 average training loss: 144529.38671875
epoch: 20 average training loss: 14

KeyboardInterrupt: ignored