In [1]:
#version 1

from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
# -*- coding: utf-8 -*-
"""
Complete JSONL processor with all required functions
"""

import json
import os
import gc
import requests
import re
import time
from pathlib import Path
from typing import Iterator, Dict, Any

# ✅ API配置
API_KEY = "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"
MODEL = "gpt-4.1"

# ✅ 构建错误分析提示词
def build_error_prompt(question, true_whole_answer, sample_whole_answer):
    """构建用于错误分析的提示词"""
    return f"""
Here is a math problem, its correct answer, and a sample answer that may contain mistakes.

【Question】:
{question}

【Correct Answer】:
{true_whole_answer}

【Incorrect Answer】:
{sample_whole_answer}

Please help me:
1. Identify the earliest mistake in the incorrect answer and provide the complete sentence from that point.
2. Briefly explain why it is incorrect.
3. Find the fix sentence in correct answer that and fix the error.
4. Briefly explain why it can fix the error.

Please output in the following JSON format:
{{
  "first_error_sentence": "<sentence>",
  "error_reason": "<brief explanation>",
  "fix_sentence": "<sentence>",
  "fix_reason": "<brief explanation>"
}}
"""

# ✅ 调用GPT API
def call_custom_gpt_api(prompt):
    """调用OpenAI API"""
    headers = {
        "Authorization": f"Bearer {API_KEY}",
        "Content-Type": "application/json"
    }
    payload = {
        "model": MODEL,
        "messages": [
            {"role": "system", "content": "You are a meticulous and precise comparer."},
            {"role": "user", "content": prompt}
        ]
    }

    try:
        response = requests.post(
            "https://api.openai.com/v1/chat/completions",
            headers=headers,
            json=payload,
            timeout=30  # 添加超时
        )

        if response.status_code != 200:
            raise Exception(f"API request failed: {response.status_code}, {response.text}")

        return response.json()["choices"][0]["message"]["content"]

    except requests.exceptions.Timeout:
        raise Exception("API request timeout")
    except requests.exceptions.RequestException as e:
        raise Exception(f"API request error: {str(e)}")

# ✅ 查找句子在token序列中的位置
def find_sentence_span_indices_robust(fragment, token_probs):
    """
    返回 fragment 在 token_probs 中匹配到的 token 范围: (begin_index, end_index)
    使用去除空白字符的方式匹配
    """
    if not fragment or not token_probs:
        return -1, -1

    fragment_clean = re.sub(r"\s+", "", fragment)
    tokens = [entry["token"] for entry in token_probs]
    decoded_text = "".join(tokens)
    decoded_text_clean = re.sub(r"\s+", "", decoded_text)

    char_start_idx = decoded_text_clean.find(fragment_clean)
    if char_start_idx == -1:
        return -1, -1

    cumulative_len = 0
    begin_index = -1

    for idx, entry in enumerate(token_probs):
        token_clean = re.sub(r"\s+", "", entry["token"])
        prev_len = cumulative_len
        cumulative_len += len(token_clean)

        if begin_index == -1 and cumulative_len > char_start_idx:
            begin_index = idx
        if cumulative_len >= char_start_idx + len(fragment_clean):
            end_index = idx
            return begin_index, end_index

    return begin_index, len(token_probs) - 1  # fallback

