In [2]:
import numpy as np
import matplotlib.pyplot as plt

import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T

import os
import random
from PIL import Image

We want to implement the `Batch Hard` training strategy from the "In Defense of the Triplet Loss" paper (2017). Given the features from the backbone of shape `NxD`, where `N` is the number of examples in the batch and `D` is the embedding dimension, we wish to compute valid triples `(A,P,N)` where A is an anchor, `(A,P)` is a positive sample with the sample class and `(A,N)` is a negative sample with different classes. The batch hard strategy requires that we filter only those triples for which `d(A,P)` is as small as possible and `d(A,N)` is as large as possible for a given anchor `A`.

In [16]:
def compute_distance_matrix(features, distance_fn="L2", eps=1e-12):
    """Assumes features is an NxD matrix. Computes NxN pairwise distance matrix between the rows and returns it.
    Currently assumes that the distance function is the L2 distance function as the method below only works
    for this function. Alternative functions might require longer computation since easy vectorization may
    not be possible."""
    
    #To compute the distance matrix we use |a-b|^2 = |a|^2 - 2 a.b + |b|^2
    X = features @ features.T #X has shape NxN and X[i,j] = f_i . f_j
    sq_norm_features = torch.diagonal(X) #length N, f_i . f_i
    squared_dist = sq_norm_features.unsqueeze(1) - 2 * X + sq_norm_features.unsqueeze(0)
    
    #we need to ensure positivity before taking square root due to floating point errors
    squared_dist = F.relu(squared_dist)
    
    #for numerical stability we add epsilon whereever squared_dist is 0 before taking sqrt
    #this helps during backprop because the derivative of sqrt(x) is 1/2sqrt(x). 
    #See https://www.robots.ox.ac.uk/~albanie/notes/Euclidean_distance_trick.pdf
    
    zero_mask = squared_dist == 0.
    
    
    squared_dist += zero_mask.to(torch.float) * eps #whereever it is zero, add epsilon
    
    distance_matrix = torch.sqrt(squared_dist)
    distance_matrix[zero_mask] = 0.
    
    return distance_matrix
    


In [17]:
(torch.sqrt(torch.zeros((4,4))) == 0) *.4

tensor([[0.4000, 0.4000, 0.4000, 0.4000],
        [0.4000, 0.4000, 0.4000, 0.4000],
        [0.4000, 0.4000, 0.4000, 0.4000],
        [0.4000, 0.4000, 0.4000, 0.4000]])

In [18]:
a = torch.tensor([[3,2], [1,4], [5,6]], dtype=torch.float)

compute_distance_matrix(a)

tensor([[0.0000, 2.8284, 4.4721],
        [2.8284, 0.0000, 4.4721],
        [4.4721, 4.4721, 0.0000]])

In [21]:
np.sqrt((3-1)**2 + (2-4)**2), np.sqrt((3-5)**2 + (2-6)**2)

(2.8284271247461903, 4.47213595499958)

Now, given a distance matrix, and true labels, we want to generate a `Mx3` tensor of valid triples `(i,j,k)` corresponding to the batch hard strategy.

In [32]:
def get_valid_triples(distance_matrix, labels):
    """Given an NxN distance matrix between and a N label array returns a tensor of length Mx3 with i,j,k being valid triples
    such that for any i, (i,j) is the furthest apart positive example and (i,k) is the closest negative example.
    
    The sampling strategy ensures that every batch has at least K elements in one class and at least P diff. classes."""
    N = distance_matrix.shape[0]
    
    _, furthest_indices = torch.max(distance_matrix, dim=1, keepdim=True) #shape is Nx1 and furthest[i] = j if j is furthest from i
    max_val = torch.max(distance_matrix) + 1
    _, closest_indices = torch.min(distance_matrix + torch.diag(torch.ones(N) * max_val)) #we want to minimize d(i,j) for i=/=j
    
    

In [23]:
torch.max(torch.arange(6).reshape((2,3)), dim=1)

torch.return_types.max(
values=tensor([2, 5]),
indices=tensor([2, 2]))

In [27]:
torch.diag(torch.ones(3) * 4)

tensor([[4., 0., 0.],
        [0., 4., 0.],
        [0., 0., 4.]])

In [34]:
import random
def test_triplet_generator():
    dist_matrix = torch.rand((5,5)) * 5
    labels = torch.tensor(random.choices(range(5), k=5))
    print(f"the matrix is dist_matrix = {dist_matrix}, labels = {labels} \n\n\n\n\n" + "*"*50)
    
    print(get_valid_triples(dist_matrix, labels))
       

In [35]:
test_triplet_generator()

the matrix is dist_matrix = tensor([[2.1554, 3.8405, 4.9992, 3.6540, 1.8460],
        [4.5680, 3.2760, 4.7992, 2.3844, 4.5190],
        [2.6432, 2.3703, 1.1153, 2.0088, 3.0268],
        [4.9789, 2.1472, 0.0962, 2.6718, 0.7436],
        [0.2879, 2.2944, 0.4205, 2.9724, 2.8840]]), labels = tensor([0, 1, 1, 4, 3]) 




**************************************************


TypeError: iteration over a 0-d tensor