This paper is very much alike the SimCLR paper in a supervised learning framework.

Motivation: contrastive loss and triplet loss often suffer from slow convergence, partially because they employ only one negative examples.

Hard negative class mining(vs hard negative instance): greedily adds examples to form a batch from a class that violates the constraint with the previously selected classes in the batch.

Loss function:
$$L(\{x,x^+, \{x_i\}_{i=1}^{N-1}\})=\log\left(1+\sum_{i=1}^{N-1}\exp(f^{\top} f_i-f^{\top} f^+)\right)$$

similar to the multi-class logistic loss (softmax loss).

https://github.com/KevinMusgrave/pytorch-metric-learning/tree/master/src/pytorch_metric_learning/losses

In [None]:
def get_all_pairs_indices(labels, ref_labels=None):
    """
    Given a tensor of labels, this will return 4 tensors.
    The first 2 tensors are the indices which form all positive pairs
    The second 2 tensors are the indices which form all negative pairs
    """
    if ref_labels is None:
        ref_labels = labels
    labels1 = labels.unsqueeze(1)
    labels2 = ref_labels.unsqueeze(0)
    matches = (labels1 == labels2).byte()
    diffs = matches ^ 1
    if ref_labels is labels:
        matches.fill_diagonal_(0)
    a1_idx, p_idx = torch.where(matches)
    a2_idx, n_idx = torch.where(diffs)
    return a1_idx, p_idx, a2_idx, n_idx


def convert_to_pairs(indices_tuple, labels):
    """
    This returns anchor-positive and anchor-negative indices,
    regardless of what the input indices_tuple is
    Args:
        indices_tuple: tuple of tensors. Each tensor is 1d and specifies indices
                        within a batch
        labels: a tensor which has the label for each element in a batch
    """
    if indices_tuple is None:
        return get_all_pairs_indices(labels)
    elif len(indices_tuple) == 4:
        return indices_tuple
    else:
        a, p, n = indices_tuple
        return a, p, a, n


def convert_to_pos_pairs_with_unique_labels(indices_tuple, labels):
    a, p, _, _ = convert_to_pairs(indices_tuple, labels)
    _, unique_idx = np.unique(labels[a].cpu().numpy(), return_index=True)
    return a[unique_idx], p[unique_idx]

In [None]:
import torch

from ..distances import DotProductSimilarity
from ..utils import common_functions as c_f
from ..utils import loss_and_miner_utils as lmu
from .base_metric_loss_function import BaseMetricLossFunction


class NPairsLoss(BaseMetricLossFunction):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.add_to_recordable_attributes(name="num_pairs", is_stat=True)
        self.cross_entropy = torch.nn.CrossEntropyLoss(reduction="none")

    def compute_loss(self, embeddings, labels, indices_tuple):
        anchor_idx, positive_idx = lmu.convert_to_pos_pairs_with_unique_labels(
            indices_tuple, labels
        )
        self.num_pairs = len(anchor_idx)
        if self.num_pairs == 0:
            return self.zero_losses()
        anchors, positives = embeddings[anchor_idx], embeddings[positive_idx]
        targets = c_f.to_device(torch.arange(self.num_pairs), embeddings)
        sim_mat = self.distance(anchors, positives)
        if not self.distance.is_inverted:
            sim_mat = -sim_mat
        return {
            "loss": {
                "losses": self.cross_entropy(sim_mat, targets),
                "indices": anchor_idx,
                "reduction_type": "element",
            }
        }

    def get_default_distance(self):
        return DotProductSimilarity()