# InfoNCE Loss

$$
\mathcal{L}_{info} = -\frac{1}{N}\sum_{i=1}^N \log\frac{\exp(sim(x_{i}, x_i^+)/\tau)}{\sum_{k=1}^{K} \exp(sim(x_i, x_k^{-})/\tau)}
$$

In [None]:
import torch
import torch.nn.functional as F

def info_nce_loss(pairs, temperature=0.07):
    """
    InfoNCE损失函数，使用in-batch样本作为negative sample
    Args:
        pairs (torch.Tensor): 输入样本对, shape: [batch_size, 2, feature_dim]
        temperature (float): 温度系数
    """
    z1 = F.normalize(pairs[:, 0], p=2, dim=1) # L2归一化后余弦相似度等价于点积
    z2 = F.normalize(pairs[:, 1], p=2, dim=1) 
    
    sim_matrix = z1 @ z2.T / temperature # 计算相似度矩阵
    pos_sim = sim_matrix.diagonal() 
    total_sim = torch.logsumexp(sim_matrix, dim=1) # 分母：所有样本的相似度（包括正样本和负样本）
    
    # -log(exp(pos_sim) / exp(total_sim))
    loss = -pos_sim + total_sim
    
    return loss.mean()