class JSONLProcessor:
    """
    高效的JSONL处理器，支持内存管理和进度跟踪
    """

    def __init__(self, api_key: str, model: str = "gpt-4.1"):
        self.api_key = api_key
        self.model = model
        self.processed_count = 0
        self.error_count = 0
        self.start_time = None

    def convert_json_to_jsonl(self, input_path: str, output_path: str,
                             chunk_size: int = 1000):
        """
        将大JSON文件转换为JSONL，支持分块处理
        """
        print(f"🔄 Converting {input_path} to JSONL format...")

        # 创建输出目录
        os.makedirs(os.path.dirname(output_path), exist_ok=True)

        # 检查文件大小
        file_size = os.path.getsize(input_path)
        print(f"📊 Input file size: {file_size / (1024**3):.2f} GB")

        with open(input_path, 'r', encoding='utf-8') as infile:
            data = json.load(infile)

        total_items = len(data)
        print(f"📊 Total items to convert: {total_items}")

        with open(output_path, 'w', encoding='utf-8') as outfile:
            for i, (qid, sample) in enumerate(data.items()):
                line_data = {
                    "qid": qid,
                    "data": sample
                }
                outfile.write(json.dumps(line_data, ensure_ascii=False) + '\n')

                if (i + 1) % chunk_size == 0:
                    print(f"📈 Converted {i + 1}/{total_items} items...")
                    # 强制刷新到磁盘
                    outfile.flush()

        print(f"✅ Conversion complete! Saved to {output_path}")

        # 清理内存
        del data
        gc.collect()

    def process_jsonl_file(self, jsonl_path: str, output_path: str,
                          batch_size: int = 10, save_interval: int = 20):
        """
        流式处理JSONL文件，支持批处理和定期保存
        """
        self.start_time = time.time()
        results = {}

        # 如果输出文件已存在，加载已处理的结果
        if os.path.exists(output_path):
            print("📂 Loading existing results...")
            try:
                with open(output_path, 'r', encoding='utf-8') as f:
                    results = json.load(f)
                    self.processed_count = len(results)
                    print(f"📊 Loaded {self.processed_count} existing results")
            except (json.JSONDecodeError, FileNotFoundError):
                print("⚠️ Could not load existing results, starting fresh")
                results = {}

        # 创建输出目录
        os.makedirs(os.path.dirname(output_path), exist_ok=True)

        with open(jsonl_path, 'r', encoding='utf-8') as f:
            batch = []
            line_count = 0

            for line in f:
                try:
                    line_data = json.loads(line.strip())
                    qid = line_data["qid"]

                    # 跳过已处理的项目
                    if qid in results:
                        print(f"⏭️ Skipping already processed: {qid}")
                        continue

                    batch.append((qid, line_data["data"]))
                    line_count += 1

                    # 处理批次
                    if len(batch) >= batch_size:
                        self._process_batch(batch, results)
                        batch = []

                        # 定期保存和清理内存
                        if line_count % save_interval == 0:
                            self._save_results(results, output_path)
                            gc.collect()
                            self._print_progress()

                except json.JSONDecodeError as e:
                    print(f"⚠️ JSON decode error in line: {e}")
                    continue
                except Exception as e:
                    print(f"⚠️ Unexpected error processing line: {e}")
                    continue

            # 处理剩余的项目
            if batch:
                self._process_batch(batch, results)

        # 最终保存
        self._save_results(results, output_path)
        self._print_final_stats()

        return results

    def _process_batch(self, batch: list, results: dict):
        """处理一个批次的数据"""
        for qid, sample in batch:
            try:
                print(f"🔍 Processing {qid}...")
                result = self._process_single_sample(qid, sample)
                if result:
                    results[qid] = result
                    self.processed_count += 1
                    print(f"✅ Successfully processed {qid}")
                else:
                    print(f"⚠️ No valid result for {qid}")

            except Exception as e:
                print(f"⚠️ Error processing {qid}: {str(e)}")
                self.error_count += 1
                continue

    def _process_single_sample(self, qid: str, sample: dict) -> dict:
        """处理单个样本"""
        try:
            question = sample.get("question", "")
            true_final_result = sample.get("true_final_result", "")

            if not question or not true_final_result:
                print(f"⚠️ Missing question or true_final_result for {qid}")
                return None

            # 找到正样本
            correct_sampling_id = None
            correct_sample_answer = None

            for sampling_id in ["sampling0", "sampling1", "sampling2"]:
                if sampling_id not in sample:
                    continue
                sampling_data = sample[sampling_id]
                if sampling_data.get("final_result") == true_final_result:
                    correct_sampling_id = sampling_id
                    correct_sample_answer = sampling_data.get("whole_answer", "")
                    break

            if correct_sample_answer is None:
                print(f"⚠️ No correct sampling found for {qid}")
                return None

            sample_results = {}

            # 处理负样本
            for sampling_id in ["sampling0", "sampling1", "sampling2"]:
                if sampling_id not in sample:
                    continue
                sampling = sample[sampling_id]

                # 跳过正样本
                if sampling.get("final_result") == true_final_result:
                    continue

                incorrect_sample_answer = sampling.get("whole_answer", "")
                if not incorrect_sample_answer:
                    continue

                try:
                    print(f"  🔍 Analyzing {sampling_id}...")

                    # 调用API
                    prompt = build_error_prompt(question, correct_sample_answer, incorrect_sample_answer)
                    output = call_custom_gpt_api(prompt)

                    # 解析结果
                    output = output.strip().strip("```")
                    if output.startswith("json"):
                        output = output[4:].strip()

                    output_json = json.loads(output)

                    # 查找token索引
                    error_sentence = output_json.get("first_error_sentence", "")
                    fix_sentence = output_json.get("fix_sentence", "")

                    error_token_probs = sampling.get("token_probs", [])
                    correct_token_probs = sample[correct_sampling_id].get("token_probs", [])

                    error_begin_idx, error_end_idx = find_sentence_span_indices_robust(
                        error_sentence, error_token_probs
                    )
                    fix_begin_idx, fix_end_idx = find_sentence_span_indices_robust(
                        fix_sentence, correct_token_probs
                    )

                    sample_results[sampling_id] = {
                        "first_error_sentence": error_sentence,
                        "error_reason": output_json.get("error_reason", ""),
                        "fix_sentence": fix_sentence,
                        "fix_reason": output_json.get("fix_reason", ""),
                        "correct_sampling_id": correct_sampling_id,
                        "error_token_begin_index": error_begin_idx,
                        "error_token_end_index": error_end_idx,
                        "fix_token_begin_index": fix_begin_idx,
                        "fix_token_end_index": fix_end_idx
                    }

                    print(f"  ✅ Successfully analyzed {sampling_id}")

                except json.JSONDecodeError as e:
                    print(f"  ⚠️ JSON decode error for {sampling_id}: {e}")
                    continue
                except Exception as e:
                    print(f"  ⚠️ Error analyzing {sampling_id}: {e}")
                    continue

            return sample_results if sample_results else None

        except Exception as e:
            print(f"⚠️ Error in _process_single_sample for {qid}: {e}")
            return None

    def _save_results(self, results: dict, output_path: str):
        """保存结果到文件"""
        try:
            with open(output_path, 'w', encoding='utf-8') as f:
                json.dump(results, f, ensure_ascii=False, indent=2)
            print(f"💾 Saved {len(results)} results to {output_path}")
        except Exception as e:
            print(f"⚠️ Error saving results: {e}")

    def _print_progress(self):
        """打印进度信息"""
        elapsed = time.time() - self.start_time
        speed = self.processed_count / elapsed if elapsed > 0 else 0
        print(f"📊 Progress: {self.processed_count} processed, {self.error_count} errors, "
              f"{speed:.2f} items/sec, {elapsed:.1f}s elapsed")

    def _print_final_stats(self):
        """打印最终统计信息"""
        elapsed = time.time() - self.start_time
        print(f"\n🎉 Processing complete!")
        print(f"📊 Total processed: {self.processed_count}")
        print(f"⚠️ Total errors: {self.error_count}")
        print(f"⏱️ Total time: {elapsed:.1f}s")
        if elapsed > 0:
            print(f"🚀 Average speed: {self.processed_count / elapsed:.2f} items/sec")

