In [2]:
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import torchvision
from torchvision import transforms, datasets
from torch.optim import SGD

import matplotlib.pyplot as plt

In [4]:
class GaussianBlur:
    def __init__(self):
        pass

class DataAugmentTransform:
    """
    Perform Data Augmentations to create positive and negative pairs
    """
    def __init__(self,
                 input_height: int = 224,
                 guassian_blur: bool = False,
                 jitter_strength: float = 1.) -> None:
        self.input_height = input_height
        self.guassian_blur = guassian_blur
        self.jitter_strength = jitter_strength

    def simclr_transform_pipeline(self):
        color_jitter = transforms.ColorJitter(
            brightness=0.8 * self.jitter_strength,
            contrast=0.8 * self.jitter_strength,
            hue=0.8 * self.jitter_strength,
            saturation=0.2 * self.jitter_strength)

        transforms_list = [transforms.RandomResizedCrop(size=self.input_height),
                           transforms.RandomHorizontalFlip(p=0.5),
                           transforms.RandomApply([color_jitter], p=0.8),
                           transforms.RandomGrayscale(p=0.2),
                           ]
        if self.guassian_blur:
            transforms_list.append(GaussianBlur(kernel_size=int(0.1 * self.input_height, p=0.5)))
        #until this point, all transforms are applied on PIL images, so here we will convert the transformed PIL images to Tensors
        transforms_list.append(transforms.ToTensor())
        self.train_transform = transforms.Compose(transforms_list) #final training transform list

    def __call__(self, name, n_views):
        #we perfprm transforms on data here
        pass




In [17]:
#Contrastive Loss NT-Xent without vectorization
class ContrastiveLossv1(nn.Module):
    def __init__(self, batch_size, temperature=0.5, verbose=False):
        super().__init__()
        self.batch_size = batch_size
        self.temperature = temperature
        self.verbose = verbose

    def forward(self, emb_i, emb_j):
        if self.verbose:
            print(f"Embeddings shape emb_i: {emb_i.shape}, emb_j: {emb_j.shape}")
        #Normalize the embeddings
        z_i = F.normalize(emb_i, dim=1)
        z_j = F.normalize(emb_j, dim=1)

        #Concatenate the representations of both the views batches
        projection_representations = torch.cat([z_i, z_j], dim=0)
        if self.verbose:
            print(f"Concatenated projection representations z shape:\
             {projection_representations.shape}")
            print(f"Projection representations unsqueeze with dim=1 shape:\
             {projection_representations.unsqueeze(dim=1).shape}")
            print(f"Projection representations unsqueeze with dim=0 shape:\
             {projection_representations.unsqueeze(dim=0).shape}")

        #Construct the similarity matrix
        sim_matrix = F.cosine_similarity(
            x1=projection_representations.unsqueeze(dim=1),
            x2=projection_representations.unsqueeze(dim=0),
            dim=2
        )
        if self.verbose:
            print(f"Similarity matrix shape:{sim_matrix.shape}")
            print(f"Similarity matrix :{sim_matrix}")

        #Calculating the loss function l(i, j)
        #(in one view persepctive l(i, j))
        def l_ij(i, j):
            if self.verbose:
                print(f"l_ij batch sample indexes {i}, {j}")
            #get the representations using the indexez i, j
            z_i_, z_j_ = projection_representations[i], projection_representations[j]

            #Get the similarity value of these 2 representation we already calculated
            sim_i_j = sim_matrix[i, j]
            if self.verbose:
                print(f"sim({i}, {j})={sim_i_j}")
            #l_ij loss for each sample
            #Calculate the numerator of l_ij loss
            numerator = torch.exp(sim_i_j / self.temperature)
            if self.verbose: print("Numerator", numerator)
            #Calculate the mask for denomiator to implement the indicator function to exclude same samples from sim matrix
            one_k_not_i = torch.ones((2 * self.batch_size)).scatter(0, torch.tensor([i]), 0.0)
            if self.verbose:
                print(f"1{{k!={i}}} shape", one_k_not_i.shape)
                print(f"1{{k!={i}}}", one_k_not_i)

            #Calculate the denominator
            denominator = torch.sum(one_k_not_i * torch.exp(sim_matrix[i, :] / self.temperature))
            if self.verbose: print("Denominator", denominator)

            #loss l_i_j
            loss_ij = - torch.log(numerator / denominator)
            if self.verbose:
                print(f"loss({i},{j})={loss_ij}\n")
            return loss_ij.squeeze()

        N = self.batch_size
        loss = 0.0 # we accumulate loss for all the samples in the batch
        for k in range(0, N):
            print(k+N)
            print(k)
            loss += l_ij(k, k + N) + l_ij(k + N, k)
            # l_ij = l_ij(k, k+N)
            # l_ji = l_ij(k+N, k)
            # l_i = l_ij + l_ji
            # loss += l_i
        final_loss = (1.0 / 2*N) * loss
        return final_loss



