In [1]:
import os
import pickle

In [2]:
import numpy as np

In [3]:
import matplotlib.pyplot as plt

In [3]:
FEATURE_ROOT = '/home/zzl/project/selfSBIR/extracted_feature/canny200'

In [4]:
PHOTO_FEATURE = os.path.join(FEATURE_ROOT, 'photo.pkl')
SKETCH_FEATURE = os.path.join(FEATURE_ROOT, 'sketch.pkl')

In [5]:
photo_data = pickle.load(open(PHOTO_FEATURE, 'rb'))

In [6]:
sketch_data = pickle.load(open(SKETCH_FEATURE, 'rb'))

In [7]:
photo_name, photo_feature = photo_data['name'], photo_data['feature']

In [8]:
sketch_name, sketch_feature = sketch_data['name'], sketch_data['feature']

In [9]:
def euclidean_distances(x, y, squared=True):
    x_square = np.expand_dims(np.einsum('ij, ij->i', x, x), axis=1)
    y_square = np.expand_dims(np.einsum('ij, ij->i', y, y), axis=0)

    distances = np.dot(x, y.T)
    distances *= -2
    distances += x_square
    distances += y_square
    np.maximum(distances, 0, distances)
    np.sqrt(distances, distances)
    return distances

In [10]:
def partition_arg_topK(matrix, K, axis=0):
    a_part = np.argpartition(matrix, K, axis=axis)
    if axis ==0 :
        row_index = np.arange(matrix.shape[1-axis])
        a_sec_argsort_K = np.argsort(matrix[a_part[0:K, :], row_index], axis=axis)
        return a_part[0:K, :][a_sec_argsort_K, row_index]
    else:
        column_index = np.arange(matrix.shape[1-axis])[:, None]
        a_sec_argsort_K = np.argsort(matrix[column_index, a_part[:, 0:K]], axis=axis)
        return a_part[:, 0:K][column_index, a_sec_argsort_K]

In [11]:
distances = euclidean_distances(sketch_feature, photo_feature)

In [12]:
gt_list = []
index = 0
for i, s_name in enumerate(sketch_name):
    query_name = s_name.split('/')[-1]
    query_name = query_name.split('-')[0]
    pk_name = photo_name[index]
    p_name = pk_name.split('/')[-1]
    p_name = p_name.split('.')[0]
    if query_name != p_name:
        index += 1
    gt_list.append(index)

In [13]:
test_item = len(gt_list)

In [14]:
gt_list = np.asarray(gt_list)

In [15]:
gt_list = np.reshape(gt_list, (test_item, 1))

In [16]:
cls_list = []
cls_name = dict()
index = -1
current_class = None
for i, s_name in enumerate(sketch_name):
    query_class = s_name.split('/')[0]
    if current_class != query_class:
        index += 1
        current_class = query_class
        cls_name[current_class] = index
    cls_list.append(index)

In [17]:
cls_list = np.asarray(cls_list)

In [18]:
cls_count = dict()
for i in range(125):
    cls_count[i] = np.sum(cls_list==i)

In [19]:
topK = partition_arg_topK(distances, 10, axis=1)

In [20]:
recall_1 = topK[:, 0, None] == gt_list

In [21]:
key_list = sorted(list(cls_name.keys()))

In [22]:
total_items = 0
recall_per_cls = dict()
for k in key_list:
    cls = cls_name[k]
    item_count = cls_count[cls]
    split_start = total_items
    split_end = total_items+item_count
    recall_per_cls[k] = np.sum(recall_1[split_start: split_end])/item_count
    total_items = split_end
    

In [23]:
recall_per_cls

{'airplane': 0.6029411764705882,
 'alarm_clock': 0.5263157894736842,
 'ant': 0.46153846153846156,
 'ape': 0.75,
 'apple': 0.4576271186440678,
 'armor': 0.6229508196721312,
 'axe': 0.40384615384615385,
 'banana': 0.7017543859649122,
 'bat': 0.4918032786885246,
 'bear': 0.7164179104477612,
 'bee': 0.6037735849056604,
 'beetle': 0.6981132075471698,
 'bell': 0.375,
 'bench': 0.32075471698113206,
 'bicycle': 0.41818181818181815,
 'blimp': 0.55,
 'bread': 0.6981132075471698,
 'butterfly': 0.6909090909090909,
 'cabin': 0.5740740740740741,
 'camel': 0.7288135593220338,
 'candle': 0.48148148148148145,
 'cannon': 0.509090909090909,
 'car_(sedan)': 0.625,
 'castle': 0.5614035087719298,
 'cat': 0.4375,
 'chair': 0.7166666666666667,
 'chicken': 0.6379310344827587,
 'church': 0.6607142857142857,
 'couch': 0.4482758620689655,
 'cow': 0.59375,
 'crab': 0.7735849056603774,
 'crocodilian': 0.7272727272727273,
 'cup': 0.3793103448275862,
 'deer': 0.6785714285714286,
 'dog': 0.5454545454545454,
 'dolphin'

In [36]:
recall_per_cls

{'airplane': 0.5735294117647058,
 'alarm_clock': 0.5087719298245614,
 'ant': 0.5384615384615384,
 'ape': 0.7,
 'apple': 0.4576271186440678,
 'armor': 0.6065573770491803,
 'axe': 0.4230769230769231,
 'banana': 0.7719298245614035,
 'bat': 0.5409836065573771,
 'bear': 0.6268656716417911,
 'bee': 0.6226415094339622,
 'beetle': 0.5849056603773585,
 'bell': 0.4107142857142857,
 'bench': 0.5094339622641509,
 'bicycle': 0.38181818181818183,
 'blimp': 0.55,
 'bread': 0.660377358490566,
 'butterfly': 0.6363636363636364,
 'cabin': 0.6666666666666666,
 'camel': 0.5932203389830508,
 'candle': 0.5,
 'cannon': 0.509090909090909,
 'car_(sedan)': 0.5714285714285714,
 'castle': 0.5614035087719298,
 'cat': 0.546875,
 'chair': 0.8,
 'chicken': 0.7241379310344828,
 'church': 0.6428571428571429,
 'couch': 0.4482758620689655,
 'cow': 0.75,
 'crab': 0.5849056603773585,
 'crocodilian': 0.7454545454545455,
 'cup': 0.4482758620689655,
 'deer': 0.6785714285714286,
 'dog': 0.6363636363636364,
 'dolphin': 0.671875,