Variational AutoEncoder (VAE) with CelebA
=========================================
---
Formation Introduction au Deep Learning  (FIDLE) - S. Arias, E. Maldonado, JL. Parouty - CNRS/SARI/DEVLOG - 2020  

## Episode 1 - Train a model

 - Defining a VAE model
 - Build the model
 - Train it
 - Follow the learning process with Tensorboard


## Step 1 - Init python stuff

In [66]:
import numpy as np
import sys, importlib

import modules.vae
import modules.data_generator

importlib.reload(modules.vae)
importlib.reload(modules.data_generator)

from modules.vae  import VariationalAutoencoder
from modules.data_generator import Data_generator

VariationalAutoencoder.about()
Data_generator.about()


FIDLE 2020 - Variational AutoEncoder (VAE)
TensorFlow version   : 2.0.0
VAE version          : 1.24

FIDLE 2020 - Data_generator
Version              : 0.1


## Step 2 - Prepare data
### 2.1 - Dataset localisation

In [67]:
dataset_dir  = '/bettik/PROJECTS/pr-fidle/datasets/celeba'

### 2.2 - Testing our data generator (Keras Sequence)
Just to understand a little bit how our data_generator works 

In [68]:
# ---- A very small dataset

clusters_dir = f'{dataset_dir}/clusters-test'

# ---- Instanciate

data_gen = Data_generator(clusters_dir,32, debug=True)

batch_sizes=[]
for i in range( len(data_gen) ):
    x,y=data_gen[i]
    batch_sizes.append(len(x))

print(f'\n\ntotal number of items : {sum(batch_sizes)}')
print(f'batch sizes : {batch_sizes}')

Clusters nb  : 10 files
Dataset size : 932
Batch size   : 32

[shuffle!]

[Load 00,s=100] (32) (32) (32) (4..) 
[Load 01,s=100] (..28) (32) (32) (8..) 
[Load 02,s=100] (..24) (32) (32) (12..) 
[Load 03,s=100] (..20) (32) (32) (16..) 
[Load 04,s=100] (..16) (32) (32) (20..) 
[Load 05,s=100] (..12) (32) (32) (24..) 
[Load 06,s= 32] (..8) (24..) 
[Load 07,s=100] (..8) (32) (32) (28..) 
[Load 08,s=100] (..4) (32) (32) (32) (0..) 
[Load 09,s=100] (..32) (32) (32) 

total number of items : 928
batch sizes : [32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32]


## Step 2 - Get data

In [None]:
(x_train, y_train), (x_test, y_test) = load_MNIST()

## Step 3 - Get VAE model

In [None]:
tag = '001'

input_shape = (28,28,1)
z_dim       = 2
verbose     = 0

encoder= [ {'type':'Conv2D',          'filters':32, 'kernel_size':(3,3), 'strides':1, 'padding':'same', 'activation':'relu'},
           {'type':'Conv2D',          'filters':64, 'kernel_size':(3,3), 'strides':2, 'padding':'same', 'activation':'relu'},
           {'type':'Conv2D',          'filters':64, 'kernel_size':(3,3), 'strides':2, 'padding':'same', 'activation':'relu'},
           {'type':'Conv2D',          'filters':64, 'kernel_size':(3,3), 'strides':1, 'padding':'same', 'activation':'relu'}
         ]

decoder= [ {'type':'Conv2DTranspose', 'filters':64, 'kernel_size':(3,3), 'strides':1, 'padding':'same', 'activation':'relu'},
           {'type':'Conv2DTranspose', 'filters':64, 'kernel_size':(3,3), 'strides':2, 'padding':'same', 'activation':'relu'},
           {'type':'Conv2DTranspose', 'filters':32, 'kernel_size':(3,3), 'strides':2, 'padding':'same', 'activation':'relu'},
           {'type':'Conv2DTranspose', 'filters':1,  'kernel_size':(3,3), 'strides':1, 'padding':'same', 'activation':'sigmoid'}
         ]

vae = modules.vae.VariationalAutoencoder(input_shape    = input_shape, 
                                         encoder_layers = encoder, 
                                         decoder_layers = decoder,
                                         z_dim          = z_dim, 
                                         verbose        = verbose,
                                         run_tag        = tag)
vae.save(model=None)

## Step 4 - Compile it

In [None]:
learning_rate = 0.0005
r_loss_factor = 1000

vae.compile(learning_rate, r_loss_factor)

## Step 5 - Train

In [None]:
batch_size        = 100
epochs            = 100
image_periodicity = 1      # for each epoch
chkpt_periodicity = 2      # for each epoch
initial_epoch     = 0
dataset_size      = 1

In [None]:
vae.train(x_train,
          x_test,
          batch_size        = batch_size, 
          epochs            = epochs,
          image_periodicity = image_periodicity,
          chkpt_periodicity = chkpt_periodicity,
          initial_epoch     = initial_epoch,
          dataset_size      = dataset_size
         )