# Environment Setup & Model Loading

## System Check

In [None]:
# 检查GPU和内存状态
import torch
import psutil

print("=== 系统信息 ===")
print(f"Python版本: {torch.__version__}")
print(f"CUDA可用: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU设备: {torch.cuda.get_device_name(0)}")
    print(f"GPU内存: {torch.cuda.get_device_properties(0).total_memory // 1024**3} GB")

print(f"系统RAM: {psutil.virtual_memory().total // 1024**3} GB")
print(f"可用RAM: {psutil.virtual_memory().available // 1024**3} GB")

# 设置内存增长策略
torch.cuda.empty_cache()

## Package Installation

In [None]:
# 安装核心库 (移除bitsandbytes，因为不需要量化)
!pip install --quiet transformers>=4.40.0
!pip install --quiet torch>=2.0.0
!pip install --quiet accelerate
!pip install --quiet plotly
!pip install --quiet numpy pandas matplotlib seaborn
!pip install --quiet tqdm

# 重启运行时（运行完这个cell后，在菜单栏选择"运行时" -> "重启运行时"）
print("安装完成！请重启运行时然后继续下一步。")

## Restart Reminder
⚠️ Important: After installation, please select "Runtime" → "Restart session" from the menu bar, then continue running the following cells.

## Load Gemma 2 2B

In [None]:
# Load Gemma 2 2B in FP16
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
import gc

print("开始加载Gemma 2 2B模型 (FP16精度)...")

# 加载模型和tokenizer
try:
    print("正在加载tokenizer...")
    tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b")

    print("正在加载模型（FP16精度，这可能需要几分钟）...")
    model = AutoModelForCausalLM.from_pretrained(
        "google/gemma-2-2b",
        device_map="auto",
        torch_dtype=torch.float16,  # 使用FP16精度
        trust_remote_code=True,
        low_cpu_mem_usage=True,  # 降低CPU内存使用
    )

    print(f"✅ 模型成功加载到设备: {next(model.parameters()).device}")
    print(f"模型数据类型: {next(model.parameters()).dtype}")
    print(f"模型参数数量: {sum(p.numel() for p in model.parameters()):,}")

    # 检查模型结构
    print(f"模型层数: {model.config.num_hidden_layers}")
    print(f"隐藏层维度: {model.config.hidden_size}")
    print(f"词汇表大小: {model.config.vocab_size}")

    # 显示内存使用情况
    if torch.cuda.is_available():
        print(f"GPU内存使用: {torch.cuda.memory_allocated() // 1024**2} MB")
        print(f"GPU内存缓存: {torch.cuda.memory_reserved() // 1024**2} MB")

except Exception as e:
    print(f"❌ 模型加载失败: {e}")
    print("请检查网络连接或尝试重新运行此cell")
    print("如果内存不足，请考虑使用更小的模型或者启用量化")

## Model Test

In [None]:
# 测试模型生成
test_prompt = "The capital of France is"
print(f"测试提示: '{test_prompt}'")

# 编码输入
inputs = tokenizer(test_prompt, return_tensors="pt").to(model.device)
print(f"输入token数量: {len(inputs['input_ids'][0])}")
print(f"输入tokens: {tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])}")

# 生成文本
with torch.no_grad():
    outputs = model.generate(
        inputs['input_ids'],
        max_new_tokens=10,
        temperature=0.7,
        do_sample=True,
        pad_token_id=tokenizer.eos_token_id
    )

generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(f"生成结果: '{generated_text}'")

# 获取logits分析
with torch.no_grad():
    model_outputs = model(**inputs, output_hidden_states=True)
    logits = model_outputs.logits[0, -1]  # 最后一个token的logits
    probs = torch.softmax(logits, dim=-1)

    # 获取top-5预测
    top_probs, top_indices = torch.topk(probs, 5)
    print("\n=== Top 5 预测 ===")
    for i, (prob, idx) in enumerate(zip(top_probs, top_indices)):
        token = tokenizer.decode([idx])
        print(f"{i+1}. '{token}' - 概率: {prob:.4f}")

    # 熵计算
    def safe_entropy_calculation(logits):
        """更稳健的熵计算方法"""
        try:
            # 检查logits是否异常
            if torch.isnan(logits).any() or torch.isinf(logits).any():
                print("❌ Logits包含异常值")
                return None

            # 使用log_softmax避免数值问题
            log_probs = torch.log_softmax(logits, dim=-1)
            probs = torch.softmax(logits, dim=-1)

            # 检查概率是否异常
            if torch.isnan(probs).any() or torch.isinf(probs).any():
                print("❌ 概率计算异常")
                return None

            # 计算熵
            entropy = -torch.sum(probs * log_probs)

            if torch.isnan(entropy) or torch.isinf(entropy):
                print("❌ 熵计算结果异常")
                return None

            return entropy.item()

        except Exception as e:
            print(f"❌ 熵计算错误: {e}")
            return None

    entropy = safe_entropy_calculation(logits)
    if entropy is not None:
        print(f"\n✅ 预测熵（不确定性）: {entropy:.4f}")
    else:
        print(f"\n❌ 熵计算失败")
        # 使用替代度量
        top1_prob = torch.max(torch.softmax(logits, dim=-1))
        print(f"替代度量 - Top-1概率: {top1_prob:.4f} (越高越确定)")
        print(f"不确定性估计: {1-top1_prob:.4f}")

