# Qwen架构LLM模型Logits研究分析

这个notebook演示如何对Qwen架构的LLM模型进行logits分析，比较不同训练阶段模型的内部表征差异。

## 研究目标
- 分析三个模型（Baseline、SFT、RL）的logits差异
- 观察推理关键词（'wait', 'aha', 'check'等）的logits变化
- 多维度对比：模型间、难度级别间、正确性间的差异

## 实验设置
- 使用transformers库加载模型
- 分析JSONL格式的query-answer数据
- 生成热力图和折线图可视化结果

# 1. 环境设置与依赖导入

首先确保使用uv管理的虚拟环境，并导入所需的库。

In [None]:
# 导入必要的库
import os
import sys
import json
import jsonlines
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import plotly.graph_objects as go
import plotly.express as px
from plotly.subplots import make_subplots
import warnings
warnings.filterwarnings('ignore')

# 深度学习相关
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from tqdm import tqdm
from typing import Dict, List, Tuple, Any

# 设置随机种子以确保结果可重现
torch.manual_seed(42)
np.random.seed(42)

print("所有依赖已成功导入！")
print(f"PyTorch版本: {torch.__version__}")
print(f"CUDA可用: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU数量: {torch.cuda.device_count()}")
    print(f"当前GPU: {torch.cuda.get_device_name(0)}")

# 2. 模型加载与配置

配置三个模型的路径和参数设置。

In [None]:
# 模型配置
MODEL_CONFIGS = {
    'baseline': {
        'name': 'Qwen2.5-Math-7B',
        'path': '/gpfs/models/huggingface.co/Qwen/Qwen2.5-Math-7B',
        'description': 'Baseline Qwen Model'
    },
    'sft': {
        'name': 'DeepSeek-R1-Distill-Qwen-7B',
        'path': '/gpfs/models/huggingface.co/deepseek-ai/DeepSeek-R1-Distill-Qwen-7B/',
        'description': 'SFT Model'
    },
    'rl': {
        'name': 'RL-DeepSeek-R1-Distill-Qwen-7B',
        'path': '/gpfs/users/xizhiheng/qiji_projects/NorthRL/checkpoints/xizhiheng_SkyWork_runs_DeepSeek-R1-Distill-Qwen-7B-Rethink-RL-Loss/xizhiheng___fsdp_valFirst_lr2e-6_kl0-low0.4-high0.2-partial-budget-1-len32k-skywork-grpo-temperature0.6-ppo_epochs2-stale1-testtmp-4-testdynamicclip-targetpos0.5-left---right-1.0-inf-1/global_step_380',
        'description': 'RL Trained Model'
    }
}

# 分析配置
TARGET_TOKENS = ['wait', 'aha', 'check', 'think', 'hmm', 'let', 'actually', 'however', 'so', 'therefore']
MAX_LENGTH = 2048
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
BATCH_SIZE = 4

print("模型配置已设置：")
for key, config in MODEL_CONFIGS.items():
    print(f"  {key}: {config['description']}")
print(f"\\n目标分析词汇: {TARGET_TOKENS}")
print(f"设备: {DEVICE}")

In [None]:
class ModelLoader:
    """模型加载器类"""
    
    def __init__(self, model_path: str, device: str = 'cuda'):
        self.model_path = model_path
        self.device = device
        self.tokenizer = None
        self.model = None
    
    def load_model(self):
        """加载模型和tokenizer"""
        print(f"正在加载模型: {self.model_path}")
        
        try:
            # 加载tokenizer
            self.tokenizer = AutoTokenizer.from_pretrained(
                self.model_path, 
                trust_remote_code=True
            )
            
            # 设置pad_token
            if self.tokenizer.pad_token is None:
                self.tokenizer.pad_token = self.tokenizer.eos_token
            
            # 加载模型
            self.model = AutoModelForCausalLM.from_pretrained(
                self.model_path,
                torch_dtype=torch.float16,
                device_map="auto",
                trust_remote_code=True
            )
            
            self.model.eval()
            print(f"模型加载成功！")
            return True
            
        except Exception as e:
            print(f"模型加载失败: {e}")
            return False
    
    def is_loaded(self):
        """检查模型是否已加载"""
        return self.model is not None and self.tokenizer is not None

# 创建模型加载器实例（暂不加载，避免内存占用）
model_loaders = {}
for model_name, config in MODEL_CONFIGS.items():
    model_loaders[model_name] = ModelLoader(config['path'], DEVICE)

print("模型加载器已创建。使用时调用 load_model() 方法加载模型。")

# 3. 数据加载与预处理

加载JSONL格式的数据，包含query、answer和level字段。

In [None]:
def load_data(file_path: str) -> List[Dict[str, Any]]:
    """从JSONL文件加载数据"""
    data = []
    try:
        with jsonlines.open(file_path) as reader:
            data = list(reader)
        print(f"成功加载 {len(data)} 条数据")
    except FileNotFoundError:
        print(f"文件 {file_path} 不存在，将创建示例数据")
        data = create_sample_data()
    return data

def create_sample_data() -> List[Dict[str, Any]]:
    """创建示例数据用于演示"""
    sample_data = [
        {
            "query": "What is the derivative of x^2 + 3x + 1?",
            "answer": "Wait, let me think about this step by step. The derivative of x^2 is 2x, the derivative of 3x is 3, and the derivative of 1 is 0. So the answer is 2x + 3.",
            "level": 2,
            "is_correct": True
        },
        {
            "query": "Solve the equation 2x + 5 = 15",
            "answer": "Hmm, I need to isolate x. So 2x = 15 - 5 = 10, therefore x = 5.",
            "level": 1,
            "is_correct": True
        },
        {
            "query": "Find the area of a circle with radius 5",
            "answer": "Actually, the formula for area of a circle is πr². So it's π × 5² = 25π.",
            "level": 1,
            "is_correct": True
        },
        {
            "query": "What is the integral of sin(x)?",
            "answer": "Aha! The integral of sin(x) is -cos(x) + C, where C is the constant of integration.",
            "level": 3,
            "is_correct": True
        },
        {
            "query": "Solve x^2 - 4x + 3 = 0",
            "answer": "Let me check... using the quadratic formula or factoring: (x-1)(x-3) = 0, so x = 1 or x = 3.",
            "level": 2,
            "is_correct": True
        }
    ]
    return sample_data

def preprocess_data(data: List[Dict[str, Any]]) -> pd.DataFrame:
    """预处理数据"""
    df = pd.DataFrame(data)
    
    # 创建完整文本
    df['full_text'] = df.apply(lambda row: f"Query: {row['query']}\\nAnswer: {row['answer']}", axis=1)
    
    # 计算文本长度
    df['query_length'] = df['query'].str.len()
    df['answer_length'] = df['answer'].str.len()
    df['full_text_length'] = df['full_text'].str.len()
    
    return df

# 加载数据
data_file = "data/queries.jsonl"
os.makedirs("data", exist_ok=True)
data = load_data(data_file)
df = preprocess_data(data)

print(f"数据预处理完成！")
print(f"数据形状: {df.shape}")
print(f"\\n数据概览:")
print(df[['level', 'query_length', 'answer_length', 'is_correct']].describe())

In [None]:
# 数据分布可视化
fig, axes = plt.subplots(2, 2, figsize=(12, 8))

# 难度级别分布
df['level'].value_counts().sort_index().plot(kind='bar', ax=axes[0,0], color='skyblue')
axes[0,0].set_title('难度级别分布')
axes[0,0].set_xlabel('难度级别')
axes[0,0].set_ylabel('数量')

# 正确性分布
if 'is_correct' in df.columns:
    df['is_correct'].value_counts().plot(kind='bar', ax=axes[0,1], color='lightgreen')
    axes[0,1].set_title('答题正确性分布')
    axes[0,1].set_xlabel('是否正确')
    axes[0,1].set_ylabel('数量')

# 文本长度分布
axes[1,0].hist(df['full_text_length'], bins=10, color='lightcoral', alpha=0.7)
axes[1,0].set_title('文本长度分布')
axes[1,0].set_xlabel('文本长度')
axes[1,0].set_ylabel('频次')

# 关键词出现频次
keyword_counts = {token: 0 for token in TARGET_TOKENS}
for text in df['full_text']:
    text_lower = text.lower()
    for token in TARGET_TOKENS:
        keyword_counts[token] += text_lower.count(token)

axes[1,1].bar(keyword_counts.keys(), keyword_counts.values(), color='gold')
axes[1,1].set_title('关键词出现频次')
axes[1,1].set_xlabel('关键词')
axes[1,1].set_ylabel('出现次数')
axes[1,1].tick_params(axis='x', rotation=45)

plt.tight_layout()
plt.show()

print(f"\\n关键词统计:")
for token, count in keyword_counts.items():
    print(f"  {token}: {count} 次")

# 4. Logits提取函数实现

实现从模型中提取每个token的logits的核心功能。

