In [2]:
pip install jsonlines huggingface_hub

Looking in indexes: https://mirrors.aliyun.com/pypi/simple/
[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.3.2[0m[39;49m -> [0m[32;49m25.1.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m
Note: you may need to restart the kernel to use updated packages.


In [1]:
    import json
    import os
    from typing import List, Dict, Tuple
    import random
    from transformers import AutoTokenizer, AutoModelForCausalLM
    import torch
    from tqdm import tqdm
    import re
    from collections import defaultdict, Counter
    import time
    import logging
    import jsonlines
        
    # 检查环境和 GPU 信息
    print(f"PyTorch version: {torch.__version__}")
    print(f"CUDA available: {torch.cuda.is_available()}")
    if torch.cuda.is_available():
        print(f"CUDA version: {torch.version.cuda}")
        print(f"GPU: {torch.cuda.get_device_name(0)}")
        print(f"Total GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB")
    
    # 设置设备
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    
    # 1. 加载 GLM-4-9B 模型和分词器
    from transformers import AutoModelForCausalLM, AutoTokenizer

    model_path = "/mnt/workspace/dataroot/models/deepseek-ai/DeepSeek-R1-Distill-Llama-8B"
    tokenizer = AutoTokenizer.from_pretrained(model_path,local_files_only=True)
    tokenizer.padding_side = 'left'
    model = AutoModelForCausalLM.from_pretrained(
        model_path,
        torch_dtype=torch.float16,  # 使用 float16 减少显存
        device_map="cuda:0",  # 强制加载到 P100
        trust_remote_code=True,
        local_files_only=True,
    )
    model_name = "deepseek-ai/DeepSeek-R1-Distill-Llama-8B"  # 确认 GLM-4-9B 模型名称
    # tokenizer = AutoTokenizer.from_pretrained("/ds/", trust_remote_code=True)
    # tokenizer.padding_side = 'left'  # 确保与 Qwen2 一致
    # model = AutoModelForCausalLM.from_pretrained(
    #     model_name,
    #     torch_dtype=torch.float16,  # 使用 float16 减少显存
    #     device_map="cuda:0",  # 强制加载到 P100
    #     trust_remote_code=True,
    #     local_files_only=True
    # )
    model.eval()
    print("Model loaded successfully")
    print(f"Model device: {next(model.parameters()).device}")

  from .autonotebook import tqdm as notebook_tqdm


PyTorch version: 2.5.1+cu124
CUDA available: True
CUDA version: 12.4
GPU: NVIDIA A10
Total GPU memory: 21.98 GB
Using device: cuda:0


Loading checkpoint shards: 100%|██████████| 2/2 [00:02<00:00,  1.40s/it]

Model loaded successfully
Model device: cuda:0





In [1]:
#original
# 打印初始显存
if torch.cuda.is_available():
    print(f"GPU memory allocated: {torch.cuda.memory_allocated(0) / 1024**3:.2f} GB")
    print(f"GPU memory reserved: {torch.cuda.memory_reserved(0) / 1024**3:.2f} GB")

# 2. 构造提示函数
def create_zero_shot_prompt(passage: str, number: str) -> str:
    return f"""Answer with only 'Yes' or 'No'. Do not provide explanations. Is "{number}" in the following passage an error? "{passage}"
Answer:"""

def create_few_shot_prompt(passage: str, number: str) -> str:
    examples = [
        {"passage": "Spiders have 9 limbs.", "number": "9", "answer": "Yes"},
        {"passage": "Spiders have 8 limbs.", "number": "8", "answer": "No"},
        {"passage": "Mike's height is -3.6 meters.", "number": "-3.6", "answer": "Yes"},
        {"passage": "Mike's height is 1.8 meters.", "number": "1.8", "answer": "No"}
    ]
    prompt = "Answer with only 'Yes' or 'No'. Do not provide explanations.\n"
    for ex in examples:
        prompt += f"""Question: Is "{ex['number']}" in the following passage an error? "{ex['passage']}"
Answer: {ex['answer']}\n"""
    prompt += f"""Question: Is "{number}" in the following passage an error? "{passage}"
Answer:"""
    return prompt

# 3. 加载 BeNEDect 数据集
def load_benedect_dataset(file_path: str) -> List[Dict]:
    if not os.path.exists(file_path):
        raise FileNotFoundError(f"数据集文件 {file_path} 不存在，请确认路径！")
    try:
        with open(file_path, 'r', encoding='utf-8') as f:
            dataset_dict = json.load(f)
    except json.JSONDecodeError as e:
        raise ValueError(f"JSON 文件解析错误：{e}")
    
    dataset = list(dataset_dict.values())
    save_list = []
    
    for i, data in enumerate(tqdm(dataset, desc="Processing dataset")):
        required_fields = ['correct_number', 'correct_passage', 'error_number', 'error_passage', 'dataset', 'operation']
        for field in required_fields:
            if field not in data:
                print(f"样本 {data.get('id', '未知')} 缺少字段 {field}")
                continue
        
        prompt_fn = create_few_shot_prompt if i % 48 == 0 else create_zero_shot_prompt
        correct_item = {
            "prompt": prompt_fn(data['correct_passage'], data['correct_number']),
            "expected_answer": "No",
            "dataset": data['dataset'],
            "operation": data['operation'],
            "error_annotation": data.get('error_annotation', {}),
            "passage": data['correct_passage'],
            "number": data['correct_number'],
            "prompt_type": "few_shot" if i % 48 == 0 else "zero_shot"
        }
        error_item = {
            "prompt": prompt_fn(data['error_passage'], data['error_number']),
            "expected_answer": "Yes",
            "dataset": data['dataset'],
            "operation": data['operation'],
            "error_annotation": data.get('error_annotation', {}),
            "passage": data['error_passage'],
            "number": data['error_number'],
            "prompt_type": "few_shot" if i % 48 == 0 else "zero_shot"
        }
        save_list.append(correct_item)
        save_list.append(error_item)
    
    return save_list

# 4. 单条推理
def predict_single(prompt: str, max_retries: int = 3) -> str:
    print(f"Single prediction prompt: {prompt[:100]}...")
    attempt = 0
    success = False
    prediction = None
    
    while attempt < max_retries and not success:
        try:
            inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
            inputs = {k: v.to(device) for k, v in inputs.items()}
            print(f"Input device: {inputs['input_ids'].device}")
            print(f"Input shape: {inputs['input_ids'].shape}")
            
            with torch.no_grad():
                outputs = model.generate(
                    **inputs,
                    max_new_tokens=5,
                    pad_token_id=tokenizer.pad_token_id,
                    eos_token_id=tokenizer.eos_token_id,
                    do_sample=True,
                    temperature=0.3,
                    top_p=0.5
                )
            
            prediction = tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True).strip()
            success = True
            print(f"Single prediction success: Raw Prediction: {prediction}")
            
        except RuntimeError as e:
            print(f"单条推理失败（尝试 {attempt + 1}/{max_retries}）：{e}")
            attempt += 1
            torch.cuda.empty_cache()
            time.sleep(1)
            if attempt == max_retries:
                print("单条推理失败，跳过")
                prediction = "generation_error"
        
        finally:
            if 'inputs' in locals():
                for v in inputs.values():
                    del v
            torch.cuda.empty_cache()
            print(f"GPU memory after single prediction: {torch.cuda.memory_allocated(0) / 1024**3:.2f} GB")
    
    return prediction

# 5. 批次推理
def predict_batch(prompts: List[str], batch_size: int = 8, max_retries: int = 3) -> List[str]:
    predictions = []
    
    for i in tqdm(range(0, len(prompts), batch_size), desc="Predicting"):
        batch_prompts = prompts[i:i + batch_size]
        attempt = 0
        success = False
        batch_preds = None
        
        while attempt < max_retries and not success:
            try:
                # 强制统一序列长度，检查张量形状
                inputs = tokenizer(
                    batch_prompts,
                    return_tensors="pt",
                    padding=True,
                    truncation=True,
                    max_length=512,
                    return_attention_mask=True
                )
                inputs = {k: v.to(device) for k, v in inputs.items()}
                print(f"Batch {i//batch_size} input shapes: input_ids={inputs['input_ids'].shape}, attention_mask={inputs['attention_mask'].shape}")
                
                with torch.no_grad():
                    outputs = model.generate(
                        **inputs,
                        max_new_tokens=5,
                        pad_token_id=tokenizer.pad_token_id,
                        eos_token_id=tokenizer.eos_token_id,
                        do_sample=True,
                        temperature=0.3,
                        top_p=0.5
                    )
                
                batch_preds = [
                    tokenizer.decode(output[inputs['input_ids'].shape[1]:], skip_special_tokens=True).strip()
                    for output in outputs
                ]
                success = True
                if i % (10 * batch_size) == 0:
                    print(f"批次 {i//batch_size}: 原始预测: {batch_preds}")
                
            except RuntimeError as e:
                print(f"批次推理失败（尝试 {attempt + 1}/{max_retries}）：{e}")
                attempt += 1
                torch.cuda.empty_cache()
                time.sleep(1)
                if attempt == max_retries:
                    print(f"批次 {i//batch_size} 推理失败，跳过")
                    batch_preds = ["generation_error"] * len(batch_prompts)
            
            finally:
                if 'inputs' in locals():
                    for v in inputs.values():
                        del v
                torch.cuda.empty_cache()
                print(f"GPU memory after batch {i//batch_size}: {torch.cuda.memory_allocated(0) / 1024**3:.2f} GB")
        
        predictions.extend(batch_preds)
    
    return predictions

# 6. 保存预测结果到 JSONL
def save_predictions_to_jsonl(data: List[Dict], predictions: List[str], output_file: str):
    with jsonlines.open(output_file, mode='w') as writer:
        for item, pred in zip(data, predictions):
            result = {
                "prompt": item['prompt'],
                "passage": item['passage'],
                "number": item['number'],
                "expected_answer": item['expected_answer'],
                "raw_prediction": pred,
                "dataset": item['dataset'],
                "operation": item['operation'],
                "error_annotation": item['error_annotation'],
                "prompt_type": item['prompt_type']
            }
            writer.write(result)

NameError: name 'torch' is not defined

In [30]:
# 打印初始显存
if torch.cuda.is_available():
    print(f"GPU memory allocated: {torch.cuda.memory_allocated(0) / 1024**3:.2f} GB")
    print(f"GPU memory reserved: {torch.cuda.memory_reserved(0) / 1024**3:.2f} GB")

# 2. 构造提示函数
"""
log:

prompt changes:
### ROLE. You are an expert fact-checker specializing in numerical accuracy across biology, physics, history, mathematics, and everyday scenarios. ### TASK. Determine if the given number contains a factual error within the provided context. ### ANALYSIS PROCESS. Follow this reasoning sequence: 1. CONTEXT ANALYSIS: Identify the domain and type of measurement being described 2. GENERATED KNOWLEDGE: Recall established facts, typical ranges, and known standards for this specific domain and measurement type 3. PLAUSIBILITY CHECK: Compare the number against expected ranges, physical laws, biological constraints, historical accuracy, and mathematical consistency 4. ERROR DETECTION: Check for biological impossibilities, physical violations, historical inaccuracies, mathematical contradictions, or scale/magnitude errors. ### OUTPUT. Answer must only be "yes" or "no",do not provide any explanations and anything other than "yes" or "no". Yes = Contains factual error, No = Factually accurate. ### EVALUATION. Number to evaluate: "{number}",Passage: "{passage}". Is "{number}" in the following passage an error? "{passage}"
Answer:
### ROLE. You are an expert fact-checker specializing in numerical accuracy across biology, physics, history, mathematics, and everyday scenarios. ### TASK. Determine if the given number contains a factual error within the provided context. ### ANALYSIS PROCESS. Follow this reasoning sequence: 1. CONTEXT ANALYSIS: Identify the domain and type of measurement being described 2. GENERATED KNOWLEDGE: Recall established facts, typical ranges, and known standards for this specific domain and measurement type 3. PLAUSIBILITY CHECK: Compare the number against expected ranges, physical laws, biological constraints, historical accuracy, and mathematical consistency 4. ERROR DETECTION: Check for biological impossibilities, physical violations, historical inaccuracies, mathematical contradictions, or scale/magnitude errors. ### OUTPUT. Answer must only be "Yes" or "No",do not provide explanations. Yes = Contains factual error, No = Factually accurate. ### EVALUATION. Number to evaluate: "{number}",Passage: "{passage}". Is "{number}" in the following passage an error? "{passage}"
Answer:
Overall Metrics:
TN: 896 (0.187)
FN: 1036 (0.216)
FP: 1504 (0.313)
TP: 1364 (0.284)
Generation Error: 889 (0.185)
Accuracy: 0.471


### ROLE. You are an expert fact-checker specializing in numerical accuracy across biology, physics, history, mathematics, and everyday scenarios. ### TASK. Determine if the given number contains a factual error within the provided context. ### ANALYSIS PROCESS. Follow this reasoning sequence: 1. CONTEXT ANALYSIS: Identify the domain and type of measurement being described 2. GENERATED KNOWLEDGE: Recall established facts, typical ranges, and known standards for this specific domain and measurement type 3. PLAUSIBILITY CHECK: Compare the number against expected ranges, physical laws, biological constraints, historical accuracy, and mathematical consistency 4. ERROR DETECTION: Check for biological impossibilities, physical violations, historical inaccuracies, mathematical contradictions, or scale/magnitude errors. ### OUTPUT FORMAT. Answer must only be "yes" or "no". Do not provide any any explanations. ### EVALUATION. Is "{number}" in the following passage an error? "{passage}"?
Answer:
### ROLE. You are an expert fact-checker specializing in numerical accuracy across biology, physics, history, mathematics, and everyday scenarios. ### TASK. Determine if the given number contains a factual error within the provided context. ### ANALYSIS PROCESS. Follow this reasoning sequence: 1. CONTEXT ANALYSIS: Identify the domain and type of measurement being described 2. GENERATED KNOWLEDGE: Recall established facts, typical ranges, and known standards for this specific domain and measurement type 3. PLAUSIBILITY CHECK: Compare the number against expected ranges, physical laws, biological constraints, historical accuracy, and mathematical consistency 4. ERROR DETECTION: Check for biological impossibilities, physical violations, historical inaccuracies, mathematical contradictions, or scale/magnitude errors. ### OUTPUT FORMAT. Answer must only be "Yes" or "No". Do not provide any explanations. ### EVALUATION. Is "{number}" in the following passage an error? "{passage}"?
Answer:
Overall Metrics:
TN: 733 (0.153)
FN: 850 (0.177)
FP: 1667 (0.347)
TP: 1550 (0.323)
Generation Error: 371 (0.077)
Accuracy: 0.476

majority volting:
Overall Metrics:
TN: 2059 (0.429)
FN: 1748 (0.364)
TP: 652 (0.136)
FP: 341 (0.071)
Accuracy: 0.565

"""






def create_zero_shot_prompt(passage: str, number: str) -> str:
    return f"""### ROLE. You are an expert fact-checker specializing in numerical accuracy across biology, physics, history, mathematics, and everyday scenarios. ### TASK. Determine if the given number contains a factual error within the provided context. ### ANALYSIS PROCESS. Follow this reasoning sequence: 1. CONTEXT ANALYSIS: Identify the domain and type of measurement being described 2. GENERATED KNOWLEDGE: Recall established facts, typical ranges, and known standards for this specific domain and measurement type 3. PLAUSIBILITY CHECK: Compare the number against expected ranges, physical laws, biological constraints, historical accuracy, and mathematical consistency 4. ERROR DETECTION: Check for biological impossibilities, physical violations, historical inaccuracies, mathematical contradictions, or scale/magnitude errors. ### OUTPUT FORMAT. Answer must only be "yes" or "no". Do not provide any any explanations. ### EVALUATION. Is "{number}" in the following passage an error? "{passage}"?
Answer:"""

def create_few_shot_prompt(passage: str, number: str) -> str:
    examples = [
        {"passage": "Spiders have 9 limbs.", "number": "9", "answer": "Yes"},
        {"passage": "Spiders have 8 limbs.", "number": "8", "answer": "No"},
        {"passage": "Mike's height is -3.6 meters.", "number": "-3.6", "answer": "Yes"},
        {"passage": "Mike's height is 1.8 meters.", "number": "1.8", "answer": "No"}
    ]
    prompt = f"""### ROLE. You are an expert fact-checker specializing in numerical accuracy across biology, physics, history, mathematics, and everyday scenarios. ### TASK. Determine if the given number contains a factual error within the provided context. ### ANALYSIS PROCESS. Follow this reasoning sequence: 1. CONTEXT ANALYSIS: Identify the domain and type of measurement being described 2. GENERATED KNOWLEDGE: Recall established facts, typical ranges, and known standards for this specific domain and measurement type 3. PLAUSIBILITY CHECK: Compare the number against expected ranges, physical laws, biological constraints, historical accuracy, and mathematical consistency 4. ERROR DETECTION: Check for biological impossibilities, physical violations, historical inaccuracies, mathematical contradictions, or scale/magnitude errors. ### OUTPUT FORMAT. Answer must only be "Yes" or "No". Do not provide any explanations. ### EVALUATION. Is "{number}" in the following passage an error? "{passage}"?
Answer:"""
    for ex in examples:
        prompt += f"""Question: Is "{ex['number']}" in the following passage an error? "{ex['passage']}"
Answer: {ex['answer']}\n"""
    prompt += f"""Question: Is "{number}" in the following passage an error? "{passage}"
Answer:"""
    return prompt

# 3. 加载 BeNEDect 数据集
def load_benedect_dataset(file_path: str) -> List[Dict]:
    if not os.path.exists(file_path):
        raise FileNotFoundError(f"数据集文件 {file_path} 不存在，请确认路径！")
    try:
        with open(file_path, 'r', encoding='utf-8') as f:
            dataset_dict = json.load(f)
    except json.JSONDecodeError as e:
        raise ValueError(f"JSON 文件解析错误：{e}")
    
    dataset = list(dataset_dict.values())
    dataset = dataset[:len(dataset) // 2]#!!!
    save_list = []
    
    for i, data in enumerate(tqdm(dataset, desc="Processing dataset")):
        required_fields = ['correct_number', 'correct_passage', 'error_number', 'error_passage', 'dataset', 'operation']
        for field in required_fields:
            if field not in data:
                print(f"样本 {data.get('id', '未知')} 缺少字段 {field}")
                continue
        
        prompt_fn = create_few_shot_prompt if i % 48 == 0 else create_zero_shot_prompt
        correct_item = {
            "prompt": prompt_fn(data['correct_passage'], data['correct_number']),
            "expected_answer": "No",
            "dataset": data['dataset'],
            "operation": data['operation'],
            "error_annotation": data.get('error_annotation', {}),
            "passage": data['correct_passage'],
            "number": data['correct_number'],
            "prompt_type": "few_shot" if i % 48 == 0 else "zero_shot"
        }
        error_item = {
            "prompt": prompt_fn(data['error_passage'], data['error_number']),
            "expected_answer": "Yes",
            "dataset": data['dataset'],
            "operation": data['operation'],
            "error_annotation": data.get('error_annotation', {}),
            "passage": data['error_passage'],
            "number": data['error_number'],
            "prompt_type": "few_shot" if i % 48 == 0 else "zero_shot"
        }
        save_list.append(correct_item)
        save_list.append(error_item)
    
    return save_list

# 4. 单条推理
def predict_single(prompt: str, max_retries: int = 3) -> str:
    print(f"Single prediction prompt: {prompt[:100]}...")
    attempt = 0
    success = False
    prediction = None
    
    while attempt < max_retries and not success:
        try:
            inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
            inputs = {k: v.to(device) for k, v in inputs.items()}
            print(f"Input device: {inputs['input_ids'].device}")
            print(f"Input shape: {inputs['input_ids'].shape}")
            
            with torch.no_grad():
                outputs = model.generate(
                    **inputs,
                    max_new_tokens=5,
                    pad_token_id=tokenizer.pad_token_id,
                    eos_token_id=tokenizer.eos_token_id,
                    do_sample=True,
                    temperature=0.3,
                    top_p=0.8
                )
            
            prediction = tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True).strip()
            success = True
            print(f"Single prediction success: Raw Prediction: {prediction}")
            
        except RuntimeError as e:
            print(f"单条推理失败（尝试 {attempt + 1}/{max_retries}）：{e}")
            attempt += 1
            torch.cuda.empty_cache()
            time.sleep(1)
            if attempt == max_retries:
                print("单条推理失败，跳过")
                prediction = "generation_error"
        
        finally:
            if 'inputs' in locals():
                for v in inputs.values():
                    del v
            torch.cuda.empty_cache()
            print(f"GPU memory after single prediction: {torch.cuda.memory_allocated(0) / 1024**3:.2f} GB")
    
    return prediction

# 5. 批次推理
def predict_batch(prompts: List[str], batch_size: int = 8, max_retries: int = 3) -> List[str]:
    predictions = []
    
    for i in tqdm(range(0, len(prompts), batch_size), desc="Predicting"):
        batch_prompts = prompts[i:i + batch_size]
        attempt = 0
        success = False
        batch_preds = None
        
        while attempt < max_retries and not success:
            try:
                # 强制统一序列长度，检查张量形状
                inputs = tokenizer(
                    batch_prompts,
                    return_tensors="pt",
                    padding=True,
                    truncation=True,
                    max_length=512,
                    return_attention_mask=True
                )
                inputs = {k: v.to(device) for k, v in inputs.items()}
                print(f"Batch {i//batch_size} input shapes: input_ids={inputs['input_ids'].shape}, attention_mask={inputs['attention_mask'].shape}")
                
                with torch.no_grad():
                    outputs = model.generate(
                        **inputs,
                        max_new_tokens=5,
                        pad_token_id=tokenizer.pad_token_id,
                        eos_token_id=tokenizer.eos_token_id,
                        do_sample=True,
                        temperature=0.3,
                        top_p=0.8
                    )
                
                batch_preds = [
                    tokenizer.decode(output[inputs['input_ids'].shape[1]:], skip_special_tokens=True).strip()
                    for output in outputs
                ]
                success = True
                if i % (10 * batch_size) == 0:
                    print(f"批次 {i//batch_size}: 原始预测: {batch_preds}")
                
            except RuntimeError as e:
                print(f"批次推理失败（尝试 {attempt + 1}/{max_retries}）：{e}")
                attempt += 1
                torch.cuda.empty_cache()
                time.sleep(1)
                if attempt == max_retries:
                    print(f"批次 {i//batch_size} 推理失败，跳过")
                    batch_preds = ["generation_error"] * len(batch_prompts)
            
            finally:
                if 'inputs' in locals():
                    for v in inputs.values():
                        del v
                torch.cuda.empty_cache()
                print(f"GPU memory after batch {i//batch_size}: {torch.cuda.memory_allocated(0) / 1024**3:.2f} GB")
        
        predictions.extend(batch_preds)
    
    return predictions

# 6. 保存预测结果到 JSONL
def save_predictions_to_jsonl(data: List[Dict], predictions: List[str], output_file: str):
    with jsonlines.open(output_file, mode='w') as writer:
        for item, pred in zip(data, predictions):
            result = {
                "prompt": item['prompt'],
                "passage": item['passage'],
                "number": item['number'],
                "expected_answer": item['expected_answer'],
                "raw_prediction": pred,
                "dataset": item['dataset'],
                "operation": item['operation'],
                "error_annotation": item['error_annotation'],
                "prompt_type": item['prompt_type']
            }
            writer.write(result)

GPU memory allocated: 15.08 GB
GPU memory reserved: 15.18 GB


In [2]:
#deepseek edit majority voting




if torch.cuda.is_available():
    print(f"GPU memory allocated: {torch.cuda.memory_allocated(0) / 1024**3:.2f} GB")
    print(f"GPU memory reserved: {torch.cuda.memory_reserved(0) / 1024**3:.2f} GB")

    


def create_zero_shot_prompt(passage: str, number: str) -> str:
    return f"""### ROLE. You are an expert fact-checker specializing in numerical accuracy across biology, physics, history, mathematics, and everyday scenarios. ### TASK. Determine if the given number contains a factual error within the provided context. ### ANALYSIS PROCESS. Follow this reasoning sequence: 1. CONTEXT ANALYSIS: Identify the domain and type of measurement being described 2. GENERATED KNOWLEDGE: Recall established facts, typical ranges, and known standards for this specific domain and measurement type 3. PLAUSIBILITY CHECK: Compare the number against expected ranges, physical laws, biological constraints, historical accuracy, and mathematical consistency 4. ERROR DETECTION: Check for biological impossibilities, physical violations, historical inaccuracies, mathematical contradictions, or scale/magnitude errors. ### OUTPUT FORMAT. Answer must only be "yes" or "no". Do not provide any any explanations. ### EVALUATION. Is "{number}" in the following passage an error? "{passage}"?
Answer:"""

def create_few_shot_prompt(passage: str, number: str) -> str:
    examples = [
        {"passage": "Spiders have 9 limbs.", "number": "9", "answer": "Yes"},
        {"passage": "Spiders have 8 limbs.", "number": "8", "answer": "No"},
        {"passage": "Mike's height is -3.6 meters.", "number": "-3.6", "answer": "Yes"},
        {"passage": "Mike's height is 1.8 meters.", "number": "1.8", "answer": "No"}
    ]
    prompt = f"""### ROLE. You are an expert fact-checker specializing in numerical accuracy across biology, physics, history, mathematics, and everyday scenarios. ### TASK. Determine if the given number contains a factual error within the provided context. ### ANALYSIS PROCESS. Follow this reasoning sequence: 1. CONTEXT ANALYSIS: Identify the domain and type of measurement being described 2. GENERATED KNOWLEDGE: Recall established facts, typical ranges, and known standards for this specific domain and measurement type 3. PLAUSIBILITY CHECK: Compare the number against expected ranges, physical laws, biological constraints, historical accuracy, and mathematical consistency 4. ERROR DETECTION: Check for biological impossibilities, physical violations, historical inaccuracies, mathematical contradictions, or scale/magnitude errors. ### OUTPUT FORMAT. Answer must only be "Yes" or "No". Do not provide any explanations. ### EVALUATION. Is "{number}" in the following passage an error? "{passage}"?
Answer:"""
    for ex in examples:
        prompt += f"""Question: Is "{ex['number']}" in the following passage an error? "{ex['passage']}"
Answer: {ex['answer']}\n"""
    prompt += f"""Question: Is "{number}" in the following passage an error? "{passage}"
Answer:"""
    return prompt

# 3. 加载 BeNEDect 数据集
def load_benedect_dataset(file_path: str) -> List[Dict]:
    if not os.path.exists(file_path):
        raise FileNotFoundError(f"数据集文件 {file_path} 不存在，请确认路径！")
    try:
        with open(file_path, 'r', encoding='utf-8') as f:
            dataset_dict = json.load(f)
    except json.JSONDecodeError as e:
        raise ValueError(f"JSON 文件解析错误：{e}")
    
    dataset = list(dataset_dict.values())
    #dataset = dataset[:len(dataset) // 2]#!!!
    save_list = []
    
    for i, data in enumerate(tqdm(dataset, desc="Processing dataset")):
        required_fields = ['correct_number', 'correct_passage', 'error_number', 'error_passage', 'dataset', 'operation']
        for field in required_fields:
            if field not in data:
                print(f"样本 {data.get('id', '未知')} 缺少字段 {field}")
                continue
        
        prompt_fn = create_few_shot_prompt if i % 48 == 0 else create_zero_shot_prompt
        correct_item = {
            "prompt": prompt_fn(data['correct_passage'], data['correct_number']),
            "expected_answer": "No",
            "dataset": data['dataset'],
            "operation": data['operation'],
            "error_annotation": data.get('error_annotation', {}),
            "passage": data['correct_passage'],
            "number": data['correct_number'],
            "prompt_type": "few_shot" if i % 48 == 0 else "zero_shot"
        }
        error_item = {
            "prompt": prompt_fn(data['error_passage'], data['error_number']),
            "expected_answer": "Yes",
            "dataset": data['dataset'],
            "operation": data['operation'],
            "error_annotation": data.get('error_annotation', {}),
            "passage": data['error_passage'],
            "number": data['error_number'],
            "prompt_type": "few_shot" if i % 48 == 0 else "zero_shot"
        }
        save_list.append(correct_item)
        save_list.append(error_item)
    
    return save_list

    
def majority_vote(responses: List[str]) -> str:
    """Determine majority vote from 3 responses (Yes/No)"""
    response_list_index=0
    valid_responses = []
    for r in responses:
        r_low = r.lower().strip()
        if r_low.startswith('y') or r_low == 'yes':
            valid_responses.append(["Yes",response_list_index])
        elif r_low.startswith('n') or r_low == 'no':
            valid_responses.append(["No",response_list_index])
        else:
            valid_responses.append(["unparsed",response_list_index])
        response_list_index+=1
    

    
    count_yes = valid_responses.count("Yes")
    count_no = valid_responses.count("No")
    
    if count_yes > count_no:
        return #return the first yes response, responses[valid_responses[0][1]], if the first response in valid_responses is yes
    elif count_no > count_yes:
        return #return the first no response, responses[valid_responses[0][1]], if the first response in valid_responses is no
    else:
        return #return the first unparsed response, responses[valid_responses[0][1]], if the first response in valid_responses is unparsed

def predict_single(prompt: str, max_retries: int = 3) -> str:
    print(f"Single prediction prompt: {prompt[:100]}...")
    attempt = 0
    success = False
    prediction = None
    
    while attempt < max_retries and not success:
        try:
            inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
            inputs = {k: v.to(device) for k, v in inputs.items()}
            
            with torch.no_grad():
                # Generate 3 sequences for voting
                outputs = model.generate(
                    **inputs,
                    max_new_tokens=5,
                    pad_token_id=tokenizer.pad_token_id,
                    eos_token_id=tokenizer.eos_token_id,
                    do_sample=True,
                    temperature=0.3,
                    top_p=0.5,
                    num_return_sequences=3  # Key change: get 3 responses
                )
            
            # Decode all 3 responses
            responses = []
            input_length = inputs['input_ids'].shape[1]
            for i in range(3):
                gen_tokens = outputs[i][input_length:]
                responses.append(
                    tokenizer.decode(gen_tokens, skip_special_tokens=True).strip()
                )
            
            # Apply majority voting
            prediction = majority_vote(responses)
            success = True
            
        except RuntimeError as e:
            print(f"Retry {attempt+1}/{max_retries}: {e}")
            attempt += 1
            torch.cuda.empty_cache()
            time.sleep(1)
            if attempt == max_retries:
                prediction = "generation_error"
        finally:
            torch.cuda.empty_cache()
    
    return prediction


def predict_batch(prompts: List[str], batch_size: int = 4, max_retries: int = 3) -> List[str]:
    predictions = []
    
    for i in tqdm(range(0, len(prompts), batch_size), desc="Predicting"):
        batch_prompts = prompts[i:i + batch_size]
        attempt = 0
        success = False
        batch_preds = ["generation_error"] * len(batch_prompts)  # Default
        
        while attempt < max_retries and not success:
            try:
                inputs = tokenizer(
                    batch_prompts,
                    return_tensors="pt",
                    padding=True,
                    truncation=True,
                    max_length=512
                )
                inputs = {k: v.to(device) for k, v in inputs.items()}
                
                with torch.no_grad():
                    # Generate 3 sequences per prompt
                    outputs = model.generate(
                        **inputs,
                        max_new_tokens=5,
                        pad_token_id=tokenizer.pad_token_id,
                        eos_token_id=tokenizer.eos_token_id,
                        do_sample=True,
                        temperature=0.3,
                        top_p=0.5,
                        num_return_sequences=3  # Key change
                    )
                
                # Process responses (batch_size * 3 sequences)
                input_length = inputs['input_ids'].shape[1]
                batch_preds = []
                for j in range(len(batch_prompts)):
                    responses = []
                    for k in range(3):
                        idx = j * 3 + k
                        gen_tokens = outputs[idx][input_length:]
                        responses.append(
                            tokenizer.decode(gen_tokens, skip_special_tokens=True).strip()
                        )
                    batch_preds.append(majority_vote(responses))
                
                success = True
                
            except RuntimeError as e:
                print(f"Batch {i}-{i+batch_size} retry {attempt+1}/{max_retries}: {e}")
                attempt += 1
                torch.cuda.empty_cache()
                time.sleep(2)
            finally:
                torch.cuda.empty_cache()
        
        predictions.extend(batch_preds)
    
    return predictions
# 6. 保存预测结果到 JSONL
def save_predictions_to_jsonl(data: List[Dict], predictions: List[str], output_file: str):
    with jsonlines.open(output_file, mode='w') as writer:
        for item, pred in zip(data, predictions):
            result = {
                "prompt": item['prompt'],
                "passage": item['passage'],
                "number": item['number'],
                "expected_answer": item['expected_answer'],
                "raw_prediction": pred,
                "dataset": item['dataset'],
                "operation": item['operation'],
                "error_annotation": item['error_annotation'],
                "prompt_type": item['prompt_type']
            }
            writer.write(result)

GPU memory allocated: 14.96 GB
GPU memory reserved: 15.18 GB


In [3]:
"""
majority voting performance
Overall Metrics:
TN: 2059 (0.429)
FN: 1748 (0.364)
TP: 652 (0.136)
FP: 341 (0.071)
Accuracy: 0.565
"""
file_path = "/mnt/workspace/model_eval_first/BeNEDect_all.json"  # 确认路径
dataset = load_benedect_dataset(file_path)

random_sample = random.choice(dataset)
print("=== 随机样本推理 ===")
print(f"Passage: {random_sample['passage']}")
print(f"Number: {random_sample['number']}")
print(f"Expected Answer: {random_sample['expected_answer']}")
print(f"Prompt Type: {random_sample['prompt_type']}")
print(f"Prompt:\n{random_sample['prompt']}")

random_pred = predict_single(random_sample['prompt'])
print(f"Raw Prediction: {random_pred}")

prompts = [item['prompt'] for item in dataset]
predictions = predict_batch(prompts, batch_size=4, max_retries=3)  # 减小批次大小
output_file = "/mnt/workspace/model_eval_first/deepseek_8b/deepseek_predictions.jsonl"
save_predictions_to_jsonl(dataset, predictions, output_file)
print(f"预测结果已保存到 {output_file}")

Processing dataset: 100%|██████████| 4800/4800 [00:00<00:00, 244952.66it/s]


=== 随机样本推理 ===
Passage: 48 PA school districts asking for waiver on 180 day class requirement
Number: 48
Expected Answer: No
Prompt Type: zero_shot
Prompt:
### ROLE. You are an expert fact-checker specializing in numerical accuracy across biology, physics, history, mathematics, and everyday scenarios. ### TASK. Determine if the given number contains a factual error within the provided context. ### ANALYSIS PROCESS. Follow this reasoning sequence: 1. CONTEXT ANALYSIS: Identify the domain and type of measurement being described 2. GENERATED KNOWLEDGE: Recall established facts, typical ranges, and known standards for this specific domain and measurement type 3. PLAUSIBILITY CHECK: Compare the number against expected ranges, physical laws, biological constraints, historical accuracy, and mathematical consistency 4. ERROR DETECTION: Check for biological impossibilities, physical violations, historical inaccuracies, mathematical contradictions, or scale/magnitude errors. ### OUTPUT FORMAT. A

Predicting: 100%|██████████| 2400/2400 [52:09<00:00,  1.30s/it]


预测结果已保存到 /mnt/workspace/model_eval_first/deepseek_8b/deepseek_predictions.jsonl


In [4]:
import json
from collections import Counter, defaultdict
from typing import List, Dict, Tuple

In [5]:
import re

def parse_prediction(raw_prediction: str) -> str:
    raw_prediction = raw_prediction.lower()
    if re.search(r'\byes\b', raw_prediction):
        return 'yes'
    elif re.search(r'\bno\b', raw_prediction):
        return 'no'
    else:
        # Optional: print(f"无法解析响应: {raw_prediction}")
        return 'generation_error'

def evaluate_model(data_list: List[Dict], unparsed_output_file: str = 'unparsed_predictions.json') -> Tuple[Dict, Dict]:
    metrics = Counter()
    detailed_metrics = {
        'by_domain': defaultdict(Counter),
        'by_error_type': defaultdict(Counter),
        'by_operation': defaultdict(Counter),
        'by_prompt_type': defaultdict(Counter)
    }
    unparsed_data = {}  # 存储无法解析的样本，格式为 {id: {...}}
    
    for idx, item in enumerate(data_list):
        expected = item['expected_answer'].lower()  # Yes/No 转为小写
        pred = parse_prediction(item['raw_prediction'])
        item['parsel_prediction'] = pred  # 保存解析结果
        
        # 如果无法解析，添加到 unparsed_data
        if pred == 'generation_error':
            # 只保存 expected_answer == "Yes" 的样本（错误样本）
            if expected == 'yes':
                sample_id = f"unparsed_{idx}"
                unparsed_data[sample_id] = {
                    "error_number": item['number'],
                    "error_passage": item['passage'],
                    "dataset": item['dataset'],
                    "operation": item['operation'],
                    "error_annotation": item['error_annotation'],
                    # 以下字段需补充（若有正确数据）
                    "correct_number": "",  # 需手动补充或从原始数据推导
                    "correct_passage": ""  # 需手动补充或从原始数据推导
                }
        
        domain = item['dataset']
        operation = item['operation']
        prompt_type = item['prompt_type']
        error_types = [k for k, v in item['error_annotation'].items() if v > 0]
        
        # 计算总体指标
        if pred == expected:
            if expected == 'yes':
                metrics['TP'] += 1
                for et in error_types:
                    detailed_metrics['by_error_type'][et]['TP'] += 1
                detailed_metrics['by_domain'][domain]['TP'] += 1
                detailed_metrics['by_operation'][operation]['TP'] += 1
                detailed_metrics['by_prompt_type'][prompt_type]['TP'] += 1
            else:  # expected == 'no'
                metrics['TN'] += 1
                for et in error_types:
                    detailed_metrics['by_error_type'][et]['TN'] += 1
                detailed_metrics['by_domain'][domain]['TN'] += 1
                detailed_metrics['by_operation'][operation]['TN'] += 1
                detailed_metrics['by_prompt_type'][prompt_type]['TN'] += 1
        else:
            if expected == 'yes':
                metrics['FN'] += 1
                for et in error_types:
                    detailed_metrics['by_error_type'][et]['FN'] += 1
                detailed_metrics['by_domain'][domain]['FN'] += 1
                detailed_metrics['by_operation'][operation]['FN'] += 1
                detailed_metrics['by_prompt_type'][prompt_type]['FN'] += 1
            else:  # expected == 'no'
                metrics['FP'] += 1
                for et in error_types:
                    detailed_metrics['by_error_type'][et]['FP'] += 1
                detailed_metrics['by_domain'][domain]['FP'] += 1
                detailed_metrics['by_operation'][operation]['FP'] += 1
                detailed_metrics['by_prompt_type'][prompt_type]['FP'] += 1
        
        if pred == 'generation_error':
            metrics['Generation Error'] += 1
            for et in error_types:
                detailed_metrics['by_error_type'][et]['Generation Error'] += 1
            detailed_metrics['by_domain'][domain]['Generation Error'] += 1
            detailed_metrics['by_operation'][operation]['Generation Error'] += 1
            detailed_metrics['by_prompt_type'][prompt_type]['Generation Error'] += 1
    
    # 保存无法解析的数据到 JSON
    if unparsed_data:
        with open(unparsed_output_file, 'w', encoding='utf-8') as f:
            json.dump(unparsed_data, f, indent=2, ensure_ascii=False)
        # print(f"无法解析的 {len(unparsed_data)} 条数据已保存到 {unparsed_output_file}")
        print("注意：JSON 文件仅包含 expected_answer='Yes' 的样本，correct_number 和 correct_passage 需手动补充")
    else:
        print("没有无法解析的数据")
    
    total = len(data_list)
    metrics['Accuracy'] = (metrics['TP'] + metrics['TN']) / total if total > 0 else 0
    return metrics, detailed_metrics

In [6]:
# 读取 predictions.jsonl
data_list = []
input_file = 'deepseek_predictions.jsonl'  # 确认路径
with open(input_file, 'r', encoding='utf-8') as f:
    for line in f:
        data = json.loads(line)
        data_list.append(data)

# 评测模型并保存无法解析的数据
unparsed_output_file = 'deepseek_unparsed_predictions.json'
metrics, detailed_metrics = evaluate_model(data_list, unparsed_output_file)

# 打印总体指标
print("\nOverall Metrics:")
total = len(data_list)
for key, value in metrics.items():
    if key == 'Accuracy':
        print(f"{key}: {value:.3f}")
    else:
        print(f"{key}: {value} ({value / total:.3f})")

# 打印分维度指标
print("\nMetrics by Domain:")
for domain, counts in detailed_metrics['by_domain'].items():
    print(f"{domain}: {dict(counts)}")

print("\nMetrics by Error Type:")
for error_type, counts in detailed_metrics['by_error_type'].items():
    print(f"{error_type}: {dict(counts)}")

print("\nMetrics by Operation:")
for operation, counts in detailed_metrics['by_operation'].items():
    print(f"{operation}: {dict(counts)}")

print("\nMetrics by Prompt Type:")
for prompt_type, counts in detailed_metrics['by_prompt_type'].items():
    print(f"{prompt_type}: {dict(counts)}")

没有无法解析的数据

Overall Metrics:
TN: 1695 (0.177)
FN: 1634 (0.170)
FP: 3105 (0.323)
TP: 3166 (0.330)
Accuracy: 0.506

Metrics by Domain:
Numeracy_600K_article_title: {'TN': 321, 'FN': 313, 'FP': 679, 'TP': 687}
aclsent: {'FP': 523, 'TP': 604, 'FN': 344, 'TN': 425}
DROP: {'FP': 448, 'TP': 439, 'TN': 546, 'FN': 555}
qa-text-source-comparison: {'FP': 630, 'TP': 615, 'TN': 294, 'FN': 309}
FinNum: {'FP': 825, 'TP': 821, 'FN': 113, 'TN': 109}

Metrics by Error Type:
Error in Number Relationships: {'TN': 66, 'FN': 59, 'TP': 137, 'FP': 130}
Undetectable Error: {'FP': 346, 'TP': 351, 'TN': 118, 'FN': 113}
Type Error: {'FP': 339, 'FN': 191, 'TN': 179, 'TP': 327}
Anomaly: {'FP': 160, 'TP': 160, 'TN': 70, 'FN': 70}
Improper Data: {'FP': 23, 'TP': 25, 'TN': 6, 'FN': 4}
Factual Error: {'TN': 18, 'TP': 39, 'FP': 39, 'FN': 18}

Metrics by Operation:
*2: {'TN': 60, 'FN': 57, 'FP': 99, 'TP': 102}
-10: {'FP': 121, 'TP': 119, 'TN': 62, 'FN': 64}
+1: {'TN': 62, 'FN': 64, 'FP': 126, 'TP': 124}
*0.9: {'FP': 124, 