print("\n✅ 模型测试完成，准备开始分析不确定性神经元！")

# Neuron Analysis

## Test Dataset

In [None]:
# 不确定性神经元测试数据集
# 基于不同类型的不确定性设计测试句子

uncertainty_test_dataset = {

    # 1. Epistemic Uncertainty (认知不确定性)
    # 模型知识不足导致的不确定性，理论上可以通过更多训练数据解决
    "epistemic": {
        "description": "模型知识不足导致的不确定性",
        "expected_behavior": "高不确定性，模型不知道答案",
        "sentences": [
            # Obscure factual knowledge
            "The population of Nauru in 2024 is",
            "The CEO of startup company Zephyr Labs is",
            "The atomic weight of Flerovium is",
            "The mayor of Vaduz, Liechtenstein is",
            "The 47th element on the periodic table is",
            "The director of the 1927 film Metropolis was",
            "The winner of the 1952 Nobel Prize in Chemistry was",
            "The height of Mount Vinson in Antarctica is",
            # Technical/specialized knowledge
            "The Hausdorff dimension of the Sierpinski triangle is",
            "The IUPAC name for water is",
            "The half-life of Carbon-14 is",
            "The speed of sound in helium at 20°C is"
        ]
    },

    # 2. Aleatoric Uncertainty (随机不确定性)
    # 输入本身固有的模糊性，即使有完美模型也无法确定
    "aleatoric": {
        "description": "输入固有的模糊性和多义性",
        "expected_behavior": "中等到高不确定性，多个合理答案",
        "sentences": [
            # Subjective/opinion-based completions
            "The best programming language is",
            "The most important thing in life is",
            "The greatest movie of all time is",
            "The most beautiful color is",
            "The meaning of happiness is",
            "Success is defined as",
            # Open-ended continuations with multiple valid paths
            "She told him that",
            "They decided to go",
            "The reason for this is",
            "After thinking about it,",
            "The story ended when",
            # Lexical ambiguity (word-level multiple meanings)
            "The bank is",
            "The bat flew",
            "She couldn't bear",
            "The seal swam",
            "The light solution is"
        ]
    },

    # 3. Linguistic Uncertainty (语言不确定性)
    # 语言结构或语法导致的不确定性
    "linguistic": {
        "description": "语言结构和语法导致的不确定性",
        "expected_behavior": "结构性不确定性，语法解析困难",
        "sentences": [
            # PP-attachment ambiguity
            "The man saw the boy with the telescope",
            "She hit the man with the umbrella",
            "They discussed the plan in the office",
            "I saw the Grand Canyon flying to New York",
            # Syntactic ambiguity
            "Flying planes can be dangerous",
            "They are hunting dogs",
            "Visiting relatives can be boring",
            "The shooting of the hunters was terrible",
            # Garden path sentences
            "The horse raced past the barn fell",
            "The old man the boats",
            "The complex houses married and single soldiers",
            "The prime number few",
            # Coordination ambiguity
            "Old men and women were served first",
            "I saw her duck and cover"
        ]
    },

    # 4. Low Uncertainty Controls (低不确定性对照组)
    # 确定性高的句子，作为基线对照
    "low_uncertainty": {
        "description": "高确定性句子，作为基线对照",
        "expected_behavior": "低不确定性，明确的预期答案",
        "sentences": [
            # Basic facts
            "The capital of USA is",
            "Two plus two equals",
            "The sun rises in the",
            "The first letter of the alphabet is",
            "Christmas is celebrated on December",
            # Common knowledge
            "The color of grass is",
            "The Earth orbits the",
            "The opposite of hot is",
            "One meter equals one hundred"
        ]
    },
}

