<a href="https://colab.research.google.com/github/AlessioChen/Computer-Vision-Class/blob/main/SSL_Siamese_Network.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Siamese Network



---

In this session, we are going to implement a Siamese Network.

It takes as input two augmented versions of the same image and produces as output two feature vectors one for each version of the image.

For simplicity, we will use the same backbone to process the views as in SimCLR paper.



In [None]:
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim

import torchvision
import torchvision.models as models
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset

In [None]:
# you can use a resnet18 as backbone
backbone = models.resnet18()
#! remember to delete the fc layer (we need just the CNN layers + flatten)
backbone.fc = torch.nn.Identity()


In [None]:
class SiameseNetSymmetric(nn.Module):
    def __init__(self, backbone):
        super().__init__()

        self.encoder = backbone
        self.encoder.fc = torch.nn.Identity()


    def forward(self, x1, x2):
        x = torch.concat((x1 , x2), dim=0)
        x = self.encoder(x)

        f1 = x[:x.shape[0]]
        f2 = x[x.shape[0]:]

        return f1, f2


# Check output shape
model = SiameseNetSymmetric(backbone)
f1, f2 = model(torch.randn(5, 3, 32, 32), torch.randn(5, 3, 32, 32))
print(f1.shape, f2.shape)

torch.Size([10, 512]) torch.Size([0, 512])


In [None]:
class SiameseNetAsymmetric(nn.Module):
    def __init__(self, backbone):
        super().__init__()
        self.encoder1 = backbone
        self.encoder1.fc = torch.nn.Identity()

        self.encoder2 = backbone
        self.encoder2.fc = torch.nn.Identity()

    def forward(self, x1, x2):

        f1 = self.encoder1(x1)
        f2 = self.encoder2(x2)

        return f1, f2


# Check output shape
model = SiameseNetAsymmetric(backbone)
f1, f2 = model(torch.randn(5, 3, 32, 32), torch.randn(5, 3, 32, 32))
print(f1.shape, f2.shape)

torch.Size([5, 512]) torch.Size([5, 512])


Let's now use the Dataset which creates the two augmented views for each image from the [past lab session](https://colab.research.google.com/drive/1NJwAFbRiD4MdwWf__6P2Lm0xYk_DNdVu?usp=sharing) and create a loop with forward pass

In [None]:
class SSLDataset(Dataset):
    def __init__(self, dataset, transform=None):
        self.dataset = dataset
        self.transform = transform

        self.imgs = dataset.data
        self.targets = dataset.targets

        self.id = 0

    def __len__(self):
        return len(self.imgs)

    def __getitem__(self, idx):
        image, label = self.dataset[idx]

        if self.transform:
          img1 = self.transform(image)
          img2 = self.transform(image)

        else:
          img1 = image
          img2 = image

        label1 = label2 = self.id

        self.id += 1
        # same id = positive pair
        return img1, img2, label1, label2

In [None]:
dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True)

# simclr DA pipeline
s=1
size = 32
color_jitter = transforms.ColorJitter(0.8 * s, 0.8 * s, 0.8 * s, 0.2 * s)
transform = transforms.Compose([transforms.RandomResizedCrop(size=size),
                                  transforms.RandomHorizontalFlip(),
                                  transforms.RandomApply([color_jitter], p=0.8),
                                  transforms.RandomGrayscale(p=0.2),
                                  transforms.GaussianBlur(kernel_size=int(0.1 * size)),
                                  transforms.ToTensor()])

# create training set from CustomDataset
trainset = SSLDataset(dataset = dataset, transform=transform)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170M/170M [00:02<00:00, 80.5MB/s]


Extracting ./data/cifar-10-python.tar.gz to ./data


## Contrastive loss

Obiettivo: massimizzare negative log likelihood delle similutidine tra le classi negative

