In [1]:
%load_ext autoreload
%autoreload 2

import sys
sys.path.append("/Users/ashabanov/code/metric_learning/open-metric-learning")

from IPython.core.display import HTML
from IPython.display import display
import pandas as pd
from tqdm import tqdm
from pathlib import Path
import torch
import matplotlib.pyplot as plt

from oml.datasets.base import BaseDataset
from oml.models.vit.vit import ViTExtractor
from oml.transforms.images.torchvision.transforms import get_normalisation_resize_hypvit
from oml.transforms.images.utils import get_im_reader_for_transforms
from oml.metrics.embeddings import EmbeddingMetrics
from oml.postprocessors.pairwise_embeddings import PairwiseEmbeddingsPostprocessor
from oml.postprocessors.pairwise_images import PairwiseImagesPostprocessor
from oml.samplers.balance import BalanceSampler
from oml.miners.inbatch_all_tri import AllTripletsMiner
from oml.miners.inbatch_hard_tri import HardTripletsMiner
from oml.utils.misc_torch import elementwise_dist
from torchvision.models import resnet50
from oml.models.vit.vit import ViTExtractor
from oml.transforms.images.torchvision.transforms import get_normalisation_resize_hypvit
from oml.datasets.base import DatasetWithLabels
    
from source import TensorsWithLabels

display(HTML("<style>.container { width:100% !important; }</style>"))
pd.set_option('display.max_rows', 330)

%matplotlib inline


In [2]:
dataset_root = Path("/nydl/data/DeepFashion_InShop/")
weights = "vits16_dino"

if False:  # save features
    batch_size = 1024

    df = pd.read_csv(dataset_root / "df.csv")

    transform = get_normalisation_resize_hypvit(im_size=224, crop_size=224)
    im_reader = get_im_reader_for_transforms(transform)

    dataset = BaseDataset(df=df, transform=transform, f_imread=im_reader)
    model = ViTExtractor(weights, arch=weights.split("_")[0], normalise_features=True).eval().cuda()
    train_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=20)

    embeddings = torch.zeros((len(df), model.feat_dim))

    with torch.no_grad():
        for i, batch in enumerate(tqdm(train_loader)):
            embs = model(batch["input_tensors"].cuda()).detach().cpu()
            ia = i * batch_size
            ib = min(len(embeddings), (i + 1) * batch_size)
            embeddings[ia:ib, :] = embs

    torch.save(embeddings, dataset_root / f"embeddings_{weights}.pkl")
    
    
def get_datasets():
    embeddings = torch.load(dataset_root / f"embeddings_{weights}.pkl")
    df = pd.read_csv(dataset_root / "df.csv")
    train_mask = df["split"] == "train"
    
    emb_train = embeddings[train_mask]
    emb_val = embeddings[~train_mask]
    
    df_train = df[train_mask]
    df_train.reset_index(inplace=True)
    
    df_val = df[~train_mask]
    df_val.reset_index(inplace=True)

    return emb_train, emb_val, df_train, df_val


emb_train, emb_val, df_train, df_val = get_datasets()


In [None]:
class PairsMiner:
    
    def __init__(self):
        self.miner = HardTripletsMiner()
#         self.miner = AllTripletsMiner()
        
    def sample(self, features, labels):
        ii_a, ii_p, ii_n = self.miner._sample(features, labels=labels)
            
        ii_a_1, ii_p = zip(*list(set(list(map(lambda x: tuple(sorted([x[0], x[1]])), zip(ii_a, ii_p))))))
        ii_a_2, ii_n = zip(*list(set(list(map(lambda x: tuple(sorted([x[0], x[1]])), zip(ii_a, ii_n))))))
        
        gt_distance = torch.ones(len(ii_a_1) + len(ii_a_2))
        gt_distance[:len(ii_a_1)] = 0
                                
        return features[[*ii_a_1, *ii_a_2]], features[[*ii_p, *ii_n]], gt_distance
    
    
miner = PairsMiner()
miner.sample(torch.ones(10, 24), torch.tensor([0, 0, 0, 0, 1, 1, 1, 1, 1, 1]));


In [None]:
class Siamese(torch.nn.Module):

    def __init__(self, feat_dim: int, identity_init: bool = True):
        super(Siamese, self).__init__()
        self.feat_dim = feat_dim

        self.proj = torch.nn.Linear(feat_dim, feat_dim, bias=True)

        if identity_init:
            ini_dict = {"weight": torch.eye(feat_dim), "bias": torch.zeros(feat_dim)}
            self.proj.load_state_dict(ini_dict)


    def forward(self, x1, x2):
        x1 = self.proj(x1)
        x2 = self.proj(x2)
        
        x1 = x1 / torch.linalg.norm(x1, 2, dim=1, keepdim=True).detach()
        x2 = x2 / torch.linalg.norm(x2, 2, dim=1, keepdim=True).detach()
        
        x = (x1 * x2).sum(dim=1)
        x = (1 - x) / 2
        
        return x
    
    
