In [1]:
"""Landmark retrieval offline code for ILR2021."""

import numpy as np
import pandas as pd
import torch
from tqdm import tqdm

import pickle
import os
import time
import sys
sys.path.append('../input/util-code/local_matching')
from local_matching import load_superpointglue_model, generate_superpoint_superglue, get_num_inliers, get_total_score
from superpointglue_util import get_whole_cached_num_inliers, save_whole_cached_num_inliers


# pylint: disable=not-callable, invalid-name, line-too-long
# ########################## Cell 2 Configs ##################################

MODE = 'retrieval'
setting = {
        # ############# General params ############
        'IMAGE_DIR': '../input/landmark-retrieval-2021/',
        'RAW_IMAGE_DIR': '/home/gongyou.zyq/datasets/google_landmark/',
        'OUTPUT_TEMP_DIR': f'../temp/{MODE}/',    # Will lose after reset
        'OUTPUT_DIR': f'../working/{MODE}/',    # Will be saved after reset
        'MODEL_DIR': '../input/models/',
        'META_DIR': '../input/meta-data-final/',
        'PROBE_DIR': '../input/landmark-retrieval-2021/test/',
        'INDEX_DIR': '../input/landmark-retrieval-2021/index/',
        'FEAT_DIR': f'../temp/{MODE}/features/',
        'SIMS_DIR': f'../temp/{MODE}/sims/',
        'SAMPLE_TEST_NUM': 1129,
        # ############# ReID params ############
        'REID_EXTRACT_FLAG': True,
        'FP16': True,
        'DEBUG_FLAG': False,
        'MULTI_SCALE_FEAT': False,
        # ############# ReID model list ############
        # 'MODEL_LIST': ['R50', 'R101ibn', 'RXt101ibn', 'SER101ibn', 'ResNeSt101', 'ResNeSt269', 'EffNetB7'],
        # 'MODEL_LIST': ['SER101ibn'],
        'MODEL_LIST': ['SER101ibn', 'RXt101ibn', 'ResNeSt101', 'ResNeSt269'],
        'MODEL_WEIGHT': [1.0, 1.0, 1.0, 1.0],
         #'MODEL_WEIGHT': [1.0,],
        'IMAGE_SIZE': 512,    # 256, 384, 448, 512
        'BATCH_SIZE': 32,
        'MODEL_PARAMS': {'R50': {'MODEL_NAME': 'R50_256.pth', 'BACKBONE': 'resnet50'},
                         'R101ibn': {'MODEL_NAME': 'R101ibn_384_finetune_c2x.pth', 'BACKBONE': 'resnet101_ibn_a'},
                         'RXt101ibn': {'MODEL_NAME': 'RXt101ibn_512_all.pth', 'BACKBONE': 'resnext101_ibn_a'},
                         'SER101ibn': {'MODEL_NAME': 'SER101ibn_512_all.pth', 'BACKBONE': 'se_resnet101_ibn_a'},
                         'ResNeSt101': {'MODEL_NAME': 'ResNeSt101_512_all.pth', 'BACKBONE': 'resnest101'},
                         'ResNeSt269': {'MODEL_NAME': 'ResNeSt269_512_all.pth', 'BACKBONE': 'resnest269'},
                         'EffNetB7': {'MODEL_NAME': 'efficientnet-b7_20_512_3796.pth', 'BACKBONE': 'efficientnet-b7'},
                         },
        # ############# Rerank params ############
        'KR_FLAG': False,
        'K1': 10,
        'K2': 3,
        'INITIAL_RANK_FILE': f'../temp/{MODE}/initial_rank.npy',
        'NAME_LIST_FILE': f'../temp/{MODE}/name_list.pkl',
        'EUC_DIST_DIR': f'../temp/{MODE}/euc_dist/',
        'GRAPH_DIST_DIR': f'../temp/{MODE}/graph_dist/',
        'QE_DIST_DIR': f'../temp/{MODE}/qe_dist/',
        'JACCARD_DIR': f'../temp/{MODE}/jaccard/',
        'LAMBDA': 0.3,
        # ############# Category Rerank ############
        'CATEGORY_RERANK': 'before_merge',    # after_merge, before_merge or off
        'VOTE_NUM': 3,    # Soft voting seems not work
        'REF_SET_EXTRACT': False,    # Just need to cache once
        'REF_ALL_LIST': '../input/meta-data-all/cache_all_list.pkl',
        'REF_SET_LIST': '../input/meta-data-final/cache_index_train_list.pkl',    # full, index_train, all
        'REF_SET_META': f'../temp/{MODE}/ref_meta.pkl',
        'REF_SET_FEAT': '../input/meta-data-final/ref_feats.pkl',
        'REF_LOC_MAP': '../input/meta-data-final/gbid2country.pkl',
        'CATEGORY_THR': -1.0,
        'alpha': 1.0,
        'beta': 0.1,
        # ############ LocalMatching Rerank ############
        'LOCAL_MATCHING': 'off',    # 'spg' or 'off'
        'SPG_MODEL_DIR': '../input/models/local_matching',
        'SPG_CACHE_DIR': f'../temp/{MODE}/local_matching_cache',
        'SPG_RERANK_NUM': 10,    # rerank length, larger is better
        'LOCAL_WEIGHT': 0.15,
        'MAX_INLIERS': 90,
        'SPG_DO_CACHE': True,    # wheather save inliers cache or not.
        }


# ########################## Cell 8 Get output file  #########################

def slice_jaccard(probe_feat, topk_index_feats):
    """Kr rerank for only top-k index feats."""

    query_num = 1
    gallery_num = len(topk_index_feats)
    all_num = query_num + gallery_num
    concat_feat = torch.cat([probe_feat, topk_index_feats])
    cos_sim = torch.matmul(concat_feat, concat_feat.T)    # (101, 101)
    original_dist = 1.0 - (cos_sim + 1.0)/2
    initial_rank = torch.argsort(original_dist, dim=1)
    initial_rank = initial_rank.cpu().numpy()
    original_dist = original_dist.cpu().numpy()
    # print(f'Memory usage: {psutil.virtual_memory().percent}')
    V = np.zeros((all_num, all_num))
    gallery_num = original_dist.shape[0]

    k1 = setting['K1']
    k2 = setting['K2']
    for i in range(all_num):
        # k-reciprocal neighbors
        forward_k_neigh_index = initial_rank[i,:k1+1]
        backward_k_neigh_index = initial_rank[forward_k_neigh_index,:k1+1]
        fi = np.where(backward_k_neigh_index==i)[0]
        k_reciprocal_index = forward_k_neigh_index[fi]
        k_reciprocal_expansion_index = k_reciprocal_index
        for j in range(len(k_reciprocal_index)):
            candidate = k_reciprocal_index[j]
            candidate_forward_k_neigh_index = initial_rank[candidate,:int(np.around(k1/2.))+1]
            candidate_backward_k_neigh_index = initial_rank[candidate_forward_k_neigh_index,:int(np.around(k1/2.))+1]
            fi_candidate = np.where(candidate_backward_k_neigh_index == candidate)[0]
            candidate_k_reciprocal_index = candidate_forward_k_neigh_index[fi_candidate]
            if len(np.intersect1d(candidate_k_reciprocal_index,k_reciprocal_index))> 2./3*len(candidate_k_reciprocal_index):
                k_reciprocal_expansion_index = np.append(k_reciprocal_expansion_index,candidate_k_reciprocal_index)

        k_reciprocal_expansion_index = np.unique(k_reciprocal_expansion_index)
        weight = np.exp(-original_dist[i,k_reciprocal_expansion_index])
        V[i,k_reciprocal_expansion_index] = weight/np.sum(weight)
    original_dist = original_dist[:query_num,]
    if k2 != 1:
        V_qe = np.zeros_like(V,dtype=np.float32)
        for i in range(all_num):
            V_qe[i,:] = np.mean(V[initial_rank[i,:k2],:],axis=0)
        V = V_qe
        del V_qe
    del initial_rank
    invIndex = []
    for i in range(gallery_num):
        invIndex.append(np.where(V[:,i] != 0)[0])

    jaccard_dist = np.zeros_like(original_dist,dtype = np.float32)

    for i in range(query_num):
        temp_min = np.zeros(shape=[1,gallery_num],dtype=np.float32)
        indNonZero = np.where(V[i,:] != 0)[0]
        indImages = []
        indImages = [invIndex[ind] for ind in indNonZero]
        for j in range(len(indNonZero)):
            temp_min[0,indImages[j]] = temp_min[0,indImages[j]]+ np.minimum(V[i,indNonZero[j]],V[indImages[j],indNonZero[j]])
        jaccard_dist[i] = 1-temp_min/(2-temp_min)

    # final_dist = jaccard_dist*(1-lambda_value) + original_dist*lambda_value
    del original_dist
    del V
    slice_jaccard = jaccard_dist[:query_num,query_num:].flatten()
    return torch.tensor(slice_jaccard).cuda().half()


