In [None]:
%load_ext autoreload
%autoreload 2

import pandas as pd
import yaml
from IPython.core.display import HTML
from IPython.display import display
import torch
import random

from oml.lightning.pipelines.validate import extractor_validation_pipeline
from oml.lightning.callbacks.metric import MetricValCallback

display(HTML("<style>.container { width:100% !important; }</style>"))
pd.set_option('display.max_rows', 330)

%matplotlib inline


In [None]:
cfg = f"""
    accelerator: gpu
    precision: 32
    devices: 1

    dataset_root: /path/to/dataset
    dataframe_name: df.csv
    bs_val: 128
    num_workers: 10

    transforms_val:
      name: norm_resize_hypvit_torch
      args:
        im_size: 224
        crop_size: 224

    model:
      name: vit
      args:
        arch: vits16
        normalise_features: True
        use_multi_scale: False
        weights: /path/to/extractor.ckpt

    metric_args:
      cmc_top_k: [1, 10, 20, 30, 100]
      map_top_k: [5, 10]

"""

In [None]:
trainer, ret_dict = extractor_validation_pipeline(yaml.load(cfg, Loader=yaml.Loader));
clb_metric = [x for x in trainer.callbacks if isinstance(x, MetricValCallback)][0]


In [None]:
cfg_p =  cfg + f"""
    postprocessor:
      name: pairwise_reranker
      args:
        top_n: 5
        pairwise_model:
          name: concat_siamese
          args:
            mlp_hidden_dims: [192]
            weights: /path/to/postprocessor.ckpt
            extractor:
              name: vit
              args:
                arch: vits16
                normalise_features: False
                use_multi_scale: False
                weights: null
        num_workers: 10
        batch_size: 128
        verbose: True

"""

In [None]:
trainer_p, ret_dict_p = extractor_validation_pipeline(yaml.load(cfg_p, Loader=yaml.Loader));
clb_metric_p = [x for x in trainer_p.callbacks if isinstance(x, MetricValCallback)][0]


In [None]:
cmc_1 = clb_metric.metric.metrics_unreduced["OVERALL"]["cmc"][1]
cmc_1_p = clb_metric_p.metric.metrics_unreduced["OVERALL"]["cmc"][1]


In [None]:
# Let's visualize the cases where postprocessor has improved the desired metric:

ids = random.sample(torch.nonzero(cmc_1_p > cmc_1).squeeze().tolist(), 10)

for idx in ids:
    fig = clb_metric.metric.get_plot_for_queries([idx], n_instances=4, verbose=False)
    fig = clb_metric_p.metric.get_plot_for_queries([idx], n_instances=4, verbose=False)
