## Dependencies

In [None]:
# Install dependencies
!pip install torch-lr-finder

## Imports

In [None]:
from functools import partial
from tqdm.notebook import tqdm
from collections import defaultdict, Counter
import copy
from dataclasses import field
from datetime import datetime

# Basic libraries for data manipulation
import torch
from PIL import Image
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from pathlib import Path
import os
import random
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score, accuracy_score, confusion_matrix
import seaborn as sns
import pickle

# PyTorch parts
from torch import nn
from torch.nn import functional as F
from torch import optim
from torch.utils.data import DataLoader, Dataset
from torchvision import models
from torchvision import transforms

# lr_finder
from torch_lr_finder import LRFinder, TrainDataLoaderIter

In [None]:
# Seed

def seed_everything(seed: int):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

seed_everything(seed=2022)

## Config

In [None]:
AUGMENTATIONS_ENGINE = "torchvision"
N_CLASSES = 100
PROBABILITY_SAME = 1 / 4
OUTPUT_SIZE = 1

In [None]:
device = "cpu"
if torch.cuda.device_count() == 1: device = "cuda"
elif torch.cuda.device_count() == 2: device = "cuda:0"
device = torch.device(device)
device

In [None]:
# Training hyperparameters
BATCH_SIZE = 64
N_EPOCHS_INITIAL = 7
N_EPOCHS_FULL = 50
WEIGHT_DECAY = 1e-3
SHOW_PROGRESS = False

## Data

**Download the data**

In [None]:
%%time

# Retrieve data directly from Stanford data source
if not Path("tiny-imagenet-200.zip").exists():
    !wget http://cs231n.stanford.edu/tiny-imagenet-200.zip
    !unzip -qq "tiny-imagenet-200.zip"

In [None]:
DATA_DIR = Path("tiny-imagenet-200") # Original images come in shapes of [3, 64, 64]

# Define training and validation data paths
TRAIN_DIR = DATA_DIR / "train"
VALID_DIR = DATA_DIR / "val"

DATA_DIR

In [None]:
!ls {DATA_DIR}

In [None]:
val_labels = pd.read_csv(
    f'{VALID_DIR}/val_annotations.txt', 
    sep='\t', 
    header=None, 
    names=['File', 'Class', 'X', 'Y', 'H', 'W']
)
print(len(val_labels))
val_labels.head()

In [None]:
valid_filename_to_raw_label = dict(zip(val_labels["File"], val_labels["Class"]))
len(valid_filename_to_raw_label.keys()), list(valid_filename_to_raw_label.items())[:10]

In [None]:
FOLDERS = [folder for folder in TRAIN_DIR.iterdir()]
TRAIN_FILENAMES = [f for folder in FOLDERS for f in (folder / "images").iterdir()]
len(TRAIN_FILENAMES), TRAIN_FILENAMES[:5]

In [None]:
VALID_FILENAMES = [f for f in (VALID_DIR / "images").iterdir()]
len(VALID_FILENAMES), VALID_FILENAMES[:5]

In [None]:
raw_labels_in_train = list(set([item.parent.parent.name for item in TRAIN_FILENAMES]))
len(raw_labels_in_train), raw_labels_in_train[:10]

In [None]:
with open(DATA_DIR / "words.txt") as f: raw_label_mapping = f.readlines()

raw_label_mapping = [line.strip() for line in raw_label_mapping]
raw_label_mapping = [line.split("\t") for line in raw_label_mapping]
raw_label_mapping = [item for item in raw_label_mapping if item[0] in raw_labels_in_train] # filter
raw_label_mapping = { item[0]: item[1] for item in raw_label_mapping } # cast to dictionary

# Only leave the first N_CLASSES ones in terms of length
top_sorted_labels_by_length = sorted(list(raw_label_mapping.values()), key=lambda x: len(x))[:N_CLASSES]
raw_label_mapping = { key: value for key, value in raw_label_mapping.items() if value in top_sorted_labels_by_length } # filter

len(raw_label_mapping.keys()), list(raw_label_mapping.items())[:10]

In [None]:
# Filter out the classes that will not be used
TRAIN_FILENAMES = [item for item in TRAIN_FILENAMES if item.parent.parent.name in raw_label_mapping.keys()]
VALID_FILENAMES = [item for item in VALID_FILENAMES if valid_filename_to_raw_label[item.name] in raw_label_mapping.keys()]

