# Siamese Neural Networks for One-shot Image Recognition

In [1]:
from torchvision import datasets, transforms
from torch.utils.data import Dataset
import random
import torch

In [2]:
from torchvision import datasets, transforms
from torch.utils.data import Dataset
import torch
import random

class SiameseMNIST(Dataset):
    def __init__(self, train=True):
        self.mnist = datasets.MNIST(
            root="./data",
            train=train,
            download=True,
            transform=transforms.ToTensor()
        )
        self.train = train
        self.targets = self.mnist.targets
        self.data = self.mnist.data

    def __getitem__(self, index):
        img1 = self.data[index]
        label1 = self.targets[index]

        # Decide whether to generate a similar or different pair
        should_get_same_class = random.randint(0, 1)

        if should_get_same_class:
            while True:
                idx2 = random.randint(0, len(self.data) - 1)
                if self.targets[idx2] == label1 and idx2 != index:
                    break
        else:
            while True:
                idx2 = random.randint(0, len(self.data) - 1)
                if self.targets[idx2] != label1:
                    break

        img2 = self.data[idx2]
        label2 = self.targets[idx2]

        # Similar = 0, Dissimilar = 1
        target = torch.tensor(int(label1 != label2), dtype=torch.float32)

        return img1.unsqueeze(0).float() / 255.0, img2.unsqueeze(0).float() / 255.0, target 

    def __len__(self):
        return len(self.data)


In [3]:
# Siamese Network
import torch.nn as nn
import torch.nn.functional as F

In [4]:
class SiameseNet(nn.Module):
    def __init__(self):
        super(SiameseNet, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(1,16,3),
            nn.ReLU(),
            nn.MaxPool2d(2),

            nn.Conv2d(16,32,3),
            nn.ReLU(),
            nn.MaxPool2d(2),
        )
        self.fc = nn.Sequential(
            nn.Flatten(),
            nn.Linear(32*5*5,256),
            nn.ReLU(),
            nn.Linear(256,128)
        )
    def forward_once(self,x):
        x = self.conv(x)
        x = self.fc(x)
        return x
    def forward(self, img1, img2):
        feat1 = self.forward_once(img1)
        feat2 = self.forward_once(img2)
        return feat1, feat2

In [5]:
class ContrastiveLoss(nn.Module):
    def __init__(self, margin=2.0):
        super(ContrastiveLoss, self).__init__()
        self.margin = margin

    def forward(self, feat1, feat2, label):
        dist = F.pairwise_distance(feat1, feat2)
        loss = (1 - label) * torch.pow(dist, 2) + label * torch.pow(torch.clamp(self.margin - dist, min=0.0), 2)
        return loss.mean()


In [6]:
from torch.utils.data import DataLoader
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

dataset = SiameseMNIST(train=True)
loader = DataLoader(dataset, batch_size=64, shuffle=True)

model = SiameseNet().to(device)
criterion = ContrastiveLoss()
optimizer = torch.optim.Adam(model.parameters(),lr=0.001)

for epoch in range(10):
    total_loss = 0
    for img1, img2, label in loader:
        img1, img2, label = img1.to(device), img2.to(device), label.to(device)

        feat1,feat2 = model(img1,img2)
        loss = criterion(feat1,feat2,label)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(f"Epoch {epoch+1} | Loss: {total_loss/len(loader):.4f}")

Epoch 1 | Loss: 0.2333
Epoch 2 | Loss: 0.1071
Epoch 3 | Loss: 0.0822
Epoch 4 | Loss: 0.0680
Epoch 5 | Loss: 0.0546
Epoch 6 | Loss: 0.0482
Epoch 7 | Loss: 0.0429
Epoch 8 | Loss: 0.0361
Epoch 9 | Loss: 0.0331
Epoch 10 | Loss: 0.0310


In [7]:
def is_same(model, img1, img2, threshold=1.0):
    model.eval()
    with torch.no_grad():
        f1,f2 = model(img1.unsqueeze(0).to(device), img2.unsqueeze(0).to(device))
        dist = F.pairwise_disatance(f1,f2).item
        return dist<threshold