### Comparing InfoNCE and Adjusted InfoNCE Losses

In [5]:
import torch

class InfoNceLoss(torch.nn.Module):
    def __init__(self, temperature):
        super(InfoNceLoss, self).__init__()
        self.temperature = temperature

    def forward(self, z_recovered, z_sim_recovered, z_neg_recovered):
        # Compute the dot products between each z and each "negative" sample
        neg = torch.einsum("ij,kj -> ik", z_recovered, z_neg_recovered)

        # Compute the dot product between each z and recovered (positive sample)
        pos = torch.einsum("ij,ij -> i", z_recovered, z_sim_recovered)

        neg_and_pos = torch.cat((neg, pos.unsqueeze(1)), dim=1)

        loss_pos = -pos / self.temperature
        loss_neg = torch.logsumexp(neg_and_pos / self.temperature, dim=1)
        
        total_loss = (loss_pos + loss_neg).mean()
        
        # Split for monitoring
        pos_component = loss_pos.mean()
        neg_component = loss_neg.mean()

        return pos_component.detach().item(), neg_component.detach().item(), total_loss

In [59]:
class InfoNceLossAdjusted(torch.nn.Module):
    def __init__(self, temperature):
        super(InfoNceLossAdjusted, self).__init__()
        self.temperature = temperature
    
    def forward(self, z_recovered, z_enc_sim, z_enc_neg):
        """
        Standard InfoNCE formulation: -log(exp(pos/τ) / (exp(pos/τ) + Σexp(neg/τ)))
        """
        # Positive similarity
        pos_sim = (z_recovered * z_enc_sim).sum(dim=-1) / self.temperature  # [N]
        
        # Negative similarities  
        neg_sim = (z_recovered.unsqueeze(1) * z_enc_neg).sum(dim=-1) / self.temperature  # [N, M]
        neg_sim_exp_sum = torch.exp(neg_sim).sum(dim=-1)
        neg_sim_log = torch.log(neg_sim_exp_sum + torch.exp(pos_sim))
        
        loss = -pos_sim + neg_sim_log
        
        # Split for monitoring
        pos_component = -pos_sim.mean()
        neg_component = neg_sim_log.mean()
        total_loss = loss.mean()
        
        return pos_component.detach().item(), neg_component.detach().item(), total_loss

In [61]:
from spaces import NSphereSpace


def compute_orthogonal_transformation_loss(tau, kappa, sample_pair, batch_size, latent_dim=3):
        z, z_aug = sample_pair(batch_size, kappa)

        z_neg = torch.nn.functional.normalize(
            torch.randn((batch_size, batch_size, latent_dim), device=z.device), p=2, dim=-1
        )

        pos = - torch.sum(z * z_aug, dim=-1).mean() / tau
        neg = torch.log(torch.exp((z.unsqueeze(1) * z_neg).sum(dim=-1) / tau).sum(-1)).mean()

        return (pos + neg).item()

tau = 0.3
kapp = 1 / tau

full_sphere = NSphereSpace(3)
sub_sphere = NSphereSpace(2)

normal_loss = InfoNceLoss(tau)
adjusted_loss = InfoNceLossAdjusted(tau)

batch = 6144

z = full_sphere.uniform(batch)
z, z_sim = full_sphere.sample_pair_vmf(batch, kapp)
z_neg = full_sphere.uniform(batch)

print("Normal:", normal_loss(z, z_sim, z_neg)[2].item())
z_neg_expanded = z_neg.unsqueeze(0).expand(z_neg.shape[0], -1, -1)  # Shape: (N, N, d)

print("Adjusted:", adjusted_loss(z, z_sim, z_neg_expanded)[2].item())
print("Orhtogonal:", compute_orthogonal_transformation_loss(tau, kapp, full_sphere.sample_pair_vmf, batch))

print(z_neg_expanded[0,0])
print(z_neg_expanded[0,1])
print(z_neg_expanded[0,2])

# print(adjusted_loss(z, z_sim, z_neg))

Normal: 7.836532115936279
Adjusted: 7.836532115936279
Orhtogonal: 7.800528526306152
tensor([-0.1155,  0.1294,  0.9848])
tensor([ 0.6504, -0.5285, -0.5456])
tensor([-0.9113, -0.1109, -0.3966])