len(TRAIN_FILENAMES), len(VALID_FILENAMES)

In [None]:
label_to_id = list(raw_label_mapping.values())
label_to_id.sort()

label_to_id = { label: index for index, label in enumerate(label_to_id) }
len(label_to_id.keys()), list(label_to_id.items())[:10]

In [None]:
def extract_label_train(filename, raw_label_to_label:dict=None, label_to_id:dict=None):
    raw_label = filename.parent.parent.name
    label = raw_label_to_label[raw_label]
    label_id = label_to_id[label]

    return label_id

extract_label_train_partial = partial(extract_label_train, raw_label_to_label=raw_label_mapping, label_to_id=label_to_id)

extract_label_train_partial(TRAIN_FILENAMES[0])

In [None]:
def extract_label_valid(filename, filename_to_raw_label:dict=None, raw_label_to_label:dict=None, label_to_id:dict=None):
    raw_label = filename_to_raw_label[filename.name]
    label = raw_label_to_label[raw_label]
    label_id = label_to_id[label]

    return label_id

extract_label_valid_partial = partial(
    extract_label_valid, filename_to_raw_label=valid_filename_to_raw_label,
    raw_label_to_label=raw_label_mapping, label_to_id=label_to_id
)

extract_label_valid_partial(VALID_FILENAMES[0])

In [None]:
from dataclasses import dataclass

@dataclass
class File:
    filename: Path
    label: str
    
    @classmethod
    def from_filename(cls, filename, label_extractor):
        return cls(filename=filename, label=label_extractor(filename))

In [None]:
%%time

TRAIN_FILENAMES = [File.from_filename(filename=item, label_extractor=extract_label_train_partial) for item in TRAIN_FILENAMES]
VALID_FILENAMES = [File.from_filename(filename=item, label_extractor=extract_label_valid_partial) for item in VALID_FILENAMES]

TRAIN_FILENAMES[0], VALID_FILENAMES[0]

In [None]:
FILENAMES = TRAIN_FILENAMES + VALID_FILENAMES
len(FILENAMES)

In [None]:
labels = [item.label for item in FILENAMES]
len(labels)

In [None]:
# Set aside some unseen classes
unseen_classes = list(label_to_id.keys())[::10]
unseen_classes_dict = { key: label_to_id[key] for key in list(label_to_id.keys())[::10] }
unseen_classes_dict.values()

In [None]:
train_valid_files_tmp = [item for item in FILENAMES if item.label not in unseen_classes_dict.values()]
valid_unseen_files = [item for item in FILENAMES if item.label in unseen_classes_dict.values()]
len(train_valid_files_tmp), len(valid_unseen_files)

In [None]:
%%time

train_valid_files = train_valid_files_tmp.copy()
random.Random(2022).shuffle(train_valid_files)

train_files, valid_files = train_test_split(
    train_valid_files, train_size=int(len(train_valid_files) / 11 * 10),
    stratify=[item.label for item in train_valid_files],
    random_state=2022,
)
assert len(train_files) + len(valid_files) == len(train_valid_files_tmp)

len(train_files), len(valid_files)

In [None]:
tmp = Counter([item.label for item in train_files])
np.mean(list(tmp.values())), np.std(list(tmp.values())), len(tmp.keys())

In [None]:
tmp = Counter([item.label for item in valid_files])
np.mean(list(tmp.values())), np.std(list(tmp.values())), len(tmp.keys())

## Dataset

In [None]:
# Build label to files mapping
def get_labels_to_files_mapping(files):
    labels = sorted(list(set([item.label for item in files])))

    return { label: [file for file in files if file.label == label] for label in labels }

In [None]:
%%time

train_label_to_files = get_labels_to_files_mapping(files=train_files)
valid_label_to_files = get_labels_to_files_mapping(files=valid_files)
valid_unseen_label_to_files = get_labels_to_files_mapping(files=valid_unseen_files)

In [None]:
%%time

def draw_image(labels_to_files, label, probability_same=PROBABILITY_SAME):
    are_same = random.random() < probability_same
    
    labels = list(labels_to_files.keys())
    if not are_same: label = random.choice([lbl for lbl in labels if lbl != label])
    
    return random.choice(labels_to_files[label]), are_same

out = draw_image(train_label_to_files, label=2, probability_same=PROBABILITY_SAME)
out

In [None]:
random.seed(2022)

