In [2]:
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Flatten, Dense, Dropout, Lambda
from tensorflow.keras.optimizers import RMSprop
from tensorflow.keras.datasets import fashion_mnist
from tensorflow.keras.utils import plot_model
from tensorflow.keras import backend as K

import numpy as np
import matplotlib.pyplot as plt
from PIL import Image, ImageFont, ImageDraw
import random

## Prepare the Dataset
First define a few utilities for preparing and visualizing your dataset.

In [3]:
# Create a pairs of images and generate y values if those pair are from the same class
def create_pairs(x, digit_indices):
    '''Positive and negative pair creation.
    Alternates between positive and negative pairs.
    '''
    pairs = []
    labels = []
    # Find the safe number of pairs based on the minimum class
    n = min([len(digit_indices[d]) for d in range(10)]) - 1

    for d in range(10):
        for i in range(n):
            # Positive pair
            z1, z2 = digit_indices[d][i], digit_indices[d][i + 1]
            pairs += [[x[z1], x[z2]]]
            inc = random.randrange(1, 10)
            dn = (d + inc) % 10
            # Negative pair
            z1, z2 = digit_indices[d][i], digit_indices[dn][i]
            pairs += [[x[z1], x[z2]]]
            labels += [1, 0]

    return np.array(pairs), np.array(labels)

# A helper function for generating pairs and their similarity labels
def create_pairs_on_set(images, labels):
    """
    function is a helper function that takes in a set of images and their corresponding labels, 
    then returns pairs of images along with their relationship labels 
    (either 1 for positive pairs or 0 for negative pairs).
    """
    digit_indices = [np.where(labels == i)[0] for i in range(10)]
    pairs, y = create_pairs(images, digit_indices)
    y = y.astype(np.float32)

    return pairs, y

# Showing an image
def show_image(image):
    plt.figure()
    plt.imshow(image)
    plt.colorbar()
    plt.grid(False)
    plt.show()

We can now download and prepare our train and test sets. We will also create pairs of images that will go into the multi-input model.

In [4]:
# Load the dataset
(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()

# Prepare the train and test images
train_images = train_images.astype('float32')
test_images = test_images.astype('float32')

# Normalize values
train_images = train_images / 255.
test_images = test_images / 255.0

# Create pairs on train and test dataset
tr_pairs, tr_y = create_pairs_on_set(train_images, train_labels)
ts_pairs, ts_y = create_pairs_on_set(test_images, test_labels)

Now, we can show and plot some of the pairs and see their y values, which shows if they are from the same class based on the `fashion_mnist` dataset.