# Siamese Neural Network

> An architecture for comparing pairs of inputs

In [None]:
#| default_exp siamese

In [None]:
#| hide
from nbdev.showdoc import *
from similarity_learning.utils import *

In [None]:
#| hide
fix_notebook_widgets()

In [None]:
#| export
from torch import nn
from torch.nn.functional import normalize

from fastai.vision.all import *


def normalized_squared_euclidean_distance(x1, x2):
    """
    Squared Euclidean distance over normalized vectors:
    $$\| x_1/\|x_1\|-x_2/\|x_2\| \|^2 $$
    """
    assert x1.dim() <= 2
    assert x2.dim() <= 2
    x1 = normalize(x1, dim=-1)
    x2 = normalize(x2, dim=-1)
    return (x1 - x2).pow(2).sum(dim=-1)


class ContrastiveLoss(BaseLoss):
    @delegates(nn.HingeEmbeddingLoss)
    def __init__(self, **kwargs):
        super().__init__(loss_cls=nn.HingeEmbeddingLoss, **kwargs)

    def __call__(self, input, target):
        return super().__call__(input, 2*target-1)  # hinge_embedding_loss expects targets to be 1 or -1


# @dataclass
class DistanceSiamese(Module):
    """Outputs the distance between two inputs in feature space"""
    def __init__(self, 
                 backbone: Module,  # embeds inputs in a feature space
                 distance_metric = normalized_squared_euclidean_distance):
        self.backbone = backbone
        self.distance_metric = distance_metric

    def forward(self, x):
        f1, f2 = L(x).map(self.backbone).map(nn.Flatten())
        return self.distance_metric(f1, f2)


In [None]:
from fastai_datasets.all import *

In [None]:
classifier = resnet34(weights=ResNet34_Weights.DEFAULT)
siamese = DistanceSiamese(create_body(model=classifier, cut=-1)).cuda()

In [None]:
pairs = Pairs(Imagenette(160), .1)
dls = pairs.dls(after_item=Resize(128),
                after_batch=Normalize.from_stats(*imagenet_stats))

Class map: scanning targets: 0it [00:00, ?it/s]

Generating positive pairs:   0%|          | 0/473 [00:00<?, ?it/s]

Generating negative pairs:   0%|          | 0/473 [00:00<?, ?it/s]

Class map: scanning targets: 0it [00:00, ?it/s]

Generating positive pairs:   0%|          | 0/196 [00:00<?, ?it/s]

Generating negative pairs:   0%|          | 0/196 [00:00<?, ?it/s]

When starting with a decent backbone, positive pairs are closer than negative pairs:

In [None]:
x, y = dls.one_batch()
positive_pairs = x[0][y==1], x[1][y==1]
negative_pairs = x[0][y==0], x[1][y==0]
siamese(positive_pairs).mean().item(), siamese(negative_pairs).mean().item()

(0.9166696071624756, 1.1699669361114502)

Train with contrastive loss:

In [None]:
learn = Learner(dls, siamese, ContrastiveLoss(margin=1.5))
learn.fit_one_cycle(3)

epoch,train_loss,valid_loss,time
0,0.441342,0.366875,00:09
1,0.355776,0.414326,00:09
2,0.301716,0.283433,00:09


In [None]:
x, y = dls.one_batch()
positive_pairs = x[0][y==1], x[1][y==1]
negative_pairs = x[0][y==0], x[1][y==0]
siamese(positive_pairs).mean().item(), siamese(negative_pairs).mean().item()

(0.2568347156047821, 1.5866713523864746)

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()