def kr_rerank_fast(probe_feats, index_feats):
    """Memory efficient rerank.

    probe_feats and index_feats are in gpu tensor.
    """

    # print('Starting KR re_ranking')
    topk = 30
    fast_lambda = 0.5
    cos_sims = torch.matmul(probe_feats, index_feats.T)
    original_dists = 1.0 - (cos_sims + 1.0)/2
    query_num = len(probe_feats)
    gallery_num = len(index_feats)
    final_dists = torch.zeros((query_num, gallery_num)).cuda().half()
    # for i, probe_feat in enumerate(tqdm(probe_feats)):
    for i, probe_feat in enumerate(probe_feats):
        q_sim = cos_sims[i]
        _, top_indices = torch.topk(q_sim, topk)
        q_neighbour_feats = index_feats[top_indices]
        q_neighbour_sims = torch.matmul(q_neighbour_feats, index_feats.T)    # (topk, 400w)
        _, neighbour_top_indices = torch.topk(q_neighbour_sims, topk)    # (topk, topk)
        neighbour_top_indices = neighbour_top_indices.flatten()
        top_indices = torch.cat([top_indices, neighbour_top_indices])
        top_indices = torch.unique(top_indices)
        # print(q_neighbour_sims.shape, neighbour_top_indices.shape, top_indices.shape)
        topk_index_feats = index_feats[top_indices]
        jaccard_dist = slice_jaccard(probe_feat[None, :], topk_index_feats)
        expand_jaccard = torch.ones((gallery_num,)).cuda().half()
        expand_jaccard[top_indices] = jaccard_dist
        final_dists[i] = original_dists[i] * fast_lambda
        final_dists[i] += expand_jaccard * (1-fast_lambda)
    return 1.0 - final_dists


def merge_tags(tags_list, scores, weight):
    if len(tags_list)==1 or scores is None:
        return tags_list[-1]

    tags = torch.tensor(tags_list)
    scores = torch.tensor(scores)
    if weight is not None:
        weight = torch.tensor(weight)
    print(tags.shape, scores.shape)
    #print(tags_list.shape)
    #tags = torch.stack(tags_list, dim=0)
    merged_tags = []
    count = 0
    for i in range(tags.shape[1]):
        preds = torch.unique(tags[:, i])
        if weight is not None:
            score = scores[:, i] * weight
        else:
            score = scores[:, i]
        if len(preds) == 1:
            merged_tags.append(preds[0])
        elif len(preds) == tags.shape[0]:
            merged_tags.append(-1)
        else:
            unique_score_list = []
            for item in preds:
                sum_score = torch.sum(score[tags[:, i] == item])
                unique_score_list.append(sum_score)
            unique_score_list = torch.tensor(unique_score_list)
            best_index = torch.argmax(unique_score_list)
            merged_tags.append(preds[best_index])
            #print(f"{tags[:, i]}->{preds[best_index]}")
            #print(f"{score}")
            count += 1
    print(f'{count} low constancy tags')
    return merged_tags

def get_probe_tags_index(probe_feats, index_feats, index_tags, mode='avg'):

    probe_tags = []
    probe_scores = []
    if mode == 'avg':
        tag_mean_feats = []
        index_tags_unique = torch.unique(index_tags)
        for index_tag in index_tags_unique:
            same_tags = torch.where(index_tags == index_tag)
            same_tag_feats = torch.mean(index_feats[same_tags], dim=0, keepdim=True)
            same_tag_feats = same_tag_feats / torch.norm(same_tag_feats, 2, 1)
            tag_mean_feats.append(same_tag_feats)
        tag_mean_feats = torch.cat(tag_mean_feats, dim=0)

        for probe_index, query_feat in enumerate(tqdm(probe_feats)):
            sim = torch.matmul(tag_mean_feats, query_feat[:, None]).flatten()
            _, indices = torch.topk(sim, 1)
            probe_tag = index_tags_unique[indices[0]]
            probe_tags.append(probe_tag)
            probe_scores.append(sim[indices[0]])
            print(sim[indices[0]])
    elif mode == 'single':
        for probe_index, query_feat in enumerate(tqdm(probe_feats)):
            sim = torch.matmul(index_feats, query_feat[:, None]).flatten()
            _, indices = torch.topk(sim, 1)
            probe_tag = index_tags[indices[0]]
            probe_tags.append(probe_tag)
            probe_scores.append(sim[indices[0]])

    return probe_tags, probe_scores

def get_probe_tags_avg(probe_feats, ref_info):
    globalid2refindex = ref_info['globalid2refindex']
    ref_feats = ref_info['ref_feats']
    category_mean_feats = []
    for globalid in sorted(globalid2refindex.keys()):
        refindexes = globalid2refindex[globalid]
        same_id_feats = ref_feats[refindexes]
        same_id_feats = torch.mean(same_id_feats, dim=0, keepdim=True)
        same_id_feats = same_id_feats / torch.norm(same_id_feats, 2, 1)
        category_mean_feats.append(same_id_feats)
    category_mean_feats = torch.cat(category_mean_feats, dim=0)
    probe_tags = []

    print(f'computing probe tags, total probes:{probe_feats.shape[0]}, refs: {ref_feats.shape[0]}')
    for probe_index, query_feat in enumerate(tqdm(probe_feats)):
        ref_sim = torch.matmul(category_mean_feats, query_feat[:, None]).flatten()
        _, ref_indices = torch.topk(ref_sim, setting['VOTE_NUM'])
        pred_global_id = ref_indices[0]
        probe_tags.append(pred_global_id)

    return probe_tags


def get_probe_tags_topk(probe_feats, ref_info):
    """Get topk probe tags in gpu tensor."""

    ref_gbid = ref_info['ref_gbid']
    ref_loc = ref_info['ref_loc']
    ref_feats = ref_info['ref_feats']
    probe_tags = []
    probe_tag_scores = []
    probe_locs = []
    probe_loc_scores = []
    print(f'computing probe tags, total probes:{probe_feats.shape[0]}, refs: {ref_feats.shape[0]}')
    for probe_index, query_feat in enumerate(tqdm(probe_feats)):
        # ref_sim = torch.matmul(ref_feats, query_feat[:, None]).flatten()
        ref_sim = kr_rerank_fast(query_feat[None, :], ref_feats).flatten()
        _, ref_indices = torch.topk(ref_sim, setting['VOTE_NUM'])

        id_list = ref_gbid[ref_indices]
        loc_list = ref_loc[ref_indices]
        score_list = ref_sim[ref_indices]
        unique_id_list = torch.unique(id_list)
        unique_loc_list = torch.unique(loc_list)
        id_score_list = []
        loc_score_list = []
        for item in unique_id_list:
            indexes = torch.where(id_list == item)[0]
            sum_score = torch.sum(score_list[indexes])
            id_score_list.append(sum_score)
        for item in unique_loc_list:
            indexes = torch.where(loc_list == item)[0]
            sum_score = torch.sum(score_list[indexes])
            loc_score_list.append(sum_score)

        id_score_list = torch.tensor(id_score_list)
        id_score_list = id_score_list / torch.sum(id_score_list)

        loc_score_list = torch.tensor(loc_score_list)
        loc_score_list = loc_score_list / torch.sum(loc_score_list)

        probe_tags.append(unique_id_list)
        probe_tag_scores.append(id_score_list)
        probe_locs.append(unique_loc_list)
        probe_loc_scores.append(loc_score_list)

    return probe_tags, probe_tag_scores, probe_locs, probe_loc_scores