# ✅ 主函数
def main():
    # 路径配置
    BASE_PATH = "/content/drive/MyDrive/Cluster-proj"
    range_tag = "700-731"

    input_json = f"{BASE_PATH}/output/llm_steps/whole_logits/deepseek7b-gsm-{range_tag}-hidden.json"
    output_jsonl = f"{BASE_PATH}/output/llm_steps/whole_logits/deepseek7b-gsm-{range_tag}.jsonl"
    output_results = f"{BASE_PATH}/output/error_fix_index/deepseek-7b-{range_tag}_error_fix_index.json"

    # 创建处理器
    processor = JSONLProcessor(API_KEY, MODEL)

    # 步骤1: 转换为JSONL（如果还没有转换）
    if not os.path.exists(output_jsonl):
        processor.convert_json_to_jsonl(input_json, output_jsonl)
    else:
        print(f"📂 JSONL file already exists: {output_jsonl}")

    # 步骤2: 处理JSONL文件
    print("\n🚀 Starting JSONL processing...")
    results = processor.process_jsonl_file(
        jsonl_path=output_jsonl,
        output_path=output_results,
        batch_size=1,      # 设置为1以便调试
        save_interval=5    # 每5个项目保存一次
    )

    print(f"\n✅ All processing complete! Results saved to {output_results}")

if __name__ == "__main__":
    main()

🔄 Converting /content/drive/MyDrive/Cluster-proj/output/llm_steps/whole_logits/deepseek7b-gsm-700-731-hidden.json to JSONL format...
📊 Input file size: 1.79 GB
📊 Total items to convert: 31
✅ Conversion complete! Saved to /content/drive/MyDrive/Cluster-proj/output/llm_steps/whole_logits/deepseek7b-gsm-700-731.jsonl

🚀 Starting JSONL processing...
🔍 Processing q_700...
  🔍 Analyzing sampling0...
  ✅ Successfully analyzed sampling0
  🔍 Analyzing sampling2...
  ✅ Successfully analyzed sampling2
✅ Successfully processed q_700
🔍 Processing q_701...
  🔍 Analyzing sampling1...
  ✅ Successfully analyzed sampling1
✅ Successfully processed q_701
🔍 Processing q_702...
⚠️ No valid result for q_702
🔍 Processing q_703...
  🔍 Analyzing sampling1...
  ✅ Successfully analyzed sampling1
✅ Successfully processed q_703
🔍 Processing q_704...
⚠️ No valid result for q_704
💾 Saved 3 results to /content/drive/MyDrive/Cluster-proj/output/error_fix_index/deepseek-7b-700-731_error_fix_index.json
📊 Progress: 3 proc

In [None]:

BASE_PATH = "/content/drive/MyDrive/Cluster-proj"
LOGITS_PATH = f"{BASE_PATH}/output/llm_steps/whole_logits/deepseek-math-7b-gsm-{range_tag}.json"

# ✅ Prompt builder
def build_error_prompt(question, true_whole_answer, sample_whole_answer):
    return f"""
Here is a math problem, its correct answer, and a sample answer that may contain mistakes.

【Question】:
{question}

【Correct Answer】:
{true_whole_answer}

【Incorrect Answer】:
{sample_whole_answer}

Please help me:
1. Identify the earliest mistake in the incorrect answer and provide the compelete sentence from that point.
2. Briefly explain why it is incorrect.
3. Find the fix sentence in correct answer that and fix the error.
4. Briefly explain why it can fix the error.

Please output in the following JSON format:
{{
  "first_error_sentence": "<sentence>",
  "error_reason": "<brief explanation>",
  "fix_sentence": "<sentence>",
  "fix_reason": "<brief explanation>"
}}
"""



