In [20]:
import argparse
import os
import numpy as np
import pandas as pd

from package.optim.eanet_trainer import EANetTrainer
from package.eval.np_distance import compute_dist, compute_dist_with_visibility
from package.utils.file import walkdir
from package.utils.rank_list import get_rank_list, save_rank_list_to_im
from package.data.datasets.market1501 import Market1501

import seaborn as sns
import matplotlib.pyplot as plt
from package.utils.image import read_im, save_im, make_im_grid
% matplotlib inline

## 配置信息

In [21]:
def parse_args():
    def str2bool(v):
        return v.lower() in ("yes", "true", "t", "1")
    parser = argparse.ArgumentParser()
    parser.add_argument('--q_im_dir', type=str,
                        default='./dataset/market1501/Market-1501-v15.09.15/query',
                        help='Directory of query images.')
    parser.add_argument('--g_im_dir', type=str,
                        default='./dataset/market1501/Market-1501-v15.09.15/bounding_box_test',
                        help='Directory of gallery images.')
    parser.add_argument('--num_queries', type=int, default=16,
                        help='How many query images to visualize.')
    parser.add_argument('--rank_list_size', type=int, default=10,
                        help='How many top gallery images to visualize for each query.')
    parser.add_argument('--save_dir', type=str,
                        default='./exp/vis_rank_list_PCB_M_to_M_id_aware/result',
                        help='Where to save visualization result.')
    parser.add_argument('--id_aware', type=str2bool, default=False,
                        help='Whether id and camera info are known for your images. If known, each gallery image in '
                             'rank list will have either green or red boundary, denoting same or different id as query.')
    parser.add_argument('--pap_mask_provided', type=str2bool, default=False,
                        help='Do you provide pap_mask or not. If not, all PAP based models are not available '
                             'for inference, and you can only use GlobalPool and PCB.')
    parser.add_argument('--exp_dir', type=str,
                        default='./exp/vis_rank_list_PCB_M_to_M_id_aware')
    parser.add_argument('--cfg_file', type=str,
                        default='./package/config/default.py')
    parser.add_argument('--ow_file', type=str,
                        default='./paper_configs/PCB.txt')
    parser.add_argument('--ow_str', type=str, default="cfg.only_infer = True")

    args, _ = parser.parse_known_args()
    return args

## 工具函数

In [22]:
def normalize(nparray, order=2, axis=0):
    """Normalize a N-D numpy array along the specified axis."""
    norm = np.linalg.norm(nparray, ord=order, axis=axis, keepdims=True)
    return nparray / (norm + np.finfo(np.float32).eps)

def get_pap_mask(im_path):
    return market1501_loader.dataset.get_pap_mask('/'.join(im_path.split('/')[2:]))

def get_img_list(im_path):
    return sorted(list(walkdir(im_path, exts=['.jpg', '.png'])))

def get_file_name(im_path):
    return np.array(map(lambda x:x.split('/')[-1].split('.')[0], im_path))

## 加载参数和模型

In [23]:
# 加载参数
args = parse_args()
# 实例化模型
eanet_trainer = EANetTrainer(args=args)

# 本次测试使用均分方式,从上到下均分成6份,即PCB方案；不使用,因为mask需要根据人体关键点提供．
market1501_loader = eanet_trainer.create_dataloader('test', 'market1501', 'query')
parse_im_path = Market1501.parse_im_path

#　测试图片存放路径
test_img_zhuoshi_dir = 'test_distance/zhuoshi'

In [24]:
# gallery图片列表
g_im_paths = get_img_list(test_img_zhuoshi_dir)

#计算标签
columns = np.array(map(lambda x:x.split('/')[-1].split('.')[0], g_im_paths))

# 部分结果显示
columns, g_im_paths[:10]

(array(['101', '102', '103', '104', '105', '106', '107', '108', '41', '42',
        '43', '44', '45', '46', '47', '48', '51', '52', '53', '54', '55',
        '56', '57', '58', '61', '62', '63', '64', '65', '66', '67', '68',
        '71', '72', '73', '74', '75', '76', '77', '78', '81', '82', '83',
        '84', '85', '86', '87', '88', '91', '92', '93', '94', '95', '96',
        '97', '98', '1', '2', '107', '41', '51', '61', '77', '88', '92',
        '107', '41', '51', '61', '77', '84', '92'], dtype='|S3'),
 ['test_distance/zhuoshi/101.png',
  'test_distance/zhuoshi/102.png',
  'test_distance/zhuoshi/103.png',
  'test_distance/zhuoshi/104.png',
  'test_distance/zhuoshi/105.png',
  'test_distance/zhuoshi/106.png',
  'test_distance/zhuoshi/107.png',
  'test_distance/zhuoshi/108.png',
  'test_distance/zhuoshi/41.png',
  'test_distance/zhuoshi/42.png'])

