In [1]:
from models.VAE import VariationalAutoencoder
from utils.loaders import load_mnist

Using TensorFlow backend.


In [12]:
import pathlib
import os
import numpy as np

In [3]:
# run params
SECTION = 'vae'
RUN_ID = '0002'
DATA_NAME = 'digits'
RUN_FOLDER = './run/{}/'.format(SECTION)
RUN_FOLDER += '_'.join([RUN_ID, DATA_NAME])

p = pathlib.Path(RUN_FOLDER)

if not p.exists():
    p.mkdir(parents=True, exist_ok=False)
    (p/'viz').mkdir(parents=True, exist_ok=False)
    (p/'images').mkdir(parents=True, exist_ok=False)
    (p/'weights').mkdir(parents=True, exist_ok=False)
    
MODE =  'build' #'load' #

# data

In [4]:
(x_train, y_train), (x_test, y_test) = load_mnist()

# VAE architecture

In [7]:
vae = VariationalAutoencoder(
    input_dim = (28,28,1)
    , encoder_conv_filters = [32,64,64, 64]
    , encoder_conv_kernel_size = [3,3,3,3]
    , encoder_conv_strides = [1,2,2,1]
    , decoder_conv_t_filters = [64,64,32,1]
    , decoder_conv_t_kernel_size = [3,3,3,3]
    , decoder_conv_t_strides = [1,2,2,1]
    , z_dim = 2
)

if MODE == 'build':
    vae.save(RUN_FOLDER)
else:
    vae.load_weights(os.path.join(RUN_FOLDER, 'weights/weights.h5'))

In [8]:

vae.encoder.summary()

Model: "model_6"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
encoder_input (InputLayer)      (None, 28, 28, 1)    0                                            
__________________________________________________________________________________________________
encoder_conv_0 (Conv2D)         (None, 28, 28, 32)   320         encoder_input[0][0]              
__________________________________________________________________________________________________
leaky_re_lu_8 (LeakyReLU)       (None, 28, 28, 32)   0           encoder_conv_0[0][0]             
__________________________________________________________________________________________________
encoder_conv_1 (Conv2D)         (None, 14, 14, 64)   18496       leaky_re_lu_8[0][0]              
____________________________________________________________________________________________

In [9]:
vae.decoder.summary()

