# Inplementation of Internal State-based Uncertainty Estimation

1. 数据集预处理：将不同格式的数据集处理成input+gt的格式，方便判断模型的correctness，这一部分的采用固定不可调整的prompt，即Context: Question: Options: Answer:格式
2. 生成回复，为每个模型确定一个prompt，一个max_new_tokens数，然后生成回复
3. 计算回复部分的correctness指标，判断模型的回复是否正确
4. 计算uncertainty指标，包括PE, LN-PE, SAR, Ours
5. 计算AUROC，绘制AUROC/Correctness-Threshold曲线

In [1]:
from utils import *

datasets.disable_caching()
torch.set_grad_enabled(False)

# Eval Result Config
model_names = [
    "vicuna-7b-v1.1",
    "vicuna-13b-v1.1",
    "vicuna-33b-v1.3",
]

dst_names = [
    "sciq",
    "coqa",
    "triviaqa",
    "medmcqa",
    "MedQA-USMLE-4-options",
]

c_metrics = [
    'rougel'
    'sentsim'
    'include'
]

dst_types = [
    "short",
    "long",
]

acc_map = {
    "vicuna-7b-v1.1": {
        "sciq": 0.60,
        "coqa": 0.8,
        "triviaqa": 0.55,
        "medmcqa": 0.30,
        "MedQA-USMLE-4-options": 0.30
    },
    "vicuna-13b-v1.1": {
        "sciq": 0.0,
        "coqa": 0.0,
        "triviaqa": 0.0,
        "medmcqa": 0.0,
        "MedQA-USMLE-4-options": 0.0
    },
    "vicuna-33b-v1.3": {
        "sciq": 0.0,
        "coqa": 0.0,
        "triviaqa": 0.0,
        "medmcqa": 0.0,
        "MedQA-USMLE-4-options": 0.0
    }
}

model_names_alias = {
    "vicuna-7b-v1.1": "Vicuna-7B",
    "vicuna-13b-v1.1": "Vicuna-13B",
    "vicuna-33b-v1.3": "Vicuna-33B"
}

dst_names_alias = {
    "sciq": "SciQ",
    "coqa": "CoQA",
    "triviaqa": "TriviaQA",
    "medmcqa": "MedMCQA",
    "MedQA-USMLE-4-options": "MedQA"
}

u_metric_alias = {
    "u_score_pe": "PE",
    "u_score_ln_pe": "LN-PE",
    "u_score_token_sar": "TokenSAR",
    "u_score_sent_sar": "SentSAR",
    "u_score_sar": "SAR",
    "u_score_ls": "LS",
    "u_score_se": "SE",
    "u_score_ours_mean_soft": "Ours(MS)",
    "u_score_ours_last_soft": "Ours(LS)"
}


def get_eval_main_result_path(model_name, dst_name, dst_type):
    return f"eval_results/{model_name}/{dst_name}_{dst_type}"


def get_eval_cross_result_path(model_name, train_dst_name, train_dst_type, test_dst_name, test_dst_type):
    return f"cross_eval_results/{model_name}/rougel/v_c_{train_dst_name}_{train_dst_type}_mean_soft_best.pth/{test_dst_name}_{test_dst_type}"


def get_c_th_by_acc(test_dst, c_metric, acc):
    sorted_c_scores = sorted(list(test_dst[c_metric]), reverse=True)
    c_th = sorted_c_scores[int(len(sorted_c_scores) * acc)]
    return c_th


def get_acc_by_c_th(test_dst, c_metric, c_th):
    return sum([1 if s > c_th else 0 for s in test_dst[c_metric]]) / len(test_dst)

In [15]:
# Show Case
test_dst = Dataset.load_from_disk(get_eval_main_result_path('vicuna-7b-v1.1', 'sciq', 'short'))
for i in range(0, 10):
    for k in ['washed_answer', 'gt', 'options', 'include']:
        print(f"{k}: {test_dst[i][k]}")
    print()
print(get_auroc(test_dst, 'u_score_ours_mean_soft', 'include', 0.3))

washed_answer: Darwin, Linnaeus, Scopes, Shaw
gt: darwin
options: ['Scopes', 'shaw', 'darwin', 'Linnaeus']
include: 0

washed_answer: hydrochloric, amino, lactic, fatty
gt: amino
options: ['hydrochloric', 'amino', 'lactic', 'fatty']
include: 0