class Siamese2(torch.nn.Module):

    def __init__(self, feat_dim: int):
        super(Siamese2, self).__init__()
        self.feat_dim = feat_dim

        self.fc1 = torch.nn.Linear(feat_dim, feat_dim, bias=True)
        self.fc2 = torch.nn.Linear(feat_dim, feat_dim, bias=True)
        self.fc3 = torch.nn.Linear(feat_dim, feat_dim, bias=True)
        self.fc  = torch.nn.Linear(4 * feat_dim, feat_dim, bias=True)
        self.last = torch.nn.Linear(feat_dim, 1, bias=False)

    def forward(self, x1, x2):
        y1 = torch.relu(self.fc1(x1 + x2))
        y2 = torch.relu(self.fc2(x1 - x2))
        y3 = torch.relu(self.fc2(x2 - x1))
        y4 = torch.relu(self.fc3(x1 * x2))
        
        y = torch.relu(self.fc(torch.concat([y1, y2, y3, y4], dim=1)))
        
        y = torch.sigmoid(self.last(y)).squeeze()
        
        return y
    
    
y = Siamese2(3)(torch.tensor([[+1.0, 0, 0], [1, 0, 0]]), torch.tensor([[-1.0, 0, 0.0], [0, 1, 0]]))
print(y.shape)


In [6]:
class ImagesSiamese(torch.nn.Module):
    def __init__(self) -> None:
        super(ImagesSiamese, self).__init__()
        self.model = ViTExtractor(weights=weights, arch="vits16", normalise_features=True)
        self.fc = torch.nn.Linear(in_features=self.model.feat_dim * 2, out_features=1)

    def forward(self, x1, x2):
        x1 = self.model(x1)
        x2 = self.model(x2)
        
        x1 = x1 / torch.linalg.norm(x1, 2, dim=1, keepdim=True).detach()
        x2 = x2 / torch.linalg.norm(x2, 2, dim=1, keepdim=True).detach()
        
        x = (x1 * x2).sum(dim=1)
        x = (1 - x) / 2
        
        return x
    

In [7]:
def val(pairwise_model):
    if pairwise_model:
        processor = PairwiseImagesPostprocessor(pairwise_model, top_n=5, image_transforms=get_normalisation_resize_hypvit(224, 224))
#         processor = PairwiseEmbeddingsPostprocessor(pairwise_model, top_n=5)
    else:
        processor = None

    calculator = EmbeddingMetrics(
        cmc_top_k=(1, 5),
        postprocessor=processor,
        extra_keys=("paths",)
    )
    calculator.setup(len(df_val))
    calculator.update_data({
        "embeddings": emb_val,
        "is_query": torch.tensor(df_val["is_query"]).bool(),
        "is_gallery": torch.tensor(df_val["is_gallery"]).bool(),
        "labels": torch.tensor(df_val["label"]).long(),
        "paths": df_val["path"].tolist()
    })
    metrics = calculator.compute_metrics();
    print(metrics)
    
    return calculator
    
    
calc = val(ImagesSiamese().cuda())
calc
    

https://dl.fbaipublicfiles.com/dino/dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth
Checkpoint is already here.


  0%|          | 0/556 [00:02<?, ?it/s]


Metrics:
{'OVERALL': {'cmc': {1: tensor(0.4602), 5: tensor(0.6329)},
             'map': {5: tensor(0.2762)},
             'pcf': {0.5: tensor(0.0339)},
             'precision': {5: tensor(0.3315)}}}
{'OVERALL': {'cmc': {1: tensor(0.4602), 5: tensor(0.6329)}, 'precision': {5: tensor(0.3315)}, 'map': {5: tensor(0.2762)}, 'pcf': {0.5: tensor(0.0339)}}}


<oml.metrics.embeddings.EmbeddingMetrics at 0x7f6aac8bed30>

In [None]:
# Training

model = ImagesSiamese().cuda()
model.cuda().train()

dataset = TensorsWithLabels(df_train, emb_train)

n_labels, n_instances = 100, 4
loader = torch.utils.data.DataLoader(
    batch_sampler=BalanceSampler(labels=dataset.get_labels(), n_labels=n_labels, n_instances=n_instances),
    dataset=dataset, num_workers=0
)

repeated = [next(iter(loader))] * 50

optimizer = torch.optim.AdamW(params=model.parameters(), lr=1e-3)

pairs_miner = PairsMiner()
criterion = torch.nn.BCELoss(reduction="mean")

losses = []
acc = []

for i_epoch in range(20):
    tqdm_loader = tqdm(repeated)
    for batch in tqdm_loader:
        x1, x2, gt_dist = pairs_miner.sample(batch["input_tensors"], batch["labels"])
        x1, x2, gt_dist = x1.cuda(), x2.cuda(), gt_dist.cuda()

        pred_dist = model(x1=x1, x2=x2)
        loss = criterion(pred_dist, gt_dist)
                        
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        
        # logs
        accuracy = ((pred_dist > 0.025) == gt_dist).float().mean().item()
        tqdm_loader.set_postfix({"acc": accuracy, "loss": loss.item()})
        losses.append(loss.item())
        acc.append(accuracy)
        
        
    if i_epoch % 1 == 0:
        val(model)
        pass


In [None]:
plt.hist(pred_dist.detach().cpu().numpy(), bins=20)

In [None]:
plt.plot(losses)
plt.show()
plt.plot(acc)
plt.show()