# Inplementation of Internal State-based Uncertainty Estimation

1. 数据集预处理：将不同格式的数据集处理成input+gt的格式，方便判断模型的correctness，这一部分的采用固定不可调整的prompt，即Context: Question: Options: Answer:格式
2. 生成回复，为每个模型确定一个prompt，一个max_new_tokens数，然后生成回复
3. 计算回复部分的correctness指标，判断模型的回复是否正确
4. 计算uncertainty指标，包括PE, LN-PE, SAR, Ours
5. 计算AUROC，绘制AUROC/Correctness-Threshold曲线

In [1]:
from utils import *

datasets.disable_caching()
torch.set_grad_enabled(False)

# Eval Result Config
model_names = [
    "vicuna-7b-v1.1",
    "vicuna-13b-v1.1",
    "vicuna-33b-v1.3",
]

dst_names = [
    "sciq",
    "coqa",
    "triviaqa",
    "medmcqa",
    "MedQA-USMLE-4-options",
]

c_metrics = [
    'rougel',
    'sentsim',
    'include'
]

dst_types = [
    "short",
    "long",
]

acc_map = {
    "vicuna-7b-v1.1": {
        "sciq": 0.60,
        "coqa": 0.8,
        "triviaqa": 0.55,
        "medmcqa": 0.30,
        "MedQA-USMLE-4-options": 0.30
    },
    "vicuna-13b-v1.1": {
        "sciq": 0.0,
        "coqa": 0.0,
        "triviaqa": 0.0,
        "medmcqa": 0.0,
        "MedQA-USMLE-4-options": 0.0
    },
    "vicuna-33b-v1.3": {
        "sciq": 0.0,
        "coqa": 0.0,
        "triviaqa": 0.0,
        "medmcqa": 0.0,
        "MedQA-USMLE-4-options": 0.0
    }
}

model_names_alias = {
    "vicuna-7b-v1.1": "Vicuna-7B",
    "vicuna-13b-v1.1": "Vicuna-13B",
    "vicuna-33b-v1.3": "Vicuna-33B"
}

dst_names_alias = {
    "sciq": "SciQ",
    "coqa": "CoQA",
    "triviaqa": "TriviaQA",
    "medmcqa": "MedMCQA",
    "MedQA-USMLE-4-options": "MedQA"
}

u_metric_alias = {
    "u_score_pe": "PE",
    "u_score_ln_pe": "LN-PE",
    "u_score_token_sar": "TokenSAR",
    "u_score_sent_sar": "SentSAR",
    "u_score_sar": "SAR",
    "u_score_ls": "LS",
    "u_score_se": "SE",
    "u_score_ours_mean_soft_rougel": "Ours(MSRL)",
    "u_score_ours_last_soft_rougel": "Ours(LSRL)",
    "u_score_ours_mean_soft_include": "Ours(MSIN)",
    "u_score_ours_last_soft_include": "Ours(LSIN)",
    "u_score_ours_mean_soft_sentsim": "Ours(MSSI)",
    "u_score_ours_last_soft_sentsim": "Ours(LSSI)"

}


def get_cached_result_path(model_name, dst_name, dst_type, dst_split):
    return f"cached_results/{model_name}/{dst_type}/{dst_name}_{dst_split}"


def get_eval_main_result_path(model_name, dst_name, dst_type):
    return f"eval_results/{model_name}/{dst_name}_{dst_type}"


def get_eval_cross_result_path(model_name, train_dst_name, train_dst_type, test_dst_name, test_dst_type, c_metric):
    return f"cross_eval_results/{model_name}/{c_metric}/v_c_{train_dst_name}_{train_dst_type}_mean_soft_best.pth/{test_dst_name}_{test_dst_type}"


def get_c_th_by_acc(test_dst, c_metric, acc):
    sorted_c_scores = sorted(list(test_dst[c_metric]), reverse=True)
    c_th = sorted_c_scores[int(len(sorted_c_scores) * acc)]
    return c_th


def get_acc_by_c_th(test_dst, c_metric, c_th):
    return sum([1 if s > c_th else 0 for s in test_dst[c_metric]]) / len(test_dst)

In [72]:
def show_dst_case(test_dst):
    print(test_dst.column_names)
    for i in range(0, 10):
        for k in ['question', 'washed_answer','input','output', 'gt', 'u_score_pe']:
            print(f"{k}: {test_dst[i][k]}")
        print()
show_dst_case(Dataset.load_from_disk(get_eval_main_result_path('vicuna-7b-v1.1', 'sciq', 'short')))