valid_paired_files = [
    draw_image(labels_to_files=valid_label_to_files, label=file.label, probability_same=PROBABILITY_SAME)[0] for file in valid_files
]
valid_paired_files_balanced = [
    draw_image(labels_to_files=valid_label_to_files, label=file.label, probability_same=0.5)[0] for file in valid_files
]
valid_paired_files_sparse = [
    draw_image(labels_to_files=valid_label_to_files, label=file.label, probability_same=1 / len(valid_label_to_files.keys()))[0] for file in valid_files
]
len(valid_paired_files), len(valid_paired_files_balanced), len(valid_paired_files_sparse)

In [None]:
random.seed(2022)

valid_unseen_paired_files = [
    draw_image(labels_to_files=valid_unseen_label_to_files, label=file.label, probability_same=PROBABILITY_SAME)[0] for file in valid_unseen_files
]
valid_unseen_paired_files_balanced = [
    draw_image(labels_to_files=valid_unseen_label_to_files, label=file.label, probability_same=0.5)[0] for file in valid_unseen_files
]
valid_unseen_paired_files_sparse = [
    draw_image(
        labels_to_files=valid_unseen_label_to_files, label=file.label,
        probability_same=1 / len(valid_unseen_label_to_files.keys())
    )[0] for file in valid_unseen_files
]
len(valid_unseen_paired_files), len(valid_unseen_paired_files_balanced), len(valid_unseen_paired_files_sparse)

In [None]:
class TinyImagenetDataset(Dataset):

    def __init__(
        self, files, split, labels_to_files, probability_same=PROBABILITY_SAME, paired_files=None,
        augmentations:transforms.Compose=None, augmentations_engine:str="torchvision"
    ):
        assert split in ["train", "valid", "valid_unseen"], \
            "Variable split has invalid value."
        
        assert augmentations_engine in ["torchvision", "albumentations"], \
            "Variable augmentations_engine has to be one of torchvision, albumentations."
        
        self.files = files
        self.split = split
        self.labels_to_files = labels_to_files
        self.probability_same = probability_same
        self.paired_files = paired_files
        self.augmentations = augmentations
        self.augmentations_engine = augmentations_engine

    def __len__(self):
        return len(self.files)
    
    def _get_files_train(self, idx):
        file1 = self.files[idx]
        file2, same = draw_image(labels_to_files=self.labels_to_files, label=file1.label, probability_same=self.probability_same)
            
        return file1, file2, same
    
    def _get_files_valid(self, idx):
        file1 = self.files[idx]
        file2 = self.paired_files[idx]
        same = file1.label == file2.label
            
        return file1, file2, same
    

    def __getitem__(self, idx):
        if self.split == "train":
            file1, file2, same = self._get_files_train(idx)
        else:
            file1, file2, same = self._get_files_valid(idx)
        
        image1 = Image.open(file1.filename).convert("RGB")
        image2 = Image.open(file2.filename).convert("RGB")
        
        if self.augmentations is not None:
            if self.augmentations_engine == "albumentations":
                image1 = self.augmentations(image=image1)["image"]
                image2 = self.augmentations(image=image2)["image"]
            else:
                image1 = self.augmentations(image1)
                image2 = self.augmentations(image2)

            
        return { "image1": image1, "image2": image2, "label": torch.tensor(float(same)) }

In [None]:
# If using pre-trained ImageNet, normalize with mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]

train_augmentations = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.2, contrast=0.3, saturation=0.2),
    transforms.RandomAffine(degrees=10, translate=(8/224, 8/224), scale=(0.95, 1.05), shear=5),
    transforms.ToTensor(), # Converting images to tensors
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

valid_augmentations = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(), # Converting images to tensors
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

In [None]:
train_dataset = TinyImagenetDataset(
    files=train_files, split="train",
    labels_to_files=train_label_to_files, probability_same=PROBABILITY_SAME,
    augmentations=train_augmentations, augmentations_engine=AUGMENTATIONS_ENGINE
)

valid_dataset = TinyImagenetDataset(
    files=valid_files, split="valid",
    labels_to_files=valid_label_to_files, paired_files=valid_paired_files,
    augmentations=valid_augmentations, augmentations_engine=AUGMENTATIONS_ENGINE
)
valid_balanced_dataset = TinyImagenetDataset(
    files=valid_files, split="valid",
    labels_to_files=valid_label_to_files, paired_files=valid_paired_files_balanced,
    augmentations=valid_augmentations, augmentations_engine=AUGMENTATIONS_ENGINE
)
valid_sparse_dataset = TinyImagenetDataset(
    files=valid_files, split="valid",
    labels_to_files=valid_label_to_files, paired_files=valid_paired_files_sparse,
    augmentations=valid_augmentations, augmentations_engine=AUGMENTATIONS_ENGINE
)

