In [None]:
import os

# 从环境变量读取 OpenAI API 密钥
# 密钥应该已经在 ~/.bashrc 中设置，或者通过 export OPENAI_API_KEY=... 设置
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
if OPENAI_API_KEY is None:
    raise ValueError("OPENAI_API_KEY environment variable is not set. Please set it before running this notebook.")
else:
    print("✓ OpenAI API 密钥已从环境变量加载")

✓ OpenAI API 密钥已从代码中加载


In [2]:
# !pip install openai transformers datasets accelerate torch pandas pyarrow tqdm
# 'accelerate' 是为了更快地加载和运行模型
# 'pyarrow' 是为了将 DataFrame 保存为 parquet 格式
# 'openai' 用于 GPT-3.5 API

import torch
import numpy as np
import pandas as pd
from datasets import load_dataset
from openai import OpenAI
from torch.nn.functional import softmax
from tqdm import tqdm
import os

# 检查是否有可用的 GPU (在 Colab 或本地)
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")

# 初始化 OpenAI 客户端 (需要设置 OPENAI_API_KEY 环境变量)
client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))

Using device: cuda


In [3]:
# 使用 GPT-3.5 Turbo
MODEL_ID = "gpt-3.5-turbo"

# 注意：GPT-3.5 通过 OpenAI API 使用，不需要本地加载模型
# 确保已设置 OPENAI_API_KEY 环境变量
print(f"Using model: {MODEL_ID} via OpenAI API")
print("Model ready to use.")

Using model: gpt-3.5-turbo via OpenAI API
Model ready to use.


In [4]:
# 获取 MMLU 数据集的所有可用主题
# 先获取所有主题，排除 'all'，然后选择前20个进行标注
print("Fetching all available MMLU subjects...")
from datasets import get_dataset_config_names

# 获取所有可用的配置（主题）
all_subjects = get_dataset_config_names("cais/mmlu")
print(f"Found {len(all_subjects)} subjects in MMLU dataset")

# 排除 'all' 这个subject
filtered_subjects = [s for s in all_subjects if s != 'all']
print(f"After filtering out 'all', {len(filtered_subjects)} subjects remain")

# 只选择前20个subject进行标注
SUBJECTS = filtered_subjects[:2]
print(f"Selected first {len(SUBJECTS)} subjects for processing:")
print(f"Subjects: {SUBJECTS}")

# 用于存储所有加载数据的字典
mmlu_data = {}

for subject in SUBJECTS:
    # MMLU 的 "test" 集是有标签的，"validation" 集是无标签的（用于官方提交）
    # 所以我们加载 "test" 集
    try:
        dataset = load_dataset("cais/mmlu", subject, split="test")
        mmlu_data[subject] = dataset
        print(f"Loaded {len(dataset)} questions for subject: {subject}")
    except Exception as e:
        print(f"Failed to load {subject}: {e}")

print(f"\nSuccessfully loaded {len(mmlu_data)} subjects out of {len(SUBJECTS)} selected subjects")

# MMLU 的选项
CHOICES = ["A", "B", "C", "D", "E", "F"]

Fetching all available MMLU subjects...
Found 59 subjects in MMLU dataset
After filtering out 'all', 58 subjects remain
Selected first 2 subjects for processing:
Subjects: ['abstract_algebra', 'anatomy']
Loaded 100 questions for subject: abstract_algebra
Loaded 135 questions for subject: anatomy

Successfully loaded 2 subjects out of 2 selected subjects


In [5]:
def format_mmlu_prompt(sample, subject_name):
    """
    将 MMLU 的一行数据格式化为 zero-shot CoT prompt。
    """
    subject_formatted = subject_name.replace("_", " ")
    question = sample['question']
    
    # 组合选项
    options = ""
    for i, choice in enumerate(sample['choices']):
        options += f"{CHOICES[i]}. {choice}\n"
    
    prompt = f"""The following is a multiple-choice question about {subject_formatted}. Please choose the single most likely answer.

Question: {question}
{options}
Answer:"""
    return prompt

