In [1]:
%env CUDA_VISIBLE_DEVICES 1

env: CUDA_VISIBLE_DEVICES=1


In [2]:
import torch
import torchvision
from torchvision import transforms as T
import numpy as np
from tqdm import tqdm_notebook
from PIL import Image
from glob import glob
import shutil
from sklearn.cluster import KMeans

In [3]:
from model import build_model
from dataset import make_dataloader
from config import cfg

In [4]:
cfg.merge_from_file('configs/cvwc1.yml')

In [5]:
_, _, _, num_classes = make_dataloader(cfg)

=> Data loaded
Dataset statistics:
  ----------------------------------------
  subset   | # ids | # images | # cameras
  ----------------------------------------
  train    |   107 |     1887 |         1
  query    |    20 |       40 |         1
  gallery  |    20 |      438 |         1
  ----------------------------------------


In [6]:
model = build_model(cfg, num_classes)
para_dict = torch.load('outputs/cvwc_local/resnet50_epoch45.pth')
model.load_state_dict(para_dict)
model.cuda()
model.eval();

In [7]:
trm = T.Compose([
    T.Resize(cfg.INPUT.SIZE_TEST),
    T.ToTensor(),
    T.Normalize(mean=cfg.INPUT.PIXEL_MEAN, std=cfg.INPUT.PIXEL_STD),
])

In [8]:
gallery_list = sorted(glob('/home/zbc/data/cvwc/detection_test/test_bb/*.jpg'))
query_list = sorted(glob('/home/zbc/data/cvwc/detection_test/test_bb/*.jpg'))
img_list = query_list
gallery_np = np.array(gallery_list)
query_np = np.array(query_list)
len(query_list), len(gallery_list)

(4334, 4334)

In [9]:
feats = []
for im in tqdm_notebook(img_list):
    im = Image.open(im)
    feat = model(trm(im).unsqueeze(0).cuda())
    feats.append(feat.cpu().data)
del feat
feats = torch.cat(feats)

HBox(children=(IntProgress(value=0, max=4334), HTML(value='')))




In [10]:
query_feat = feats
gallery_feat = feats

In [11]:
query_feat.shape, gallery_feat.shape

(torch.Size([4334, 2048]), torch.Size([4334, 2048]))

In [12]:
from evaluate import euclidean_dist
from evaluate import re_rank

In [13]:
#distmat = euclidean_dist(query_feat, gallery_feat)
distmat = re_rank(query_feat, gallery_feat)
ind = np.argsort(distmat, axis=1)

In [14]:
distmat[0][ind[0]][:10]

array([3.9988452e-17, 2.1281974e-02, 2.4117138e-02, 3.0147396e-02,
       3.1368878e-02, 3.4097627e-02, 4.3901734e-02, 4.6741121e-02,
       5.3687580e-02, 5.5294842e-02], dtype=float32)

In [15]:
gallery_np[ind[0]][:10]

array(['/home/zbc/data/cvwc/detection_test/test_bb/0001_b00_685_378_643_356.jpg',
       '/home/zbc/data/cvwc/detection_test/test_bb/1679_b00_708_366_658_360.jpg',
       '/home/zbc/data/cvwc/detection_test/test_bb/4248_b00_667_336_653_348.jpg',
       '/home/zbc/data/cvwc/detection_test/test_bb/2218_b00_751_315_677_353.jpg',
       '/home/zbc/data/cvwc/detection_test/test_bb/2639_b00_640_316_675_347.jpg',
       '/home/zbc/data/cvwc/detection_test/test_bb/1281_b00_657_293_622_341.jpg',
       '/home/zbc/data/cvwc/detection_test/test_bb/3217_b00_730_324_710_382.jpg',
       '/home/zbc/data/cvwc/detection_test/test_bb/2149_b00_652_338_573_332.jpg',
       '/home/zbc/data/cvwc/detection_test/test_bb/3093_b00_521_334_701_363.jpg',
       '/home/zbc/data/cvwc/detection_test/test_bb/4265_b00_630_292_622_337.jpg'],
      dtype='<U72')

In [16]:
import ipywidgets as widgets
from ipywidgets import interact, interact_manual

In [17]:
from utils.vistools import read_im, make_im_grid, save_im

In [18]:
def show(q_id=0, show_num=10):
    q_im = read_im(query_np[q_id])
    g_ims = [read_im(i) for i in gallery_np[ind[q_id]][1:show_num+1]]
    ims = [q_im]
    ims.extend(g_ims)
    import math
    s = int(math.sqrt(len(ims))) + 1
    img = make_im_grid(ims, s, s, 4, 255)
    save_im(img, 'temp.png')
    img = Image.open('temp.png')
    !rm temp.png
    return img

In [19]:
@interact
def inter_show(q_id=range(1052), show_num=range(1, 101)):
    return show(q_id, show_num)

interactive(children=(Dropdown(description='q_id', options=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, …

In [20]:
import json
import ntpath
import re

In [21]:
pattern = re.compile(r'([\d]+)_b([\d]+)_([\d]+)_([\d]+)_([\d]+)_([\d]+)')

In [22]:
ntpath.basename(img_list[0])[:-4]

'0001_b00_685_378_643_356'

In [23]:
pattern.search(img_list[0]).groups()

('0001', '00', '685', '378', '643', '356')

In [24]:
bboxs = []
bbox_2_id = {}
for bb_idx, bbx in enumerate(img_list):
    img_id, _, x, y, w, h = map(int, pattern.search(bbx).groups())
    bbx_dict = {
        'bbox_id': bb_idx,
        'image_id': img_id,
        'pos': [x, y, w, h]
    }
    bbox_2_id.update({
        ntpath.basename(bbx)[:-4]: bb_idx
    })
    bboxs.append(bbx_dict)

In [25]:
final = {
    'bboxs': bboxs,
    'reid_result': []
}

In [26]:
ans_ids = []
ans_list = gallery_np[ind[0]][1:]
for ans in ans_list:
    ans_ids.append(bbox_2_id[ntpath.basename(ans)[:-4]])

In [27]:
img_list[1676]

'/home/zbc/data/cvwc/detection_test/test_bb/1679_b00_708_366_658_360.jpg'

In [28]:
for idx, im in enumerate(img_list):
    query_id = bbox_2_id[ntpath.basename(im)[:-4]]
    ans_ids = []
    ans_list = gallery_np[ind[idx]][1:]
    for ans in ans_list:
        ans_ids.append(bbox_2_id[ntpath.basename(ans)[:-4]])
    
    final['reid_result'].append({
        'query_id': query_id,
        'ans_ids': ans_ids
    })

In [29]:
json.dump(final, open('wild_reid_submit_rerank_all100.json', 'w'))