In [18]:
#Test the function

#Take 3 images in a batch wirth representation size of 2 for each image
U = torch.tensor([
    [1.0, 2.0],
    [3.0, -2.0],
    [1.0, 5.0]
]) #first augmentation batch
V = torch.tensor([
    [1.0, 0.75],
    [2.8, -1.75],
     [1.0, 4.7]
]) #second augmentation
lossobj = ContrastiveLossv1(batch_size=3, temperature=1.0, verbose=True)
lossobj(U, V)

Embeddings shape emb_i: torch.Size([3, 2]), emb_j: torch.Size([3, 2])
Concatenated projection representations z shape:             torch.Size([6, 2])
Projection representations unsqueeze with dim=1 shape:             torch.Size([6, 1, 2])
Projection representations unsqueeze with dim=0 shape:             torch.Size([1, 6, 2])
Similarity matrix shape:torch.Size([6, 6])
Similarity matrix :tensor([[ 1.0000, -0.1240,  0.9648,  0.8944, -0.0948,  0.9679],
        [-0.1240,  1.0000, -0.3807,  0.3328,  0.9996, -0.3694],
        [ 0.9648, -0.3807,  1.0000,  0.7452, -0.3534,  0.9999],
        [ 0.8944,  0.3328,  0.7452,  1.0000,  0.3604,  0.7533],
        [-0.0948,  0.9996, -0.3534,  0.3604,  1.0000, -0.3419],
        [ 0.9679, -0.3694,  0.9999,  0.7533, -0.3419,  1.0000]])
3
0
l_ij batch sample indexes 0, 3
sim(0, 3)=0.8944272994995117
Numerator tensor(2.4459)
1{k!=0} shape torch.Size([6])
1{k!=0} tensor([0., 1., 1., 1., 1., 1.])
Denominator tensor(9.4954)
loss(0,3)=1.3563847541809082

l_ij bat

tensor(10.1943)

In [57]:
class ContrastiveLossNTXent(nn.Module):
    def __init__(self, batch_size, temperature=0.5, verbose=False):
        super().__init__()
        self.batch_size = batch_size
        self.temperature = temperature
        self.verbose = verbose

    def forward(self, emb_z_i, emb_z_j):
        """
        Calculate NT_Xent Loss

        emb_z_i : (b, 128) Projecttion head embeddings of one flow of augmentations z_i of a batch
        emb_z_j : (b, 128) Projecttion head embeddings of second flow of augmentations z_j of a batch
        """
        ## Normalize the embeddings first
        z_i = F.normalize(emb_z_i, 1)
        z_j = F.normalize(emb_z_j, 1)

        #Contentanate both of them to further use it to construct similarity matrix
        representations = torch.cat([z_i, z_j], dim=0)

        #Similarity Matrix using cosine similarity manual
        sim_matrix = torch.matmul(representations, representations.T)

        if self.verbose:
            print(f"Similarity Matrix: \n {sim_matrix}")

        #Denominator similarity
        #creating mask
        num_samples = len(representations)
        mask = ~torch.eye(num_samples).bool()
        sim_matrix_neg = torch.exp(sim_matrix / self.temperature)
        loss_denominator_neg = sim_matrix_neg.masked_select(mask).view(num_samples, -1).sum(dim=-1)
        print("Mask"+str(mask))
        print("loss denominator"+str(loss_denominator_neg))

        #Positive examples Numerator similarity
        loss_numerator_pos = torch.exp(torch.sum(z_i * z_j, dim=-1) / self.temperature)
        print("loss loss_numerator_pos"+str(loss_numerator_pos))
        loss_numerator_pos = torch.cat([loss_numerator_pos, loss_numerator_pos], dim=0)
        print("Concat loss_numerator_pos"+str(loss_numerator_pos))

        final_loss = - torch.log( loss_numerator_pos / loss_denominator_neg ).mean()
        print(torch.log( loss_numerator_pos / loss_denominator_neg ))
        print(final_loss)
        return final_loss.squeeze()