## 计算每一个图片的feature map

In [25]:
model_feat = eanet_trainer.infer_im_list(g_im_paths,
                                             get_pap_mask=get_pap_mask if args.pap_mask_provided else None)

# 结果第一部分是feature maps,尺寸(56, 1536),每一个array(1, 1536)对应一张图片计算得到的feature maps
model_feat['feat']

array([[0.07344911, 0.03218788, 0.        , ..., 0.        , 0.        ,
        0.        ],
       [0.0555134 , 0.02440583, 0.00219755, ..., 0.06656636, 0.        ,
        0.00034353],
       [0.09700681, 0.        , 0.        , ..., 0.        , 0.00641874,
        0.        ],
       ...,
       [0.00850786, 0.        , 0.        , ..., 0.10077368, 0.02864657,
        0.        ],
       [0.        , 0.        , 0.        , ..., 0.        , 0.01660722,
        0.        ],
       [0.        , 0.        , 0.        , ..., 0.17458431, 0.0151435 ,
        0.00154521]], dtype=float32)

## 计算距离矩阵

#### 余弦距离

In [26]:
# 计算两两之间的距离,尺寸为(56, 56))矩阵
dist_mat = compute_dist(model_feat['feat'], model_feat['feat'])

result = pd.DataFrame(dist_mat, index=columns, columns=columns)
result

Unnamed: 0,101,102,103,104,105,106,107,108,41,42,...,77,88,92,107.1,41.1,51,61,77.1,84,92.1
101,1.788139e-07,3.240373e-01,2.286093e-01,2.448016e-01,2.716789e-01,2.480321e-01,3.014180e-01,2.866942e-01,4.800600e-01,3.938867e-01,...,0.502270,5.122590e-01,5.010944e-01,3.014180e-01,4.800600e-01,5.544462e-01,4.869937e-01,4.430389e-01,3.005176e-01,5.221819e-01
102,3.240373e-01,1.192093e-07,3.827962e-01,3.532566e-01,3.965588e-01,3.437737e-01,3.626719e-01,3.681582e-01,3.940219e-01,3.367617e-01,...,0.532582,5.286825e-01,5.130423e-01,3.626719e-01,3.940219e-01,5.486875e-01,5.081195e-01,4.067759e-01,3.427886e-01,4.813970e-01
103,2.286093e-01,3.827962e-01,4.768372e-07,2.378076e-01,2.355905e-01,2.216783e-01,2.964673e-01,3.120916e-01,4.985716e-01,4.079981e-01,...,0.509665,5.185061e-01,4.862748e-01,2.964673e-01,4.985716e-01,5.241501e-01,4.864488e-01,4.384277e-01,3.006490e-01,5.219625e-01
104,2.448016e-01,3.532566e-01,2.378076e-01,3.576279e-07,2.839249e-01,2.319751e-01,3.057888e-01,2.663612e-01,5.557510e-01,4.164661e-01,...,0.482440,4.852901e-01,4.954498e-01,3.057888e-01,5.557510e-01,5.616496e-01,5.539894e-01,4.357958e-01,3.407988e-01,5.265098e-01
105,2.716789e-01,3.965588e-01,2.355905e-01,2.839249e-01,2.384186e-07,1.495442e-01,1.718857e-01,1.851532e-01,4.999082e-01,4.255439e-01,...,0.516587,4.995295e-01,5.070130e-01,1.718857e-01,4.999082e-01,5.636470e-01,5.298257e-01,3.876382e-01,2.822745e-01,5.128813e-01
106,2.480321e-01,3.437737e-01,2.216783e-01,2.319751e-01,1.495442e-01,2.384186e-07,1.944891e-01,1.863247e-01,4.860609e-01,4.204757e-01,...,0.470961,4.838576e-01,4.596638e-01,1.944891e-01,4.860609e-01,5.512872e-01,4.731127e-01,3.624485e-01,3.041402e-01,4.915258e-01
107,3.014180e-01,3.626719e-01,2.964673e-01,3.057888e-01,1.718857e-01,1.944891e-01,1.192093e-07,1.595769e-01,4.497774e-01,3.701516e-01,...,0.499716,5.002929e-01,4.844366e-01,1.192093e-07,4.497774e-01,5.755353e-01,5.959347e-01,3.759044e-01,2.491937e-01,4.783204e-01
108,2.866942e-01,3.681582e-01,3.120916e-01,2.663612e-01,1.851532e-01,1.863247e-01,1.595769e-01,3.576279e-07,4.839672e-01,4.113102e-01,...,0.467427,4.741701e-01,4.597092e-01,1.595769e-01,4.839672e-01,5.724140e-01,5.619332e-01,3.372169e-01,3.201514e-01,5.113811e-01
41,4.800600e-01,3.940219e-01,4.985716e-01,5.557510e-01,4.999082e-01,4.860609e-01,4.497774e-01,4.839672e-01,1.788139e-07,2.364717e-01,...,0.565469,5.932399e-01,5.786829e-01,4.497774e-01,1.788139e-07,4.908005e-01,5.181277e-01,4.406723e-01,3.501391e-01,4.530488e-01
42,3.938867e-01,3.367617e-01,4.079981e-01,4.164661e-01,4.255439e-01,4.204757e-01,3.701516e-01,4.113102e-01,2.364717e-01,5.960464e-08,...,0.546044,5.416141e-01,5.573249e-01,3.701516e-01,2.364717e-01,4.224590e-01,4.566120e-01,4.446062e-01,3.634145e-01,4.892944e-01


