In [None]:
%load_ext autoreload
%autoreload 2

import os
from pathlib import Path

import pandas as pd
import pytorch_lightning as pl
from IPython.core.display import HTML
from IPython.display import display
from pytorch_lightning.plugins import DDPPlugin
from torch.utils.data import DataLoader

from oml.datasets.retrieval import get_retrieval_datasets
from oml.lightning.callbacks.metric import MetricValCallback
from oml.lightning.modules.retrieval import RetrievalModule
from oml.metrics.embeddings import EmbeddingMetrics
from oml.registry.models import get_extractor_by_cfg
from oml.utils.images.augs import get_default_transforms_albu
from oml.utils.visualisation import RetrievalVisualizer


display(HTML("<style>.container { width:100% !important; }</style>"))
pd.set_option('display.max_rows', 330)

%matplotlib inline

if "TEST_RUN" in os.environ:
    dataset_root = Path("../tests/test_examples/data/")
    weights = "random"
    gpus = 0
    strategy = None
    n_workers = 0
else:
    gpus = 1
    strategy = DDPPlugin(find_unused_parameters=False)
    n_workers = 10
    
    dataset_root = Path("/nydl/data/DeepFashion_InShop")
    weights = "/nydl/logs/cur/ml/deepfashion/2022-06-05_13-21-32_deepfashion/checkpoints/best.ckpt"

#     dataset_root = Path("/nydl/data/whales")
#     weights = "pretrained_dino"


In [None]:
train_dataset, valid_dataset = get_retrieval_datasets(dataset_root, im_size=304, 
                                                      pad_ratio_train=0, pad_ratio_val=0,
                                                      train_transform=get_default_transforms_albu(),
                                                      dataframe_name="df.csv"
                                                     )


train_dataset, valid_dataset_fixed = get_retrieval_datasets(dataset_root, im_size=304, 
                                                      pad_ratio_train=0, pad_ratio_val=0,
                                                      train_transform=get_default_transforms_albu(),
                                                      dataframe_name="df_fixed_val.csv"
                                                      )

In [None]:
cfg = {"name": "vit",
        "args": {
          "arch": "vits8",
          "normalise_features": True,
          "use_multi_scale": False,
          "weights": weights,
          "strict_load": True
        }}


val_loader = DataLoader(dataset=valid_dataset, batch_size=200, num_workers=20, drop_last=False)


model = get_extractor_by_cfg(cfg)

pl_model = RetrievalModule(model=model, criterion=None, optimizer=None)

clb_metric = MetricValCallback(metric=EmbeddingMetrics(extra_keys=("paths", "x1", "x2", "y1", "y2")))

trainer = pl.Trainer(gpus=gpus,
                     num_nodes=1,
                     strategy=strategy,
                     replace_sampler_ddp=False,
                     callbacks=[clb_metric],
                    )

ret = trainer.validate(dataloaders=val_loader,
                       verbose=True,
                       model=pl_model
                      )


In [None]:
calc = RetrievalVisualizer.from_embeddings_metric(clb_metric.metric)
n_query = clb_metric.metric.distance_matrix.shape[0]

In [None]:
for i in range(min(100, n_query)):
    calc.visualise(query_idx=9263 + i, top_k=1, skip_no_errors=True)