valid_unseen_dataset = TinyImagenetDataset(
    files=valid_unseen_files, split="valid_unseen",
    labels_to_files=valid_unseen_label_to_files, paired_files=valid_unseen_paired_files,
    augmentations=valid_augmentations, augmentations_engine=AUGMENTATIONS_ENGINE
)
valid_unseen_balanced_dataset = TinyImagenetDataset(
    files=valid_unseen_files, split="valid_unseen",
    labels_to_files=valid_unseen_label_to_files, paired_files=valid_unseen_paired_files_balanced,
    augmentations=valid_augmentations, augmentations_engine=AUGMENTATIONS_ENGINE
)
valid_unseen_sparse_dataset = TinyImagenetDataset(
    files=valid_unseen_files, split="valid_unseen",
    labels_to_files=valid_unseen_label_to_files, paired_files=valid_unseen_paired_files_sparse,
    augmentations=valid_augmentations, augmentations_engine=AUGMENTATIONS_ENGINE
)

In [None]:
item = train_dataset[0]
print(item["label"])
plt.imshow(item["image1"].permute(1, 2, 0))
plt.show()
plt.imshow(item["image2"].permute(1, 2, 0))
plt.show()

In [None]:
item = valid_dataset[0]
print(item["label"])
plt.imshow(item["image1"].permute(1, 2, 0))
plt.show()
plt.imshow(item["image2"].permute(1, 2, 0))
plt.show()

In [None]:
print((
    len(train_dataset), len(valid_dataset), len(valid_balanced_dataset), len(valid_sparse_dataset),
    len(valid_unseen_dataset), len(valid_unseen_balanced_dataset), len(valid_unseen_sparse_dataset)
))

## Dataloaders

In [None]:
BATCH_SIZE_SCALE_FACTOR = 2

In [None]:
valid_kwargs = {
    "batch_size": BATCH_SIZE * BATCH_SIZE_SCALE_FACTOR, # We can increase the eval batch size since gradients aren't stored
    "shuffle": False,
    "drop_last": False,
    "num_workers": 2,
    "pin_memory": device.type == "cuda",
}

train_dataloader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    drop_last=False,
    num_workers=2,
    pin_memory=device.type == "cuda",
)

valid_dataloader = DataLoader(valid_dataset, **valid_kwargs)
valid_balanced_dataloader = DataLoader(valid_balanced_dataset, **valid_kwargs)
valid_sparse_dataloader = DataLoader(valid_sparse_dataset, **valid_kwargs)

valid_unseen_dataloader = DataLoader(valid_unseen_dataset, **valid_kwargs)
valid_unseen_balanced_dataloader = DataLoader(valid_unseen_balanced_dataset, **valid_kwargs)
valid_unseen_sparse_dataloader = DataLoader(valid_unseen_sparse_dataset, **valid_kwargs)

In [None]:
dataloaders = {
    "train_dataloader": train_dataloader,
    "valid_dataloader": valid_dataloader,
    "valid_balanced_dataloader": valid_balanced_dataloader,
    "valid_sparse_dataloader": valid_sparse_dataloader,
    "valid_unseen_dataloader": valid_unseen_dataloader,
    "valid_unseen_balanced_dataloader": valid_unseen_balanced_dataloader,
    "valid_unseen_sparse_dataloader": valid_unseen_sparse_dataloader,
}

In [None]:
%%time

batch = next(iter(dataloaders["train_dataloader"]))
batch["image1"].shape, batch["image2"].shape, batch["label"].shape

In [None]:
batch["label"][:10]

## Model

