In [1]:
# tensorflow 2.0
import tensorflow as tf
size = 32
embedding_dim = 16
num_positive_samples = 3
num_negative_samples = 6
batch_size = 8

In [2]:
# generate data
tf.random.set_seed(1234)
# positive
subset = tf.random.uniform((batch_size, num_positive_samples + 1), 
                  maxval=size, dtype=tf.int32)
src, pos = tf.split(subset, [1, -1], axis=-1)
assert src.shape == (batch_size,  1)
assert pos.shape == (batch_size, num_positive_samples)
# negative
negs = tf.random.uniform((batch_size, num_negative_samples), 
                  maxval=size, dtype=tf.dtypes.int32)
assert negs.shape == (batch_size, num_negative_samples)

In [3]:
# embed_table
embed_table = tf.random.uniform((size, embedding_dim))

In [4]:
# encode
embedding = tf.nn.embedding_lookup(embed_table, src)
assert embedding.shape == (batch_size,  1, embedding_dim)
pos_embedding = tf.nn.embedding_lookup(embed_table, pos)
assert pos_embedding.shape == (batch_size,  num_positive_samples, embedding_dim)
negs_embedding = tf.nn.embedding_lookup(embed_table, negs)
assert negs_embedding.shape == (batch_size,  num_negative_samples, embedding_dim)

logits = tf.matmul(embedding, pos_embedding, transpose_b=True)
assert logits.shape == (batch_size,  1, num_positive_samples)
negs_logits = tf.matmul(embedding, negs_embedding, transpose_b=True)
assert negs_logits.shape == (batch_size,  1, num_negative_samples)

In [5]:
# compute mrr
mrr_all = tf.concat((negs_logits, logits), axis=2)
mrr_size = mrr_all.shape[2]

In [6]:
_, indices_of_ranks = tf.nn.top_k(mrr_all, k=mrr_size)
_, ranks = tf.nn.top_k(-indices_of_ranks, k=mrr_size)
mrr = tf.reduce_mean(tf.math.reciprocal(tf.cast(ranks[:, :, -1] + 1, tf.float32)))
mrr

<tf.Tensor: id=46, shape=(), dtype=float32, numpy=0.4496528>

In [7]:
# compute loss, method one
pos_xent = tf.nn.sigmoid_cross_entropy_with_logits(
    labels=tf.ones_like(logits), logits=logits)
negs_xent = tf.nn.sigmoid_cross_entropy_with_logits(
          labels=tf.zeros_like(negs_logits), logits=negs_logits)
loss1 = tf.reduce_sum(pos_xent) + tf.reduce_sum(negs_xent)

In [8]:
# compute loss, method two
pos_loss = tf.math.log_sigmoid(tf.reduce_sum(logits, axis=2))
negs_loss = tf.math.log_sigmoid(tf.reduce_sum(-negs_logits, axis=2))
loss2 = -tf.math.reduce_sum(pos_loss + negs_loss)

In [9]:
loss1, loss2

(<tf.Tensor: id=75, shape=(), dtype=float32, numpy=216.56586>,
 <tf.Tensor: id=90, shape=(), dtype=float32, numpy=215.28963>)