def get_probe_tags(probe_feats, ref_info):
    ref_gbid = []
    refindex2globalid = ref_info['refindex2globalid']
    for refindex in refindex2globalid:
        ref_gbid.append(refindex2globalid[refindex])
    ref_gbid = torch.tensor(ref_gbid).cuda()

    ref_feats = ref_info['ref_feats']
    probe_tags = []
    probe_scores = []

    print(f'computing probe tags, total probes:{probe_feats.shape[0]}, refs: {ref_feats.shape[0]}')
    for probe_index, query_feat in enumerate(tqdm(probe_feats)):
        ref_sim = torch.matmul(ref_feats, query_feat[:, None]).flatten()
        # ref_sim = kr_rerank_fast(query_feat[None, :], ref_feats).flatten()
        _, ref_indices = torch.topk(ref_sim, setting['VOTE_NUM'])

        pred_id_list = []
        pred_score_list = []
        for ref_index in ref_indices:
            pred_score = ref_sim[ref_index]
            pred_global_id = ref_gbid[ref_index]
            pred_id_list.append(pred_global_id)
            pred_score_list.append(pred_score)
        pred_id_list = torch.tensor(pred_id_list)
        pred_score_list = torch.tensor(pred_score_list)

        if len(torch.unique(pred_id_list)) == 1:
            # This is often the case
            pred_global_id = pred_id_list[0]
            score = torch.sum(pred_score_list)
        else:
            unique_id_list = torch.unique(pred_id_list)
            unique_score_list = []
            for item in unique_id_list:
                indexes = torch.where(pred_id_list == item)[0]
                sum_score = torch.sum(pred_score_list[indexes])
                unique_score_list.append(sum_score)
            unique_score_list = torch.tensor(unique_score_list)
            best_index = torch.argmax(unique_score_list)
            pred_global_id = unique_id_list[best_index]
            score = unique_score_list[best_index]
        probe_tags.append(pred_global_id)
        probe_scores.append(score)
    return probe_tags, probe_scores


def get_index_tags(index_feats, ref_info, batch_size=128):

    ref_feats = ref_info['ref_feats']
    ref_gbid = ref_info['ref_gbid']
    ref_loc = ref_info['ref_loc']

    print(f'computing index tags, total index:{index_feats.shape[0]}, refs: {ref_feats.shape[0]}')
    num_batches = len(index_feats) / batch_size + 1
    num_batches = int(num_batches)
    index_gbid = []
    index_locs = []
    for batch_idx in tqdm(range(num_batches)):
        batch_data = index_feats[batch_idx*batch_size:(batch_idx+1)*batch_size]
        ref_sim = torch.matmul(batch_data, ref_feats.T)
        _, ref_indices = torch.topk(ref_sim, 1, dim=1)
        ref_indices = ref_indices.flatten()    # (batch_size, )
        pred_global_id = ref_gbid[ref_indices]    # (batch_size, )
        pred_loc = ref_loc[ref_indices]
        index_gbid.append(pred_global_id)
        index_locs.append(pred_loc)
    index_gbid = torch.cat(index_gbid)    # (num_index, )
    index_locs = torch.cat(index_locs)

    return index_gbid, index_locs

def rerank_tag_and_loc(sim, probe_tags, probe_locs, probe_tag_scores, probe_loc_scores, index_tags, index_locs, alpha=1.0, beta=1.0):


    for idx, (probe_tag, probe_tag_score) in enumerate(zip(probe_tags, probe_tag_scores)):
        good_tag_indexes = torch.where(index_tags == probe_tag)

        sim = recomputing_sim(sim, good_tag_indexes, probe_tag_score, alpha)

    for idx, (probe_loc, probe_loc_score) in enumerate(zip(probe_locs, probe_loc_scores)):
        good_loc_indexes = torch.where(index_locs == probe_loc)

        sim = recomputing_sim(sim, good_loc_indexes, probe_loc_score, beta)


    return sim

def recomputing_sim(sim, indexes, score, weight):

    sim[indexes] += weight * score

    return sim

def category_rerank_after_merge(sims, probe_tags, probe_locs, probe_tag_scores, probe_loc_scores, index_tags, index_locs, sim_thr=0.1, alpha=1.0, beta=0.1):
    """Category rerank."""

    print('Category Reranking after merge......')
    print(f'Category Thr is {sim_thr}')
    rerank_sims = torch.zeros_like(sims)
    print(f'rerank sims by {alpha}, {beta}')
    for probe_index, (probe_tag, probe_loc, probe_tag_score, probe_loc_score) in enumerate(tqdm(zip(probe_tags, probe_locs, probe_tag_scores, probe_loc_scores))):

        # print(probe_tag, probe_loc, probe_tag_score, probe_loc_score)

        raw_sim = sims[probe_index].flatten()
        rerank_sims[probe_index] = rerank_tag_and_loc(raw_sim,
                                                      probe_tag, probe_loc, probe_tag_score, probe_loc_score,
                                                      index_tags, index_locs,
                                                      alpha=alpha, beta=beta)
    return rerank_sims


def category_rerank_before_merge(probe_feats, index_feats, ref_info):
    """Category rerank."""

    print('Category Reranking before merge......')
    ref_feats = ref_info['ref_feats']
    print('ref, ', ref_feats.shape)
    rerank_sims = np.zeros((len(probe_feats), len(index_feats)),
                           dtype=np.float32)
    index_gbid = []
    ref_gbid = []
    refindex2globalid = ref_info['refindex2globalid']
    for refindex in refindex2globalid:
        ref_gbid.append(refindex2globalid[refindex])
    ref_gbid = torch.tensor(ref_gbid).cuda()
    print('Get label for each index image')
    batch_size = 128
    num_batches = len(index_feats) / batch_size + 1
    num_batches = int(num_batches)
    for batch_idx in tqdm(range(num_batches)):
        batch_data = index_feats[batch_idx*batch_size:(batch_idx+1)*batch_size]
        ref_sim = torch.matmul(batch_data, ref_feats.T)
        _, ref_indices = torch.topk(ref_sim, 1, dim=1)
        ref_indices = ref_indices.flatten()    # (batch_size, )
        pred_global_id = ref_gbid[ref_indices]    # (batch_size, )
        index_gbid.append(pred_global_id)
    index_gbid = torch.cat(index_gbid)    # (num_index, )

    for probe_index, query_feat in enumerate(tqdm(probe_feats)):
        ref_sim = torch.matmul(ref_feats, query_feat[:, None]).flatten()
        _, ref_indices = torch.topk(ref_sim, setting['VOTE_NUM'])

        pred_id_list = []
        pred_score_list = []
        for ref_index in ref_indices:
            pred_score = ref_sim[ref_index]
            pred_global_id = ref_gbid[ref_index]
            pred_id_list.append(pred_global_id)
            pred_score_list.append(pred_score)
        pred_id_list = torch.tensor(pred_id_list)
        pred_score_list = torch.tensor(pred_score_list)

        if len(torch.unique(pred_id_list)) == 1:
            # This is often the case
            pred_global_id = pred_id_list[0]
        else:
            unique_id_list = torch.unique(pred_id_list)
            unique_score_list = []
            for item in unique_id_list:
                indexes = torch.where(pred_id_list == item)[0]
                sum_score = torch.sum(pred_score_list[indexes])
                unique_score_list.append(sum_score)
            unique_score_list = torch.tensor(unique_score_list)
            best_index = torch.argmax(unique_score_list)
            pred_global_id = unique_id_list[best_index]
            # print(pred_id_list, pred_score_list, pred_global_id)

        raw_sim = torch.matmul(index_feats, query_feat[:, None]).flatten()
        raw_orders = torch.argsort(-raw_sim)
        raw_orders = raw_orders.cpu().numpy()
        good_indexes = torch.where(index_gbid == pred_global_id)[0]
        good_indexes = good_indexes.cpu().numpy()
        match_indexes = np.in1d(raw_orders, good_indexes)
        pos_list = list(raw_orders[match_indexes])
        neg_list = list(raw_orders[~match_indexes])
        #pos_list = list(good_indexes)
        #neg_list = list(np.arange(index_feats.shape[0])[~np.in1d(np.arange(index_feats.shape[0]), good_indexes)])
        merged_list = pos_list + neg_list
        dummpy_sim = np.arange(len(merged_list)) / float(len(merged_list))
        dummpy_sim = 1.0 - dummpy_sim
        rerank_sims[probe_index, merged_list] = dummpy_sim
    return rerank_sims