In [None]:
from torch.nn import functional as F
class ContrastiveLoss(nn.Module):
    def __init__(self, temperature=0.07):
        super().__init__()
        self.temperature = temperature
        self.criterion = torch.nn.CrossEntropyLoss()

    def forward(self, features):
        ### features = torch.cat((x1,x2), dim=0)
        # normalize features to later compute cosine distance/similarity btw them
        features = F.normalize(features, dim=1)

        # compute the similarity matrix btw features
        N = features.shape[0] # 2 * batch size
        batch_size = N / 2
        features = nn.functional.normalize(features, dim=1)
        # compute the similarity matrix btw features
        # (consider that feature are normalized! so the cosine similarity is ...)
        similarity_matrix = torch.matmul(features, features.T)

        # create the logits tensor where:
        #   - in the first position there is the similarity of the positive pair
        #   - in the other 2N-1 positions there are the similarity w negatives
        # the shape of the tensor need to be 2Nx2N-1, with N is the batch size
        logits = torch.zeros((N , N - 1))

        for idx, val in enumerate(similarity_matrix):
          row = torch.zeros(N - 1)

          pos_idx = idx + batch_size if idx < batch_size else idx - batch_size
          row[0] = val[pos_idx]   # positive in first position
          row[1:] = torch.tensor([v for i, v in enumerate(val) if i != idx and i != pos_idx])

          logits[idx] = row

        logits /= self.temperature
        gt = torch.zeros(logits.shape[0], dtype=torch.long) # positive in first position

        # to compute the contrastive loss using the CE loss, we just need to
        # specify where is the similarity of the positive pair in the logits tensor
        # since we put in the first position we create a gt of all zeros
        # N.B.: this is just one of the possible implementations!
        loss = self.criterion(logits, gt)

        return loss

In [None]:

labels = torch.cat([torch.arange(5) for i in range(2)], dim=0)  #0,1,2, N-1 | 0,1,2 N-1 -> label associata ad ogni elemento
print(labels)
labels = (labels.unsqueeze(0) == labels.unsqueeze(1)).float() # diagonale identità + sotto identità
print(labels)

tensor([0, 1, 2, 3, 4, 0, 1, 2, 3, 4])
tensor([[1., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0., 0., 1., 0., 0., 0.],
        [0., 0., 1., 0., 0., 0., 0., 1., 0., 0.],
        [0., 0., 0., 1., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 1., 0., 0., 0., 0., 1.],
        [1., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0., 0., 1., 0., 0., 0.],
        [0., 0., 1., 0., 0., 0., 0., 1., 0., 0.],
        [0., 0., 0., 1., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 1., 0., 0., 0., 0., 1.]])


In [None]:
from torch.nn import functional as F
class ContrastiveLoss2(nn.Module):
    def __init__(self, temperature=0.07):
        super().__init__()
        self.temperature = temperature
        self.criterion = torch.nn.CrossEntropyLoss()

    def forward(self, features):
        batch_size = features.shape[0] // 2
        features = F.normalize(features, dim = 1)
        similarity_matrix = torch.matmul(features, features.T)

        labels = torch.cat([torch.arange(batch_size) for i in range(2)], dim=0)  #0,1,2, N-1 | 0,1,2 N-1 -> label associata ad ogni elemento
        labels = (labels.unsqueeze(0) == labels.unsqueeze(1)).float()
        mask = torch.eye(labels.shape[0], dtype=torch.bool)
        labels = labels[~mask].view(labels.shape[0], -1)
        similarity_matrix = similarity_matrix[~mask].view(similarity_matrix.shape[0], -1)

        positives = similarity_matrix[labels.bool()].view(similarity_matrix.shape[0], -1)
        negatives = similarity_matrix[~labels.bool()].view(similarity_matrix.shape[0], -1)

        logits = torch.cat(positives, negatives, dim=1)
        logits /= self.temperature
        gt = torch.zeros(logits.shape[0], dtype=torch.long) # positive in first position
        loss = self.criterion(logits, gt)

        return loss

In [None]:

dataloader = DataLoader(trainset, batch_size=64, shuffle=True)

model = SiameseNetSymmetric(backbone)
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = ContrastiveLoss()


for idx, data in enumerate(dataloader):
    img1, img2, _ , _ = data


    f1, f2 = model(img1, img2)
    # loss (f1, f2)

    features = torch.cat((f1, f2), dim=0)
    output = criterion(features)
    break
    # backprop

    # optimizer.zero_grad()
    # output = model(images)
    # loss = criterion(output, target)
    # loss.backward()
    # optimizer.step()

    if idx == 3:
        break

IndexError: only integers, slices (`:`), ellipsis (`...`), None and long or byte Variables are valid indices (got float)