## Imports

In [None]:
from functools import partial
from tqdm.notebook import tqdm
from collections import defaultdict, Counter
import copy
from dataclasses import dataclass, field
from datetime import datetime

# Basic libraries for data manipulation
import torch
from PIL import Image
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from pathlib import Path
import os
import random
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score, accuracy_score, confusion_matrix
import seaborn as sns
import pickle

# PyTorch parts
from torch import nn
from torch.nn import functional as F
from torch import optim
from torch.utils.data import DataLoader, Dataset
from torchvision import models, transforms, datasets

In [None]:
def seed_everything(seed: int):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

seed_everything(seed=2022)

## Config

In [None]:
AUGMENTATIONS_ENGINE = "torchvision"
OUTPUT_SIZE = 1

In [None]:
device = "cpu"
if torch.cuda.device_count() == 1: device = "cuda"
elif torch.cuda.device_count() == 2: device = "cuda:0"
device = torch.device(device)
device

In [None]:
# Training hyperparameters
BATCH_SIZE = 128
SHOW_PROGRESS = False

## Data

**Download the data**

In [None]:
# Switch this between CIFAR-10 and CIFAR-100 when needed

N_LABELS = 10
dataset = datasets.CIFAR10(root=".", train=False, download=True)
len(dataset)

In [None]:
x, y = dataset[9]
print(y)
plt.imshow(x)
plt.show()

In [None]:
%%time

label_to_files = { label: [idx for idx in range(len(dataset)) if dataset[idx][1] == label] for label in tqdm(range(N_LABELS)) }
label_to_files.keys()

In [None]:
set([len(value) for value in label_to_files.values()])

In [None]:
label_to_files[0][:10]

## Dataset

In [None]:
%%time

def draw_image(label_to_files, label, probability_same=0.5):
    are_same = random.random() < probability_same
    
    labels = list(label_to_files.keys())
    if not are_same: label = random.choice([lbl for lbl in labels if lbl != label])
    
    return random.choice(label_to_files[label]), are_same

out = draw_image(label_to_files=label_to_files, label=0, probability_same=0.5)
out

In [None]:
random.seed(2022)

balanced_pairs = [
    draw_image(label_to_files=label_to_files, label=dataset[idx][1], probability_same=0.5)[0] for idx in range(len(dataset))
]
sparse_pairs = [
    draw_image(label_to_files=label_to_files, label=dataset[idx][1], probability_same=1 / N_LABELS)[0] for idx in range(len(dataset))
]
len(balanced_pairs), len(sparse_pairs)

In [None]:
np.mean([dataset[i][1] == dataset[balanced_pairs[i]][1] for i in range(len(dataset))])

In [None]:
np.mean([dataset[i][1] == dataset[sparse_pairs[i]][1] for i in range(len(dataset))])

In [None]:
class CustomDataset(Dataset):

    def __init__(
        self, dataset, label_to_files, paired_files,
        augmentations:transforms.Compose=None, augmentations_engine:str="torchvision"
    ):        
        assert augmentations_engine in ["torchvision", "albumentations"], \
            "Variable augmentations_engine has to be one of torchvision, albumentations."
        
        self.dataset = dataset
        self.label_to_files = label_to_files
        self.paired_files = paired_files
        self.augmentations = augmentations
        self.augmentations_engine = augmentations_engine

    def __len__(self):
        return len(self.dataset)

    
    def _get_files_valid(self, idx):
        x1, y1 = self.dataset[idx]
        x2, y2 = self.dataset[self.paired_files[idx]]
        same = y1 == y2
            
        return x1, x2, same
    

    def __getitem__(self, idx):
        image1, image2, same = self._get_files_valid(idx)
        
        if self.augmentations is not None:
            if self.augmentations_engine == "albumentations":
                image1 = self.augmentations(image=image1)["image"]
                image2 = self.augmentations(image=image2)["image"]
            else:
                image1 = self.augmentations(image1)
                image2 = self.augmentations(image2)

            
        return { "image1": image1, "image2": image2, "label": torch.tensor(float(same)) }

In [None]:
valid_augmentations = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(), # Converting images to tensors
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

