In [8]:
import torch
import torch.nn as nn
import torchvision.models as models
from torch.utils.data import Dataset, DataLoader
import os
import pandas as pd
from torchvision.transforms import ToTensor, Resize
from PIL import Image

In [9]:
class SiameseDataset(Dataset):
    def __init__(self, folder1, folder2, dataframe):
        self.folder1 = folder1
        self.folder2 = folder2
        self.dataframe = dataframe
        self.transform = ToTensor()

        self.image_pairs = self.get_image_pairs()

    def get_image_pairs(self):
        image_pairs = []
        for image_name in os.listdir(self.folder1):
            if image_name.endswith(".jpeg"):
                image_name_without_ext = os.path.splitext(image_name)[0]
                image_path1 = os.path.join(self.folder1, image_name).replace("\\", "/")
                image_path2 = os.path.join(self.folder2, image_name).replace("\\", "/")
                # image_path1 = os.path.join(self.folder1, image_name)
                # image_path2 = os.path.join(self.folder2, image_name)
                is_same = self.dataframe.loc[int(image_name_without_ext), "is_same"]
                image_pairs.append((image_path1, image_path2, is_same, image_name))
        return image_pairs


    def __len__(self):
        return len(self.image_pairs)

    def __getitem__(self, index):
        image_path1, image_path2, is_same, image_name = self.image_pairs[index]
        image1 = self.transform(Image.open(image_path1))
        image2 = self.transform(Image.open(image_path2))

        # Apply resize_transform to ensure consistent image size
        image1 = resize_transform(image1)
        image2 = resize_transform(image2)


        return image1, image2, is_same, image_name

In [10]:
# Step 1: Prepare the dataset
loaded_dataset = torch.load('I:/CSC Hackathon/siamese_dataset.pt')

In [11]:
# Step 3: Define the Siamese network loss
class ContrastiveLoss(nn.Module):
    def __init__(self, margin):
        super(ContrastiveLoss, self).__init__()
        self.margin = margin

    def forward(self, embedding1, embedding2, target):
        euclidean_distance = nn.functional.pairwise_distance(embedding1, embedding2)
        loss_contrastive = torch.mean((1 - target) * torch.pow(euclidean_distance, 2) +
                                       target * torch.pow(torch.clamp(self.margin - euclidean_distance, min=0.0), 2))
        return loss_contrastive

class SiameseNetwork(nn.Module):
    def __init__(self):
        super(SiameseNetwork, self).__init__()
        self.backbone = models.resnet50(pretrained=True)
        self.embedding_size = self.backbone.fc.in_features
        self.backbone.fc = nn.Identity()

    def forward(self, x1, x2):
        embedding1 = self.backbone(x1)
        embedding2 = self.backbone(x2)
        return embedding1, embedding2

In [17]:
# Step 4: Prepare the data loaders
resize_transform = Resize((600, 800))
batch_size = 8 # 32
data_loader = DataLoader(loaded_dataset, batch_size=batch_size, shuffle=True)

# Step 5: Train the Siamese network
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = SiameseNetwork().to(device)
criterion = ContrastiveLoss(margin=1.0)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [None]:
num_epochs = 5 # change it to 10 / 20
for epoch in range(num_epochs):
    for batch in data_loader:
        images1, images2, targets, _ = batch
        images1 = images1.to(device)
        images2 = images2.to(device)
        targets = targets.to(device)

        optimizer.zero_grad()
        embeddings1, embeddings2 = model(images1, images2)
        loss = criterion(embeddings1, embeddings2, targets)
        loss.backward()
        optimizer.step()

        # Print training progress
        print(f"Epoch [{epoch+1}/{num_epochs}], Batch Loss: {loss.item():.4f}")