In [None]:
import os
import numpy as np
from tqdm import tqdm
import torch
from sklearn.metrics import roc_auc_score
import matplotlib.pyplot as plt
import csv

from utils import compute_semantic_clusters, compute_semantic_entropy, prepare_results

import seaborn as sns
sns.set()

**输入结构**（每个样本一份 results_dict，每个 run 两条通道）：

- run_key ∈ {"sdlg","baseline"}
- results_dict[run_key]['generations']：生成的回答序列（第 0 个通常是 most-likely）
- results_dict[run_key]['likelihoods']：对应的负对数似然 / token 概率等
- results_dict[run_key][f'semantic_pairs_{model_type}']：
  - semantic_pairs：NLI/DeBERTa 得到的语义等价矩阵（bool，上三角 → 对称化）
  - （可选）cleaned_semantic_pairs：清洗过的等价矩阵

**核心对比流程**（逐渐增加可用的 generations 数量 n，从 2 到 num_total_gens）

1. 取前 n 个生成：
- baseline：用 boolean_mask（前 n 个为 True）。
- SDLG：除了 boolean_mask，还构造 mask（重要性权重）：

```
mask = tensor([1] + [gen['token_likelihood'] for gen in generations[1:]])
# 注意：SDLG里这个mask是 (0,1) 区间的“权重”，随后会做 L1 归一化
```

2. 截取对应的 semantic_pairs 子矩阵（并与转置求交，保证对称）。

3. 聚类（compute_semantic_clusters）：把 n 个回答按语义连通分量/团划成若干簇，得到 semantic_difference 结构（含簇 id）。

4. 算语义熵 SE（compute_semantic_entropy）：
- baseline：权重 = boolean_mask.float()（等权）
- SDLG：权重 = normalize(mask[boolean_mask])（重要性采样权重）
- 两个版本：
  - "normalised_semantic_entropy"（按簇概率正规化）
  - "unnormalised_semantic_entropy"（未正规化版本）
- 还特意算了 baseline 的一个 错误估计（mc_estimate_over_clusters=True），在图里用 “(incorrect estimate)” 对比。

5. 把 SE 当成“不确定度分数”，用一个正确性度量（如 bleurt/rouge/rougeL 等）阈值化成 “对/错” 二分类标签；然后计算 AUROC：
- 标签：list_correct_labels = ~ (score >= threshold) → “1 = 错误”
- 预测：SE（值越大，越“不确定”）
- roc_auc_score(labels, SE) → 不同 n（#gen）下的曲线

6. 记录 CSV + 画图：
- 同一张图里画 "sdlg" 和 "baseline" 的 AUROC 随 n（使用的 generations 数） 的变化；
- baseline 还额外画一个 “incorrect estimate” 曲线（mc_estimate_over_clusters）。

In [None]:
# dataset = ['coqa', 'trivia_qa', 'truthful_qa'][0]
# correctness_metric = ["rougeL", "bleurt", "rouge1"][2]
dataset = 'trivia_qa'
correctness_metric = 'rougeL'

if dataset == 'truthful_qa':
    correctness_metric += "-diff"

correctness_threshold_list = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1]
num_total_gens = 10

run_ids = ['tqa_llama2-13b-chat-1']

model_type = "deberta-large-mnli"
auroc_keys = ["normalised_semantic_entropy", "unnormalised_semantic_entropy"] 
# auroc_keys += ["cleaned_normalised_semantic_entropy", "cleaned_unnormalised_semantic_entropy"]

log_results = True
plot_results = True

Check the existence of the pkl file

In [None]:
import pickle

# 替换为你自己的文件路径
file_path = 'results/tqa_llama2-7b-chat-3/results_dict_0.pkl'

with open(file_path, 'rb') as f:
    data = pickle.load(f)


print("字典键：", data.keys())
# for k, v in list(data.items())[:]:
#     print(f"{k}: {v}")

# 正确访问方式
#print(data["input_ids"])
print("Question:", data["question"])
print("Correctness:", data["correctness_dict"])
print("sdlg generations:", data["sdlg"]["generations"])
print("baseline generations:", data["baseline"]["generations"])

if not data["sdlg"]["generations"]:
    print("⚠️ SDLG generations 是空的")
if not data["baseline"]["generations"]:
    print("⚠️ Baseline generations 是空的")

In [None]:
compute_cleaned = "cleaned_normalised_semantic_entropy" in auroc_keys or "cleaned_unnormalised_semantic_entropy" in auroc_keys



if plot_results:
    fig_dict = {}
    for correctness_threshold in correctness_threshold_list:
        fig, axs = plt.subplots(nrows=1, ncols=len(auroc_keys), figsize=(5*len(auroc_keys), 5))
        fig_dict[correctness_threshold] = {"fig": fig, "axs": axs}

