In [None]:
import pandas as pd
import torch
import torchvision.transforms as T
from torch.utils.data import DataLoader, Dataset

from PIL import Image

In [None]:
full_data = pd.read_csv('data/full.csv')
train_num = int(full_data.shape[0] * 0.9)
train_data = full_data[:train_num]
valid_data = full_data[train_num:]
train_data.to_csv('data/train.csv')
valid_data.to_csv('data/valid.csv')

In [None]:
from DataHandlers import ImageDataset, InMemDataLoader

transform = T.Compose([
    T.CenterCrop((200, 200)),
    T.ToTensor(),
    T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

batch_size = 10

train_dataset = ImageDataset(pd.read_csv('data/train.csv'), transform=transform)
train_dataloader = DataLoader(
    train_dataset,
    batch_size=batch_size
)

valid_dataset = ImageDataset(pd.read_csv('data/valid.csv'), transform=transform)

In [None]:
import torch.nn as nn
import torch.nn.functional as F

class SiameseNetwork(nn.Module):
    def __init__(self):
        super(SiameseNetwork, self).__init__()
        self.layers = nn.Sequential(
            nn.Conv2d(3, 64, 10),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(64, 128, 7),
            nn.ReLU(),    
            nn.MaxPool2d(2),
            nn.Conv2d(128, 256, 4),
            nn.ReLU(), 
            nn.MaxPool2d(2),
            nn.Flatten()
        )

    def forward_one(self, x):
        x = self.layers(x)
        return x

    def forward(self, input0, input1):
        output0 = self.forward_one(input0)
        output1 = self.forward_one(input1)
        return F.pairwise_distance(output0, output1)

In [None]:
class ContrastiveLoss(nn.Module):
    def __init__(self, margin=2.0):
        super(ContrastiveLoss, self).__init__()
        self.margin = margin

    def forward(self, distance, label):
        loss_contrastive = torch.mean((1 - label) * torch.pow(distance, 2) +
                                      label * torch.pow(torch.clamp(self.margin - distance, min=0), 2))


        return loss_contrastive

In [None]:
device = torch.device('mps')

In [None]:
from tqdm import tqdm

model = SiameseNetwork().to(device)
criterion = ContrastiveLoss().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

threshold = 0.2

def evaluate(model, dataset, batch_size, threshold):
    model.eval()
    with torch.no_grad():
        dataloader = DataLoader(
            dataset,
            batch_size=batch_size
        )
        correct = 0
        for img0, img1, label in tqdm(dataloader):
            img0 = img0.to(device)
            img1 = img1.to(device)
            label = label.to(device)
            distance = model(img0, img1)
            correct += ((distance < threshold) == label).count_nonzero()
    return correct / len(dataset)

for epoch in range(0, 10):
    print(epoch)
    model.train()
    for (img0, img1, label), _ in tqdm(zip(train_dataloader, range(50))):
        img0 = img0.to(device)
        img1 = img1.to(device)
        label = label.to(device)
        optimizer.zero_grad()
        outputs = model(img0, img1)
        loss = criterion(outputs, label)
        loss.backward()
        optimizer.step()

    print(f'Epoch {epoch} | Loss:{loss.item()} | Eval accuracy: {evaluate(model, valid_dataset, batch_size, threshold)}')