In [1]:
import torch
import torch.nn.functional as F

## NT-Xent Loss

### version 1

In [2]:
n_views = 2
batch_size = 2
feature_dim = 4
features = torch.rand(batch_size * n_views, feature_dim)
features

tensor([[0.0673, 0.6466, 0.4738, 0.9351],
        [0.3068, 0.9623, 0.8182, 0.1495],
        [0.7062, 0.9271, 0.8500, 0.2595],
        [0.5794, 0.8153, 0.4345, 0.6248]])

In [3]:
labels = torch.cat([torch.arange(batch_size) for i in range(n_views)], dim=0).to(features.device)
labels

tensor([0, 1, 0, 1])

In [4]:
labels = (labels.unsqueeze(0) == labels.unsqueeze(1)).float()
labels

tensor([[1., 0., 1., 0.],
        [0., 1., 0., 1.],
        [1., 0., 1., 0.],
        [0., 1., 0., 1.]])

In [5]:
norm_features = F.normalize(features, dim=1)
norm_features

tensor([[0.0546, 0.5242, 0.3841, 0.7581],
        [0.2345, 0.7355, 0.6253, 0.1142],
        [0.4819, 0.6325, 0.5800, 0.1771],
        [0.4610, 0.6487, 0.3457, 0.4971]])

In [6]:
similarity_matrix = torch.matmul(norm_features, norm_features.T)
similarity_matrix

tensor([[1.0000, 0.7251, 0.7149, 0.8749],
        [0.7251, 1.0000, 0.9611, 0.8582],
        [0.7149, 0.9611, 1.0000, 0.9210],
        [0.8749, 0.8582, 0.9210, 1.0000]])

In [7]:
mask = torch.eye(labels.shape[0], dtype=torch.bool, device=features.device)
mask

tensor([[ True, False, False, False],
        [False,  True, False, False],
        [False, False,  True, False],
        [False, False, False,  True]])

In [8]:
labels = labels[~mask].view(labels.shape[0], -1)
labels

tensor([[0., 1., 0.],
        [0., 0., 1.],
        [1., 0., 0.],
        [0., 1., 0.]])

In [9]:
similarity_matrix = similarity_matrix[~mask].view(similarity_matrix.shape[0], -1)
similarity_matrix

tensor([[0.7251, 0.7149, 0.8749],
        [0.7251, 0.9611, 0.8582],
        [0.7149, 0.9611, 0.9210],
        [0.8749, 0.8582, 0.9210]])

In [10]:
positives = similarity_matrix[labels.bool()].view(labels.shape[0], -1)
negatives = similarity_matrix[~labels.bool()].view(similarity_matrix.shape[0], -1)
positives, negatives

(tensor([[0.7149],
         [0.8582],
         [0.7149],
         [0.8582]]),
 tensor([[0.7251, 0.8749],
         [0.7251, 0.9611],
         [0.9611, 0.9210],
         [0.8749, 0.9210]]))

In [11]:
logits = torch.cat([positives, negatives], dim=1)
logits

tensor([[0.7149, 0.7251, 0.8749],
        [0.8582, 0.7251, 0.9611],
        [0.7149, 0.9611, 0.9210],
        [0.8582, 0.8749, 0.9210]])

In [12]:
temperature = 1
labels = torch.zeros(logits.shape[0], dtype=torch.long, device=features.device)
logits = logits / temperature
logits

tensor([[0.7149, 0.7251, 0.8749],
        [0.8582, 0.7251, 0.9611],
        [0.7149, 0.9611, 0.9210],
        [0.8582, 0.8749, 0.9210]])

In [13]:
labels

tensor([0, 0, 0, 0])

In [14]:
F.cross_entropy(logits, labels, reduction="mean")

tensor(1.1580)

### version 2

In [15]:
xcs = F.cosine_similarity(norm_features[None, :, :], norm_features[:, None, :], dim=-1)
xcs

tensor([[1.0000, 0.7251, 0.7149, 0.8749],
        [0.7251, 1.0000, 0.9611, 0.8582],
        [0.7149, 0.9611, 1.0000, 0.9210],
        [0.8749, 0.8582, 0.9210, 1.0000]])

In [16]:
xcs[torch.eye(xcs.size(0)).bool()] = float("-inf")
xcs

tensor([[  -inf, 0.7251, 0.7149, 0.8749],
        [0.7251,   -inf, 0.9611, 0.8582],
        [0.7149, 0.9611,   -inf, 0.9210],
        [0.8749, 0.8582, 0.9210,   -inf]])

In [17]:
target = torch.arange(batch_size * n_views)
target[0::2] += 1
target[1::2] -= 1
target

tensor([1, 0, 3, 2])

In [18]:
F.cross_entropy(xcs / temperature, target, reduction="mean")

tensor(1.1214)