In [None]:
class SimilarityModel(nn.Module):
    def __init__(self, backbone_model, backbone_output_size, output_size=OUTPUT_SIZE):
        super().__init__()
        
        self.backbone_model = backbone_model
        
        input_size = 3 * backbone_output_size + 2
        self.classifier = nn.Sequential(
            nn.BatchNorm1d(num_features=input_size),
            nn.Dropout(p=0.25),
            nn.Linear(in_features=input_size, out_features=512, bias=True),
            nn.ReLU(inplace=True),
            nn.BatchNorm1d(num_features=512),
            nn.Dropout(p=0.5),
            nn.Linear(in_features=512, out_features=output_size, bias=True),
        )
    
    def forward(self, inputs):
        image1, image2 = inputs
        out1 = self.backbone_model(image1)
        out2 = self.backbone_model(image2)
        
        # Use both vectors, their difference squared and the distance between them as input features
        # to the classification head
        diff_squared = (out1 - out2) ** 2
        cosine_similarity = F.cosine_similarity(out1, out2, dim=-1).unsqueeze(1)
        
        concatenated_features = torch.hstack([
            out1, out2,
            diff_squared, cosine_similarity,
            diff_squared.sum(dim=1, keepdim=True).sqrt()
        ])
        out = self.classifier(concatenated_features).squeeze()
        
        return out


def get_similarity_model(pretrained=True, output_size=OUTPUT_SIZE, device=device):
    backbone_model = models.efficientnet_b0(pretrained=pretrained)
    backbone_output_size = backbone_model.classifier[1].in_features
    backbone_model.classifier = nn.Identity()
    
    model = SimilarityModel(
        backbone_model=backbone_model, backbone_output_size=backbone_output_size, output_size=output_size
    )
    model.to(device)
    
    return model

model = get_similarity_model(pretrained=False, output_size=OUTPUT_SIZE, device=device)
sum([p.numel() for p in model.parameters()]) # number of parameters in the model

In [None]:
%%time

with torch.no_grad():
    out = model((batch["image1"].to(device), batch["image2"].to(device)))
out.shape

In [None]:
def freeze_model(learner, freeze_all=False):
    # First freeze the the whole model
    for param in learner.model.backbone_model.parameters():
        param.requires_grad = False

    # If desired, unfreeze the final layer
    if not freeze_all:
        for param in learner.model.classifier.parameters():
            param.requires_grad = True
            

# TODO: maybe also freeze running stats.
# This is a debated topic.
def unfreeze_model(learner):
    for param in learner.model.parameters():
        param.requires_grad = True
    
    # Freeze batch normalization layers
    for module in learner.model.backbone_model.features.modules():
        if isinstance(module, nn.BatchNorm2d):
            for param in module.parameters():
                param.requires_grad = False

In [None]:
def get_optimizer_and_scheduler(learner, lr, n_epochs, weight_decay=None):
    learner.lrs = []
    
    # Instantiate the optimizer
    if weight_decay is not None:
        optimizer = optim.AdamW(learner.model.parameters(), lr=lr, weight_decay=weight_decay)
    else:
        optimizer = optim.Adam(params=learner.model.parameters(), lr=lr)

    # Instantiate the learning rate scheduler
    lr_scheduler = optim.lr_scheduler.OneCycleLR(
        optimizer=optimizer,
        max_lr=lr,
        epochs=n_epochs,
        steps_per_epoch=len(learner.dataloaders["train_dataloader"])
    )
    
    return optimizer, lr_scheduler

In [None]:
@dataclass
class Metric:
    """
    Computes and stores the average and current value
    """

    name: str
    total_value:int = 0
    total_samples:int = 0
    avg:int = 0

    def reset(self):
        self.total_value = 0
        self.total_samples = 0
        self.avg = 0

    def update(self, loss, predictions, is_correct, new_count=1):
        new_value = self._compute_value(loss, predictions, is_correct, new_count)
        
        self.total_value += new_value
        self.total_samples += new_count
        self.avg = self.total_value / self.total_samples
    
    # Metric specializes
    def _compute_value(self, loss, predictions, is_correct, new_count):
        if self.name == "loss": return loss.item() * new_count
        else: return is_correct.sum().item()


In [None]:
@dataclass
class Learner:
    """
    Container for various model/training variables
    """
    
    model:nn.Module
    device:torch.device
    dataloaders:dict
    metric_values:dict
    stage:str = "train"
    optimizer:optim.Optimizer = None
    lr_scheduler:optim.lr_scheduler._LRScheduler = None
    lrs:list = field(default_factory=list)


