In [None]:
%load_ext autoreload
%autoreload 2

import sys
sys.path.append("/Users/ashabanov/code/metric_learning/open-metric-learning")

from IPython.core.display import HTML
from IPython.display import display
import pandas as pd
from tqdm import tqdm
from pathlib import Path
import torch

from oml.datasets.base import BaseDataset
from oml.models.vit.vit import ViTExtractor
from oml.transforms.images.torchvision.transforms import get_normalisation_resize_hypvit
from oml.transforms.images.utils import get_im_reader_for_transforms
from oml.metrics.embeddings import EmbeddingMetrics
from oml.models.siamese import SimpleSiamese
from oml.postprocessors.pairwise_embeddings import PairwiseEmbeddingsPostprocessor


display(HTML("<style>.container { width:100% !important; }</style>"))
pd.set_option('display.max_rows', 330)

%matplotlib inline


In [None]:
dataset_root = Path("/nydl/data/Stanford_Online_Products/")
weights = "vits16_sop"

if False:  # save features
    batch_size = 1024

    df = pd.read_csv(dataset_root / "df.csv")

    transform = get_normalisation_resize_hypvit(im_size=224, crop_size=224)
    im_reader = get_im_reader_for_transforms(transform)

    dataset = BaseDataset(df=df, transform=transform, f_imread=im_reader)
    model = ViTExtractor(weights, arch=weights.split("_")[0], normalise_features=True).eval().cuda()
    train_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=20)

    embeddings = torch.zeros((len(df), model.feat_dim))

    with torch.no_grad():
        for i, batch in enumerate(tqdm(train_loader)):
            embs = model(batch["input_tensors"].cuda()).detach().cpu()
            ia = i * batch_size
            ib = min(len(embeddings), (i + 1) * batch_size)
            embeddings[ia:ib, :] = embs

    torch.save(embeddings, dataset_root / f"embeddings_{weights}.pkl")
    
    

def get_datasets():
    embeddings = torch.load(dataset_root / f"embeddings_{weights}.pkl")
    df = pd.read_csv(dataset_root / "df.csv")
    train_mask = df["split"] == "train"
    
    emb_train = embeddings[train_mask]
    emb_val = embeddings[~train_mask]
    
    df_train = df[train_mask]
    df_train.reset_index(inplace=True)
    
    df_val = df[~train_mask]
    df_val.reset_index(inplace=True)

    return emb_train, emb_val, df_train, df_val


emb_train, emb_val, df_train, df_val = get_datasets()


In [None]:
# Training
# todo


In [None]:
# Validation with postprocessing

model = SimpleSiamese(feat_dim=384, identity_init=True)
processor = PairwiseEmbeddingsPostprocessor(model, top_n=5)

calculator = EmbeddingMetrics(
    cmc_top_k=(1, 5, 10),
    postprocessor=processor
)
calculator.setup(len(df_val))
calculator.update_data({
    "embeddings": emb_val,
    "is_query": torch.tensor(df_val["is_query"]).bool(),
    "is_gallery": torch.tensor(df_val["is_gallery"]).bool(),
    "labels": torch.tensor(df_val["label"]).long()
})
metrics = calculator.compute_metrics();
