# Creating PK batches for triplet loss training
The concept of triplet loss training is best described in [this blog post by Oliver Moindrot](https://omoindrot.github.io/triplet-loss) and [this tensorflow tutorial](https://www.tensorflow.org/addons/tutorials/losses_triplet) trains a model using it. 

This tutorial is based on [In Defense of the Triplet Loss for Person Re-Identification](https://arxiv.org/pdf/1703.07737.pdf), a paper which shows how to avoid some of the difficulties of triplet loss training. In it they create batches of P classes and K examples from each class to maximise the training speed. 

What I show here is how to create Tensorflow Records with the classes and samples arranged in PK batches so you can optimise your triplet loss training.

## Creating some dummy data

To start we need a TFRecord file for each class containing (sample, label). I assume one record per class as that is what I had, but if you have all in one the principle to create the PK will be similar. I'll generate dummy TFRecords here.

In [430]:
import os
import tensorflow as tf
import numpy as np
from tensorflow.python.lib.io import file_io
import json

In [431]:
NUM_CLASSES = 10
NUM_SAMPLES = 50
DATA_FOLDER = "dummy_data"
FILE_STEM = "train_id"

In [432]:
try:
    os.mkdir(DATA_FOLDER)
except FileExistsError:
    print("Directory already exists")

In [433]:
# Some helper functions to make our dummy data
def _floats_matrix_feature(value):
    return tf.train.Feature(float_list=tf.train.FloatList(value=value.reshape(-1)))

def _int64_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

def create_example(label_idx):
    # In my case I had spectrograms of varying size, perhaps you have images, doesn't matter I just need an array
    array = np.random.rand(2, 2)
    height = array.shape[0]
    width = array.shape[1]
    example = tf.train.Example(features=tf.train.Features(feature={
        'height': _int64_feature(height),
        'width': _int64_feature(width),
        'label': _int64_feature(label_idx),
        'array': _floats_matrix_feature(array)}))
    return example

When processing data and creating datasets I usually save information about each class and the data as a whole. For example, the number of examples I have for each class, the mean, the standard deviation etc. This helps to understand the data that I'm training the model with. (Saving the mean and standard deviation of the datasets is also handy as you can easily normalise the inputs).

In this case I only save the number of examples per class to `stats_dict`. This is used to calculate the multiple of K samples to take for each epoch. I don't actually need it as we just created the data, but include it in case the concept is useful to others.

In [446]:
stats_dict = {}
for label in range(NUM_CLASSES):
    filename = '{0}/{1}{2}.tfrecords'.format(DATA_FOLDER, FILE_STEM, label)
    record_writer = tf.io.TFRecordWriter(filename)
    for sample in range(NUM_SAMPLES):
        example = create_example(label)
        record_writer.write(example.SerializeToString())
    record_writer.close()
    stats_dict['id_{0}'.format(label)] = NUM_SAMPLES

with open('{0}/stats_dict.json'.format(DATA_FOLDER), 'w') as json_file:
    json.dump(stats_dict, json_file)

0 0
0 1
0 2
0 3
0 4
0 5
0 6
0 7
0 8
0 9
0 10
0 11
0 12
0 13
0 14
0 15
0 16
0 17
0 18
0 19
0 20
0 21
0 22
0 23
0 24
0 25
0 26
0 27
0 28
0 29
0 30
0 31
0 32
0 33
0 34
0 35
0 36
0 37
0 38
0 39
0 40
0 41
0 42
0 43
0 44
0 45
0 46
0 47
0 48
0 49
1 0
1 1
1 2
1 3
1 4
1 5
1 6
1 7
1 8
1 9
1 10
1 11
1 12
1 13
1 14
1 15
1 16
1 17
1 18
1 19
1 20
1 21
1 22
1 23
1 24
1 25
1 26
1 27
1 28
1 29
1 30
1 31
1 32
1 33
1 34
1 35
1 36
1 37
1 38
1 39
1 40
1 41
1 42
1 43
1 44
1 45
1 46
1 47
1 48
1 49
2 0
2 1
2 2
2 3
2 4
2 5
2 6
2 7
2 8
2 9
2 10
2 11
2 12
2 13
2 14
2 15
2 16
2 17
2 18
2 19
2 20
2 21
2 22
2 23
2 24
2 25
2 26
2 27
2 28
2 29
2 30
2 31
2 32
2 33
2 34
2 35
2 36
2 37
2 38
2 39
2 40
2 41
2 42
2 43
2 44
2 45
2 46
2 47
2 48
2 49
3 0
3 1
3 2
3 3
3 4
3 5
3 6
3 7
3 8
3 9
3 10
3 11
3 12
3 13
3 14
3 15
3 16
3 17
3 18
3 19
3 20
3 21
3 22
3 23
3 24
3 25
3 26
3 27
3 28
3 29
3 30
3 31
3 32
3 33
3 34
3 35
3 36
3 37
3 38
3 39
3 40
3 41
3 42
3 43
3 44
3 45
3 46
3 47
3 48
3 49
4 0
4 1
4 2
4 3
4 4
4 5
4 6
4 7
4 8
4 9


### Side Note

You can't lookup values in a dictionary using tensorflow strings. To use the stats_dict in tensorflow code we would need to convert it to a tensorflow hashtable. The function below would do that.

This would be useful if we needed to calculate variables to do with the data during execution time.

In [447]:
def set_PID_hashtable():
    """Sets global tensorflow hashtable containing (label, num of examples tensorflow record)

    The hashtable is used to determine if the number of examples of a 
    class needs to be restricted so that the dataset is balanced.
    Why this hashtable instead of a dictionary?
    The tensorflow graph is created before running
    """
    keys = list(stats_dict.keys())
    values = list(stats_dict.values())
    global PID_INFO
    PID_INFO = tf.lookup.StaticHashTable(
        tf.lookup.KeyValueTensorInitializer(
            keys, values, key_dtype=None, 
            value_dtype=None, name=None), -1)
    print('Created hash table: {}'.format(PID_INFO))

# Creating PK Batches
Using this fake data we can create a training dataset organised into PK batches. We will set up the dataset for a **TensorFlow Estimator** model. Estimators have perks when doing quick experiments; for example they can easily (most of the time) be deployed to train in a distributed manner on Google cloud ML engine.



To train an estimator we need three things: `tf.estimator.train_and_evaluate(an_estimator_model, train_spec, eval_spec)`.

Both the [TrainSpec](https://www.tensorflow.org/api_docs/python/tf/estimator/TrainSpec) and [EvalSpec](https://www.tensorflow.org/api_docs/python/tf/estimator/EvalSpec) need an input function. This input function provides the data used in training or evaluation. This is where we need to setup the PK batching. We will be creative and call this function: `pk_input_fn`.

The `pk_input_fn` will look like this:
 
    Full function
    
Each step needs other functions which we'll go through now before getting to the grand finale.

## Create mapping batch k dataset

The first thing we will do is create a second dataset which acts as label to organise the dataset in batches. Say we have 10 samples `[x_1, x_2, ..., x_10]` and want to create batches with K=3. Our second dataset labels each value in the original data set with the k batch it belongs to: `[k_1, k_1, k_1, k_2, k_2, k_2, k_3, k_3, k_3]`. We later join the two datasets: `[(x_1, k_1), ..., (x_9, k_3)]` and the `k` labels let us shuffle the dataset and create PK batches.

`PK_CEILING` is a variable I use to make the PK batches balanced. Its the maximum multiple of K which is less than the number of examples of the class that has the fewest examples. Imagine we have 1000 examples of id0 and 50 of id1, and K=4. PK ceiling = 48.

In production it would be better to calculate this dynamically based on the data and the variables set.

In [468]:
BATCH_P = 5 # The number of classes each batch
BATCH_K = 5 # The number of samples per class in each batch
PK_CEILING = 10 # Largest multiple of K less than smallest dataset size 

In [469]:
def create_interleave_dataset(demo=False):
    keys_list = [create_interleave_keys(i, demo) for i in range(NUM_CLASSES)]
    keys_datasets = [tf.data.Dataset.from_tensor_slices(x) for x in keys_list]
    interleave_keys = keys_datasets[0]
    for i in range(1, len(keys_datasets)):
        interleave_keys = interleave_keys.concatenate(keys_datasets[i])
    return interleave_keys

def create_interleave_keys(pid, demo=False):
    """
    pid: the shorthand for "person identity" in the paper (i.e. the label) 
    """
    
    #  ensure the number of examples from this person is a multiple of K
    num_elements = stats_dict['id_{}'.format(pid)]
    max_elements = np.floor_divide(num_elements, BATCH_K) * BATCH_K
    print('max elements: {0}'.format(max_elements))
    #  
    k = 0
    k_count = 0
    k_keys = []
    for i in range(max_elements):
        if k_count < BATCH_K:
            k_keys.append(k)
            k_count += 1
        else:
            k += 1
            k_keys.append(k)
            k_count = 1
    if demo:
        print('{} keys for id_{}:'.format(len(k_keys), pid))
        print(k_keys)
        print()
    return k_keys

In [470]:
x = create_interleave_keys(0, True)
print(len(x))

max elements: 50
50 keys for id_0:
[0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 8, 8, 8, 8, 8, 9, 9, 9, 9, 9]

50


## Parse the TFRecord files
When we read a TFRecord each element is a TensorFlow Example. We need to convert these into arrays and labels to train the model. We do this with a map function and a parse function.

In [471]:
def tfr_input_parser(example_proto):
    """Parser to convert tfrecord example into (array, label)

    Args:
        example_proto: A single TFR from the TFRecord file

    Returns:
        array: here dummy data, but could be image, spectrogram, etc.
        str: class label
    """
    features={
    # We know the length of both height and width. 
    # If not the tf.VarLenFeature would be used
    'height': tf.io.FixedLenFeature([], tf.int64),
    'width': tf.io.FixedLenFeature([], tf.int64),
    'label': tf.io.FixedLenFeature([], tf.int64),
    'array': tf.io.VarLenFeature(tf.float32)}
    parsed_features = tf.io.parse_single_example(example_proto, features=features)
    height = tf.cast(parsed_features['height'], tf.int32)
    width = tf.cast(parsed_features['width'], tf.int32)
    label = tf.cast(parsed_features['label'], tf.int32)
    array = tf.sparse.to_dense(parsed_features['array'], default_value=0)
    array = tf.reshape(array, [height, width])
    array = tf.cast(array, tf.float32)
    return array, label

## Functions to shuffle our PK batches

In [472]:

def shuffle_class_examples(key, elements, shuffle_size):
    """Shuffle so the order of class examples is different each epoch"""
    elements = elements.shuffle(shuffle_size)
    # ensure we have the same number of examples per class for balanced datasets
    elements = elements.take(PK_CEILING)
    return elements

def shuffle_k_group(key, elements, shuffle_size, batch_k, demo=False):
    """Shuffle so each PK batch contains different speakers"""
    print('Shuffle k group')
    elements = elements.batch(batch_k)
    if not demo:
        elements = elements.shuffle(shuffle_size)
    return elements

## Extra helper functions

In [473]:
def create_dict(array, label):
    """Make the input to our estimator model a dictionary"""
    return dict(features=array), label

def remove_selectors(array_and_label, selector):
    """Remove the k labels created by create_interleave_dataset()"""
    return array_and_label[0], array_and_label[1]

## PK Input function to bring everything together 

In [492]:
def pk_input_fn(batch_p, batch_k, num_classes, num_epochs=1, 
         num_parallel_calls=1, demo=False):
    """Builds tensorflow dataset in PK batches from N TFRecord files each containing examples of a single class

    Implements PK batching from https://arxiv.org/pdf/1703.07737.pdf
    Each batch contains P classes with K examples from each. 
    Shuffles so each batch contains different classes and examples 
    Each epoch will have different PK batches
    
    A note about window size for the group_by_window functions
    window_size:
        A tf.int64 scalar tf.Tensor, representing the 
        number of consecutive elements matching the same key 
        to combine in a single batch, which will be passed 
        to reduce_func.
    We set the window_size by thinking about it. 
    We need to shuffle all the data of each class so
        class_shuffle window_size: PK_CEILING
    We take a k batch from every class in each shuffle group hence
        k_batch_shuffle window size: num_classes * batch_k
    
    Args:
        batch_p (int): number of classes per batch
        batch_k (int): number of examples per class per batch
        num_classes (int): total number of classes
        num_epochs (int):
        num_parallel_calls (int): How many parallelised calls to the dataset map function.
            TF recommend using the number of available CPU cores for its value.

    Returns:
        Dataset: Tensorflow dataset where each batch contains P classes and K examples from each
    """
    # Set up the hashtable to find the number of examples for a class
    set_PID_hashtable()
    
    print('Fetch all data')
    filenames = ['{0}/{1}{2}.tfrecords'.format(DATA_FOLDER, FILE_STEM, i) for i in range(num_classes)]
    dataset = tf.data.TFRecordDataset(filenames)
    print(dataset)
    interleave_keys = create_interleave_dataset()
    print('Process examples')
    dataset = dataset.map(tfr_input_parser, 
                          num_parallel_calls=num_parallel_calls)
    print('Cache')
    dataset = dataset.cache()
    
    print('Put features into dictionary')
    # This was required for TF1, perhaps no longer needed in TF2
    dataset = dataset.map(create_dict, 
                          num_parallel_calls=num_parallel_calls)
    print('Class shuffle')
    # Shuffle class examples
    class_shuffle = tf.data.experimental.group_by_window(
        lambda x, k: tf.cast(k, tf.int64),
        lambda key, x: shuffle_class_examples(key, x, 
                                              shuffle_size=1000),
        window_size=PK_CEILING
    )

    dataset = dataset.apply(class_shuffle)

    print('Interleave and shuffle')
    # Interleave k examples from each speaker
    interleave_window = num_classes*batch_k
    key_dataset = tf.data.Dataset.zip((dataset, interleave_keys))
    k_shuffle = tf.data.experimental.group_by_window(
        lambda example_and_label, k: tf.cast(k, tf.int64),
        lambda key, example_and_label: shuffle_k_group(key, 
                                                       example_and_label,
                                                       shuffle_size=interleave_window,
                                                       batch_k=batch_k,
                                                       demo=demo),
        window_size=num_classes*batch_k
    )
    key_dataset = key_dataset.apply(k_shuffle)
    print('Unbatch')
    dataset = key_dataset.unbatch()
    print('Remove selectors')
    if not demo:
        dataset = key_dataset.map(remove_selectors)
    
    # Create PK batches
    print('repeat')
    dataset = dataset.repeat(num_epochs)
    print('batch')
    dataset = dataset.batch(batch_p)
    dataset = dataset.prefetch(batch_p)
    return dataset



# Create PK datasets

In [493]:
NUM_EPOCHS = 5

# Inspect the datasets we created

### Without shuffling to see whats going on

For the demo we do not shuffle within the PK batches so that the ordering of the ids is the same each time. We also don't remove the interleave key such that we can see how the PK batches were grouped together. This is so we can see whats going on.

In [494]:
demo_dataset = pk_input_fn(BATCH_P, BATCH_K, NUM_CLASSES, 
               num_epochs=NUM_EPOCHS, num_parallel_calls=1,
               demo=True) 
demo_dataset = demo_dataset.enumerate()
print('-'*10)
c = 0
for element in demo_dataset.as_numpy_iterator():
    c += 1
    print('Element:{}    ID:{}    Interleave key: {}'.format(c, element[1][0][1], element[1][1]))    



Created hash table: <tensorflow.python.ops.lookup_ops.StaticHashTable object at 0x7f715c1ff710>
Fetch all data
<TFRecordDatasetV2 shapes: (), types: tf.string>
max elements: 50
max elements: 50
max elements: 50
max elements: 50
max elements: 50
max elements: 50
max elements: 50
max elements: 50
max elements: 50
max elements: 50
Process examples
Cache
Put features into dictionary
Class shuffle
Interleave and shuffle
Shuffle k group
Unbatch
Remove selectors
repeat
batch
----------
Element:1    ID:[0 0 0 0 0]    Interleave key: [0 0 0 0 0]
Element:2    ID:[1 1 1 1 1]    Interleave key: [0 0 0 0 0]
Element:3    ID:[2 2 2 2 2]    Interleave key: [0 0 0 0 0]
Element:4    ID:[3 3 3 3 3]    Interleave key: [0 0 0 0 0]
Element:5    ID:[4 4 4 4 4]    Interleave key: [0 0 0 0 0]
Element:6    ID:[5 5 5 5 5]    Interleave key: [0 0 0 0 0]
Element:7    ID:[6 6 6 6 6]    Interleave key: [0 0 0 0 0]
Element:8    ID:[7 7 7 7 7]    Interleave key: [0 0 0 0 0]
Element:9    ID:[8 8 8 8 8]    Interleave ke

Element:268    ID:[7 7 7 7 7]    Interleave key: [6 6 6 6 6]
Element:269    ID:[8 8 8 8 8]    Interleave key: [6 6 6 6 6]
Element:270    ID:[9 9 9 9 9]    Interleave key: [6 6 6 6 6]
Element:271    ID:[0 0 0 0 0]    Interleave key: [7 7 7 7 7]
Element:272    ID:[1 1 1 1 1]    Interleave key: [7 7 7 7 7]
Element:273    ID:[2 2 2 2 2]    Interleave key: [7 7 7 7 7]
Element:274    ID:[3 3 3 3 3]    Interleave key: [7 7 7 7 7]
Element:275    ID:[4 4 4 4 4]    Interleave key: [7 7 7 7 7]
Element:276    ID:[5 5 5 5 5]    Interleave key: [7 7 7 7 7]
Element:277    ID:[6 6 6 6 6]    Interleave key: [7 7 7 7 7]
Element:278    ID:[7 7 7 7 7]    Interleave key: [7 7 7 7 7]
Element:279    ID:[8 8 8 8 8]    Interleave key: [7 7 7 7 7]
Element:280    ID:[9 9 9 9 9]    Interleave key: [7 7 7 7 7]
Element:281    ID:[0 0 0 0 0]    Interleave key: [8 8 8 8 8]
Element:282    ID:[1 1 1 1 1]    Interleave key: [8 8 8 8 8]
Element:283    ID:[2 2 2 2 2]    Interleave key: [8 8 8 8 8]
Element:284    ID:[3 3 3

### The real thing with all shuffling

For the real thing we batch the elements and perfom all of the shuffling

In [495]:
real_dataset = pk_input_fn(BATCH_P, BATCH_K, NUM_CLASSES, 
               num_epochs=NUM_EPOCHS, num_parallel_calls=1) 
print('-'*10)
c = 0
real_dataset = real_dataset.enumerate()
for element in real_dataset.as_numpy_iterator(): 
    c += 1
    print('PK batch {}'.format(c))  
    # This prints the labels of the examples in our batch
    print(element[1][1])
    print()

Created hash table: <tensorflow.python.ops.lookup_ops.StaticHashTable object at 0x7f715c634610>
Fetch all data
<TFRecordDatasetV2 shapes: (), types: tf.string>
max elements: 50
max elements: 50
max elements: 50
max elements: 50
max elements: 50
max elements: 50
max elements: 50
max elements: 50
max elements: 50
max elements: 50
Process examples
Cache
Put features into dictionary
Class shuffle
Interleave and shuffle
Shuffle k group
Unbatch
Remove selectors
repeat
batch
----------
PK batch 1
[[2 2 2 2 2]
 [7 7 7 7 7]
 [8 8 8 8 8]
 [6 6 6 6 6]
 [5 5 5 5 5]]

PK batch 2
[[1 1 1 1 1]
 [4 4 4 4 4]
 [0 0 0 0 0]
 [9 9 9 9 9]
 [3 3 3 3 3]]

PK batch 3
[[8 8 8 8 8]
 [7 7 7 7 7]
 [6 6 6 6 6]
 [5 5 5 5 5]
 [9 9 9 9 9]]

PK batch 4
[[4 4 4 4 4]
 [0 0 0 0 0]
 [3 3 3 3 3]
 [2 2 2 2 2]
 [1 1 1 1 1]]

PK batch 5
[[8 8 8 8 8]
 [0 0 0 0 0]
 [5 5 5 5 5]
 [4 4 4 4 4]
 [2 2 2 2 2]]

PK batch 6
[[3 3 3 3 3]
 [7 7 7 7 7]
 [9 9 9 9 9]
 [6 6 6 6 6]
 [1 1 1 1 1]]

PK batch 7
[[1 1 1 1 1]
 [7 7 7 7 7]
 [0 0 0 0 0

# Conclusion



That's the implementation of the dataset described in [In Defense of the Triplet Loss for Person Re-Identification](https://arxiv.org/pdf/1703.07737.pdf). Its a simple concept in the paper, but setting it up in tensorflow isn't as straight forward. But the core ideas here should let you implement an input pipeline for your model to optimise triplet loss training.