In [None]:
def iterate(learner, perform_backward_pass:bool=False, threshold:float=0.5, show_progress:bool=False):
    all_labels = []
    all_predictions = []
    
    metrics = [Metric(name="loss"), Metric(name="acc")]
    
    dl = learner.dataloaders[f"{learner.stage}_dataloader"]
    loop = tqdm(dl, total=len(dl), leave=False) if show_progress else dl
    for batch in loop:
        inputs1 = batch["image1"].to(learner.device)
        inputs2 = batch["image2"].to(learner.device)
        labels = batch["label"].to(learner.device)

        outputs = learner.model((inputs1, inputs2))

        # Calculate the loss
        loss = F.binary_cross_entropy_with_logits(outputs, labels)
    
        if perform_backward_pass:
            learner.optimizer.zero_grad()
            loss.backward()

            learner.optimizer.step()
            
            learner.lrs.append(learner.lr_scheduler.get_last_lr()[0])
            learner.lr_scheduler.step()
        
        predictions = outputs.detach() >= torch.log(torch.tensor(threshold / (1 - threshold)))
        is_correct = (predictions == labels).long()
        for metric in metrics:
            metric.update(loss=loss, predictions=predictions, is_correct=is_correct, new_count=batch["label"].shape[0])
        
        all_labels.append(labels)
        all_predictions.append(predictions)
    
    # Calculate loss and acc
    for metric in metrics:
        learner.metric_values[f"{learner.stage}_{metric.name}"].append(metric.avg)
    
    # Calculate the f1 metric
    all_labels, all_predictions = torch.hstack(all_labels).cpu(), torch.hstack(all_predictions).cpu()
    f1 = f1_score(all_labels, all_predictions)
    learner.metric_values[f"{learner.stage}_f1"].append(f1)
    
    return all_labels, all_predictions

In [None]:
def fit_model(learner, n_epochs, show_progress=False, save_every=10):   
    for epoch in range(n_epochs):
        learner.model.train()
        learner.stage = "train"
        all_labels, all_predictions = iterate(learner=learner, perform_backward_pass=True, threshold=0.5, show_progress=show_progress)
        
        learner.model.eval()
        
        for stage in ["valid_balanced", "valid_unseen_balanced"]:
            learner.stage = stage
            with torch.no_grad():
                iterate(learner=learner, perform_backward_pass=False, threshold=0.5, show_progress=show_progress)
        
        print(f"Epoch: {epoch + 1}/{n_epochs}: ", end="")
        metrics_to_log = [
            "train_loss", "train_f1", "valid_balanced_loss", "valid_balanced_f1", "valid_unseen_balanced_loss", "valid_unseen_balanced_f1",
        ]
        print(", ".join([f"{metric_name}: {metric_values[metric_name][-1]:.2f}" for metric_name in metrics_to_log]))
        
        if (epoch + 1) % save_every == 0:
            date_and_time = datetime.now().strftime("%Y_%m_%d_%H_%M_%S")
            torch.save({
                'epoch': epoch,
                'model': learner.model.state_dict(),
                'optimizer': learner.optimizer.state_dict(),
                'scheduler': learner.lr_scheduler.state_dict(),
                'metric_values': learner.metric_values,
            }, f"checkpoint_{epoch+1}_{date_and_time}.pt")



## Training

In [None]:
validation_stages = ["valid", "valid_balanced", "valid_sparse", "valid_unseen", "valid_unseen_balanced", "valid_unseen_sparse"]

In [None]:
metric_values = {
    "_".join([split, metric]): [] for split in (["train"] + validation_stages) for metric in ["loss", "acc", "f1"]
}
metric_values

In [None]:
learner = Learner(
    model=get_similarity_model(pretrained=True, output_size=OUTPUT_SIZE, device=device),
    device=device,
    dataloaders=dataloaders,
    metric_values=metric_values,
)

In [None]:
# Freeze the whole model except for the last layer
freeze_model(learner=learner, freeze_all=False)

In [None]:
class TrainIter(TrainDataLoaderIter):
    def inputs_labels_from_batch(self, batch):
        return (batch['image1'], batch['image2']), batch['label']

In [None]:
%%time

model = learner.model
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3) if WEIGHT_DECAY is None else optim.AdamW(model.parameters(), lr=1e-3, weight_decay=WEIGHT_DECAY)
dl = TrainIter(learner.dataloaders["train_dataloader"])
lr_finder = LRFinder(model, optimizer, criterion, device=device.type)
lr_finder.range_test(dl, start_lr=1e-6, end_lr=1e-1, num_iter=100)
lr_finder.plot() # to inspect the loss-learning rate graph
lr_finder.reset() # to reset the model and optimizer to their initial state

In [None]:
%%time