In [27]:
# f, ax = plt.subplots(figsize = (13,13),nrows=1)

# sns.heatmap(result, linewidths = 0.08, ax = ax, vmax=0.8, vmin=0, cmap='rainbow') 


![](test_distance/zhuoshi/mat/1.png)

##### 欧氏距离

In [28]:
# dist_mat_euclidean = compute_dist(model_feat['feat'], model_feat['feat'], dist_type='euclidean')
# result_euclidean = pd.DataFrame(dist_mat_euclidean, index=columns, columns=columns)

# f, ax = plt.subplots(figsize = (13,13),nrows=1)
# sns.heatmap(result_euclidean, linewidths = 0.08, ax = ax, vmax=0.8, vmin=0, cmap='rainbow')


![](test_distance/zhuoshi/mat/2.png)

## 下面提取图片,查看效果

In [29]:
#　测试图片存放路径
query_img_zhuoshi_dir = 'test_distance/zhuoshi/query'
q_im_paths = get_img_list(query_img_zhuoshi_dir)
query_feat = eanet_trainer.infer_im_list(q_im_paths,
                                             get_pap_mask=get_pap_mask if args.pap_mask_provided else None)
query_feat['feat'].shape

(7, 1536)

In [30]:
model_feat['feat'].shape

(72, 1536)

In [31]:
def get_file_name(path):
    return np.array(map(lambda x:x.split('/')[-1].split('.')[0], path))

In [32]:
# 计算两两之间的距离,尺寸为(7, 56))矩阵
dist_mat = compute_dist(query_feat['feat'], model_feat['feat'])

result = pd.DataFrame(dist_mat, index=get_file_name(q_im_paths), columns=get_file_name(g_im_paths))
result


Unnamed: 0,101,102,103,104,105,106,107,108,41,42,...,77,88,92,107.1,41.1,51,61,77.1,84,92.1
107,0.301418,0.362672,0.296467,0.305789,0.171886,0.194489,1.192093e-07,0.159577,0.4497774,0.370152,...,0.499716,0.500293,0.484437,1.192093e-07,0.4497774,0.5755353,0.5959347,0.3759044,0.2491937,0.4783204
41,0.48006,0.394022,0.498572,0.555751,0.499908,0.486061,0.4497774,0.483967,1.788139e-07,0.236472,...,0.565469,0.59324,0.578683,0.4497774,1.788139e-07,0.4908005,0.5181277,0.4406723,0.3501391,0.4530488
51,0.554446,0.548687,0.52415,0.56165,0.563647,0.551287,0.5755353,0.572414,0.4908005,0.422459,...,0.629344,0.600649,0.631327,0.5755353,0.4908005,1.192093e-07,0.4650697,0.5383426,0.5218973,0.5692091
61,0.486994,0.508119,0.486449,0.553989,0.529826,0.473113,0.5959347,0.561933,0.5181277,0.456612,...,0.564403,0.552752,0.562376,0.5959347,0.5181277,0.4650697,4.172325e-07,0.5109685,0.5911011,0.5975151
77,0.443039,0.406776,0.438428,0.435796,0.387638,0.362449,0.3759044,0.337217,0.4406723,0.444606,...,0.404149,0.465292,0.392964,0.3759044,0.4406723,0.5383426,0.5109685,2.384186e-07,0.392673,0.4496529
84,0.300518,0.342789,0.300649,0.340799,0.282275,0.30414,0.2491937,0.320151,0.3501391,0.363414,...,0.508467,0.512509,0.494909,0.2491937,0.3501391,0.5218973,0.5911011,0.392673,2.384186e-07,0.4482725
92,0.522182,0.481397,0.521962,0.52651,0.512881,0.491526,0.4783204,0.511381,0.4530488,0.489294,...,0.515638,0.523984,0.509392,0.4783204,0.4530488,0.5692091,0.5975151,0.4496529,0.4482725,4.768372e-07


