In [None]:
"""
    Basic imports
"""

import requests
from io import BytesIO
from PIL import Image
import os

from tqdm.auto import tqdm
import matplotlib.pyplot as plt
from IPython.display import display
from IPython.display import clear_output
import ipywidgets as widgets

import pandas as pd
import numpy as np

import torch
import torchvision
import torch.nn as nn
import torchvision.transforms as T
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torchmetrics.classification import BinaryF1Score

from sklearn.metrics import f1_score
from sklearn.metrics import confusion_matrix
from sklearn.linear_model import LogisticRegression

from DataHandlers import ImageDataset, InMemDataLoader

In [None]:
def split_dataframe(df, ratio, shuffle=True):
    assert(sum(ratio) == 1)
    if shuffle:
        df = df.sample(frac=1)
    return np.split(df, (np.cumsum(ratio[:-1]) * df.shape[0]).astype(int))

def denormalize_tensor(img):
    return (img.permute(1, 2, 0) + 1) / 2

In [None]:
batch_size = 32
max_submit_id = 22661

In [None]:
transform = T.Compose([
    T.Resize(200),
    T.CenterCrop(200),
    T.ToTensor(),
    T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

df_slices = list(split_dataframe(pd.read_csv('data/train.csv'), (0.9, 0.01, 0.04, 0.05)))

datasets = [ImageDataset(df, transform=transform) for df in df_slices]
loaders = [DataLoader(dataset, batch_size=batch_size, shuffle=True) for dataset in datasets]

train_dataset, train_threshold_dataset, valid_dataset, test_dataset = datasets
train_loader, train_threshold_loader, valid_loader, test_loader = loaders

submit_df = pd.read_csv('data/submit.csv')
submit_dataset = ImageDataset(submit_df, transform=transform)
submit_loader = DataLoader(submit_dataset, batch_size=batch_size, shuffle=True)


In [None]:
def test_loader_images(loader):
    print('Examples of images from loader:')
    image1, image2, _ = next(iter(loader))
    image1 = image1[0]
    image2 = image2[0]
    _, axs = plt.subplots(1, 2, figsize=(15, 15))

    axs[0].imshow(denormalize_tensor(image1))
    axs[0].set_title('Image 1')
    axs[0].axis('off')
    axs[1].imshow(denormalize_tensor(image2))
    axs[1].set_title('Image 2')
    axs[1].axis('off')
    plt.show()

        
"""
    Tries to find and show mislabeled images from the specified loader.
"""
def mislabeled(model, loader):
    def mislabeled_inner():
        for images1, images2, equal in loader:
            preds = model.predict(images1, images2)
            for index in (preds != equal).nonzero().reshape(-1).tolist():
                yield (images1[index], images2[index], preds[index], equal[index])

    button = widgets.Button(description="Next Images")
    output = widgets.Output()

    def on_button_clicked(b):
        with output:
            clear_output()
            image1, image2, pred, truth = next(mislabeled_gen)
            fig, axs = plt.subplots(1, 3, figsize=(14,7))
            axs[0].imshow(denormalize_tensor(image1))
            axs[1].imshow(denormalize_tensor(image2))
            axs[2].imshow(np.abs(denormalize_tensor(image1 - image2)))
            suptitle = f'Predicted: {pred.item()}\nTruth: {truth.item()}'
            # try:
            #     suptitle += f'\nmodel.forward(): {model.forward(image1, image2)}'
            # except:
            #     pass
            fig.suptitle(suptitle)
            plt.show()

    button.on_click(on_button_clicked)
    display(button, output)
    mislabeled_gen = mislabeled_inner()
    on_button_clicked(None)  # show the first images


"""
    Saves predictions that should be submitted to kaggle.
"""
def save_test_preds(model, loader, max_submit_id, path='res.csv'):
    print(f'Started saving test predictions to {path}')
    ids = []
    preds = []
        
    for images1, images2, id_ in tqdm(loader):
        preds.extend(model.predict(images1, images2))
        ids.extend(id_)
        
    all_ids = pd.DataFrame({
        'ID': range(2, max_submit_id + 1),
    })
    res = pd.DataFrame({
        'ID': [obj.item() for obj in ids],
        'is_same': [obj.item() for obj in preds]
    }).drop_duplicates()

    res = all_ids.merge(res, on='ID', how='left').fillna(0)
    res.to_csv(path, index=False)
    print(f'Saved test predictions to {path}\n')

In [None]:
from itertools import islice

class ContrastiveLoss(nn.Module):
    def __init__(self, margin=2.0):
        super(ContrastiveLoss, self).__init__()
        self.margin = margin

    def forward(self, distance, label):
        loss_contrastive = torch.mean(label * torch.pow(distance, 2) +
                                      (1 - label) * torch.pow(torch.clamp(self.margin - distance, min=0), 2))

        return loss_contrastive

class SiameseNetworkClassifier(nn.Module):
    def __init__(self, device='mps'):
        super(SiameseNetworkClassifier, self).__init__()
        
        self.layers = nn.Sequential(
            nn.Conv2d(3, 16, 10),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(16, 32, 4),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Flatten(),
            nn.Linear(67712, 50)
        )
        self.threshold = torch.tensor(0.)
        
        self.device = torch.device(device)
        self.to(self.device)

    def forward(self, images1, images2):
        output1 = self.layers(images1)
        output2 = self.layers(images2)
        return F.pairwise_distance(output1, output2)
        
    def update_threshold(self, loader, max_batches=None):
        self.eval()
        with torch.no_grad():
            distances = []
            labels = []
            for images1, images2, equals in islice(tqdm(loader, desc='Calculating threshold'), max_batches):
                distance = self.forward(images1.to(self.device), images2.to(self.device))
                distances.append(distance.cpu())
                labels.append(equals)
    
            distances = torch.cat(distances)
            labels = torch.cat(labels)
            log_reg = LogisticRegression(penalty=None)
            log_reg.fit(distances.reshape((-1, 1)), labels)
            self.threshold = (-log_reg.intercept_ / log_reg.coef_).item()

    # TODO refactor this method so we don't have to call .to(self.device) ? 
    def predict(self, images1, images2):
        self.eval()
        with torch.no_grad():
            images1 = images1.to(self.device)
            images2 = images2.to(self.device)
            distances = self.forward(images1, images2)
            return (distances < self.threshold).int().cpu()

def evaluate(model, loader, max_batches=None):
    model.eval()
    with torch.no_grad():
        pos_f1 = BinaryF1Score()
        neg_f1 = BinaryF1Score()
        for images1, images2, label in islice(tqdm(loader, desc='Evaluating model'), max_batches):
            distance = model.forward(images1.to(model.device), images2.to(model.device)).cpu()
            pos_f1.update(distance < model.threshold, label)
            neg_f1.update(distance > model.threshold, 1 - label)
    return (pos_f1.compute() + neg_f1.compute()) / 2
    
def train(model, train_loader, train_threshold_loader, valid_loader, test_loader, epochs=20, lr=1e-4, max_batches=None):
    print("Debug: Initializing ContrastiveLoss and Optimizer")
    criterion = ContrastiveLoss().to(model.device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    for epoch in range(epochs):
        for images1, images2, label in islice(tqdm(train_loader, desc='Training model'), max_batches):
            model.train()
            images1 = images1.to(model.device)
            images2 = images2.to(model.device)
            label = label.to(model.device)

            optimizer.zero_grad()
            outputs = model.forward(images1, images2)
            loss = criterion(outputs, label)
            loss.backward()
            optimizer.step()

        model.update_threshold(train_threshold_loader, max_batches=max_batches)
        
        print(f'Epoch {epoch} | Loss:{loss.item()}')
        print(f'Train accuracy: {evaluate(model, train_loader, max_batches=max_batches)}')
        print(f'Valid accuracy: {evaluate(model, valid_loader, max_batches=max_batches)}')
    print(f'Test accuracy: {evaluate(model, test_loader, max_batches=max_batches)}')

In [None]:
model = SiameseNetworkClassifier()
train(model, *loaders, epochs=1, max_batches=1)

In [None]:
"""
    Check on what's wrong with our model
"""

mislabeled(model, valid_loader)