In [99]:
"""
    Basic imports
"""

import requests
from io import BytesIO
from PIL import Image
import os

from tqdm.auto import tqdm
import matplotlib.pyplot as plt
from IPython.display import display
from IPython.display import clear_output
import ipywidgets as widgets

import pandas as pd
import numpy as np

import torch
import torchvision
import torch.nn as nn
import torchvision.transforms as T
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torchmetrics.classification import BinaryF1Score

from sklearn.metrics import f1_score
from sklearn.metrics import confusion_matrix
from sklearn.linear_model import LogisticRegression

from DataHandlers import ImageDataset, InMemDataLoader

In [91]:
"""
    Wrapper class wraps the model and simplifies the testing, training and debugging routine.
    The provided model should have following methods in order to work correctly with the Wrapper:
        
        1. method _train: dict(loader) -> Unit
            Method takes only argument `loaders` - dict consisting of train, valid, test, submit loaders
            
            train, valid, test loaders provide batches of first images from the pair, second images and 
            labels whether those images are equal
            
        2. method predict: images1, images2 -> tensor of 0/1
        
"""

class Wrapper():

    """
        Construct loaders objects, set local variables.
    """
    def __init__(
                self, 
                model, 
                paths = {'train': 'data/train.csv', 'submit': 'data/submit.csv', 'small': 'data/small.csv'},
                transform = None,
                batch_size = 32,
            ):
        
        self.model = model
        self.batch_size = batch_size
        
        if transform is None:
            transform = T.Compose([
                T.Resize(200),
                T.CenterCrop(100),
                T.ToTensor(),
                T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
            ])

        train = pd.read_csv(paths['train'])  
        n = len(train)
        
        submit = pd.read_csv(paths['submit'])
        self.max_submit_id = 22661
        
        small = pd.read_csv(paths['small'])
        
        dataset_types = ['train', 'valid', 'test', 'submit', 'small']        
        self.datasets = {
            'train'  : ImageDataset(train[:int(n * 0.8)], transform=transform),
            'valid'  : ImageDataset(train[int(n * 0.8):int(n * 0.9)], transform=transform),
            'test'   : ImageDataset(train[int(n * 0.9):], transform=transform),
            'small' : ImageDataset(small, transform=transform),
            'submit' : ImageDataset(submit, transform=transform)
        }
        
        self.loaders = {
            dataset_type: DataLoader(dataset, self.batch_size, shuffle=(dataset_type=='train' or dataset_type=='small'))
            for dataset_type, dataset in self.datasets.items()
        } 

    """
        Wrapper function for model training.
    """
    def train(self):
        print('Started training model')
        self.model._train(self.loaders)
        print('Finished training model\n')


    """
        Calculates f-scores on samples for all the loaders.
    """
    def fscore(self, num_batches=10):
        def _count(loader):
            preds, truth = [], []
            for (images1, images2, equals), _ in zip(loader, range(num_batches)):
                preds.append(self.model.predict(images1, images2))
                truth.append(equals)

            preds = torch.cat(preds)
            truth = torch.cat(truth)

            preds_bin = (preds > 0.5).int() # todo 
            f1 = f1_score(truth.numpy(), preds_bin.numpy())
            return f1
    
        print('Started calculating f-score')
        print(f'Train       : {_count(self.loaders["train"]): .3f}')
        print(f'Small train : {_count(self.loaders["small"]): .3f}')
        print(f'Validation  : {_count(self.loaders["valid"]): .3f}') 
        print(f'Test        : {_count(self.loaders["test"]): .3f}\n') 


    """
        Shows example images from each loader available.
    """
    def _test_loaders(self):
        print('Examples of images from the supported loaders:')
        to_show = {}
        for loader_type, loader in self.loaders.items():
            image1, image2, _ = next(iter(loader))
            image1 = image1[0]
            image2 = image2[0]
            to_show[loader_type + '-1'] = image1
            to_show[loader_type + '-2'] = image2
        num_images = len(to_show)
        fig, axs = plt.subplots(1, num_images, figsize=(15, 15))

        for i, (name, img) in enumerate(to_show.items()):
            axs[i].imshow(img.numpy().transpose(1,2,0).clip(0, 254), cmap='gray')
            axs[i].set_title(name)
            axs[i].axis('off')

        plt.show()

        
    """
        Tries to find and show mislabeled images from the specified loader.
    """
    def mislabeled(self, loader='small'):
        loader = self.loaders[loader]

        def mislabeled_inner():
            for images1, images2, equal in loader:
                preds = self.model.predict(images1, images2)
                if (preds == equal).all(): 
                    continue
                else:
                    for index in (preds != equal).nonzero():
                        yield (images1[index], images2[index], preds[index], equal[index])

        button = widgets.Button(description="Next Images")
        output = widgets.Output()

        def on_button_clicked(b):
            with output:
                clear_output()
                image1, image2, pred, truth = next(mislabeled_gen)
                def torch2np(x): return x.squeeze().numpy().transpose(1,2,0).clip(0,244)
                image1 = torch2np(image1)
                image2 = torch2np(image2)
                fig, axs = plt.subplots(1, 3, figsize=(14,7))
                axs[0].imshow(image1)
                axs[1].imshow(image2)
                axs[2].imshow(np.abs(image1 - image2))
                suptitle = f'Predicted: {pred.item()}\nTruth: {truth.item()}'
                try:
                    suptitle += f'\nmodel.forward(): {self.model.forward(image1, image2)}'
                except:
                    pass
                fig.suptitle(suptitle)
                plt.show()

        button.on_click(on_button_clicked)
        display(button, output)
        mislabeled_gen = mislabeled_inner()
        on_button_clicked(None)  # show the first images


    """
        Saves predictions that should be submitted to kaggle.
    """
    def save_test_preds(self, path='res.csv'):
        print(f'Started saving test predictions to {path}')
        ids = []
        preds = []
            
        for images1, images2, id_ in tqdm(self.loaders['submit']):
            preds.extend(self.model.predict(images1, images2))
            ids.extend(id_)
            
        all_ids = pd.DataFrame({
            'ID': range(2, self.max_submit_id + 1),
        })
        res = pd.DataFrame({
            'ID': [obj.item() for obj in ids],
            'is_same': [obj.item() for obj in preds]
        }).drop_duplicates()

        res = all_ids.merge(res, on='ID', how='left').fillna(0)
        res.to_csv(path, index=False)
        print(f'Saved test predictions to {path}\n')

