In [1]:
import torch

In [2]:
BATCH_SIZE = 8
DIMS = 128
NEGS = 4

queries = torch.randn(BATCH_SIZE, DIMS)
docs = torch.randn(BATCH_SIZE, NEGS, DIMS)
labels = torch.randint(0, NEGS, (BATCH_SIZE,))


## Using cosine similarity

In [3]:
import torch.nn.functional as F

In [4]:
similarities = F.cosine_similarity(queries.unsqueeze(1), docs, dim=2)
print(similarities)


tensor([[-0.0301, -0.0199,  0.0759, -0.0753],
        [-0.1443, -0.0074,  0.0205,  0.0678],
        [-0.0662, -0.1729, -0.0794, -0.0704],
        [-0.0774,  0.1044, -0.0715,  0.0125],
        [ 0.0147, -0.0013, -0.0150, -0.0045],
        [-0.0072,  0.0348,  0.0216, -0.0874],
        [ 0.1636, -0.0025,  0.1079,  0.0734],
        [-0.1126, -0.0049, -0.1203,  0.0408]])


## Manual

In [5]:
similarities_2 = torch.sum(
    F.normalize(queries.unsqueeze(1)) * F.normalize(docs),
    dim=2,
)
print(similarities_2.shape)


torch.Size([8, 4])


In [6]:
torch.allclose(similarities, similarities_2)

False

In [7]:
similarities_2

tensor([[ -2.6535,  -1.7800,   2.1008,  -5.1477],
        [ -7.7001,   3.3407,   5.0736,   3.3899],
        [  2.9544,  -7.6617,  -0.9367,  -2.0208],
        [ -5.7524,   9.2913,   1.8645,  -4.1168],
        [  0.5855,  -6.8271,  -2.5184,  -1.4178],
        [  2.7816,   2.4713,   0.4089,  -2.1736],
        [  8.5400,   1.5806,   4.0709,   6.7426],
        [ -3.7116,  -0.9620, -15.0059,   4.9379]])

In [8]:
similarities

tensor([[-0.0301, -0.0199,  0.0759, -0.0753],
        [-0.1443, -0.0074,  0.0205,  0.0678],
        [-0.0662, -0.1729, -0.0794, -0.0704],
        [-0.0774,  0.1044, -0.0715,  0.0125],
        [ 0.0147, -0.0013, -0.0150, -0.0045],
        [-0.0072,  0.0348,  0.0216, -0.0874],
        [ 0.1636, -0.0025,  0.1079,  0.0734],
        [-0.1126, -0.0049, -0.1203,  0.0408]])

## Loss Calculation

In [9]:
mask = torch.ones_like(similarities)
mask[torch.arange(BATCH_SIZE), labels] = 0

In [10]:
pos_sim = similarities[mask == 0].view(BATCH_SIZE, 1)
pos_sim


tensor([[-0.0301],
        [ 0.0205],
        [-0.0794],
        [-0.0715],
        [-0.0150],
        [-0.0072],
        [ 0.0734],
        [ 0.0408]])

In [11]:

neg_sim = similarities[mask == 1].view(BATCH_SIZE, -1)
neg_sim


tensor([[-0.0199,  0.0759, -0.0753],
        [-0.1443, -0.0074,  0.0678],
        [-0.0662, -0.1729, -0.0704],
        [-0.0774,  0.1044,  0.0125],
        [ 0.0147, -0.0013, -0.0045],
        [ 0.0348,  0.0216, -0.0874],
        [ 0.1636, -0.0025,  0.1079],
        [-0.1126, -0.0049, -0.1203]])

In [12]:
temperature = 0.07

In [13]:
pos_exp = torch.exp(pos_sim / temperature)
neg_exp = torch.exp(neg_sim / temperature)

loss = -torch.log(pos_exp / (neg_exp.sum(dim=1)))
loss.mean()


tensor(1.3339)