In [6]:
import json
import numpy as np
from scipy.special import softmax
from tqdm import tqdm
from pprint import pprint
from sklearn.cluster import KMeans
from munkres import Munkres

from label_processor import CStdLib

In [7]:
with open('result_logits.json', encoding='utf-8') as f:
    logits_json = json.load(f)

In [8]:
embed = CStdLib(single=False)

In [9]:
pbar = tqdm(logits_json.items())
munkres = Munkres()
i = 0
for key, prod_dict in pbar:
    optional_tags = prod_dict['optional_tags']
    num_class = len(optional_tags)
    imgs_tags = prod_dict['imgs_tags']
    imgs_logits = []
    for i in range(len(imgs_tags)):
        imgs_logits.append(np.fromstring(imgs_tags[i][key + '_{}.jpg'.format(i)], dtype=np.float32, sep=' '))
    imgs_logits = np.array(imgs_logits)
    
    # 聚类
    cluster = KMeans(n_clusters=num_class, random_state=42).fit(imgs_logits)
    labels = cluster.labels_
    
    # 计算匹配度
    match_score = np.zeros((num_class, num_class))
    # match_score[i, j] 表示 kmeans 分类结果为 i 和 optional_tags[j] 的匹配度
    embeded_tags = [embed(tag) for tag in optional_tags]
    not_matched = False
    for i in range(num_class):
        id_i = np.where(labels == i)[0]
        logits = imgs_logits[id_i]
        for j in range(num_class):
            logits_ids = embeded_tags[j]
            if len(logits_ids) > 0:
                logits_j = logits[:, logits_ids]
                match_score[i, j] = np.mean(logits_j)
            else:
                not_matched = True
                match_score[i, j] = 1
    
    # 对 scores 按列归一化
    match_score = (np.abs(match_score) +  match_score) / 2. + 1e-7 # relu
    match_score_norm = match_score / np.linalg.norm(match_score, axis=0, ord=1)
    # match_score_norm = softmax(match_score, axis=0)
    
    # 使用匈牙利算法计算最优匹配
    risk_matrix = 1. - match_score_norm
    munkres_result = munkres.compute(risk_matrix)
    
    # 终于可以写入 json 了
    for i in range(num_class):
        id_i = np.where(labels == i)[0]
        label = optional_tags[munkres_result[i][1]]
        for j in id_i:
            logits_json[key]['imgs_tags'][j][key + '_{}.jpg'.format(j)] = label
    if not_matched:
        print(key, '\n', match_score, '\n', match_score_norm)
        pprint(logits_json[key])

  5%|▌         | 279/5331 [00:03<01:09, 73.08it/s] 

624952669722 
 [[3.42698036 1.0000001 ]
 [5.39661608 1.0000001 ]] 
 [[0.38838816 0.5       ]
 [0.61161184 0.5       ]]
{'imgs_tags': [{'624952669722_0.jpg': '珊瑚色组合'},
               {'624952669722_1.jpg': '珊瑚色组合'},
               {'624952669722_2.jpg': '橄榄色组合'},
               {'624952669722_3.jpg': '橄榄色组合'},
               {'624952669722_4.jpg': '橄榄色组合'}],
 'optional_tags': ['珊瑚色组合', '橄榄色组合']}


  7%|▋         | 348/5331 [00:03<00:47, 105.65it/s]

