In [None]:
# -*- coding: utf-8 -*-
"""
Fourier analysis comparing positive/negative samples with error and fix positions
"""

import os
import json
import numpy as np
import matplotlib.pyplot as plt
from scipy.fft import fft, fftfreq, ifft
from google.colab import drive

# ✅ 挂载 Google Drive
drive.mount('/content/drive')


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
# -*- coding: utf-8 -*-
"""
Memory-efficient Fourier analysis using streaming processing
"""

import os
import json
import numpy as np
import matplotlib.pyplot as plt
from scipy.fft import fft, fftfreq, ifft
import gc
from google.colab import drive

# ✅ 挂载 Google Drive
# drive.mount('/content/drive')

# ✅ 配置路径参数
start_index = 700
end_index = 731
range_tag = f"{start_index}-{end_index}"
BASE_PATH = "/content/drive/MyDrive/Cluster-proj"

# 输入文件路径
LOGITS_JSONL_PATH = f"{BASE_PATH}/output/llm_steps/whole_logits/deepseek7b-gsm-{range_tag}.jsonl"
ERROR_FIX_INDEX_PATH = f"{BASE_PATH}/output/error_fix_index/deepseek-7b-{range_tag}_error_fix_index.json"

# ✅ 检查JSONL文件是否存在
if not os.path.exists(LOGITS_JSONL_PATH):
    print(f"⚠️ JSONL file not found: {LOGITS_JSONL_PATH}")
    print("Converting JSON to JSONL first...")

    # 快速转换JSON到JSONL
    json_path = f"{BASE_PATH}/output/llm_steps/whole_logits/deepseek-math-7b-gsm-{range_tag}.json"

    def convert_json_to_jsonl_chunked(json_path, jsonl_path, chunk_size=10):
        """分块转换JSON到JSONL以节省内存"""
        print(f"🔄 Converting {json_path} to JSONL...")

        with open(json_path, 'r', encoding='utf-8') as f:
            data = json.load(f)

        total_items = len(data)
        print(f"📊 Total items: {total_items}")

        with open(jsonl_path, 'w', encoding='utf-8') as f:
            for i, (qid, sample) in enumerate(data.items()):
                line_data = {"qid": qid, "data": sample}
                f.write(json.dumps(line_data, ensure_ascii=False) + '\n')

                if (i + 1) % chunk_size == 0:
                    print(f"📈 Converted {i + 1}/{total_items}")
                    f.flush()  # 强制写入磁盘

        # 清理内存
        del data
        gc.collect()
        print(f"✅ Conversion complete: {jsonl_path}")

    convert_json_to_jsonl_chunked(json_path, LOGITS_JSONL_PATH)

# ✅ 加载错误修复索引数据
print("📂 Loading error-fix index data...")
if os.path.exists(ERROR_FIX_INDEX_PATH):
    with open(ERROR_FIX_INDEX_PATH, "r") as f:
        error_fix_data = json.load(f)
    print(f"✅ Loaded error-fix data: {len(error_fix_data)} questions")
else:
    print(f"⚠️ Error-fix index file not found: {ERROR_FIX_INDEX_PATH}")
    print("Please run the error analysis script first!")
    exit()

# ✅ 输出路径
OUTPUT_DIR = f"{BASE_PATH}/output/fourier_analysis_error_fix"
DETAIL_DIR = f"{BASE_PATH}/output/fourier_analysis_detail"
os.makedirs(OUTPUT_DIR, exist_ok=True)
os.makedirs(DETAIL_DIR, exist_ok=True)

# ✅ 傅里叶平滑函数
def fourier_smooth(y, keep_ratio=0.1):
    """傅里叶变换平滑函数"""
    y = np.asarray(y)
    N = len(y)
    if N < 4:
        return y
    Y = fft(y)
    Y[int(N * keep_ratio):-int(N * keep_ratio)] = 0
    y_smooth = np.real(ifft(Y))
    return y_smooth