# ✅ Call your custom GPT API
def call_custom_gpt_api(prompt):
    headers = {
        "Authorization": f"Bearer {API_KEY}",
        "Content-Type": "application/json"
    }
    payload = {
        "model": MODEL,
        "messages": [
            {"role": "system", "content": "You are a meticulous and precise comparer."},
            {"role": "user", "content": prompt}
        ]
    }
    response = requests.post(
        "https://api.openai.com/v1/chat/completions",
        headers=headers,
        json=payload
    )
    if response.status_code != 200:
        raise Exception(f"API request failed: {response.status_code}, {response.text}")
    return response.json()["choices"][0]["message"]["content"]



In [None]:
#version 2

# -*- coding: utf-8 -*-
"""
Advanced JSONL processing with memory management and progress tracking
"""

import json
import os
import gc
from pathlib import Path
import time
from typing import Iterator, Dict, Any

class JSONLProcessor:
    """
    高效的JSONL处理器，支持内存管理和进度跟踪
    """

    def __init__(self, api_key: str, model: str = "gpt-4.1"):
        self.api_key = api_key
        self.model = model
        self.processed_count = 0
        self.error_count = 0
        self.start_time = None

    def convert_json_to_jsonl(self, input_path: str, output_path: str,
                             chunk_size: int = 1000):
        """
        将大JSON文件转换为JSONL，支持分块处理
        """
        print(f"🔄 Converting {input_path} to JSONL format...")

        # 创建输出目录
        os.makedirs(os.path.dirname(output_path), exist_ok=True)

        # 检查文件大小
        file_size = os.path.getsize(input_path)
        print(f"📊 Input file size: {file_size / (1024**3):.2f} GB")

        with open(input_path, 'r', encoding='utf-8') as infile:
            data = json.load(infile)

        total_items = len(data)
        print(f"📊 Total items to convert: {total_items}")

        with open(output_path, 'w', encoding='utf-8') as outfile:
            for i, (qid, sample) in enumerate(data.items()):
                line_data = {
                    "qid": qid,
                    "data": sample
                }
                outfile.write(json.dumps(line_data, ensure_ascii=False) + '\n')

                if (i + 1) % chunk_size == 0:
                    print(f"📈 Converted {i + 1}/{total_items} items...")
                    # 强制刷新到磁盘
                    outfile.flush()

        print(f"✅ Conversion complete! Saved to {output_path}")

        # 清理内存
        del data
        gc.collect()

    def process_jsonl_file(self, jsonl_path: str, output_path: str,
                          batch_size: int = 10, save_interval: int = 50):
        """
        流式处理JSONL文件，支持批处理和定期保存
        """
        self.start_time = time.time()
        results = {}

        # 如果输出文件已存在，加载已处理的结果
        if os.path.exists(output_path):
            print("📂 Loading existing results...")
            with open(output_path, 'r', encoding='utf-8') as f:
                results = json.load(f)
                self.processed_count = len(results)
                print(f"📊 Loaded {self.processed_count} existing results")

        # 创建输出目录
        os.makedirs(os.path.dirname(output_path), exist_ok=True)

        with open(jsonl_path, 'r', encoding='utf-8') as f:
            batch = []

            for line in f:
                line_data = json.loads(line.strip())
                qid = line_data["qid"]

                # 跳过已处理的项目
                if qid in results:
                    continue

                batch.append((qid, line_data["data"]))

                # 处理批次
                if len(batch) >= batch_size:
                    self._process_batch(batch, results)
                    batch = []

                    # 定期保存和清理内存
                    if self.processed_count % save_interval == 0:
                        self._save_results(results, output_path)
                        gc.collect()
                        self._print_progress()

            # 处理剩余的项目
            if batch:
                self._process_batch(batch, results)

        # 最终保存
        self._save_results(results, output_path)
        self._print_final_stats()

        return results

    def _process_batch(self, batch: list, results: dict):
        """处理一个批次的数据"""
        for qid, sample in batch:
            try:
                result = self._process_single_sample(qid, sample)
                if result:
                    results[qid] = result
                    self.processed_count += 1

            except Exception as e:
                print(f"⚠️ Error processing {qid}: {str(e)}")
                self.error_count += 1
                continue

    def _process_single_sample(self, qid: str, sample: dict) -> dict:
        """处理单个样本"""
        question = sample["question"]
        true_final_result = sample["true_final_result"]

        # 找到正样本
        correct_sampling_id = None
        correct_sample_answer = None

        for sampling_id in ["sampling0", "sampling1", "sampling2"]:
            if sampling_id not in sample:
                continue
            if sample[sampling_id]["final_result"] == true_final_result:
                correct_sampling_id = sampling_id
                correct_sample_answer = sample[sampling_id]["whole_answer"]
                break

        if correct_sample_answer is None:
            return None

        sample_results = {}

        # 处理负样本
        for sampling_id in ["sampling0", "sampling1", "sampling2"]:
            if sampling_id not in sample:
                continue
            sampling = sample[sampling_id]

            # 跳过正样本
            if sampling["final_result"] == true_final_result:
                continue

            incorrect_sample_answer = sampling["whole_answer"]

            # 调用API
            prompt = build_error_prompt(question, correct_sample_answer, incorrect_sample_answer)
            output = call_custom_gpt_api(prompt)

            # 解析结果
            output = output.strip().strip("```")
            if output.startswith("json"):
                output = output[4:].strip()

            output_json = json.loads(output)

            # 查找token索引
            error_sentence = output_json["first_error_sentence"]
            fix_sentence = output_json["fix_sentence"]

            error_token_probs = sample[sampling_id]["token_probs"]
            correct_token_probs = sample[correct_sampling_id]["token_probs"]

            error_begin_idx, error_end_idx = find_sentence_span_indices_robust(
                error_sentence, error_token_probs
            )
            fix_begin_idx, fix_end_idx = find_sentence_span_indices_robust(
                fix_sentence, correct_token_probs
            )

            sample_results[sampling_id] = {
                "first_error_sentence": error_sentence,
                "error_reason": output_json["error_reason"],
                "fix_sentence": fix_sentence,
                "fix_reason": output_json["fix_reason"],
                "correct_sampling_id": correct_sampling_id,
                "error_token_begin_index": error_begin_idx,
                "error_token_end_index": error_end_idx,
                "fix_token_begin_index": fix_begin_idx,
                "fix_token_end_index": fix_end_idx
            }

        return sample_results if sample_results else None

    def _save_results(self, results: dict, output_path: str):
        """保存结果到文件"""
        with open(output_path, 'w', encoding='utf-8') as f:
            json.dump(results, f, ensure_ascii=False, indent=2)
        print(f"💾 Saved {len(results)} results to {output_path}")

    def _print_progress(self):
        """打印进度信息"""
        elapsed = time.time() - self.start_time
        speed = self.processed_count / elapsed if elapsed > 0 else 0
        print(f"📊 Progress: {self.processed_count} processed, {self.error_count} errors, "
              f"{speed:.2f} items/sec, {elapsed:.1f}s elapsed")

    def _print_final_stats(self):
        """打印最终统计信息"""
        elapsed = time.time() - self.start_time
        print(f"\n🎉 Processing complete!")
        print(f"📊 Total processed: {self.processed_count}")
        print(f"⚠️ Total errors: {self.error_count}")
        print(f"⏱️ Total time: {elapsed:.1f}s")
        print(f"🚀 Average speed: {self.processed_count / elapsed:.2f} items/sec")