In [56]:
torch.manual_seed(42)
emb_i = torch.randn(3, 1)
print(emb_i)
torch.manual_seed(0)
emb_j = torch.randn(3, 1)
print(emb_j)

contloss = ContrastiveLossNTXent(batch_size=2, verbose=True)
contloss(emb_i, emb_j)


tensor([[0.3367],
        [0.1288],
        [0.2345]])
tensor([[ 1.5410],
        [-0.2934],
        [-2.1788]])
Similarity Matrix: 
 tensor([[ 1.,  1.,  1.,  1., -1., -1.],
        [ 1.,  1.,  1.,  1., -1., -1.],
        [ 1.,  1.,  1.,  1., -1., -1.],
        [ 1.,  1.,  1.,  1., -1., -1.],
        [-1., -1., -1., -1.,  1.,  1.],
        [-1., -1., -1., -1.,  1.,  1.]])
Masktensor([[False,  True,  True,  True,  True,  True],
        [ True, False,  True,  True,  True,  True],
        [ True,  True, False,  True,  True,  True],
        [ True,  True,  True, False,  True,  True],
        [ True,  True,  True,  True, False,  True],
        [ True,  True,  True,  True,  True, False]])
loss denominatortensor([22.4378, 22.4378, 22.4378, 22.4378,  7.9304,  7.9304])
loss loss_numerator_postensor([7.3891, 0.1353, 0.1353])
Concat loss_numerator_postensor([7.3891, 0.1353, 0.1353, 7.3891, 0.1353, 0.1353])
tensor([-1.1107, -5.1107, -5.1107, -1.1107, -4.0707, -4.0707])
tensor(3.4307)


tensor(3.4307)

In [31]:
torch.manual_seed(42)
sim_matrix = torch.randn(6, 6)
mask = ~torch.eye(6).bool()
my_mat = sim_matrix.masked_select(mask)
transformed_mat = my_mat.view(len(sim_matrix), -1)
sum_mat = transformed_mat.sum(dim=-1) #axis=1
sim_matrix, mask, my_mat, transformed_mat, sum_mat

(tensor([[ 1.9269,  1.4873,  0.9007, -2.1055,  0.6784, -1.2345],
         [-0.0431, -1.6047, -0.7521,  1.6487, -0.3925, -1.4036],
         [-0.7279, -0.5594, -0.7688,  0.7624,  1.6423, -0.1596],
         [-0.4974,  0.4396,  0.3189, -0.4245,  0.3057, -0.7746],
         [ 0.0349,  0.3211,  1.5736, -0.8455, -1.2742,  2.1228],
         [-1.2347, -0.4879, -1.4181,  0.8963,  0.0499,  2.2667]]),
 tensor([[False,  True,  True,  True,  True,  True],
         [ True, False,  True,  True,  True,  True],
         [ True,  True, False,  True,  True,  True],
         [ True,  True,  True, False,  True,  True],
         [ True,  True,  True,  True, False,  True],
         [ True,  True,  True,  True,  True, False]]),
 tensor([ 1.4873,  0.9007, -2.1055,  0.6784, -1.2345, -0.0431, -0.7521,  1.6487,
         -0.3925, -1.4036, -0.7279, -0.5594,  0.7624,  1.6423, -0.1596, -0.4974,
          0.4396,  0.3189,  0.3057, -0.7746,  0.0349,  0.3211,  1.5736, -0.8455,
          2.1228, -1.2347, -0.4879, -1.4181, 