# Inplementation of Internal State-based Uncertainty Estimation

1. 数据集预处理：将不同格式的数据集处理成input+gt的格式，方便判断模型的correctness，这一部分的采用固定不可调整的prompt，即Context: Question: Options: Answer:格式
2. 生成回复，为每个模型确定一个prompt，一个max_new_tokens数，然后生成回复
3. 计算回复部分的correctness指标，判断模型的回复是否正确
4. 计算uncertainty指标，包括PE, LN-PE, SAR, Ours
5. 计算AUROC，绘制AUROC/Correctness-Threshold曲线

In [1]:
from utils import *

datasets.disable_progress_bar()
datasets.disable_caching()
torch.set_grad_enabled(False)

# Eval Result Config
model_names = [
    "vicuna-7b-v1.1",
    "vicuna-13b-v1.1",
    "vicuna-33b-v1.3",
]

dst_names = [
    "sciq",
    "coqa",
    "triviaqa",
    "medmcqa",
    "MedQA-USMLE-4-options",
]

c_metrics = [
    'rougel',
    'sentsim',
    'include'
]

dst_types = [
    "short",
    "long",
]

acc_map = {
    "vicuna-7b-v1.1": {
        "sciq": 0.60,
        "coqa": 0.8,
        "triviaqa": 0.55,
        "medmcqa": 0.30,
        "MedQA-USMLE-4-options": 0.30
    },
    "vicuna-13b-v1.1": {
        "sciq": 0.0,
        "coqa": 0.0,
        "triviaqa": 0.0,
        "medmcqa": 0.0,
        "MedQA-USMLE-4-options": 0.0
    },
    "vicuna-33b-v1.3": {
        "sciq": 0.0,
        "coqa": 0.0,
        "triviaqa": 0.0,
        "medmcqa": 0.0,
        "MedQA-USMLE-4-options": 0.0
    }
}

model_names_alias = {
    "vicuna-7b-v1.1": "Vicuna-7B",
    "vicuna-13b-v1.1": "Vicuna-13B",
    "vicuna-33b-v1.3": "Vicuna-33B"
}

dst_names_alias = {
    "sciq": "SciQ",
    "coqa": "CoQA",
    "triviaqa": "TriviaQA",
    "medmcqa": "MedMCQA",
    "MedQA-USMLE-4-options": "MedQA"
}

dst_names_alias = {
    "sciq": "SciQ",
    "coqa": "CoQA",
    "triviaqa": "TriviaQA",
    "medmcqa": "MedMCQA",
    "MedQA-USMLE-4-options": "MedQA"
}

u_metric_alias = {
    "u_score_pe": "PE",
    "u_score_ln_pe": "LN-PE",
    "u_score_token_sar": "TokenSAR",
    "u_score_sent_sar": "SentSAR",
    "u_score_sar": "SAR",
    "u_score_ls": "LS",
    "u_score_se": "SE",
    "u_score_ours_mean_soft_rougel": "Ours(MSRL)",
    "u_score_ours_last_soft_rougel": "Ours(LSRL)",
    "u_score_ours_mean_soft_include": "Ours(MSIN)",
    "u_score_ours_last_soft_include": "Ours(LSIN)",
    "u_score_ours_mean_soft_sentsim": "Ours(MSSI)",
    "u_score_ours_last_soft_sentsim": "Ours(LSSI)"
}

u_metric_direction = {
    "u_score_pe": -1,
    "u_score_ln_pe": -1,
    "u_score_token_sar": -1,
    "u_score_sent_sar": -1,
    "u_score_sar": -1,
    "u_score_ls": 1,
    "u_score_se": -1,
    "u_score_ours_mean_soft_rougel": 1,
    "u_score_ours_last_soft_rougel": 1,
    "u_score_ours_mean_soft_include": 1,
    "u_score_ours_last_soft_include": 1,
    "u_score_ours_mean_soft_sentsim": 1,
    "u_score_ours_last_soft_sentsim": 1
}


def get_cached_result_path(model_name, dst_name, dst_type, dst_split):
    return f"cached_results/{model_name}/{dst_type}/{dst_name}_{dst_split}"


def get_eval_main_result_path(model_name, dst_name, dst_type):
    return f"eval_results/{model_name}/{dst_name}_{dst_type}"