# 使用示例
def main():
    # 配置
    API_KEY = "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"
    BASE_PATH = "/content/drive/MyDrive/Cluster-proj"
    range_tag = "901-950"

    # 路径设置
    input_json = f"{BASE_PATH}/output/llm_steps/whole_logits/deepseek-math-7b-gsm-{range_tag}.json"
    output_jsonl = f"{BASE_PATH}/output/llm_steps/whole_logits/deepseek-math-7b-gsm-{range_tag}.jsonl"
    output_results = f"{BASE_PATH}/output/error_fix_index/deepseek-math-7b-{range_tag}_error_fix_index.json"

    # 创建处理器
    processor = JSONLProcessor(API_KEY)

    # 步骤1: 转换为JSONL（如果还没有转换）
    if not os.path.exists(output_jsonl):
        processor.convert_json_to_jsonl(input_json, output_jsonl)
    else:
        print(f"📂 JSONL file already exists: {output_jsonl}")

    # 步骤2: 处理JSONL文件
    print("\n🚀 Starting JSONL processing...")
    results = processor.process_jsonl_file(
        jsonl_path=output_jsonl,
        output_path=output_results,
        batch_size=5,      # 减小批次大小以节省内存
        save_interval=20   # 每20个项目保存一次
    )

    print(f"\n✅ All processing complete! Results saved to {output_results}")

if __name__ == "__main__":
    main()

In [None]:

import json
import os
import requests
from difflib import SequenceMatcher
import re


# ✅ Your custom API config
API_KEY = "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"
MODEL = "gpt-4.1"

# ✅ Paths
start_index = 901
end_index = 950
range_tag = f"{start_index}-{end_index}"




In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [3]:

BASE_PATH = "/content/drive/MyDrive/Cluster-proj"
LOGITS_PATH = f"{BASE_PATH}/output/llm_steps/whole_logits/deepseek-math-7b-gsm-{range_tag}.json"

# ✅ Prompt builder
def build_error_prompt(question, true_whole_answer, sample_whole_answer):
    return f"""
Here is a math problem, its correct answer, and a sample answer that may contain mistakes.

【Question】:
{question}

【Correct Answer】:
{true_whole_answer}

【Incorrect Answer】:
{sample_whole_answer}

Please help me:
1. Identify the earliest mistake in the incorrect answer and provide the compelete sentence from that point.
2. Briefly explain why it is incorrect.
3. Find the fix sentence in correct answer that and fix the error.
4. Briefly explain why it can fix the error.

Please output in the following JSON format:
{{
  "first_error_sentence": "<sentence>",
  "error_reason": "<brief explanation>",
  "fix_sentence": "<sentence>",
  "fix_reason": "<brief explanation>"
}}
"""