def print_dataset_summary():
    """打印数据集摘要"""
    total_sentences = 0
    print("=== 不确定性测试数据集摘要 ===\n")

    for category, data in uncertainty_test_dataset.items():
        num_sentences = len(data["sentences"])
        total_sentences += num_sentences

        print(f"📊 {category.upper().replace('_', ' ')} ({num_sentences} 句子)")
        print(f"   描述: {data['description']}")
        print(f"   预期行为: {data['expected_behavior']}")
        print("   示例句子:")
        for i, sentence in enumerate(data["sentences"][:2]):  # 只显示前2个
            print(f"     • '{sentence}'")
        if num_sentences > 2:
            print(f"     ... 还有 {num_sentences - 2} 个句子")
        print()

    print(f"📈 总计: {total_sentences} 个测试句子")
    print(f"⏱️  预计实验时间: {total_sentences * 0.5:.1f}-{total_sentences * 1:.1f} 分钟")
    return total_sentences

def get_test_sentences_by_category(category=None, limit_per_category=None):
    """
    获取指定类别的测试句子

    Args:
        category: 指定类别，None表示全部
        limit_per_category: 每个类别的句子数量限制
    """
    if category and category in uncertainty_test_dataset:
        sentences = uncertainty_test_dataset[category]["sentences"]
        if limit_per_category:
            sentences = sentences[:limit_per_category]
        return [(sentence, category) for sentence in sentences]

    # 返回所有类别
    all_sentences = []
    for cat, data in uncertainty_test_dataset.items():
        sentences = data["sentences"]
        if limit_per_category:
            sentences = sentences[:limit_per_category]
        all_sentences.extend([(sentence, cat) for sentence in sentences])

    return all_sentences

def get_recommended_test_set(quick_test=True):
    """
    获取推荐的测试集

    Args:
        quick_test: True = 快速测试(每类3-4句), False = 完整测试
    """
    if quick_test:
        print("🚀 推荐：快速测试集 (每类3-4个句子，总计约15个)")
        limit = 4
    else:
        print("🔬 推荐：完整测试集 (所有句子)")
        limit = None

    return get_test_sentences_by_category(limit_per_category=limit)

# 显示数据集信息
total_count = print_dataset_summary()

# 提供使用建议
print("💡 使用建议:")
print("   • 第一次实验：使用 get_recommended_test_set(quick_test=True)")
print("   • 详细分析：使用 get_recommended_test_set(quick_test=False)")
print("   • 特定分析：使用 get_test_sentences_by_category('epistemic')")
print("\n✅ 测试数据集准备完成！")

## Weight Extraction

In [None]:
print("开始提取最后一层神经元权重...")

# 获取模型的最后一层
last_layer_idx = model.config.num_hidden_layers - 1
print(f"分析第{last_layer_idx}层（最后一层）")

# 提取最后一层的输出权重和unembedding矩阵
try:
    # Gemma 2的结构访问
    last_layer = model.model.layers[last_layer_idx]

    # 获取MLP的输出权重 (hidden_size, intermediate_size)
    mlp_gate_proj = last_layer.mlp.gate_proj.weight.data  # (intermediate_size, hidden_size)
    mlp_up_proj = last_layer.mlp.up_proj.weight.data      # (intermediate_size, hidden_size)
    mlp_down_proj = last_layer.mlp.down_proj.weight.data  # (hidden_size, intermediate_size)

    # 获取unembedding矩阵
    unembed_matrix = model.lm_head.weight.data  # (vocab_size, hidden_size)

    print(f"MLP gate projection形状: {mlp_gate_proj.shape}")
    print(f"MLP up projection形状: {mlp_up_proj.shape}")
    print(f"MLP down projection形状: {mlp_down_proj.shape}")
    print(f"Unembedding矩阵形状: {unembed_matrix.shape}")

    # 计算有效的输出权重 (我们关注down_proj，它是MLP的输出)
    W_out = mlp_down_proj.T  # 转置为 (intermediate_size, hidden_size)
    print(f"输出权重矩阵形状: {W_out.shape}")
    print(f"设备: {W_out.device}")
    print(f"数据类型: {W_out.dtype}")

    # 转换为CPU进行分析，保持FP16精度
    W_out_cpu = W_out.cpu()  # 保持FP16
    unembed_cpu = unembed_matrix.cpu()  # 保持FP16

    print("✅ 权重提取完成！")

except Exception as e:
    print(f"❌ 权重提取失败: {e}")
    print("模型结构可能与预期不同，让我们检查实际结构...")

    # 打印模型结构以便调试
    print("\n=== 模型结构检查 ===")
    for name, module in model.named_modules():
        if 'layer' in name and 'mlp' in name:
            print(f"{name}: {type(module)}")
            if hasattr(module, 'weight'):
                print(f"  权重形状: {module.weight.shape}")