In [19]:
def NTXent_Loss(features, temperature=1, normalize=True):
    # features: [2N, D], N is the number of samples, D is the dimension of feature, 2k-1 is the positive pair of 2k
    features = F.normalize(features, dim=1) if normalize else features
    similarity_matrix = torch.matmul(features, features.T)  # Cosine similarity, [2N, 2N]
    similarity_matrix[torch.eye(features.size(0)).bool()] = float("-inf")

    ''' Ground truth labels, 2k-1 is the positive pair of 2k
    tensor([[0, 1, 0, 0, 0, 0],
            [1, 0, 0, 0, 0, 0],
            [0, 0, 0, 1, 0, 0],
            [0, 0, 1, 0, 0, 0],
            [0, 0, 0, 0, 0, 1],
            [0, 0, 0, 0, 1, 0]])
    '''
    target = torch.arange(features.size(0))
    target[0::2] += 1
    target[1::2] -= 1

    # Standard cross entropy loss
    return F.cross_entropy(similarity_matrix / temperature, target, reduction="mean")

In [20]:
NTXent_Loss(features)

tensor(1.1214)

#### version 1 and version 2 are equivalent but here the computed loss value is different because they require different input. For version 1, the first N of 2N samples are the positive pairs of the last N samples, while for version 2, the positive pairs are the adjacent samples.

## SupConLoss

In [259]:
n_views = 1
batch_size = 4
feature_dim = 4
features = torch.rand(batch_size * n_views, feature_dim)
features

tensor([[0.5695, 0.5705, 0.7160, 0.8298],
        [0.0652, 0.1614, 0.3072, 0.0560],
        [0.1572, 0.9189, 0.3612, 0.0158],
        [0.3656, 0.4852, 0.6520, 0.4377]])

In [249]:
f1, f2 = torch.split(features, [batch_size, batch_size], dim=0)
f1, f2

RuntimeError: start (4) + length (4) exceeds dimension size (4).

In [251]:
features = torch.cat([f1.unsqueeze(1), f2.unsqueeze(1)], dim=1)
features

tensor([[[0.0051, 0.0207, 0.2768, 0.9484],
         [0.7537, 0.9323, 0.4553, 0.2884]],

        [[0.5589, 0.2835, 0.1780, 0.9505],
         [0.8993, 0.3663, 0.9877, 0.9788]],

        [[0.1532, 0.1470, 0.2990, 0.3349],
         [0.9127, 0.7172, 0.8151, 0.2047]],

        [[0.1736, 0.0940, 0.7889, 0.4471],
         [0.0126, 0.0987, 0.7025, 0.4047]]])

In [263]:
mask = None
labels = torch.tensor([1, 0, 0, 1])
temperature = 1
base_temperature = 1
contrast_mode = 'all'

In [264]:
features = features.unsqueeze(1) if len(features.shape) < 3 else features
features

tensor([[[0.5695, 0.5705, 0.7160, 0.8298]],

        [[0.0652, 0.1614, 0.3072, 0.0560]],

        [[0.1572, 0.9189, 0.3612, 0.0158]],

        [[0.3656, 0.4852, 0.6520, 0.4377]]])

In [265]:
batch_size = features.shape[0]
device = (torch.device('cuda')
          if features.is_cuda
          else torch.device('cpu'))
if labels is not None and mask is not None:
    raise ValueError('Cannot define both `labels` and `mask`')
elif labels is None and mask is None:
    mask = torch.eye(batch_size, dtype=torch.float32).to(device)
elif labels is not None:
    labels = labels.contiguous().view(-1, 1)
    if labels.shape[0] != batch_size:
        raise ValueError('Num of labels does not match num of features')
    mask = torch.eq(labels, labels.T).float().to(device)
else:
    mask = mask.float().to(device)
device, labels, mask

(device(type='cpu'),
 tensor([[1],
         [0],
         [0],
         [1]]),
 tensor([[1., 0., 0., 1.],
         [0., 1., 1., 0.],
         [0., 1., 1., 0.],
         [1., 0., 0., 1.]]))

In [266]:
torch.unbind(features, dim=1)

(tensor([[0.5695, 0.5705, 0.7160, 0.8298],
         [0.0652, 0.1614, 0.3072, 0.0560],
         [0.1572, 0.9189, 0.3612, 0.0158],
         [0.3656, 0.4852, 0.6520, 0.4377]]),)

In [268]:
contrast_count = features.shape[1]
contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0)
if contrast_mode == 'one':
    anchor_feature = features[:, 0]
    anchor_count = 1
elif contrast_mode == 'all':
    anchor_feature = contrast_feature
    anchor_count = contrast_count
else:
    raise ValueError('Unknown mode: {}'.format(contrast_mode))
anchor_count, contrast_count, anchor_feature, contrast_feature

(1,
 1,
 tensor([[0.5695, 0.5705, 0.7160, 0.8298],
         [0.0652, 0.1614, 0.3072, 0.0560],
         [0.1572, 0.9189, 0.3612, 0.0158],
         [0.3656, 0.4852, 0.6520, 0.4377]]),
 tensor([[0.5695, 0.5705, 0.7160, 0.8298],
         [0.0652, 0.1614, 0.3072, 0.0560],
         [0.1572, 0.9189, 0.3612, 0.0158],
         [0.3656, 0.4852, 0.6520, 0.4377]]))