In [None]:
valid_balanced_dataset = CustomDataset(
    dataset=dataset,
    label_to_files=label_to_files, paired_files=balanced_pairs,
    augmentations=valid_augmentations, augmentations_engine=AUGMENTATIONS_ENGINE
)
valid_sparse_dataset = CustomDataset(
    dataset=dataset,
    label_to_files=label_to_files, paired_files=sparse_pairs,
    augmentations=valid_augmentations, augmentations_engine=AUGMENTATIONS_ENGINE
)

In [None]:
item = valid_balanced_dataset[0]
print(item["label"])
plt.imshow(item["image1"].permute(1, 2, 0))
plt.show()
plt.imshow(item["image2"].permute(1, 2, 0))
plt.show()

In [None]:
item = valid_sparse_dataset[0]
print(item["label"])
plt.imshow(item["image1"].permute(1, 2, 0))
plt.show()
plt.imshow(item["image2"].permute(1, 2, 0))
plt.show()

In [None]:
len(valid_balanced_dataset), len(valid_sparse_dataset)

## Dataloaders

In [None]:
valid_kwargs = {
    "batch_size": BATCH_SIZE,
    "shuffle": False,
    "drop_last": False,
    "num_workers": 2,
    "pin_memory": device.type == "cuda",
}

valid_balanced_dataloader = DataLoader(valid_balanced_dataset, **valid_kwargs)
valid_sparse_dataloader = DataLoader(valid_sparse_dataset, **valid_kwargs)

In [None]:
dataloaders = {
    "valid_balanced_dataloader": valid_balanced_dataloader,
    "valid_sparse_dataloader": valid_sparse_dataloader,
}

In [None]:
%%time

batch = next(iter(dataloaders["valid_balanced_dataloader"]))
batch["image1"].shape, batch["image2"].shape, batch["label"].shape

In [None]:
batch["label"][:10]

## Model

In [None]:
class SimilarityModel(nn.Module):
    def __init__(self, backbone_model, backbone_output_size, output_size=OUTPUT_SIZE):
        super().__init__()
        
        self.backbone_model = backbone_model
        
        input_size = 3 * backbone_output_size + 2
        self.classifier = nn.Sequential(
            nn.BatchNorm1d(num_features=input_size),
            nn.Dropout(p=0.25),
            nn.Linear(in_features=input_size, out_features=512, bias=True),
            nn.ReLU(inplace=True),
            nn.BatchNorm1d(num_features=512),
            nn.Dropout(p=0.5),
            nn.Linear(in_features=512, out_features=output_size, bias=True),
        )
    
    def forward(self, inputs):
        image1, image2 = inputs
        out1 = self.backbone_model(image1)
        out2 = self.backbone_model(image2)
        
        # Use both vectors, their difference squared and the distance between them as input features
        # to the classification head
        diff_squared = (out1 - out2) ** 2
        cosine_similarity = F.cosine_similarity(out1, out2, dim=-1).unsqueeze(1)
        
        concatenated_features = torch.hstack([
            out1, out2,
            diff_squared, cosine_similarity,
            diff_squared.sum(dim=1, keepdim=True).sqrt()
        ])
        out = self.classifier(concatenated_features).squeeze()
        
        return out


def get_similarity_model(pretrained=True, output_size=OUTPUT_SIZE, device=device):
    backbone_model = models.efficientnet_b0(pretrained=pretrained)
    backbone_output_size = backbone_model.classifier[1].in_features
    backbone_model.classifier = nn.Identity()
    
    model = SimilarityModel(
        backbone_model=backbone_model, backbone_output_size=backbone_output_size, output_size=output_size
    )
    model.to(device)
    
    return model

model = get_similarity_model(pretrained=False, output_size=OUTPUT_SIZE, device=device)
sum([p.numel() for p in model.parameters()]) # number of parameters in the model

In [None]:
%%time

with torch.no_grad():
    out = model((batch["image1"].to(device), batch["image2"].to(device)))
out.shape

In [None]:
@dataclass
class Metric:
    """
    Computes and stores the average and current value
    """

    name: str
    total_value:int = 0
    total_samples:int = 0
    avg:int = 0

    def reset(self):
        self.total_value = 0
        self.total_samples = 0
        self.avg = 0

    def update(self, loss, predictions, is_correct, new_count=1):
        new_value = self._compute_value(loss, predictions, is_correct, new_count)
        
        self.total_value += new_value
        self.total_samples += new_count
        self.avg = self.total_value / self.total_samples
    
    # Metric specializes
    def _compute_value(self, loss, predictions, is_correct, new_count):
        if self.name == "loss": return loss.item() * new_count
        else: return is_correct.sum().item()

