In [35]:
import os
import os.path as osp
import sys
sys.path.append('/home/wangxiao13/annotation/ilearnmm')
import json
import pickle
from tqdm import tqdm

import numpy as np
from scipy.special import kl_div
import pandas as pd
import matplotlib.pyplot as plt

from multimodal.core.evaluation import mmit_mean_average_precision, top_k_recall, top_k_precision, mean_average_precision
from multimodal.core.evaluation import top_k_accuracy, mean_class_accuracy, confusion_matrix

In [2]:
def sigmoid(x):
    return 1/(1+np.exp(-x))

In [41]:
def class_accuracy(scores, labels):
    """Calculate mean class accuracy.

    Args:
        scores (list[np.ndarray]): Prediction scores for each class.
        labels (list[int]): Ground truth labels.

    Returns:
        np.ndarray: Mean class accuracy.
    """
    pred = np.argmax(scores, axis=1)
    cf_mat = confusion_matrix(pred, labels).astype(float)

    cls_cnt = cf_mat.sum(axis=1) # [K]
    cls_hit = np.diag(cf_mat) # [K]

    class_acc = [hit / cnt if cnt else 0.0 for cnt, hit in zip(cls_cnt, cls_hit)]

    return cf_mat, class_acc

In [14]:
root_path = '/home/wangxiao13/annotation'
test_label_path = osp.join(root_path, 'data/hetu600/hetu600_text_test_list.txt')
label_map_path =  osp.join(root_path, 'data/hetu600/label_map_hetu600.txt')
result_path = osp.join(root_path, 'ilearntext2/work_dirs/bert_hetu600/test_results.json')

# Load labels

In [8]:
idx2cat = {}
cat2idx = {}
with open(label_map_path, 'r') as F:
    lines = F.readlines()
for line in lines:
    cat = line.strip()
    idx = len(idx2cat)
    idx2cat[idx] = cat
    cat2idx[cat] = idx
# No dupliacted category name
assert len(idx2cat) == len(cat2idx) 

In [66]:
def calculate_hit_at(result, gt_labels):
    assert len(result) == len(gt_labels)
    pd_hit_at = []
    for res, label_idx in tqdm(zip(result, gt_labels)):
        topk_idxs = np.argsort(res)[::-1]
        gt_label_rank = np.where(topk_idxs == label_idx)[0].item()
        pd_hit_at.append(gt_label_rank)
    return pd_hit_at

# Load annotation

In [78]:
pd_pids = []
pd_sents = []
pd_labels = []
pd_hit_at = []

with open(test_label_path, 'r') as F:
    lines = F.readlines()

gt_labels = []
for line in tqdm(lines):
    pid, sent, label_idx = line.strip().split('\t')
    pid = int(pid)
    pd_pids.append(pid)
    pd_sents.append(sent)
    label_idx = int(label_idx)
    label_name = idx2cat[label_idx]
    pd_labels.append(label_name)
    gt_labels.append(label_idx)
gt_labels = np.array(gt_labels)

100%|██████████| 66036/66036 [00:00<00:00, 557770.85it/s]


# Load results

In [60]:
with open(result_path, 'r') as F:
    result = json.load(F)
result = np.array(result)
if len(np.where(result<0)[0]) or len(np.where(result>1)[0]):
    # got logits
    result = sigmoid(result)

In [85]:
pd_hit_at = calculate_hit_at(result, gt_labels)

66036it [00:02, 26655.73it/s]


# Analysis

In [87]:
stat = pd.DataFrame({
        'pid': pd_pids,
        'title': pd_sents,
        'category_name': pd_labels,
        'hit@': pd_hit_at})

In [32]:
mCAcc = mean_class_accuracy(result, gt_labels)
accuracy = top_k_accuracy(result, gt_labels, topk=(1,2,3,4,5))
print('model:')
print('mean class accuracy: %.2f' % mCAcc)
print('acc@1 %.2f, acc@2 %.2f, acc@3 %.2f, acc@4 %.2f, acc@5 %.2f' % tuple(accuracy))

model:
mean class accuracy: 0.72
acc@1 0.74, acc@2 0.80, acc@3 0.82, acc@4 0.84, acc@5 0.85


In [42]:
confusion_matrix, class_acc = class_accuracy(result, gt_labels)
cls_cnt = confusion_matrix.sum(axis=1)

In [43]:
for idx in np.argsort(class_acc):
    print('%s(%d) acc %.2f' % (idx2cat[idx], cls_cnt[idx], class_acc[idx]))

竞技体育(13) acc 0.00
影视综艺(17) acc 0.06
丧系负能量(16) acc 0.06
宠物医美(31) acc 0.06
动作角色扮演游戏（ARPG）(85) acc 0.09
大型多人在线游戏（MMO）(86) acc 0.10
球类运动(101) acc 0.11
音乐游戏(104) acc 0.12
人物摄影(83) acc 0.12
宅物潮玩(62) acc 0.13
乐器(82) acc 0.15
枪战射击(68) acc 0.15
搭配展示(46) acc 0.15
美食日常(64) acc 0.17
休闲益智(88) acc 0.18
埙(16) acc 0.19
极限运动(106) acc 0.19
室内演唱(84) acc 0.20
宠物(74) acc 0.20
即兴舞蹈(132) acc 0.21
汽车买卖(99) acc 0.22
手艺(54) acc 0.22
网红(122) acc 0.23
绘画(104) acc 0.23
竞速游戏(95) acc 0.23
山水游(138) acc 0.23
玩具短剧(68) acc 0.24
宠物鸟(68) acc 0.26
城市游(122) acc 0.27
念珠制作(18) acc 0.28
4-12岁萌娃(53) acc 0.28
冒险游戏(108) acc 0.29
猫(72) acc 0.29
晒娃(126) acc 0.30
体育游戏(99) acc 0.30
对战卡牌类(108) acc 0.31
多人在线战术竞技游戏（MOBA）(111) acc 0.31
动画(81) acc 0.31
漫画(100) acc 0.32
育儿(59) acc 0.32
炫技(137) acc 0.33
亲子周边(108) acc 0.34
野生动物(70) acc 0.34
表演(96) acc 0.34
短剧(127) acc 0.35
健身训练(140) acc 0.35
烹饪(99) acc 0.35
魔幻短剧(76) acc 0.36
景物摄影(137) acc 0.36
二次元装扮(136) acc 0.36
阅读(38) acc 0.37
继承法(27) acc 0.37
说车评车(121) acc 0.37
动漫(109) acc 0.38
搞笑段子(145) 

In [99]:
stat[(stat['category_name']=='采耳') & (stat['hit@']>0)]

Unnamed: 0,pid,title,category_name,hit@
3587,52174298154,经历了那么多的事事非非，我终于明白了，原来人也会变的，哈哈,采耳,252
12029,52098576198,#面部护理,采耳,14
23242,52066188721,#掏耳屎解压 一听价格，你走了于是你到处找更便宜的！被人坑了，你又回来了！最近很火的话：任...,采耳,2
34486,52058171934,上班中！,采耳,49
34949,52241781892,#小小喵,采耳,232
38254,47215399882,真是吓人啊，,采耳,156
43317,52058337319,贫民窟探店，经费有限 #,采耳,140
45148,41850274589,#感谢快手平台 #感谢推广小助手助我上热门,采耳,265
46900,52122662718,耳朵下雪了,采耳,2
49687,52224570737,#偷拍成功 #忙忙碌碌 #记录生活,采耳,12
