In [1]:
import os
import pickle
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

In [2]:
def euclidean_distances(x, y, squared=True):
    x_square = np.expand_dims(np.einsum('ij, ij->i', x, x), axis=1)
    y_square = np.expand_dims(np.einsum('ij, ij->i', y, y), axis=0)

    distances = np.dot(x, y.T)
    distances *= -2
    distances += x_square
    distances += y_square
    np.maximum(distances, 0, distances)
    np.sqrt(distances, distances)
    return distances

In [3]:
def partition_arg_topK(matrix, K, axis=0):
    a_part = np.argpartition(matrix, K, axis=axis)
    if axis ==0 :
        row_index = np.arange(matrix.shape[1-axis])
        a_sec_argsort_K = np.argsort(matrix[a_part[0:K, :], row_index], axis=axis)
        return a_part[0:K, :][a_sec_argsort_K, row_index]
    else:
        column_index = np.arange(matrix.shape[1-axis])[:, None]
        a_sec_argsort_K = np.argsort(matrix[column_index, a_part[:, 0:K]], axis=axis)
        return a_part[:, 0:K][column_index, a_sec_argsort_K]

In [4]:
FEATURE_ROOT = '/data1/zzl/ICCV/features/imgnet_cls/sketchydb/'

In [5]:
PHOTO_FEATURE = os.path.join(FEATURE_ROOT, 'photo.pkl')
SKETCH_FEATURE = os.path.join(FEATURE_ROOT, 'sketch.pkl')

In [6]:
photo_data = pickle.load(open(PHOTO_FEATURE, 'rb'))
sketch_data = pickle.load(open(SKETCH_FEATURE, 'rb'))

In [7]:
photo_name, photo_feature = photo_data['name'], photo_data['feature']
sketch_name, sketch_feature = sketch_data['name'], sketch_data['feature']

In [8]:
distances = euclidean_distances(sketch_feature, photo_feature)

In [9]:
gt_list = []
index = 0
for i, s_name in enumerate(sketch_name):
    query_name = s_name.split('/')[-1]
    query_name = query_name.split('-')[0]
    pk_name = photo_name[index]
    p_name = pk_name.split('/')[-1]
    p_name = p_name.split('.')[0]
    if query_name != p_name:
        index += 1
    gt_list.append(index)

In [10]:
test_item = len(gt_list)

In [11]:
gt_list = np.asarray(gt_list)

In [12]:
gt_list = np.reshape(gt_list, (test_item, 1))

In [13]:
topK = partition_arg_topK(distances, 10, axis=1)

In [14]:
recall_1 = topK[:, 0, None] == gt_list

In [15]:
recall_5 = np.any(topK[:, :5] == gt_list, axis=1)

In [16]:
np.mean(recall_5)

0.873424890273255

In [17]:
negative_5_index = np.where(recall_5==False)

In [18]:
for i in negative_5_index[0]:
    print(i, sketch_name[i])

6 airplane/n02691156_10151-7.png
14 airplane/n02691156_11286-1.png
18 airplane/n02691156_11286-5.png
19 airplane/n02691156_1512-1.png
21 airplane/n02691156_1512-3.png
22 airplane/n02691156_1512-4.png
23 airplane/n02691156_1512-5.png
25 airplane/n02691156_1512-7.png
27 airplane/n02691156_1512-9.png
28 airplane/n02691156_1692-1.png
29 airplane/n02691156_1692-2.png
30 airplane/n02691156_1692-3.png
38 airplane/n02691156_43250-5.png
70 alarm_clock/n02694662_12296-3.png
77 alarm_clock/n02694662_13497-5.png
83 alarm_clock/n02694662_14927-1.png
84 alarm_clock/n02694662_14927-2.png
85 alarm_clock/n02694662_14927-3.png
86 alarm_clock/n02694662_14927-4.png
87 alarm_clock/n02694662_14927-5.png
88 alarm_clock/n02694662_14927-6.png
94 alarm_clock/n02694662_1564-6.png
127 ant/n02219486_21712-3.png
140 ant/n02219486_25623-1.png
141 ant/n02219486_25623-2.png
145 ant/n02219486_26238-1.png
158 ant/n02219486_28856-4.png
161 ant/n02219486_28996-1.png
192 ape/n02480495_788-3.png
205 ape/n02481823_6275-4.png

In [15]:
np.mean(recall_1)

0.5385813393742036

In [16]:
negative_index = np.where(recall_1==False)

In [17]:
for i in negative_index[0]:
    print(sketch_name[i])

airplane/n02691156_10151-1.png
airplane/n02691156_10151-2.png
airplane/n02691156_10151-7.png
airplane/n02691156_10391-2.png
airplane/n02691156_10391-5.png
airplane/n02691156_10391-6.png
airplane/n02691156_11286-1.png
airplane/n02691156_11286-5.png
airplane/n02691156_1512-1.png
airplane/n02691156_1512-2.png
airplane/n02691156_1512-3.png
airplane/n02691156_1512-4.png
airplane/n02691156_1512-5.png
airplane/n02691156_1512-7.png
airplane/n02691156_1512-8.png
airplane/n02691156_1512-9.png
airplane/n02691156_1692-1.png
airplane/n02691156_1692-2.png
airplane/n02691156_1692-3.png
airplane/n02691156_1692-4.png
airplane/n02691156_1692-5.png
airplane/n02691156_1692-6.png
airplane/n02691156_43250-4.png
airplane/n02691156_43250-5.png
airplane/n02691156_43250-8.png
airplane/n02691156_47926-1.png
airplane/n02691156_47926-3.png
airplane/n02691156_47926-6.png
airplane/n02691156_47926-8.png
airplane/n02691156_5740-3.png
airplane/n02691156_7639-1.png
airplane/n02691156_7639-3.png
airplane/n02691156_7639-4

In [18]:
topK[negative_index[0][0]]

array([  8, 523,   9,   7,   0,   4, 941, 527,   1,   2])