def get_choice_probabilities(prompt, model_id, client, num_choices=None):
    """
    给定一个 prompt，计算模型对选项的概率。
    使用 OpenAI API 的 logprobs 功能。
    
    Args:
        prompt: 输入提示
        model_id: 模型ID
        client: OpenAI客户端
        num_choices: 选项数量（如果为None，则使用CHOICES的长度）
    """
    # 确定实际使用的选项数量
    if num_choices is None:
        num_choices = len(CHOICES)
    actual_choices = CHOICES[:num_choices]
    
    # 1. 准备选项 token (GPT-3.5 通常使用 " A", " B", " C", " D" 等格式)
    choice_tokens = [f" {choice}" for choice in actual_choices]
    
    # 2. 调用 OpenAI API 获取 logprobs (带重试机制)
    import time
    max_retries = 3
    retry_delay = 1  # 初始延迟（秒）
    
    for attempt in range(max_retries):
        try:
            response = client.chat.completions.create(
                model=model_id,
                messages=[
                    {"role": "user", "content": prompt}
                ],
                logprobs=True,  # 启用 logprobs
                top_logprobs=20,  # 获取 top 20 的 logprobs
                max_tokens=1,  # 只生成一个 token
                temperature=0  # 使用确定性输出
            )
            break  # 成功则跳出重试循环
        except Exception as e:
            if attempt < max_retries - 1:
                # 如果是速率限制错误，等待更长时间
                if "rate limit" in str(e).lower() or "429" in str(e):
                    wait_time = retry_delay * (2 ** attempt)  # 指数退避
                    print(f"Rate limit hit, waiting {wait_time}s before retry {attempt + 1}/{max_retries}...")
                    time.sleep(wait_time)
                else:
                    time.sleep(retry_delay * (2 ** attempt))
                continue
            else:
                # 最后一次尝试也失败，返回均匀分布
                print(f"Error calling OpenAI API after {max_retries} attempts: {e}")
                return np.ones(num_choices) / num_choices
    
    if response is None:
        return np.ones(num_choices) / num_choices
    
    # 3. 获取第一个（也是唯一的）token 的 logprobs
    if response.choices[0].logprobs and response.choices[0].logprobs.content:
        token_logprobs = response.choices[0].logprobs.content[0].top_logprobs
        # 创建一个字典，将 token 文本映射到 logprob
        logprob_dict = {item.token: item.logprob for item in token_logprobs}
    else:
        logprob_dict = {}
    
    # 4. 提取每个选项的 logprob
    choice_logprobs = []
    for choice_token in choice_tokens:
        choice_letter = choice_token.strip()  # 获取字母部分 (A, B, C, D, E, F)
        logprob = None
        
        # 尝试多种可能的 token 格式
        # 1. 带前导空格的格式: " A", " B", " C", etc.
        # 注意: choice_token 已经是 " A" 格式，所以直接使用
        if logprob is None:
            logprob = logprob_dict.get(choice_token, None)
        
        # 2. 不带空格的格式: "A", "B", "C", etc.
        if logprob is None:
            logprob = logprob_dict.get(choice_letter, None)
        
        # 3. 带点号的格式: "A.", "B.", "C.", etc.
        if logprob is None:
            logprob = logprob_dict.get(f"{choice_letter}.", None)
        
        # 4. 带前导空格和点号的格式: " A.", " B.", " C.", etc.
        if logprob is None:
            logprob = logprob_dict.get(f" {choice_letter}.", None)
        
        # 5. 规范化匹配：去除所有空格和标点后比较
        if logprob is None:
            for token, lp in logprob_dict.items():
                # 规范化 token：去除空格、点号等，只保留字母
                normalized_token = ''.join(c for c in token if c.isalpha())
                if normalized_token == choice_letter:
                    logprob = lp
                    break
        
        # 6. 大小写不敏感匹配
        if logprob is None:
            for token, lp in logprob_dict.items():
                normalized_token = ''.join(c for c in token if c.isalpha())
                if normalized_token.upper() == choice_letter.upper():
                    logprob = lp
                    break
        
        if logprob is None:
            # 如果找不到，使用一个很小的值
            logprob = -100.0
            # 只在第一次找不到时打印调试信息
            if len(choice_logprobs) == 0:  # 只在第一个选项找不到时打印
                print(f"Warning: Could not find logprob for choice token '{choice_token}'")
                print(f"Available tokens (top 10): {list(logprob_dict.keys())[:10]}")
        
        choice_logprobs.append(logprob)
    
    # 5. 将 logprobs 转换为 logits (logprobs 已经是 log 概率)
    choice_logits = np.array(choice_logprobs)
    
    # 6. 应用 softmax 得到概率分布
    # 为了避免数值不稳定，减去最大值
    choice_logits_shifted = choice_logits - np.max(choice_logits)
    exp_logits = np.exp(choice_logits_shifted)
    choice_probs = exp_logits / np.sum(exp_logits)
    
    return choice_probs

