In [None]:
%load_ext autoreload
%autoreload 2

import yaml
from IPython.core.display import HTML
from IPython.display import display
import pandas as pd
import torch
import numpy as np
import matplotlib.pyplot as plt
import cv2

from oml.models.siamese import ConcatSiamese
from oml.models.vit.vit import ViTExtractor
from oml.const import MOCK_DATASET_PATH
from oml.transforms.images.torchvision.transforms import get_normalisation_resize_hypvit
from oml.utils.images.images import imread_cv2, imread_pillow

display(HTML("<style>.container { width:100% !important; }</style>"))
pd.set_option('display.max_rows', 330)

%matplotlib inline


In [None]:
extractor = ViTExtractor(arch="vits16", normalise_features=False, weights='/home/daloro/python_projects/open-metric-learning/embedder.ckpt')
siamese = ConcatSiamese(extractor=extractor,
              mlp_hidden_dims=[192],
              weights="/home/daloro/python_projects/open-metric-learning/postprocessor.ckpt"
             )

In [None]:
df = pd.read_csv("/home/daloro/data/DeepFashion_InShop/df.csv")
df = df[df["split"] == "validation"]
df.reset_index(inplace=True, drop=True)


In [None]:
def compare_old(path_1, path_2):
    im1 = cv2.resize(imread_cv2(path_1), (224, 224))
    im2 = cv2.resize(imread_cv2(path_2), (224, 224))
    im = np.concatenate([im1, im2], axis=1)

    tensor1 = get_normalisation_resize_hypvit(224, 224)(imread_pillow(path_1)).unsqueeze(0)
    tensor2 = get_normalisation_resize_hypvit(224, 224)(imread_pillow(path_2)).unsqueeze(0)
    out12 = siamese.predict(tensor1, tensor2)
    out21 = siamese.predict(tensor2, tensor1)
    print(out12.item(), out21.item())

    # attn = extractor.draw_attention(im1);
    # plt.imshow(attn)
    # plt.show()

    # attn = extractor.draw_attention(im2);
    # plt.imshow(attn)
    # plt.show()

    attn1 = extractor.draw_attention(im1);
    attn2 = extractor.draw_attention(im2);
    attn = np.concatenate([attn1, attn2], axis=1)
    plt.imshow(attn)
    plt.show()

    attn = extractor.draw_attention(im);
    plt.imshow(attn)
    plt.show()

In [None]:
def compare(path_1, path_2):
    im1 = cv2.resize(imread_cv2(path_1), (224, 224))
    im2 = cv2.resize(imread_cv2(path_2), (224, 224))
    
    attn1 = extractor.draw_attention(im1);
    attn2 = extractor.draw_attention(im2);
    attn = np.concatenate([attn1, attn2], axis=1)
    plt.imshow(attn)
    plt.show()
    
    im_concat = np.concatenate([im1, im2], axis=1)
    
    attn_siam = siamese.extractor.draw_attention(im_concat)
    plt.imshow(attn_siam)
    plt.show()

In [None]:
p1 = '/home/daloro/data/DeepFashion_InShop/img_highres/WOMEN/Leggings/id_00002368/01_1_front.jpg'
p2 = '/home/daloro/data/DeepFashion_InShop/img_highres/WOMEN/Leggings/id_00002368/01_2_side.jpg'

compare(p1, p2)

In [None]:
p1 = "/home/daloro/data/DeepFashion_InShop/img_highres/WOMEN/Dresses/id_00003264/01_7_additional.jpg"
p2 = ['/home/daloro/data/DeepFashion_InShop/img_highres/WOMEN/Pants/id_00005705/02_3_back.jpg', 
      '/home/daloro/data/DeepFashion_InShop/img_highres/WOMEN/Dresses/id_00003264/02_4_full.jpg']

for g in p2:
    compare(p1, g)