# Instantiate the optimizer and the lr_scheduler
learner.optimizer, learner.lr_scheduler = get_optimizer_and_scheduler(learner=learner, lr=1e-3, n_epochs=N_EPOCHS_INITIAL, weight_decay=WEIGHT_DECAY)

fit_model(learner=learner, n_epochs=N_EPOCHS_INITIAL, show_progress=SHOW_PROGRESS)

In [None]:
# Save things

torch.save(model.state_dict(), "model_initial.pth")

df = pd.DataFrame.from_dict({ key: value for key, value in learner.metric_values.items() if len(value) > 0 })
df.to_csv("metric_values_initial.csv", index=False)

In [None]:
# Unfreeze the whole model
unfreeze_model(learner=learner)

In [None]:
%%time

model = learner.model
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3) if WEIGHT_DECAY is None else optim.AdamW(model.parameters(), lr=1e-3, weight_decay=WEIGHT_DECAY)
dl = TrainIter(learner.dataloaders["train_dataloader"])
lr_finder = LRFinder(model, optimizer, criterion, device="cuda")
lr_finder.range_test(dl, start_lr=1e-6, end_lr=1e-1, num_iter=100)
lr_finder.plot() # to inspect the loss-learning rate graph
lr_finder.reset() # to reset the model and optimizer to their initial state

In [None]:
%%time

# Instantiate the optimizer and the lr_scheduler
learner.optimizer, learner.lr_scheduler = get_optimizer_and_scheduler(learner=learner, lr=1e-4, n_epochs=N_EPOCHS_FULL, weight_decay=WEIGHT_DECAY)

fit_model(learner=learner, n_epochs=N_EPOCHS_FULL, show_progress=SHOW_PROGRESS)

In [None]:
# Save things

torch.save(model.state_dict(), "model_final.pth")

df = pd.DataFrame.from_dict({ key: value for key, value in learner.metric_values.items() if len(value) > 0 })
df.to_csv("metric_values.csv", index=False)

## Evaluate

In [None]:
len(metric_values["train_loss"])

In [None]:
def plot_metric(metric_values, metric_name):
    x = range(1, len(metric_values[f"train_{metric_name}"]) + 1)
    plt.plot(x, metric_values[f"train_{metric_name}"], label="train")
    plt.plot(x, metric_values[f"valid_balanced_{metric_name}"], label="validation")
    plt.plot(x, metric_values[f"valid_unseen_balanced_{metric_name}"], label="validation_unseen")

    plt.xlabel("epoch")
    plt.ylabel(metric_name)
    plt.grid()
    plt.legend()
    plt.show()

In [None]:
plot_metric(metric_values=learner.metric_values, metric_name="loss")

In [None]:
plot_metric(metric_values=learner.metric_values, metric_name="acc")

In [None]:
plot_metric(metric_values=learner.metric_values, metric_name="f1")

In [None]:
for metric_name in sorted(learner.metric_values.keys(), key=lambda x: x.split("_")[-1]):
    if len(learner.metric_values[metric_name]) > 0:
        print(f"{metric_name}: {learner.metric_values[metric_name][-1]:.4f}")

In [None]:
# Load model
learner.model.load_state_dict(torch.load("model_final.pth", map_location=device))
learner.model.eval();

In [None]:
%%time

classes = ["dissimilar", "similar"]

fig, axes = plt.subplots(2, 3, figsize=(15, 10))
for stage, ax in zip(validation_stages, axes.flatten()):
    print(stage)
    learner.stage = stage
    with torch.no_grad():
        y_true, y_predicted = iterate(learner=learner, perform_backward_pass=False, threshold=0.5, show_progress=True)
    conf_matrix = confusion_matrix(y_true, y_predicted, normalize='true')

    sns.heatmap(
        conf_matrix, ax=ax, xticklabels=classes, yticklabels=classes, annot=True, cbar=False, square=True)
    ax.set_title(stage)
    
fig.tight_layout()
plt.show()

Check most and least confused:

In [None]:
id_to_label = { value: key for key, value in label_to_id.items() }
list(id_to_label.items())[:10]

In [None]:
def predict(learner, th=0.5):
    
    all_labels = []
    all_outputs = []
    
    with torch.no_grad():
        dl = learner.dataloaders[f"{learner.stage}_dataloader"]
        loop = tqdm(dl, total=len(dl), leave=False)
        for batch in loop:
            inputs1 = batch["image1"].to(learner.device)
            inputs2 = batch["image2"].to(learner.device)
            labels = batch["label"].to(learner.device)
            all_labels.append(labels)

            outputs = learner.model((inputs1, inputs2))
            outputs = outputs.sigmoid()
            all_outputs.append(outputs)
    
    all_labels = torch.hstack(all_labels)
    all_outputs = torch.hstack(all_outputs)
    
    return all_outputs, all_labels