washed_answer: genes
gt: nucleotides
options: ['carotenoids', 'proteins', 'genes', 'nucleotides']
include: 0

washed_answer: Wetland
gt: wetland
options: ['plains', 'wetland', 'grassland', 'tundra']
include: 1

washed_answer: gamma rays
gt: the sun
options: ['gamma rays', 'the sun', 'the moon', 'decomposition']
include: 0

washed_answer: blood vessels, organs, muscles, tissue
gt: blood vessels
options: ['blood vessels', 'organs', 'muscles', 'tissue']
include: 0

washed_answer: catabolic and anabolic

Metabolic reactions can be broadly classified into two categories: catabolic and anabolic reactions.

Catabolic reactions break down large molecules into smaller ones, releasing energy in the process. These reactions occur in the cytoplasm and the m

In [16]:
# Show Case
test_dst = Dataset.load_from_disk(get_eval_cross_result_path('vicuna-7b-v1.1', 'sciq', 'short', 'sciq', 'short'))
for i in range(0, 10):
    for k in ['washed_answer', 'gt', 'include']:
        print(f"{k}: {test_dst[i][k]}")
    print()
print(get_auroc(test_dst, 'u_score_ours_mean_soft', 'include', 0.3))

washed_answer: Darwin, Linnaeus, Scopes, Shaw
gt: darwin
include: 0

washed_answer: hydrochloric, amino, lactic, fatty
gt: amino
include: 0

washed_answer: genes
gt: nucleotides
include: 0

washed_answer: Wetland
gt: wetland
include: 1

washed_answer: gamma rays
gt: the sun
include: 0

washed_answer: blood vessels, organs, muscles, tissue
gt: blood vessels
include: 0

washed_answer: catabolic and anabolic

Metabolic reactions can be broadly classified into two categories: catabolic and anabolic reactions.

Catabolic reactions break down large molecules into smaller ones, releasing energy in the process. These reactions occur in the cytoplasm and the mitochondria of cells and are responsible for breaking down nutrients into smaller molecules that can be used for energy or for building new molecules. Examples of catabolic reactions include the breakdown of glucose to produce energy in the form of ATP, the
gt: catabolic and anabolic
include: 1

washed_answer: volcanic ash
gt: volcanic ash

In [8]:
# Main Results: Get Main Results

c_metric = 'rougel'
default_c_th = 0.3

model_names_index = [model_names_alias[name] for name in model_names for i in range(len(dst_names))]
dst_names_index = [dst_names_alias[name] for name in dst_names] * len(model_names)
multi_index = pd.MultiIndex.from_tuples(zip(model_names_index, dst_names_index), names=['Model', 'Dataset'])
main_results_short = pd.DataFrame(columns=['ACC'] + list(u_metric_alias.values()), index=multi_index).astype(float)
main_results_long = deepcopy(main_results_short)

for model_name in model_names:
    for dst_name in dst_names:
        for dst_type in dst_types:
            new_row = {k: 0. for k in main_results_short.columns}
            result_path = get_eval_main_result_path(model_name, dst_name, dst_type)
            if os.path.exists(result_path):
                test_dst = Dataset.load_from_disk(result_path)
                c_th = default_c_th
                acc = get_acc_by_c_th(test_dst, c_metric, c_th)
                new_row['ACC'] = acc * 100
                for u_metric, u_metric_name, in u_metric_alias.items():
                    if u_metric in test_dst.column_names:
                        new_row[u_metric_name] = get_auroc(test_dst, u_metric, c_metric, c_th) * 100
            result = main_results_short if dst_type == 'short' else main_results_long
            result.loc[(model_names_alias[model_name], dst_names_alias[dst_name])] = new_row

print("Short Prompt Main Result:")
display(main_results_short)

print("Long Prompt Main Result:")
display(main_results_long)

print(main_results_short.to_latex(index=True, float_format="%.2f"))
print(main_results_long.to_latex(index=True, float_format="%.2f"))

Short Prompt Main Result:


Unnamed: 0_level_0,Unnamed: 1_level_0,ACC,PE,LN-PE,TokenSAR,SentSAR,SAR,LS,SE,Ours(MS),Ours(LS)
Model,Dataset,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1
Vicuna-7B,SciQ,66.6,81.494444,57.313751,52.64561,75.829647,60.063881,74.565509,73.647974,86.126171,83.728714
Vicuna-7B,CoQA,52.2,74.528287,51.025185,50.379134,72.138059,54.986454,66.491928,60.077911,83.330929,81.749467
Vicuna-7B,TriviaQA,52.9,76.647041,76.537673,76.499344,76.850726,77.772426,73.753306,76.956281,72.515743,70.878836
Vicuna-7B,MedMCQA,25.1,83.87784,62.72294,58.141267,72.43496,56.667323,70.13016,52.861451,88.772547,85.187687
Vicuna-7B,MedQA,4.9,78.950192,62.209489,61.49059,80.293569,60.954098,63.866177,52.490397,92.167214,76.766669
Vicuna-13B,SciQ,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
Vicuna-13B,CoQA,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
Vicuna-13B,TriviaQA,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
Vicuna-13B,MedMCQA,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
Vicuna-13B,MedQA,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


Long Prompt Main Result:


Unnamed: 0_level_0,Unnamed: 1_level_0,ACC,PE,LN-PE,TokenSAR,SentSAR,SAR,LS,SE,Ours(MS),Ours(LS)
Model,Dataset,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1
Vicuna-7B,SciQ,20.8,50.692623,77.468192,73.288474,52.062087,67.557789,58.036798,51.621382,88.927739,86.360905
Vicuna-7B,CoQA,38.6,63.344079,59.224317,58.842889,62.571096,54.083475,62.267304,61.997266,78.890652,72.618184
Vicuna-7B,TriviaQA,36.9,75.067536,67.876516,68.145371,71.89732,54.045499,77.248872,58.929776,85.726833,87.194156
Vicuna-7B,MedMCQA,16.5,64.432589,50.808928,51.367084,62.763201,52.803484,55.372528,62.919978,80.601343,75.462529
Vicuna-7B,MedQA,15.0,67.643922,52.041569,52.715294,65.254118,59.762745,63.347059,65.269804,79.352549,76.653333
Vicuna-13B,SciQ,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
Vicuna-13B,CoQA,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
Vicuna-13B,TriviaQA,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
Vicuna-13B,MedMCQA,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
Vicuna-13B,MedQA,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


\begin{tabular}{llrrrrrrrrrr}
\toprule
 &  & ACC & PE & LN-PE & TokenSAR & SentSAR & SAR & LS & SE & Ours(MS) & Ours(LS) \\
Model & Dataset &  &  &  &  &  &  &  &  &  &  \\
\midrule
\multirow[t]{5}{*}{Vicuna-7B} & SciQ & 66.60 & 81.49 & 57.31 & 52.65 & 75.83 & 60.06 & 74.57 & 73.65 & 86.13 & 83.73 \\
 & CoQA & 52.20 & 74.53 & 51.03 & 50.38 & 72.14 & 54.99 & 66.49 & 60.08 & 83.33 & 81.75 \\
 & TriviaQA & 52.90 & 76.65 & 76.54 & 76.50 & 76.85 & 77.77 & 73.75 & 76.96 & 72.52 & 70.88 \\
 & MedMCQA & 25.10 & 83.88 & 62.72 & 58.14 & 72.43 & 56.67 & 70.13 & 52.86 & 88.77 & 85.19 \\
 & MedQA & 4.90 & 78.95 & 62.21 & 61.49 & 80.29 & 60.95 & 63.87 & 52.49 & 92.17 & 76.77 \\
\cline{1-12}
\multirow[t]{5}{*}{Vicuna-13B} & SciQ & 0.00 & 0.00 & 0.00 & 0.00 & 0.00 & 0.00 & 0.00 & 0.00 & 0.00 & 0.00 \\
 & CoQA & 0.00 & 0.00 & 0.00 & 0.00 & 0.00 & 0.00 & 0.00 & 0.00 & 0.00 & 0.00 \\
 & TriviaQA & 0.00 & 0.00 & 0.00 & 0.00 & 0.00 & 0.00 & 0.00 & 0.00 & 0.00 & 0.00 \\
 & MedMCQA & 0.00 & 0.00 & 0.00 & 0.0

In [50]:
# Generalization: Get Cross Eval Results
u_metric = 'u_score_ours_mean_soft'
c_metric = 'rougel'
c_th = 0.3

