Variational AutoEncoder (VAE) with CelebA
=========================================
---
Formation Introduction au Deep Learning  (FIDLE) - S. Arias, E. Maldonado, JL. Parouty - CNRS/SARI/DEVLOG - 2020  

## Episode 1 - Train a model

 - Defining a VAE model
 - Build the model
 - Train it
 - Follow the learning process with Tensorboard


## Step 1 - Init python stuff

In [1]:
import numpy as np
import sys
from importlib import reload

import modules.vae
import modules.data_generator

reload(modules.data_generator)
reload(modules.vae)

from modules.vae  import VariationalAutoencoder
from modules.data_generator import DataGenerator

VariationalAutoencoder.about()
DataGenerator.about()


FIDLE 2020 - Variational AutoEncoder (VAE)
TensorFlow version   : 2.0.0
VAE version          : 1.24

FIDLE 2020 - DataGenerator
Version              : 0.4


## Step 2 - Prepare data
### 2.1 - Dataset localisation

In [2]:
dataset_dir  = '/bettik/PROJECTS/pr-fidle/datasets/celeba'

train_dir    = f'{dataset_dir}/clusters-xs.train'
test_dir     = f'{dataset_dir}/clusters-xs.test'

### 2.2 - DataGenerator
Ok, everything's perfect, now let's instantiate our generator for the entire dataset.

In [3]:
data_gen = DataGenerator(train_dir, 32, k_size=1)
print(f'Datasize : {len(data_gen)}')

Datasize : 31


In [4]:
x_train = np.load(f'{dataset_dir}/clusters-xs.train/images-000.npy')
x_test  = np.load(f'{dataset_dir}/clusters-xs.test/images-000.npy')
x_train.shape

(100, 128, 128, 3)

## Step 3 - Get VAE model

In [5]:
tag = 'CelebA.000'

input_shape = (128, 128, 3)
z_dim       = 200
verbose     = 0

encoder= [ {'type':'Conv2D',          'filters':32, 'kernel_size':(3,3), 'strides':2, 'padding':'same', 'activation':'relu'},
           {'type':'Conv2D',          'filters':64, 'kernel_size':(3,3), 'strides':2, 'padding':'same', 'activation':'relu'},
           {'type':'Conv2D',          'filters':64, 'kernel_size':(3,3), 'strides':2, 'padding':'same', 'activation':'relu'},
           {'type':'Conv2D',          'filters':64, 'kernel_size':(3,3), 'strides':2, 'padding':'same', 'activation':'relu'}
         ]

decoder= [ {'type':'Conv2DTranspose', 'filters':64, 'kernel_size':(3,3), 'strides':2, 'padding':'same', 'activation':'relu'},
           {'type':'Conv2DTranspose', 'filters':64, 'kernel_size':(3,3), 'strides':2, 'padding':'same', 'activation':'relu'},
           {'type':'Conv2DTranspose', 'filters':32, 'kernel_size':(3,3), 'strides':2, 'padding':'same', 'activation':'relu'},
           {'type':'Conv2DTranspose', 'filters':3,  'kernel_size':(3,3), 'strides':2, 'padding':'same', 'activation':'sigmoid'}
         ]

vae = modules.vae.VariationalAutoencoder(input_shape    = input_shape, 
                                         encoder_layers = encoder, 
                                         decoder_layers = decoder,
                                         z_dim          = z_dim, 
                                         verbose        = verbose,
                                         run_tag        = tag)
vae.save(model=None)

Model initialized.
Outputs will be in  : ./run/CelebA.000
Config saved in     : ./run/CelebA.000/models/vae_config.json


## Step 4 - Compile it

In [6]:
learning_rate = 0.0005
r_loss_factor = 10000

vae.compile(learning_rate, r_loss_factor)

Compiled.
Optimizer is Adam with learning_rate=0.0005


## Step 5 - Train

In [7]:
batch_size        = 32
epochs            = 100
image_periodicity = 1      # for each epoch
chkpt_periodicity = 2      # for each epoch
initial_epoch     = 0
dataset_size      = 1

In [8]:
vae.train(#x_train     = x_train,
          data_generator    = data_gen,
          x_test            = x_test,
          epochs            = epochs,
          image_periodicity = image_periodicity,
          chkpt_periodicity = chkpt_periodicity,
          initial_epoch     = initial_epoch
         )

Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 16/100
Epoch 17/100
Epoch 18/100
Epoch 19/100
Epoch 20/100
Epoch 21/100
Epoch 22/100
Epoch 23/100
Epoch 24/100
Epoch 25/100
Epoch 26/100
Epoch 27/100
Epoch 28/100
Epoch 29/100
Epoch 30/100
Epoch 31/100
Epoch 32/100
Epoch 33/100
Epoch 34/100
Epoch 35/100
Epoch 36/100
Epoch 37/100
Epoch 38/100
Epoch 39/100
Epoch 40/100
Epoch 41/100
Epoch 42/100
Epoch 43/100
Epoch 44/100
Epoch 45/100
Epoch 46/100
Epoch 47/100
Epoch 48/100
Epoch 49/100
Epoch 50/100
Epoch 51/100

KeyboardInterrupt: 