In [None]:
@dataclass
class Learner:
    """
    Container for various model/training variables
    """
    
    model:nn.Module
    device:torch.device
    dataloaders:dict
    metric_values:dict
    stage:str = "train"
    optimizer:optim.Optimizer = None
    lr_scheduler:optim.lr_scheduler._LRScheduler = None
    lrs:list = field(default_factory=list)


In [None]:
def iterate(learner, perform_backward_pass:bool=False, threshold:float=0.5, show_progress:bool=False):
    all_labels = []
    all_predictions = []
    
    metrics = [Metric(name="loss"), Metric(name="acc")]
    
    dl = learner.dataloaders[f"{learner.stage}_dataloader"]
    loop = tqdm(dl, total=len(dl), leave=False) if show_progress else dl
    for batch in loop:
        inputs1 = batch["image1"].to(learner.device)
        inputs2 = batch["image2"].to(learner.device)
        labels = batch["label"].to(learner.device)

        outputs = learner.model((inputs1, inputs2))

        # Calculate the loss
        loss = F.binary_cross_entropy_with_logits(outputs, labels)
    
        if perform_backward_pass:
            learner.optimizer.zero_grad()
            loss.backward()

            learner.optimizer.step()
            
            learner.lrs.append(learner.lr_scheduler.get_last_lr()[0])
            learner.lr_scheduler.step()
        
        predictions = outputs.detach() >= torch.log(torch.tensor(threshold / (1 - threshold)))
        is_correct = (predictions == labels).long()
        for metric in metrics:
            metric.update(loss=loss, predictions=predictions, is_correct=is_correct, new_count=batch["label"].shape[0])
        
        all_labels.append(labels)
        all_predictions.append(predictions)
    
    # Calculate loss and acc
    for metric in metrics:
        learner.metric_values[f"{learner.stage}_{metric.name}"].append(metric.avg)
    
    # Calculate the f1 metric
    all_labels, all_predictions = torch.hstack(all_labels).cpu(), torch.hstack(all_predictions).cpu()
    f1 = f1_score(all_labels, all_predictions)
    learner.metric_values[f"{learner.stage}_f1"].append(f1)
    
    return all_labels, all_predictions

## Training

In [None]:
validation_stages = ["valid_balanced", "valid_sparse"]

In [None]:
metric_values = {
    "_".join([split, metric]): [] for split in validation_stages for metric in ["loss", "acc", "f1"]
}
metric_values

In [None]:
learner = Learner(
    model=get_similarity_model(pretrained=False, output_size=OUTPUT_SIZE, device=device),
    device=device,
    dataloaders=dataloaders,
    metric_values=metric_values,
)

## Evaluate

In [None]:
checkpoint = torch.load("path/to/model/weights")
learner.model.load_state_dict(checkpoint)
learner.model.eval();

In [None]:
%%time

classes = ["dissimilar", "similar"]

fig, axes = plt.subplots(1, 1, figsize=(5, 5))
for stage, ax in zip(["valid_balanced"], [axes]):
    print(stage)
    learner.stage = stage
    with torch.no_grad():
        y_true, y_predicted = iterate(learner=learner, perform_backward_pass=False, threshold=0.5, show_progress=True)
    conf_matrix = confusion_matrix(y_true, y_predicted, normalize='true')

    sns.heatmap(
        conf_matrix, ax=ax, xticklabels=classes, yticklabels=classes, annot=True, cbar=False, square=True)
    ax.set_title(stage)
    
fig.tight_layout()
plt.show()

In [None]:
for metric_name in sorted(learner.metric_values.keys(), key=lambda x: x.split("_")[-1]):
    if len(learner.metric_values[metric_name]) > 0:
        print(f"{metric_name}: {learner.metric_values[metric_name][-1]:.4f}")

Check most and least confused:

