In [None]:
%load_ext autoreload
%autoreload 2

import sys
sys.path.append("/Users/ashabanov/code/metric_learning/open-metric-learning")

from IPython.core.display import HTML
from IPython.display import display
import pandas as pd
from tqdm import tqdm
from pathlib import Path
import torch
import matplotlib.pyplot as plt

from oml.datasets.base import BaseDataset
from oml.models.vit.vit import ViTExtractor
from oml.transforms.images.torchvision.transforms import get_normalisation_resize_hypvit
from oml.transforms.images.utils import get_im_reader_for_transforms
from oml.metrics.embeddings import EmbeddingMetrics
from oml.postprocessors.pairwise_embeddings import PairwiseEmbeddingsPostprocessor
from oml.samplers.balance import BalanceSampler
from oml.miners.inbatch_all_tri import AllTripletsMiner

from source import TensorsWithLabels

display(HTML("<style>.container { width:100% !important; }</style>"))
pd.set_option('display.max_rows', 330)

%matplotlib inline


In [None]:
dataset_root = Path("/nydl/data/Stanford_Online_Products/")
weights = "vits16_sop"

if False:  # save features
    batch_size = 1024

    df = pd.read_csv(dataset_root / "df.csv")

    transform = get_normalisation_resize_hypvit(im_size=224, crop_size=224)
    im_reader = get_im_reader_for_transforms(transform)

    dataset = BaseDataset(df=df, transform=transform, f_imread=im_reader)
    model = ViTExtractor(weights, arch=weights.split("_")[0], normalise_features=True).eval().cuda()
    train_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=20)

    embeddings = torch.zeros((len(df), model.feat_dim))

    with torch.no_grad():
        for i, batch in enumerate(tqdm(train_loader)):
            embs = model(batch["input_tensors"].cuda()).detach().cpu()
            ia = i * batch_size
            ib = min(len(embeddings), (i + 1) * batch_size)
            embeddings[ia:ib, :] = embs

    torch.save(embeddings, dataset_root / f"embeddings_{weights}.pkl")
    
    

def get_datasets():
    embeddings = torch.load(dataset_root / f"embeddings_{weights}.pkl")
    df = pd.read_csv(dataset_root / "df.csv")
    train_mask = df["split"] == "train"
    
    emb_train = embeddings[train_mask]
    emb_val = embeddings[~train_mask]
    
    df_train = df[train_mask]
    df_train.reset_index(inplace=True)
    
    df_val = df[~train_mask]
    df_val.reset_index(inplace=True)

    return emb_train, emb_val, df_train, df_val


emb_train, emb_val, df_train, df_val = get_datasets()


In [None]:
class AllPairsMiner():
    
    def __init__(self):
        self.miner = AllTripletsMiner()
        
    def sample(self, features, labels):
        ii_a, ii_p, ii_n = self.miner._sample(None, labels=labels)
                
        is_same = torch.ones(2 * len(ii_a)).bool()
        is_same[len(ii_a):] = False
        
        return features[[*ii_a, *ii_a]], features[[*ii_p, *ii_n]], is_same

    
class Siamese:

    def __init__(self, feat_dim: int, identity_init: bool):
        super(Siamese, self).__init__()
        self.feat_dim = feat_dim

        self.proj1 = torch.nn.Linear(in_features=feat_dim, out_features=feat_dim, bias=False)
        self.proj2 = torch.nn.Linear(in_features=feat_dim, out_features=feat_dim, bias=False)

        if identity_init:
            self.proj1.load_state_dict({"weight": torch.eye(feat_dim)})
            self.proj2.load_state_dict({"weight": torch.eye(feat_dim)})

    def forward(self, x1: Tensor, x2: Tensor) -> Tensor:
        x1 = self.proj1(x1)
        x2 = self.proj2(x2)
        y = elementwise_dist(x1, x2, p=2)
        return y


In [None]:
# Training
model = Siamese(feat_dim=384, identity_init=True)
model.cuda()

dataset = TensorsWithLabels(df_train, emb_train)

n_labels, n_instances = 50, 4
loader = torch.utils.data.DataLoader(
    batch_sampler=BalanceSampler(labels=dataset.get_labels(), n_labels=n_labels, n_instances=n_instances),
    dataset=dataset
)

optimizer = torch.optim.AdamW(params=model.parameters(), lr=1e-5)

pairs_miner = AllPairsMiner()
criterion = torch.nn.BCEWithLogitsLoss(reduction="mean")

losses = []

for _ in range(5):
    for batch in tqdm(loader):
        x1, x2, gt = pairs_miner.sample(batch["input_tensors"], batch["labels"])
        pred = model(x1=x1.cuda(), x2=x2.cuda())
        pred = pred / 2 # scale l2
        loss = criterion(pred, gt.float().cuda())
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        losses.append(loss.item())


In [None]:
plt.plot(losses)

In [None]:
# Validation with postprocessing
processor = PairwiseEmbeddingsPostprocessor(model, top_n=5)

calculator = EmbeddingMetrics(
    cmc_top_k=(1, 5, 10),
    postprocessor=processor
)
calculator.setup(len(df_val))
calculator.update_data({
    "embeddings": emb_val,
    "is_query": torch.tensor(df_val["is_query"]).bool(),
    "is_gallery": torch.tensor(df_val["is_gallery"]).bool(),
    "labels": torch.tensor(df_val["label"]).long()
})
metrics = calculator.compute_metrics();
