In [None]:
import json
import numpy as np
from tqdm import tqdm
from pprint import pprint
from sklearn.cluster import KMeans

from label_processor import CStdLib

In [None]:
with open('result_logits.json', encoding='utf-8') as f:
    logits_json = json.load(f)

In [None]:
embed = CStdLib(single=False)

In [None]:
pbar = tqdm(logits_json.items())
i = 0
for key, prod_dict in pbar:
    optional_tags = prod_dict['optional_tags']
    num_class = len(optional_tags)
    imgs_tags = prod_dict['imgs_tags']
    imgs_logits = []
    for i in range(len(imgs_tags)):
        imgs_logits.append(np.fromstring(imgs_tags[i][key + '_{}.jpg'.format(i)], dtype=np.float32, sep=' '))
    imgs_logits = np.array(imgs_logits)
    
    # 聚类
    cluster = KMeans(n_clusters=num_class, random_state=42).fit(imgs_logits)
    labels = cluster.labels_
    
    # 计算匹配度
    match_score = np.zeros((num_class, num_class))
    # match_score[i, j] 表示 kmeans 分类结果为 i 和 optional_tags[j] 的匹配度
    embeded_tags = [embed(tag) for tag in optional_tags]
    not_matched = False
    for i in range(num_class):
        id_i = np.where(labels == i)[0]
        logits = imgs_logits[id_i]
        for j in range(num_class):
            logits_ids = embeded_tags[j]
            if len(logits_ids) > 0:
                logits_j = logits[:, logits_ids]
                match_score[i, j] = np.mean(logits_j)
            else:
                not_matched = True
                match_score[i, j] = 1
    
    # 对 scores 进行归一化，使得每一列的 ord 范数为 1；然后在每一行中找到最大值，即为该行商品对应的标签。
    # 按列归一化，按行找最大值，这样就做到了双向选择。
    # 使用较高阶的范数能增大未匹配标签对应的概率，有利于提高 acc/em。
    match_score = (np.abs(match_score) +  match_score) / 2. + 1e-7 # 只保留正数，加 1e-7 避免 warning
    match_score_norm = match_score / np.linalg.norm(match_score, axis=0, ord=2)
    # print(key, '\n', match_score_norm)
    label_id = np.argmax(match_score_norm, axis=1)
    label_score = np.max(match_score_norm, axis=1)
    # print(label_id)
    
    # 强制使每一类都被选中
    selected_label_id = []
    for i in range(num_class):
        if label_id[i] == -1:
            continue
        selected_label_id.append(label_id[i])
        for j in range(i + 1, num_class):
            if label_id[i] == label_id[j]:
                if label_score[i] > label_score[j]:
                    label_id[j] = -1
                else:
                    label_id[i] = -1
    selected_label_id = set(selected_label_id)
    unselected = list(set(range(num_class)) - selected_label_id)
    i = 0
    for j in range(num_class):
        if label_id[j] == -1:
            label_id[j] = unselected[i]
            i += 1
    
    # 终于可以写入 json 了
    for i in range(num_class):
        id_i = np.where(labels == i)[0]
        for j in id_i:
            logits_json[key]['imgs_tags'][j][key + '_{}.jpg'.format(j)] = optional_tags[label_id[i]]
    if not_matched:
        print(key, '\n', match_score, '\n', match_score_norm)
        pprint(logits_json[key])

In [None]:
with open('./result_labels.json', 'w') as f:
    json.dump(logits_json, f, indent=4, ensure_ascii=False)