In [1]:
import pickle

In [2]:
import numpy as np
from sklearn.utils import shuffle

In [3]:
TRAIN_NUM = 1150
VALIDATION = 50

def test_val_train_split(data_path, seed=None):
    X_all = pickle.load(open(data_path, 'rb'))
    
    if seed:
        np.random.seed(seed)

    indices = np.random.permutation(len(X_all))
    
    return { 
        "train" : X_all[indices[:TRAIN_NUM]],
        "val": X_all[indices[TRAIN_NUM:TRAIN_NUM+VALIDATION]],
        "test": X_all[indices[TRAIN_NUM+VALIDATION:]]
    }

#data = test_val_train_split('/data/X_all.pickle', 4)

In [4]:
#X = data['train']

In [5]:
# X.shape[0]

In [6]:
class DataFeeder:
    """
    For loading batches and testing tasks to a siamese net
    """
    
    def __init__(self, data_path, seed):
        self.data = test_val_train_split(data_path, seed)

    def get_batch(self, n, type_="train"):
        """
        Create batch of n pairs, half same class, half different class
        """
        X=self.data[type_]
        
        n_classes, n_examples, w, h = X.shape
        
        # select n random classes
        categories = np.random.choice(n_classes, size=(n,), replace=False)
        pairs=[np.zeros((n, h, w, 1)) for i in range(2)]
        targets=np.ones((n,))

        # top n//2 examples would be negative
        targets[:n//2] = -1
        
        for i in range(n):
            # pick first image
            category_1 = categories[i]
            idx_1 = np.random.randint(0, n_examples)
            pairs[0][i,:,:,:] = X[category_1, idx_1].reshape(w, h, 1)
          
        
            # pick images of same class for 1st half, different for 2nd
            category_2 = category_1 if i <= n//2 else (category_1 + np.random.randint(1, n_classes)) % n_classes
            idx_2 = np.random.randint(0, n_examples)
            pairs[1][i,:,:,:] = X[category_2, idx_2].reshape(w, h, 1)
        
        return pairs, targets
    
    def get_1_shot_batch(self, N, type_="val", category=None):
        """Create a batch pairs of test image, support set for testing N way one-shot learning. """
        X=self.data[type_]
        n_classes, n_examples, w, h = X.shape


        # sample categories
        categories = np.random.choice(range(n_classes), size=(N,), replace=False)
        if category is None:
            true_category = categories[0]
        else:
            if category > n_classes:
                raise ValueError("The category should be in range(0, %d)" % n_classes)
            true_category = category
        
        categories[0] = true_category
        
        # sample indices to be used in each random classs
        indices = np.random.randint(0, n_examples, size=(N,))
        
        # Select two indices to be fetch examples from the true class
        fst_img_indx, snd_img_indx = np.random.choice(n_examples, replace=False, size=(2,))
        
        # Return the first example copied n-times (for N-pairs) 
        test_image = np.asarray([X[true_category, fst_img_indx, :, :]]*N).reshape(N, w, h, 1)
        
        # Create support set for comparision (remember, first is true)
        support_set = X[categories, indices, :, :]
        support_set[0, :, :] = X[true_category, snd_img_indx]
        support_set = support_set.reshape(N, w, h, 1)
        
        # Create targets
        targets = -1*np.ones((N,))
        targets[0] = 1 # the first one is true
        
        # Randomize it
        support_set, targets = shuffle(support_set, targets)
        
        pairs = [test_image, support_set]

        return pairs, targets, true_category

    def test_oneshot(self, model, m, N, type_="test", verbose=0, category=None):
        """
        Test average N way one-shot learning accuracy of a siamese neural net over m one-shot tasks
        """
        n_correct = 0
        if verbose:
            print("Evaluating model on {} unique {} way one-shot learning tasks ...".format(m, N))
        for i in range(m):
            pairs, targets, c = self.get_1_shot_batch(N, type_, category=category)
            probs = model.predict(pairs)
            if np.argmax(probs) == np.argmax(targets):
                n_correct+=1
        percent_correct = (100.0*n_correct / m)
        if verbose:
            print("Got an average of {}% {} way one-shot learning accuracy".format(percent_correct, N))
        return percent_correct

In [7]:
feeder = DataFeeder('/data/X_all.pickle', 4)

In [8]:
class RandomModel(object):
    def predict(delf, pairs):
        batch, _ = pairs
        n, _,_,_ = batch.shape
        probs = np.zeros((n,1))
        probs[np.random.randint(0, n)] = 1
        
        return probs

In [9]:
pairs, targets = feeder.get_batch(10)

In [10]:
(left, right), targets, category = feeder.get_1_shot_batch(10)

In [11]:
model = RandomModel()
model.predict((left, right))

array([[ 0.],
       [ 0.],
       [ 0.],
       [ 0.],
       [ 0.],
       [ 0.],
       [ 0.],
       [ 0.],
       [ 0.],
       [ 1.]])

In [12]:
feeder.test_oneshot(model, 100, 10, verbose=1);

Evaluating model on 100 unique 10 way one-shot learning tasks ...
Got an average of 12.0% 10 way one-shot learning accuracy
