In [1]:
import torch
import torch.nn.functional as F

def simclr_loss(z_i: torch.Tensor, z_j: torch.Tensor, temperature: float = 0.5) -> torch.Tensor:
    """
    Computes the simCLR loss function for a batch of paired feature vectors.

    Args:
        z_i (torch.Tensor): A tensor of shape (N, D) representing the feature vectors of the first views.
        z_j (torch.Tensor): A tensor of shape (N, D) representing the feature vectors of the second views.
        temperature (float): The temperature parameter for the softmax operation. Default: 0.5.

    Returns:
        torch.Tensor: A scalar tensor representing the simCLR loss for the given batch of paired feature vectors.

    Raises:
        ValueError: If z_i and z_j do not have the same shape or if the shape of z_i or z_j is not (N, D).
    """
    # Normalize the feature vectors
    z_i = F.normalize(z_i, dim=1)
    z_j = F.normalize(z_j, dim=1)

    # Concatenate the feature vectors and create the targets
    z = torch.cat([z_i, z_j], dim=0)
    targets = torch.arange(z.size(0)).to(z.device)
    masks = F.one_hot(targets, num_classes=z.size(0))

    # Compute the similarities between all pairs of feature vectors
    similarities = torch.matmul(z, z.t()) / temperature

    # Set the diagonal elements (i.e., the similarities between each feature vector and itself) to negative infinity
    mask = masks.float().neg()
    similarities = similarities.masked_fill(mask == 1, float('-inf'))

    # Compute the numerator and denominator of the loss function
    numerator = torch.exp(similarities)
    denominator = numerator.sum(dim=1, keepdim=True)

    # Compute the loss function
    loss = -torch.log(numerator / denominator).mean()

    return loss


In [None]:
import unittest
import torch
from simclr_loss import simclr_loss

class TestSimCLRLoss(unittest.TestCase):
    def test_simclr_loss(self):
        # Set the random seed for reproducibility
        torch.manual_seed(0)

        # Create a batch of feature vectors
        batch_size = 32
        embedding_dim = 128
        z_i = torch.randn(batch_size, embedding_dim)
        z_j = torch.randn(batch_size, embedding_dim)

        # Compute the simCLR loss
        loss = simclr_loss(z_i, z_j)

        # Check that the loss is a scalar tensor
        self.assertIsInstance(loss, torch.Tensor)
        self.assertEqual(loss.dim(), 0)

        # Check that the loss is non-negative
        self.assertGreaterEqual(loss.item(), 0)

        # Check that the loss is finite
        self.assertTrue(torch.isfinite(loss).all())

if __name__ == '__main__':
    unittest.main()
