# VAE & Triplet network

# Imports

In [None]:
# # # Credit: https://stackoverflow.com/questions/34199233/how-to-prevent-tensorflow-from-allocating-the-totality-of-a-gpu-memory
import tensorflow as tf
gpus = tf.config.experimental.list_physical_devices('GPU')
for gpu in gpus:
    tf.config.experimental.set_memory_growth(gpu, True)

In [None]:
from tensorflow import keras
from loaders import TripletSequence
from loaders import PairSequence
import data
import nets

# Globals

In [None]:
K = keras.backend
codings_size = 10
alpha = 0.5

# Data

In [None]:
emnist = data.load_dataset('emnist/balanced')
(train_x, train_y), (valid_x, valid_y), (test_x, test_y) = emnist.split()

In [None]:
train_seq = TripletSequence(train_x, train_y, samples_per_label=1, batch_size=1024, can_shuffle=True, is_generator=True)
valid_seq = TripletSequence(valid_x, valid_y, samples_per_label=1, batch_size=1024, can_shuffle=True, is_generator=True)
test_seq  = TripletSequence(test_x,  test_y,  samples_per_label=1, batch_size=1024, can_shuffle=True, is_generator=True)

# Model

In [None]:
encoder = nets.make_encoder_vae(codings_size)
decoder = nets.make_decoder(codings_size)
tnet = nets.TripletNetVAE(encoder, decoder, nets.ecludean_distance)

In [None]:
triplet_loss = nets.get_triplet_loss(alpha)
recon_loss = nets.get_recon_loss()
kld_loss = nets.get_kld_loss(image_count=3)

In [None]:
opt = keras.optimizers.SGD(learning_rate=0.1, momentum=0.9, nesterov=True) #0.01

In [None]:
tnet.compile(loss={'recon':recon_loss, 'mean-var':kld_loss, 'distance':triplet_loss}, 
                    metrics={'distance':nets.TripletAccuracy(alpha)},
                    loss_weights={'recon': 2.0, 'mean-var': 1.0, 'distance': 3.0},
                    optimizer=opt, 
                    run_eagerly=True
                   )

In [None]:
early_stopping = keras.callbacks.EarlyStopping(patience=20, min_delta=1/100000, restore_best_weights=True, monitor='val_loss')

In [None]:
print("Fitting")
tnet.fit(train_seq, epochs=500, workers=26, validation_data=(valid_seq), callbacks=[early_stopping])

# Evaluation

In [None]:
print("Evaluating training set")
for i in range(3):
    print(tnet.evaluate(train_seq, verbose = 0)[-1])

In [None]:
print("Evaluating validation set")
for i in range(3):
    print(tnet.evaluate(valid_seq, verbose = 0)[-1])

In [None]:
print("Evaluating testing set")
for i in range(3):
    print(tnet.evaluate(test_seq, verbose = 0)[-1])

In [None]:
print("Saving Model")
tnet.save("vae_tn_model")