# ✅ 流式加载单个样本的logits数据
def load_sample_logits(qid, sampling_ids):
    """
    从JSONL文件中流式加载指定样本的logits数据
    """
    sample_data = {}

    with open(LOGITS_JSONL_PATH, 'r', encoding='utf-8') as f:
        for line in f:
            line_data = json.loads(line.strip())
            if line_data["qid"] == qid:
                data = line_data["data"]
                # 只提取需要的sampling数据
                for sid in sampling_ids:
                    if sid in data:
                        sample_data[sid] = data[sid]
                break

    return sample_data

# ✅ 绘制熵曲线对比（内存优化版本）
def plot_entropy_with_error_fix_positions_efficient(qid, neg_sid, pos_sid, analysis_data, save_dir):
    """
    内存效率优化的熵曲线绘制函数
    """
    try:
        # 流式加载所需的样本数据
        sample_logits = load_sample_logits(qid, [neg_sid, pos_sid])

        if neg_sid not in sample_logits or pos_sid not in sample_logits:
            print(f"⚠️ 缺失 logits：{qid} | {neg_sid} / {pos_sid}")
            return

        neg_probs = sample_logits[neg_sid]["token_probs"]
        pos_probs = sample_logits[pos_sid]["token_probs"]

        # 提取熵值
        entropy_neg = np.array([tok["topk_info"]["entropy"] for tok in neg_probs])
        entropy_pos = np.array([tok["topk_info"]["entropy"] for tok in pos_probs])

        if len(entropy_neg) < 4 or len(entropy_pos) < 4:
            print(f"⚠️ 序列太短，跳过 {qid}")
            return

        idx_neg = np.arange(len(entropy_neg))
        idx_pos = np.arange(len(entropy_pos))

        # 傅里叶平滑
        smooth_neg = fourier_smooth(entropy_neg, keep_ratio=0.1)
        smooth_pos = fourier_smooth(entropy_pos, keep_ratio=0.1)

        # 错误和修复位置
        error_begin = analysis_data.get("error_token_begin_index", -1)
        error_end = analysis_data.get("error_token_end_index", -1)
        fix_begin = analysis_data.get("fix_token_begin_index", -1)
        fix_end = analysis_data.get("fix_token_end_index", -1)

        # ✅ 创建图形
        fig, axs = plt.subplots(2, 2, figsize=(16, 10))

        # 子图1: 负样本（错误）的熵曲线
        axs[0, 0].plot(idx_neg, entropy_neg, 'o-', alpha=0.3, label="Original Entropy", color='red')
        axs[0, 0].plot(idx_neg, smooth_neg, '-', label="Smoothed Entropy", color='red', linewidth=2)

        # 标注错误位置
        if error_begin >= 0 and error_end >= 0:
            axs[0, 0].axvspan(error_begin, error_end, alpha=0.2, color='red', label="Error Region")
            axs[0, 0].axvline(error_begin, color='red', linestyle='--', alpha=0.7)
            axs[0, 0].axvline(error_end, color='red', linestyle='--', alpha=0.7)

        axs[0, 0].set_title(f"Negative Sample ({neg_sid}) - Error Position")
        axs[0, 0].set_xlabel("Token Index")
        axs[0, 0].set_ylabel("Entropy")
        axs[0, 0].legend()
        axs[0, 0].grid(True)

        # 子图2: 正样本（正确）的熵曲线
        axs[0, 1].plot(idx_pos, entropy_pos, 'o-', alpha=0.3, label="Original Entropy", color='green')
        axs[0, 1].plot(idx_pos, smooth_pos, '-', label="Smoothed Entropy", color='green', linewidth=2)

        # 标注修复位置
        if fix_begin >= 0 and fix_end >= 0:
            axs[0, 1].axvspan(fix_begin, fix_end, alpha=0.2, color='blue', label="Fix Region")
            axs[0, 1].axvline(fix_begin, color='blue', linestyle='--', alpha=0.7)
            axs[0, 1].axvline(fix_end, color='blue', linestyle='--', alpha=0.7)

        axs[0, 1].set_title(f"Positive Sample ({pos_sid}) - Fix Position")
        axs[0, 1].set_xlabel("Token Index")
        axs[0, 1].set_ylabel("Entropy")
        axs[0, 1].legend()
        axs[0, 1].grid(True)

        # 子图3: 对比图（重叠显示）
        axs[1, 0].plot(idx_neg, smooth_neg, '-', label=f"{neg_sid} (Error)", color='red', linewidth=2)
        axs[1, 0].plot(idx_pos, smooth_pos, '-', label=f"{pos_sid} (Correct)", color='green', linewidth=2)

        # 同时标注错误和修复位置
        if error_begin >= 0 and error_end >= 0:
            axs[1, 0].axvspan(error_begin, error_end, alpha=0.15, color='red', label="Error Region")
        if fix_begin >= 0 and fix_end >= 0:
            axs[1, 0].axvspan(fix_begin, fix_end, alpha=0.15, color='blue', label="Fix Region")

        axs[1, 0].set_title("Positive vs Negative Sample Comparison")
        axs[1, 0].set_xlabel("Token Index")
        axs[1, 0].set_ylabel("Entropy")
        axs[1, 0].legend()
        axs[1, 0].grid(True)

        # 子图4: 频谱分析
        N_neg = len(entropy_neg)
        N_pos = len(entropy_pos)

        freqs_neg = fftfreq(N_neg, d=1)[:N_neg // 2]
        amp_neg = np.abs(fft(entropy_neg))[:N_neg // 2]

        freqs_pos = fftfreq(N_pos, d=1)[:N_pos // 2]
        amp_pos = np.abs(fft(entropy_pos))[:N_pos // 2]

        axs[1, 1].plot(freqs_neg, amp_neg, label=f"{neg_sid} Spectrum", color='red')
        axs[1, 1].plot(freqs_pos, amp_pos, label=f"{pos_sid} Spectrum", color='green')
        axs[1, 1].set_title("Fourier Amplitude Spectrum Comparison")
        axs[1, 1].set_xlabel("Frequency")
        axs[1, 1].set_ylabel("Amplitude")
        axs[1, 1].legend()
        axs[1, 1].grid(True)

        # 设置整体标题
        plt.suptitle(f"Entropy Analysis: {qid} | {neg_sid} vs {pos_sid}", fontsize=14)
        plt.tight_layout()

        # 保存图片
        fname = f"{qid}_{neg_sid}_vs_{pos_sid}_error_fix_analysis.png"
        plt.savefig(os.path.join(save_dir, fname), dpi=150, bbox_inches='tight')
        plt.close()

        # 清理内存
        del sample_logits, neg_probs, pos_probs, entropy_neg, entropy_pos
        gc.collect()

        print(f"✅ Saved: {fname}")

    except Exception as e:
        print(f"❌ Processing failed {qid} | {neg_sid} vs {pos_sid}: {e}")

# ✅ 生成详细的分析报告图（内存优化版本）
def plot_detailed_position_analysis_efficient(qid, neg_sid, pos_sid, analysis_data, save_dir):
    """
    内存效率优化的详细位置分析图
    """
    try:
        # 流式加载所需的样本数据
        sample_logits = load_sample_logits(qid, [neg_sid, pos_sid])

        if neg_sid not in sample_logits or pos_sid not in sample_logits:
            return

        neg_probs = sample_logits[neg_sid]["token_probs"]
        pos_probs = sample_logits[pos_sid]["token_probs"]

        entropy_neg = np.array([tok["topk_info"]["entropy"] for tok in neg_probs])
        entropy_pos = np.array([tok["topk_info"]["entropy"] for tok in pos_probs])

        # 提取错误和修复区间的熵值
        error_begin = analysis_data.get("error_token_begin_index", -1)
        error_end = analysis_data.get("error_token_end_index", -1)
        fix_begin = analysis_data.get("fix_token_begin_index", -1)
        fix_end = analysis_data.get("fix_token_end_index", -1)

        fig, axs = plt.subplots(1, 3, figsize=(18, 6))

        # 子图1: 错误区间详细分析
        if error_begin >= 0 and error_end >= 0 and error_end < len(entropy_neg):
            error_entropy = entropy_neg[error_begin:error_end+1]
            error_indices = np.arange(error_begin, error_end+1)

            axs[0].plot(error_indices, error_entropy, 'ro-', linewidth=2, markersize=6)
            axs[0].set_title("Error Region Entropy Details")
            axs[0].set_xlabel("Token Index")
            axs[0].set_ylabel("Entropy")
            axs[0].grid(True)

            # 添加token信息（限制数量以避免拥挤）
            max_annotations = min(10, len(error_indices))
            step = max(1, len(error_indices) // max_annotations)
            for i in range(0, len(error_indices), step):
                idx = error_indices[i]
                if idx < len(neg_probs):
                    token_text = neg_probs[idx]["token"][:10]  # 限制长度
                    axs[0].annotate(f"{token_text}", (idx, error_entropy[i]),
                                  textcoords="offset points", xytext=(0,10), ha='center', fontsize=8)

        # 子图2: 修复区间详细分析
        if fix_begin >= 0 and fix_end >= 0 and fix_end < len(entropy_pos):
            fix_entropy = entropy_pos[fix_begin:fix_end+1]
            fix_indices = np.arange(fix_begin, fix_end+1)

            axs[1].plot(fix_indices, fix_entropy, 'go-', linewidth=2, markersize=6)
            axs[1].set_title("Fix Region Entropy Details")
            axs[1].set_xlabel("Token Index")
            axs[1].set_ylabel("Entropy")
            axs[1].grid(True)

            # 添加token信息
            max_annotations = min(10, len(fix_indices))
            step = max(1, len(fix_indices) // max_annotations)
            for i in range(0, len(fix_indices), step):
                idx = fix_indices[i]
                if idx < len(pos_probs):
                    token_text = pos_probs[idx]["token"][:10]  # 限制长度
                    axs[1].annotate(f"{token_text}", (idx, fix_entropy[i]),
                                  textcoords="offset points", xytext=(0,10), ha='center', fontsize=8)

        # 子图3: 错误vs修复区间对比
        if (error_begin >= 0 and error_end >= 0 and error_end < len(entropy_neg) and
            fix_begin >= 0 and fix_end >= 0 and fix_end < len(entropy_pos)):

            error_entropy = entropy_neg[error_begin:error_end+1]
            fix_entropy = entropy_pos[fix_begin:fix_end+1]

            # 对齐长度进行对比
            min_len = min(len(error_entropy), len(fix_entropy))
            if min_len > 0:
                x_indices = np.arange(min_len)
                axs[2].plot(x_indices, error_entropy[:min_len], 'ro-', label="Error Region", linewidth=2)
                axs[2].plot(x_indices, fix_entropy[:min_len], 'go-', label="Fix Region", linewidth=2)
                axs[2].set_title("Error vs Fix Region Comparison")
                axs[2].set_xlabel("Relative Position")
                axs[2].set_ylabel("Entropy")
                axs[2].legend()
                axs[2].grid(True)

        plt.suptitle(f"Position Analysis: {qid} | Error vs Fix", fontsize=14)
        plt.tight_layout()

        fname = f"{qid}_{neg_sid}_vs_{pos_sid}_position_detail.png"
        plt.savefig(os.path.join(save_dir, fname), dpi=150, bbox_inches='tight')
        plt.close()

        # 清理内存
        del sample_logits, neg_probs, pos_probs, entropy_neg, entropy_pos
        gc.collect()

    except Exception as e:
        print(f"❌ Detailed analysis failed {qid}: {e}")

# ✅ 主处理循环（内存优化）
print("\n🚀 Starting analysis generation...")

# 统计信息
total_samples = 0
processed_count = 0
error_count = 0

# 逐个处理样本，避免内存累积
for qid, sample_data in error_fix_data.items():
    for neg_sid, analysis in sample_data.items():
        pos_sid = analysis.get("correct_sampling_id")
        if pos_sid:
            total_samples += 1

            try:
                print(f"🔍 Processing {total_samples}: {qid} | {neg_sid} vs {pos_sid}")

                # 生成主要分析图
                plot_entropy_with_error_fix_positions_efficient(
                    qid, neg_sid, pos_sid, analysis, OUTPUT_DIR
                )

                # 生成详细位置分析图
                plot_detailed_position_analysis_efficient(
                    qid, neg_sid, pos_sid, analysis, DETAIL_DIR
                )

                processed_count += 1

                # 定期清理内存
                if total_samples % 5 == 0:
                    gc.collect()
                    print(f"📊 Processed {processed_count}/{total_samples} samples")

            except Exception as e:
                print(f"❌ Processing failed {qid} | {neg_sid} vs {pos_sid}: {e}")
                error_count += 1
                continue

# ✅ 生成最终统计报告
print(f"\n🎉 Analysis completed!")
print(f"📊 Total samples: {total_samples}")
print(f"✅ Successfully processed: {processed_count}")
print(f"❌ Failed to process: {error_count}")
print(f"📁 Main analysis plots saved to: {OUTPUT_DIR}")
print(f"📁 Detailed analysis plots saved to: {DETAIL_DIR}")

# ✅ 生成处理摘要
summary_data = {
    "processing_summary": {
        "total_samples": total_samples,
        "successful_processed": processed_count,
        "failed_processed": error_count,
        "success_rate": f"{processed_count/total_samples*100:.1f}%" if total_samples > 0 else "0%"
    },
    "output_directories": {
        "main_analysis": OUTPUT_DIR,
        "detailed_analysis": DETAIL_DIR
    }
}

summary_file = os.path.join(OUTPUT_DIR, "processing_summary.json")
with open(summary_file, 'w', encoding='utf-8') as f:
    json.dump(summary_data, f, ensure_ascii=False, indent=2)

print(f"📄 Processing summary saved to: {summary_file}")

📂 Loading error-fix index data...
✅ Loaded error-fix data: 19 questions

🚀 Starting analysis generation...
🔍 Processing 1: q_700 | sampling0 vs sampling1
✅ Saved: q_700_sampling0_vs_sampling1_error_fix_analysis.png
🔍 Processing 2: q_700 | sampling2 vs sampling1
✅ Saved: q_700_sampling2_vs_sampling1_error_fix_analysis.png
🔍 Processing 3: q_701 | sampling1 vs sampling0
✅ Saved: q_701_sampling1_vs_sampling0_error_fix_analysis.png
🔍 Processing 4: q_703 | sampling1 vs sampling0
✅ Saved: q_703_sampling1_vs_sampling0_error_fix_analysis.png
🔍 Processing 5: q_705 | sampling1 vs sampling0
✅ Saved: q_705_sampling1_vs_sampling0_error_fix_analysis.png
📊 Processed 5/5 samples
🔍 Processing 6: q_705 | sampling2 vs sampling0
✅ Saved: q_705_sampling2_vs_sampling0_error_fix_analysis.png
🔍 Processing 7: q_707 | sampling0 vs sampling1
✅ Saved: q_707_sampling0_vs_sampling1_error_fix_analysis.png
🔍 Processing 8: q_707 | sampling2 vs sampling1
✅ Saved: q_707_sampling2_vs_sampling1_error_fix_analysis.png
🔍 Pro