In [1]:
import numpy as np
import torch
import torch.nn.functional as F

In [117]:
from pytorch_metric_learning.losses import NTXentLoss

In [2]:
CE = torch.nn.CrossEntropyLoss()

In [3]:
def contrastive_loss(v1, v2):
    logits = torch.matmul(v1, torch.transpose(v2, 0, 1))
    labels = torch.arange(logits.shape[0], device=v1.device)
    return CE(logits, labels) + CE(torch.transpose(logits, 0, 1), labels)

In [121]:
# Create two random batch embeddings
batch_size = 16
embedding_size = 128

torch.manual_seed(0)

v1 = torch.randn(batch_size, embedding_size)
v2 = torch.randn(batch_size, embedding_size)


# Compute InfoNCE loss
loss = contrastive_loss(v1, v2)
print(loss)

labels = torch.arange(batch_size, device=v1.device)
loss = NTXentLoss()(v1, v2, labels)

tensor(37.6935)


ValueError: labels must be a 1D tensor of shape (batch_size,)

In [119]:

loss = NTXentLoss()
temperature = 1.0
logits = torch.matmul(v1, v2.t()) / temperature
labels = torch.diagonal(logits)

loss(logits, labels)

tensor(0.)

In [None]:
# info_nce_loss = NTXentLoss()

def info_nce_loss__(v1, v2, temperature=0.7):
    logits = torch.matmul(v1, torch.transpose(v2, 0, 1))

    # Generate labels
    labels = torch.arange(logits.shape[0], device=v1.device)

    # Cross-Entropy Loss for positive pairs
    ce_positive = CE(logits, labels)

    # Cross-Entropy Loss for negative pairs
    ce_negative = CE(torch.transpose(logits, 0, 1), labels)

    # Compute mutual information
    prob_positive = torch.nn.functional.softmax(logits / temperature, dim=-1)
    prob_negative = torch.nn.functional.softmax(torch.transpose(logits, 0, 1) / temperature, dim=-1)

    mi_loss = -torch.sum(prob_positive * torch.log(prob_positive / prob_negative)) / logits.shape[0]

    # InfoCNE Loss
    infocne_loss = ce_positive + ce_negative + mi_loss

    return infocne_loss

def infoCNE_loss(v1, v2, temperature=0.7):
    # Normalize embeddings
    v1_normalized = F.normalize(v1, dim=-1, p=2)
    v2_normalized = F.normalize(v2, dim=-1, p=2)

    # Compute cosine similarity
    sim_matrix = torch.matmul(v1_normalized, v2_normalized.t()) / temperature

    # Compute positive pair similarity (diagonal of the similarity matrix)
    pos_pair_sim = torch.diagonal(sim_matrix)

    # Compute negative pair similarity
    neg_pair_sim = torch.logsumexp(sim_matrix, dim=1) - pos_pair_sim

    # Compute InfoCNE loss
    loss = -torch.mean(pos_pair_sim - neg_pair_sim)

    return loss
