In [None]:
%load_ext autoreload
%autoreload 2

from IPython.core.display import HTML
from IPython.display import display

display(HTML("<style>.container { width:100% !important; }</style>"))


from pprint import pprint

import torch
from torch.nn import BCEWithLogitsLoss
from torch.utils.data import DataLoader

from oml.datasets.base import DatasetQueryGallery
from oml.metrics.embeddings import EmbeddingMetrics
from oml.retrieval.postprocessors.pairwise import PairwiseImagesPostprocessor
from oml.samplers.category_balance import CategoryBalanceSampler
from oml import const

from source import BioDatasetWithLabels, SimpleSiamese, SimpleExtractor, PairsSamplerTwoModalities

"""
we have 4 hid

data sctructure:

hid u1 u2   | v1  v2  v3 
0   5.4 3.2 | 5.3 5.4 9.0
0   5.4 6.2 | 2.3 5.3 9.0
1   5.4 3.2 | 5.3 9.4 9.0
1   5.2 3.1 | 5.2 5.1 9.0


2   5.4 3.2 | 5.3 5.4 9.0
2   5.4 6.2 | 2.3 5.3 9.0
3   5.4 3.2 | 5.3 9.4 9.0
3   5.2 3.1 | 5.2 5.1 9.0


e1
e2
siamese(e1(x1),e2(x2))

"""

# I assume that descriptors of both types will have the same size after PCA
feat_dim_after_pca = 64

labels =        [0, 0, 1, 1, 2, 2, 3, 3]
categories =    [0, 0, 0, 0, 1, 1, 1, 1]  # this is a hospital id
is_first_type = [1, 0, 1, 0, 1, 0, 1, 0]
descriptors =   torch.randn((len(labels), feat_dim_after_pca))

labels2category = dict(zip(labels, categories))


In [None]:
# these extractors were trained on a first stage (for example, as a part of CLIP)
extractor1 = SimpleExtractor(in_dim=feat_dim_after_pca, out_dim=18)
extractor2 = SimpleExtractor(in_dim=feat_dim_after_pca, out_dim=24)

siamese = SimpleSiamese(extractor1=extractor1, extractor2=extractor2)
optimizer = torch.optim.SGD(siamese.parameters(), lr=1e-6)
miner = PairsSamplerTwoModalities()
criterion = BCEWithLogitsLoss()

train_dataset = BioDatasetWithLabels(labels, categories, is_first_type, descriptors)
batch_sampler = CategoryBalanceSampler(train_dataset.get_labels(), label2category=labels2category, n_labels=2, n_instances=2, n_categories=1)
train_loader = DataLoader(train_dataset, batch_sampler=batch_sampler)

for batch in train_loader:
    features_a = batch[const.INPUT_TENSORS_KEY][batch["is_first_type"]]
    features_b = batch[const.INPUT_TENSORS_KEY][~batch["is_first_type"]]
    labels_a = batch[const.LABELS_KEY][batch["is_first_type"]]
    labels_b = batch[const.LABELS_KEY][~batch["is_first_type"]]
    
    ids1, ids2, is_negative_pair = miner.sample(features_a, features_b, labels_a, labels_b)
    probs = siamese(x1=batch["input_tensors"][ids1], x2=batch["input_tensors"][ids2])
    loss = criterion(probs, is_negative_pair.float())

    loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    
    print(loss)


In [None]:
# Siamese re-ranks top-n retrieval outputs of the original model performing inference on pairs (query, output_i)
val_dataset = DatasetQueryGallery(df=df_val, extra_data={"embeddings": embeddings_val}, transform=transform)
valid_loader = DataLoader(val_dataset, batch_size=4, shuffle=False)

postprocessor = PairwiseImagesPostprocessor(top_n=3, pairwise_model=siamese, transforms=transform)
calculator = EmbeddingMetrics(postprocessor=postprocessor)
calculator.setup(num_samples=len(val_dataset))

for batch in valid_loader:
    calculator.update_data(data_dict=batch)

pprint(calculator.compute_metrics())  # Pairwise inference happens here