Import necessary packages

In [3]:
import numpy as np

import random
from keras.datasets import mnist
from keras.models import Sequential, Model
from keras.layers.core import Activation, Flatten
from keras.layers.convolutional import Convolution2D, MaxPooling2D
from keras.layers import Dense, Dropout, Input, Lambda
from keras.optimizers import SGD, RMSprop
from keras import backend as K

In [29]:
def euclidean_distance(vects):
    x, y = vects
    return K.sqrt(K.sum(K.square(x - y), axis=1, keepdims=True))


def eucl_dist_output_shape(shapes):
    shape1, shape2 = shapes
    return shape1


def contrastive_loss(y_true, y_pred):
    '''Contrastive loss from Hadsell-et-al.'06
    http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf
    '''
    margin = 1
    return K.mean(y_true * K.square(y_pred) + (1 - y_true) * K.square(K.maximum(margin - y_pred, 0)))


def create_pairs(x, digit_indices):
    '''Positive and negative pair creation.
    Alternates between positive and negative pairs.
    '''
    pairs = []
    labels = []
    n = min([len(digit_indices[d]) for d in range(10)]) - 1
    for d in range(10):
        for i in range(n):
            z1, z2 = digit_indices[d][i], digit_indices[d][i + 1]
            pairs += [[x[z1], x[z2]]]
            inc = random.randrange(1, 10)
            dn = (d + inc) % 10
            z1, z2 = digit_indices[d][i], digit_indices[dn][i]
            pairs += [[x[z1], x[z2]]]
            labels += [1, 0]
    return np.array(pairs), np.array(labels)



def compute_accuracy(predictions, labels):
    '''Compute classification accuracy with a fixed threshold on distances.
    '''
    return labels[predictions.ravel() < 0.5].mean()


'''
def create_base_network_dense(input_dim):
	#Base network to be shared (eq. to feature extraction).
	
	seq = Sequential()
	seq.add(Dense(128, input_shape=(input_dim,), activation='relu'))
	seq.add(Dropout(0.1))
	seq.add(Dense(128, activation='relu'))
	seq.add(Dropout(0.1))
	seq.add(Dense(128, activation='relu'))
	return seq
'''


def create_base_network(input_dim):
    # input image dimensions
    img_colours, img_rows, img_cols = input_dim

    # number of convolutional filters to use
    nb_filters = 32
    # size of pooling area for max pooling
    nb_pool = 2
    # convolution kernel size
    nb_conv = 3
    model = Sequential()

    model.add(Convolution2D(nb_filters, (nb_conv, nb_conv),
                            padding='valid',
                            input_shape=(img_colours, img_rows, img_cols)))
    model.add(Activation('relu'))
    model.add(Convolution2D(nb_filters, (nb_conv, nb_conv)))
    model.add(Activation('relu'))
    model.add(MaxPooling2D(pool_size=(nb_pool, nb_pool)))
    #model.add(Dropout(0.1)) #0.25 #too much dropout and loss -> nan

    model.add(Flatten())

    model.add(Dense(64, input_shape=(input_dim,), activation='relu'))
    #model.add(Dropout(0.05)) #too much dropout and loss -> nan
    model.add(Dense(32, activation='relu'))

    return model

def compute_accuracy(y_true, y_pred):
    '''Compute classification accuracy with a fixed threshold on distances.
    '''
    pred = y_pred.ravel() < 0.5
    return np.mean(pred == y_true)


def accuracy(y_true, y_pred):
    '''Compute classification accuracy with a fixed threshold on distances.
    '''
    return K.mean(K.equal(y_true, K.cast(y_pred < 0.5, y_true.dtype)))


In [38]:
# the data, split between train and test sets
epochs = 10
num_classes = 10
img_rows, img_cols = 28, 28

(x_train, y_train), (x_test, y_test) = mnist.load_data()

if K.image_data_format() == 'channels_first':
    x_train = x_train.reshape(x_train.shape[0], 1, img_rows, img_cols)
    x_test = x_test.reshape(x_test.shape[0], 1, img_rows, img_cols)
    input_shape = (1, img_rows, img_cols)
else:
    x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1)
    x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1)
    input_shape = (img_rows, img_cols, 1)

x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255
x_test /= 255
print(input_shape)

# create training+test positive and negative pairs
digit_indices = [np.where(y_train == i)[0] for i in range(num_classes)]
tr_pairs, tr_y = create_pairs(x_train, digit_indices)

digit_indices = [np.where(y_test == i)[0] for i in range(num_classes)]
te_pairs, te_y = create_pairs(x_test, digit_indices)

# network definition
base_network = create_base_network(input_shape)

input_a = Input(shape=input_shape)
input_b = Input(shape=input_shape)

# because we re-use the same instance `base_network`,
# the weights of the network
# will be shared across the two branches
processed_a = base_network(input_a)
processed_b = base_network(input_b)

distance = Lambda(euclidean_distance,
                  output_shape=eucl_dist_output_shape)([processed_a, processed_b])

model = Model([input_a, input_b], distance)

# train
rms = RMSprop()
model.compile(loss=contrastive_loss, optimizer=rms, metrics=[accuracy])
model.fit([tr_pairs[:, 0], tr_pairs[:, 1]], tr_y,
          batch_size=128,
          epochs=epochs,
          validation_data=([te_pairs[:, 0], te_pairs[:, 1]], te_y))