# ✅ Call your custom GPT API
def call_custom_gpt_api(prompt):
    headers = {
        "Authorization": f"Bearer {API_KEY}",
        "Content-Type": "application/json"
    }
    payload = {
        "model": MODEL,
        "messages": [
            {"role": "system", "content": "You are a meticulous and precise comparer."},
            {"role": "user", "content": prompt}
        ]
    }
    response = requests.post(
        "https://api.openai.com/v1/chat/completions",
        headers=headers,
        json=payload
    )
    if response.status_code != 200:
        raise Exception(f"API request failed: {response.status_code}, {response.text}")
    return response.json()["choices"][0]["message"]["content"]



In [4]:
import re
# ✅ 匹配函数：返回片段起止 token index
def find_sentence_span_indices_robust(fragment, token_probs):
    """
    返回 fragment 在 token_probs 中匹配到的 token 范围: (begin_index, end_index)
    - 使用去除空白字符的方式匹配
    """
    fragment_clean = re.sub(r"\s+", "", fragment)
    tokens = [entry["token"] for entry in token_probs]
    decoded_text = "".join(tokens)
    decoded_text_clean = re.sub(r"\s+", "", decoded_text)

    char_start_idx = decoded_text_clean.find(fragment_clean)
    if char_start_idx == -1:
        return -1, -1

    cumulative_len = 0
    begin_index = -1
    for idx, entry in enumerate(token_probs):
        token_clean = re.sub(r"\s+", "", entry["token"])
        prev_len = cumulative_len
        cumulative_len += len(token_clean)

        if begin_index == -1 and cumulative_len > char_start_idx:
            begin_index = idx
        if cumulative_len >= char_start_idx + len(fragment_clean):
            end_index = idx
            return begin_index, end_index

    return begin_index, len(token_probs) - 1  # fallback

In [1]:
!fallocate -l 4G /tmp/swapfile
!chmod 600 /tmp/swapfile
!mkswap /tmp/swapfile
!swapon /tmp/swapfile

Setting up swapspace version 1, size = 4 GiB (4294963200 bytes)
no label, UUID=0082653d-b97c-4cd0-9b1e-a8ce020a3a7e
swapon: /tmp/swapfile: swapon failed: Invalid argument


In [None]:

# ✅ Load logits_data
with open(LOGITS_PATH, "r") as f:
    logits_data = json.load(f)


In [None]:

results = {}

# ✅ 主循环 - 修改为比较正负样本
for qid, sample in logits_data.items():
    question = sample["question"]
    true_final_result = sample["true_final_result"]

    # 找到正样本（正确的sampling）
    correct_sampling_id = None
    correct_sample_answer = None

    for sampling_id in ["sampling0", "sampling1", "sampling2"]:
        if sampling_id not in sample:
            continue
        if sample[sampling_id]["final_result"] == true_final_result:
            correct_sampling_id = sampling_id
            correct_sample_answer = sample[sampling_id]["whole_answer"]
            break

    # 如果没有找到正样本，跳过这个问题
    if correct_sample_answer is None:
        print(f"⚠️ No correct sampling found for {qid}, skipping...")
        continue

    # 处理负样本（错误的sampling）
    for sampling_id in ["sampling0", "sampling1", "sampling2"]:
        if sampling_id not in sample:
            continue
        sampling = sample[sampling_id]

        # 跳过正样本
        if sampling["final_result"] == true_final_result:
            continue

        incorrect_sample_answer = sampling["whole_answer"]

        # 构造 prompt 并调用 API - 使用正样本作为正确答案
        prompt = build_error_prompt(question, correct_sample_answer, incorrect_sample_answer)
        output = call_custom_gpt_api(prompt)
        print(f"\n🔍 {qid} / {sampling_id} (comparing with {correct_sampling_id}):\n{output}")

        # 去除可能的 ''' 包裹
        output = output.strip().strip("```")
        if output.startswith("json"):
            output = output[4:].strip()

        # 解析 JSON
        try:
            output_json = json.loads(output)
            error_sentence = output_json["first_error_sentence"]
            error_reason = output_json["error_reason"]
            fix_sentence = output_json["fix_sentence"]
            fix_reason = output_json["fix_reason"]
        except Exception as e:
            print(f"⚠️ JSON parsing failed: {e}")
            error_sentence = ""
            error_reason = output
            fix_sentence = ""
            fix_reason = ""

        # ✅ 保存结果
        if qid not in results:
            results[qid] = {}
        results[qid][sampling_id] = {
            "first_error_sentence": error_sentence,
            "error_reason": error_reason,
            "fix_sentence": fix_sentence,
            "fix_reason": fix_reason,
            "correct_sampling_id": correct_sampling_id  # 记录使用的正样本ID
        }


In [None]:

# ✅ 第二轮遍历，补充 token index
for qid, sample_data in results.items():
    for sampling_id, info in sample_data.items():
        error_sentence = info["first_error_sentence"]
        fix_sentence = info["fix_sentence"]
        correct_sampling_id = info["correct_sampling_id"]

        # 获取错误样本的token_probs
        error_token_probs = logits_data[qid][sampling_id]["token_probs"]

        # 匹配错误句子的 token index 范围
        error_begin_idx, error_end_idx = find_sentence_span_indices_robust(error_sentence, error_token_probs)

        # 获取正确样本的token_probs
        correct_token_probs = logits_data[qid][correct_sampling_id]["token_probs"]

        # 匹配修复句子的 token index 范围
        fix_begin_idx, fix_end_idx = find_sentence_span_indices_robust(fix_sentence, correct_token_probs)

        # 加入到结果中
        info["error_token_begin_index"] = error_begin_idx
        info["error_token_end_index"] = error_end_idx
        info["fix_token_begin_index"] = fix_begin_idx
        info["fix_token_end_index"] = fix_end_idx

        # 可选：打印检查
        print(f"{qid} / {sampling_id}:")
        print(f"  Error: [{error_begin_idx}, {error_end_idx}] : {error_sentence}")
        print(f"  Fix (from {correct_sampling_id}): [{fix_begin_idx}, {fix_end_idx}] : {fix_sentence}")

# ✅ 保存结果
output_dir = os.path.join(BASE_PATH, "output/error_fix_index")
os.makedirs(output_dir, exist_ok=True)
output_path = os.path.join(output_dir, f"deepseek-math-7b-{range_tag}_error_fix_index.json")
with open(output_path, "w") as f:
    json.dump(results, f, ensure_ascii=False, indent=2)

print(f"\n✅ 所有结果已保存到 {output_path}")


In [None]:

# ✅ 验证部分 - 修改为验证错误句子和修复句子
print("\n" + "="*80)
print("开始验证...")

for qid, sample_data in results.items():
    for sampling_id, info in sample_data.items():
        error_sentence = info.get("first_error_sentence", "").strip()
        error_start = info.get("error_token_begin_index", -1)
        error_end = info.get("error_token_end_index", -1)

        fix_sentence = info.get("fix_sentence", "").strip()
        fix_start = info.get("fix_token_begin_index", -1)
        fix_end = info.get("fix_token_end_index", -1)

        correct_sampling_id = info.get("correct_sampling_id", "")

        # 验证错误句子
        if error_start != -1 and error_end != -1:
            error_token_probs = logits_data[qid][sampling_id]["token_probs"]
            error_tokens = [entry["token"] for entry in error_token_probs[error_start:error_end+1]]
            error_reconstructed = " ".join(error_tokens).strip()

            print("=" * 60)
            print(f"🔍 {qid} / {sampling_id}")
            print(f"📌 Error token span [{error_start}, {error_end}]:\n{error_reconstructed}")
            print(f"📌 Error sentence:\n{error_sentence}")

            error_clean = re.sub(r"\s+", "", error_reconstructed.lower())
            error_sentence_clean = re.sub(r"\s+", "", error_sentence.lower())
            error_match = "✅ MATCH" if error_clean == error_sentence_clean else "❌ DIFFERENT"
            print(f"🔎 Error比对结果: {error_match}")

        # 验证修复句子
        if fix_start != -1 and fix_end != -1:
            correct_token_probs = logits_data[qid][correct_sampling_id]["token_probs"]
            fix_tokens = [entry["token"] for entry in correct_token_probs[fix_start:fix_end+1]]
            fix_reconstructed = " ".join(fix_tokens).strip()

            print(f"\n📌 Fix token span (from {correct_sampling_id}) [{fix_start}, {fix_end}]:\n{fix_reconstructed}")
            print(f"📌 Fix sentence:\n{fix_sentence}")

            fix_clean = re.sub(r"\s+", "", fix_reconstructed.lower())
            fix_sentence_clean = re.sub(r"\s+", "", fix_sentence.lower())
            fix_match = "✅ MATCH" if fix_clean == fix_sentence_clean else "❌ DIFFERENT"
            print(f"🔎 Fix比对结果: {fix_match}\n")

In [None]:
#version 1

In [2]:
results = {}

# ✅ 主循环
for qid, sample in logits_data.items():
    question = sample["question"]
    true_final_result = sample["true_final_result"]
    # true_whole_answer = sample["true_whole_answer"]

    for sampling_id in ["sampling0", "sampling1", "sampling2"]:
        if sampling_id not in sample:
            continue
        sampling = sample[sampling_id]
        if sampling["final_result"] == true_final_result:
            true_whole_answer = sampling["whole_answer"]

        sample_whole_answer = sampling["whole_answer"]

        # 构造 prompt 并调用 API
        prompt = build_error_prompt(question, true_whole_answer, sample_whole_answer)
        output = call_custom_gpt_api(prompt)
        print(f"\n🔍 {qid} / {sampling_id}:\n{output}")

        # 去除可能的 ''' 包裹
        output = output.strip().strip("")
        if output.startswith("json"):
          output = output[4:].strip()
        # 解析 JSON
        try:
            output_json = json.loads(output)
            sentence = output_json["first_error_sentence"]
            error_reason = output_json["error_reason"]
        except Exception as e:
            print(f"⚠️ JSON parsing failed: {e}")
            sentence = ""
            error_reason = output

        # ✅ 保存结果（字段名为 sentence）
        if qid not in results:
            results[qid] = {}
        results[qid][sampling_id] = {
            "first_error_sentence": sentence,
            "error_reason": error_reason,
        }