In [None]:
class LogitsExtractor:
    """Logits提取器"""
    
    def __init__(self, model_loader: ModelLoader):
        self.model_loader = model_loader
    
    def extract_logits(self, text: str, max_length: int = MAX_LENGTH) -> Tuple[np.ndarray, List[str]]:
        """
        提取文本的logits
        
        Returns:
            logits: shape (seq_len, vocab_size)
            tokens: token列表
        """
        if not self.model_loader.is_loaded():
            raise ValueError("模型未加载，请先调用 load_model()")
        
        # 编码输入
        inputs = self.model_loader.tokenizer(
            text,
            return_tensors="pt",
            max_length=max_length,
            truncation=True,
            padding=False
        )
        
        input_ids = inputs['input_ids'].to(self.model_loader.device)
        attention_mask = inputs['attention_mask'].to(self.model_loader.device)
        
        # 获取tokens
        tokens = self.model_loader.tokenizer.convert_ids_to_tokens(input_ids[0])
        
        # 提取logits
        with torch.no_grad():
            outputs = self.model_loader.model(input_ids=input_ids, attention_mask=attention_mask)
            logits = outputs.logits[0].cpu().numpy()  # (seq_len, vocab_size)
        
        return logits, tokens
    
    def get_token_probabilities(self, logits: np.ndarray, token_ids: List[int]) -> np.ndarray:
        """获取特定token的概率"""
        # 应用softmax获取概率
        probabilities = torch.softmax(torch.tensor(logits), dim=-1).numpy()
        
        # 提取目标token的概率
        target_probs = probabilities[:, token_ids]
        
        return target_probs
    
    def find_target_positions(self, tokens: List[str], target_tokens: List[str]) -> Dict[str, List[int]]:
        """找到目标token在序列中的位置"""
        positions = {token: [] for token in target_tokens}
        
        for i, token in enumerate(tokens):
            # 处理不同的tokenization格式
            cleaned_token = token.replace('Ġ', '').replace('▁', '').lower().strip()
            for target in target_tokens:
                if cleaned_token == target.lower() or token.lower() == target.lower():
                    positions[target].append(i)
        
        return positions

def analyze_single_text(extractor: LogitsExtractor, text: str, target_tokens: List[str]) -> Dict[str, Any]:
    """分析单个文本的logits"""
    # 提取logits
    logits, tokens = extractor.extract_logits(text)
    
    # 找到目标token位置
    positions = extractor.find_target_positions(tokens, target_tokens)
    
    # 获取目标token的ID
    target_token_ids = []
    for token in target_tokens:
        try:
            # 尝试不同的编码方式
            variations = [token, token.capitalize(), f" {token}"]
            for var in variations:
                encoded = extractor.model_loader.tokenizer.encode(var, add_special_tokens=False)
                if encoded:
                    target_token_ids.append(encoded[0])
                    break
            else:
                target_token_ids.append(0)  # 如果找不到，使用0
        except:
            target_token_ids.append(0)
    
    # 获取概率
    probabilities = extractor.get_token_probabilities(logits, target_token_ids)
    
    return {
        'logits': logits,
        'tokens': tokens,
        'positions': positions,
        'probabilities': probabilities,
        'target_token_ids': target_token_ids
    }

print("Logits提取器已定义完成！")

# 5. 特定词汇位置检测

实现检测关键推理词汇在文本中位置的功能。

In [None]:
def analyze_context_around_tokens(logits: np.ndarray, positions: Dict[str, List[int]], 
                                 tokens: List[str], context_window: int = 5) -> Dict[str, Dict[str, Any]]:
    """分析目标token周围的上下文logits"""
    context_analysis = {}
    
    for target_token, pos_list in positions.items():
        if not pos_list:
            continue
            
        token_analysis = {
            'positions': pos_list,
            'context_logits': [],
            'context_tokens': [],
            'logits_changes': []
        }
        
        for pos in pos_list:
            # 获取上下文窗口
            start_idx = max(0, pos - context_window)
            end_idx = min(len(tokens), pos + context_window + 1)
            
            # 提取上下文logits和tokens
            context_logits = logits[start_idx:end_idx]
            context_tokens = tokens[start_idx:end_idx]
            
            token_analysis['context_logits'].append(context_logits)
            token_analysis['context_tokens'].append(context_tokens)
            
            # 分析logits变化（相对于前一个token）
            if pos > 0:
                current_logits = logits[pos]
                previous_logits = logits[pos - 1]
                logits_change = np.mean(current_logits - previous_logits)
                token_analysis['logits_changes'].append(logits_change)
        
        context_analysis[target_token] = token_analysis
    
    return context_analysis

