## 汇总结果

In [27]:
import json
import pandas as pd
from pathlib import Path

mtl_albef_paths = list(Path("./output/trecis").glob("mtl_albef_*"))

res_list = []
name_template = "{}-{} {}:{}, {}:{}, {}"
for output_dir in mtl_albef_paths:
    val_log_fpath = output_dir / "val_log.txt"
    with open(val_log_fpath, "r", encoding="utf8") as file:
        desc_list = output_dir.name.split("_")
        tmp_dict = dict(
            config = "-".join(desc_list[:2]),
            only_text = desc_list[3],
            ckpt = desc_list[5],
            task = desc_list[7:],
        )
        tmp_dict.update(json.loads(file.read()))
        res_list.append(tmp_dict)

res_df = pd.DataFrame(res_list)[
    ["config", "only_text", 
     "val_info_type_cls_f1", "val_info_type_cls_acc",
     "val_priority_regression_f1", "val_priority_regression_acc",
     "ckpt", "task", "epoch"
     ]]
res_df.sort_values(by="only_text").sort_values(by="task")

Unnamed: 0,config,only_text,val_info_type_cls_f1,val_info_type_cls_acc,val_priority_regression_f1,val_priority_regression_acc,ckpt,task,epoch
2,mtl-albef,False,0.496,0.9271,,,albef,[itc],9
4,mtl-albef,True,0.509,0.9267,,,bert,[itc],11
5,mtl-albef,False,0.4082,0.9312,0.3976,0.4401,albef,"[itc, pr]",6
0,mtl-albef,True,0.3695,0.931,0.4166,0.5214,bert,"[itc, pr]",5
1,mtl-albef,False,,,0.4394,0.6244,albef,[pr],10
3,mtl-albef,True,,,0.4346,0.685,bert,[pr],11


## 含有图片的样本比较研究

In [56]:
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
from dataset.trecis_dataset import INFO_TYPE_CATEGORIES

val_json_fpath = Path("./trecis/val.json")

# 加载验证集样本信息
with open(val_json_fpath, "r", encoding="utf8") as file:
    sample_df = pd.DataFrame(json.load(file))
    
id2imgs_dict = {}
for fpath in Path("../../../my_dataset/trecis_images/").glob("*.jpg"):
    post_id = fpath.stem.split("_")[0]
    id2imgs_dict.setdefault(post_id, [])
    id2imgs_dict[post_id].append(fpath)

# 图片数目计数
def count_imgs(post_id):
    return len(id2imgs_dict.get(post_id, []))

sample_df["img_num"] = sample_df["post_id"].apply(count_imgs)
sample_df.head()

# 有无图片分隔
no_img_df = sample_df.query("img_num == 0")
own_img_df = sample_df.query("img_num > 0")

print("验证集中总样本数目为：{}".format(sample_df.shape[0]))
print("验证集中 没有 图片的样本数目为：{:>5}".format(no_img_df.shape[0]))
print("验证集中 拥有 图片的样本数目为：{:>5}".format(own_img_df.shape[0]))

# info type label
label_dict = {info_type:[] for info_type in INFO_TYPE_CATEGORIES}
for label_str in sample_df["info_type"]:
    labels = label_str.split(",")
    for info_type in label_dict.keys():
        label_dict[f"{info_type}"].append(1 if info_type in labels else 0)
        
label_df = pd.DataFrame(label_dict)
label_df["post_id"] = sample_df["post_id"]
label_df["priority"] = sample_df["priority"]
# label_df = label_df.set_index("post_id")


no_img_p_collect = {}
own_img_p_collect = {}

ranks = ["Critical", "High", "Medium", "Low"]
for rank in ranks:
    no_img_p_collect.update({rank: label_df[sample_df["img_num"] == 0].query(f"priority == '{rank}'").shape[0]})
    own_img_p_collect.update({rank: label_df[sample_df["img_num"] > 0].query(f"priority == '{rank}'").shape[0]})
   
    
no_img_ifc_collect = label_df[sample_df["img_num"] == 0][INFO_TYPE_CATEGORIES].sum()    
own_img_ifc_collect = label_df[sample_df["img_num"] > 0][INFO_TYPE_CATEGORIES].sum()

import matplotlib.pyplot as plt

