## Simaese Network
- Siamese Network is a type of neural network that uses the same weights while working in tandem on two different input vectors to compute comparable output vectors.

In [69]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.models as models

input: 64x3x32x32
->
conv2d(3, 64, 3, 1, 1): 64x64x32x32
->
ReLU: 64x64x32x32
->
MaxPool2d(2, 2): 64x64x16x16

In [70]:
'''
Simple CNN
'''
class FeatureExtractorCNN(nn.Module):
    def __init__(self):
        """
        Args:
            conv1:
                input kernel = 3
                output kernel = 32
                kernel size = 3
                padding = 1
                stride = 1
            ...
            pool: Maxpooling 2x2
            fc: fully connected layer
                input: 128 * 8 * 8
                output: 128
        """
        super(FeatureExtractorCNN, self).__init__()
        self.conv1 = nn.Conv2d(3,32,kernel_size=3,padding=1,stride=1)
        self.conv2 = nn.Conv2d(32,64,kernel_size=3,padding=1,stride=1)
        self.conv3 = nn.Conv2d(64,128,kernel_size=3,padding=1,stride=1)
        self.pool = nn.MaxPool2d(2,2) # Maxpooling 2x2
        self.fc = nn.Linear(128*8*8,128)
        self._intialize_weight_()
    def _intialize_weight_(self): # weight initialize
        for module in self.modules():
            if isinstance(module, nn.Conv2d):
                nn.init.kaiming_normal_(module.weight, mode = 'fan_in', nonlinearity='relu')
                if module.bias is not None:
                    nn.init.constant_(module.bias,0)
            elif isinstance(module, nn.Linear):
                nn.init.kaiming_normal_(module.weight, mode = 'fan_in', nonlinearity='relu')
                nn.init.constant_(module.bias, 0)
    def forward(self, x):
        """
        Args:
            x: input vector
        Returns:
            x: output vector
        """
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = self.pool(F.relu(self.conv3(x)))
        # flatten
        x = x.view(x.size(0),-1)
        x = self.fc(x)
        return x

In [71]:
# resnet = models.resnet18(pretrained=True)
# print(*list(resnet.children())[:-1])

In [72]:
# """
# Use Resnet18 as feature extractor
# """
# class FeatureExtractorResnet18(nn.Module):
#     def __init__(self):
#         super(FeatureExtractorResnet18, self).__init__()
#         """
#         Args:
#             resnet: use resnet18 model
#             feature_extractor: Sequential
#         """
#         resnet = models.resnet18(pretrained=True)
#         self.feature_extractor = nn.Sequential(*list(resnet.children())[:-1]) # 마지막 fc layer 제거
#         self.fc = nn.Linear(512, 128)
#     def forward(self,x):
#         x = self.feature_extractor(x)
#         x = x.view(x.size(0), -1)
#         x = self.fc(x)
#         return x

In [73]:
"""
Siamese Network
"""
class SiameseNetwork(nn.Module):
    def __init__(self):
        super(SiameseNetwork, self).__init__()
        """
        Args:
            feature_extractor: simple CNN (can be Resnet etc)
            fc: 
                input: 128
                output: 1
        """
        self.feature_extractor = FeatureExtractorCNN()
        # self.feature_extractor = FeatureExtractorResnet18()
        self.fc = nn.Linear(128,1)
    def forward(self, input_1, input_2):
        """
        Args:
            feat1: feature of input1
            feat2: feature of input2
            distance: L1 distance
            sim: similarity between feat1 and feat2
        Returns:
            sim
        """
        feat1 = self.feature_extractor(input_1)
        feat2 = self.feature_extractor(input_2)
        # distance = torch.abs(feat1 - feat2)
        # sim = torch.sigmoid(self.fc(distance))
        sim = F.cosine_similarity(feat1,feat2)
        return sim

In [74]:
"""
deep metric learning
    - contrastive loss
"""
class ContrastiveLoss(nn.Module):
    def __init__(self, margin = 1):
        super(ContrastiveLoss,self).__init__()
        """
        Args:
            margin: contrastive loss
        """
        self.margin = margin
    def forward(self, sim, label):
        """
        Args:
            sim: similarity (result of fc)
            label: input name
        Returns:
            loss.mean(): mean of contrastive loss
        """
        loss = (1 - label) * torch.pow(sim,2) + label * torch.pow(torch.clamp(self.margin - sim, 0),2)
        return loss.mean()

In [75]:
"""
Initialize Model
"""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SiameseNetwork().to(device)
criterion = ContrastiveLoss()
optimizer = optim.Adam(model.parameters(), lr = .0001)

In [76]:
"""
Dummy Inputs &  Labels
"""
dummy_input1 = torch.randn(4, 3, 64, 64).to(device)  # Batch of 4 images
dummy_input2 = torch.randn(4, 3, 64, 64).to(device)  # Batch of 4 images
dummy_labels = torch.tensor([1, 0, 1, 0], dtype=torch.float32).to(device)  # 유사한 이미지(1), 다른 이미지(0)

In [77]:
"""
forward & loss
"""
output = model(dummy_input1, dummy_input2)
print("output:",output)
print("output.squeeze():",output.squeeze())
loss = criterion(output.squeeze(),dummy_labels)
print("Sim score:", output.squeeze().detach().cpu().numpy())
print("Loss:", loss.item())

output: tensor([0.9454, 0.9413, 0.9394, 0.9425], grad_fn=<SumBackward1>)
output.squeeze(): tensor([0.9454, 0.9413, 0.9394, 0.9425], grad_fn=<SqueezeBackward0>)
Sim score: [0.94535786 0.94132346 0.9393565  0.94249415]
Loss: 0.4452621340751648
