In [None]:
%load_ext autoreload
%autoreload 2

from pathlib import Path
import os
import pandas as pd
import pytorch_lightning as pl
from IPython.core.display import HTML
from IPython.display import display
import matplotlib.pyplot as plt

from oml.datasets.retrieval import get_retrieval_datasets
from oml.lightning.callbacks.metric import MetricValCallback
from oml.lightning.modules.retrieval import RetrievalModule
from pytorch_lightning.plugins import DDPPlugin
from torch.utils.data import DataLoader
from oml.utils.visualisation import RetrievalVisualizer
from oml.metrics.embeddings import EmbeddingMetrics
from oml.registry.models import get_extractor_by_cfg
from oml.utils.images.images import tensor_to_numpy_image, imread_cv2
from oml.utils.images.augs import get_default_transforms_albu


display(HTML("<style>.container { width:100% !important; }</style>"))
pd.set_option('display.max_rows', 330)

%matplotlib inline

if "TEST_RUN" in os.environ:
    dataset_root = Path("../tests/test_examples/data/")
    weights = "random"
    gpus = 0
    strategy = None
    n_workers = 0
else:
    gpus = 1
    strategy = DDPPlugin(find_unused_parameters=False)
    n_workers = 10
    
    dataset_root = Path("/nydl/data/DeepFashion_InShop")
    weights = "/nydl/logs/cur/ml/deepfashion/2022-05-27_13-54-34_deepfashion/checkpoints/best.ckpt"

#     dataset_root = Path("/nydl/data/whales")
#     weights = "pretrained_dino"

In [None]:
train_dataset, valid_dataset = get_retrieval_datasets(dataset_root, im_size=512, 
                                                      pad_ratio_train=0, pad_ratio_val=0,
                                                      train_transform=get_default_transforms_albu())



In [None]:
import albumentations as albu
import hydra
from omegaconf import DictConfig

from oml.lightning.entrypoints.train import main
from oml.registry.transforms import AUGS_REGISTRY
from oml.utils.images.augs import get_all_augs

augs = get_all_augs()
    
for p in range(3):
    n = 6
    plt.figure(figsize=(40,40))
    plt.subplot(1, n + 1, 1)
    im = train_dataset.read_image(p+878)
    plt.imshow(im)
    plt.axis('off')
    for i in range(1, n + 1):
        plt.subplot(1, n + 1, i + 1)
        plt.imshow(augs(image=im)['image'])
        plt.axis('off')
    plt.show()
    


In [None]:
from collections import Counter
df = train_dataset.df

df["count"] = 1
y = df.groupby("label").count()["count"]

covered, uncovered, auged = 0, 0, 0
k = 8

for sz, count in Counter(y).items():
    if sz <= k:
        covered += count
        
        auged += ((k - sz) / k) * count
        
    else:
        uncovered += count
        
print(covered / (covered + uncovered))
print(auged / (covered + uncovered))

In [None]:
import matplotlib.pyplot as plt
for i in range(1):
    plt.imshow(tensor_to_numpy_image(train_dataset[105]["input_tensors"]))
    plt.show()

In [None]:
cfg = {"name": "vit",
        "args": {
          "arch": "vits8",
          "normalise_features": False,
          "use_multi_scale": True,
          "weights": weights,
          "strict_load": True
        }}

# cfg = {"name": "resnet",
#         "args": {
#           "arch": "resnet50",
#           "normalise_features": True,
#           "remove_fc": True,
#           "weights": "pretrained",
#           "strict_load": False,
#           "gem_p": 7,
#           "hid_dim": None,
#            "out_dim": None
#         }}

val_loader = DataLoader(dataset=valid_dataset, batch_size=20, num_workers=20, drop_last=False)


model = get_extractor_by_cfg(cfg)

pl_model = RetrievalModule(model=model, criterion=None, optimizer=None)

clb_metric = MetricValCallback(metric=EmbeddingMetrics(extra_keys=("paths", "x1", "x2", "y1", "y2")))

trainer = pl.Trainer(gpus=gpus,
                     num_nodes=1,
                     strategy=strategy,
                     replace_sampler_ddp=False,
                     callbacks=[clb_metric],
                    )

ret = trainer.validate(dataloaders=val_loader,
                       verbose=True,
                       model=pl_model
                      )


In [None]:
calc = RetrievalVisualizer.from_embeddings_metric(clb_metric.metric)
n_query = clb_metric.metric.distance_matrix.shape[0]

In [None]:
for i in range(min(150, n_query)):
    calc.visualise(query_idx=500 + i, top_k=1, skip_no_errors=True)
