In [1]:
from tensorflow.keras import models, layers
from ngdlm import models as ngdlmodels
from ngdlm import utils as ngdlutils
from tensorflow.keras.datasets import mnist
import numpy as np

# Train- and test-data.

In [2]:
(x_input_train, y_output_train), (x_input_test, y_output_test) = mnist.load_data()
x_input_train = x_input_train.astype("float32") / 255.0
x_input_test = x_input_test.astype("float32") / 255.0
print(x_input_train.shape)
print(x_input_test.shape)

(60000, 28, 28)
(10000, 28, 28)


# Triplet loss.

In [None]:
latent_dim = 8

# Create the base-model.
base_input = layers.Input(shape=(28, 28))
base_output = base_input
base_output = layers.Flatten()(base_output)
base_output = layers.Dense(512, activation="relu")(base_output)
base_output = layers.Dense(256, activation="relu")(base_output)
base_output = layers.Dense(128, activation="relu")(base_output)
base_output = layers.Dense(latent_dim)(base_output)
base = models.Model(base_input, base_output)

# Create the triplet loss model.
tl = ngdlmodels.TL(base)
tl.compile(optimizer="rmsprop", triplet_loss="euclidean")
#tl.summary()

# Train.
print("Train...")
history = tl.fit(
        x_input_train, y_output_train,
        epochs=1000,
        batch_size=128,
        steps_per_epoch=1000,
        minibatch_size=10,
        shuffle=True,
        validation_data=(x_input_test, y_output_test),
        validation_steps=500
    )

Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor
Train...
Epoch 1/1000...
Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where
loss [0.001776958405273035]
val_loss [0.28968775272369385]
Epoch 2/1000...
loss [0.002040584284812212]
val_loss [0.42852091789245605]
Epoch 3/1000...
loss [0.0018833374110981823]
val_loss [0.25745901465415955]
Epoch 4/1000...
loss [0.0015206384893972427]
val_loss [0.3331780433654785]
Epoch 5/1000...
loss [0.0013796454295516014]
val_loss [0.31549692153930664]
Epoch 6/1000...
loss [0.0016841064137406647]
val_loss [0.3355473279953003]
Epoch 7/1000...
loss [0.0014051316061522811]
val_loss [0.2461884468793869]
Epoch 8/1000...
loss [0.0013360956518445163]
val_loss [0.32527977228164673]
Epoch 9/1000...
loss [0.0015073403760325163]
val_loss [0.21918249130249023]
Epoch 10/1000...
loss [0.0013042997615411878]
val_loss [0.2221946120262146]
Epoch 11/1000...
loss [0

# Visualizing triplet-loss.

In [None]:
#print("Rendering history...")
ngdlutils.render_history(history)

print("Rendering encodings...")
ngdlutils.render_encodings(tl.base, x_input_train, y_output_train)
ngdlutils.render_encodings(tl.base, x_input_test, y_output_test)

# TODO visualize triplets