# Variational Auto-Encoder

In [None]:
# # # Credit: https://stackoverflow.com/questions/34199233/how-to-prevent-tensorflow-from-allocating-the-totality-of-a-gpu-memory
import tensorflow as tf
gpus = tf.config.experimental.list_physical_devices('GPU')
for gpu in gpus:
    tf.config.experimental.set_memory_growth(gpu, True)

In [None]:
from tensorflow import keras
from loaders import TripletSequence
from loaders import PairSequence
import data
import nets

In [None]:
K = keras.backend
codings_size = 10
alpha = 0.5

# Data

In [None]:
emnist = data.load_dataset('emnist/balanced')
(train_x, train_y), (valid_x, valid_y), (test_x, test_y) = emnist.split()

# Model

In [None]:
encoder = nets.make_encoder_vae(codings_size)
decoder = nets.make_decoder(codings_size)
vae = nets.VAE(encoder, decoder)

In [None]:
recon_loss = nets.get_recon_loss()
kld_loss = nets.get_kld_loss(1)

In [None]:
opt = keras.optimizers.SGD(learning_rate=0.1, momentum=0.9, nesterov=True)

In [None]:
vae.compile(loss={'recon':recon_loss, 'mean-var':kld_loss}, 
                    optimizer=opt, 
                    run_eagerly=True)

In [None]:
early_stopping = keras.callbacks.EarlyStopping(patience=20, min_delta=1/100000, restore_best_weights=True, monitor='val_loss')

In [None]:
print("Fitting")
vae.fit(train_x, train_x, epochs=500, batch_size = 1024, workers=26, validation_data=(valid_x, valid_x), callbacks=[early_stopping])

# Evaluation

In [None]:
train_triplet_seq = TripletSequence(train_x, train_y, samples_per_label=1, batch_size=1024, can_shuffle=True, is_generator=True)
valid_triplet_seq = TripletSequence(valid_x, valid_y, samples_per_label=1, batch_size=1024, can_shuffle=True, is_generator=True)
test_triplet_seq  = TripletSequence(test_x,  test_y,  samples_per_label=1, batch_size=1024, can_shuffle=True, is_generator=True)

In [None]:
print("Evaluating training set:")
for x in range(3):
    nets.evaluate_vae_on_triplets(train_triplet_seq, vae, alpha=alpha)

In [None]:
print("Evaluating validation set:")
for x in range(3):
    nets.evaluate_vae_on_triplets(valid_triplet_seq, vae, alpha=alpha)

In [None]:
print("Evaluating testing set:")
for x in range(3):
    nets.evaluate_vae_on_triplets(test_triplet_seq, vae, alpha=alpha)

In [None]:
print("Saving Model")
vae.save("vae_model")