In [None]:
import pandas as pd
import torch
import torchvision.transforms as T
from torch.utils.data import DataLoader
import numpy as np

In [None]:
def split_dataframe(df, ratio):
    assert(sum(ratio) == 1)
    return np.split(df, (np.cumsum(ratio[:-1]) * df.shape[0]).astype(int))

In [None]:
full_data = pd.read_csv('data/full.csv')
train_data, train_threshold_data, valid_data = split_dataframe(full_data, (0.87, 0.03, 0.1))

In [None]:
from DataHandlers import ImageDataset, InMemDataLoader

transform = T.Compose([
    T.CenterCrop((200, 200)),
    T.ToTensor(),
    T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

batch_size = 10

train_dataset = ImageDataset(train_data, transform=transform, path='data')
train_dataloader = DataLoader(
    train_dataset,
    batch_size=batch_size
)

valid_dataset = ImageDataset(valid_data, transform=transform, path='data')
train_threshold_dataset = ImageDataset(train_threshold_data, transform=transform, path='data')

In [None]:
import torch.nn as nn
import torch.nn.functional as F

class SiameseNetwork(nn.Module):
    def __init__(self):
        super(SiameseNetwork, self).__init__()
        self.layers = nn.Sequential(
            nn.Conv2d(3, 16, 10),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(16, 32, 4),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Flatten(),
            nn.Linear(67712, 50)
        )

    def forward_one(self, x):
        x = self.layers(x)
        return x

    def forward(self, input0, input1):
        output0 = self.forward_one(input0)
        output1 = self.forward_one(input1)
        return F.pairwise_distance(output0, output1)

In [None]:
class ContrastiveLoss(nn.Module):
    def __init__(self, margin=2.0):
        super(ContrastiveLoss, self).__init__()
        self.margin = margin

    def forward(self, distance, label):
        loss_contrastive = torch.mean((1 - label) * torch.pow(distance, 2) +
                                      label * torch.pow(torch.clamp(self.margin - distance, min=0), 2))


        return loss_contrastive

In [None]:
device = torch.device('mps')

In [None]:

from tqdm import tqdm
from torchmetrics.classification import BinaryF1Score
from sklearn.linear_model import LogisticRegression

def get_theshold(model, train_theshold_dataset, batch_size):
    model.eval()
    with torch.no_grad():
        dataloader = DataLoader(
            train_theshold_dataset,
            batch_size=batch_size
        )
        distances = []
        labels = []
        for img0, img1, label in tqdm(dataloader):
            distance = model(img0.to(device), img1.to(device))
            distances.append(distance.cpu())
            labels.append(label)

        distances = torch.cat(distances)
        labels = torch.cat(labels)
        log_reg = LogisticRegression(penalty=None)
        log_reg.fit(distances.reshape((-1, 1)), labels)
        
        theshold = (-log_reg.intercept_ / log_reg.coef_).item()
        return theshold

def evaluate(model, dataset, threshold, batch_size, max_batches = 50):
    model.eval()
    with torch.no_grad():
        dataloader = DataLoader(
            dataset,
            batch_size=batch_size
        )
        pos_f1 = BinaryF1Score()
        neg_f1 = BinaryF1Score()
        for (img0, img1, label), _ in zip(tqdm(dataloader), range(max_batches)):
            distance = model(img0.to(device), img1.to(device)).cpu()
            pos_f1.update(distance < threshold, label)
            neg_f1.update(distance > threshold, 1 - label)
    return (pos_f1.compute() + neg_f1.compute()) / 2

model = SiameseNetwork().to(device)
criterion = ContrastiveLoss().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

torch.set_printoptions(precision=3)

for epoch in range(20):
    model.train()
    for img0, img1, label in tqdm(train_dataloader):
        img0 = img0.to(device)
        img1 = img1.to(device)
        label = label.to(device)
        optimizer.zero_grad()
        outputs = model(img0, img1)
        loss = criterion(outputs, label)
        loss.backward()
        optimizer.step()

    theshold = get_theshold(model, train_threshold_dataset, batch_size)
    print(f'Epoch {epoch} | Loss:{loss.item()} | Eval accuracy: {evaluate(model, valid_dataset, theshold, batch_size)} | Train accuracy: {evaluate(model, train_dataset, theshold, batch_size)}')