def visualize_token_positions(tokens: List[str], positions: Dict[str, List[int]], 
                            title: str = "Token Positions") -> None:
    """可视化token位置"""
    fig, ax = plt.subplots(figsize=(15, 4))
    
    # 绘制所有token
    x_positions = range(len(tokens))
    ax.scatter(x_positions, [0] * len(tokens), c='lightgray', s=20, alpha=0.5)
    
    # 突出显示目标token
    colors = plt.cm.Set3(np.linspace(0, 1, len(TARGET_TOKENS)))
    
    for i, (target_token, pos_list) in enumerate(positions.items()):
        if pos_list:
            ax.scatter(pos_list, [0] * len(pos_list), 
                      c=[colors[i]], s=100, label=target_token, alpha=0.8)
            
            # 添加标签
            for pos in pos_list:
                ax.annotate(target_token, (pos, 0), xytext=(0, 20), 
                           textcoords='offset points', ha='center',
                           bbox=dict(boxstyle='round,pad=0.3', facecolor=colors[i], alpha=0.7))
    
    ax.set_xlabel('Token Position')
    ax.set_title(title)
    ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
    ax.grid(True, alpha=0.3)
    
    # 显示部分token文本（避免过于拥挤）
    step = max(1, len(tokens) // 20)  # 显示大约20个token
    ax.set_xticks(range(0, len(tokens), step))
    ax.set_xticklabels([tokens[i][:10] + '...' if len(tokens[i]) > 10 else tokens[i] 
                       for i in range(0, len(tokens), step)], rotation=45, ha='right')
    
    plt.tight_layout()
    plt.show()

# 演示位置检测功能
sample_text = df.iloc[0]['full_text']
print(f"示例文本: {sample_text[:200]}...")
print("\\n准备分析token位置...")

# 注意：这里我们使用一个简化的tokenizer来演示概念
# 实际使用时需要加载真实的模型
def demo_tokenize(text: str) -> List[str]:
    """简化的tokenizer用于演示"""
    import re
    # 简单的单词分割
    tokens = re.findall(r"\\b\\w+\\b|[.,!?;]", text.lower())
    return tokens

demo_tokens = demo_tokenize(sample_text)
demo_positions = {}
for token in TARGET_TOKENS:
    demo_positions[token] = [i for i, t in enumerate(demo_tokens) if t == token]

print(f"发现的目标token位置:")
for token, positions in demo_positions.items():
    if positions:
        print(f"  {token}: 位置 {positions}")

# 可视化演示
visualize_token_positions(demo_tokens, demo_positions, "演示：目标Token位置")

# 6. 单模型Logits分析

在实际环境中加载单个模型并分析其logits模式。

**注意**: 由于模型较大，这里提供了完整的代码框架。在实际运行时请确保有足够的GPU内存。

In [None]:
# 单模型分析函数
def analyze_single_model(model_name: str, texts: List[str], target_tokens: List[str]) -> Dict[str, Any]:
    """分析单个模型的logits"""
    
    print(f"开始分析模型: {model_name}")
    
    # 加载模型
    model_loader = model_loaders[model_name]
    if not model_loader.load_model():
        print(f"模型 {model_name} 加载失败")
        return None
    
    extractor = LogitsExtractor(model_loader)
    
    model_results = {
        'model_name': model_name,
        'text_analyses': [],
        'summary_stats': {}
    }
    
    all_positions = {token: [] for token in target_tokens}
    all_probabilities = {token: [] for token in target_tokens}
    
    # 分析每个文本
    for i, text in enumerate(tqdm(texts, desc=f"分析 {model_name}")):
        try:
            result = analyze_single_text(extractor, text, target_tokens)
            result['text_index'] = i
            model_results['text_analyses'].append(result)
            
            # 收集统计信息
            for token, positions in result['positions'].items():
                all_positions[token].extend(positions)
                
                # 如果找到了token，记录其概率
                if positions and len(result['probabilities']) > 0:
                    token_idx = target_tokens.index(token)
                    for pos in positions:
                        if pos < len(result['probabilities']):
                            all_probabilities[token].append(result['probabilities'][pos, token_idx])
                            
        except Exception as e:
            print(f"分析文本 {i} 时出错: {e}")
            continue
    
    # 计算汇总统计
    model_results['summary_stats'] = {
        'total_texts': len(texts),
        'processed_texts': len(model_results['text_analyses']),
        'token_frequencies': {token: len(positions) for token, positions in all_positions.items()},
        'average_probabilities': {token: np.mean(probs) if probs else 0.0 
                                for token, probs in all_probabilities.items()},
        'std_probabilities': {token: np.std(probs) if probs else 0.0 
                            for token, probs in all_probabilities.items()}
    }
    
    print(f"模型 {model_name} 分析完成")
    return model_results

# 创建模拟分析结果（用于演示，因为实际模型可能无法在此环境中加载）
def create_mock_analysis_results() -> Dict[str, Any]:
    """创建模拟的分析结果用于演示可视化"""
    
    mock_results = {}
    
    for model_name in ['baseline', 'sft', 'rl']:
        # 模拟不同模型的概率分布
        base_probs = {
            'wait': 0.1, 'aha': 0.05, 'check': 0.08, 'think': 0.12, 'hmm': 0.03,
            'let': 0.15, 'actually': 0.07, 'however': 0.06, 'so': 0.20, 'therefore': 0.09
        }
        
        # 为不同模型添加一些变化
        if model_name == 'sft':
            # SFT模型可能在推理词汇上有更高概率
            for token in ['think', 'check', 'therefore']:
                base_probs[token] *= 1.3
        elif model_name == 'rl':
            # RL模型可能在确定性词汇上概率更高
            for token in ['aha', 'actually', 'so']:
                base_probs[token] *= 1.5
        
        mock_results[model_name] = {
            'model_name': model_name,
            'summary_stats': {
                'total_texts': len(df),
                'processed_texts': len(df),
                'token_frequencies': {token: np.random.randint(1, 10) for token in TARGET_TOKENS},
                'average_probabilities': base_probs,
                'std_probabilities': {token: prob * 0.3 for token, prob in base_probs.items()}
            }
        }
    
    return mock_results

# 演示：创建模拟结果
print("创建模拟分析结果用于演示...")
mock_results = create_mock_analysis_results()

# 显示结果概览
for model_name, results in mock_results.items():
    print(f"\\n{model_name.upper()} 模型统计:")
    stats = results['summary_stats']
    print(f"  处理文本数: {stats['processed_texts']}")
    print(f"  目标token平均概率:")
    for token, prob in stats['average_probabilities'].items():
        print(f"    {token}: {prob:.3f} ± {stats['std_probabilities'][token]:.3f}")

print("\\n如需分析真实模型，请取消注释以下代码：")
print("# real_results = {}")
print("# for model_name in ['baseline']:  # 先从一个模型开始")
print("#     texts = df['full_text'].tolist()")
print("#     real_results[model_name] = analyze_single_model(model_name, texts, TARGET_TOKENS)")

# 7. 多模型对比分析

比较Baseline、SFT和RL三个模型的logits差异。

In [None]:
def compare_models(results_dict: Dict[str, Dict], metric: str = 'average_probabilities') -> pd.DataFrame:
    """比较多个模型的结果"""
    
    comparison_data = []
    
    for model_name, results in results_dict.items():
        stats = results['summary_stats']
        for token, value in stats[metric].items():
            comparison_data.append({
                'model': model_name,
                'token': token,
                'value': value
            })
    
    return pd.DataFrame(comparison_data)

def plot_model_comparison(results_dict: Dict[str, Dict]):
    """绘制模型对比图"""
    
    # 准备数据
    df_comparison = compare_models(results_dict, 'average_probabilities')
    
    # 创建子图
    fig, axes = plt.subplots(2, 2, figsize=(15, 12))
    
    # 1. 分组柱状图
    pivot_data = df_comparison.pivot(index='token', columns='model', values='value')
    pivot_data.plot(kind='bar', ax=axes[0, 0], width=0.8)
    axes[0, 0].set_title('模型间目标Token平均概率对比')
    axes[0, 0].set_ylabel('平均概率')
    axes[0, 0].legend(title='模型')
    axes[0, 0].tick_params(axis='x', rotation=45)
    
    # 2. 热力图
    heatmap_data = pivot_data.T
    sns.heatmap(heatmap_data, annot=True, fmt='.3f', cmap='YlOrRd', ax=axes[0, 1])
    axes[0, 1].set_title('模型-Token概率热力图')
    
    # 3. 雷达图（使用极坐标）
    angles = np.linspace(0, 2 * np.pi, len(TARGET_TOKENS), endpoint=False)
    angles = np.concatenate((angles, [angles[0]]))  # 闭合图形
    
    ax_radar = plt.subplot(2, 2, 3, projection='polar')
    colors = ['red', 'blue', 'green']
    
    for i, (model_name, results) in enumerate(results_dict.items()):
        values = [results['summary_stats']['average_probabilities'][token] for token in TARGET_TOKENS]
        values += [values[0]]  # 闭合图形
        
        ax_radar.plot(angles, values, 'o-', linewidth=2, label=model_name, color=colors[i])
        ax_radar.fill(angles, values, alpha=0.25, color=colors[i])
    
    ax_radar.set_xticks(angles[:-1])
    ax_radar.set_xticklabels(TARGET_TOKENS)
    ax_radar.set_title('模型性能雷达图')
    ax_radar.legend(loc='upper right', bbox_to_anchor=(0.1, 0.1))
    
    # 4. 箱线图（模拟标准差）
    box_data = []
    box_labels = []
    for model_name, results in results_dict.items():
        for token in TARGET_TOKENS:
            mean_prob = results['summary_stats']['average_probabilities'][token]
            std_prob = results['summary_stats']['std_probabilities'][token]
            # 模拟数据点
            simulated_data = np.random.normal(mean_prob, std_prob, 50)
            box_data.extend(simulated_data)
            box_labels.extend([f"{model_name}\\n{token}"] * 50)
    
    df_box = pd.DataFrame({'value': box_data, 'model_token': box_labels})
    
    # 选择部分数据避免过于拥挤
    selected_tokens = ['wait', 'think', 'check', 'so']
    df_box_filtered = df_box[df_box['model_token'].str.contains('|'.join(selected_tokens))]
    
    sns.boxplot(data=df_box_filtered, x='model_token', y='value', ax=axes[1, 1])
    axes[1, 1].set_title('概率分布箱线图（选择部分Token）')
    axes[1, 1].tick_params(axis='x', rotation=45)
    
    plt.tight_layout()
    plt.show()

# 统计显著性测试
def statistical_comparison(results_dict: Dict[str, Dict]) -> Dict[str, Any]:
    """进行统计显著性测试"""
    from scipy import stats
    
    comparison_results = {}
    models = list(results_dict.keys())
    
    for i in range(len(models)):
        for j in range(i + 1, len(models)):
            model1, model2 = models[i], models[j]
            
            # 收集两个模型的概率数据
            probs1 = list(results_dict[model1]['summary_stats']['average_probabilities'].values())
            probs2 = list(results_dict[model2]['summary_stats']['average_probabilities'].values())
            
            # 进行t检验
            t_stat, p_value = stats.ttest_ind(probs1, probs2)
            
            comparison_results[f"{model1}_vs_{model2}"] = {
                't_statistic': t_stat,
                'p_value': p_value,
                'significant': p_value < 0.05
            }
    
    return comparison_results

# 执行模型对比分析
print("=== 模型对比分析 ===")

# 绘制对比图
plot_model_comparison(mock_results)

# 统计测试
stats_results = statistical_comparison(mock_results)
print("\\n统计显著性测试结果:")
for comparison, result in stats_results.items():
    models = comparison.replace('_vs_', ' vs ')
    significance = "显著" if result['significant'] else "不显著"
    print(f"  {models}: t={result['t_statistic']:.3f}, p={result['p_value']:.3f} ({significance})")

# 模型差异分析
print("\\n=== 模型间差异分析 ===")
for i, model1 in enumerate(['baseline', 'sft', 'rl']):
    for j, model2 in enumerate(['baseline', 'sft', 'rl']):
        if i < j:
            print(f"\\n{model1.upper()} vs {model2.upper()}:")
            for token in TARGET_TOKENS:
                prob1 = mock_results[model1]['summary_stats']['average_probabilities'][token]
                prob2 = mock_results[model2]['summary_stats']['average_probabilities'][token]
                diff = prob2 - prob1
                change = "增加" if diff > 0 else "减少"
                print(f"  {token}: {change} {abs(diff):.3f} ({diff/prob1*100:+.1f}%)")

# 8. 难度级别对比分析

分析不同难度级别(1-5)题目的logits模式差异。

In [None]:
def analyze_by_difficulty_level(df: pd.DataFrame, target_tokens: List[str]) -> Dict[int, Dict[str, float]]:
    """按难度级别分析token使用模式"""
    
    level_analysis = {}
    
    for level in range(1, 6):
        level_data = df[df['level'] == level]
        if len(level_data) == 0:
            continue
        
        level_token_stats = {}
        
        for token in target_tokens:
            # 计算该级别中包含此token的文本比例
            texts_with_token = level_data['full_text'].str.lower().str.contains(token, regex=False)
            frequency = texts_with_token.sum() / len(level_data)
            
            # 模拟概率数据（实际中应从logits分析获得）
            base_prob = np.random.uniform(0.02, 0.25)
            # 难度越高，某些推理词汇概率可能更高
            if token in ['think', 'check', 'however', 'therefore'] and level > 3:
                level_multiplier = 1 + (level - 3) * 0.2
            else:
                level_multiplier = 1.0
            
            level_token_stats[token] = {
                'frequency': frequency,
                'avg_probability': base_prob * level_multiplier,
                'text_count': len(level_data)
            }
        
        level_analysis[level] = level_token_stats
    
    return level_analysis

def plot_difficulty_analysis(level_analysis: Dict[int, Dict[str, Dict[str, float]]]):
    """绘制难度级别分析图"""
    
    fig, axes = plt.subplots(2, 2, figsize=(15, 12))
    
    # 1. 频率热力图
    freq_data = []
    prob_data = []
    levels = sorted(level_analysis.keys())
    
    for level in levels:
        freq_row = []
        prob_row = []
        for token in TARGET_TOKENS:
            if token in level_analysis[level]:
                freq_row.append(level_analysis[level][token]['frequency'])
                prob_row.append(level_analysis[level][token]['avg_probability'])
            else:
                freq_row.append(0)
                prob_row.append(0)
        freq_data.append(freq_row)
        prob_data.append(prob_row)
    
    # 频率热力图
    sns.heatmap(freq_data, 
                xticklabels=TARGET_TOKENS, 
                yticklabels=[f'Level {l}' for l in levels],
                annot=True, fmt='.2f', cmap='Blues', ax=axes[0, 0])
    axes[0, 0].set_title('Token出现频率按难度级别')
    axes[0, 0].tick_params(axis='x', rotation=45)
    
    # 概率热力图
    sns.heatmap(prob_data, 
                xticklabels=TARGET_TOKENS, 
                yticklabels=[f'Level {l}' for l in levels],
                annot=True, fmt='.3f', cmap='Reds', ax=axes[0, 1])
    axes[0, 1].set_title('Token平均概率按难度级别')
    axes[0, 1].tick_params(axis='x', rotation=45)
    
    # 3. 折线图 - 选择几个关键token
    key_tokens = ['think', 'check', 'wait', 'so']
    for token in key_tokens:
        frequencies = [level_analysis[level][token]['frequency'] 
                      if token in level_analysis[level] else 0 
                      for level in levels]
        axes[1, 0].plot(levels, frequencies, marker='o', label=token, linewidth=2)
    
    axes[1, 0].set_xlabel('难度级别')
    axes[1, 0].set_ylabel('出现频率')
    axes[1, 0].set_title('关键Token频率变化趋势')
    axes[1, 0].legend()
    axes[1, 0].grid(True, alpha=0.3)
    
    # 4. 箱线图 - 各级别的总体token使用
    box_data = []
    box_labels = []
    
    for level in levels:
        level_probs = [level_analysis[level][token]['avg_probability'] 
                      for token in TARGET_TOKENS 
                      if token in level_analysis[level]]
        box_data.extend(level_probs)
        box_labels.extend([f'Level {level}'] * len(level_probs))
    
    df_box = pd.DataFrame({'probability': box_data, 'level': box_labels})
    sns.boxplot(data=df_box, x='level', y='probability', ax=axes[1, 1])
    axes[1, 1].set_title('各难度级别Token概率分布')
    axes[1, 1].set_ylabel('概率')
    
    plt.tight_layout()
    plt.show()

def analyze_difficulty_patterns(level_analysis: Dict[int, Dict[str, Dict[str, float]]]):
    """分析难度级别的模式"""
    
    print("=== 难度级别分析结果 ===")
    
    # 计算相关性
    levels = sorted(level_analysis.keys())
    
    for token in TARGET_TOKENS:
        frequencies = []
        probabilities = []
        
        for level in levels:
            if token in level_analysis[level]:
                frequencies.append(level_analysis[level][token]['frequency'])
                probabilities.append(level_analysis[level][token]['avg_probability'])
            else:
                frequencies.append(0)
                probabilities.append(0)
        
        # 计算与难度级别的相关性
        freq_corr = np.corrcoef(levels, frequencies)[0, 1]
        prob_corr = np.corrcoef(levels, probabilities)[0, 1]
        
        print(f"\\n{token}:")
        print(f"  频率与难度相关性: {freq_corr:.3f}")
        print(f"  概率与难度相关性: {prob_corr:.3f}")
        
        if abs(freq_corr) > 0.5:
            trend = "正相关" if freq_corr > 0 else "负相关"
            print(f"  -> 频率与难度呈{trend} (|r|>0.5)")
        
        if abs(prob_corr) > 0.5:
            trend = "正相关" if prob_corr > 0 else "负相关"
            print(f"  -> 概率与难度呈{trend} (|r|>0.5)")

# 执行难度级别分析
print("执行难度级别分析...")
level_analysis = analyze_by_difficulty_level(df, TARGET_TOKENS)

# 绘制分析图
plot_difficulty_analysis(level_analysis)

# 分析模式
analyze_difficulty_patterns(level_analysis)

# 显示数据概览
print("\\n=== 各级别数据概览 ===")
for level, stats in level_analysis.items():
    text_count = next(iter(stats.values()))['text_count']
    avg_freq = np.mean([data['frequency'] for data in stats.values()])
    avg_prob = np.mean([data['avg_probability'] for data in stats.values()])
    print(f"Level {level}: {text_count} 条文本, 平均token频率: {avg_freq:.3f}, 平均概率: {avg_prob:.3f}")

# 9. 正确性对比分析

分析答对和答错题目的logits差异，探索模型自信度与准确性的关系。

In [None]:
def analyze_by_correctness(df: pd.DataFrame, target_tokens: List[str]) -> Dict[bool, Dict[str, float]]:
    """按答题正确性分析token模式"""
    
    correctness_analysis = {}
    
    for is_correct in [True, False]:
        subset = df[df['is_correct'] == is_correct]
        if len(subset) == 0:
            continue
        
        token_stats = {}
        
        for token in target_tokens:
            # 计算token出现频率
            texts_with_token = subset['full_text'].str.lower().str.contains(token, regex=False)
            frequency = texts_with_token.sum() / len(subset)
            
            # 模拟概率数据
            base_prob = np.random.uniform(0.03, 0.20)
            
            # 假设：正确答案中某些推理词汇概率更高
            if is_correct and token in ['check', 'think', 'therefore', 'actually']:
                correctness_multiplier = 1.3
            elif not is_correct and token in ['wait', 'hmm', 'aha']:
                correctness_multiplier = 1.2  # 错误答案中犹豫词汇更多
            else:
                correctness_multiplier = 1.0
            
            token_stats[token] = {
                'frequency': frequency,
                'avg_probability': base_prob * correctness_multiplier,
                'text_count': len(subset),
                'variance': (base_prob * correctness_multiplier) * 0.3
            }
        
        correctness_analysis[is_correct] = token_stats
    
    return correctness_analysis

def plot_correctness_analysis(correctness_analysis: Dict[bool, Dict[str, Dict[str, float]]]):
    """绘制正确性分析图"""
    
    fig, axes = plt.subplots(2, 2, figsize=(15, 12))
    
    # 准备数据
    correct_data = correctness_analysis[True]
    incorrect_data = correctness_analysis[False]
    
    # 1. 对比柱状图 - 频率
    tokens = TARGET_TOKENS
    correct_freq = [correct_data[token]['frequency'] for token in tokens]
    incorrect_freq = [incorrect_data[token]['frequency'] for token in tokens]
    
    x = np.arange(len(tokens))
    width = 0.35
    
    axes[0, 0].bar(x - width/2, correct_freq, width, label='正确答案', color='green', alpha=0.7)
    axes[0, 0].bar(x + width/2, incorrect_freq, width, label='错误答案', color='red', alpha=0.7)
    axes[0, 0].set_xlabel('Token')
    axes[0, 0].set_ylabel('出现频率')
    axes[0, 0].set_title('Token出现频率：正确 vs 错误答案')
    axes[0, 0].set_xticks(x)
    axes[0, 0].set_xticklabels(tokens, rotation=45)
    axes[0, 0].legend()
    axes[0, 0].grid(True, alpha=0.3)
    
    # 2. 对比柱状图 - 概率
    correct_prob = [correct_data[token]['avg_probability'] for token in tokens]
    incorrect_prob = [incorrect_data[token]['avg_probability'] for token in tokens]
    
    axes[0, 1].bar(x - width/2, correct_prob, width, label='正确答案', color='green', alpha=0.7)
    axes[0, 1].bar(x + width/2, incorrect_prob, width, label='错误答案', color='red', alpha=0.7)
    axes[0, 1].set_xlabel('Token')
    axes[0, 1].set_ylabel('平均概率')
    axes[0, 1].set_title('Token平均概率：正确 vs 错误答案')
    axes[0, 1].set_xticks(x)
    axes[0, 1].set_xticklabels(tokens, rotation=45)
    axes[0, 1].legend()
    axes[0, 1].grid(True, alpha=0.3)
    
    # 3. 差异热力图
    freq_diff = [correct_freq[i] - incorrect_freq[i] for i in range(len(tokens))]
    prob_diff = [correct_prob[i] - incorrect_prob[i] for i in range(len(tokens))]
    
    diff_data = np.array([freq_diff, prob_diff])
    
    sns.heatmap(diff_data, 
                xticklabels=tokens, 
                yticklabels=['频率差异', '概率差异'],
                annot=True, fmt='.3f', cmap='RdBu_r', center=0, ax=axes[1, 0])
    axes[1, 0].set_title('正确 - 错误答案的差异\\n(正值表示正确答案中更高)')
    axes[1, 0].tick_params(axis='x', rotation=45)
    
    # 4. 散点图 - 频率vs概率
    for is_correct, label, color in [(True, '正确答案', 'green'), (False, '错误答案', 'red')]:
        data = correctness_analysis[is_correct]
        freqs = [data[token]['frequency'] for token in tokens]
        probs = [data[token]['avg_probability'] for token in tokens]
        
        axes[1, 1].scatter(freqs, probs, label=label, color=color, alpha=0.7, s=60)
        
        # 添加token标签
        for i, token in enumerate(tokens):
            axes[1, 1].annotate(token, (freqs[i], probs[i]), 
                              xytext=(3, 3), textcoords='offset points', 
                              fontsize=8, alpha=0.8)
    
    axes[1, 1].set_xlabel('出现频率')
    axes[1, 1].set_ylabel('平均概率')
    axes[1, 1].set_title('Token频率 vs 概率分布')
    axes[1, 1].legend()
    axes[1, 1].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()

def correctness_statistical_analysis(correctness_analysis: Dict[bool, Dict[str, Dict[str, float]]]):
    """正确性统计分析"""
    
    print("=== 正确性统计分析 ===")
    
    correct_data = correctness_analysis[True]
    incorrect_data = correctness_analysis[False]
    
    print(f"正确答案数量: {correct_data[TARGET_TOKENS[0]]['text_count']}")
    print(f"错误答案数量: {incorrect_data[TARGET_TOKENS[0]]['text_count']}")
    
    # 计算显著差异的token
    significant_tokens = []
    
    print("\\n各Token在正确性上的差异:")
    for token in TARGET_TOKENS:
        correct_freq = correct_data[token]['frequency']
        incorrect_freq = incorrect_data[token]['frequency']
        correct_prob = correct_data[token]['avg_probability']
        incorrect_prob = incorrect_data[token]['avg_probability']
        
        freq_diff = correct_freq - incorrect_freq
        prob_diff = correct_prob - incorrect_prob
        
        # 简单的差异判断（实际应用中可以用统计检验）
        is_significant = abs(freq_diff) > 0.1 or abs(prob_diff) > 0.05
        
        print(f"\\n{token}:")
        print(f"  频率差异: {freq_diff:+.3f} ({'显著' if abs(freq_diff) > 0.1 else '不显著'})")
        print(f"  概率差异: {prob_diff:+.3f} ({'显著' if abs(prob_diff) > 0.05 else '不显著'})")
        
        if is_significant:
            significant_tokens.append(token)
            if freq_diff > 0:
                print(f"  -> {token} 在正确答案中更常见")
            else:
                print(f"  -> {token} 在错误答案中更常见")
    
    print(f"\\n显著差异的Token: {significant_tokens}")
    
    # 计算整体模式
    correct_avg_freq = np.mean([correct_data[token]['frequency'] for token in TARGET_TOKENS])
    incorrect_avg_freq = np.mean([incorrect_data[token]['frequency'] for token in TARGET_TOKENS])
    correct_avg_prob = np.mean([correct_data[token]['avg_probability'] for token in TARGET_TOKENS])
    incorrect_avg_prob = np.mean([incorrect_data[token]['avg_probability'] for token in TARGET_TOKENS])
    
    print(f"\\n整体模式:")
    print(f"  正确答案平均token频率: {correct_avg_freq:.3f}")
    print(f"  错误答案平均token频率: {incorrect_avg_freq:.3f}")
    print(f"  正确答案平均token概率: {correct_avg_prob:.3f}")
    print(f"  错误答案平均token概率: {incorrect_avg_prob:.3f}")

# 执行正确性分析
print("执行正确性对比分析...")
correctness_analysis = analyze_by_correctness(df, TARGET_TOKENS)

# 绘制分析图
plot_correctness_analysis(correctness_analysis)

# 统计分析
correctness_statistical_analysis(correctness_analysis)

# 10. 可视化热力图实现

创建高级的logits热力图，突出显示关键词位置。

In [None]:
def create_advanced_heatmap(logits_data: Dict[str, np.ndarray], 
                          tokens: List[str], 
                          target_positions: Dict[str, List[int]],
                          title: str = "Logits Heatmap") -> None:
    """创建高级logits热力图"""
    
    # 模拟logits数据（实际中从模型获取）
    seq_len = len(tokens)
    vocab_size = 1000  # 假设词汇表大小
    
    # 创建模拟的logits矩阵
    simulated_logits = {}
    for model_name in logits_data.keys():
        # 创建基础logits矩阵
        base_logits = np.random.randn(seq_len, vocab_size) * 2
        
        # 在目标token位置添加模式
        for target_token, positions in target_positions.items():
            for pos in positions:
                if pos < seq_len:
                    # 在目标位置附近增强某些维度
                    enhancement_dims = np.random.choice(vocab_size, 50, replace=False)
                    base_logits[pos, enhancement_dims] += np.random.uniform(2, 5, 50)
        
        simulated_logits[model_name] = base_logits
    
    # 创建多模型对比热力图
    n_models = len(simulated_logits)
    fig, axes = plt.subplots(n_models, 1, figsize=(20, 6 * n_models))
    
    if n_models == 1:
        axes = [axes]
    
    for i, (model_name, logits) in enumerate(simulated_logits.items()):
        # 选择top-k最重要的词汇维度进行可视化
        top_k = 100
        mean_logits = np.mean(logits, axis=0)
        top_indices = np.argsort(mean_logits)[-top_k:]
        
        # 提取top-k维度的logits
        heatmap_data = logits[:, top_indices].T
        
        # 创建热力图
        im = axes[i].imshow(heatmap_data, cmap='viridis', aspect='auto', interpolation='nearest')
        
        # 标记目标token位置
        for target_token, positions in target_positions.items():
            for pos in positions:
                if pos < seq_len:
                    # 绘制垂直线标记
                    axes[i].axvline(x=pos, color='red', linestyle='--', linewidth=2, alpha=0.8)
                    
                    # 添加token标签
                    axes[i].text(pos, top_k // 2, target_token, 
                               rotation=90, verticalalignment='center',
                               color='white', fontweight='bold', fontsize=10,
                               bbox=dict(boxstyle='round,pad=0.2', facecolor='red', alpha=0.7))
        
        # 设置标签和标题
        axes[i].set_title(f'{title} - {model_name.upper()}', fontsize=14, fontweight='bold')
        axes[i].set_xlabel('Token Position')
        axes[i].set_ylabel(f'Top {top_k} Vocabulary Dimensions')
        
        # 设置x轴标签（显示部分token）
        step = max(1, seq_len // 15)
        tick_positions = range(0, seq_len, step)
        tick_labels = [tokens[i][:8] + '...' if len(tokens[i]) > 8 else tokens[i] 
                      for i in tick_positions]
        axes[i].set_xticks(tick_positions)
        axes[i].set_xticklabels(tick_labels, rotation=45, ha='right')
        
        # 添加颜色条
        cbar = plt.colorbar(im, ax=axes[i], shrink=0.8)
        cbar.set_label('Logits Value', rotation=270, labelpad=20)
    
    plt.tight_layout()
    plt.show()

def create_interactive_heatmap(model_results: Dict[str, Dict], target_tokens: List[str]) -> go.Figure:
    """创建交互式热力图"""
    
    # 创建子图
    n_models = len(model_results)
    fig = make_subplots(
        rows=n_models, cols=1,
        subplot_titles=[f"{name.upper()} Model" for name in model_results.keys()],
        vertical_spacing=0.1
    )
    
    # 为每个模型创建热力图
    for i, (model_name, results) in enumerate(model_results.items()):
        # 模拟概率矩阵数据
        seq_length = 50  # 假设序列长度
        prob_matrix = np.random.rand(len(target_tokens), seq_length)
        
        # 在某些位置增强概率以模拟目标token
        for j, token in enumerate(target_tokens):
            # 随机选择几个位置作为该token的出现位置
            positions = np.random.choice(seq_length, np.random.randint(1, 4), replace=False)
            for pos in positions:
                prob_matrix[j, pos] = np.random.uniform(0.7, 1.0)
        
        # 添加热力图
        heatmap = go.Heatmap(
            z=prob_matrix,
            x=list(range(seq_length)),
            y=target_tokens,
            colorscale='Viridis',
            showscale=(i == 0),  # 只在第一个子图显示色标
            hoverongaps=False,
            hovertemplate='Position: %{x}<br>Token: %{y}<br>Probability: %{z:.3f}<extra></extra>'
        )
        
        fig.add_trace(heatmap, row=i+1, col=1)
    
    # 更新布局
    fig.update_layout(
        title="Interactive Logits Heatmap Comparison",
        height=300 * n_models,
        showlegend=False
    )
    
    # 更新y轴标签
    for i in range(n_models):
        fig.update_yaxes(title_text="Target Tokens", row=i+1, col=1)
        fig.update_xaxes(title_text="Token Position", row=i+1, col=1)
    
    return fig

def create_difference_heatmap(results_dict: Dict[str, Dict]) -> None:
    """创建模型间差异热力图"""
    
    models = list(results_dict.keys())
    n_comparisons = len(models) * (len(models) - 1) // 2
    
    if n_comparisons == 0:
        return
    
    fig, axes = plt.subplots(1, n_comparisons, figsize=(6 * n_comparisons, 5))
    
    if n_comparisons == 1:
        axes = [axes]
    
    comparison_idx = 0
    
    for i in range(len(models)):
        for j in range(i + 1, len(models)):
            model1, model2 = models[i], models[j]
            
            # 计算概率差异
            prob1 = results_dict[model1]['summary_stats']['average_probabilities']
            prob2 = results_dict[model2]['summary_stats']['average_probabilities']
            
            differences = []
            for token in TARGET_TOKENS:
                diff = prob2[token] - prob1[token]
                differences.append(diff)
            
            # 创建差异矩阵（1D转2D用于可视化）
            diff_matrix = np.array(differences).reshape(-1, 1)
            
            # 绘制热力图
            im = axes[comparison_idx].imshow(diff_matrix, cmap='RdBu_r', 
                                           aspect='auto', vmin=-0.1, vmax=0.1)
            
            # 设置标签
            axes[comparison_idx].set_title(f'{model2.upper()} - {model1.upper()}')
            axes[comparison_idx].set_yticks(range(len(TARGET_TOKENS)))
            axes[comparison_idx].set_yticklabels(TARGET_TOKENS)
            axes[comparison_idx].set_xticks([])
            
            # 添加数值标注
            for k, diff in enumerate(differences):
                color = 'white' if abs(diff) > 0.05 else 'black'
                axes[comparison_idx].text(0, k, f'{diff:.3f}', 
                                        ha='center', va='center', color=color, fontweight='bold')
            
            # 添加颜色条
            cbar = plt.colorbar(im, ax=axes[comparison_idx], shrink=0.6)
            cbar.set_label('Probability Difference', rotation=270, labelpad=15)
            
            comparison_idx += 1
    
    plt.tight_layout()
    plt.show()

# 演示热力图功能
print("=== 创建高级可视化热力图 ===")

# 使用示例数据创建热力图
sample_text = df.iloc[0]['full_text']
demo_tokens = demo_tokenize(sample_text)
demo_positions = {}
for token in TARGET_TOKENS:
    demo_positions[token] = [i for i, t in enumerate(demo_tokens) if t == token]

# 创建模拟logits数据
mock_logits_data = {
    'baseline': np.random.randn(len(demo_tokens), 1000),
    'sft': np.random.randn(len(demo_tokens), 1000),
    'rl': np.random.randn(len(demo_tokens), 1000)
}

# 生成高级热力图
create_advanced_heatmap(mock_logits_data, demo_tokens, demo_positions, 
                       "Advanced Logits Analysis Heatmap")

# 创建交互式热力图
print("\\n创建交互式热力图...")
interactive_fig = create_interactive_heatmap(mock_results, TARGET_TOKENS)
interactive_fig.show()

# 创建差异热力图
print("\\n创建模型差异热力图...")
create_difference_heatmap(mock_results)

# 11. Logits变化折线图

展示关键词前后logits变化的动态模式。

In [None]:
def plot_logits_trends(model_results: Dict[str, Dict], target_tokens: List[str], 
                      context_window: int = 10) -> None:
    """绘制目标token周围的logits变化趋势"""
    
    n_tokens = len(target_tokens)
    n_models = len(model_results)
    
    fig, axes = plt.subplots(n_tokens, n_models, figsize=(5 * n_models, 4 * n_tokens))
    
    if n_tokens == 1:
        axes = axes.reshape(1, -1)
    if n_models == 1:
        axes = axes.reshape(-1, 1)
    
    colors = ['blue', 'red', 'green', 'orange', 'purple']
    
    for token_idx, target_token in enumerate(target_tokens):
        for model_idx, (model_name, results) in enumerate(model_results.items()):
            
            # 模拟该token在不同位置的logits变化
            sequence_positions = range(-context_window, context_window + 1)
            
            # 生成多条轨迹（代表不同的文本样本）
            n_samples = 5
            for sample_idx in range(n_samples):
                # 模拟logits变化模式
                base_logits = np.random.randn(len(sequence_positions)) * 0.5
                
                # 在目标位置(0)附近增加特定模式
                peak_position = 0  # 目标token位置
                for i, pos in enumerate(sequence_positions):
                    if pos == peak_position:
                        # 目标token位置的logits增强
                        base_logits[i] += np.random.uniform(1, 3)
                    elif abs(pos - peak_position) <= 2:
                        # 目标token附近的logits变化
                        distance_factor = 1 - abs(pos - peak_position) / 3
                        base_logits[i] += np.random.uniform(0, 1) * distance_factor
                
                # 添加模型特定的模式
                if model_name == 'sft' and target_token in ['think', 'check']:
                    base_logits += 0.3  # SFT模型在推理token上更强
                elif model_name == 'rl' and target_token in ['so', 'therefore']:
                    base_logits += 0.4  # RL模型在结论token上更强
                
                # 绘制轨迹
                alpha = 0.7 if sample_idx == 0 else 0.3
                linewidth = 2 if sample_idx == 0 else 1
                label = f'{model_name}_{target_token}' if sample_idx == 0 else None
                
                axes[token_idx, model_idx].plot(sequence_positions, base_logits, 
                                               color=colors[sample_idx], alpha=alpha, 
                                               linewidth=linewidth, label=label)
            
            # 标记目标token位置
            axes[token_idx, model_idx].axvline(x=0, color='red', linestyle='--', 
                                             linewidth=2, alpha=0.8)
            axes[token_idx, model_idx].axhline(y=0, color='gray', linestyle='-', 
                                             linewidth=0.5, alpha=0.5)
            
            # 设置标签和标题
            axes[token_idx, model_idx].set_title(f'{target_token} - {model_name.upper()}')
            axes[token_idx, model_idx].set_xlabel('Relative Position to Target Token')
            axes[token_idx, model_idx].set_ylabel('Logits Value')
            axes[token_idx, model_idx].grid(True, alpha=0.3)
            
            # 添加阴影区域表示目标token附近
            axes[token_idx, model_idx].axvspan(-2, 2, alpha=0.1, color='yellow', 
                                             label='Target Context' if token_idx == 0 and model_idx == 0 else "")
    
    plt.tight_layout()
    plt.show()

def create_interactive_trend_plot(model_results: Dict[str, Dict], target_tokens: List[str]) -> go.Figure:
    """创建交互式logits趋势图"""
    
    fig = go.Figure()
    
    context_window = 10
    sequence_positions = list(range(-context_window, context_window + 1))
    
    colors = {'baseline': 'blue', 'sft': 'red', 'rl': 'green'}
    
    for model_name, results in model_results.items():
        for target_token in target_tokens:
            # 模拟logits变化
            base_trend = np.random.randn(len(sequence_positions)) * 0.3
            
            # 在目标位置添加峰值
            peak_idx = len(sequence_positions) // 2  # 中心位置
            base_trend[peak_idx] += np.random.uniform(1, 2)
            
            # 添加平滑的上下文效应
            for i in range(len(base_trend)):
                distance_from_peak = abs(i - peak_idx)
                if distance_from_peak <= 3:
                    base_trend[i] += (3 - distance_from_peak) * 0.2
            
            # 添加模型特定效应
            if model_name == 'sft':
                base_trend += 0.2
            elif model_name == 'rl':
                base_trend += 0.3
            
            fig.add_trace(go.Scatter(
                x=sequence_positions,
                y=base_trend,
                mode='lines+markers',
                name=f'{model_name}_{target_token}',
                line=dict(color=colors.get(model_name, 'black'), width=2),
                marker=dict(size=4),
                hovertemplate=f'Model: {model_name}<br>Token: {target_token}<br>Position: %{{x}}<br>Logits: %{{y:.3f}}<extra></extra>'
            ))
    
    # 添加目标位置标记
    fig.add_vline(x=0, line_dash="dash", line_color="red", 
                  annotation_text="Target Token Position")
    
    fig.update_layout(
        title="Interactive Logits Trends Around Target Tokens",
        xaxis_title="Relative Position to Target Token",
        yaxis_title="Logits Value",
        hovermode='x unified',
        height=600
    )
    
    return fig

def analyze_logits_patterns(model_results: Dict[str, Dict]) -> Dict[str, Any]:
    """分析logits变化模式"""
    
    pattern_analysis = {}
    
    for model_name, results in model_results.items():
        model_patterns = {}
        
        # 分析每个target token的模式
        for token in TARGET_TOKENS:
            avg_prob = results['summary_stats']['average_probabilities'][token]
            std_prob = results['summary_stats']['std_probabilities'][token]
            frequency = results['summary_stats']['token_frequencies'][token]
            
            # 计算模式指标
            confidence_score = avg_prob / (std_prob + 1e-6)  # 避免除零
            prominence_score = avg_prob * frequency  # 概率 × 频率
            
            model_patterns[token] = {
                'confidence': confidence_score,
                'prominence': prominence_score,
                'variability': std_prob / (avg_prob + 1e-6),
                'frequency_normalized': frequency / sum(results['summary_stats']['token_frequencies'].values())
            }
        
        pattern_analysis[model_name] = model_patterns
    
    return pattern_analysis

# 执行logits趋势分析
print("=== Logits变化趋势分析 ===")

# 绘制趋势图
plot_logits_trends(mock_results, TARGET_TOKENS[:4])  # 只显示前4个token避免过于拥挤

# 创建交互式趋势图
print("\\n创建交互式趋势图...")
interactive_trend_fig = create_interactive_trend_plot(mock_results, TARGET_TOKENS[:3])
interactive_trend_fig.show()

# 分析logits模式
pattern_analysis = analyze_logits_patterns(mock_results)

print("\\n=== Logits模式分析结果 ===")
for model_name, patterns in pattern_analysis.items():
    print(f"\\n{model_name.upper()} 模型模式:")
    
    # 按confidence排序
    sorted_patterns = sorted(patterns.items(), key=lambda x: x[1]['confidence'], reverse=True)
    
    print("  Top 3 最稳定的tokens (高confidence):")
    for token, metrics in sorted_patterns[:3]:
        print(f"    {token}: confidence={metrics['confidence']:.2f}, prominence={metrics['prominence']:.3f}")
    
    # 按prominence排序
    sorted_by_prominence = sorted(patterns.items(), key=lambda x: x[1]['prominence'], reverse=True)
    print("  Top 3 最突出的tokens (高prominence):")
    for token, metrics in sorted_by_prominence[:3]:
        print(f"    {token}: prominence={metrics['prominence']:.3f}, variability={metrics['variability']:.2f}")

# 创建模式对比图
def plot_pattern_comparison(pattern_analysis: Dict[str, Dict]) -> None:
    """绘制模式对比图"""
    
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
    
    # 准备数据
    models = list(pattern_analysis.keys())
    tokens = TARGET_TOKENS
    
    metrics = ['confidence', 'prominence', 'variability', 'frequency_normalized']
    metric_titles = ['Confidence Score', 'Prominence Score', 'Variability', 'Normalized Frequency']
    
    for i, (metric, title) in enumerate(zip(metrics, metric_titles)):
        ax = axes[i // 2, i % 2]
        
        # 为每个模型创建数据
        for j, model in enumerate(models):
            values = [pattern_analysis[model][token][metric] for token in tokens]
            x_positions = np.arange(len(tokens)) + j * 0.25
            ax.bar(x_positions, values, width=0.25, label=model, alpha=0.8)
        
        ax.set_title(title)
        ax.set_xlabel('Tokens')
        ax.set_ylabel(title)
        ax.set_xticks(np.arange(len(tokens)) + 0.25)
        ax.set_xticklabels(tokens, rotation=45)
        ax.legend()
        ax.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()

print("\\n绘制模式对比图...")
plot_pattern_comparison(pattern_analysis)

# 12. 统计分析与模式发现

进行统计测试，量化不同条件下的logits差异显著性，总结发现的模式。

In [None]:
from scipy import stats
from sklearn.decomposition import PCA
from sklearn.cluster import KMeans

def comprehensive_statistical_analysis(mock_results: Dict, level_analysis: Dict, 
                                     correctness_analysis: Dict) -> Dict[str, Any]:
    """综合统计分析"""
    
    analysis_results = {
        'model_comparisons': {},
        'level_correlations': {},
        'correctness_effects': {},
        'principal_components': {},
        'clustering_results': {}
    }
    
    # 1. 模型间显著性检验
    print("=== 1. 模型间差异显著性检验 ===")
    models = list(mock_results.keys())
    
    for i in range(len(models)):
        for j in range(i + 1, len(models)):
            model1, model2 = models[i], models[j]
            
            # 提取概率数据
            probs1 = list(mock_results[model1]['summary_stats']['average_probabilities'].values())
            probs2 = list(mock_results[model2]['summary_stats']['average_probabilities'].values())
            
            # t检验
            t_stat, p_value = stats.ttest_rel(probs1, probs2)
            
            # Wilcoxon符号秩检验（非参数）
            w_stat, w_p_value = stats.wilcoxon(probs1, probs2)
            
            comparison_key = f"{model1}_vs_{model2}"
            analysis_results['model_comparisons'][comparison_key] = {
                't_test': {'statistic': t_stat, 'p_value': p_value},
                'wilcoxon': {'statistic': w_stat, 'p_value': w_p_value},
                'effect_size': (np.mean(probs2) - np.mean(probs1)) / np.std(probs1 + probs2),
                'significant': p_value < 0.05
            }
            
            print(f"{model1} vs {model2}:")
            print(f"  t-test: t={t_stat:.3f}, p={p_value:.3f}")
            print(f"  Wilcoxon: W={w_stat:.3f}, p={w_p_value:.3f}")
            print(f"  Effect size: {analysis_results['model_comparisons'][comparison_key]['effect_size']:.3f}")
            print(f"  Significant: {'Yes' if p_value < 0.05 else 'No'}\\n")
    
    # 2. 难度级别相关性分析
    print("=== 2. 难度级别相关性分析 ===")
    levels = sorted(level_analysis.keys())
    
    for token in TARGET_TOKENS:
        frequencies = []
        probabilities = []
        
        for level in levels:
            if token in level_analysis[level]:
                frequencies.append(level_analysis[level][token]['frequency'])
                probabilities.append(level_analysis[level][token]['avg_probability'])
            else:
                frequencies.append(0)
                probabilities.append(0)
        
        # Pearson相关性
        freq_corr, freq_p = stats.pearsonr(levels, frequencies)
        prob_corr, prob_p = stats.pearsonr(levels, probabilities)
        
        # Spearman相关性（非参数）
        freq_spearman, freq_sp = stats.spearmanr(levels, frequencies)
        prob_spearman, prob_sp = stats.spearmanr(levels, probabilities)
        
        analysis_results['level_correlations'][token] = {
            'frequency_pearson': {'correlation': freq_corr, 'p_value': freq_p},
            'probability_pearson': {'correlation': prob_corr, 'p_value': prob_p},
            'frequency_spearman': {'correlation': freq_spearman, 'p_value': freq_sp},
            'probability_spearman': {'correlation': prob_spearman, 'p_value': prob_sp}
        }
        
        print(f"{token}:")
        print(f"  频率-难度相关性: r={freq_corr:.3f} (p={freq_p:.3f})")
        print(f"  概率-难度相关性: r={prob_corr:.3f} (p={prob_p:.3f})")
    
    # 3. 正确性效应分析
    print("\\n=== 3. 正确性效应分析 ===")
    correct_data = correctness_analysis[True]
    incorrect_data = correctness_analysis[False]
    
    for token in TARGET_TOKENS:
        correct_freq = correct_data[token]['frequency']
        incorrect_freq = incorrect_data[token]['frequency']
        correct_prob = correct_data[token]['avg_probability']
        incorrect_prob = incorrect_data[token]['avg_probability']
        
        # 卡方检验（频率）
        obs_freq = np.array([[correct_freq * 100, (1-correct_freq) * 100],
                            [incorrect_freq * 100, (1-incorrect_freq) * 100]])
        chi2, chi2_p = stats.chi2_contingency(obs_freq)[:2]
        
        # t检验（概率）
        # 模拟样本数据进行t检验
        n_samples = 50
        correct_samples = np.random.normal(correct_prob, correct_prob * 0.3, n_samples)
        incorrect_samples = np.random.normal(incorrect_prob, incorrect_prob * 0.3, n_samples)
        t_stat, t_p = stats.ttest_ind(correct_samples, incorrect_samples)
        
        analysis_results['correctness_effects'][token] = {
            'frequency_chi2': {'statistic': chi2, 'p_value': chi2_p},
            'probability_ttest': {'statistic': t_stat, 'p_value': t_p},
            'frequency_effect_size': abs(correct_freq - incorrect_freq),
            'probability_effect_size': abs(correct_prob - incorrect_prob)
        }
        
        print(f"{token}:")
        print(f"  频率差异: χ²={chi2:.3f} (p={chi2_p:.3f})")
        print(f"  概率差异: t={t_stat:.3f} (p={t_p:.3f})")
    
    return analysis_results

def dimensionality_reduction_analysis(mock_results: Dict) -> Dict[str, Any]:
    """降维分析"""
    
    print("\\n=== 4. 主成分分析 (PCA) ===")
    
    # 准备数据矩阵
    data_matrix = []
    labels = []
    
    for model_name, results in mock_results.items():
        prob_vector = list(results['summary_stats']['average_probabilities'].values())
        data_matrix.append(prob_vector)
        labels.append(model_name)
    
    data_matrix = np.array(data_matrix)
    
    # 执行PCA
    pca = PCA(n_components=2)
    pca_result = pca.fit_transform(data_matrix)
    
    print(f"主成分解释的方差比例: {pca.explained_variance_ratio_}")
    print(f"累积解释方差: {np.sum(pca.explained_variance_ratio_):.3f}")
    
    # 绘制PCA结果
    plt.figure(figsize=(10, 6))
    
    plt.subplot(1, 2, 1)
    colors = ['red', 'blue', 'green']
    for i, (model, color) in enumerate(zip(labels, colors)):
        plt.scatter(pca_result[i, 0], pca_result[i, 1], c=color, s=100, label=model, alpha=0.8)
        plt.annotate(model, (pca_result[i, 0], pca_result[i, 1]), 
                    xytext=(5, 5), textcoords='offset points')
    
    plt.xlabel(f'PC1 ({pca.explained_variance_ratio_[0]:.2%} variance)')\n    plt.ylabel(f'PC2 ({pca.explained_variance_ratio_[1]:.2%} variance)')\n    plt.title('PCA: Model Comparison in Token Probability Space')\n    plt.legend()\n    plt.grid(True, alpha=0.3)\n    \n    # 特征重要性\n    plt.subplot(1, 2, 2)\n    feature_importance = np.abs(pca.components_[0])  # 第一主成分\n    indices = np.argsort(feature_importance)[::-1]\n    \n    plt.bar(range(len(TARGET_TOKENS)), feature_importance[indices])\n    plt.xlabel('Token Features')\n    plt.ylabel('PC1 Loading (Absolute Value)')\n    plt.title('Feature Importance in First Principal Component')\n    plt.xticks(range(len(TARGET_TOKENS)), [TARGET_TOKENS[i] for i in indices], rotation=45)\n    plt.grid(True, alpha=0.3)\n    \n    plt.tight_layout()\n    plt.show()\n    \n    return {\n        'pca_result': pca_result,\n        'explained_variance_ratio': pca.explained_variance_ratio_,\n        'components': pca.components_,\n        'feature_importance': feature_importance\n    }\n\ndef clustering_analysis(mock_results: Dict) -> Dict[str, Any]:\n    \"\"\"聚类分析\"\"\"\n    \n    print(\"\\n=== 5. 聚类分析 ===\\n\")\n    \n    # 准备数据：将所有token概率作为特征\n    all_data = []\n    all_labels = []\n    \n    for model_name, results in mock_results.items():\n        probs = list(results['summary_stats']['average_probabilities'].values())\n        stds = list(results['summary_stats']['std_probabilities'].values())\n        freqs = list(results['summary_stats']['token_frequencies'].values())\n        \n        # 组合特征：概率 + 标准差 + 频率\n        combined_features = probs + stds + freqs\n        all_data.append(combined_features)\n        all_labels.append(model_name)\n    \n    all_data = np.array(all_data)\n    \n    # 标准化数据\n    from sklearn.preprocessing import StandardScaler\n    scaler = StandardScaler()\n    scaled_data = scaler.fit_transform(all_data)\n    \n    # K-means聚类\n    kmeans = KMeans(n_clusters=2, random_state=42)\n    cluster_labels = kmeans.fit_predict(scaled_data)\n    \n    print(\"聚类结果:\")\n    for i, (model, cluster) in enumerate(zip(all_labels, cluster_labels)):\n        print(f\"  {model}: Cluster {cluster}\")\n    \n    # 计算聚类质量指标\n    from sklearn.metrics import silhouette_score\n    silhouette_avg = silhouette_score(scaled_data, cluster_labels)\n    print(f\"\\n轮廓系数 (Silhouette Score): {silhouette_avg:.3f}\")\n    \n    return {\n        'cluster_labels': cluster_labels,\n        'silhouette_score': silhouette_avg,\n        'cluster_centers': kmeans.cluster_centers_\n    }\n\ndef generate_research_summary(analysis_results: Dict, pattern_analysis: Dict) -> None:\n    \"\"\"生成研究总结\"\"\"\n    \n    print(\"\\n\" + \"=\"*60)\n    print(\"                  研究发现总结\")\n    print(\"=\"*60)\n    \n    # 1. 模型间主要差异\n    print(\"\\n1. 模型间主要差异:\")\n    significant_comparisons = []\n    for comparison, results in analysis_results['model_comparisons'].items():\n        if results['significant']:\n            significant_comparisons.append(comparison)\n            effect_size = results['effect_size']\n            direction = \"增强\" if effect_size > 0 else \"减弱\"\n            print(f\"   • {comparison.replace('_vs_', ' → ')}: {direction} (效应量: {abs(effect_size):.3f})\")\n    \n    if not significant_comparisons:\n        print(\"   • 未发现模型间显著差异\")\n    \n    # 2. 关键发现\n    print(\"\\n2. 关键发现:\")\n    \n    # 最稳定的tokens\n    all_confidence = {}\n    for model, patterns in pattern_analysis.items():\n        for token, metrics in patterns.items():\n            if token not in all_confidence:\n                all_confidence[token] = []\n            all_confidence[token].append(metrics['confidence'])\n    \n    avg_confidence = {token: np.mean(confs) for token, confs in all_confidence.items()}\n    top_stable_tokens = sorted(avg_confidence.items(), key=lambda x: x[1], reverse=True)[:3]\n    \n    print(\"   • 最稳定的推理词汇:\")\n    for token, conf in top_stable_tokens:\n        print(f\"     - {token} (平均confidence: {conf:.2f})\")\n    \n    # 难度级别效应\n    print(\"\\n   • 难度级别效应:\")\n    strong_correlations = []\n    for token, corrs in analysis_results['level_correlations'].items():\n        freq_corr = corrs['frequency_pearson']['correlation']\n        if abs(freq_corr) > 0.5 and corrs['frequency_pearson']['p_value'] < 0.05:\n            direction = \"正相关\" if freq_corr > 0 else \"负相关\"\n            strong_correlations.append(f\"{token} ({direction}, r={freq_corr:.2f})\")\n    \n    if strong_correlations:\n        for corr in strong_correlations:\n            print(f\"     - {corr}\")\n    else:\n        print(\"     - 未发现显著的难度级别效应\")\n    \n    # 正确性效应\n    print(\"\\n   • 正确性效应:\")\n    significant_correctness = []\n    for token, effects in analysis_results['correctness_effects'].items():\n        if effects['probability_ttest']['p_value'] < 0.05:\n            effect_size = effects['probability_effect_size']\n            significant_correctness.append(f\"{token} (效应量: {effect_size:.3f})\")\n    \n    if significant_correctness:\n        for effect in significant_correctness:\n            print(f\"     - {effect}\")\n    else:\n        print(\"     - 未发现显著的正确性效应\")\n    \n    # 3. 实践建议\n    print(\"\\n3. 实践建议:\")\n    print(\"   • 模型改进方向:\")\n    if len(significant_comparisons) > 0:\n        print(\"     - 重点关注显著差异的token，分析训练策略\")\n        print(\"     - 考虑增强表现较弱模型的推理能力\")\n    \n    print(\"   • 进一步研究:\")\n    print(\"     - 增加更多样本以提高统计功效\")\n    print(\"     - 分析更多上下文窗口大小的影响\")\n    print(\"     - 研究token组合的交互效应\")\n    \n    print(\"\\n\" + \"=\"*60)\n\n# 执行综合分析\nprint(\"开始执行综合统计分析...\")\n\n# 统计分析\nanalysis_results = comprehensive_statistical_analysis(mock_results, level_analysis, correctness_analysis)\n\n# 降维分析\npca_results = dimensionality_reduction_analysis(mock_results)\n\n# 聚类分析\nclustering_results = clustering_analysis(mock_results)\n\n# 生成研究总结\ngenerate_research_summary(analysis_results, pattern_analysis)\n\n# 保存结果\nprint(\"\\n保存分析结果...\")\nfinal_results = {\n    'statistical_analysis': analysis_results,\n    'pca_analysis': {\n        'explained_variance_ratio': pca_results['explained_variance_ratio'].tolist(),\n        'feature_importance': pca_results['feature_importance'].tolist()\n    },\n    'clustering_analysis': {\n        'silhouette_score': clustering_results['silhouette_score'],\n        'cluster_labels': clustering_results['cluster_labels'].tolist()\n    },\n    'pattern_analysis': pattern_analysis\n}\n\n# 可以保存到JSON文件\nimport json\nwith open('results/comprehensive_analysis_results.json', 'w', encoding='utf-8') as f:\n    json.dump(final_results, f, ensure_ascii=False, indent=2)\n\nprint(\"分析完成！结果已保存至 results/comprehensive_analysis_results.json\")\nprint(\"\\n=== Notebook演示完成 ===\")\nprint(\"在实际应用中，请将模拟数据替换为真实的模型logits数据。\")