In [6]:
def calculate_aps_score(probs, choice_index):
    """
    为 *一个* 假设的答案 (choice_index) 计算 APS 不一致性分数。
    S(X, y) = 1 - (所有 P_j >= P_y 的 P_j 的总和)
    
    Args:
    - probs (np.array): 概率数组，例如 [P(A), P(B), P(C), P(D), P(E), P(F)]
    - choice_index (int): 我们正在计算分数的那个选项 (0=A, 1=B, 2=C, 3=D, 4=E, 5=F)
    """
    
    # 1. 获取我们正在打分的这个选项的概率
    prob_y = probs[choice_index]
    
    # 2. 找到所有概率 >= prob_y 的选项
    indices_to_sum = np.where(probs >= prob_y)[0]
    
    # 3. 把它们的概率加起来
    # 为了处理浮点数精度问题，我们应该比较 probs >= prob_y - 1e-9
    prob_sum = 0
    for idx in indices_to_sum:
        # 再次检查，避免浮点数问题
        if probs[idx] >= prob_y - 1e-9:
            prob_sum += probs[idx]
            
    # 4. APS 分数
    score = prob_sum
    
    return score

# --- 快速测试一下我们的计分函数 ---
test_probs = np.array([0.42, 0.40, 0.10, 0.05, 0.02, 0.01])
# 选项 A: S(A) = 1 - P(A) = 1 - 0.42 = 0.58
# 选项 B: S(B) = 1 - (P(A) + P(B)) = 1 - (0.42 + 0.40) = 0.18
# 选项 C: S(C) = 1 - (P(A) + P(B) + P(C)) = 1 - (0.45 + 0.40 + 0.10) = 0.05
# 选项 D: S(D) = 1 - (P(A) + P(B) + P(C) + P(D)) = 1 - 1.0 = 0.0

print(f"Test S(A): {calculate_aps_score(test_probs, 0)}") # 应该约等于 0.58
print(f"Test S(B): {calculate_aps_score(test_probs, 1)}") # 应该约等于 0.18
print(f"Test S(C): {calculate_aps_score(test_probs, 2)}") # 应该约等于 0.08
print(f"Test S(D): {calculate_aps_score(test_probs, 3)}") # 应该约等于 0.03
print(f"Test S(E): {calculate_aps_score(test_probs, 4)}") # 应该约等于 0.01
print(f"Test S(F): {calculate_aps_score(test_probs, 5)}") # 应该约等于 0.0


Test S(A): 0.42
Test S(B): 0.8200000000000001
Test S(C): 0.92
Test S(D): 0.9700000000000001
Test S(E): 0.9900000000000001
Test S(F): 1.0


In [7]:
results_list = [] # 用于存储我们所有数据的列表

# 遍历我们加载的每个 MMLU 主题
for subject_name, dataset in mmlu_data.items():
    print(f"\nProcessing subject: {subject_name}...")
    
    # 遍历该主题中的所有问题
    for i, sample in enumerate(tqdm(dataset)):
        
        # 1. 格式化 prompt
        prompt = format_mmlu_prompt(sample, subject_name)
        
        # 2. 获取实际选项数量（MMLU问题可能有不同数量的选项）
        num_choices = len(sample['choices'])
        
        # 3. 获取概率分布 [P(A), P(B), P(C), ...]
        try:
            probabilities = get_choice_probabilities(prompt, MODEL_ID, client, num_choices=num_choices)
        except Exception as e:
            print(f"Error processing question {i}: {e}")
            continue
            
        # 4. 获取标准答案
        ground_truth_label = sample['answer'] # 这是一个 0 到 num_choices-1 的索引
        
        # 5. 为 *每一个* 选项计算 APS 分数
        for j in range(num_choices):
            
            aps_score = calculate_aps_score(probabilities, j)
            
            # 6. 结构化保存
            row = {
                "question_id": f"{subject_name}_{i}",
                "subject": subject_name,
                "question": sample['question'],
                "choice_str": CHOICES[j],          # A, B, C, D, E, or F
                "choice_index": j,
                "choice_text": sample['choices'][j],
                "probability": probabilities[j],   # 模型对这个选项的原始概率
                "aps_score": aps_score,            # 这个选项的 APS 不一致性分数
                "is_ground_truth": (j == ground_truth_label) # 这是一个 bool 值
            }
            results_list.append(row)