def category_expansion(probe_feats, index_feats, ref_info):
    """Category query expansion."""

    ref_feats = ref_info['ref_feats']
    rerank_sims = np.zeros((len(probe_feats), len(index_feats)),
                           dtype=np.float32)
    for probe_index, query_feat in enumerate(tqdm(probe_feats)):
        globalid2refindex = ref_info['globalid2refindex']
        refindex2globalid = ref_info['refindex2globalid']
        ref_sim = torch.matmul(ref_feats, query_feat[:, None]).flatten()
        _, ref_indices = torch.topk(ref_sim, setting['VOTE_NUM'])
        ref_indices = ref_indices.cpu().numpy()
        same_cat_indexes = []
        for ref_index in ref_indices:
            pred_global_id = refindex2globalid[ref_index]
            same_cat_indexes.append(globalid2refindex[pred_global_id])
        same_cat_indexes = np.concatenate(same_cat_indexes)
        same_cat_indexes = torch.tensor(same_cat_indexes)
        # print(same_cat_indexes)
        cat_feats = ref_feats[same_cat_indexes]
        # print(f'{len(cat_feats)} ref images with same cat, {cat_feats.shape}')
        cat2pred_sim = torch.matmul(cat_feats, index_feats.T)    # (C, index_num)
        cat2pred_sim, _ = torch.max(cat2pred_sim, dim=0)    # (index_num, )
        # print(cat2pred_sim.shape, cat2pred_sim.max())
        good_indexes = (cat2pred_sim > 0.6).nonzero()
        good_indexes = good_indexes.cpu().numpy().flatten()
        # back to index_name_list
        # good_names = index_name_list[good_indexes]
        # print(f'good names: {good_names}')

        # 2019 GLR retrieval rerank
        raw_sim = torch.matmul(index_feats, query_feat[:, None]).flatten()
        raw_orders = torch.argsort(-raw_sim)
        raw_orders = raw_orders.cpu().numpy()
        match_indexes = np.in1d(raw_orders, good_indexes)
        pos_list = list(raw_orders[match_indexes])
        neg_list = list(raw_orders[~match_indexes])

        merged_list = pos_list + neg_list
        # merged_list = list(raw_orders)
        dummpy_sim = np.arange(len(merged_list)) / float(len(merged_list))
        dummpy_sim = 1.0 - dummpy_sim
        rerank_sims[probe_index, merged_list] = dummpy_sim
        """
        # simply query expansion max sims.
        # rerank_sims[probe_index] = cat2pred_sim.cpu().numpy()

        # QE average query features
        # qe_feats = ref_feats[same_cat_indexes]
        # qe_feats = torch.mean(qe_feats, dim=0, keepdim=True)    # (1, 512)
        # qe_feats = qe_feats / torch.norm(qe_feats, 2, 1)
        # qe2index_sim = torch.matmul(qe_feats, index_feats.T)    # (1, index_num)
        # rerank_sims[probe_index] = qe2index_sim.flatten().cpu().numpy()
        """
    return rerank_sims


def rerank_local_matching(spg_model, num_inliers_dict, probe_name, probe_dir, index_name_list, index_dir, sims, local_weight, max_inliers, cache_dir, do_cache, ignore_global_score=False):
    if do_cache:
        if not os.path.exists(cache_dir):
            os.makedirs(cache_dir)

    probe_path = f'{probe_dir}/{probe_name[0]}/{probe_name[1]}/{probe_name[2]}/{probe_name}.jpg'
    scores = []
    probe_image_cache = {}
    keypoint_time = 0
    spp_time = 0
    spg_time = 0
    matching_time = 0
    for idx, index_name in enumerate(index_name_list):
        index_path = f'{index_dir}/{index_name[0]}/{index_name[1]}/{index_name[2]}/{index_name}.jpg'

        if (probe_name, index_name) not in num_inliers_dict:
            start = time.time()
            pred, spp_t, spg_t = generate_superpoint_superglue(probe_path, probe_name, index_path, index_name,
                                                 spg_model, cache_dir, False, probe_image_cache)

            spp_time += spp_t
            spg_time += spg_t
            end_keypoint = time.time()
            keypoint_time += (end_keypoint - start)

            num_inliers = get_num_inliers(pred)
            matching_time += (time.time() - end_keypoint)

            num_inliers_dict[(probe_name, index_name)] = num_inliers
        else:
            num_inliers = num_inliers_dict.get((probe_name, index_name))
        if ignore_global_score:
            total_score = get_total_score(num_inliers, 0.)
        else:
            total_score = get_total_score(num_inliers, sims[idx], weight=local_weight, max_inlier_score=max_inliers)

        if False and idx % 9 == 0 and idx != 0:
            print(f"time of extract keypoints: {keypoint_time/idx}")
            print(f"time of extract SPP keypoints: {spp_time/idx}")
            print(f"time of matching SPG: {spg_time/idx}")
            print(f"time of matching: {matching_time/idx}")
        scores.append(total_score)

    #if do_cache:
    #    save_whole_cached_num_inliers(cache_dir, num_inliers_dict)
    scores = np.asarray(scores)
    rerank_sort = np.argsort(scores)[::-1]
    return index_name_list[rerank_sort]