In [33]:
result = pd.DataFrame(dist_mat, index=q_im_paths, columns=g_im_paths)

In [34]:
def get_rank_list(dist_vec, label, pic_num = 7):
    rank_list = []
    same_id = []
    i = 0
    sort_inds = np.argsort(dist_vec)
    
    for index, ind in  zip(sort_inds, g_im_paths):
        same_id.append( get_file_name(index)[:-1] == get_file_name(label)[:-1] )
        rank_list.append(ind)
        i += 1
        if i >= pic_num:
            break
    return rank_list, same_id

In [35]:
rank_list = []
same_id = []

for i in range(7):
    dist_vect = result.iloc[i].sort_values()[1:8]
    # 行标签名字
    dist_vect_name = dist_vect.name
    
    # 列标签列表
    sort_inds = list(dist_vect.index)
    temp = []
    for j in range(7):
        temp.append(dist_vect_name.split('/')[-1][:-5] == sort_inds[j].split('/')[-1][:-5]) 
    rank_list.append(sort_inds)
    same_id.append(temp)
rank_list, same_id

([['test_distance/zhuoshi/107.png',
   'test_distance/zhuoshi/108.png',
   'test_distance/zhuoshi/105.png',
   'test_distance/zhuoshi/106.png',
   'test_distance/zhuoshi/75.png',
   'test_distance/zhuoshi/46.png',
   'test_distance/zhuoshi/76.png'],
  ['test_distance/zhuoshi/query/41.png',
   'test_distance/zhuoshi/42.png',
   'test_distance/zhuoshi/48.png',
   'test_distance/zhuoshi/44.png',
   'test_distance/zhuoshi/47.png',
   'test_distance/zhuoshi/45.png',
   'test_distance/zhuoshi/82.png'],
  ['test_distance/zhuoshi/51.png',
   'test_distance/zhuoshi/52.png',
   'test_distance/zhuoshi/57.png',
   'test_distance/zhuoshi/53.png',
   'test_distance/zhuoshi/54.png',
   'test_distance/zhuoshi/58.png',
   'test_distance/zhuoshi/56.png'],
  ['test_distance/zhuoshi/61.png',
   'test_distance/zhuoshi/63.png',
   'test_distance/zhuoshi/62.png',
   'test_distance/zhuoshi/64.png',
   'test_distance/zhuoshi/66.png',
   'test_distance/zhuoshi/67.png',
   'test_distance/zhuoshi/68.png'],
  ['te

In [36]:
# 结果存放路径
args.save_dir = './test_distance/zhuoshi/misc'

In [37]:
def add_border(im, border_width, value):
    assert (im.ndim == 3) and (im.shape[0] == 3)
    im = np.copy(im)

    if isinstance(value, np.ndarray):
        # reshape to [3, 1, 1]
        value = value.flatten()[:, np.newaxis, np.newaxis]
    im[:, :border_width, :] = value
    im[:, -border_width:, :] = value
    im[:, :, :border_width] = value
    im[:, :, -border_width:] = value

    return im

def save_rank_list_to_im(rank_list, q_im_path, save_path, same_id=None, resize_h_w=(128, 64)):
    ims = [read_im(q_im_path, convert_rgb=True, resize_h_w=resize_h_w, transpose=True)]
    for i, ind in enumerate(rank_list):
        im = read_im(ind, convert_rgb=True, resize_h_w=resize_h_w, transpose=True)
        if same_id is not None:
            # Add green boundary to true positive, red to false positive
            color = np.array([0, 255, 0]) if same_id[i] else np.array([255, 0, 0])
            im = add_border(im, 3, color)
        ims.append(im)
    im = make_im_grid(ims, 1, len(rank_list) + 1, 8, 255)
    save_path = os.path.join(save_path, q_im_path.split('/')[-1] )
    save_im(im, save_path, transpose=True)

In [38]:
for q_im_path, ranks, ids  in zip(q_im_paths, rank_list, same_id):
    save_rank_list_to_im(ranks, q_im_path, args.save_dir, same_id=ids)

![](test_distance/zhuoshi/misc/41.png)
![](test_distance/zhuoshi/misc/51.png)
![](test_distance/zhuoshi/misc/61.png)
![](test_distance/zhuoshi/misc/77.png)
![](test_distance/zhuoshi/misc/84.png)
![](test_distance/zhuoshi/misc/92.png)
![](test_distance/zhuoshi/misc/107.png)