['question', 'distractor3', 'distractor1', 'distractor2', 'correct_answer', 'support', 'input', 'dst_template', 'options', 'gt', 'answer', 'washed_answer', 'output', 'washed_output', 'sampled_answer', 'washed_sampled_answer', 'sampled_output', 'washed_sampled_output', 'num_input_tokens', 'num_output_tokens', 'num_answer_tokens', 'answer_idxs', 'rougel', 'sentsim', 'include', 'answer_prob', 'time_fwd', 'sampled_answer_prob', 'u_score_len', 'u_score_pe_all', 'u_score_pe', 'u_score_ln_pe', 'time_pe', 'u_score_token_sar', 'time_token_sar', 'time_sent_sar', 'time_sar', 'u_score_sent_sar', 'u_score_sar', 'u_score_ls', 'time_ls', 'u_score_se', 'time_se', 'u_score_ours_mean_soft_rougel', 'time_ours_mean_soft_rougel', 'u_score_ours_last_soft_rougel', 'time_ours_last_soft_rougel', 'u_score_ours_mean_soft_include', 'time_ours_mean_soft_include', 'u_score_ours_last_soft_include', 'time_ours_last_soft_include']
question: Who proposed the theory of evolution by natural selection?
washed_answer: Darw

In [2]:
# Merge Train Dataset
model_name = 'vicuna-7b-v1.1'
train_size_per_dataset = 2000
val_size_per_dataset = 100


def merge_dst(model_name, dst_names, dst_types, dst_split, data_size):
    all_dst = [Dataset.load_from_disk(get_cached_result_path(model_name, dst_name, dst_type, dst_split)) for dst_name in dst_names for dst_type in dst_types]

    all_columns = ['dst_template', 'question', 'input', 'gt', 'options', 'answer', 'output']

    def fill_missing_columns(dst: Dataset):
        for column in dst.column_names:
            if column not in all_columns:
                dst = dst.remove_columns(column)
        if 'options' not in dst.column_names:
            dst = dst.add_column('options', [['#####'] for i in range(len(dst))])
        return dst

    all_dst = [dst if len(dst) <= data_size else dst.select(range(data_size)) for dst in all_dst]
    all_dst = [fill_missing_columns(dst) for dst in all_dst]
    merged_dst = datasets.concatenate_datasets(all_dst)
    print(len(merged_dst))
    save_path = get_cached_result_path(model_name, 'all', 'merged', dst_split)
    os.makedirs(save_path, exist_ok=True)
    merged_dst.save_to_disk(save_path)
    print(f"Save Merged Dataset to {save_path}")

merge_dst(model_name, dst_names, dst_types, 'train', train_size_per_dataset)
merge_dst(model_name, dst_names, dst_types, 'validation', val_size_per_dataset)

In [None]:
# Main Results: Get Main Results

c_metric = 'include'
c_th = 0.3

model_names_index = [model_names_alias[name] for name in model_names for i in range(len(dst_names))]
dst_names_index = [dst_names_alias[name] for name in dst_names] * len(model_names)
multi_index = pd.MultiIndex.from_tuples(zip(model_names_index, dst_names_index), names=['Model', 'Dataset'])
columns = ['ACC'] + [alias for u_metric, alias in u_metric_alias.items() if (c_metric in u_metric) or 'ours' not in u_metric]
main_results_short = pd.DataFrame(columns=columns, index=multi_index).astype(float)
main_results_long = deepcopy(main_results_short)

for model_name in model_names:
    for dst_name in dst_names:
        for dst_type in dst_types:
            new_row = {k: 0. for k in main_results_short.columns}
            result_path = get_eval_main_result_path(model_name, dst_name, dst_type)
            if os.path.exists(result_path):
                test_dst = Dataset.load_from_disk(result_path)
                new_row['ACC'] = get_acc_by_c_th(test_dst, c_metric, c_th) * 100
                for u_metric, u_metric_name in u_metric_alias.items():
                    if u_metric_name in columns and u_metric in test_dst.column_names:
                        new_row[u_metric_name] = get_auroc(test_dst, u_metric, c_metric, c_th) * 100
            result = main_results_short if dst_type == 'short' else main_results_long
            result.loc[(model_names_alias[model_name], dst_names_alias[dst_name])] = new_row

print(f"Correctness Metric: {c_metric} Threshold: {c_th} AUROC Results")
print("Short Prompt Main Result:")
display(main_results_short)

print("Long Prompt Main Result:")
display(main_results_long)

# print(main_results_short.to_latex(index=True, float_format="%.2f"))
# print(main_results_long.to_latex(index=True, float_format="%.2f"))

In [None]:
# Generalization1: Cross Dataset and Cross Prompt Evaluation
model_name = 'vicuna-7b-v1.1'