614043736149 
 [[ 5.664219    6.80937682  3.0382763   1.0000001   5.91158305]
 [ 3.8386942   3.8388711  11.31113253  1.0000001   3.06997309]
 [ 3.59171877 11.78814326  4.34701214  1.0000001   2.77561102]
 [12.11357317  3.26983796  3.10478507  1.0000001   2.78206406]
 [ 4.20342074  3.02493964  2.04442153  1.0000001  12.64851389]] 
 [[0.19258436 0.23700313 0.1274144  0.2        0.21743558]
 [0.13051622 0.13361347 0.47434828 0.2        0.11291753]
 [0.12211901 0.41029111 0.18229808 0.2        0.10209052]
 [0.41186343 0.11380804 0.13020354 0.2        0.10232787]
 [0.14291698 0.10528425 0.0857357  0.2        0.4652285 ]]
{'imgs_tags': [{'614043736149_0.jpg': '条纹'},
               {'614043736149_1.jpg': '桔色'},
               {'614043736149_2.jpg': '红色'},
               {'614043736149_3.jpg': '桔色'},
               {'614043736149_4.jpg': '红色'},
               {'614043736149_5.jpg': '黑色'},
               {'614043736149_6.jpg': '条纹'},
               {'614043736149_7.jpg': '黑色'},
               {

 15%|█▌        | 813/5331 [00:08<00:50, 88.91it/s] 

565607486988 
 [[13.13959322  1.0000001 ]
 [12.38072119  1.0000001 ]] 
 [[0.514868 0.5     ]
 [0.485132 0.5     ]]
{'imgs_tags': [{'565607486988_0.jpg': '条纹'},
               {'565607486988_1.jpg': '条纹'},
               {'565607486988_2.jpg': '条纹'},
               {'565607486988_3.jpg': '黑色'},
               {'565607486988_4.jpg': '黑色'},
               {'565607486988_5.jpg': '黑色'},
               {'565607486988_6.jpg': '黑色'}],
 'optional_tags': ['黑色', '条纹']}


 30%|██▉       | 1591/5331 [00:16<00:35, 106.08it/s]

623845128856 
 [[ 5.21820126  1.0000001  15.01067553]
 [ 8.65206538  1.0000001   4.14062605]
 [ 7.18516217  1.0000001  14.37576685]] 
 [[0.24783163 0.33333333 0.4477181 ]
 [0.41091851 0.33333333 0.12350099]
 [0.34124986 0.33333333 0.42878091]]
{'imgs_tags': [{'623845128856_0.jpg': '黑色'},
               {'623845128856_1.jpg': '黑色'},
               {'623845128856_2.jpg': '条纹'},
               {'623845128856_3.jpg': '白色'},
               {'623845128856_4.jpg': '条纹'},
               {'623845128856_5.jpg': '白色'}],
 'optional_tags': ['黑色', '条纹', '白色']}
605311018824 
 [[ 5.49235306  1.0000001   8.07947264]
 [11.20709239  1.0000001   4.11887322]
 [11.5468627   1.0000001   7.90635548]] 
 [[0.19444499 0.33333333 0.40186982]
 [0.39676309 0.33333333 0.20487115]
 [0.40879193 0.33333333 0.39325904]]
{'imgs_tags': [{'605311018824_0.jpg': '黑色'},
               {'605311018824_1.jpg': '蓝色'},
               {'605311018824_2.jpg': '条纹'},
               {'605311018824_3.jpg': '蓝色'},
               {'605311

 31%|███▏      | 1679/5331 [00:17<00:37, 97.46it/s] 

621875084587 
 [[12.77419577  1.0000001 ]
 [11.76983366  1.0000001 ]] 
 [[0.52046042 0.5       ]
 [0.47953958 0.5       ]]
{'imgs_tags': [{'621875084587_0.jpg': '条纹'},
               {'621875084587_1.jpg': '条纹'},
               {'621875084587_2.jpg': '条纹'},
               {'621875084587_3.jpg': '条纹'},
               {'621875084587_4.jpg': '黑色'},
               {'621875084587_5.jpg': '黑色'},
               {'621875084587_6.jpg': '黑色'}],
 'optional_tags': ['黑色', '条纹']}


 43%|████▎     | 2298/5331 [00:23<00:25, 120.65it/s]

576888702197 
 [[ 8.49875937  1.0000001   5.05461607 11.74309931]
 [ 5.26506291  1.0000001  13.26082716  8.40398989]
 [ 5.55389033  1.0000001   7.07131587 14.75581751]
 [12.62910185  1.0000001   6.03299103  8.25590716]] 
 [[0.26602838 0.25       0.16087385 0.27209041]
 [0.16480713 0.25       0.42205387 0.19472245]
 [0.17384802 0.25       0.22505958 0.34189581]
 [0.39531647 0.25       0.1920127  0.19129134]]
{'imgs_tags': [{'576888702197_0.jpg': '黑色'},
               {'576888702197_1.jpg': '黑色'},
               {'576888702197_2.jpg': '条纹色'},
               {'576888702197_3.jpg': '白色'},
               {'576888702197_4.jpg': '灰色'},
               {'576888702197_5.jpg': '灰色'},
               {'576888702197_6.jpg': '条纹色'},
               {'576888702197_7.jpg': '黑色'},
               {'576888702197_8.jpg': '白色'}],
 'optional_tags': ['黑色', '条纹色', '灰色', '白色']}


 46%|████▋     | 2476/5331 [00:25<00:24, 114.62it/s]

603281751757 
 [[12.39502726  1.0000001 ]
 [13.52659331  1.0000001 ]] 
 [[0.47817332 0.5       ]
 [0.52182668 0.5       ]]
{'imgs_tags': [{'603281751757_0.jpg': '黑色'},
               {'603281751757_1.jpg': '黑色'},
               {'603281751757_2.jpg': '黑色'},
               {'603281751757_3.jpg': '拼色'},
               {'603281751757_4.jpg': '黑色'},
               {'603281751757_5.jpg': '拼色'},
               {'603281751757_6.jpg': '拼色'}],
 'optional_tags': ['黑色', '拼色']}
617278066189 
 [[6.24258242 1.0000001 ]
 [5.61380206 1.0000001 ]] 
 [[0.52651653 0.5       ]
 [0.47348347 0.5       ]]
{'imgs_tags': [{'617278066189_0.jpg': '条纹套装'},
               {'617278066189_1.jpg': '白色套装'},
               {'617278066189_2.jpg': '条纹套装'},
               {'617278066189_3.jpg': '白色套装'},
               {'617278066189_4.jpg': '白色套装'},
               {'617278066189_5.jpg': '白色套装'}],
 'optional_tags': ['白色套装', '条纹套装']}


 53%|█████▎    | 2826/5331 [00:28<00:21, 116.00it/s]

617521964984 
 [[ 3.91663704 11.03135024  1.0000001   6.58475028]
 [12.7886859   4.51802931  1.0000001   4.60156498]
 [ 4.80758391  7.85608731  1.0000001  10.33585177]
 [ 8.73006068 10.50558005  1.0000001   7.83476887]] 
 [[0.12950571 0.32530256 0.25       0.22429964]
 [0.42286478 0.13323179 0.25       0.15674541]
 [0.15896535 0.2316675  0.25       0.35207529]
 [0.28866416 0.30979816 0.25       0.26687965]]
{'imgs_tags': [{'617521964984_0.jpg': '灰色'},
               {'617521964984_1.jpg': '黑色'},
               {'617521964984_2.jpg': '黑色'},
               {'617521964984_3.jpg': '白色'},
               {'617521964984_4.jpg': '条纹'},
               {'617521964984_5.jpg': '白色'},
               {'617521964984_6.jpg': '灰色'},
               {'617521964984_7.jpg': '条纹'}],
 'optional_tags': ['黑色', '灰色', '条纹', '白色']}


 58%|█████▊    | 3088/5331 [00:30<00:18, 118.79it/s]

614491314779 
 [[13.74386893  1.0000001 ]
 [13.38741122  1.0000001 ]] 
 [[0.50656913 0.5       ]
 [0.49343087 0.5       ]]
{'imgs_tags': [{'614491314779_0.jpg': '黑白点'},
               {'614491314779_1.jpg': '竖条纹'},
               {'614491314779_2.jpg': '黑白点'},
               {'614491314779_3.jpg': '竖条纹'},
               {'614491314779_4.jpg': '竖条纹'},
               {'614491314779_5.jpg': '竖条纹'},
               {'614491314779_6.jpg': '黑白点'}],
 'optional_tags': ['黑白点', '竖条纹']}


 65%|██████▍   | 3442/5331 [00:33<00:16, 116.18it/s]

621931894053 
 [[ 1.0000001  12.00257979  8.67579184]
 [ 1.0000001   7.48503695 12.75935851]
 [ 1.0000001   8.4399129   9.68032275]] 
 [[0.33333333 0.42977592 0.27882564]
 [0.33333333 0.26801644 0.41006474]
 [0.33333333 0.30220764 0.31110961]]
{'imgs_tags': [{'621931894053_0.jpg': '白色半裙'},
               {'621931894053_1.jpg': '白色半裙'},
               {'621931894053_2.jpg': '白色半裙'},
               {'621931894053_3.jpg': '黑色半裙'},
               {'621931894053_4.jpg': '条纹衬衫'},
               {'621931894053_5.jpg': '白色半裙'},
               {'621931894053_6.jpg': '黑色半裙'}],
 'optional_tags': ['条纹衬衫', '黑色半裙', '白色半裙']}


 70%|██████▉   | 3714/5331 [00:35<00:13, 118.37it/s]

601450019157 
 [[1.0000001]] 
 [[1.]]
{'imgs_tags': [{'601450019157_0.jpg': '1906纯色'},
               {'601450019157_1.jpg': '1906纯色'},
               {'601450019157_2.jpg': '1906纯色'}],
 'optional_tags': ['1906纯色']}


 82%|████████▏ | 4381/5331 [00:41<00:08, 118.39it/s]

625820335127 
 [[11.47327147  1.0000001   8.58770476]
 [ 6.56241618  1.0000001  12.8673564 ]
 [13.04104815  1.0000001   7.14513979]] 
 [[0.36919165 0.33333333 0.30026729]
 [0.21116813 0.33333333 0.44990441]
 [0.41964022 0.33333333 0.24982831]]
{'imgs_tags': [{'625820335127_0.jpg': '白色'},
               {'625820335127_1.jpg': '白色'},
               {'625820335127_2.jpg': '条纹'},
               {'625820335127_3.jpg': '黑色'},
               {'625820335127_4.jpg': '条纹'},
               {'625820335127_5.jpg': '白色'},
               {'625820335127_6.jpg': '条纹'},
               {'625820335127_7.jpg': '黑色'}],
 'optional_tags': ['黑色', '条纹', '白色']}


 85%|████████▍ | 4529/5331 [00:42<00:06, 118.79it/s]

577541680180 
 [[14.86534319  1.0000001   6.96493731  2.01583348]
 [ 2.39070569  1.0000001   2.37925969 11.59864817]
 [ 4.04465399  1.0000001  11.47464762  3.51491008]
 [ 3.49112401  1.0000001   4.32417927  9.89799605]] 
 [[0.59960661 0.25       0.27701271 0.07458484]
 [0.0964312  0.25       0.09462902 0.42914425]
 [0.16314465 0.25       0.456375   0.13004994]
 [0.14081754 0.25       0.17198326 0.36622097]]
{'imgs_tags': [{'577541680180_0.jpg': '红色'},
               {'577541680180_1.jpg': '红色'},
               {'577541680180_2.jpg': '蓝色'},
               {'577541680180_3.jpg': '蓝色'},
               {'577541680180_4.jpg': '橙色'},
               {'577541680180_5.jpg': '深色'},
               {'577541680180_6.jpg': '橙色'},
               {'577541680180_7.jpg': '蓝色'},
               {'577541680180_8.jpg': '蓝色'}],
 'optional_tags': ['橙色', '深色', '红色', '蓝色']}


 90%|████████▉ | 4781/5331 [00:44<00:04, 124.01it/s]

612126042525 
 [[ 1.0000001  13.50281916]
 [ 1.0000001  12.25867472]] 
 [[0.5        0.52414737]
 [0.5        0.47585263]]
{'imgs_tags': [{'612126042525_0.jpg': '957黑色'},
               {'612126042525_1.jpg': '957黑色'},
               {'612126042525_2.jpg': '957黑色'},
               {'612126042525_3.jpg': '6802格子'},
               {'612126042525_4.jpg': '957黑色'},
               {'612126042525_5.jpg': '957黑色'},
               {'612126042525_6.jpg': '957黑色'}],
 'optional_tags': ['6802格子', '957黑色']}


100%|██████████| 5331/5331 [00:49<00:00, 107.53it/s]


In [10]:
with open('./result_labels.json', 'w') as f:
    json.dump(logits_json, f, indent=4, ensure_ascii=False)