In [36]:
import matplotlib.pyplot as plt
import matplotlib.transforms as mtrans
import os
from PIL import Image
import numpy as np

from tqdm.auto import tqdm

from pyretri.config import get_defaults_cfg, setup_cfg
from pyretri.datasets import build_transformers
from pyretri.models import build_model
from pyretri.extract import build_extract_helper
from pyretri.index import build_index_helper, feature_loader

In [2]:
cfg = get_defaults_cfg()
cfg = setup_cfg(cfg, '/home/artem/data/COCO/CBIR_data/CBIR_test/class/config.yaml', '')

In [15]:
def get_index_structures(cfg):
    index_structures = {}

    # build transformers
    transformers = build_transformers(cfg.datasets.transformers)
    index_structures['transformers'] = transformers

    # build model
    model = build_model(cfg.model)

    extract_helper = build_extract_helper(model, cfg.extract)
    index_structures['extract_helper'] = extract_helper

    # load gallery features
    gallery_fea, gallery_info, _ = feature_loader.load(cfg.index.gallery_fea_dir, cfg.index.feature_names)
    index_structures['gallery_fea'] = gallery_fea
    index_structures['gallery_info'] = gallery_info

    index_helper = build_index_helper(cfg.index)
    index_structures['index_helper'] = index_helper
    
    return index_structures

[LoadFeature] Success, total 5880 images, 
 feature names: dict_keys(['pool5_GeM'])
[LoadFeature] Success, total 5880 images, 
 feature names: dict_keys(['pool5_GeM'])


In [34]:
def index_and_save(queries, output_path, index_structures):

    def index_img(path, top_k=10):
        img = Image.open(path).convert("RGB")
        img_tensor = index_structures['transformers'](img)
        img_fea_info = index_structures['extract_helper'].do_single_extract(img_tensor)
        stacked_feature = list()
        for name in cfg.index.feature_names:
            assert name in img_fea_info[0], "invalid feature name: {} not in {}!".format(name, img_fea_info[0].keys())
            stacked_feature.append(img_fea_info[0][name].cpu())
        img_fea = np.concatenate(stacked_feature, axis=1)
        index_result_info, _, _ = index_structures['index_helper'].do_index(img_fea, img_fea_info, index_structures['gallery_fea'])
        top_k_idx = index_result_info[0]['ranked_neighbors_idx'][:top_k]
        return top_k_idx

    
    def visualise_index_result(top_k_idx, query_path, ax):

        def single_imshow(ax, img, title, color='black'):
            ax.imshow(img)
            ax.set_axis_off()
            ax.set_title(title, color=color, fontweight='bold')

        query_label = query_path.split('/')[-2]
        query_img = np.array(Image.open(query_path))
        single_imshow(ax[0], query_img, f'QUERY\n{query_label}')

        for i, idx in enumerate(top_k_idx):
            idx_info = index_structures['gallery_info'][idx]
            label = idx_info['label']
            img_path = idx_info['path']
            img = np.array(Image.open(img_path))
            text_color = 'green' if query_label == label else 'red'
            single_imshow(ax[i + 1], img, f'TOP {i + 1}\n{label}', text_color)


    f, ax = plt.subplots(nrows=3, ncols=11, figsize=(20, 20), gridspec_kw = {'wspace':0.15, 'hspace':-0.83})
    f.set_tight_layout(False)

    for i, path in enumerate(queries):
        top_k_idx = index_img(path)
        visualise_index_result(top_k_idx, path, ax[i])

    plt.plot([0.125, 0.9], [0.553, 0.553], color='black', lw=3,transform=plt.gcf().transFigure, clip_on=False)
    plt.plot([0.125, 0.9], [0.455, 0.455], color='black', lw=3,transform=plt.gcf().transFigure, clip_on=False)
    plt.plot([0.1915, 0.1915], [0.3685, 0.65], color='black', lw=3,transform=plt.gcf().transFigure, clip_on=False)
    
    plt.savefig(output_path, bbox_inches='tight', pad_inches=0)
    plt.close(f)
    
queries = ['/home/artem/data/COCO/CBIR_data/CBIR_test/class/query/airplane/973_airplane_vehicle_0.9336.jpg',
           '/home/artem/data/COCO/CBIR_data/CBIR_test/class/query/airplane/66_airplane_vehicle_0.7957.jpg',
           '/home/artem/data/COCO/CBIR_data/CBIR_test/class/query/airplane/353_airplane_vehicle_0.5255.jpg']
output_path = 'search_result.png'
index_and_save(queries, output_path, index_structures)

In [45]:
BASE_DIR = '/home/artem/data/COCO/CBIR_data/CBIR_test/superclass'
QUERY_DIR = f'{BASE_DIR}/query'
CFG_PATH = f'{BASE_DIR}/config.yaml'

cfg = get_defaults_cfg()
cfg = setup_cfg(cfg, CFG_PATH, '')
index_structures = get_index_structures(cfg)

[LoadFeature] Success, total 5964 images, 
 feature names: dict_keys(['pool5_GeM'])
[LoadFeature] Success, total 5964 images, 
 feature names: dict_keys(['pool5_GeM'])


In [46]:
for class_name in tqdm(os.listdir(QUERY_DIR)):
    queries = os.listdir(f'{QUERY_DIR}/{class_name}')
    queries = list(map(lambda x: f'{QUERY_DIR}/{class_name}/{x}', queries))
    output_path = f'search_results/superclass/{class_name}.png'
    index_and_save(queries, output_path, index_structures)

  0%|          | 0/12 [00:00<?, ?it/s]