In [1]:
%load_ext autoreload
%autoreload 2

from IPython.core.display import HTML
from IPython.display import display

display(HTML("<style>.container { width:100% !important; }</style>"))


from pprint import pprint
from tqdm.auto import tqdm

import torch
from torch.nn import BCEWithLogitsLoss
from torch.utils.data import DataLoader

from oml.datasets.base import DatasetQueryGallery
from oml.metrics.embeddings import EmbeddingMetrics
from oml.retrieval.postprocessors.pairwise import PairwiseEmbeddingsPostprocessor
from oml.samplers.category_balance import CategoryBalanceSampler
from oml.samplers.balance import BalanceSampler
from oml import const

from source import BioDatasetWithLabels, BioDatasetQueryGallery, SimpleSiamese, SimpleExtractor, PairsSamplerTwoModalities


In [2]:
# DATA

# I assume that descriptors of both types will have the same size after PCA
feat_dim_after_pca = 64

# index         0  1  2  3  4  5  6  7
labels =        [0, 0, 1, 1, 2, 2, 3, 3]
categories =    [0, 0, 0, 0, 1, 1, 1, 1]  # this is a hospital id
is_first_type = [1, 0, 1, 0, 1, 0, 1, 0]
descriptors =   torch.randn((len(labels), feat_dim_after_pca))

labels2category = dict(zip(labels, categories))


In [3]:
# these extractors were trained on a first stage (for example, as a part of CLIP)
extractor1 = SimpleExtractor(in_dim=feat_dim_after_pca, out_dim=18)
extractor2 = SimpleExtractor(in_dim=feat_dim_after_pca, out_dim=18)

siamese = SimpleSiamese(extractor1=extractor1, extractor2=extractor2).train()
optimizer = torch.optim.SGD(siamese.parameters(), lr=1e-2)
miner = PairsSamplerTwoModalities(hard=False)
criterion = BCEWithLogitsLoss()

train_dataset = BioDatasetWithLabels(labels, categories, is_first_type, descriptors)
# batch_sampler = CategoryBalanceSampler(train_dataset.get_labels(), label2category=labels2category, n_labels=2, n_instances=2, n_categories=2)
batch_sampler = BalanceSampler(train_dataset.get_labels(), n_labels=4, n_instances=2)
train_loader = DataLoader(train_dataset, batch_sampler=batch_sampler)

n_epochs = 50000

for _ in tqdm(range(n_epochs)):
    for batch in train_loader:
        features_a = batch[const.INPUT_TENSORS_KEY][batch["is_first_type"]]
        features_b = batch[const.INPUT_TENSORS_KEY][~batch["is_first_type"]]
        labels_a = batch[const.LABELS_KEY][batch["is_first_type"]]
        labels_b = batch[const.LABELS_KEY][~batch["is_first_type"]]
        
        ids1, ids2, is_negative_pair = miner.sample(features_a, features_b, labels_a, labels_b)
        probs = siamese(x1=features_a[ids1], x2=features_b[ids2])
        loss = criterion(probs, is_negative_pair.float())

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

print(loss)


  0%|          | 0/50000 [00:00<?, ?it/s]

tensor(3.7413e-05, grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)


In [4]:
# test
for j in range(0, 8, 2):
    for i in range(1, 9, 2):
        assert is_first_type[i] != is_first_type[j]
        dist = torch.nn.Sigmoid()(siamese(descriptors[j], descriptors[i]))
        if not bool(torch.round(dist)) == (labels[i] != labels[j]):
            print("broken ", i, j)
        


In [5]:
siamese.eval()

val_dataset = BioDatasetQueryGallery(labels, is_first_type, descriptors)
valid_loader = DataLoader(val_dataset, batch_size=4, shuffle=False)

postprocessor = PairwiseEmbeddingsPostprocessor(top_n=10, pairwise_model=siamese, num_workers=0, batch_size=4)
calculator = EmbeddingMetrics(
    cmc_top_k=(1,5),
    postprocessor=postprocessor
)
calculator.setup(num_samples=len(val_dataset))

for batch in valid_loader:
    calculator.update_data(data_dict=batch)

# Note! I don't apply sigmoid on top of siamese model's output. I will fix it in the next PR soon!!!
calculator.compute_metrics();  # Pairwise inference happens here



Metrics:
{'OVERALL': {'cmc': {1: tensor(1.), 5: tensor(1.)},
             'map': {5: tensor(1.)},
             'pcf': {0.5: tensor(0.0469)},
             'precision': {5: tensor(1.)}}}


