# Contrastive Loss

---
In this session, we are going to implement the SimCLR loss function (https://arxiv.org/abs/2002.05709).

This follows the InfoNCE loss, i.e., uses two different augmented versions of the same image as positive pair and the other images in the batch as negative samples, and the batch construction of the N-pair-mc loss.


In [2]:
import os
import pandas as pd

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

from torchvision.io import read_image

import torchvision
import torchvision.models as models
import torchvision.transforms as transforms

from PIL import Image
import time

In [23]:
# Base version with loops
class ContrastiveLossBase(nn.Module):
    def __init__(self, temperature=0.07):
        super().__init__()
        self.temperature = temperature
        self.criterion = torch.nn.CrossEntropyLoss()

    def forward(self, features, batch_size=64):
        ### features = torch.cat((x1,x2), dim=0)
        # normalize features to later compute cosine distance/similarity btw them
        features = F.normalize(features, dim=1) # (128, 512)
        # compute the similarity matrix btw features
        # (consider that feature are normalized! so the cosine similarity is dot product)
        similarity_matrix = torch.matmul(features, features.T)

        start = time.time()
        # create the logits tensor where:
        #   - in the first position there is the similarity of the positive pair
        #   - in the other 2N-1 positions there are the similarity w negatives
        # the shape of the tensor need to be 2Nx2N-1, with N is the batch size
        logits = torch.zeros(2 * batch_size, 2 *batch_size-1)
        for idx, val in enumerate(similarity_matrix):
            row = torch.zeros(2 * batch_size-1)
            pos_idx = idx + batch_size if idx < batch_size else idx - batch_size # i + N seconda sottomatrice, i - N terza sottomatrice
            row[0] = val[pos_idx]
            row[1:] = torch.tensor([v for i, v in enumerate(val) if i!=idx and i!=pos_idx])
            logits[idx] = row
        logits.requires_grad_(requires_grad=True)
        print(f"Logits shape: {logits.shape}")

        logits = logits / self.temperature

        # to compute the contrastive loss using the CE loss, we just need to
        # specify where is the similarity of the positive pair in the logits tensor
        # since we put in the first position we create a gt of all zeros
        # N.B.: this is just one of the possible implementations!
        gt = torch.zeros(logits.shape[0], dtype=torch.long) # -> indexes of positives
        end = time.time()
        time_taken = end - start
        return time_taken, self.criterion(logits, gt)

In [26]:
# Optimized Version
class ContrastiveLoss(nn.Module):
    def __init__(self, temperature=0.07):
        super().__init__()
        self.temperature = temperature
        self.criterion = torch.nn.CrossEntropyLoss()

    def forward(self, features, batch_size=64):
        ### features = torch.cat((x1,x2), dim=0)
        # normalize features to later compute cosine distance/similarity btw them
        features = F.normalize(features, dim=1) # (128, 512)
        # compute the similarity matrix btw features
        # (consider that feature are normalized! so the cosine similarity is dot product)

        similarity_matrix = torch.matmul(features, features.T)

        start = time.time()
        labels = torch.cat([torch.arange(features.shape[0]//2) for i in range(2)], dim=0)
        labels = (labels.unsqueeze(0) == labels.unsqueeze(1)).float() # diagonale principale identità, e anche nelle diagonali principali delle 4 sottomatrici
        mask = torch.eye(labels.shape[0], dtype=torch.bool)
        labels = labels[~mask].view(labels.shape[0], -1)
        similarity_matrix = similarity_matrix[~mask].view(similarity_matrix.shape[0], -1)

        # select and combine multiple positives
        positives = similarity_matrix[labels.bool()].view(labels.shape[0], -1)

        # select only the negatives
        negatives = similarity_matrix[~labels.bool()].view(similarity_matrix.shape[0], -1)

        logits = torch.cat((positives, negatives), dim=1)
        logits = logits / self.temperature

        gt = torch.zeros(logits.shape[0], dtype=torch.long)
        end = time.time()
        time_taken = end - start
        return time_taken, self.criterion(logits, gt)

Let's now use the Dataset which creates the two augmented views for each image and the Siamese Network from the past lab session [1](https://colab.research.google.com/drive/1NJwAFbRiD4MdwWf__6P2Lm0xYk_DNdVu?usp=sharing) and [2](https://colab.research.google.com/drive/1AMkh0q8L5nJScx7v6cMWoK336zqOqDY6?usp=sharing) and create a training loop

In [5]:
class SiameseNetSIM(nn.Module):
    def __init__(self, backbone):
        super().__init__()
        self.encoder = backbone

    # if pairs are not concatenated before, use this
    '''
    def forward(self, x1, x2):
        images = torch.cat((x1, x2), 0)
        features = self.encoder(images)
        return features
    '''
    # if pairs are concatenaed before, use tihs
    def forward(self, x):
        return self.encoder(x)

In [27]:
class CustomImageDataset(Dataset):
    def __init__(self, data, targets, transform=None, target_transform=None):
        self.imgs = data
        self.targets = targets
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.targets)

    def __getitem__(self, idx):
        img_base = self.imgs[idx]
        if isinstance(img_base, str):
          img_base = read_image(img_base)
        label = self.targets[idx]
        if self.transform:
            img1 = self.transform(img_base)
            img2 = self.transform(img_base)
        else:
            img1 = img_base
            img2 = img_base
        if self.target_transform:
            label = self.target_transform(label)
        return img1, img2, label


data = torchvision.datasets.CIFAR10(root='./data', train=True, download=True)
size = 32
s=1
color_jitter = transforms.ColorJitter(0.8 * s, 0.8 * s, 0.8 * s, 0.2 * s)
transform = transforms.Compose([transforms.ToTensor(),
                                  transforms.RandomResizedCrop(size=size),
                                  transforms.RandomHorizontalFlip(),
                                  transforms.RandomApply([color_jitter], p=0.8),
                                  transforms.RandomGrayscale(p=0.2),
                                  transforms.GaussianBlur(kernel_size=int(0.1 * size))])

Files already downloaded and verified


In [25]:
trainset = CustomImageDataset(data.data, data.targets, transform=transform)
dataloader = DataLoader(trainset, batch_size=64, shuffle=True)

# you can use a resnet18 as backbone
backbone = models.resnet18()

#! remember to delete the fc layer (we need just the CNN layers + flatten)
backbone.fc = nn.Identity()
model = SiameseNetSIM(backbone)
optimizer = optim.Adam(model.parameters())
criterion = ContrastiveLossBase()

for idx, data in enumerate(dataloader):
    views1, views2, targets = data
    print(views1.shape)
    print(views2.shape)

    optimizer.zero_grad()
    output = model(torch.cat((views1, views2), 0))
    time_taken, loss = criterion(output, 64)
    loss.backward()
    optimizer.step()

    print(f"Base loss time: {time_taken:.3f}, loss value: {loss.item():.3f}")

    if idx == 3:
        break

torch.Size([64, 3, 32, 32])
torch.Size([64, 3, 32, 32])
Logits shape: torch.Size([128, 127])
Base loss time: 0.047, loss value: 5.199
torch.Size([64, 3, 32, 32])
torch.Size([64, 3, 32, 32])
Logits shape: torch.Size([128, 127])
Base loss time: 0.086, loss value: 5.461
torch.Size([64, 3, 32, 32])
torch.Size([64, 3, 32, 32])
Logits shape: torch.Size([128, 127])
Base loss time: 0.047, loss value: 6.197
torch.Size([64, 3, 32, 32])
torch.Size([64, 3, 32, 32])
Logits shape: torch.Size([128, 127])
Base loss time: 0.051, loss value: 6.390


In [28]:
trainset = CustomImageDataset(data.data, data.targets, transform=transform)
dataloader = DataLoader(trainset, batch_size=64, shuffle=True)

# you can use a resnet18 as backbone
backbone = models.resnet18()

#! remember to delete the fc layer (we need just the CNN layers + flatten)
backbone.fc = nn.Identity()
model = SiameseNetSIM(backbone)
optimizer = optim.Adam(model.parameters())
criterion = ContrastiveLoss()

for idx, data in enumerate(dataloader):
    views1, views2, targets = data
    print(views1.shape)
    print(views2.shape)

    optimizer.zero_grad()
    output = model(torch.cat((views1, views2), 0))
    time_taken, loss = criterion(output, 64)
    loss.backward()
    optimizer.step()

    print(f"Optimized loss time: {time_taken:.3f}, loss value: {loss.item():.3f}")

    if idx == 3:
        break

torch.Size([64, 3, 32, 32])
torch.Size([64, 3, 32, 32])
Optimized loss time: 0.004, loss value: 5.581
torch.Size([64, 3, 32, 32])
torch.Size([64, 3, 32, 32])
Optimized loss time: 0.001, loss value: 5.328
torch.Size([64, 3, 32, 32])
torch.Size([64, 3, 32, 32])
Optimized loss time: 0.001, loss value: 5.092
torch.Size([64, 3, 32, 32])
torch.Size([64, 3, 32, 32])
Optimized loss time: 0.002, loss value: 5.390


In [None]:
# Fare operazioni in place migliora i tempi
# prova a trovare altre implementazioni
# Per elaborato: presentazione loss, poi schema facendo vedere matrici