In [None]:
for stage in validation_stages:
    print("-"*30)
    print(stage)
    learner.stage = stage
    
    all_outputs, all_labels = predict(learner, th=0.5)
    all_outputs = all_outputs.cpu()
    all_labels = all_labels.cpu()
    
    is_correct = (all_outputs >= 0.5) == all_labels
    not_correct_and_similar = (~is_correct) & (all_labels == 1)
    print(f"{is_correct.float().mean().item():.4f}")
    print(f"{(~is_correct).float().sum().item()}/{len(is_correct)}")
    print(f"{not_correct_and_similar.float().sum().item()}")
    
    most_confused = (all_outputs - 0.5).abs()
    most_confused = torch.vstack([most_confused, torch.arange(len(most_confused))])
    
    filtered_most_confused = most_confused[:, not_correct_and_similar]
    
    sorted_most_confused = filtered_most_confused[:, filtered_most_confused[0, :].argsort(descending=True)]
    
    initial_indices = sorted_most_confused[1, :].long()
    
    for idx in initial_indices[:10]:
        file1, file2, same = learner.dataloaders[f"{learner.stage}_dataloader"].dataset._get_files_valid(idx)
        
        fig, axes = plt.subplots(1, 2)
        for file, ax in zip([file1, file2], axes.flatten()):
            image = Image.open(file.filename).convert("RGB")
            ax.imshow(image)
            ax.axis("off")
            ax.set_title(id_to_label[file.label])
        
        fig.suptitle("similar" if same else "dissimilar")
        fig.tight_layout()
        plt.show()

In [None]:
for stage in validation_stages:
    print("-"*30)
    print(stage)
    learner.stage = stage
    
    all_outputs, all_labels = predict(learner, th=0.5)
    all_outputs = all_outputs.cpu()
    all_labels = all_labels.cpu()
    
    is_correct = (all_outputs >= 0.5) == all_labels
    print(f"{is_correct.float().mean().item():.4f}")
    print(f"{is_correct.float().sum().item()}/{len(is_correct)}")
    
    least_confused = (all_outputs - 0.5).abs()
    least_confused = torch.vstack([least_confused, torch.arange(len(least_confused))])
    
    filtered_least_confused = least_confused[:, is_correct]
    
    sorted_least_confused = filtered_least_confused[:, filtered_least_confused[0, :].argsort()]
    
    initial_indices = sorted_least_confused[1, :].long()
    
    for idx in initial_indices[:10]:
        file1, file2, same = learner.dataloaders[f"{learner.stage}_dataloader"].dataset._get_files_valid(idx)
        
        fig, axes = plt.subplots(1, 2)
        for file, ax in zip([file1, file2], axes.flatten()):
            image = Image.open(file.filename).convert("RGB")
            ax.imshow(image)
            ax.axis("off")
            ax.set_title(id_to_label[file.label])
        
        fig.suptitle("similar" if same else "dissimilar")
        fig.tight_layout()
        plt.show()

Assess influence of threshold:

In [None]:
%%time

thresholds = np.linspace(start=0.025, stop=0.25, num=10)
print(thresholds)

for stage in ["valid_unseen_balanced"]:
    print(stage)
    learner.stage = stage
    
    fig, axes = plt.subplots(1, 2, figsize=(10, 5))
    
    for metric_function, label, ax in zip([accuracy_score, f1_score], ["accuracy", "f1 score"], axes.flatten()):
        print(label)

        values = []
        for th in tqdm(thresholds, total=len(thresholds), leave=False):
            with torch.no_grad():
                y_true, y_predicted = iterate(learner=learner, perform_backward_pass=False, threshold=th, show_progress=False)
            val = metric_function(y_true, y_predicted)
            values.append(val)

        ax.plot(thresholds, values, "o-")
        ax.grid()
        ax.set(xlabel="threshold", ylabel=label)

    fig.suptitle(stage)
    fig.tight_layout()
    plt.show()

In [None]:
learner.metric_values["valid_unseen_balanced_acc"][:10]

In [None]:
learner.metric_values["valid_unseen_balanced_f1"][-10:]