Related to https://github.com/google-research/ibc

In [1]:
import os
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"

import tensorflow as tf

In [2]:
def info_nce(predictions):
    kl = tf.keras.losses.KLDivergence(reduction=tf.keras.losses.Reduction.NONE)
    batch_size, num_neg_and_pos = predictions.shape
    num_counter_examples = num_neg_and_pos - 1

    softmax_temperature = 1.0
    softmaxed_predictions = tf.nn.softmax(
        predictions / softmax_temperature, axis=-1)

    # [B x n+1] with 1 in column [:, -1]
    indices = tf.ones(
        (batch_size,), dtype=tf.int32) * num_counter_examples
    labels = tf.one_hot(indices, depth=num_counter_examples + 1)

    per_example_loss = kl(labels, softmaxed_predictions)

    return per_example_loss


def simple_info_nce(energies):
    _, num_neg_and_pos = energies.shape
    pos_sample_index = num_neg_and_pos - 1
    log_softmin_all = tf.nn.log_softmax(-energies)
    per_example_loss = -log_softmin_all[:, pos_sample_index]
    return per_example_loss

In [3]:
batch_size = 3
num_neg_and_pos = 4
tf.random.set_seed(0)
energies = tf.random.uniform(shape=(batch_size, num_neg_and_pos))

a = info_nce(energies)
ahat = info_nce(-energies)  # :(
b = simple_info_nce(energies)

print(a)
print(ahat)
print(b)

tf.Tensor([1.235404  1.0950543 1.2996238], shape=(3,), dtype=float32)
tf.Tensor([1.5604469 1.7325498 1.5327481], shape=(3,), dtype=float32)
tf.Tensor([1.5604513 1.7325542 1.5327525], shape=(3,), dtype=float32)
