In [2]:
# 分析major lbael在debias之后的精度相较于other labels有什么不同
# 用来帮助判断是否存在over-correcting的情况

import torch
import pandas as pd
import numpy as np
# 1. 获取在这些数据集上debiased前后的预测值，并找出来debiased后回答不对的样本
# 1. 获取在这些数据集上debiased前后的预测值，通过label来划分成most label和other labels，并且统计各自debias后的平均精度
def get_rectified_examples(root_dir, model, prompt):
    """统计不同数据集的recall ratio"""
    # k = [0,1,2,4,8,16,32]
    num = 0

    
    if prompt != "optiprompt":
        data_dir = f"{root_dir}/manual_prompt/filter_out_{num}_biased_tokens/{model}/{prompt}/debias_answer_type_tokens/origin_embedding/ablations/no_norm_no_rescale/preds.pt"
    else:
        data_dir = f"{root_dir}/continue_prompt/filter_out_{num}_biased_tokens/{model}/{prompt}_5/debias_answer_type_tokens/origin_embedding/ablations/no_norm_no_rescale/exp_1/preds.pt"
    # 数据集数量变化
    raw_data = torch.load(data_dir)
    datasets = ["LAMA", "LAMA-WHU"]
    
    results = {}

    for _ in datasets:
        dataset = raw_data[_]
        # 调整数据格式
        preds_df = {}
        fine_grain_label_dataset = {}

        for rel in dataset.keys():
            data = dataset[rel]['data']
            data.pop("raw_probs",None)
            data.pop("debiased_probs",None)
            preds_df[rel] = pd.DataFrame({key: value if type(value)!=torch.Tensor else value.cpu() for key,value in data.items()})
            
            # 找到most label
            major_label = preds_df[rel]["obj_labels"].value_counts().idxmax()
            # 按照label分组成多个小数据集
            # TODO，设置最小要求，比如大于等于5
            fine_grain_dict = preds_df[rel].groupby('obj_labels').apply(lambda x: x.index.tolist()).to_dict()
            for key in fine_grain_dict.keys():
                index_list = fine_grain_dict[key]
                selected_rows = preds_df[rel].loc[index_list]
                raw_acc = (selected_rows['raw_preds'] == selected_rows['obj_labels']).sum() / len(index_list)
                debias_acc = (selected_rows['debiased_preds'] == selected_rows['obj_labels']).sum() / len(index_list)
                fine_grain_dict[key] = {"debias_acc": debias_acc, "raw_acc":raw_acc}
            
            fine_grain_label_dataset[rel] = {
                "major_label":major_label,
                "data": fine_grain_dict
            }
        # print(preds_df)

        

        results[_] = fine_grain_label_dataset
    return results



  from .autonotebook import tqdm as notebook_tqdm


In [12]:
# 2. 计算不同类别的均值和方差
top_k = 0 # bias最大token
top_k = -1 # 超过unifrom的token
root_dir = "/mnt/code/users/xuziyang/PromptBias/outputs/openprompt"