# compute final accuracy on training and test sets
y_pred = model.predict([tr_pairs[:, 0], tr_pairs[:, 1]])
tr_acc = compute_accuracy(tr_y, y_pred)
y_pred = model.predict([te_pairs[:, 0], te_pairs[:, 1]])
te_acc = compute_accuracy(te_y, y_pred)

print('* Accuracy on training set: %0.2f%%' % (100 * tr_acc))
print('* Accuracy on test set: %0.2f%%' % (100 * te_acc))


(1, 28, 28)


Instructions for updating:
Use tf.cast instead.


Train on 108400 samples, validate on 17820 samples
Epoch 1/10


   128/108400 [..............................] - ETA: 59:35 - loss: 0.3418 - accuracy: 0.4688

   256/108400 [..............................] - ETA: 32:21 - loss: 0.4017 - accuracy: 0.5664

   384/108400 [..............................] - ETA: 23:22 - loss: 0.3381 - accuracy: 0.5964

   512/108400 [..............................] - ETA: 18:44 - loss: 0.2983 - accuracy: 0.6309

   640/108400 [..............................] - ETA: 16:03 - loss: 0.2760 - accuracy: 0.6422

   768/108400 [..............................] - ETA: 14:24 - loss: 0.2617 - accuracy: 0.6523

   896/108400 [..............................] - ETA: 13:01 - loss: 0.2521 - accuracy: 0.6562

  1024/108400 [..............................] - ETA: 12:00 - loss: 0.2443 - accuracy: 0.6631

  1152/108400 [..............................] - ETA: 11:11 - loss: 0.2355 - accuracy: 0.6771

  1280/108400 [..............................] - ETA: 10:31 - loss: 0.2259 - accuracy: 0.6937

  1408/108400 [..............................] - ETA: 10:04 - loss: 0.2209 - accuracy: 0.6974

  1536/108400 [..............................] - ETA: 9:39 - loss: 0.2139 - accuracy: 0.7083 

  1664/108400 [..............................] - ETA: 9:18 - loss: 0.2079 - accuracy: 0.7194

  1792/108400 [..............................] - ETA: 8:58 - loss: 0.2052 - accuracy: 0.7199

  1920/108400 [..............................] - ETA: 8:42 - loss: 0.2004 - accuracy: 0.7266

  2048/108400 [..............................] - ETA: 8:27 - loss: 0.1964 - accuracy: 0.7314

  2176/108400 [..............................] - ETA: 8:13 - loss: 0.1924 - accuracy: 0.7381

  2304/108400 [..............................] - ETA: 8:00 - loss: 0.1906 - accuracy: 0.7418

  2432/108400 [..............................] - ETA: 7:51 - loss: 0.1879 - accuracy: 0.7463

  2560/108400 [..............................] - ETA: 7:41 - loss: 0.1853 - accuracy: 0.7508

  2688/108400 [..............................] - ETA: 7:33 - loss: 0.1836 - accuracy: 0.7526

  2816/108400 [..............................] - ETA: 7:24 - loss: 0.1808 - accuracy: 0.7585

  2944/108400 [..............................] - ETA: 7:17 - loss: 0.1780 - accuracy: 0.7626

  3072/108400 [..............................] - ETA: 7:11 - loss: 0.1759 - accuracy: 0.7663

  3200/108400 [..............................] - ETA: 7:05 - loss: 0.1733 - accuracy: 0.7725

  3328/108400 [..............................] - ETA: 7:00 - loss: 0.1704 - accuracy: 0.7776

  3456/108400 [..............................] - ETA: 6:54 - loss: 0.1695 - accuracy: 0.7784

  3584/108400 [..............................] - ETA: 6:49 - loss: 0.1682 - accuracy: 0.7799

  3712/108400 [>.............................] - ETA: 6:44 - loss: 0.1660 - accuracy: 0.7831

  3840/108400 [>.............................] - ETA: 6:39 - loss: 0.1643 - accuracy: 0.7867

  3968/108400 [>.............................] - ETA: 6:37 - loss: 0.1629 - accuracy: 0.7891

  4096/108400 [>.............................] - ETA: 6:32 - loss: 0.1612 - accuracy: 0.7927

  4224/108400 [>.............................] - ETA: 6:30 - loss: 0.1595 - accuracy: 0.7943

  4352/108400 [>.............................] - ETA: 6:27 - loss: 0.1582 - accuracy: 0.7953

  4480/108400 [>.............................] - ETA: 6:24 - loss: 0.1575 - accuracy: 0.7969

  4608/108400 [>.............................] - ETA: 6:20 - loss: 0.1565 - accuracy: 0.7975

  4736/108400 [>.............................] - ETA: 6:17 - loss: 0.1555 - accuracy: 0.7992

  4864/108400 [>.............................] - ETA: 6:15 - loss: 0.1543 - accuracy: 0.8012

  4992/108400 [>.............................] - ETA: 6:12 - loss: 0.1532 - accuracy: 0.8037

  5120/108400 [>.............................] - ETA: 6:09 - loss: 0.1522 - accuracy: 0.8047

  5248/108400 [>.............................] - ETA: 6:07 - loss: 0.1516 - accuracy: 0.8062

  5376/108400 [>.............................] - ETA: 6:05 - loss: 0.1504 - accuracy: 0.8084

KeyboardInterrupt: 