# Siamese Network



---

In this session, we are going to implement a Siamese Network.

It takes as input two augmented versions of the same image and produces as output two feature vectors one for each version of the image.

For simplicity, we will use the same backbone to process the views as in SimCLR paper.



In [None]:
import numpy as np

import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.optim as optim

from torchvision.io import read_image

import torchvision
import torchvision.models as models
import torchvision.transforms as transforms

In [None]:
# you can use a resnet18 as backbone
backbone = models.resnet18()

#! remember to delete the fc layer (we need just the CNN layers + flatten)
backbone.fc = nn.Identity()
print(backbone)

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [None]:
class SiameseNetSIM(nn.Module):
    def __init__(self, backbone):
        super().__init__()
        self.encoder = backbone

    def forward(self, x1, x2):
        images = torch.cat((x1, x2), 0)
        features = self.encoder(images)
        return features

    # if pairs are concatenated before, use tihs
    '''
    def forward(self, x):
        return self.encoder(x)
    '''

class SiameseNetASIM(nn.Module):
    def __init__(self, backbone, backbone2):
        super().__init__()
        self.encoder = backbone
        self.encoder2 = backbone2

    def forward(self, x1, x2):
        features = self.encoder(x1)
        features2 = self.encoder2(x2)
        return torch.cat((features, features2), 0)


# Check output shape
features_sim = SiameseNetSIM(backbone)(torch.randn(5, 3, 32, 32), torch.randn(5, 3, 32, 32))
print(features_sim.shape)

backbone2 = models.resnet18()
backbone2.fc = nn.Identity()
features_asim = SiameseNetASIM(backbone, backbone2)(torch.randn(5, 3, 32, 32), torch.randn(5, 3, 32, 32))
print(features_asim.shape)

torch.Size([10, 512])
torch.Size([10, 512])


Let's now use the Dataset which creates the two augmented views for each image from the [past lab session](https://colab.research.google.com/drive/1NJwAFbRiD4MdwWf__6P2Lm0xYk_DNdVu?usp=sharing) and create a loop with forward pass

In [None]:
class CustomImageDataset(Dataset):
    def __init__(self, data, targets, transform=None, target_transform=None):
        self.imgs = data
        self.targets = targets
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.targets)

    def __getitem__(self, idx):
        img_base = self.imgs[idx]
        if isinstance(img_base, str):
          img_base = read_image(img_base)
        label = self.targets[idx]
        if self.transform:
            img1 = self.transform(img_base)
            img2 = self.transform(img_base)
        else:
            img1 = img_base
            img2 = img_base
        if self.target_transform:
            label = self.target_transform(label)
        return img1, img2, label


data = torchvision.datasets.CIFAR10(root='./data', train=True, download=True)
size = 32

# simclr DA pipeline
s=1
color_jitter = transforms.ColorJitter(0.8 * s, 0.8 * s, 0.8 * s, 0.2 * s)
transform = transforms.Compose([transforms.ToTensor(),
                                  transforms.RandomResizedCrop(size=size),
                                  transforms.RandomHorizontalFlip(),
                                  transforms.RandomApply([color_jitter], p=0.8),
                                  transforms.RandomGrayscale(p=0.2),
                                  transforms.GaussianBlur(kernel_size=int(0.1 * size))])

# create training set from CustomDataset
trainset = CustomImageDataset(data.data, data.targets, transform=transform)

Files already downloaded and verified


In [None]:
dataloader = DataLoader(trainset, batch_size=64, shuffle=True)

model = SiameseNetSIM(backbone)

for idx, data in enumerate(dataloader):
    views1, views2 , targets = data
    print(views1.shape)
    print(views2.shape)

    output = model(views1, views2)
    print(output.shape)

    if idx == 3:
        break

torch.Size([64, 3, 32, 32])
torch.Size([64, 3, 32, 32])
torch.Size([128, 512])
torch.Size([64, 3, 32, 32])
torch.Size([64, 3, 32, 32])
torch.Size([128, 512])
torch.Size([64, 3, 32, 32])
torch.Size([64, 3, 32, 32])
torch.Size([128, 512])
torch.Size([64, 3, 32, 32])
torch.Size([64, 3, 32, 32])
torch.Size([128, 512])


In [None]:
model = SiameseNetASIM(backbone, backbone2)

for idx, data in enumerate(dataloader):
    views1, views2 , targets = data
    print(views1.shape)
    print(views2.shape)

    output = model(views1, views2)
    print(output.shape)

    if idx == 3:
        break

torch.Size([64, 3, 32, 32])
torch.Size([64, 3, 32, 32])
torch.Size([128, 512])
torch.Size([64, 3, 32, 32])
torch.Size([64, 3, 32, 32])
torch.Size([128, 512])
torch.Size([64, 3, 32, 32])
torch.Size([64, 3, 32, 32])
torch.Size([128, 512])
torch.Size([64, 3, 32, 32])
torch.Size([64, 3, 32, 32])
torch.Size([128, 512])