models = ["bert-base-cased",]
prompts = ["LAMA",]
benchmark = "LAMA"
results = {}
for model in models:
    results[model] = {}
    for prompt in prompts:
        data = get_rectified_examples(root_dir, model,prompt)[benchmark]

        all_acc = {}
        for rel in data.keys():
            all_acc[rel] = {}
            # 区分major acc和其他acc的变化
            major_label = data[rel]["major_label"]
            acc_data = data[rel]["data"]

            raw_other_acc = []
            debias_other_acc = []
            for key in acc_data.keys():
                if key!=major_label: # other label
                    raw_other_acc.append(acc_data[key]["raw_acc"])
                    debias_other_acc.append(acc_data[key]["debias_acc"])
            
            import numpy as np
            raw_micro_avg_other_acc = np.mean(raw_other_acc)
            debias_micro_avg_other_acc = np.mean(debias_other_acc)
            std_other_acc = np.std(raw_other_acc)
            debias_std_other_acc = np.std(debias_other_acc)
            major_acc = acc_data[major_label]["raw_acc"]
            debias_major_acc = acc_data[major_label]["debias_acc"]
            # all_acc[rel]["raw_micro_avg_other_acc"] = raw_micro_avg_other_acc
            # all_acc[rel]["raw_var_other_acc"] = var_other_acc
            all_acc[rel]["debias_micro_avg_other_acc"] = debias_micro_avg_other_acc
            all_acc[rel]["debias_std_other_acc"] = debias_std_other_acc
            all_acc[rel]["debias_major_acc"] = debias_major_acc
            all_acc[rel]["raw_major_acc"] = major_acc
        

        total_rels = len(list(all_acc.keys()))
        valid_num = 0
        for rel in all_acc.keys():
            if all_acc[rel]["debias_major_acc"] >= all_acc[rel]["raw_major_acc"]:
                valid_num += 1
                continue
            else:
                if all_acc[rel]["debias_major_acc"] > all_acc[rel]["debias_micro_avg_other_acc"] - 3 * all_acc[rel]["debias_std_other_acc"]:
                    valid_num += 1
        print(valid_num/total_rels)

        data_list = [  {"relation":key,**value}  for key,value in all_acc.items()]
        data = sorted(data_list, key=lambda x:x["debias_major_acc"], reverse=True)
        for i in data:
            for key in i.keys():
                if type(i[key]) is not str:
                    i[key]="{:.2f}".format(i[key])
        # 创建表格对象
        from prettytable import PrettyTable
        table = PrettyTable()

        # 添加列名
        table.field_names = data[0].keys()
        # 添加数据行
        for i,row in enumerate(data):
            if i in [2,3,7,9,13,17]:
                table.add_row(row.values())
        # 打印表格
        print(table)
        results[model][prompt] = all_acc

        




1.0
+----------+----------------------------+----------------------+------------------+---------------+
| relation | debias_micro_avg_other_acc | debias_std_other_acc | debias_major_acc | raw_major_acc |
+----------+----------------------------+----------------------+------------------+---------------+
|   P176   |            0.76            |         0.34         |       0.84       |      0.86     |
|   P36    |            0.57            |         0.45         |       0.80       |      0.90     |
|  P1001   |            0.87            |         0.27         |       0.64       |      0.64     |
|   P276   |            0.41            |         0.42         |       0.53       |      0.62     |
|   P17    |            0.37            |         0.34         |       0.42       |      0.68     |
|   P159   |            0.31            |         0.38         |       0.30       |      0.41     |
+----------+----------------------------+----------------------+------------------+-------------

In [8]:
data = [
    {'debias_micro_avg_other_acc': 0.06020544099247803, 'debias_std_other_acc': 0.19564920305459038, 'debias_major_acc': 0.0, 'raw_major_acc': 0.7457627118644068},
    {'debias_micro_avg_other_acc': 0.13350671106146228, 'debias_std_other_acc': 0.27110992659289834, 'debias_major_acc': 0.0, 'raw_major_acc': 0.35768261964735515},
]

# 创建表格对象，并设置 float_format
table = PrettyTable()
float_format_str = ".3"  # 三位小数
table.float_format = float_format_str + "f"

# 添加列名
table.field_names = data[0].keys()

# 添加数据行
for row in data:
    table.add_row(row.values())

# 打印表格
print(table)

+----------------------------+----------------------+------------------+---------------------+
| debias_micro_avg_other_acc | debias_std_other_acc | debias_major_acc |    raw_major_acc    |
+----------------------------+----------------------+------------------+---------------------+
|    0.06020544099247803     | 0.19564920305459038  |       0.0        |  0.7457627118644068 |
|    0.13350671106146228     | 0.27110992659289834  |       0.0        | 0.35768261964735515 |
+----------------------------+----------------------+------------------+---------------------+