## LogitVar Calculation

In [None]:
# 计算每个神经元的LogitVar和权重范数
print("计算LogitVar指标...")

def calculate_logit_var(neuron_weights, unembed_matrix):
    """
    计算神经元的LogitVar指标
    LogitVar(i) = Var(W_out^(i) @ W_U / ||W_out^(i) @ W_U||_2)
    """
    # 确保计算精度，转换为float32进行数值计算
    neuron_weights_f32 = neuron_weights.float()
    unembed_matrix_f32 = unembed_matrix.float()

    # 计算神经元权重与unembedding的乘积
    projections = neuron_weights_f32 @ unembed_matrix_f32.T
    norms = torch.norm(projections, dim=1, keepdim=True)

    # 计算每个神经元投影的L2范数 标准化投影
    norms = torch.clamp(norms, min=1e-10)  # 增加最小值
    normalized_projections = projections / norms

    # 计算每个神经元标准化投影的方差
    mask = torch.isfinite(normalized_projections).all(dim=1)
    logit_vars = torch.full((normalized_projections.shape[0],), float('nan'))
    logit_vars[mask] = torch.var(normalized_projections[mask], dim=1)

    return logit_vars, norms.squeeze()

# 执行计算
try:
    print(f"计算{W_out_cpu.shape[0]}个神经元的指标...")

    logit_vars, projection_norms = calculate_logit_var(W_out_cpu, unembed_cpu)

    # 计算输出权重的L2范数
    weight_norms = torch.norm(W_out_cpu.float(), dim=1)  # 每个神经元权重向量的范数

    print(f"LogitVar计算完成: {logit_vars.shape}")
    print(f"权重范数计算完成: {weight_norms.shape}")
    print(f"投影范数计算完成: {projection_norms.shape}")

    # 基本统计
    print(f"\n=== 统计摘要 ===")
    print(f"LogitVar - 均值: {logit_vars.mean():.6f}, 标准差: {logit_vars.std():.6f}")
    print(f"权重范数 - 均值: {weight_norms.mean():.6f}, 标准差: {weight_norms.std():.6f}")
    print(f"投影范数 - 均值: {projection_norms.mean():.6f}, 标准差: {projection_norms.std():.6f}")

    # 寻找异常值（低LogitVar但高权重范数的神经元）
    logit_var_threshold = logit_vars.quantile(0.1)  # 最低10%的LogitVar
    weight_norm_threshold = weight_norms.quantile(0.9)  # 最高10%的权重范数

    # 候选不确定性神经元
    uncertainty_candidates = (logit_vars < logit_var_threshold) & (weight_norms > weight_norm_threshold)
    num_candidates = uncertainty_candidates.sum().item()

    print(f"\n=== 候选不确定性神经元 ===")
    print(f"低LogitVar阈值: {logit_var_threshold:.6f}")
    print(f"高权重范数阈值: {weight_norm_threshold:.6f}")
    print(f"找到候选神经元: {num_candidates} 个")

    if num_candidates > 0:
        candidate_indices = torch.where(uncertainty_candidates)[0]
        print(f"候选神经元索引: {candidate_indices.tolist()}")

        # 显示前5个候选神经元的详细信息
        for i, idx in enumerate(candidate_indices[:5]):
            print(f"  神经元 {idx.item()}: LogitVar={logit_vars[idx]:.6f}, 权重范数={weight_norms[idx]:.6f}")

    print("\n✅ LogitVar分析完成！")

except Exception as e:
    print(f"❌ 计算失败: {e}")
    import traceback
    traceback.print_exc()

## Visualization

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

# Set plotting style
plt.style.use('default')
sns.set_palette("husl")

# Create scatter plots
fig, axes = plt.subplots(2, 2, figsize=(15, 12))
fig.suptitle('Gemma 2 2B (FP16) Last Layer Neuron Analysis: Finding Uncertainty Neurons', fontsize=16, fontweight='bold')

# 1. Main scatter plot: Weight norm vs LogitVar
ax1 = axes[0, 0]
scatter = ax1.scatter(weight_norms, logit_vars, alpha=0.6, s=30, c='blue', edgecolors='none')
ax1.set_xlabel('Weight L2 Norm', fontsize=12)
ax1.set_ylabel('LogitVar', fontsize=12)
ax1.set_title('Weight Norm vs LogitVar\n(Bottom-right = Uncertainty Neuron Candidates)', fontsize=11)
ax1.grid(True, alpha=0.3)