NameError: name 'logits_data' is not defined

In [3]:


# ✅ 第二轮遍历，补充 token index（不重新调用 API）
for qid, sample_data in results.items():
    for sampling_id, info in sample_data.items():
        sentence = info["first_error_sentence"]
        token_probs = logits_data[qid][sampling_id]["token_probs"]

        # 匹配 token index 范围
        begin_idx, end_idx = find_sentence_span_indices_robust(sentence, token_probs)

        # 加入到结果中
        info["first_error_token_index"] = begin_idx
        info["last_error_token_index"] = end_idx

        # 可选：打印检查
        print(f"{qid} / {sampling_id} → [{begin_idx}, {end_idx}] : {sentence}")

In [8]:
output_dir = os.path.join(BASE_PATH, "output/error_index")
os.makedirs(output_dir, exist_ok=True)
output_path = os.path.join(output_dir, f"deepseek-math-7b-{range_tag}_index.json")
with open(output_path, "w") as f:
    json.dump(results, f, ensure_ascii=False, indent=2)

print(f"\n✅ 所有结果已保存到 {output_path}")



✅ 所有结果已保存到 /content/drive/MyDrive/Cluster-proj/output/error_index/deepseek-math-7b-901-950_index.json


In [11]:
#test
logits_data['q_947'].keys()

dict_keys(['question', 'true_whole_answer', 'true_final_result', 'sampling0', 'sampling1', 'sampling2'])

In [19]:
#test
logits_data['q_927']['true_final_result']

'31'

In [23]:
#test
logits_data['q_927']['sampling2']['final_result']

'50'

In [9]:
#check


import json
import os
import re

# ✅ 路径配置
BASE_PATH = "/content/drive/MyDrive/Cluster-proj"
range_tag = "901-950"
LOGITS_PATH = f"{BASE_PATH}/output/llm_steps/whole_logits/deepseek7b-gsm-{range_tag}.json"
ERROR_INDEX_PATH = f"{BASE_PATH}/output/error_index/deepseek-math-7b-{range_tag}_index.json"

# # ✅ 加载两个数据源
# with open(LOGITS_PATH, "r") as f:
#     logits_data = json.load(f)

with open(ERROR_INDEX_PATH, "r") as f:
    error_index_data = json.load(f)

# ✅ 遍历每一条，拼接 token & 比对句子
for qid, sample_data in error_index_data.items():
    for sampling_id, info in sample_data.items():
        sentence = info.get("first_error_sentence", "").strip()
        start = info.get("first_error_token_index", -1)
        end = info.get("last_error_token_index", -1)

        if start == -1 or end == -1:
            print(f"{qid} / {sampling_id} ❌ 缺失 index")
            continue

        token_probs = logits_data[qid][sampling_id]["token_probs"]
        tokens = [entry["token"] for entry in token_probs[start:end+1]]
        reconstructed = " ".join(tokens).strip()

        print("=" * 60)
        print(f"🔍 {qid} / {sampling_id}")
        print(f"📌 Token span [{start}, {end}]:\n{reconstructed}")
        print(f"\n📌 Error sentence:\n{sentence}")

        # 简单比对相似度
        reconstructed_clean = re.sub(r"\s+", "", reconstructed.lower())
        sentence_clean = re.sub(r"\s+", "", sentence.lower())
        match_status = "✅ MATCH" if reconstructed_clean == sentence_clean else "❌ DIFFERENT"

        print(f"\n🔎 比对结果: {match_status}\n")


🔍 q_902 / sampling0
📌 Token span [36, 74]:
He fed an equal number of straw s to the pig lets , which means he fed  1 8 0 straw s /  2 0 pig lets =  9 straw s to each pig let .

📌 Error sentence:
He fed an equal number of straws to the piglets, which means he fed 180 straws / 20 piglets = 9 straws to each piglet.

🔎 比对结果: ✅ MATCH

🔍 q_906 / sampling2
📌 Token span [34, 56]:
At the second stop ,  3 people got off the bus , so the number of passengers decreased by  3 .

📌 Error sentence:
At the second stop, 3 people got off the bus, so the number of passengers decreased by 3.

🔎 比对结果: ✅ MATCH

🔍 q_911 / sampling0
📌 Token span [136, 182]:
Now , to find the combined time the all igators walked , we need to add the time Paul spent walking to the Nile Delta ( 4 hours ) and the time the other six all igators spent walking on the return journey ( 6 hours ).

📌 Error sentence:
Now, to find the combined time the alligators walked, we need to add the time Paul spent walking to the Nile Delta (4 hou