In [None]:
class ContrastiveLoss(nn.Module):
    def __init__(self, margin=2.0):
        super(ContrastiveLoss, self).__init__()
        self.margin = margin

    def forward(self, distance, label):
        loss_contrastive = torch.mean(label * torch.pow(distance, 2) +
                                      (1 - label) * torch.pow(torch.clamp(self.margin - distance, min=0), 2))

        return loss_contrastive

class SiameseNetworkClassifier(nn.Module):
    def __init__(self, device='mps'):
        super(SiameseNetworkClassifier, self).__init__()
        
        self.layers = nn.Sequential(
            nn.Conv2d(3, 16, 10),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(16, 32, 4),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Flatten(),
            nn.Linear(67712, 50)
        )
        self.threshold = torch.tensor(0.)
        
        self.device = torch.device(device)
        self.to(self.device)

    def forward(self, images1, images2):
        output1 = self.layers(images1)
        output2 = self.layers(images2)
        return F.pairwise_distance(output1, output2)

    # TODO refactor this method so we don't have to call .to(self.device) ? 
    def predict(self, images1, images2):
        images1 = images1.to(self.device)
        images2 = images2.to(self.device)
        distances = self.forward(images1, images2)
        return (distances < self.threshold).int().cpu()
        
    def _update_threshold(self, loader, max_batches=50):
        self.eval()
        with torch.no_grad():
            distances = []
            labels = []
            for (images1, images2, equals), _ in zip(loader, range(max_batches)):
                distance = self.forward(images1.to(self.device), images2.to(self.device))
                distances.append(distance.cpu())
                labels.append(equals)
    
            distances = torch.cat(distances)
            labels = torch.cat(labels)
            log_reg = LogisticRegression(penalty=None)
            log_reg.fit(distances.reshape((-1, 1)), labels)
            self.threshold = (-log_reg.intercept_ / log_reg.coef_).item()

    def _evaluate(self, loader, max_batches = 50):
        super().eval()
        with torch.no_grad():
            pos_f1 = BinaryF1Score()
            neg_f1 = BinaryF1Score()
            for (images1, images2, label), _ in zip(loader, range(max_batches)):
                distance = self.forward(images1.to(self.device), images2.to(self.device)).cpu()
                pos_f1.update(distance < self.threshold, label)
                neg_f1.update(distance > self.threshold, 1 - label)
        return (pos_f1.compute() + neg_f1.compute()) / 2


    # TODO make train_loader and valid_loader args instead of loaders ? 
    def _train(self, loaders, lr=1e-4, batch_size=32):
        print("Debug: Initializing ContrastiveLoss and Optimizer")
        criterion = ContrastiveLoss().to(self.device)
        optimizer = torch.optim.Adam(self.parameters(), lr=lr)
    
        for epoch in range(10):
            super().train()
            for images1, images2, label in tqdm(loaders['small']):
                images1 = images1.to(self.device)
                images2 = images2.to(self.device)
                label = label.to(self.device)
    
                optimizer.zero_grad()
                outputs = self.forward(images1, images2)
                loss = criterion(outputs, label)
                loss.backward()
                optimizer.step()
    
            self._update_threshold(loaders['small'], batch_size)
            
            print(f'Epoch {epoch} | Loss:{loss.item()}')
            print(f'Train accuracy: {self._evaluate(loaders["train"], batch_size)}')
            print(f'Small train dataset accuracy: {self._evaluate(loaders["small"], batch_size)}')
            print(f'Valid accuracy: {self._evaluate(loaders["valid"], batch_size)}')