# Mark candidate uncertainty neurons
if num_candidates > 0:
    candidate_x = weight_norms[uncertainty_candidates]
    candidate_y = logit_vars[uncertainty_candidates]
    ax1.scatter(candidate_x, candidate_y, c='red', s=80, marker='o',
               edgecolors='black', linewidths=2, alpha=0.8, label=f'Candidates ({num_candidates} neurons)')
    ax1.legend()

# 2. LogitVar distribution histogram
ax2 = axes[0, 1]
ax2.hist(logit_vars.numpy(), bins=50, alpha=0.7, color='green', edgecolor='black')
ax2.axvline(logit_var_threshold, color='red', linestyle='--', linewidth=2, label=f'10th percentile: {logit_var_threshold:.4f}')
ax2.set_xlabel('LogitVar', fontsize=12)
ax2.set_ylabel('Number of Neurons', fontsize=12)
ax2.set_title('LogitVar Distribution', fontsize=11)
ax2.legend()
ax2.grid(True, alpha=0.3)

# 3. Weight norm distribution histogram
ax3 = axes[1, 0]
ax3.hist(weight_norms.numpy(), bins=50, alpha=0.7, color='orange', edgecolor='black')
ax3.axvline(weight_norm_threshold, color='red', linestyle='--', linewidth=2, label=f'90th percentile: {weight_norm_threshold:.4f}')
ax3.set_xlabel('Weight L2 Norm', fontsize=12)
ax3.set_ylabel('Number of Neurons', fontsize=12)
ax3.set_title('Weight Norm Distribution', fontsize=11)
ax3.legend()
ax3.grid(True, alpha=0.3)

# 4. Projection norm vs LogitVar
ax4 = axes[1, 1]
ax4.scatter(projection_norms, logit_vars, alpha=0.6, s=30, c='purple', edgecolors='none')
ax4.set_xlabel('Projection Norm (||W_out @ W_U||)', fontsize=12)
ax4.set_ylabel('LogitVar', fontsize=12)
ax4.set_title('Projection Norm vs LogitVar', fontsize=11)
ax4.grid(True, alpha=0.3)

if num_candidates > 0:
    candidate_proj = projection_norms[uncertainty_candidates]
    ax4.scatter(candidate_proj, candidate_y, c='red', s=80, marker='o',
               edgecolors='black', linewidths=2, alpha=0.8)

plt.tight_layout()
plt.show()

# Print detailed analysis results
print("=== Detailed Analysis Results ===")
print(f"Total number of neurons: {len(weight_norms)}")
print(f"Candidate uncertainty neurons: {num_candidates}")

if num_candidates > 0:
    print(f"\n=== Top {min(30, num_candidates)} Candidate Neuron Details ===")
    candidate_indices = torch.where(uncertainty_candidates)[0]
    for i, idx in enumerate(candidate_indices[:min(30, num_candidates)]):
        idx_val = idx.item()
        print(f"Neuron {idx_val:4d}: "
              f"LogitVar={logit_vars[idx]:.4e}, "
              f"Weight norm={weight_norms[idx]:.4f}, "
              f"Projection norm={projection_norms[idx]:.4f}")

    # Save candidate neuron indices for subsequent analysis
    top_candidates = candidate_indices[:5] if num_candidates >= 5 else candidate_indices
    print(f"\nSelecting top {len(top_candidates)} neurons for further validation: {top_candidates.tolist()}")
else:
    print("No obvious candidate uncertainty neurons found")
    # Select some boundary cases for analysis
    sorted_indices = torch.argsort(logit_vars)
    top_candidates = sorted_indices[:3]  # Top 3 neurons with lowest LogitVar
    print(f"Selecting 3 neurons with lowest LogitVar for analysis: {top_candidates.tolist()}")

print("\n✅ Visualization analysis completed! Next we will validate the causal effects of these candidate neurons.")

# Causal Inference Experiment

## Causal Verification Experiment Setup

In [None]:
# 因果验证：激活补丁实验
print("准备因果验证实验...")

def calculate_entropy(logits):
    """计算预测的熵（不确定性度量）"""
    log_probs = torch.log_softmax(logits, dim=-1)
    probs = torch.softmax(logits, dim=-1)
    entropy = -torch.sum(probs * log_probs, dim=-1)
    return entropy

def create_hook_fn(layer_idx, neuron_indices, intervention_type='zero'):
    """创建激活干预的hook函数"""
    def hook_fn(module, input, output):
        # Gemma 2的MLP输出通常是hidden states
        if intervention_type == 'zero':
            # 将指定神经元的激活设为0
            for neuron_idx in neuron_indices:
                if neuron_idx < output.shape[-1]:
                    output[:, :, neuron_idx] = 0
        elif intervention_type == 'mean':
            # 将指定神经元的激活设为该层的均值
            layer_mean = output.mean(dim=[0, 1], keepdim=True)
            for neuron_idx in neuron_indices:
                if neuron_idx < output.shape[-1]:
                    output[:, :, neuron_idx] = layer_mean[:, :, neuron_idx]
        return output
    return hook_fn

