# The siamese network

# Imports

In [None]:
# # # Credit: https://stackoverflow.com/questions/34199233/how-to-prevent-tensorflow-from-allocating-the-totality-of-a-gpu-memory
import tensorflow as tf
gpus = tf.config.experimental.list_physical_devices('GPU')
for gpu in gpus:
    tf.config.experimental.set_memory_growth(gpu, True)

In [None]:
from tensorflow import keras
from loaders import TripletSequence
from loaders import PairSequence
import nets
import data

# Globals

In [None]:
K = keras.backend
codings_size = 10
alpha = 0.5

# Data

In [None]:
emnist = data.load_dataset('emnist/balanced')
(train_x, train_y), (valid_x, valid_y), (test_x, test_y) = emnist.split()

In [None]:
train_seq = PairSequence(train_x, train_y, samples_per_label=1, batch_size=1024, can_shuffle=True, output='y')
valid_seq = PairSequence(valid_x, valid_y, samples_per_label=1, batch_size=1024, can_shuffle=True, output='y')
test_seq  = PairSequence(test_x,  test_y,  samples_per_label=1, batch_size=1024, can_shuffle=True, output='y')

# Model

In [None]:
conv_base   = nets.make_conv_base(codings_size)
siamese_net = nets.SiameseNet(conv_base)

In [None]:
my_siamese_loss = nets.get_contrastive_loss(alpha)

In [None]:
opt = keras.optimizers.SGD(learning_rate=0.1, momentum=0.9, nesterov=True)

In [None]:
siamese_net.compile(loss=my_siamese_loss, 
                    metrics=nets.PairAccuracy(alpha),
                    optimizer=opt, 
                    run_eagerly=True
                   )

In [None]:
early_stopping = keras.callbacks.EarlyStopping(patience=20, min_delta=1/100000, restore_best_weights=True, monitor='val_loss')

In [None]:
siamese_net.fit(train_seq, epochs=500, validation_data=valid_seq, workers=26, callbacks=[early_stopping])

# Evaluation

In [None]:
train_triplet_seq = TripletSequence(train_x, train_y, samples_per_label=1, batch_size=1024, can_shuffle=True, is_generator=True)
valid_triplet_seq = TripletSequence(valid_x, valid_y, samples_per_label=1, batch_size=1024, can_shuffle=True, is_generator=True)
test_triplet_seq  = TripletSequence(test_x,  test_y,  samples_per_label=1, batch_size=1024, can_shuffle=True, is_generator=True)

In [None]:
print("Evaluating training set:")
for i in range(3):
    nets.evaluate_siamese_on_triplets(train_triplet_seq, siamese_net, alpha=alpha)

In [None]:
print("Evaluating validation set:")
for i in range(3):
    nets.evaluate_siamese_on_triplets(valid_triplet_seq, siamese_net, alpha=alpha)

In [None]:
print("Evaluating testing set:")
for i in range(3):
    nets.evaluate_siamese_on_triplets(test_triplet_seq, siamese_net, alpha=alpha)

In [None]:
print("Saving Model")
siamese_net.save("sn_model")