color_count = 0
for run_id in run_ids:

    basepath = os.path.join('results', run_id)

    for run_key in ["sdlg", "baseline"]:

        if plot_results:
            plot_dict, plot_dict_kuhn = {}, {}
            for correctness_threshold in correctness_threshold_list:
                plot_dict[correctness_threshold] = {'plot_data_x': [], 'plot_data_y': []}
                plot_dict_kuhn[correctness_threshold] = {'plot_data_x': [], 'plot_data_y': []}

        list_results_dict, list_rougeL, dataset_size = prepare_results(num_samples=8000,
                                                                        run_key=run_key,
                                                                        metric=correctness_metric,
                                                                        start_sample_id=0,
                                                                        base_path=basepath)

        # iterate over num_gens
        for num_gens in tqdm(range(2, num_total_gens+1)):

            all_semantic_entropies, all_semantic_entropies_kuhn = [], []
            list_num_semantic_clusters, list_num_generations = [], []
            
            # iterate over instances
            for i, results_dict in enumerate(list_results_dict):

                # ---------- create mask of considered generations
                if num_gens < len(results_dict[run_key]['generations']):
                    boolean_mask = [True] * num_gens + [False] * (len(results_dict[run_key]['generations']) - num_gens)
                else:
                    boolean_mask = [True] * len(results_dict[run_key]['generations'])

                boolean_mask = torch.tensor(boolean_mask)

                if run_key == "sdlg":
                    mask = torch.tensor([1] + [gen['token_likelihood'] for gen in results_dict[run_key]['generations'][1:]])
                    assert torch.all(mask > 0) and torch.all(mask[1:] < 1), f"mask: {mask}"
                elif run_key == "baseline":
                    mask = boolean_mask
                # ----------

                list_num_generations.append(torch.sum(boolean_mask).item())

                all_considered_generations, all_considered_likelihoods = [], []
                for m, included in enumerate(boolean_mask):
                    if included:
                        all_considered_generations.append(results_dict[run_key]['generations'][m])
                        all_considered_likelihoods.append(results_dict[run_key]['likelihoods'][m])

                if results_dict[run_key][f'semantic_pairs_{model_type}']['semantic_pairs'].shape[0] > 1:
                    semantic_pairs = results_dict[run_key][f'semantic_pairs_{model_type}']['semantic_pairs'][boolean_mask, :][:, boolean_mask]
                    if compute_cleaned:
                        cleaned_semantic_pairs = results_dict[run_key][f'semantic_pairs_{model_type}']['cleaned_semantic_pairs'][boolean_mask, :][:, boolean_mask]
                else:
                    # deal with single generations
                    assert boolean_mask.item() == True, f"mask: {boolean_mask}"
                    semantic_pairs = results_dict[run_key][f'semantic_pairs_{model_type}']['semantic_pairs']
                    if compute_cleaned:
                        cleaned_semantic_pairs = results_dict[run_key][f'semantic_pairs_{model_type}']['cleaned_semantic_pairs']

                # compute symmetric adjacency matrix
                semantic_pairs = semantic_pairs & semantic_pairs.T
                assert np.array_equal(semantic_pairs, semantic_pairs.T)
                if compute_cleaned:
                    cleaned_semantic_pairs = cleaned_semantic_pairs & cleaned_semantic_pairs.T
                    assert np.array_equal(cleaned_semantic_pairs, cleaned_semantic_pairs.T)

                # compute semantic clusters
                semantic_difference = compute_semantic_clusters(generations=all_considered_generations, 
                                                                        cleaned_semantic_pairs=cleaned_semantic_pairs if compute_cleaned else None,
                                                                        semantic_pairs=semantic_pairs,
                                                                        compute_cleaned=compute_cleaned)
                list_num_semantic_clusters.append(torch.unique(semantic_difference["semantic_clusters"]).shape[0])

                # compute semantic entropy
                weights = boolean_mask[boolean_mask].to(torch.float32) if run_key == "baseline" else torch.nn.functional.normalize(mask[boolean_mask].to(torch.float32), p=1, dim=0)
                all_semantic_entropies.append(
                    compute_semantic_entropy(weights=weights,
                                             mc_estimate_over_clusters=False,
                                             neg_log_likelihoods=all_considered_likelihoods, 
                                             semantic_difference=semantic_difference,
                                             compute_cleaned=compute_cleaned))

                if run_key == "baseline":
                    all_semantic_entropies_kuhn.append(
                        compute_semantic_entropy(weights=weights,
                                                 mc_estimate_over_clusters=True,
                                                 neg_log_likelihoods=all_considered_likelihoods, 
                                                 semantic_difference=semantic_difference,
                                                 compute_cleaned=compute_cleaned))

            for r, correctness_threshold in enumerate(correctness_threshold_list):
                list_correct_labels = torch.logical_not((torch.tensor(list_rougeL) >= correctness_threshold))
                aurocs, aurocs_kuhn = {}, {}
                for k, key in enumerate(auroc_keys):
                    aurocs[key] = roc_auc_score(list_correct_labels,
                                                [d[key] for d in all_semantic_entropies])

                    if run_key == "baseline":
                        aurocs_kuhn[key] = roc_auc_score(list_correct_labels,
                                                        [d[key] for d in all_semantic_entropies_kuhn])

                if plot_results:
                    plot_dict[correctness_threshold]['plot_data_x'].append(num_gens) # sum(list_num_generations) / len(list_num_generations))
                    plot_dict[correctness_threshold]['plot_data_y'].append(aurocs)

                    if run_key == "baseline":
                        plot_dict_kuhn[correctness_threshold]['plot_data_x'].append(num_gens)
                        plot_dict_kuhn[correctness_threshold]['plot_data_y'].append(aurocs_kuhn)
                    
                if log_results:
                    with open(os.path.join(basepath, f'auroc_results.csv'), 'a') as f:
                        writer = csv.writer(f)
                        writer.writerow([run_id, 
                                         run_key, 
                                         num_gens, 
                                         correctness_threshold, 
                                         "importance_sampling", 
                                         aurocs["normalised_semantic_entropy"], 
                                         aurocs["unnormalised_semantic_entropy"]])
                        if run_key == "baseline":
                            writer.writerow([run_id, 
                                             run_key, 
                                             num_gens, 
                                             correctness_threshold, 
                                             "mc_estimate_over_clusters", 
                                             aurocs_kuhn["normalised_semantic_entropy"], 
                                             aurocs_kuhn["unnormalised_semantic_entropy"]])

                    with open(os.path.join(basepath, f'avg_num_clusters_found.csv'), 'a') as f:
                        writer = csv.writer(f)
                        clusters_correct_answer = torch.tensor(list_num_semantic_clusters)[list_correct_labels == 0]
                        clusters_wrong_answer = torch.tensor(list_num_semantic_clusters)[list_correct_labels == 1]
                        writer.writerow([run_id, run_key, num_gens, correctness_threshold, 
                                        sum(list_num_semantic_clusters) / len(list_num_semantic_clusters), 
                                        torch.mean(clusters_correct_answer.float()).item(), sum(list_correct_labels).item(),
                                        torch.mean(clusters_wrong_answer.float()).item(), (len(list_correct_labels) - sum(list_correct_labels)).item()])

        if plot_results:
            for correctness_threshold in correctness_threshold_list:
                for k, key in enumerate(auroc_keys):
                    x_values = plot_dict[correctness_threshold]['plot_data_x']
                    y_values = [d[key] for d in plot_dict[correctness_threshold]['plot_data_y']]
                    fig_dict[correctness_threshold]["axs"][k].plot(x_values, 
                                                                   y_values, 
                                                                   c=f"C{color_count}", 
                                                                   alpha=0.8, 
                                                                   marker="o" if run_key=="sdlg" else "^", 
                                                                   label=f"{run_id} ({run_key})")
                    
                    if run_key == "baseline":
                        x_values = plot_dict_kuhn[correctness_threshold]['plot_data_x']
                        y_values = [d[key] for d in plot_dict_kuhn[correctness_threshold]['plot_data_y']]
                        fig_dict[correctness_threshold]["axs"][k].plot(x_values, 
                                                                       y_values, 
                                                                       c=f"C{color_count}", 
                                                                       alpha=0.4, 
                                                                       marker="v", 
                                                                       label=f"{run_id} (incorrect estimate)")

                    fig_dict[correctness_threshold]["axs"][k].set_xlabel('# generations')
                    fig_dict[correctness_threshold]["axs"][k].set_ylabel(f'AUROC')
                    fig_dict[correctness_threshold]["axs"][k].set_xlim([1, num_total_gens+1])
                    fig_dict[correctness_threshold]["axs"][k].set_xticks(range(2, num_total_gens+1, 2))
                    fig_dict[correctness_threshold]["axs"][k].set_ylim([0.5, 1])
                    fig_dict[correctness_threshold]["axs"][k].legend()
                    fig_dict[correctness_threshold]["axs"][k].set_title(key)
                    
                    if run_key == "baseline":
                        fig_dict[correctness_threshold]["fig"].suptitle(f'Correctness threshold: {correctness_threshold}', fontsize=12)
                        fig_dict[correctness_threshold]["fig"].tight_layout()
                        fig_dict[correctness_threshold]["fig"].savefig(os.path.join(basepath, f'{model_type}_{dataset_size}samples_{correctness_threshold}threshold.png'))

        color_count += 1
        