In [2]:
import numpy as np
import matplotlib.pyplot as plt

import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T

import os
import random
from PIL import Image

random.seed(0) #for reproducibility

We want to implement the `Batch Hard` training strategy from the "In Defense of the Triplet Loss" paper (2017). Given the features from the backbone of shape `NxD`, where `N` is the number of examples in the batch and `D` is the embedding dimension, we wish to compute valid triples `(A,P,N)` where A is an anchor, `(A,P)` is a positive sample with the sample class and `(A,N)` is a negative sample with different classes. The batch hard strategy requires that we filter only those triples for which `d(A,P)` is as small as possible and `d(A,N)` is as large as possible for a given anchor `A`.

In [16]:
def compute_distance_matrix(features, distance_fn="L2", eps=1e-12):
    """Assumes features is an NxD matrix. Computes NxN pairwise distance matrix between the rows and returns it.
    Currently assumes that the distance function is the L2 distance function as the method below only works
    for this function. Alternative functions might require longer computation since easy vectorization may
    not be possible."""
    
    #To compute the distance matrix we use |a-b|^2 = |a|^2 - 2 a.b + |b|^2
    X = features @ features.T #X has shape NxN and X[i,j] = f_i . f_j
    sq_norm_features = torch.diagonal(X) #length N, f_i . f_i
    squared_dist = sq_norm_features.unsqueeze(1) - 2 * X + sq_norm_features.unsqueeze(0)
    
    #we need to ensure positivity before taking square root due to floating point errors
    squared_dist = F.relu(squared_dist)
    
    #for numerical stability we add epsilon whereever squared_dist is 0 before taking sqrt
    #this helps during backprop because the derivative of sqrt(x) is 1/2sqrt(x). 
    #See https://www.robots.ox.ac.uk/~albanie/notes/Euclidean_distance_trick.pdf
    
    zero_mask = squared_dist == 0.
    
    
    squared_dist += zero_mask.to(torch.float) * eps #whereever it is zero, add epsilon
    
    distance_matrix = torch.sqrt(squared_dist)
    distance_matrix[zero_mask] = 0.
    
    return distance_matrix
    


In [18]:
a = torch.tensor([[3,2], [1,4], [5,6]], dtype=torch.float)

compute_distance_matrix(a)

tensor([[0.0000, 2.8284, 4.4721],
        [2.8284, 0.0000, 4.4721],
        [4.4721, 4.4721, 0.0000]])

In [21]:
np.sqrt((3-1)**2 + (2-4)**2), np.sqrt((3-5)**2 + (2-6)**2)

(2.8284271247461903, 4.47213595499958)

Now, given a distance matrix, and true labels, we want to generate a `Mx3` tensor of valid triples `(i,j,k)` corresponding to the batch hard strategy.

In [38]:
def get_triple_indices(distance_matrix, labels):
    """Given an NxN distance matrix between and a N label array returns a tensor of length Mx3 with i,j,k being valid triples
    such that for any i, (i,j) is the furthest apart positive example and (i,k) is the closest negative example.
    
    The sampling strategy ensures that every batch has at least K elements in one class and at least P diff. classes."""
    #dist_matrix = distance_matrix.copy() #don't want changes affected
    N = distance_matrix.shape[0]
    
    
    #we want a mask such that mask[i,j] = True when
    #i and j are the same class
    
    same_class_mask = labels.unsqueeze(0) == labels.unsqueeze(1)
    diff_class_mask = ~same_class_mask
    
    #since we want to pick the farthest index in the same class
    #we use a mask to set the distance of all the (i,j) with diff
    #classes to -1. That way such indices will not affect the
    #maximum selection.
    
    #is 1+d[i,j] whenever i,j are diff class. else 0
    temp_matrix = diff_class_mask * (1 + distance_matrix)
    
    #d[i,j] when i,j same class, otherwise -1
    modified_dist_matrix = distance_matrix - temp_matrix
    
    _, farthest = torch.max(modified_dist_matrix, dim=1) #shape is Nx1 and furthest[i] = j if fj is furthest from fi
    
    
    
    M = torch.max(distance_matrix) + 1
    
    #to get the closest distance in a different class
    #we set d[i,j] = M if i,j in the same class
    
    #is M-d[i,j] if i,j  same class else 0
    temp_matrix = same_class_mask * (M - distance_matrix)
    
    #M is i,j same class else d[i,j]
    modified_dist_matrix = distance_matrix + temp_matrix
    
    _, closest = torch.min(modified_dist_matrix, dim=1) 
    
    
    
    return torch.stack([torch.arange(N), farthest, closest], dim=1)
    

    

In [112]:
import random
random.seed(42)
torch.manual_seed(42)

def test_triplet_generator():
    dist_matrix = torch.rand((5,5)) * 5
    dist_matrix -= torch.eye(5) * dist_matrix

    labels = torch.tensor(random.choices(range(5), k=5))
    print(f"the matrix is \ndist_matrix = \n {dist_matrix},\n labels = \n{labels.view((-1,1))} \n\n\n\n\n" + "*"*50)
    
    print(get_triple_indices(dist_matrix, labels))
    
    return dist_matrix
       

In [113]:
dist_matrix = test_triplet_generator()


the matrix is 
dist_matrix = 
 tensor([[0.0000, 4.5750, 1.9143, 4.7965, 1.9522],
        [3.0045, 0.0000, 3.9682, 4.7039, 0.6659],
        [4.6730, 2.9679, 0.0000, 2.8386, 3.7055],
        [2.1470, 4.4272, 2.8695, 0.0000, 3.1372],
        [1.3482, 2.2068, 1.4846, 4.1584, 0.0000]]),
 labels = 
tensor([[3],
        [0],
        [1],
        [1],
        [3]]) 




**************************************************
tensor([[0, 4, 2],
        [1, 1, 4],
        [2, 3, 1],
        [3, 2, 0],
        [4, 0, 2]])


In [114]:
a = torch.rand((3,3))
print(a)

a[torch.tensor([0,1,1,2])]

tensor([[0.2695, 0.3588, 0.1994],
        [0.5472, 0.0062, 0.9516],
        [0.0753, 0.8860, 0.5832]])


tensor([[0.2695, 0.3588, 0.1994],
        [0.5472, 0.0062, 0.9516],
        [0.5472, 0.0062, 0.9516],
        [0.0753, 0.8860, 0.5832]])

In [115]:
def get_valid_triples(features, labels):
    """Given an Nxd matrix of N features and corresponding N labels, returns f1, f2, f3, each of shape Nxd, giving
    anchors, positives, and negatives."""
    
    dist_matrix = compute_distance_matrix(features)
    
    indices = get_triple_indices(dist_matrix, labels)
    ind1, ind2, ind3 = indices.T
    
    return features[ind1], features[ind2], features[ind3]
    