In [6]:
from IPython.core.display import HTML
from IPython.display import display

display(HTML("<style>.container { width:100% !important; }</style>"))


from pprint import pprint
from collections import Counter

import torch
from torch.nn import BCEWithLogitsLoss
from torch.utils.data import DataLoader

from oml.datasets.base import DatasetWithLabels, DatasetQueryGallery
from oml.inference.flat import inference_on_dataframe
from oml.metrics.embeddings import EmbeddingMetrics
from oml.miners.pairs import PairsMiner
from oml.models.siamese import ConcatSiamese
from oml.models.vit.vit import ViTExtractor
from oml.retrieval.postprocessors.pairwise import PairwiseImagesPostprocessor
from oml.samplers.category_balance import CategoryBalanceSampler
from oml.transforms.images.torchvision.transforms import get_normalisation_resize_torch
from oml.utils.download_mock_dataset import download_mock_dataset
from oml.interfaces.datasets import IDatasetWithLabels
from oml import const

from source import BioDatasetWithLabels, SimpleSiamese, SimpleExtractor, PairsSamplerTwoModalities

"""
we have 4 hid

data sctructure:

hid u1 u2   | v1  v2  v3 
0   5.4 3.2 | 5.3 5.4 9.0
0   5.4 6.2 | 2.3 5.3 9.0
1   5.4 3.2 | 5.3 9.4 9.0
1   5.2 3.1 | 5.2 5.1 9.0


2   5.4 3.2 | 5.3 5.4 9.0
2   5.4 6.2 | 2.3 5.3 9.0
3   5.4 3.2 | 5.3 9.4 9.0
3   5.2 3.1 | 5.2 5.1 9.0


e1
e2
siamese(e1(x1),e2(x2))

"""

# I assume that descriptors of both types will have the same size after PCA
feat_dim_after_pca = 64

labels =        [0, 0, 1, 1, 2, 2, 3, 3]
categories =    [0, 0, 0, 0, 1, 1, 1, 1]  # this is a hospital id
is_first_type = [1, 0, 1, 0, 1, 0, 1, 0]
descriptors =   torch.randn((len(labels), feat_dim_after_pca))





In [2]:
train_dataset[0]

{'input_tensors': tensor([ 0.8439,  1.1075, -1.8287, -1.8799,  1.0038, -0.4417,  0.1399, -1.5350,
          0.8361,  0.6795, -1.1870,  0.9015, -0.5298, -0.7333,  0.8621,  0.5923,
         -1.2551,  0.5639,  3.1262,  0.9393, -0.3574, -0.8543, -1.0328, -0.4436,
         -2.1585, -0.8695, -1.7186,  0.5806,  1.6303, -1.8406, -1.5857, -0.3112,
         -0.3475, -0.3353,  0.1391, -1.8080,  0.5724,  0.2229, -0.2842,  0.4157,
          0.2724, -1.0671,  0.4637,  0.4134,  0.3679,  1.1582, -1.2600, -0.0702,
          1.0421,  1.6099, -0.0864,  1.1118, -0.8526, -0.9129,  1.3778, -0.5983,
          1.0534,  0.6310,  0.4721,  0.3414,  0.2091, -0.0742, -0.2267, -0.1156]),
 'labels': 0,
 'categories': 0,
 'is_first_type': 1}

In [5]:
# these extractors were trained on a first stage (for example, as a part of CLIP)
extractor1 = SimpleExtractor(in_dim=feat_dim_after_pca, out_dim=18)
extractor2 = SimpleExtractor(in_dim=feat_dim_after_pca, out_dim=24)

embeddings_train, embeddings_val, df_train, df_val = \
    inference_on_dataframe(dataset_root, "df.csv", extractor=extractor, transforms_extraction=transform)

siamese = SimpleSiamese(extractor1=extractor1, extractor2=extractor2)
optimizer = torch.optim.SGD(siamese.parameters(), lr=1e-6)
miner = PairsSamplerTwoModalities()
criterion = BCEWithLogitsLoss()

train_dataset = BioDatasetWithLabels(labels, categories, is_first_type, descriptors)
batch_sampler = CategoryBalanceSampler(train_dataset.get_labels(), n_labels=10, n_instances=2, n_categories=1)
train_loader = DataLoader(train_dataset, batch_sampler=batch_sampler)

for batch in train_loader:
    features_a = batch[const.INPUT_TENSORS_KEY][batch["is_first_type"]]
    features_b = batch[const.INPUT_TENSORS_KEY][~batch["is_first_type"]]
    labels_a = batch[const.LABELS_KEY][batch["is_first_type"]]
    labels_b = batch[const.LABELS_KEY][~batch["is_first_type"]]
    
    ids1, ids2, is_negative_pair = miner.sample(features_a, features_b, labels_a, labels_b)
    probs = siamese(x1=batch["input_tensors"][ids1], x2=batch["input_tensors"][ids2])
    loss = criterion(probs, is_negative_pair.float())

    loss.backward()
    optimizer.step()
    optimizer.zero_grad()


TypeError: __init__() missing 2 required positional arguments: 'in_dim' and 'out_dim'

In [None]:
# Siamese re-ranks top-n retrieval outputs of the original model performing inference on pairs (query, output_i)
val_dataset = DatasetQueryGallery(df=df_val, extra_data={"embeddings": embeddings_val}, transform=transform)
valid_loader = DataLoader(val_dataset, batch_size=4, shuffle=False)

postprocessor = PairwiseImagesPostprocessor(top_n=3, pairwise_model=siamese, transforms=transform)
calculator = EmbeddingMetrics(postprocessor=postprocessor)
calculator.setup(num_samples=len(val_dataset))

for batch in valid_loader:
    calculator.update_data(data_dict=batch)

pprint(calculator.compute_metrics())  # Pairwise inference happens here