In [92]:
"""
    Create the model and wrapper instances
"""

model = SiameseNetworkClassifier()
wrapper = Wrapper(model, batch_size=64)

In [93]:
"""
    Train the model.
    TODO specify loaders for the train method ?
"""

wrapper.train()

Started training model
Debug: Initializing ContrastiveLoss and Optimizer


  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 0 | Loss:0.14787566661834717
Train accuracy: 0.9977034330368042
Small train dataset accuracy: 0.9891458749771118
Valid accuracy: 0.9983172416687012


  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 1 | Loss:0.16687117516994476
Train accuracy: 0.9994363784790039
Small train dataset accuracy: 0.9960115551948547
Valid accuracy: 0.9988795518875122


  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 2 | Loss:0.026339977979660034
Train accuracy: 0.9983682632446289
Small train dataset accuracy: 0.9965192675590515
Valid accuracy: 1.0


  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 3 | Loss:0.028118792921304703
Train accuracy: 1.0
Small train dataset accuracy: 0.9990079402923584
Valid accuracy: 1.0


  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 4 | Loss:0.07012743502855301
Train accuracy: 0.9994143843650818
Small train dataset accuracy: 0.9990008473396301
Valid accuracy: 0.999439537525177


  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 5 | Loss:0.05376338213682175
Train accuracy: 0.9994193315505981
Small train dataset accuracy: 0.9955447912216187
Valid accuracy: 1.0


  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 6 | Loss:0.027552222833037376
Train accuracy: 1.0
Small train dataset accuracy: 0.9995027780532837
Valid accuracy: 1.0


  0%|          | 0/94 [00:00<?, ?it/s]

Epoch 7 | Loss:0.019300859421491623


KeyboardInterrupt: 

In [95]:
"""
    Check on what's wrong with our model
"""

wrapper.mislabeled(loader='small')

Button(description='Next Images', style=ButtonStyle())

Output()

In [None]:
"""
    Calculate final fscore on the whole dataset
"""

wrapper.fscore(num_batches=500)