# Siamese networks with TensorFlow 2.0/Keras

In this example, we'll implement a simple siamese network system, which verifyies whether a pair of MNIST images is of the same class (true) or not (false). 

_This example is partially based on_ [https://github.com/keras-team/keras/blob/master/examples/mnist_siamese.py](https://github.com/keras-team/keras/blob/master/examples/mnist_siamese.py)


Let's start with the imports

In [1]:
import random

import numpy as np
import tensorflow as tf

We'll continue with the `create_pairs` function, which creates a training dataset of equal number of true/false pairs of each MNIST class.

In [2]:
def create_pairs(inputs: np.ndarray, labels: np.ndarray):
    """Create equal number of true/false pairs of samples"""

    num_classes = 10

    digit_indices = [np.where(labels == i)[0] for i in range(num_classes)]
    pairs = list()
    labels = list()
    n = min([len(digit_indices[d]) for d in range(num_classes)]) - 1
    for d in range(num_classes):
        for i in range(n):
            z1, z2 = digit_indices[d][i], digit_indices[d][i + 1]
            pairs += [[inputs[z1], inputs[z2]]]
            inc = random.randrange(1, num_classes)
            dn = (d + inc) % num_classes
            z1, z2 = digit_indices[d][i], digit_indices[dn][i]
            pairs += [[inputs[z1], inputs[z2]]]
            labels += [1, 0]
    return np.array(pairs), np.array(labels, dtype=np.float32)

Next, we'll define the base network of the siamese system:

In [3]:
def create_base_network():
    """The shared encoding part of the siamese network"""

    return tf.keras.models.Sequential([
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(128, activation='relu'),
        tf.keras.layers.Dropout(0.1),
        tf.keras.layers.Dense(128, activation='relu'),
        tf.keras.layers.Dropout(0.1),
        tf.keras.layers.Dense(64, activation='relu'),
    ])

Next, let's load the regular MNIST training and validation sets and create true/false pairs out of them:

In [4]:
# Load the train and test MNIST datasets
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
x_train = x_train.astype(np.float32)
x_test = x_test.astype(np.float32)
x_train /= 255
x_test /= 255
input_shape = x_train.shape[1:]

# Create true/false training and testing pairs
train_pairs, tr_labels = create_pairs(x_train, y_train)
test_pairs, test_labels = create_pairs(x_test, y_test)

Then, we'll build the siamese system, which includes the `base_network`, the 2 siamese paths `encoder_a` and `encoder_b`, the `l1_dist` measure, and the combined `model`:

In [5]:
# Create the siamese network
# Start from the shared layers
base_network = create_base_network()

# Create first half of the siamese system
input_a = tf.keras.layers.Input(shape=input_shape)

# Note how we reuse the base_network in both halfs
encoder_a = base_network(input_a)

# Create the second half of the siamese system
input_b = tf.keras.layers.Input(shape=input_shape)
encoder_b = base_network(input_b)

# Create the the distance measure
l1_dist = tf.keras.layers.Lambda(
    lambda embeddings: tf.keras.backend.abs(embeddings[0] - embeddings[1])) \
    ([encoder_a, encoder_b])

# Final fc layer with a single logistic output for the binary classification
flattened_weighted_distance = tf.keras.layers.Dense(1, activation='sigmoid') \
    (l1_dist)

# Build the model
model = tf.keras.models.Model([input_a, input_b], flattened_weighted_distance)

Finally, we can train the model and check the validation accuracy, which reaches 99.37%:

In [6]:
# Train
model.compile(loss='binary_crossentropy',
              optimizer=tf.keras.optimizers.Adam(),
              metrics=['accuracy'])

model.fit([train_pairs[:, 0], train_pairs[:, 1]], tr_labels,
          batch_size=128,
          epochs=20,
          validation_data=([test_pairs[:, 0], test_pairs[:, 1]], test_labels))

Train on 108400 samples, validate on 17820 samples
Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20
Epoch 17/20
Epoch 18/20
Epoch 19/20
Epoch 20/20


<tensorflow.python.keras.callbacks.History at 0x7f455027b4a8>