# 使用结构化的测试数据集
print("=== 选择测试句子 ===")

# 使用更大测试集
test_data = get_recommended_test_set(quick_test=False)
test_sentences = [item[0] for item in test_data]  # 提取句子
sentence_categories = [item[1] for item in test_data]  # 提取类别

print(f"选择了 {len(test_sentences)} 个测试句子，涵盖 {len(set(sentence_categories))} 种不确定性类型")
print("\n按类别显示测试句子:")

# 按类别组织显示
for category in set(sentence_categories):
    category_sentences = [sent for sent, cat in test_data if cat == category]
    print(f"\n📊 {category.upper().replace('_', ' ')} ({len(category_sentences)} 句子):")
    for i, sentence in enumerate(category_sentences):
        print(f"   {i+1}. '{sentence}'")

# 选择要测试的神经元
if 'top_candidates' in locals() and len(top_candidates) > 0:
    test_neurons = top_candidates[:5].tolist()  # 测试前5个候选神经元
    print(f"\n将测试神经元: {test_neurons}")
else:
    # 如果没有明显候选，随机选择一些神经元作为对照
    test_neurons = [100, 200, 500]  # 示例神经元索引
    print(f"\n使用示例神经元进行测试: {test_neurons}")

print("\n✅ 验证实验设置完成！")

## Activation Patching Experiment

In [None]:
import torch.nn as nn
from contextlib import contextmanager

print("开始执行激活补丁实验...")

def run_intervention_experiment(model, tokenizer, test_data, neuron_indices, layer_idx):
    """
    运行神经元干预实验
    """
    results = {
        'sentences': [],
        'categories': [],
        'baseline_entropy': [],
        'zero_entropy': [],
        'mean_entropy': [],
        'entropy_change_zero': [],
        'entropy_change_mean': [],
        'baseline_top_tokens': [],
        'zero_top_tokens': [],
        'mean_top_tokens': []
    }

    # 获取要干预的层
    target_layer = model.model.layers[layer_idx].mlp

    for sentence, category in tqdm(test_data, desc="测试句子"):
        # 编码输入
        inputs = tokenizer(sentence, return_tensors="pt").to(model.device)

        # 1. 基线预测（无干预）
        with torch.no_grad():
            baseline_outputs = model(**inputs)
            baseline_logits = baseline_outputs.logits[0, -1]  # 最后一个token的logits
            baseline_entropy = calculate_entropy(baseline_logits).item()

            # 获取top-3预测
            baseline_probs = torch.softmax(baseline_logits, dim=-1)
            baseline_top_probs, baseline_top_indices = torch.topk(baseline_probs, 3)
            baseline_top_tokens = [tokenizer.decode([idx]).strip() for idx in baseline_top_indices]

        # 2. 零化干预
        zero_hook = target_layer.register_forward_hook(
            create_hook_fn(layer_idx, neuron_indices, 'zero')
        )

        with torch.no_grad():
            zero_outputs = model(**inputs)
            zero_logits = zero_outputs.logits[0, -1]
            zero_entropy = calculate_entropy(zero_logits).item()

            zero_probs = torch.softmax(zero_logits, dim=-1)
            zero_top_probs, zero_top_indices = torch.topk(zero_probs, 3)
            zero_top_tokens = [tokenizer.decode([idx]).strip() for idx in zero_top_indices]

        zero_hook.remove()

        # 3. 均值干预
        mean_hook = target_layer.register_forward_hook(
            create_hook_fn(layer_idx, neuron_indices, 'mean')
        )

        with torch.no_grad():
            mean_outputs = model(**inputs)
            mean_logits = mean_outputs.logits[0, -1]
            mean_entropy = calculate_entropy(mean_logits).item()

            mean_probs = torch.softmax(mean_logits, dim=-1)
            mean_top_probs, mean_top_indices = torch.topk(mean_probs, 3)
            mean_top_tokens = [tokenizer.decode([idx]).strip() for idx in mean_top_indices]

        mean_hook.remove()

        # 计算熵变化
        entropy_change_zero = zero_entropy - baseline_entropy
        entropy_change_mean = mean_entropy - baseline_entropy

        # 保存结果
        results['sentences'].append(sentence)
        results['categories'].append(category)
        results['baseline_entropy'].append(baseline_entropy)
        results['zero_entropy'].append(zero_entropy)
        results['mean_entropy'].append(mean_entropy)
        results['entropy_change_zero'].append(entropy_change_zero)
        results['entropy_change_mean'].append(entropy_change_mean)
        results['baseline_top_tokens'].append(baseline_top_tokens)
        results['zero_top_tokens'].append(zero_top_tokens)
        results['mean_top_tokens'].append(mean_top_tokens)

        # 实时显示结果
        print(f"\n--- [{category.upper()}] '{sentence}' ---")
        print(f"基线熵: {baseline_entropy:.4f}")
        print(f"零化熵: {zero_entropy:.4f} (Δ: {entropy_change_zero:+.4f})")
        print(f"均值熵: {mean_entropy:.4f} (Δ: {entropy_change_mean:+.4f})")
        print(f"Top-3预测: {' | '.join(baseline_top_tokens)}")

    return results

