In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import random
import torchvision.transforms as transforms
import matplotlib.pyplot as plt

from torch.utils.data import DataLoader, Dataset
from torchvision.datasets import MNIST
from torch import optim

In [16]:
mnist_train = MNIST(root='./data', train=True, download=True)
mnist_test = MNIST(root='./data', train=False, download=True)

transform = transforms.Compose([transforms.ToTensor()])

In [17]:
class SiameseDataset(Dataset):
    def __init__(self, data, transform=None):
        self.data = data
        self.transform = transform
    
    def __len__(self):
        return len(self.data)  # Add this method
        

    def __getitem__(self, index):
    
        imgA, labelA = self.data[index]
            
        same_class_flag = random.randint(0, 1) # pair with same class?
        
        if same_class_flag: # yes, pair with same class
            labelB = -1
            while labelB != labelA:
                imgB, labelB = random.choice(self.data)
                
        else: # no, pair with different class
            labelB = labelA
            while labelB == labelA:
                imgB, labelB = random.choice(self.data)

        if self.transform:
            imgA = self.transform(imgA)
            imgB = self.transform(imgB)
            
        pair_label = torch.tensor([(labelA != labelB)], dtype=torch.float32)
            
        return imgA, imgB, pair_label

In [18]:
siamese_train = SiameseDataset(mnist_train, transform)
siamese_test = SiameseDataset(mnist_test, transform)

In [19]:
class SiameseNetwork(nn.Module):
    def __init__(self):
        super(SiameseNetwork, self).__init__()

        self.cnn = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=5, stride=1, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, stride=2),

            nn.Conv2d(64, 128, kernel_size=5, stride=1, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, stride=2),

            nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, stride=2)
        )

        self.fc = nn.Sequential(
            nn.Linear(256 * 3 * 3, 1024),
            nn.ReLU(inplace=True),

            nn.Linear(1024, 256),
            nn.ReLU(inplace=True),

            nn.Linear(256, 2)
        )

    def forward_once(self, x):
        output = self.cnn(x)
        output = output.view(output.size()[0], -1)
        output = self.fc(output)
        return output

    def forward(self, inputA, inputB):
        outputA = self.forward_once(inputA)
        outputB = self.forward_once(inputB)
        return outputA, outputB

In [20]:
class ContrastiveLoss(torch.nn.Module):
    
    def __init__(self, margin=2.0):
        super(ContrastiveLoss, self).__init__()
        self.margin = margin

    def forward(self, outputA, outputB, y):
        euclidean_distance = F.pairwise_distance(outputA, outputB, keepdim = True)

        same_class_loss = (1-y) * (euclidean_distance**2)
        diff_class_loss = (y) * (torch.clamp(self.margin - euclidean_distance, min=0.0)**2)
    
        return torch.mean(same_class_loss + diff_class_loss)

In [25]:
train_dataloader = DataLoader(siamese_train, shuffle=True, num_workers=0, batch_size=64)
model = SiameseNetwork()
criterion = ContrastiveLoss()
optimizer = optim.Adam(model.parameters(), lr = 0.001)

In [None]:
for epoch in range(5):
    total_loss = 0
    
    for imgA, imgB, label in train_dataloader:

        # imgA, imgB, label = imgA.cuda(), imgB.cuda(), label.cuda()
        optimizer.zero_grad()
        outputA, outputB = model(imgA, imgB)
        loss_contrastive = criterion(outputA, outputB, label)
        loss_contrastive.backward()

        total_loss += loss_contrastive.item()
        optimizer.step()

    print(f"Epoch {epoch}; Loss {total_loss}")