In [None]:
import numpy as np
import keras
import pickle as pk
import matplotlib.pyplot as plt
import os
from keras.applications.resnet50 import ResNet50
from keras.applications.vgg16 import VGG16
from keras.preprocessing import image
from keras.applications.resnet50 import preprocess_input
from keras.layers import UpSampling2D, Input
from keras.models import Model, Sequential

### Define SiameseNet model

In [None]:
### Define model

# Load ResNet NN and keep only the relevant layers

#TODO: add dropout

#resnet = ResNet50(weights='imagenet', include_top=True)
resnet = ResNet50(weights=None, include_top=True)
resnet.layers.pop()
resnet.outputs = [resnet.layers[-1].output]
resnet.layers[-1].outbound_nodes = []

# Def right and left inputs
left_img = Input(shape=(32, 32, 3), name='left_input')
right_img = Input(shape=(32, 32, 3), name='right_input')
left_inp = UpSampling2D(size =(7,7))(right_img)
right_inp = UpSampling2D(size =(7,7))(left_img)

# Def model
left_out = resnet(left_inp)
right_out = resnet(right_inp)
diff = keras.layers.subtract([left_out, right_out])
prediction = keras.layers.Dense(1,activation='sigmoid', name='dist')(diff)
siamese_net = Model(input=[left_img,right_img],output=prediction)

# Compile model
adadelta = keras.optimizers.Adadelta()
siamese_net.compile(loss='binary_crossentropy', optimizer=adadelta)

# Print information about the model
siamese_net.summary()

### Feature extractor

In [None]:
# Extract ResNet weight fine tuned
resnet_tuned = siamese_net.layers[4]
img = Input(shape=(32, 32, 3), name='input')
img2 = UpSampling2D(size =(7,7))(img)

# Def model
feat = resnet_tuned(img2)
extractor = Model(input=img, output=feat)
extractor.summary()

### Define custom distance for our k-NN

In [4]:
# Extract distance layer from our model, and use it for the k-NN
dist_layer = siamese_net.get_layer('dist').get_weights()
def cust_dist(x, y):
    """Distance using the metric learned by our SiameseNet
    Args:
        x, y: (array (2048,)): the features to use
    return:
        dist: (float) distance value
    """
    temp = np.abs(x-y)
    return np.dot(dist_layer, temp)


In [27]:
machine = [0,1,8,9]
animal = [2,3,4,5,6,7]
# Function to load a batch into memory
def load_batch(data_dir, batch_id):
    with open(os.path.join(data_dir, 'data_batch_%i' % batch_id), mode='rb') as file:
        batch = pk.load(file, encoding='latin1')
    feats = batch['data'].reshape((len(batch['data']), 3, 32, 32)).transpose(0, 2, 3, 1)
    lbls = batch['labels']
    return feats, lbls

# Generate pairs according to a certain percentage
def get_pairs(data_path, positive_percentage=0.5, cifar_batch_id=1, batch_size=32, plt_fig=False):
    left = []
    right = [] 
    feats, labels = load_batch(data_path, cifar_batch_id)
    indexes = np.random.randint(len(labels),size=batch_size)
    sample_images = np.take(feats, indexes, axis=0)
    sample_labels = np.take(labels, indexes, axis=0)
    other_indexes = np.delete(range(len(labels)), indexes)
    other_feats = np.take(feats, other_indexes, axis=0)
    other_labels = np.take(labels, other_indexes)
    for i in range(batch_size):
        if i < int(batch_size*positive_percentage):
            index_of_label = np.where(other_labels==sample_labels[i])[0]
        else :
            if sample_labels[i] in machine:
                machine_without = machine.copy()
                machine_without.remove(sample_labels[i])
                index_of_label = np.where(np.isin(other_labels, machine_without))[0]
            else : 
                animal_without = animal.copy()
                animal_without.remove(sample_labels[i])
                index_of_label = np.where(np.isin(other_labels, animal_without))[0]
        index = np.take(index_of_label, np.random.randint(len(index_of_label)))
        other_image = other_feats[index]
        if plt_fig:
            plt.figure()
            plt.subplot("121")
            plt.imshow(other_image)
            plt.subplot("122")
            plt.imshow(sample_images[i])
        left.append(sample_images[i])
        right.append(other_image)
        
    #Generate pair labels
    y = np.zeros(batch_size)
    for i in range(int(batch_size*positive_percentage)):
        y[i] = 1
        
    return [np.array(left), np.array(right)], y

In [None]:
# Create toy dataset (to be improved)
data_path = '../data/cifar10/'
dataset = [[], []]
dataset_labels = []
cifar_batch_id = [1, 2, 3, 4 , 5]
nb_pairs_per_batch = 10000
for batch in cifar_batch_id:
    data, y = get_pairs(data_path, 0.5, 1, batch_size=20000)
    dataset[0] += [data[0]]
    dataset[1] += [data[1]]
    dataset_labels += [y]
dataset[0] = np.concatenate(dataset[0])
dataset[1] = np.concatenate(dataset[1])
dataset_labels = np.concatenate(dataset_labels)

### Siamese Networks training

In [None]:
siamese_net.fit(dataset, dataset_labels, epoch=10, metrics=["accuracy"], batch_size=32)