In [None]:
def predict(learner):
    
    all_labels = []
    all_outputs = []
    
    with torch.no_grad():
        dl = learner.dataloaders[f"{learner.stage}_dataloader"]
        loop = tqdm(dl, total=len(dl), leave=False)
        for batch in loop:
            inputs1 = batch["image1"].to(learner.device)
            inputs2 = batch["image2"].to(learner.device)
            labels = batch["label"].to(learner.device)
            all_labels.append(labels)

            outputs = learner.model((inputs1, inputs2))
            outputs = outputs.sigmoid()
            all_outputs.append(outputs)
    
    all_labels = torch.hstack(all_labels)
    all_outputs = torch.hstack(all_outputs)
    
    return all_outputs, all_labels


In [None]:
N_IMAGES_TO_SHOW = 5

In [None]:
for stage in validation_stages:
    print("-"*30)
    print(stage)
    learner.stage = stage
    
    all_outputs, all_labels = predict(learner)
    all_outputs = all_outputs.cpu()
    all_labels = all_labels.cpu()
    
    is_correct = (all_outputs >= 0.5) == all_labels
    not_correct_and_similar = (~is_correct) & (all_labels == 1)
    print(f"{is_correct.float().mean().item():.4f}")
    print(f"{(~is_correct).long().sum().item()}/{len(is_correct)}")
    print(f"{not_correct_and_similar.long().sum().item()}")
    
    most_confused = (all_outputs - 0.5).abs()
    most_confused = torch.vstack([most_confused, torch.arange(len(most_confused))])
    
    filtered_most_confused = most_confused[:, not_correct_and_similar]
    
    sorted_most_confused = filtered_most_confused[:, filtered_most_confused[0, :].argsort(descending=True)]
    
    initial_indices = sorted_most_confused[1, :].long()
    
    for idx in initial_indices[:N_IMAGES_TO_SHOW]:
        x1, x2, same = learner.dataloaders[f"{learner.stage}_dataloader"].dataset._get_files_valid(idx)
        
        fig, axes = plt.subplots(1, 2)
        for image, ax in zip([x1, x2], axes.flatten()):
            ax.imshow(image)
            ax.axis("off")
        
        fig.suptitle("similar" if same else "dissimilar")
        fig.tight_layout()
        plt.show()


In [None]:
for stage in validation_stages:
    print("-"*30)
    print(stage)
    learner.stage = stage
    
    all_outputs, all_labels = predict(learner)
    all_outputs = all_outputs.cpu()
    all_labels = all_labels.cpu()
    
    is_correct = (all_outputs >= 0.5) == all_labels
    print(f"{is_correct.float().mean().item():.4f}")
    print(f"{is_correct.long().sum().item()}/{len(is_correct)}")
    
    least_confused = (all_outputs - 0.5).abs()
    least_confused = torch.vstack([least_confused, torch.arange(len(least_confused))])
    
    filtered_least_confused = least_confused[:, is_correct]
    
    sorted_least_confused = filtered_least_confused[:, filtered_least_confused[0, :].argsort()]
    
    initial_indices = sorted_least_confused[1, :].long()
    
    for idx in initial_indices[:N_IMAGES_TO_SHOW]:
        x1, x2, same = learner.dataloaders[f"{learner.stage}_dataloader"].dataset._get_files_valid(idx)
        
        fig, axes = plt.subplots(1, 2)
        for image, ax in zip([x1, x2], axes.flatten()):
            ax.imshow(image)
            ax.axis("off")
        
        fig.suptitle("similar" if same else "dissimilar")
        fig.tight_layout()
        plt.show()

Assess influence of threshold:

In [None]:
%%time

thresholds = np.linspace(start=0.1, stop=0.9, num=9)
print(thresholds)

for stage in ["valid_balanced"]:
    print(stage)
    learner.stage = stage
    
    fig, axes = plt.subplots(1, 2, figsize=(10, 5))
    
    for metric_function, label, ax in zip([accuracy_score, f1_score], ["accuracy", "f1 score"], axes.flatten()):
        print(label)

        values = []
        for th in tqdm(thresholds, total=len(thresholds), leave=False):
            with torch.no_grad():
                y_true, y_predicted = iterate(learner=learner, perform_backward_pass=False, threshold=th, show_progress=False)
            val = metric_function(y_true, y_predicted)
            values.append(val)

        ax.plot(thresholds, values, "o-")
        ax.grid()
        ax.set_title(label)

    fig.suptitle(stage)
    fig.tight_layout()
    plt.show()

In [None]:
learner.metric_values["valid_balanced_acc"][:9]

In [None]:
learner.metric_values["valid_balanced_f1"][-9:]