In [None]:
import json
import numpy as np
from scipy.special import softmax
from tqdm import tqdm
from pprint import pprint
from sklearn.cluster import KMeans
from munkres import Munkres

from label_processor import CStdLib

In [None]:
with open('result_logits.json', encoding='utf-8') as f:
    logits_json = json.load(f)

In [None]:
embed = CStdLib(single=False)

In [None]:
pbar = tqdm(logits_json.items())
munkres = Munkres()
i = 0
for key, prod_dict in pbar:
    optional_tags = prod_dict['optional_tags']
    num_class = len(optional_tags)
    imgs_tags = prod_dict['imgs_tags']
    imgs_logits = []
    for i in range(len(imgs_tags)):
        imgs_logits.append(np.fromstring(imgs_tags[i][key + '_{}.jpg'.format(i)], dtype=np.float32, sep=' '))
    imgs_logits = np.array(imgs_logits)
    
    # 聚类
    cluster = KMeans(n_clusters=num_class, random_state=42).fit(imgs_logits)
    labels = cluster.labels_
    
    # 计算匹配度
    match_score = np.zeros((num_class, num_class))
    # match_score[i, j] 表示 kmeans 分类结果为 i 和 optional_tags[j] 的匹配度
    embeded_tags = [embed(tag) for tag in optional_tags]
    not_matched = False
    for i in range(num_class):
        id_i = np.where(labels == i)[0]
        logits = imgs_logits[id_i]
        for j in range(num_class):
            logits_ids = embeded_tags[j]
            if len(logits_ids) > 0:
                logits_j = logits[:, logits_ids]
                match_score[i, j] = np.mean(logits_j)
            else:
                not_matched = True
                match_score[i, j] = 1
    
    # 对 scores 按列归一化
    match_score = (np.abs(match_score) +  match_score) / 2. + 1e-7 # relu
    match_score_norm = match_score / np.linalg.norm(match_score, axis=0, ord=1)
    # match_score_norm = softmax(match_score, axis=0)
    
    # 使用匈牙利算法计算最优匹配
    risk_matrix = 1. - match_score_norm
    munkres_result = munkres.compute(risk_matrix)
    
    # 终于可以写入 json 了
    for i in range(num_class):
        id_i = np.where(labels == i)[0]
        label = optional_tags[munkres_result[i][1]]
        for j in id_i:
            logits_json[key]['imgs_tags'][j][key + '_{}.jpg'.format(j)] = label
    if not_matched:
        print(key, '\n', match_score, '\n', match_score_norm)
        pprint(logits_json[key])

In [None]:
with open('./result_labels.json', 'w') as f:
    json.dump(logits_json, f, indent=4, ensure_ascii=False)