Model: "model_7"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
decoder_input (InputLayer)   (None, 2)                 0         
_________________________________________________________________
dense_2 (Dense)              (None, 3136)              9408      
_________________________________________________________________
reshape_2 (Reshape)          (None, 7, 7, 64)          0         
_________________________________________________________________
decoder_conv_t_0 (Conv2DTran (None, 7, 7, 64)          36928     
_________________________________________________________________
leaky_re_lu_12 (LeakyReLU)   (None, 7, 7, 64)          0         
_________________________________________________________________
decoder_conv_t_1 (Conv2DTran (None, 14, 14, 64)        36928     
_________________________________________________________________
leaky_re_lu_13 (LeakyReLU)   (None, 14, 14, 64)        0   

In [11]:
x_train.shape

(60000, 28, 28, 1)

# data generator

In [138]:
import pdb
def data_set(xdata, ydata=None): # just returns a tuple
    return (xdata, ydata)

def sampler(dataset, bs, shuffle=True):
    '''
    [1] datasetlen : tuple: (xdata, ydata)
    [2] bs: batchsize
    [3] shuffle: Boolean
    '''
    n = len(dataset[0])
    idx = np.random.permutation(n) if shuffle else np.arange(n)
    while True:
        for i in range(0,n,bs):
            yield idx[i:i+bs]

def data_gen(dataset, samp):
    ''' Inputs
    [1] dataset: tuple: (xdata, ydata)
    [2] samp: generator object for the sampler function
    
    '''    
    idxs = next(samp)
    while True:
        yield (dataset[0][idxs], dataset[1][idxs])

In [154]:
bs = 32
dataset = data_set(x_train,x_train)
sampler_obj = sampler(dataset, bs, shuffle=True)
data_gen_obj = data_gen(dataset, sampler_obj)

# training

In [155]:
LEARNING_RATE = 0.0005
R_LOSS_FACTOR = 1000

In [156]:
vae.compile(LEARNING_RATE, R_LOSS_FACTOR)

In [157]:
BATCH_SIZE = 32
EPOCHS = 200
PRINT_EVERY_N_BATCHES = 100
INITIAL_EPOCH = 0
DATALEN = len(dataset[0])
STEPS_PER_EPOCH = DATALEN//BATCH_SIZE 

In [158]:
vae.train_with_generator(     
    data_gen_obj
    , epochs = EPOCHS
    , steps_per_epoch = STEPS_PER_EPOCH
    , run_folder = RUN_FOLDER
    , print_every_n_batches = PRINT_EVERY_N_BATCHES
    , initial_epoch = INITIAL_EPOCH
)

Epoch 1/200

Epoch 00001: saving model to ./run/vae/0002_digits\weights/weights-001-30.16.h5

Epoch 00001: saving model to ./run/vae/0002_digits\weights/weights.h5
Epoch 2/200

Epoch 00002: saving model to ./run/vae/0002_digits\weights/weights-002-10.45.h5

Epoch 00002: saving model to ./run/vae/0002_digits\weights/weights.h5
Epoch 3/200

Epoch 00003: saving model to ./run/vae/0002_digits\weights/weights-003-9.11.h5

Epoch 00003: saving model to ./run/vae/0002_digits\weights/weights.h5
Epoch 4/200

Epoch 00004: saving model to ./run/vae/0002_digits\weights/weights-004-8.40.h5

Epoch 00004: saving model to ./run/vae/0002_digits\weights/weights.h5
Epoch 5/200

Epoch 00005: saving model to ./run/vae/0002_digits\weights/weights-005-7.76.h5

Epoch 00005: saving model to ./run/vae/0002_digits\weights/weights.h5
Epoch 6/200

Epoch 00006: saving model to ./run/vae/0002_digits\weights/weights-006-7.36.h5

Epoch 00006: saving model to ./run/vae/0002_digits\weights/weights.h5
Epoch 7/200

Epoch 0


Epoch 00030: saving model to ./run/vae/0002_digits\weights/weights-030-5.85.h5

Epoch 00030: saving model to ./run/vae/0002_digits\weights/weights.h5
Epoch 31/200

Epoch 00031: saving model to ./run/vae/0002_digits\weights/weights-031-5.82.h5

Epoch 00031: saving model to ./run/vae/0002_digits\weights/weights.h5
Epoch 32/200

Epoch 00032: saving model to ./run/vae/0002_digits\weights/weights-032-5.81.h5

Epoch 00032: saving model to ./run/vae/0002_digits\weights/weights.h5
Epoch 33/200

Epoch 00033: saving model to ./run/vae/0002_digits\weights/weights-033-5.78.h5

Epoch 00033: saving model to ./run/vae/0002_digits\weights/weights.h5
Epoch 34/200

Epoch 00034: saving model to ./run/vae/0002_digits\weights/weights-034-5.77.h5

Epoch 00034: saving model to ./run/vae/0002_digits\weights/weights.h5
Epoch 35/200

Epoch 00035: saving model to ./run/vae/0002_digits\weights/weights-035-5.76.h5

Epoch 00035: saving model to ./run/vae/0002_digits\weights/weights.h5
Epoch 36/200

Epoch 00036: sa


Epoch 00058: saving model to ./run/vae/0002_digits\weights/weights-058-5.54.h5

Epoch 00058: saving model to ./run/vae/0002_digits\weights/weights.h5
Epoch 59/200

Epoch 00059: saving model to ./run/vae/0002_digits\weights/weights-059-5.51.h5

Epoch 00059: saving model to ./run/vae/0002_digits\weights/weights.h5
Epoch 60/200

Epoch 00060: saving model to ./run/vae/0002_digits\weights/weights-060-5.53.h5

Epoch 00060: saving model to ./run/vae/0002_digits\weights/weights.h5
Epoch 61/200

Epoch 00061: saving model to ./run/vae/0002_digits\weights/weights-061-5.49.h5

Epoch 00061: saving model to ./run/vae/0002_digits\weights/weights.h5
Epoch 62/200

Epoch 00062: saving model to ./run/vae/0002_digits\weights/weights-062-5.49.h5

Epoch 00062: saving model to ./run/vae/0002_digits\weights/weights.h5
Epoch 63/200

Epoch 00063: saving model to ./run/vae/0002_digits\weights/weights-063-5.48.h5

Epoch 00063: saving model to ./run/vae/0002_digits\weights/weights.h5
Epoch 64/200

Epoch 00064: sa


Epoch 00086: saving model to ./run/vae/0002_digits\weights/weights-086-5.41.h5

Epoch 00086: saving model to ./run/vae/0002_digits\weights/weights.h5
Epoch 87/200

Epoch 00087: saving model to ./run/vae/0002_digits\weights/weights-087-5.37.h5

Epoch 00087: saving model to ./run/vae/0002_digits\weights/weights.h5
Epoch 88/200

Epoch 00088: saving model to ./run/vae/0002_digits\weights/weights-088-5.36.h5

Epoch 00088: saving model to ./run/vae/0002_digits\weights/weights.h5
Epoch 89/200

Epoch 00089: saving model to ./run/vae/0002_digits\weights/weights-089-5.38.h5

Epoch 00089: saving model to ./run/vae/0002_digits\weights/weights.h5
Epoch 90/200

Epoch 00090: saving model to ./run/vae/0002_digits\weights/weights-090-5.35.h5

Epoch 00090: saving model to ./run/vae/0002_digits\weights/weights.h5
Epoch 91/200

Epoch 00091: saving model to ./run/vae/0002_digits\weights/weights-091-5.37.h5

Epoch 00091: saving model to ./run/vae/0002_digits\weights/weights.h5
Epoch 92/200

Epoch 00092: sa

KeyboardInterrupt: 