def get_eval_cross_result_path(model_name, train_dst_name, train_dst_type, test_dst_name, test_dst_type, c_metric):
    return f"cross_eval_results/{model_name}/{c_metric}/v_c_{train_dst_name}_{train_dst_type}_mean_soft_best.pth/{test_dst_name}_{test_dst_type}"


def get_c_th_by_acc(test_dst, c_metric, acc):
    sorted_c_scores = sorted(list(test_dst[c_metric]), reverse=True)
    c_th = sorted_c_scores[int(len(sorted_c_scores) * acc)]
    return c_th


def get_acc_by_c_th(test_dst, c_metric, c_th):
    return sum([1 if s > c_th else 0 for s in test_dst[c_metric]]) / len(test_dst)


In [2]:
# Load LLM Model
model_name = 'vicuna-7b-v1.1'
hooked_transformer_name = get_hooked_transformer_name(model_name)
hf_model_path = os.path.join(os.environ["my_models_dir"], model_name)
hf_tokenizer = AutoTokenizer.from_pretrained(hf_model_path)
hf_tokenizer.pad_token_id = hf_tokenizer.eos_token_id
with LoadWoInit():
    hf_model = AutoModelForCausalLM.from_pretrained(hf_model_path)
model = HookedTransformer.from_pretrained_no_processing(hooked_transformer_name, dtype='bfloat16', hf_model=hf_model, tokenizer=hf_tokenizer, default_padding_side='left')

You are using the default legacy behaviour of the <class 'transformers.models.llama.tokenization_llama.LlamaTokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

  return self.fget.__get__(instance, owner)()


Loaded pretrained model llama-7b-hf into HookedTransformer


In [3]:
# Show Data Sample
model_name = 'vicuna-7b-v1.1'
hooked_transformer_name = get_hooked_transformer_name(model_name)
hf_model_path = os.path.join(os.environ["my_models_dir"], model_name)
hf_tokenizer = AutoTokenizer.from_pretrained(hf_model_path)
hf_tokenizer.pad_token_id = hf_tokenizer.eos_token_id
for dst_name in dst_names:
    for dst_type in dst_types:
        dst = Dataset.load_from_disk(get_cached_result_path('vicuna-7b-v1.1', dst_name, dst_type, 'validation'))
        dst = dst.map(wash_answer, fn_kwargs=dict(tokenizer=hf_tokenizer))
        dst = dst.map(get_num_tokens, batched=True, batch_size=8)
        dst = dst.sort('num_output_tokens')
        print(f"Dataset: {dst_name}_{dst_type}")
        print(f"Size: {len(dst)}")
        print(f"Mean Num Token: {np.mean(dst['num_answer_tokens'])}")
        print(f"Empty Answer: {sum([1 if len(x) == 0 else 0 for x in dst['washed_answer']])}")
        print(f"Samples:")
        for i in range(10):
            d = dst[i]
            for k in ['input','input_ids', 'washed_answer_ids']:
                if isinstance(d[k], list):
                    dk = hf_tokenizer.batch_decode(d[k])
                else:
                    dk = d[k]
                print(f"{k}: {dk}")
            print()
        # go.Figure(go.Histogram(x=dst['num_answer_tokens'])).show()

Dataset: sciq_short
Size: 1000
Mean Num Token: 5.724
Empty Answer: 4
Samples:
input: Question:Periodic refers to something that does what? Options:fail, dies, falls, repeat Answer:
input_ids: ['<s>', 'Question', ':', 'Period', 'ic', 'refers', 'to', 'something', 'that', 'does', 'what', '?', 'Options', ':', 'fail', ',', 'dies', ',', 'falls', ',', 'repeat', 'Answer', ':']
washed_answer_ids: ['repeat']

input: Question:A bog is a type of ____ Options:stream, wetland, plant, lake Answer:
input_ids: ['<s>', 'Question', ':', 'A', 'bog', 'is', 'a', 'type', 'of', '_', '___', 'Options', ':', 'stream', ',', 'wet', 'land', ',', 'plant', ',', 'lake', 'Answer', ':']
washed_answer_ids: ['b', 'og']

