In [1]:
import torch
from torch import nn
from torch.nn import functional as F

torch.manual_seed(42)

<torch._C.Generator at 0x72b9840ad890>

In [2]:
batch_size = 64
hidden_size = 128

In [3]:
eye = torch.eye(batch_size)
eye

tensor([[1., 0., 0.,  ..., 0., 0., 0.],
        [0., 1., 0.,  ..., 0., 0., 0.],
        [0., 0., 1.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 1., 0., 0.],
        [0., 0., 0.,  ..., 0., 1., 0.],
        [0., 0., 0.,  ..., 0., 0., 1.]])

In [4]:
# Model output
x = torch.randn(batch_size, hidden_size)
x

tensor([[ 1.9269,  1.4873,  0.9007,  ...,  0.3399,  0.7200,  0.4114],
        [ 1.9312,  1.0119, -1.4364,  ...,  0.5655,  0.5058,  0.2225],
        [-0.6855,  0.5636, -1.5072,  ...,  0.8541, -0.4901, -0.3595],
        ...,
        [-0.1587,  1.6984, -0.0560,  ...,  0.1716,  0.8127, -0.6369],
        [-1.3467,  0.6522, -1.3508,  ..., -0.4601,  0.1815,  0.1850],
        [ 0.7205, -0.2833,  0.0937,  ...,  0.5409,  0.6940,  1.8563]])

In [5]:
# compute sim between every pair
similarity_matrix = F.cosine_similarity(
    x.reshape(1, batch_size, hidden_size), 
    x.reshape(batch_size, 1, hidden_size), 
    dim=-1
)
similarity_matrix

tensor([[ 1.0000,  0.1650, -0.0320,  ..., -0.0258, -0.0906,  0.0983],
        [ 0.1650,  1.0000,  0.1169,  ..., -0.1471, -0.0317,  0.0102],
        [-0.0320,  0.1169,  1.0000,  ...,  0.1656,  0.1017, -0.0095],
        ...,
        [-0.0258, -0.1471,  0.1656,  ...,  1.0000, -0.0422, -0.0437],
        [-0.0906, -0.0317,  0.1017,  ..., -0.0422,  1.0000,  0.0327],
        [ 0.0983,  0.0102, -0.0095,  ..., -0.0437,  0.0327,  1.0000]])

In [6]:
# discard main diagonal
similarity_matrix[eye.bool()] = float("-inf")
similarity_matrix

tensor([[   -inf,  0.1650, -0.0320,  ..., -0.0258, -0.0906,  0.0983],
        [ 0.1650,    -inf,  0.1169,  ..., -0.1471, -0.0317,  0.0102],
        [-0.0320,  0.1169,    -inf,  ...,  0.1656,  0.1017, -0.0095],
        ...,
        [-0.0258, -0.1471,  0.1656,  ...,    -inf, -0.0422, -0.0437],
        [-0.0906, -0.0317,  0.1017,  ..., -0.0422,    -inf,  0.0327],
        [ 0.0983,  0.0102, -0.0095,  ..., -0.0437,  0.0327,    -inf]])

In [7]:
# target matrix
target_matrix = torch.arange(batch_size)
target_matrix[0::2] += 1
target_matrix[1::2] -= 1
target_matrix

tensor([ 1,  0,  3,  2,  5,  4,  7,  6,  9,  8, 11, 10, 13, 12, 15, 14, 17, 16,
        19, 18, 21, 20, 23, 22, 25, 24, 27, 26, 29, 28, 31, 30, 33, 32, 35, 34,
        37, 36, 39, 38, 41, 40, 43, 42, 45, 44, 47, 46, 49, 48, 51, 50, 53, 52,
        55, 54, 57, 56, 59, 58, 61, 60, 63, 62])

In [8]:
index = target_matrix.reshape(batch_size, 1).long()
index

tensor([[ 1],
        [ 0],
        [ 3],
        [ 2],
        [ 5],
        [ 4],
        [ 7],
        [ 6],
        [ 9],
        [ 8],
        [11],
        [10],
        [13],
        [12],
        [15],
        [14],
        [17],
        [16],
        [19],
        [18],
        [21],
        [20],
        [23],
        [22],
        [25],
        [24],
        [27],
        [26],
        [29],
        [28],
        [31],
        [30],
        [33],
        [32],
        [35],
        [34],
        [37],
        [36],
        [39],
        [38],
        [41],
        [40],
        [43],
        [42],
        [45],
        [44],
        [47],
        [46],
        [49],
        [48],
        [51],
        [50],
        [53],
        [52],
        [55],
        [54],
        [57],
        [56],
        [59],
        [58],
        [61],
        [60],
        [63],
        [62]])

In [9]:
# create labels matrix
zeros = torch.zeros(batch_size, batch_size).long()
ones = torch.ones(batch_size, batch_size).long()

ground_truth_labels = torch.scatter(zeros, 1, index, ones)
ground_truth_labels

tensor([[0, 1, 0,  ..., 0, 0, 0],
        [1, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        ...,
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 1],
        [0, 0, 0,  ..., 0, 1, 0]])

In [10]:
F.cross_entropy(similarity_matrix, target_matrix, reduction="mean")

tensor(4.1355)

In [11]:
# implementation
def nt_xent_loss(model_output, temperature):
    """Calculate NT-Xent loss.

    Args:
        model_output (Tensor): Model output
        temperature (float): Loss temperature

    Returns:
        Tensor: NT-Xent loss
    """
    batch_size = model_output.shape[0]
    hidden_size = model_output.shape[1]
    
    # Cosine similarity
    similarity_matrix = F.cosine_similarity(
        x.reshape(1, batch_size, hidden_size),
        x.reshape(batch_size, 1, hidden_size),
        dim=-1
    )

    # Discard main diagonal
    similarity_matrix[torch.eye(batch_size).bool()] = float("-inf")

    # Labels
    labels = torch.arange(batch_size)
    labels[0::2] += 1
    labels[1::2] -= 1
    
    # Compute cross entropy loss
    return F.cross_entropy(similarity_matrix / temperature, labels, reduction="mean")

torch.manual_seed(42)
batch = torch.randn(batch_size, hidden_size)

for t in (0.01, 0.1, 0.5, 1.0, 10.0):
    print(f"Temperature: {t:.2f}, Loss: {nt_xent_loss(batch, temperature=t)}")

Temperature: 0.01, Loss: 18.72327995300293
Temperature: 0.10, Loss: 4.395903587341309
Temperature: 0.50, Loss: 4.135308742523193
Temperature: 1.00, Loss: 4.135462760925293
Temperature: 10.00, Loss: 4.142027854919434