In [269]:
# compute logits
anchor_dot_contrast = torch.div(torch.matmul(anchor_feature, contrast_feature.T), temperature)
anchor_dot_contrast

tensor([[1.8510, 0.3956, 0.8854, 1.3151],
        [0.3956, 0.1278, 0.2704, 0.3269],
        [0.8854, 0.2704, 0.9997, 0.7457],
        [1.3151, 0.3269, 0.7457, 0.9858]])

In [270]:
# for numerical stability
logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
logits_max

tensor([[1.8510],
        [0.3956],
        [0.9997],
        [1.3151]])

In [271]:
logits = anchor_dot_contrast - logits_max.detach()
logits

tensor([[ 0.0000, -1.4554, -0.9656, -0.5359],
        [ 0.0000, -0.2678, -0.1252, -0.0687],
        [-0.1143, -0.7293,  0.0000, -0.2540],
        [ 0.0000, -0.9882, -0.5694, -0.3293]])

In [272]:
# tile mask
mask = mask.repeat(anchor_count, contrast_count)
mask

tensor([[1., 0., 0., 1.],
        [0., 1., 1., 0.],
        [0., 1., 1., 0.],
        [1., 0., 0., 1.]])

In [273]:
logits_mask = torch.scatter(torch.ones_like(mask), 1, torch.arange(batch_size * anchor_count).view(-1, 1).to(device), 0)
logits_mask

tensor([[0., 1., 1., 1.],
        [1., 0., 1., 1.],
        [1., 1., 0., 1.],
        [1., 1., 1., 0.]])

In [274]:
# mask-out self-contrast cases
mask = mask * logits_mask
mask

tensor([[0., 0., 0., 1.],
        [0., 0., 1., 0.],
        [0., 1., 0., 0.],
        [1., 0., 0., 0.]])

In [275]:
# compute log_prob
exp_logits = torch.exp(logits) * logits_mask
exp_logits

tensor([[0.0000, 0.2333, 0.3808, 0.5851],
        [1.0000, 0.0000, 0.8823, 0.9336],
        [0.8920, 0.4822, 0.0000, 0.7757],
        [1.0000, 0.3723, 0.5659, 0.0000]])

In [276]:
exp_logits.sum(1, keepdim=True)

tensor([[1.1992],
        [2.8160],
        [2.1499],
        [1.9381]])

In [277]:
torch.log(exp_logits.sum(1, keepdim=True))

tensor([[0.1817],
        [1.0353],
        [0.7654],
        [0.6617]])

In [278]:
log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True))
log_prob

tensor([[-0.1817, -1.6371, -1.1472, -0.7176],
        [-1.0353, -1.3031, -1.1605, -1.1040],
        [-0.8797, -1.4948, -0.7654, -1.0194],
        [-0.6617, -1.6499, -1.2311, -0.9910]])

In [279]:
# compute mean of log-likelihood over positive
mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1)
mean_log_prob_pos

tensor([-0.7176, -1.1605, -1.4948, -0.6617])

In [280]:
loss = - (temperature / base_temperature) * mean_log_prob_pos
loss = loss.view(anchor_count, batch_size).mean()
loss

tensor(1.0086)

## SimCSE

In [326]:
features = torch.randint(2, (8, feature_dim)).float()
features

tensor([[0., 0., 0., 1.],
        [0., 1., 0., 0.],
        [1., 0., 1., 1.],
        [1., 0., 0., 0.],
        [0., 0., 1., 0.],
        [0., 0., 0., 0.],
        [1., 0., 1., 1.],
        [0., 1., 1., 0.]])

In [327]:
z1, z2 = torch.split(features, features.shape[0] // 2, dim=0)
z1, z2

(tensor([[0., 0., 0., 1.],
         [0., 1., 0., 0.],
         [1., 0., 1., 1.],
         [1., 0., 0., 0.]]),
 tensor([[0., 0., 1., 0.],
         [0., 0., 0., 0.],
         [1., 0., 1., 1.],
         [0., 1., 1., 0.]]))

In [328]:
F.cosine_similarity(z1.unsqueeze(1), z2.unsqueeze(0), dim=2)

tensor([[0.0000, 0.0000, 0.5774, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.7071],
        [0.5774, 0.0000, 1.0000, 0.4082],
        [0.0000, 0.0000, 0.5774, 0.0000]])

In [329]:
labels = torch.arange(0, z1.shape[0], device=features.device)
similarity_matrix = F.cosine_similarity(z1.unsqueeze(1), z2.unsqueeze(0), dim=2)
similarity_matrix = similarity_matrix / temperature
similarity_matrix

tensor([[0.0000, 0.0000, 0.5774, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.7071],
        [0.5774, 0.0000, 1.0000, 0.4082],
        [0.0000, 0.0000, 0.5774, 0.0000]])

In [330]:
loss = F.cross_entropy(similarity_matrix, labels)
loss

tensor(1.4227)