## Result Analysis

In [None]:
# Execute complete uncertainty neuron validation experiment
print("🧪 Starting uncertainty neuron validation experiment\n")

# Set experiment parameters
LAYER_IDX = last_layer_idx  # Use the last layer
print(f"Target layer: Layer {LAYER_IDX}")
print(f"Test neurons: {test_neurons}")
print(f"Number of test sentences: {len(test_sentences)}")

# Execute experiment
experiment_results = run_intervention_experiment(
    model=model,
    tokenizer=tokenizer,
    test_data=test_data,
    neuron_indices=test_neurons,
    layer_idx=LAYER_IDX
)

print(f"\n✅ Experiment completed! Tested {len(experiment_results['sentences'])} sentences")

# === Results Analysis ===
def analyze_results_by_category(results):
    """Analyze experiment results by uncertainty type"""
    import pandas as pd

    # Convert to DataFrame for easier analysis
    df = pd.DataFrame({
        'sentence': results['sentences'],
        'category': results['categories'],
        'baseline_entropy': results['baseline_entropy'],
        'entropy_change_zero': results['entropy_change_zero'],
        'entropy_change_mean': results['entropy_change_mean']
    })

    print("=== Results Analysis by Uncertainty Type ===\n")

    # Statistics by category
    category_stats = df.groupby('category').agg({
        'baseline_entropy': ['mean', 'std', 'count'],
        'entropy_change_zero': ['mean', 'std'],
        'entropy_change_mean': ['mean', 'std']
    }).round(4)

    for category in df['category'].unique():
        cat_data = df[df['category'] == category]
        n_samples = len(cat_data)

        print(f"📊 {category.upper().replace('_', ' ')} ({n_samples} samples)")
        print(f"   Baseline entropy: {cat_data['baseline_entropy'].mean():.4f} ± {cat_data['baseline_entropy'].std():.4f}")
        print(f"   Zero intervention effect: {cat_data['entropy_change_zero'].mean():.4f} ± {cat_data['entropy_change_zero'].std():.4f}")
        print(f"   Mean intervention effect: {cat_data['entropy_change_mean'].mean():.4f} ± {cat_data['entropy_change_mean'].std():.4f}")

        # Effect direction analysis
        zero_positive = (cat_data['entropy_change_zero'] > 0).sum()
        mean_positive = (cat_data['entropy_change_mean'] > 0).sum()
        print(f"   Proportion with zero intervention increasing entropy: {zero_positive/n_samples:.1%}")
        print(f"   Proportion with mean intervention increasing entropy: {mean_positive/n_samples:.1%}")
        print()

    return df

# Execute analysis
results_df = analyze_results_by_category(experiment_results)

