# CASE 1

In [1]:
import torch
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
from PIL import Image
import os

In [2]:
class SiameseNetwork(nn.Module):
    def __init__(self):
        super(SiameseNetwork, self).__init__()
        self.resnet = models.resnet50(pretrained=True)
        # remove the final classification layer
        self.resnet = nn.Sequential(*list(self.resnet.children())[:-1])
        # add a fully connected layer to output an embedding
        self.fc = nn.Sequential(
            nn.Linear(2048, 512),
            nn.ReLU(),
            nn.Linear(512, 128)
        )
    
    def forward_one(self, x):
        x = self.resnet(x)
        # flatten the output
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x
    
    def forward(self, x1, x2):
        output1 = self.forward_one(x1)
        output2 = self.forward_one(x2)
        return output1, output2

In [3]:
class ContrastiveLoss(nn.Module):
    def __init__(self, margin=2.0):
        super(ContrastiveLoss, self).__init__()
        self.margin = margin
        
    def forward(self, output1, output2, label):
        euclidean_distance = nn.functional.pairwise_distance(output1, output2)
        loss_contrastive = torch.mean((1-label) * torch.pow(euclidean_distance, 2) + (label) * torch.pow(torch.clamp(self.margin - euclidean_distance, min=0.0), 2))
        return loss_contrastive

In [6]:
class SiameseDataset(Dataset):
    def __init__(self, image_folder, transform=None):
        self.image_folder = image_folder
        self.transform = transform
        self.image_paths, self.labels = self._prepare_data()
        self.image_pairs = self._create_pairs()
        
    def _prepare_data(self):
        image_paths = []
        labels = []
        for label, class_folder in enumerate(os.listdir(self.image_folder)):
            class_path = os.path.join(self.image_folder, class_folder)
            for img_name in os.listdir(class_path):
                img_path = os.path.join(class_path, img_name)
                image_paths.append(img_path)
                labels.append(label)
        return image_paths, labels
    
    def _create_pairs(self):
        pairs = []
        num_classes = len(set(self.labels))
        class_to_images = {i:[] for i in range(num_classes)}
        
        for img_path, label in zip(self.image_paths, self.labels):
            class_to_images[label].append(img_path)
            
        for label in class_to_images:
            images = class_to_images[label]
            for i in range(len(images)):
                for j in range(i + 1, len(images)):
                    # positive pair(same class)
                    pairs.append((images[i], images[i], 0))
                    
        for label in class_to_images:
            images = class_to_images[label]
            for i in range(len(images)):
                neg_label = random.choice([l for l in class_to_images if l != label])
                neg_image = random.choice(class_to_images[neg_label])
                # negative pairs(different classes)
                pairs.append((images[i], neg_image, 1))
                
        random.shuffle(pairs)
        return pairs
    
    def __len__(self):
        return len(self.image_pairs)
    
    def __getitem__(self, idx):
        img1_path, img2_path, label = self.image_pairs[idx]
        img1 = Image.open(img1_path).convert("RGB")
        img2 = Image.open(img2_path).convert("RGB")
        
        if self.transform:
            img1 = self.transform(img1)
            img2 = self.transform(img2)
            
        return img1, img2, torch.tensor(label, dtype = torch.float32)

# CASE 2

In [8]:
class EmbeddingNet(nn.Module):
    def __init__(self):
        super(EmbeddingNet, self).__init__()
        self.model = models.resnet50(pretrained=True)
        self.model.fc = nn.Linear(self.model.fc.in_features, 128)
        
    def forward(self, x):
        return self.model(x)

In [9]:
class TripletLoss(nn.Module):
    def __init__(self, margin=1.0):
        super(TripletLoss, self).__init__()
        self.margin = margin
        
    def forward(self, anchor, positive, negative):
        positive_distance = torch.nn.functional.pairwise_distance(anchor, positive, p=2)
        negative_distance = torch.nn.functional.pairwise_distance(anchor, negative, p=2)
        
        loss = torch.relu(positive_distance - negative_distance + self.margin)
        return loss.mean()

In [10]:
class TripletDataset(Dataset):
    def __init__(self, anchor_images, positive_images, negative_images, transform=None):
        self.anchor_images = anchor_images
        self.positive_images = positive_images
        self.negative_images = negative_images
        self.transform = transform
        
    def __len__(self):
        return len(self.anchor_images)
    
    def __getitem__(self, idx):
        anchor = Image.open(self.anchor_images[idx])
        positive = Image.open(self.positive_images[idx])
        negative = Image.open(self.negative_images[idx])
        
        if self.transform:
            anchor = self.transform(anchor)
            positive = self.transform(positive)
            negative = self.transform(negative)
            
        return anchor, positive, negative