print("\nAll processing complete.")


Processing subject: abstract_algebra...


100%|██████████| 100/100 [01:00<00:00,  1.65it/s]



Processing subject: anatomy...


100%|██████████| 135/135 [01:18<00:00,  1.71it/s]


All processing complete.





In [8]:
# 转换为 Pandas DataFrame
df_scores = pd.DataFrame(results_list)

# 保存到 Parquet 文件 (比 CSV 更高效)
# 修复 PyArrow 文件系统注册错误
import os
import pyarrow as pa
import pyarrow.parquet as pq

# 使用绝对路径
output_filename = os.path.abspath("mmlu_with_aps_scores.parquet")

# 尝试使用 PyArrow 的低级 API 直接写入，避免文件系统注册问题
try:
    # 方法1: 使用 pyarrow.parquet.write_table 直接写入
    table = pa.Table.from_pandas(df_scores, preserve_index=False)
    pq.write_table(table, output_filename)
except Exception as e:
    if "ArrowKeyError" in str(type(e).__name__) or "already registered" in str(e):
        # 方法2: 使用 pandas 的 to_parquet，但指定 engine
        try:
            df_scores.to_parquet(output_filename, index=False, engine='pyarrow')
        except:
            # 方法3: 尝试 fastparquet
            try:
                df_scores.to_parquet(output_filename, index=False, engine='fastparquet')
            except ImportError:
                # 方法4: 如果都失败，使用 CSV 作为后备
                csv_filename = output_filename.replace('.parquet', '.csv')
                print(f"Warning: PyArrow filesystem registry error. Saving as CSV instead: {csv_filename}")
                df_scores.to_csv(csv_filename, index=False)
                output_filename = csv_filename
            except Exception as e2:
                csv_filename = output_filename.replace('.parquet', '.csv')
                print(f"Warning: Could not save as Parquet ({e2}). Saving as CSV instead: {csv_filename}")
                df_scores.to_csv(csv_filename, index=False)
                output_filename = csv_filename
    else:
        raise

# 计算每个问题的选项数量（通过统计每个question_id的行数）
questions_per_row = df_scores.groupby('question_id').size()
avg_choices = questions_per_row.mean()
print(f"Successfully processed {len(df_scores)} rows ({len(questions_per_row)} questions).")
print(f"Average number of choices per question: {avg_choices:.1f}")
print(f"Data saved to {output_filename}")

# --- 验证一下我们的数据 ---
print("\n--- DataFrame Head ---")
print(df_scores.head(8))

print("\n--- Example: Scores for one question ---")
print(df_scores[df_scores['question_id'] == 'philosophy_0'][
    ['question_id', 'choice_str', 'probability', 'aps_score', 'is_ground_truth']
])

Successfully processed 940 rows (235 questions).
Average number of choices per question: 4.0
Data saved to /egr/research-hintlab/liuxin73/projects/conformal-factual-lm/ACI/MMLU copy/mmlu_with_aps_scores.csv

--- DataFrame Head ---
          question_id           subject  \
0  abstract_algebra_0  abstract_algebra   
1  abstract_algebra_0  abstract_algebra   
2  abstract_algebra_0  abstract_algebra   
3  abstract_algebra_0  abstract_algebra   
4  abstract_algebra_1  abstract_algebra   
5  abstract_algebra_1  abstract_algebra   
6  abstract_algebra_1  abstract_algebra   
7  abstract_algebra_1  abstract_algebra   

                                            question choice_str  choice_index  \
0  Find the degree for the given field extension ...          A             0   
1  Find the degree for the given field extension ...          B             1   
2  Find the degree for the given field extension ...          C             2   
3  Find the degree for the given field extension ...     