def write_csv(probe_name_list, index_name_list, sims):
    """Write csv files for submission."""

    if setting['LOCAL_MATCHING'] == 'spg':
        spg_model = load_superpointglue_model(setting['SPG_MODEL_DIR'])
        num_inliers_dict = get_whole_cached_num_inliers(setting['SPG_CACHE_DIR'])
        #num_inliers_dict = {}
        rerank_num = setting['SPG_RERANK_NUM']
    index_name_list = np.array(index_name_list)
    id_list = []
    res_list = []
    print('Start output csv files')
    for probe_index, probe_name in enumerate(tqdm(probe_name_list)):
        id_list.append(probe_name)
        sim = sims[probe_index]
        orders = np.argsort(-sim)
        if setting['LOCAL_MATCHING'] == 'spg':
            sorted_name_list_topk = rerank_local_matching(spg_model, num_inliers_dict,
                                                        probe_name, setting['PROBE_DIR'],
                                                        index_name_list[orders[:rerank_num]], setting['INDEX_DIR'],
                                                        sim[orders[:rerank_num]], setting['LOCAL_WEIGHT'], setting['MAX_INLIERS'],
                                                        cache_dir=setting['SPG_CACHE_DIR'], do_cache=setting['SPG_DO_CACHE'])
            sorted_name_list = sorted_name_list_topk.tolist() + index_name_list[orders[rerank_num:100]].tolist()
        else:
            sorted_name_list = index_name_list[orders[:100]]
        res_str = ''
        for item in sorted_name_list:
            res_str += item + ' '
        res_str = res_str[:-1]
        res_list.append(res_str)
    # pylint: disable=invalid-name
    if setting['LOCAL_MATCHING']=='spg' and setting['SPG_DO_CACHE']:
        save_whole_cached_num_inliers(setting['SPG_CACHE_DIR'], num_inliers_dict)
    df = pd.DataFrame({'id': id_list, 'images': res_list})
    df.to_csv('submission.csv', index=False)
    print('Finish output csv files')

In [2]:
"""Base code for ILR2021"""

import gc
import gzip
import os
import pickle
import shutil
import sys
import time

import cv2
import psutil
import torch
import torch.nn as nn
from torch.cuda import amp
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from tqdm import tqdm
import numpy as np
from collections import OrderedDict

sys.path.append('../input/util-code/')
from make_model import make_model

# pylint: disable=not-callable, invalid-name, line-too-long
# ########################## Cell 1 Basic module test ########################

try:
    assert torch.cuda.is_available()
    gpu_num = torch.cuda.device_count()
    assert gpu_num > 0
    gpu_memory = torch.cuda.get_device_properties(0).total_memory
    print(f'Total {gpu_num} gpu cards with {gpu_memory} memory')
except AssertionError:
    print('Fail to set gpu')

# ########################## Cell 3 Load all image list  #####################


def load_image_list():
    """Load image list."""

    query_count = 0
    index_count = 0
    all_image_list = []
    for dirname, _, filenames in os.walk(setting['PROBE_DIR']):
        for filename in filenames:
            query_count += 1
            all_image_list.append(os.path.join(dirname, filename))
            # print(os.path.join(dirname, filename))
    if query_count == setting['SAMPLE_TEST_NUM'] and gpu_num == 1:
        return None, None
    for dirname, _, filenames in os.walk(setting['INDEX_DIR']):
        for filename in filenames:
            index_count += 1
            all_image_list.append(os.path.join(dirname, filename))
            # print(os.path.join(dirname, filename))
    print(f'query num: {query_count} and index num: {index_count}')
    return all_image_list, query_count

# ########################## Cell 4 ReID inference  ##########################


class ImageDataset(Dataset):
    """Image Dataset."""

    def __init__(self, dataset, transforms):
        _ = transforms
        self.dataset = dataset

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, index):
        img_path = self.dataset[index]
        img = cv2.imread(img_path)
        img = cv2.resize(img, (setting['IMAGE_SIZE'], setting['IMAGE_SIZE']))
        img = torch.tensor(img)
        img = img[:, :, [2, 1, 0]]
        return img, img_path


def val_collate_fn(batch):
    """Val collate fn."""

    imgs, img_paths = zip(*batch)
    return torch.stack(imgs, dim=0), img_paths


