In [1]:
import torch
import torch.nn.functional as F

size = 32
embedding_dim = 16
num_positive_samples = 3
num_negative_samples = 6
batch_size = 8

In [2]:
# generate data
torch.manual_seed(1234)
# positive
subset = torch.randint(size, size=(batch_size, num_positive_samples + 1))
src, pos = subset.split((1, num_positive_samples), dim=-1)
assert src.size() == (batch_size,  1)
assert pos.size() == (batch_size, num_positive_samples)
# negative
negs = torch.randint(size, size=(batch_size, num_negative_samples))
assert negs.size() == (batch_size, num_negative_samples)

# embedder
embedder = torch.nn.Embedding(size, embedding_dim)
embedder.reset_parameters()

In [3]:
with torch.no_grad():
    # encode
    embedding = embedder(src)
    assert embedding.size() == (batch_size,  1, embedding_dim)
    pos_embedding = embedder(pos)
    assert pos_embedding.size() == (batch_size,  num_positive_samples, embedding_dim)
    negs_embedding = embedder(negs)
    assert negs_embedding.size() == (batch_size,  num_negative_samples, embedding_dim)
    
    logits = torch.sum(embedding * pos_embedding, dim=2)
    assert logits.shape == (batch_size, num_positive_samples)
    negs_logits = torch.sum(embedding * negs_embedding, dim=2)
    assert negs_logits.shape == (batch_size, num_negative_samples)
    
    # compute mrr
    mrr_all = torch.cat((negs_logits, logits), dim=-1)
    mrr_size = mrr_all.shape[-1]
    _, indices_of_ranks = mrr_all.topk(mrr_size)
    _, ranks = (-indices_of_ranks).topk(mrr_size)
    mrr = ranks[:,-1].float().reciprocal().mean()
    print('mrr', mrr)
    
    # compute loss
    pos_loss = F.logsigmoid(logits).sum(dim=-1)
    negs_loss = F.logsigmoid(-negs_logits).sum(dim=-1)
    loss = -(pos_loss + negs_loss).mean()
    print('loss', loss)

mrr tensor(0.3159)
loss tensor(20.9239)