plt.style.use("ggplot")
plt.fill_between(no_img_p_collect.keys(), no_img_p_collect.values(), alpha=0.3)
plt.plot(no_img_p_collect.keys(), no_img_p_collect.values(), label="no img")
plt.fill_between(own_img_p_collect.keys(), own_img_p_collect.values(), alpha=0.3)
plt.plot(own_img_p_collect.keys(), own_img_p_collect.values(), label="own img")
plt.title("priority rank distribution for own/no image tweets", loc="left")
plt.yscale("log")
plt.legend()
plt.show()

plt.fill_between(no_img_ifc_collect.index, no_img_ifc_collect, alpha=0.3)
plt.plot(no_img_ifc_collect.index, no_img_ifc_collect, label="no img")
plt.fill_between(own_img_ifc_collect.index, own_img_ifc_collect, alpha=0.3)
plt.plot(own_img_ifc_collect.index, own_img_ifc_collect, label="own img")
plt.title("info type distribution for own/no image tweets", loc="left")
plt.yscale("log")
plt.xticks(rotation=90)
plt.legend()
plt.show()


验证集中总样本数目为：3381
验证集中 没有 图片的样本数目为： 2960
验证集中 拥有 图片的样本数目为：  421


UndefinedVariableError: name 'Critical' is not defined

In [34]:
from sklearn.metrics import f1_score, classification_report, accuracy_score, precision_recall_fscore_support
from typing import List, Tuple
import numpy as np

split_map = {
    "all": sample_df["img_num"] >= 0,
    "own_img": sample_df["img_num"] > 0,
    "no_img": sample_df["img_num"] == 0
}

task_map = {
    "itc": (INFO_TYPE_CATEGORIES, INFO_TYPE_CATEGORIES),
    "pr": ("priority", "priority_pred")
}

def metrics(split: str, task: str)-> Tuple[float, float, float, int]:
    sample_set_index = split_map[split]
    colum_true, colum_pred = task_map[task]
    
    y_true = label_df.loc[sample_set_index, colum_true]
    y_pred = pred_df.loc[sample_set_index, colum_pred] 
    
    assert y_true.shape == y_pred.shape, f"y_true {y_true.shape}, y_pred {y_pred.shape}"
    
    return precision_recall_fscore_support(
        y_true, y_pred, average='macro'
    )
    

split_report_list = []
for output_dir in tqdm(mtl_albef_paths):
    # load prediction
    pred_df = pd.read_csv(output_dir / "eval_res.csv")
    pred_df["post_id"] = sample_df["post_id"]

    desc_list = output_dir.name.split("_")
    tasks = desc_list[7:]
    
    tmp_dict = dict(
        config = "-".join(desc_list[:2]),
        only_text = desc_list[3],
        ckpt = desc_list[5],
        task = desc_list[7:],
        )

    for task in tasks:
        for split in ["all", "own_img", "no_img"]:
            p, r, f, _ = metrics(split, task)
            tmp_dict.update({
                f"{split}_{task}_p": p,
                f"{split}_{task}_r": r,
                f"{split}_{task}_f": f,
                    })

    split_report_list.append(tmp_dict)

split_report_df = pd.DataFrame(split_report_list).fillna("/")
split_report_df[[
    "config", "only_text", "all_itc_f", "own_img_itc_f", "no_img_itc_f", "all_pr_f", "own_img_pr_f", "no_img_pr_f", "task"
    ]].sort_values("only_text").sort_values("task").apply(lambda line: [f"{x:.2%}" if isinstance(x, float) else x for x in line], axis=0)

  0%|          | 0/6 [00:00<?, ?it/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Unnamed: 0,config,only_text,all_itc_f,own_img_itc_f,no_img_itc_f,all_pr_f,own_img_pr_f,no_img_pr_f,task
2,mtl-albef,False,49.60%,47.01%,49.16%,/,/,/,[itc]
4,mtl-albef,True,50.90%,49.75%,50.32%,/,/,/,[itc]
5,mtl-albef,False,40.82%,41.21%,40.28%,39.76%,35.74%,39.92%,"[itc, pr]"
0,mtl-albef,True,36.95%,32.60%,37.04%,41.66%,32.90%,42.83%,"[itc, pr]"
1,mtl-albef,False,/,/,/,43.94%,39.24%,44.38%,[pr]
3,mtl-albef,True,/,/,/,43.46%,44.11%,42.90%,[pr]