class ReID_Inference:
    """ReID Inference."""

    def __init__(self, backbone):

        self.model = make_model(setting['MODEL_PARAMS'][backbone]['BACKBONE'])
        model_name = setting['MODEL_PARAMS'][backbone]['MODEL_NAME']
        model_path = os.path.join(setting['MODEL_DIR'], model_name)
        self.model.load_param(model_path)
        self.batch_size = setting['BATCH_SIZE']
        if gpu_num > 1:
            print(f'Using {gpu_num} gpu for inference')
            self.model = nn.DataParallel(self.model)
            self.batch_size = setting['BATCH_SIZE'] * gpu_num
        self.model.to('cuda')
        self.model.eval()
        self.mean = torch.tensor([123.675, 116.280, 103.530]).to('cuda')
        self.std = torch.tensor([57.0, 57.0, 57.0]).to('cuda')

    def extract(self, imgpath_list):
        """Extract feature for one image."""

        val_set = ImageDataset(imgpath_list, None)

        # NOTE: no pin_memory to save memory
        if gpu_num > 1:
            pin_memory = True
            num_workers = 32
        else:
            pin_memory = False
            num_workers = 2
        val_loader = DataLoader(
            val_set, batch_size=self.batch_size, shuffle=False,
            num_workers=num_workers, collate_fn=val_collate_fn,
            pin_memory=pin_memory
        )

        batch_res_dic = OrderedDict()
        for (batch_data, batch_path) in tqdm(val_loader,
                                             total=len(val_loader)):
            with torch.no_grad():
                batch_data = batch_data.to('cuda')
                batch_data = (batch_data - self.mean) / self.std
                batch_data = batch_data.permute(0, 3, 1, 2)
                batch_data = batch_data.float()
                if not setting['MULTI_SCALE_FEAT']:
                    if setting['FP16']:
                        # NOTE: DO NOT use model.half() because of underflow
                        with amp.autocast():
                            feat = self.model(batch_data)
                    else:
                        feat = self.model(batch_data)
                else:
                    # Ref: https://github.com/lyakaap/Landmark2019-1st-and-3rd-Place-Solution/blob/3fee857dd2b2927ede70c43bffd99b41eb394507/cirtorch/networks/imageretrievalnet.py#L309
                    feat = torch.zeros((len(batch_data), 512),
                                       dtype=torch.float16).cuda()
                    raw_size = batch_data.shape[2]
                    for s in [0.707, 1.0, 1.414]:
                        new_size = int(((raw_size * s) // 16) * 16)
                        scale_data = nn.functional.interpolate(
                                batch_data, size=new_size, mode='bilinear',
                                align_corners=False)
                        with amp.autocast():
                            scale_feat = self.model(scale_data)
                        feat += scale_feat
                    feat = feat/3.0
                feat = feat / torch.norm(feat, 2, 1, keepdim=True)
                feat = feat.cpu().detach().numpy()

            for index, imgpath in enumerate(batch_path):
                batch_res_dic[imgpath] = feat[index]
        del val_loader, val_set, feat, batch_data
        return batch_res_dic


def debug_reid_inference(image_list):
    """Debug reid inference."""

    reid = ReID_Inference('R50')
    batch_res_dic = reid.extract(image_list[:20])
    print(batch_res_dic)
    del reid, batch_res_dic

# ########################## Cell 5 Extract feature  #########################


def save_feature(all_feature_dic, backbone):
    """Save feature."""

    if not os.path.exists(setting['FEAT_DIR']):
        os.makedirs(setting['FEAT_DIR'])
    index_name_list, index_feats = [], []
    probe_name_list, probe_feats = [], []
    # NOTE: attention the order! Related to probe_name_list order
    for image_path, sample_feat in sorted(all_feature_dic.items()):
        image_name = os.path.basename(image_path).split('.jpg')[0]
        sample_mode = image_path.split('/')[-5]
        if sample_mode == 'test':
            probe_name_list.append(image_name)
            probe_feats.append(sample_feat)
        else:
            index_name_list.append(image_name)
            index_feats.append(sample_feat)

    pkl_name = os.path.join(setting['FEAT_DIR'], f'probe_feats_{backbone}.pkl')
    probe_dic = {'probe_name_list': np.array(probe_name_list),
                 'probe_feats': np.array(probe_feats)}
    with open(pkl_name, 'wb') as f_pkl:
        pickle.dump(probe_dic, f_pkl, pickle.HIGHEST_PROTOCOL)
    print('Save pickle in %s' % pkl_name)
    pkl_name = os.path.join(setting['FEAT_DIR'], f'index_feats_{backbone}.pkl')
    index_dic = {'index_name_list': np.array(index_name_list),
                 'index_feats': np.array(index_feats)}
    with open(pkl_name, 'wb') as f_pkl:
        pickle.dump(index_dic, f_pkl, pickle.HIGHEST_PROTOCOL)
    print('Save pickle in %s' % pkl_name)
    all_feature_dic.clear(), probe_dic.clear(), index_dic.clear()
    del all_feature_dic, probe_dic, index_dic, probe_feats, index_feats
    del probe_name_list, index_name_list
    gc.collect()


def load_feat(mode, backbone):
    """Load precomputed features."""

    feat_dir = setting['FEAT_DIR']
    with open(f'{feat_dir}/{mode}_feats_{backbone}.pkl', 'rb') as f_pkl:
        mode_dic = pickle.load(f_pkl)
    print(f'load {backbone} feat, memory : {psutil.virtual_memory().percent}')
    return mode_dic[f'{mode}_name_list'], mode_dic[f'{mode}_feats']


def save_numpy(data_path, data, save_disk_flag=True):
    """Save numpy."""

    if save_disk_flag:
        # Save space but slow
        f_data = gzip.GzipFile(f"{data_path}.gz", "w")
        np.save(file=f_data, arr=data)
        f_data.close()
    else:
        np.save(data_path, data)


def load_numpy(data_path, save_disk_flag=True):
    """Load numpy."""

    if save_disk_flag:
        # Save space but slow
        f_data = gzip.GzipFile(f'{data_path}.gz', "r")
        data = np.load(f_data)
    else:
        data = np.load(data_path)
    return data


# ########################## Cell 6 KR rerank sims  #########################


def build_graph(initial_rank):
    """Build graph."""

    K1 = setting['K1']
    if not os.path.exists(setting['GRAPH_DIST_DIR']):
        os.makedirs(setting['GRAPH_DIST_DIR'])

    torch.cuda.empty_cache()  # empty GPU memory
    gc.collect()
    print(f'Start build graph, memory: {psutil.virtual_memory().percent}')
    all_num = initial_rank.shape[0]
    for i in tqdm(range(all_num)):
        original_dist = load_numpy(os.path.join(setting['EUC_DIST_DIR'],
                                                f'{i:08d}.npy'),
                                   save_disk_flag=False)
        V = np.zeros_like(original_dist, dtype=np.float16)
        # k-reciprocal neighbors
        forward_k_neigh_index = initial_rank[i, :K1+1]
        backward_k_neigh_index = initial_rank[forward_k_neigh_index, :K1+1]
        fi = np.where(backward_k_neigh_index == i)[0]
        k_reciprocal_index = forward_k_neigh_index[fi]
        k_reciprocal_expansion_index = k_reciprocal_index
        for j in range(len(k_reciprocal_index)):
            candidate = k_reciprocal_index[j]
            candidate_forward_k_neigh_index = initial_rank[candidate, :int(np.around(K1/2.))+1]
            candidate_backward_k_neigh_index = initial_rank[candidate_forward_k_neigh_index, :int(np.around(K1/2.))+1]
            fi_candidate = np.where(candidate_backward_k_neigh_index == candidate)[0]
            candidate_k_reciprocal_index = candidate_forward_k_neigh_index[fi_candidate]
            if len(np.intersect1d(candidate_k_reciprocal_index, k_reciprocal_index)) > 2./3*len(candidate_k_reciprocal_index):
                k_reciprocal_expansion_index = np.append(k_reciprocal_expansion_index, candidate_k_reciprocal_index)

        k_reciprocal_expansion_index = np.unique(k_reciprocal_expansion_index)
        weight = np.exp(-original_dist[k_reciprocal_expansion_index])
        norm_weight = weight/np.sum(weight)
        V[k_reciprocal_expansion_index] = norm_weight
        save_numpy(os.path.join(setting['GRAPH_DIST_DIR'], f'{i:08d}.npy'), V)
    print(f'Finish build graph, memory: {psutil.virtual_memory().percent}')


def expand_query(initial_rank):
    """Expand query."""

    K2 = setting['K2']
    print(f'Start QE, memory usage: {psutil.virtual_memory().percent}')
    if not os.path.exists(setting['QE_DIST_DIR']):
        os.makedirs(setting['QE_DIST_DIR'])

    all_num = len(initial_rank)
    for i in tqdm(range(all_num)):
        query_neighbor_list = initial_rank[i, :K2]
        neighbor_dist_list = []
        for j in query_neighbor_list:
            neighbor_file = os.path.join(setting['GRAPH_DIST_DIR'],
                                         f'{j:08d}.npy')
            neighbor_dist = load_numpy(neighbor_file)
            neighbor_dist_list.append(neighbor_dist)
        neighbor_dist_list = np.array(neighbor_dist_list)
        mean_dist = np.mean(neighbor_dist_list, axis=0)
        save_numpy(os.path.join(setting['QE_DIST_DIR'], f'{i:08d}.npy'),
                   mean_dist)
    print(f'Finish QE, memory usage: {psutil.virtual_memory().percent}')


def compute_jaccard(query_num, all_num):
    """Compute Jaccard distance."""

    JACCARD_DIR = setting['JACCARD_DIR']
    QE_DIST_DIR = setting['QE_DIST_DIR']
    if not os.path.exists(JACCARD_DIR):
        os.makedirs(JACCARD_DIR)

    gc.collect()
    print(f'Start Jaccard, memory usage: {psutil.virtual_memory().percent}')

    gal_nonzero_dic = {k: [] for k in range(all_num)}
    prb_nonzero_dic = {k: [] for k in range(all_num)}
    for k in range(all_num):
        sample_dist = load_numpy(os.path.join(QE_DIST_DIR, f'{k:08d}.npy'))
        indexes = np.where(sample_dist != 0)[0]
        for gal in indexes:
            if gal in gal_nonzero_dic:
                gal_nonzero_dic[gal].append(k)
        prb_nonzero_dic[k] = list(indexes)

    invIndex = []
    for i in range(all_num):
        invIndex.append(gal_nonzero_dic[i])
    for i in tqdm(range(query_num)):
        temp_min = np.zeros(shape=[1, all_num], dtype=np.float16)
        indNonZero = prb_nonzero_dic[i]
        indImages = []
        indImages = [invIndex[ind] for ind in indNonZero]
        for j in range(len(indNonZero)):
            temp_indNonZero_dist = load_numpy(os.path.join(QE_DIST_DIR,
                                                           f'{i:08d}.npy'))
            temp_indNonZero_dist = temp_indNonZero_dist[indNonZero[j]]
            temp_indImages = indImages[j]
            min_dist_list = []
            for ind in temp_indImages:
                temp_ind_dist = load_numpy(os.path.join(QE_DIST_DIR,
                                                        f'{ind:08d}.npy'))
                min_dist_list.append(temp_ind_dist[indNonZero[j]])
            min_dist_list = np.array(min_dist_list)
            temp_min[0, indImages[j]] = temp_min[0, indImages[j]] + \
                    np.minimum(temp_indNonZero_dist, min_dist_list)
        jaccard_dist = 1-temp_min/(2-temp_min)
        jaccard_dist = jaccard_dist.flatten()
        save_numpy(os.path.join(JACCARD_DIR, f'{i:08d}.npy'), jaccard_dist)
    print(f'Finish Jaccard, memory usage: {psutil.virtual_memory().percent}')


def merge_sims(query_num, all_num):
    """Merge original dist and jaccard dist."""

    print(f'Start merge sim, memory usage: {psutil.virtual_memory().percent}')
    EUC_DIST_DIR = setting['EUC_DIST_DIR']
    JACCARD_DIR = setting['JACCARD_DIR']
    LAMBDA = setting['LAMBDA']

    index_num = all_num - query_num
    merged_dist = np.zeros((query_num, index_num), dtype=np.float16)
    for i in range(query_num):
        original_dist = load_numpy(os.path.join(EUC_DIST_DIR, f'{i:08d}.npy'),
                                   save_disk_flag=False)
        jaccard_dist = load_numpy(os.path.join(JACCARD_DIR, f'{i:08d}.npy'))
        dist = jaccard_dist*(1-LAMBDA) + original_dist*LAMBDA
        merged_dist[i] = dist[query_num:]
    print(f'Finish merge sim, memory usage: {psutil.virtual_memory().percent}')
    return 1.0 - merged_dist


def get_origin_sims(query_num, all_num):
    """Get origin sims."""

    print(f'Start original, memory usage: {psutil.virtual_memory().percent}')
    EUC_DIST_DIR = setting['EUC_DIST_DIR']

    index_num = all_num - query_num
    merged_dist = np.zeros((query_num, index_num), dtype=np.float16)
    for i in range(query_num):
        original_dist = load_numpy(os.path.join(EUC_DIST_DIR, f'{i:08d}.npy'),
                                   save_disk_flag=False)
        merged_dist[i] = original_dist[query_num:]
    print(f'Finish original, memory usage: {psutil.virtual_memory().percent}')
    return 1.0 - merged_dist


def cache_expand_sims(probe_feats, index_feats):
    """Cache expanded(query + index) sims for KR rerank."""

    if not os.path.exists(setting['EUC_DIST_DIR']):
        os.makedirs(setting['EUC_DIST_DIR'])

    query_num = probe_feats.shape[0]
    index_num = index_feats.shape[0]
    all_num = query_num + index_num
    initial_rank = np.zeros((all_num, setting['K1']+10), dtype=np.int32)
    concat_feat = torch.cat([probe_feats, index_feats])
    del probe_feats, index_feats
    torch.cuda.empty_cache()  # empty GPU memory
    gc.collect()
    print(f'Load feats memory usage: {psutil.virtual_memory().percent}')
    for sample_index in tqdm(range(all_num)):
        cos_sim = torch.matmul(concat_feat[sample_index][None, :],
                               concat_feat.T)
        # euc_dist_gpu = 2 * (1 - cos_sim)
        # euc_dist_gpu = torch.sqrt(euc_dist_gpu)
        # custom euc dist without norm
        euc_dist_gpu = 1.0 - (cos_sim + 1.0)/2
        euc_dist_gpu = euc_dist_gpu[0]
        euc_dist_cpu = euc_dist_gpu.cpu().numpy()
        # print(euc_dist_cpu.shape, euc_dist_cpu.max(), euc_dist_cpu.min())
        orders = torch.argsort(euc_dist_gpu)
        orders = orders.cpu().numpy()[:setting['K1']+10]
        initial_rank[sample_index, :] = orders
        save_numpy(os.path.join(setting['EUC_DIST_DIR'],
                                f'{sample_index:08d}.npy'),
                   euc_dist_cpu, save_disk_flag=False)
        del cos_sim, euc_dist_gpu, euc_dist_cpu, orders
        # print(f'Memory usage: {psutil.virtual_memory().percent}')
    # print(initial_rank.shape)
    save_numpy(setting['INITIAL_RANK_FILE'], initial_rank)
    del concat_feat
    torch.cuda.empty_cache()  # empty GPU memory
    gc.collect()
    print(f'Finish cache sim, memory usage: {psutil.virtual_memory().percent}')
    return initial_rank


def kr_rerank_disk(probe_feats, index_feats):
    """Memory efficient rerank."""

    print('Starting re_ranking')
    initial_rank = cache_expand_sims(probe_feats, index_feats)
    query_num = len(probe_feats)
    all_num = len(initial_rank)
    build_graph(initial_rank)
    if setting['K2'] != 1:
        expand_query(initial_rank)
    print(f'Memory usage: {psutil.virtual_memory().percent}')
    del initial_rank
    gc.collect()
    print(f'Memory usage: {psutil.virtual_memory().percent}')
    compute_jaccard(query_num, all_num)
    sims = merge_sims(query_num, all_num)
    return sims


# ########################## Cell 7 Refset extraction  #######################


def load_meta(backbone):
    """Load meta."""

    with open(setting['REF_SET_META'], 'rb') as f_meta:
        ref_meta = pickle.load(f_meta)
    print(f'Load ref_set_meta, memory: {psutil.virtual_memory().percent}')
    pkl_name = setting['REF_SET_FEAT'].replace('.pkl', f'_{backbone}.pkl')
    refset_feat_dic = pickle.load(open(pkl_name, 'rb'))
    print(f'Load raw ref_set_feat, memory: {psutil.virtual_memory().percent}')
    ref_feats_gpu = []
    ref_feats = []
    batch_size = 128
    for ref_name in ref_meta['ref_name_list']:
        ref_feats.append(refset_feat_dic[ref_name])
        if len(ref_feats) % batch_size == 0:
            ref_feats = np.array(ref_feats)
            ref_feats = torch.tensor(ref_feats).cuda().half()
            ref_feats_gpu.append(ref_feats)
            ref_feats = []
        del refset_feat_dic[ref_name]
    if len(ref_feats) > 0:
        ref_feats = np.array(ref_feats)
        ref_feats = torch.tensor(ref_feats).cuda().half()
        ref_feats_gpu.append(ref_feats)
    del refset_feat_dic
    del ref_meta['ref_name_list']
    ref_meta['ref_feats'] = torch.cat(ref_feats_gpu, dim=0)
    print(f'Convert ref_set_feat, memory: {psutil.virtual_memory().percent}')

    with open(setting['REF_LOC_MAP'], 'rb') as f:
        loc_map = pickle.load(f)
    ref_gbid = []
    ref_loc = []
    refindex2globalid = ref_meta['refindex2globalid']
    for refindex in refindex2globalid:
        gbid = refindex2globalid[refindex]
        ref_gbid.append(gbid)
        loc = loc_map[gbid]
        ref_loc.append(loc)
    ref_gbid = torch.tensor(ref_gbid).cuda()
    ref_loc = torch.tensor(ref_loc).cuda()
    ref_meta['ref_gbid'] = ref_gbid
    ref_meta['ref_loc'] = ref_loc
    print(f'Convert other metas, memory: {psutil.virtual_memory().percent}')
    return ref_meta


def prepare_meta():
    """Prepare meta."""

    # NOTE: The order for ref_path_list, global_id serve as category name
    print('Using %s as ref set' % setting['REF_SET_LIST'])
    with open(setting['REF_SET_LIST'], 'rb') as f_ref:
        ref_set_list = pickle.load(f_ref)
    refindex2globalid, globalid2refindex = OrderedDict(), OrderedDict()
    ref_name_list = []
    global_count = 0
    for item in ref_set_list:
        ref_path = item[0]
        image_name = os.path.basename(ref_path).split('.jpg')[0]
        global_id = item[1]
        # Ignore non-landmarks
        if global_id < 0:
            continue
        assert global_id < 203094
        ref_name_list.append(image_name)
        if global_id not in globalid2refindex:
            globalid2refindex[global_id] = [global_count]
        else:
            globalid2refindex[global_id].append(global_count)
        refindex2globalid[global_count] = global_id
        global_count += 1
    print(f'{len(globalid2refindex)} unique global ids')
    print(f'{global_count} ref images')
    save_dic = {'globalid2refindex': globalid2refindex,
                'refindex2globalid': refindex2globalid,
                'ref_name_list': ref_name_list}
    dirname = os.path.dirname(setting['REF_SET_META'])
    if not os.path.exists(dirname):
        os.makedirs(dirname)
    with open(setting['REF_SET_META'], 'wb') as f_meta:
        pickle.dump(save_dic, f_meta, pickle.HIGHEST_PROTOCOL)


def extract_refset(reid, backbone):
    """Extract feature for ref set."""

    with open(setting['REF_ALL_LIST'], 'rb') as f_ref:
        ref_set_list = pickle.load(f_ref)
    ref_path_list = []
    for item in ref_set_list:
        ref_path = item[0]
        ref_path_list.append(ref_path)

    start_time = time.time()
    all_feature_dic = reid.extract(ref_path_list)
    save_dic = {}
    for ref_path, feat in all_feature_dic.items():
        ref_name = os.path.basename(ref_path).split('.jpg')[0]
        save_dic[ref_name] = feat
    print('%.4f s' % (time.time() - start_time))
    pkl_name = setting['REF_SET_FEAT'].replace('.pkl', f'_{backbone}.pkl')
    with open(pkl_name, 'wb') as f_pkl:
        pickle.dump(save_dic, f_pkl, pickle.HIGHEST_PROTOCOL)
    print(f'Extract refset Memory: {psutil.virtual_memory().percent}')


# ########################## Cell 8 Get output file  #########################


def compute_sim(backbone):
    """Compute initial similarities and ranklist."""

    torch.cuda.empty_cache()  # empty GPU memory
    gc.collect()
    print(f'Start cache sim, memory usage: {psutil.virtual_memory().percent}')
    if not os.path.exists(setting['SIMS_DIR']):
        os.makedirs(setting['SIMS_DIR'])

    probe_name_list, probe_feats = load_feat('probe', backbone)
    index_name_list, index_feats = load_feat('index', backbone)
    print(len(probe_name_list), len(index_name_list))

    probe_feats = torch.tensor(probe_feats).cuda().half()
    index_feats = torch.tensor(index_feats).cuda().half()
    if not setting['KR_FLAG'] and setting['CATEGORY_RERANK'] != 'before_merge':
        sims = torch.matmul(probe_feats, index_feats.T).cpu().numpy()
    elif setting['CATEGORY_RERANK'] == 'before_merge':
        ref_info = load_meta(backbone)
        ref_info['probe_name_list'] = probe_name_list
        ref_info['index_name_list'] = index_name_list
        
        sims = torch.matmul(probe_feats, index_feats.T)
        index_tags, index_locs = get_index_tags(index_feats, ref_info, batch_size=128)
        probe_tags, probe_tag_scores, probe_locs, probe_loc_scores = get_probe_tags_topk(probe_feats, ref_info)
        sims = category_rerank_after_merge(sims, probe_tags, probe_locs,
                                               probe_tag_scores, probe_loc_scores,
                                               index_tags, index_locs,
                                               sim_thr=setting['CATEGORY_THR'],
                                               alpha=setting['alpha'],
                                               beta=setting['beta'])
    elif setting['KR_FLAG']:
        # NOTE: KR rerank seems not suitable here.
        sims = kr_rerank_disk(probe_feats, index_feats)
    else:
        print('Unkown compute sims setting')
    if torch.is_tensor(sims):
        sims = sims.cpu().numpy()
    pkl_name = os.path.join(setting['SIMS_DIR'], f'{backbone}_sims.pkl')
    with open(pkl_name, 'wb') as f_sims:
        pickle.dump(sims, f_sims, pickle.HIGHEST_PROTOCOL)
    # NOTE: It is important to fix this order for all models.
    with open(setting['NAME_LIST_FILE'], 'wb') as f_name:
        pickle.dump([probe_name_list, index_name_list], f_name)


def get_output():
    """Get output."""
    print(f'Get output start, memory: {psutil.virtual_memory().percent}')
    with open(setting['NAME_LIST_FILE'], 'rb') as f_name:
        [probe_name_list, index_name_list] = pickle.load(f_name)
    sims = None
    for backbone, weight in zip(setting['MODEL_LIST'], setting['MODEL_WEIGHT']):
        pkl_name = os.path.join(setting['SIMS_DIR'], f'{backbone}_sims.pkl')
        with open(pkl_name, 'rb') as f_sims:
            backbone_sims = pickle.load(f_sims)
            print(f"backbone: {backbone}, weight: {weight}")
            if sims is None:
                sims = weight * backbone_sims
            else:
                sims += weight * backbone_sims
    print(f'Sim Fusion Done, memory: {psutil.virtual_memory().percent}')

    if setting['CATEGORY_RERANK'] == 'after_merge':
        sims = torch.tensor(sims).cuda().half()
        for idx, backbone in enumerate(setting['MODEL_LIST']):
            print('Computing category rerank after merge')
            probe_name_list, probe_feats = load_feat('probe', backbone)
            index_name_list, index_feats = load_feat('index', backbone)
            print(len(probe_name_list), len(index_name_list))
            ref_info = load_meta(backbone)
            ref_info['probe_name_list'] = probe_name_list
            ref_info['index_name_list'] = index_name_list
            probe_feats = torch.tensor(probe_feats).cuda().half()
            index_feats = torch.tensor(index_feats).cuda().half()
            index_tags, index_locs = get_index_tags(index_feats, ref_info, batch_size=128)
            probe_tags, probe_tag_scores, probe_locs, probe_loc_scores = get_probe_tags_topk(probe_feats, ref_info)
            sims = category_rerank_after_merge(sims, probe_tags, probe_locs,
                                               probe_tag_scores, probe_loc_scores,
                                               index_tags, index_locs,
                                               sim_thr=setting['CATEGORY_THR'],
                                               alpha=setting['alpha'],
                                               beta=setting['beta'])
        print(f'Tag Rerank for each model done!, memory: {psutil.virtual_memory().percent}')
        sims = sims.cpu().numpy()

    print(f'Get sims Memory: {psutil.virtual_memory().percent}')
    write_csv(probe_name_list, index_name_list, sims)


def main():
    """Main."""

    print(f'Init Memory usage: {psutil.virtual_memory().percent}')
    image_list, query_count = load_image_list()
    print(f'load image Memory usage: {psutil.virtual_memory().percent}')
    if image_list is None and query_count is None:
        print('Dummy submission!')
        shutil.copyfile(os.path.join(setting['IMAGE_DIR'],
                                     'sample_submission.csv'),
                        'submission.csv')
        return
    if setting['DEBUG_FLAG']:
        debug_reid_inference(image_list)
    if setting['CATEGORY_RERANK'] != 'off':
        # meta info shared by all models
        prepare_meta()
    for backbone in setting['MODEL_LIST']:
        if setting['REID_EXTRACT_FLAG']:
            reid = ReID_Inference(backbone)
            print(f'Load model, memory: {psutil.virtual_memory().percent}')
            start_time = time.time()
            feature_dic = reid.extract(image_list)
            print('%.4f s for %s' % ((time.time() - start_time), backbone))
            print(f'Extract feature Memory: {psutil.virtual_memory().percent}')
            save_feature(feature_dic, backbone)
            print(f'Save feature Memory: {psutil.virtual_memory().percent}')
        if setting['REF_SET_EXTRACT']:
            # These should be offline calculated.
            print('Extract refset feature')
            reid = ReID_Inference(backbone)
            print(f'Load model, memory: {psutil.virtual_memory().percent}')
            extract_refset(reid, backbone)
            print(f'Extract refset Memory: {psutil.virtual_memory().percent}')
        compute_sim(backbone)
    print(f'Compute sim Memory: {psutil.virtual_memory().percent}')
    get_output()


if __name__ == '__main__':
    main()

Total 1 gpu cards with 17071734784 memory
Init Memory usage: 5.3
load image Memory usage: 5.3
Dummy submission!