fig_index = [(name, type) for type in dst_types for name in dst_names]
fig_axis = list(map(lambda idx: f"{dst_names_alias[idx[0]]}-{idx[1]}", fig_index))
for model_name in model_names:
    cross_eval_matrix = torch.zeros(len(dst_types) * len(dst_names), len(dst_types) * len(dst_names))
    fig = go.Figure()
    annotations = []
    for i, (train_dst_name, train_dst_type) in enumerate(fig_index):
        for j, (test_dst_name, test_dst_type) in enumerate(fig_index):
            result_path = get_eval_cross_result_path(model_name, train_dst_name, train_dst_type, test_dst_name, test_dst_type)
            if os.path.exists(f"{result_path}/dataset_info.json"):
                cross_eval_result = Dataset.load_from_disk(result_path)
                cross_eval_matrix[j][i] = get_auroc(cross_eval_result, u_metric, c_metric, c_th) * 100
                if i == j:
                    annotations.append(dict(
                        x=i,
                        y=j,
                        text=f"{cross_eval_matrix[j][i].item():.2f}",
                        showarrow=False,
                        font=dict(
                            color='white'
                        )
                    ))
        center_auroc = cross_eval_matrix[i][i].item()
        for j, (test_dst_name, test_dst_type) in enumerate(fig_index):
            cross_eval_matrix[j][i] -= center_auroc
            if i != j:
                annotations.append(dict(
                    x=i,
                    y=j,
                    text=f"{cross_eval_matrix[j][i].item():.2f}",
                    showarrow=False,
                    font=dict(
                        color='white'
                    )
                ))
    fig.add_trace(go.Heatmap(z=cross_eval_matrix, x=fig_axis, y=fig_axis, colorscale='Inferno'))
    fig.update_layout(
        title_text=f"{model_names_alias[model_name]} Cross Eval Results",
        xaxis_title="Train Dataset",
        yaxis_title="Test Dataset",
        width=1000,
        height=1000,
        annotations=annotations
    )
    fig.show()

In [None]:
# Efficiency: Get Efficiency Results

In [None]:
# Ablation Study: Get Ablation Results

In [0]:
# Sensitivity Analysis : Get Sensitivity Results

In [None]:
# Case Study: show token level u_score
example = test_dst.filter(lambda x: x['rougel'] < 0.1)[1]
example = test_dst[2]
print(f"gt:{example['gt']}")
print(f"options:{example['options']}")

str_tokens = model.to_str_tokens(f":{example['washed_answer']}", prepend_bos=False)[1:]
fig = make_subplots(rows=2, cols=1, subplot_titles=("Token Level", "Sentence Level"), row_heights=[0.5, 0.5])

fig.add_trace(go.Scatter(x=list(range(len(str_tokens))), y=example['u_score_pe_all'], mode='lines+markers'), row=1, col=1)
fig.update_xaxes(title_text='Token', tickvals=list(range(len(str_tokens))), ticktext=str_tokens, row=1, col=1)

sentence_u_score_pe_all = []
indices = [0] + [i for i, x in enumerate(str_tokens) if x == '.'] + [-1]
spans = [(indices[i], indices[i + 1]) for i in range(len(indices) - 1)]
print(len(indices))
for span in spans:
    sentence_score = sum(example['u_score_pe_all'][span[0]:span[1]]) / (span[1] - span[0])
    sentence_u_score_pe_all.extend([sentence_score] * (span[1] - span[0]))
sentence_u_score_pe_all.append(sentence_u_score_pe_all[-1])
# print(str_tokens)
for i, sentence in enumerate(example['washed_answer'].split(".")):
    print(i + 1, sentence.replace("\n", ' ').strip())
# print(len(example['u_score_pe_all']))
# print(sentence_u_score_pe_all)
# print(len(sentence_u_score_pe_all))

fig.add_trace(go.Scatter(x=list(range(len(str_tokens))), y=sentence_u_score_pe_all, mode='lines+markers'), row=2, col=1)
fig.update_xaxes(title_text='Sentence', tickvals=list(range(len(str_tokens))), ticktext=str_tokens, row=2, col=1)

fig.update_layout(height=1000, width=2500, margin=dict(l=0, r=0, b=50, t=50), title_text=example['washed_answer'])
fig.show()