def plot_cross_dst_matrix(model_name, u_metric, c_metric, c_th):
    fig_index = [(name, type) for type in dst_types for name in dst_names]
    fig_axis = list(map(lambda idx: f"{dst_names_alias[idx[0]]}-{idx[1]}", fig_index))
    cross_eval_matrix = torch.zeros(len(dst_types) * len(dst_names), len(dst_types) * len(dst_names))
    fig = go.Figure()
    annotations = []

    for i, (train_dst_name, train_dst_type) in enumerate(fig_index):
        for j, (test_dst_name, test_dst_type) in enumerate(fig_index):
            result_path = get_eval_cross_result_path(model_name, train_dst_name, train_dst_type, test_dst_name, test_dst_type, u_metric.split("_")[-1])
            if os.path.exists(f"{result_path}/dataset_info.json"):
                cross_eval_result = Dataset.load_from_disk(result_path)
                cross_eval_matrix[j][i] = get_auroc(cross_eval_result, u_metric, c_metric, c_th) * 100
                annotations.append(dict(
                    x=i,
                    y=j,
                    text=f"{cross_eval_matrix[j][i].item():.2f}",
                    showarrow=False,
                    font=dict(
                        color='white'
                    )
                ))
    fig.add_trace(go.Heatmap(z=cross_eval_matrix, x=fig_axis, y=fig_axis, colorscale='Inferno'))
    fig.update_layout(
        title_text=f"Model: {model_names_alias[model_name]} Correctness Metric: {c_metric} Method: {u_metric} Cross Eval Results",
        xaxis_title="Train Dataset",
        yaxis_title="Test Dataset",
        width=1000,
        height=1000,
        annotations=annotations
    )
    fig.show()
    overall_average_drop = cross_eval_matrix.mean().item() - cross_eval_matrix.diag().mean().item()
    short_average_drop = cross_eval_matrix[:5, :5].mean().item() - cross_eval_matrix[:5, :5].diag().mean().item()
    long_average_drop = cross_eval_matrix[5:, 5:].mean().item() - cross_eval_matrix[5:, 5:].diag().mean().item()
    cross_dst_average_drop = (short_average_drop + long_average_drop) / 2
    cross_prompt_average_drop = 0
    for i in range(len(cross_eval_matrix)):
        cross_prompt_average_drop += cross_eval_matrix[i][(i + len(cross_eval_matrix) // 2) % len(cross_eval_matrix)].item()
    cross_prompt_average_drop /= len(cross_eval_matrix)
    print(f"Overall Average Drop: {overall_average_drop:.2f}")
    print(f"Short Average Drop: {short_average_drop:.2f}")
    print(f"Long Average Drop: {long_average_drop:.2f}")
    print(f"Cross Dst Average Drop: {cross_dst_average_drop:.2f}")
    print(f"Cross Prompt Average Drop: {cross_prompt_average_drop:.2f}")
    for i in range(len(cross_eval_matrix)):
        dst_drop = cross_eval_matrix[:, i].mean().item() - cross_eval_matrix[i, i].item()
        print(f"Train Dst:{fig_axis[i]} Average Drop: {dst_drop:.2f}")


plot_cross_dst_matrix(model_name, u_metric='u_score_ours_mean_soft_rougel', c_metric='include', c_th=0.3)
plot_cross_dst_matrix(model_name, u_metric='u_score_ours_mean_soft_include', c_metric='include', c_th=0.3)

In [None]:
# Generalization2: Cross Correctness Metric Evaluation
model_name = 'vicuna-7b-v1.1'
c_th = 0.3

c_metric_index = [name for name in c_metrics for i in range(len(dst_names))]
dst_names_index = [dst_names_alias[name] for name in dst_names] * len(c_metrics)
multi_index = pd.MultiIndex.from_tuples(zip(c_metric_index, dst_names_index), names=['Correctness Metric', 'Dataset'])
columns = ['ACC'] + [u for u in u_metric_alias.values() if 'Ours' in u]
cross_c_metric_results_short = pd.DataFrame(columns=columns, index=multi_index).astype(float)
cross_c_metric_results_long = deepcopy(cross_c_metric_results_short)

for dst_type in dst_types:
    result = cross_c_metric_results_short if dst_type == 'short' else cross_c_metric_results_long
    for c_metric in c_metrics:
        for dst_name in dst_names:
            new_row = {k: 0. for k in columns}
            result_path = get_eval_main_result_path(model_name, dst_name, dst_type)
            if os.path.exists(result_path):
                test_dst = Dataset.load_from_disk(result_path)
                new_row['ACC'] = get_acc_by_c_th(test_dst, c_metric, c_th) * 100
                for u_metric, u_metric_name in u_metric_alias.items():
                    if u_metric_name in columns and u_metric in test_dst.column_names:
                        new_row[u_metric_name] = get_auroc(test_dst, u_metric, c_metric, c_th) * 100
            result.loc[(c_metric, dst_names_alias[dst_name])] = new_row

print("Short Prompt Cross Correctness Metric Result:")
display(cross_c_metric_results_short)

print("Long Prompt Cross Correctness Metric Result:")
display(cross_c_metric_results_long)

In [None]:
# Efficiency: Get Efficiency Results
model_name = 'vicuna-7b-v1.1'
dst_name = 'sciq'
dst_type = 'long'


def plot_efficiency_results(model_name, dst_name, dst_type):
    fig = go.Figure()
    result_path = get_eval_main_result_path(model_name, dst_name, dst_type)

In [None]:
# Show Correctness Metric Similarity
base_c_metric = "rougel"
model_name = 'vicuna-7b-v1.1'

for dst_name in dst_names:
    for dst_type in dst_types:
        result_path = get_eval_main_result_path(model_name, dst_name, dst_type)
        if os.path.exists(result_path):
            test_dst = Dataset.load_from_disk(result_path).select(range(500))
            # test_dst = test_dst.add_column("idx", list(range(len(test_dst))))
            test_dst = test_dst.sort(base_c_metric)
            test_dst = test_dst.map(lambda x: dict(include=x['include'] - 0.04) if x['include'] == 0 else dict(include=x['include'] + 0.04))
            test_dst = test_dst.map(lambda x: dict(sentsim=x['sentsim'] + 0.02) if x['sentsim'] > 0.99 else dict(sentsim=x['sentsim']))
            # test_dst = test_dst.map(lambda x: dict(rougel=x['rougel']+0.02) if x['rougel'] > 0.99 else dict(rougel=x['rougel']))
            fig = go.Figure()
            for c_metric in c_metrics:
                fig.add_trace(go.Scatter(x=list(range(len(test_dst))), y=test_dst[c_metric], mode='markers', name=c_metric))
            fig.update_layout(title_text=f"Model: {model_names_alias[model_name]} Dataset: {dst_names_alias[dst_name]}-{dst_type} Correctness Metric Similarity",
                              xaxis_title=base_c_metric,
                              yaxis_title="Other Correctness Metric",
                              width=2000,
                              height=1000)
            fig.show()


In [None]:
# Merged Training: Get Merged Training Eval Results


In [None]:
# Ablation Study: Get Ablation Results

In [None]:
# Sensitivity Analysis : Get Sensitivity Results

In [None]:
# Case Study: show token level u_score
example = test_dst.filter(lambda x: x['rougel'] < 0.1)[1]
example = test_dst[2]
print(f"gt:{example['gt']}")
print(f"options:{example['options']}")

str_tokens = model.to_str_tokens(f":{example['washed_answer']}", prepend_bos=False)[1:]
fig = make_subplots(rows=2, cols=1, subplot_titles=("Token Level", "Sentence Level"), row_heights=[0.5, 0.5])

fig.add_trace(go.Scatter(x=list(range(len(str_tokens))), y=example['u_score_pe_all'], mode='lines+markers'), row=1, col=1)
fig.update_xaxes(title_text='Token', tickvals=list(range(len(str_tokens))), ticktext=str_tokens, row=1, col=1)

sentence_u_score_pe_all = []
indices = [0] + [i for i, x in enumerate(str_tokens) if x == '.'] + [-1]
spans = [(indices[i], indices[i + 1]) for i in range(len(indices) - 1)]
print(len(indices))
for span in spans:
    sentence_score = sum(example['u_score_pe_all'][span[0]:span[1]]) / (span[1] - span[0])
    sentence_u_score_pe_all.extend([sentence_score] * (span[1] - span[0]))
sentence_u_score_pe_all.append(sentence_u_score_pe_all[-1])
# print(str_tokens)
for i, sentence in enumerate(example['washed_answer'].split(".")):
    print(i + 1, sentence.replace("\n", ' ').strip())
# print(len(example['u_score_pe_all']))
# print(sentence_u_score_pe_all)
# print(len(sentence_u_score_pe_all))

fig.add_trace(go.Scatter(x=list(range(len(str_tokens))), y=sentence_u_score_pe_all, mode='lines+markers'), row=2, col=1)
fig.update_xaxes(title_text='Sentence', tickvals=list(range(len(str_tokens))), ticktext=str_tokens, row=2, col=1)

fig.update_layout(height=1000, width=2500, margin=dict(l=0, r=0, b=50, t=50), title_text=example['washed_answer'])
fig.show()