# === Visualize Analysis Results ===
def plot_experiment_results(df):
    """Visualize experiment results"""
    fig, axes = plt.subplots(2, 2, figsize=(15, 12))
    fig.suptitle('Uncertainty Neuron Intervention Experiment Results Analysis (FP16)', fontsize=16, fontweight='bold')

    # 1. Baseline entropy distribution by category
    ax1 = axes[0, 0]
    categories = df['category'].unique()
    colors = sns.color_palette("husl", len(categories))

    for i, category in enumerate(categories):
        cat_data = df[df['category'] == category]
        ax1.scatter(cat_data.index, cat_data['baseline_entropy'],
                   label=category.replace('_', ' '), alpha=0.7, s=60, color=colors[i])

    ax1.set_xlabel('Sample Index')
    ax1.set_ylabel('Baseline Entropy')
    ax1.set_title('Baseline Uncertainty for Different Sentence Types')
    ax1.legend()
    ax1.grid(True, alpha=0.3)

    # 2. Intervention effect comparison
    ax2 = axes[0, 1]
    category_means = df.groupby('category')[['entropy_change_zero', 'entropy_change_mean']].mean()
    x = np.arange(len(categories))
    width = 0.35

    ax2.bar(x - width/2, category_means['entropy_change_zero'], width,
           label='Zero Intervention', alpha=0.8, color='red')
    ax2.bar(x + width/2, category_means['entropy_change_mean'], width,
           label='Mean Intervention', alpha=0.8, color='blue')

    ax2.set_xlabel('Uncertainty Type')
    ax2.set_ylabel('Average Entropy Change')
    ax2.set_title('Comparison of Different Intervention Methods')
    ax2.set_xticks(x)
    ax2.set_xticklabels([cat.replace('_', '\n') for cat in categories])
    ax2.legend()
    ax2.grid(True, alpha=0.3)
    ax2.axhline(y=0, color='black', linestyle='-', alpha=0.5)

    # 3. Scatter plot: baseline entropy vs intervention effect
    ax3 = axes[1, 0]
    for i, category in enumerate(categories):
        cat_data = df[df['category'] == category]
        ax3.scatter(cat_data['baseline_entropy'], cat_data['entropy_change_zero'],
                   label=category.replace('_', ' '), alpha=0.7, s=60, color=colors[i])

    ax3.set_xlabel('Baseline Entropy')
    ax3.set_ylabel('Zero Intervention Entropy Change')
    ax3.set_title('Baseline Uncertainty vs Intervention Effect')
    ax3.legend()
    ax3.grid(True, alpha=0.3)
    ax3.axhline(y=0, color='black', linestyle='-', alpha=0.5)

    # 4. Effect strength distribution
    ax4 = axes[1, 1]
    all_changes = np.concatenate([df['entropy_change_zero'], df['entropy_change_mean']])
    intervention_types = ['Zero Intervention'] * len(df) + ['Mean Intervention'] * len(df)

    ax4.hist([df['entropy_change_zero'], df['entropy_change_mean']],
            bins=10, alpha=0.7, label=['Zero Intervention', 'Mean Intervention'], color=['red', 'blue'])

    ax4.set_xlabel('Entropy Change')
    ax4.set_ylabel('Frequency')
    ax4.set_title('Distribution of Intervention Effect Strength')
    ax4.legend()
    ax4.grid(True, alpha=0.3)
    ax4.axvline(x=0, color='black', linestyle='-', alpha=0.5)

    plt.tight_layout()
    plt.show()

# Generate analysis charts
plot_experiment_results(results_df)

# === Experiment Conclusions ===
print("=== 🎯 Experiment Conclusions ===")

# Calculate overall effects
overall_zero_effect = results_df['entropy_change_zero'].mean()
overall_mean_effect = results_df['entropy_change_mean'].mean()

print(f"Overall intervention effects:")
print(f"  Zero intervention average effect: {overall_zero_effect:.4f}")
print(f"  Mean intervention average effect: {overall_mean_effect:.4f}")

# Statistics for significant effects
significant_zero = (abs(results_df['entropy_change_zero']) > 0.1).sum()
significant_mean = (abs(results_df['entropy_change_mean']) > 0.1).sum()
total_samples = len(results_df)

print(f"\nSignificant effect statistics (|change| > 0.1):")
print(f"  Zero intervention significant effects: {significant_zero}/{total_samples} ({significant_zero/total_samples:.1%})")
print(f"  Mean intervention significant effects: {significant_mean}/{total_samples} ({significant_mean/total_samples:.1%})")

# Uncertainty neuron determination
if overall_zero_effect > 0.05 or overall_mean_effect > 0.05:
    print(f"\n✅ Conclusion: Neurons {test_neurons} may be uncertainty neurons!")
    print("   - Intervention on these neurons significantly affected model prediction uncertainty")
    print("   - Recommend conducting deeper analysis and testing more neurons")
else:
    print(f"\n❓ Conclusion: The uncertainty role of neurons {test_neurons} is not obvious")
    print("   - Recommend testing other candidate neurons")
    print("   - Or try different intervention methods")

print(f"\n🎉 Your first mechanistic interpretability experiment is complete!")
print("   Next steps you can try:")
print("   • Test neurons from more layers")
print("   • Use a larger test dataset")
print("   • Implement more refined activation patching methods")
print("   • Analyze specific types of uncertainty neurons")

# Display current memory usage
if torch.cuda.is_available():
    print(f"\nCurrent GPU memory usage: {torch.cuda.memory_allocated() // 1024**2} MB")
    print(f"Peak GPU memory: {torch.cuda.max_memory_allocated() // 1024**2} MB")
    print(f"Total GPU memory: {torch.cuda.get_device_properties(0).total_memory // 1024**3} GB")