input: Question:Gases have no definite shape or what? Options:volume, mass, growth, smell Answer:
input_ids: ['<s>', 'Question', ':', 'G', 'ases', 'have', 'no', 'definite', 'shape', 'or', 'what', '?', 'Options', ':', 'volume', ',', 'mass', ',', 'growth', ',', 'sm', 'ell', 'Answer', ':']
washed_answer_ids: 

In [3]:
dst = Dataset.load_from_disk(get_cached_result_path('vicuna-7b-v1.1', 'sciq', 'long', 'validation'))
dst = dst.map(wash_answer, fn_kwargs=dict(tokenizer=hf_tokenizer))
dst = dst.map(get_num_tokens, batched=True, batch_size=8)
dst = dst.map(get_include)
dst = dst.sort('num_output_tokens')
dst = dst.filter(lambda x : x['num_answer_tokens'] != 0)

In [6]:
for i in range(100):
    d = dst[i]
    for k in ['input_ids','question','washed_answer','gt','include']:
        if isinstance(d[k], list):
            dk = hf_tokenizer.batch_decode(d[k])
        else:
            dk = d[k]
        print(f"{k}: {dk}")
    output_ids = d['input_ids'] + d['washed_answer_ids']
    inp = torch.tensor(output_ids).to(model.cfg.device)
    prob = F.softmax(model(inp), dim=-1)
    next_token_prob = prob[0,list(range(prob.shape[1]-1)),output_ids[1:]].tolist()
    answer_token_prob = next_token_prob[-len(d['washed_answer_ids']):]
    print(list(zip(hf_tokenizer.batch_decode(d['washed_answer_ids']), list(map(lambda x: str(x)[:5],answer_token_prob)))))
    print()

input_ids: ['<s>', 'A', 'chat', 'between', 'a', 'curious', 'user', 'and', 'an', 'artificial', 'intelligence', 'assistant', '.', 'The', 'assistant', 'gives', 'helpful', ',', 'detailed', ',', 'and', 'pol', 'ite', 'answers', 'to', 'the', 'user', "'", 's', 'questions', '.', 'US', 'ER', ':', 'Question', ':', 'What', 'kind', 'of', 'waves', 'are', 'sound', 'waves', '?', 'Options', ':', 'External', ',', 'mechanical', ',', 'spin', 'ning', ',', 'internal', 'Answer', ':', 'A', 'SS', 'IST', 'ANT', ':']
question: What kind of waves are sound waves?
washed_answer: Sound waves are mechanical waves
gt: mechanical
include: 1
[('Sound', '0.824'), ('waves', '1.0'), ('are', '1.0'), ('mechanical', '0.921'), ('waves', '1.0')]

input_ids: ['<s>', 'A', 'chat', 'between', 'a', 'curious', 'user', 'and', 'an', 'artificial', 'intelligence', 'assistant', '.', 'The', 'assistant', 'gives', 'helpful', ',', 'detailed', ',', 'and', 'pol', 'ite', 'answers', 'to', 'the', 'user', "'", 's', 'questions', '.', 'US', 'ER', ':

In [4]:
# Merge Train Dataset
model_name = 'vicuna-7b-v1.1'
train_size_per_dataset = 2000
val_size_per_dataset = 100


def merge_dst(model_name, dst_names, dst_types, dst_split, data_size):
    all_dst = [Dataset.load_from_disk(get_cached_result_path(model_name, dst_name, dst_type, dst_split)) for dst_name in dst_names for dst_type in dst_types]

    all_columns = ['dst_template', 'question', 'input', 'input_ids', 'gt', 'options', 'answer_ids']

    def fill_missing_columns(dst: Dataset):
        for column in dst.column_names:
            if column not in all_columns:
                dst = dst.remove_columns(column)
        if 'options' not in dst.column_names:
            dst = dst.add_column('options', [['#####'] for i in range(len(dst))])
        return dst

    all_dst = [dst if len(dst) <= data_size else dst.select(range(data_size)) for dst in all_dst]
    all_dst = [fill_missing_columns(dst) for dst in all_dst]
    merged_dst = datasets.concatenate_datasets(all_dst)
    print(len(merged_dst))
    save_path = get_cached_result_path(model_name, 'all', 'long', dst_split)
    os.makedirs(save_path, exist_ok=True)
    merged_dst.save_to_disk(save_path)
    print(f"Save Merged Dataset to {save_path}")


merge_dst(model_name, dst_names, dst_types, 'train', train_size_per_dataset)
merge_dst(model_name, dst_names, dst_types, 'validation', val_size_per_dataset)

20000


Saving the dataset (0/1 shards):   0%|          | 0/20000 [00:00<?, ? examples/s]

Save Merged Dataset to cached_results/vicuna-7b-v1.1/long/all_train
1000


Saving the dataset (0/1 shards):   0%|          | 0/1000 [00:00<?, ? examples/s]

Save Merged Dataset to cached_results/vicuna-7b-v1.1/long/all_validation


In [None]:
!pip install git+https://github.com/davidbau/baukit

: 

In [7]:
# Main Results: Get Main Results
datasets.disable_progress_bar()
for c_metric in c_metrics:
    # 设置全局c_metric阈值
    c_th = 0.3

    # 创建结果表格的多重索引，包括模型、数据集
    model_names_index = [model_names_alias[name] for name in model_names for i in range(len(dst_names))]
    dst_names_index = [dst_names_alias[name] for name in dst_names] * len(model_names)
    multi_index = pd.MultiIndex.from_tuples(zip(model_names_index, dst_names_index), names=['Model', 'Dataset'])

    # 创建结果表格的列，显示所有非Ours的指标，当前c_metric对应的Ours指标，以及ACC
    columns = ['ACC'] + [alias for u_metric, alias in u_metric_alias.items() if (c_metric in u_metric) or 'ours' not in u_metric]

    # 创建结果表格
    main_results_short = pd.DataFrame(columns=columns, index=multi_index).astype(float)
    main_results_long = deepcopy(main_results_short)

    # 遍历模型、数据集、数据集类型
    for model_name in model_names:
        for dst_name in dst_names:
            for dst_type in dst_types:
                # 创建新行，初始化为0
                new_row = {k: 0. for k in main_results_short.columns}

                # 获取结果路径
                result_path = get_eval_main_result_path(model_name, dst_name, dst_type)
                if os.path.exists(result_path):
                    test_dst = Dataset.load_from_disk(result_path)
                    # 计算ACC
                    new_row['ACC'] = get_acc_by_c_th(test_dst, c_metric, c_th) * 100

                    # 计算AUROC，填充到行中
                    for u_metric, u_metric_name in u_metric_alias.items():
                        if u_metric_name in columns and u_metric in test_dst.column_names:
                            new_row[u_metric_name] = get_auroc(test_dst, u_metric, c_metric, c_th, u_metric_direction[u_metric]) * 100

                # 根据数据集类型填充到对应的结果表格中
                result = main_results_short if dst_type == 'short' else main_results_long
                result.loc[(model_names_alias[model_name], dst_names_alias[dst_name])] = new_row

    # 打印结果
    print(f"Correctness Metric: {c_metric} Threshold: {c_th} AUROC Results")
    print("Short Prompt Main Result:")
    display(main_results_short)

    print("Long Prompt Main Result:")
    display(main_results_long)
datasets.enable_progress_bar()
# print(main_results_short.to_latex(index=True, float_format="%.2f"))
# print(main_results_long.to_latex(index=True, float_format="%.2f"))

Correctness Metric: rougel Threshold: 0.3 AUROC Results
Short Prompt Main Result:


Unnamed: 0_level_0,Unnamed: 1_level_0,ACC,PE,LN-PE,TokenSAR,SentSAR,SAR,LS,SE,Ours(MSRL),Ours(LSRL)
Model,Dataset,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1
Vicuna-7B,SciQ,71.3,77.362423,60.041734,53.41957,75.875356,61.118794,72.178458,74.524876,83.903221,82.758722
Vicuna-7B,CoQA,54.0,74.20934,55.035427,49.126409,72.075684,51.826892,66.245572,68.595813,80.614332,75.635266
Vicuna-7B,TriviaQA,52.4,76.599886,76.279348,76.131407,76.869507,77.851169,74.13621,76.472994,72.939452,68.035153
Vicuna-7B,MedMCQA,47.8,62.802385,62.985139,60.029016,61.261001,62.924622,66.796919,57.136616,74.117291,67.224747
Vicuna-7B,MedQA,46.1,57.97331,59.420514,51.984675,56.49733,54.428342,61.097115,53.968142,69.389767,64.508671
Vicuna-13B,SciQ,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
Vicuna-13B,CoQA,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
Vicuna-13B,TriviaQA,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
Vicuna-13B,MedMCQA,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
Vicuna-13B,MedQA,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


Long Prompt Main Result:


Unnamed: 0_level_0,Unnamed: 1_level_0,ACC,PE,LN-PE,TokenSAR,SentSAR,SAR,LS,SE,Ours(MSRL),Ours(LSRL)
Model,Dataset,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1
Vicuna-7B,SciQ,20.8,53.049728,36.151175,41.030497,51.129079,42.497086,40.749442,47.574908,91.297895,89.142628
Vicuna-7B,CoQA,38.6,70.098395,65.717034,51.789843,69.816543,49.072589,62.829319,69.322037,80.522692,74.825741
Vicuna-7B,TriviaQA,39.4,81.544957,55.435702,53.526704,81.185396,67.510387,75.688965,76.4952,84.337254,82.755147
Vicuna-7B,MedMCQA,14.5,65.470458,62.644888,56.806614,65.163138,51.761242,52.681992,45.143779,85.492236,78.210526
Vicuna-7B,MedQA,12.4,72.054979,65.872183,52.876896,69.951392,55.450453,62.578252,50.09114,84.716545,78.584383
Vicuna-13B,SciQ,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
Vicuna-13B,CoQA,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
Vicuna-13B,TriviaQA,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
Vicuna-13B,MedMCQA,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
Vicuna-13B,MedQA,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


Correctness Metric: sentsim Threshold: 0.3 AUROC Results
Short Prompt Main Result:


Unnamed: 0_level_0,Unnamed: 1_level_0,ACC,PE,LN-PE,TokenSAR,SentSAR,SAR,LS,SE,Ours(MSSI),Ours(LSSI)
Model,Dataset,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1
Vicuna-7B,SciQ,93.9,64.937412,64.934793,56.57309,71.278304,67.725519,66.217113,69.018314,76.841425,75.430786
Vicuna-7B,CoQA,79.8,70.67297,64.651728,49.704707,68.483089,51.480186,69.61339,68.421053,81.2241,74.227648
Vicuna-7B,TriviaQA,81.6,83.908781,79.595255,80.28093,84.096601,83.124867,79.522325,81.592338,82.694946,79.576939
Vicuna-7B,MedMCQA,75.1,49.653207,62.208087,57.99015,52.67087,61.622789,55.990941,48.867106,69.359729,65.361045
Vicuna-7B,MedQA,73.7,45.705279,57.641708,50.901043,48.73627,58.093391,55.728444,48.975396,69.76464,65.631916
Vicuna-13B,SciQ,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
Vicuna-13B,CoQA,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
Vicuna-13B,TriviaQA,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
Vicuna-13B,MedMCQA,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
Vicuna-13B,MedQA,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


Long Prompt Main Result:


Unnamed: 0_level_0,Unnamed: 1_level_0,ACC,PE,LN-PE,TokenSAR,SentSAR,SAR,LS,SE,Ours(MSSI),Ours(LSSI)
Model,Dataset,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1
Vicuna-7B,SciQ,82.0,66.705962,61.348238,58.380759,67.71748,61.118564,61.222561,60.407859,83.336043,80.452913
Vicuna-7B,CoQA,77.0,62.19537,66.507058,51.708639,62.186335,49.115754,61.185771,60.582722,82.645963,75.808018
Vicuna-7B,TriviaQA,71.1,83.048633,61.158561,58.565596,83.143046,71.926085,76.741905,69.158405,88.278364,86.600821
Vicuna-7B,MedMCQA,50.6,54.006577,60.721144,55.62441,56.10768,56.145885,55.179746,47.956906,77.0809,70.333928
Vicuna-7B,MedQA,53.7,60.352892,62.057024,54.968206,61.828573,52.701393,58.300453,48.955681,77.174809,71.613757
Vicuna-13B,SciQ,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
Vicuna-13B,CoQA,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
Vicuna-13B,TriviaQA,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
Vicuna-13B,MedMCQA,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
Vicuna-13B,MedQA,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


Correctness Metric: include Threshold: 0.3 AUROC Results
Short Prompt Main Result:


Unnamed: 0_level_0,Unnamed: 1_level_0,ACC,PE,LN-PE,TokenSAR,SentSAR,SAR,LS,SE,Ours(MSIN),Ours(LSIN)
Model,Dataset,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1
Vicuna-7B,SciQ,66.1,71.13875,60.136157,55.874937,70.996613,62.493585,66.015111,70.927664,77.59741,74.496941
Vicuna-7B,CoQA,62.6,66.58352,60.086962,46.794007,68.081871,53.275188,62.011584,65.539629,70.999129,68.180964
Vicuna-7B,TriviaQA,35.7,73.341,71.709773,71.636804,73.84677,73.561649,70.206185,73.246904,70.551424,65.663839
Vicuna-7B,MedMCQA,21.8,67.177785,52.775757,52.216734,65.191581,55.106877,55.281682,65.14612,68.509937,64.392348
Vicuna-7B,MedQA,26.6,62.789894,49.340569,48.029645,60.637459,53.523796,53.753508,60.320932,66.903208,65.000973
Vicuna-13B,SciQ,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
Vicuna-13B,CoQA,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
Vicuna-13B,TriviaQA,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
Vicuna-13B,MedMCQA,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
Vicuna-13B,MedQA,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


Long Prompt Main Result:


Unnamed: 0_level_0,Unnamed: 1_level_0,ACC,PE,LN-PE,TokenSAR,SentSAR,SAR,LS,SE,Ours(MSIN),Ours(LSIN)
Model,Dataset,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1
Vicuna-7B,SciQ,64.2,66.342087,62.719504,59.417585,66.507423,60.98827,59.052977,62.258306,74.172888,78.371752
Vicuna-7B,CoQA,63.2,63.425289,65.088401,49.21574,64.130435,50.488442,60.690699,63.92921,72.646361,68.211681
Vicuna-7B,TriviaQA,36.2,66.295312,58.997818,57.156558,66.306136,62.83751,65.916019,61.542675,70.178519,69.995367
Vicuna-7B,MedMCQA,17.0,62.844082,65.5854,59.073707,62.338767,60.981573,61.087881,51.109851,72.500354,68.960666
Vicuna-7B,MedQA,19.5,62.079312,59.620322,51.356585,61.958274,56.835165,62.692467,54.398471,74.189202,70.000956
Vicuna-13B,SciQ,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
Vicuna-13B,CoQA,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
Vicuna-13B,TriviaQA,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
Vicuna-13B,MedMCQA,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
Vicuna-13B,MedQA,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


In [21]:
# Generalization1: Cross Dataset and Cross Prompt Evaluation
model_name = 'vicuna-7b-v1.1'


def plot_cross_dst_matrix(model_name, u_metric, c_metric, c_th):
    fig_index = [(name, type) for type in dst_types for name in dst_names]
    fig_axis = list(map(lambda idx: f"{dst_names_alias[idx[0]]}-{idx[1]}", fig_index))
    cross_eval_matrix = torch.zeros(len(dst_types) * len(dst_names), len(dst_types) * len(dst_names))
    fig = go.Figure()
    annotations = []

    for i, (train_dst_name, train_dst_type) in enumerate(fig_index):
        for j, (test_dst_name, test_dst_type) in enumerate(fig_index):
            result_path = get_eval_cross_result_path(model_name, train_dst_name, train_dst_type, test_dst_name, test_dst_type, u_metric.split("_")[-1])
            if os.path.exists(f"{result_path}/dataset_info.json"):
                cross_eval_result = Dataset.load_from_disk(result_path)
                cross_eval_matrix[j][i] = get_auroc(cross_eval_result, u_metric, c_metric, c_th) * 100
                annotations.append(dict(
                    x=i,
                    y=j,
                    text=f"{cross_eval_matrix[j][i].item():.2f}",
                    showarrow=False,
                    font=dict(color='white')
                ))
    fig.add_trace(go.Heatmap(z=cross_eval_matrix, x=fig_axis, y=fig_axis, colorscale='Inferno'))
    fig.update_layout(
        title_text=f"Model: {model_names_alias[model_name]} Correctness Metric: {c_metric} Method: {u_metric} Cross Eval Results",
        xaxis_title="Train Dataset",
        yaxis_title="Test Dataset",
        width=1000,
        height=1000,
        annotations=annotations
    )
    fig.show()
    overall_average_drop = cross_eval_matrix.mean().item() - cross_eval_matrix.diag().mean().item()
    short_average_drop = cross_eval_matrix[:5, :5].mean().item() - cross_eval_matrix[:5, :5].diag().mean().item()
    long_average_drop = cross_eval_matrix[5:, 5:].mean().item() - cross_eval_matrix[5:, 5:].diag().mean().item()
    cross_dst_average_drop = (short_average_drop + long_average_drop) / 2
    cross_prompt_average_drop = 0
    for i in range(len(cross_eval_matrix)):
        cross_prompt_average_drop += cross_eval_matrix[i][(i + len(cross_eval_matrix) // 2) % len(cross_eval_matrix)].item()
    cross_prompt_average_drop /= len(cross_eval_matrix)
    print(f"Overall Average Drop: {overall_average_drop:.2f}")
    print(f"Short Average Drop: {short_average_drop:.2f}")
    print(f"Long Average Drop: {long_average_drop:.2f}")
    print(f"Cross Dst Average Drop: {cross_dst_average_drop:.2f}")
    print(f"Cross Prompt Average Drop: {cross_prompt_average_drop:.2f}")
    for i in range(len(cross_eval_matrix)):
        dst_drop = cross_eval_matrix[:, i].mean().item() - cross_eval_matrix[i, i].item()
        print(f"Train Dst:{fig_axis[i]} Average Drop: {dst_drop:.2f}")


plot_cross_dst_matrix(model_name, u_metric='u_score_ours_mean_soft_rougel', c_metric='rougel', c_th=0.3)
plot_cross_dst_matrix(model_name, u_metric='u_score_ours_mean_soft_include', c_metric='include', c_th=0.3)
plot_cross_dst_matrix(model_name, u_metric='u_score_ours_mean_soft_sentsim', c_metric='sentsim', c_th=0.3)

KeyboardInterrupt: 

In [24]:
# Efficiency: Get Efficiency Results
model_name = 'vicuna-7b-v1.1'
dst_name = 'sciq'
dst_type = 'long'


def plot_efficiency_results(model_name, dst_name, dst_type):
    fig = go.Figure()
    result_path = get_eval_main_result_path(model_name, dst_name, dst_type)
    test_dst = Dataset.load_from_disk(result_path)
    time_fwd = np.sum(test_dst['time_fwd'])
    time_pe = np.sum(test_dst['time_pe'])
    time_token_sar = np.sum(test_dst['time_token_sar'])
    time_sent_sar = np.sum(test_dst['time_sent_sar'])
    time_sar = np.sum(test_dst['time_sar'])
    time_ls = np.sum(test_dst['time_ls'])
    time_se = np.sum(test_dst['time_se'])
    time_ours = np.sum(test_dst['time_ours_mean_soft_rougel']) - time_fwd


    bar = go.Bar(
        y=['PE', 'TokenSAR', 'SentSAR', 'SAR', 'LS', 'SE', 'Ours'],
        x=[time_pe, time_token_sar, time_sent_sar, time_sar, time_ls, time_se, time_ours],
        orientation='h'
    )
    fig.add_trace(bar)
    fig.show()


plot_efficiency_results(model_name, dst_name, dst_type)

In [25]:
# Show Correctness Metric Correlation
base_c_metric = "rougel"
model_name = 'vicuna-7b-v1.1'

for dst_name in dst_names:
    for dst_type in dst_types:
        result_path = get_eval_main_result_path(model_name, dst_name, dst_type)
        if os.path.exists(result_path):
            test_dst = Dataset.load_from_disk(result_path).select(range(500))
            # test_dst = test_dst.add_column("idx", list(range(len(test_dst))))
            test_dst = test_dst.sort(base_c_metric)
            test_dst = test_dst.map(lambda x: dict(include=x['include'] - 0.04) if x['include'] == 0 else dict(include=x['include'] + 0.04))
            test_dst = test_dst.map(lambda x: dict(sentsim=x['sentsim'] + 0.02) if x['sentsim'] > 0.99 else dict(sentsim=x['sentsim']))
            # test_dst = test_dst.map(lambda x: dict(rougel=x['rougel']+0.02) if x['rougel'] > 0.99 else dict(rougel=x['rougel']))
            fig = go.Figure()
            for c_metric in c_metrics:
                fig.add_trace(go.Scatter(x=list(range(len(test_dst))), y=test_dst[c_metric], mode='markers', name=c_metric))
            fig.update_layout(title_text=f"Model: {model_names_alias[model_name]} Dataset: {dst_names_alias[dst_name]}-{dst_type} Correctness Metric Similarity",
                              xaxis_title=base_c_metric,
                              yaxis_title="Other Correctness Metric",
                              width=1000,
                              height=500)
            fig.show()


Map:   0%|          | 0/500 [00:00<?, ? examples/s]

Map:   0%|          | 0/500 [00:00<?, ? examples/s]

Map:   0%|          | 0/500 [00:00<?, ? examples/s]

Map:   0%|          | 0/500 [00:00<?, ? examples/s]

Map:   0%|          | 0/500 [00:00<?, ? examples/s]

Map:   0%|          | 0/500 [00:00<?, ? examples/s]

Map:   0%|          | 0/500 [00:00<?, ? examples/s]

Map:   0%|          | 0/500 [00:00<?, ? examples/s]

Map:   0%|          | 0/500 [00:00<?, ? examples/s]

Map:   0%|          | 0/500 [00:00<?, ? examples/s]

Map:   0%|          | 0/500 [00:00<?, ? examples/s]

Map:   0%|          | 0/500 [00:00<?, ? examples/s]

Map:   0%|          | 0/500 [00:00<?, ? examples/s]

Map:   0%|          | 0/500 [00:00<?, ? examples/s]

Map:   0%|          | 0/500 [00:00<?, ? examples/s]

Map:   0%|          | 0/500 [00:00<?, ? examples/s]

Map:   0%|          | 0/500 [00:00<?, ? examples/s]

Map:   0%|          | 0/500 [00:00<?, ? examples/s]

Map:   0%|          | 0/500 [00:00<?, ? examples/s]

Map:   0%|          | 0/500 [00:00<?, ? examples/s]

In [None]:
# Merged Training: Get Merged Training Eval Results


In [None]:
# Ablation Study: Get Ablation Results

In [None]:
# Sensitivity Analysis : Get Sensitivity Results

In [None]:
# Case Study: show token level u_score
example = test_dst.filter(lambda x: x['rougel'] < 0.1)[1]
example = test_dst[2]
print(f"gt:{example['gt']}")
print(f"options:{example['options']}")

str_tokens = model.to_str_tokens(f":{example['washed_answer']}", prepend_bos=False)[1:]
fig = make_subplots(rows=2, cols=1, subplot_titles=("Token Level", "Sentence Level"), row_heights=[0.5, 0.5])

fig.add_trace(go.Scatter(x=list(range(len(str_tokens))), y=example['u_score_pe_all'], mode='lines+markers'), row=1, col=1)
fig.update_xaxes(title_text='Token', tickvals=list(range(len(str_tokens))), ticktext=str_tokens, row=1, col=1)

sentence_u_score_pe_all = []
indices = [0] + [i for i, x in enumerate(str_tokens) if x == '.'] + [-1]
spans = [(indices[i], indices[i + 1]) for i in range(len(indices) - 1)]
print(len(indices))
for span in spans:
    sentence_score = sum(example['u_score_pe_all'][span[0]:span[1]]) / (span[1] - span[0])
    sentence_u_score_pe_all.extend([sentence_score] * (span[1] - span[0]))
sentence_u_score_pe_all.append(sentence_u_score_pe_all[-1])
# print(str_tokens)
for i, sentence in enumerate(example['washed_answer'].split(".")):
    print(i + 1, sentence.replace("\n", ' ').strip())
# print(len(example['u_score_pe_all']))
# print(sentence_u_score_pe_all)
# print(len(sentence_u_score_pe_all))

fig.add_trace(go.Scatter(x=list(range(len(str_tokens))), y=sentence_u_score_pe_all, mode='lines+markers'), row=2, col=1)
fig.update_xaxes(title_text='Sentence', tickvals=list(range(len(str_tokens))), ticktext=str_tokens, row=2, col=1)

fig.update_layout(height=1000, width=2500, margin=dict(l=0, r=0, b=50, t=50), title_text=example['washed_answer'])
fig.show()