In [1]:
import os, json, re
from datetime import datetime
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import matplotlib.pyplot as plt
from collections import Counter
import numpy as np
import random
import matplotlib  # for colormap access

# =========================
# 全局超参（可按需调整）
# =========================
MODEL_PATHS = [
        "/root/autodl-tmp/llama",
        "/root/autodl-tmp/Mistral",
        "/root/autodl-tmp/AceMath",
        "/root/autodl-tmp/Qwen2.5-Math-7B",
        "/root/autodl-tmp/Qwen2.5-7B-Instruct"
    ]
OUTPUT_ROOT_DIR = "./logit_lens_results"

NUM_DATASET_SAMPLES = 1000       # 每次实验的算术题数量
NUM_REPETITIONS = 5              # 重复实验次数
BASE_SEED = 42                   # 基础随机种子（每次重复+1）

# 生成相关
MAX_INPUT_LENGTH = 512
MAX_NEW_TOKENS = 1
DO_SAMPLE = False

# 可视化相关
DISPLAY_LAST_N_LAYERS = 10
FIGSIZE = (12, 7)

# =========================
# 工具函数
# =========================
def set_seeds(seed_value):
    random.seed(seed_value)
    np.random.seed(seed_value)
    torch.manual_seed(seed_value)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed_value)

def sanitize_name(s: str) -> str:
    # 把 HuggingFace id 或路径名转成安全目录名
    s = s.strip().strip("/\\")
    s = os.path.basename(s) or s
    s = s.replace("/", "__").replace("\\", "__")
    return re.sub(r"[^A-Za-z0-9._\-+=@]", "_", s)

def ensure_dir(p: str):
    os.makedirs(p, exist_ok=True)

def generate_addition_problem():
    """
    生成两个三位数的加法题，且和仍为三位数。
    形如：'Calculate: X + Y = '
    """
    x = random.randint(100, 899)
    y_lower_bound = 100
    y_upper_bound = 999 - x
    if y_upper_bound < y_lower_bound:
        return None
    y = random.randint(y_lower_bound, y_upper_bound)
    return f"Calculate: {x} + {y} = "

# =========================
# Matplotlib 全局风格
# =========================
plt.rcParams.update({
    'font.size': 10,
    'font.family': 'serif',
    'font.sans-serif': ['Arial'],
    'axes.labelsize': 12,
    'axes.titlesize': 14,
    'xtick.labelsize': 10,
    'ytick.labelsize': 10,
    'legend.fontsize': 16,
    'lines.linewidth': 1.8,
    'lines.markersize': 6,
    'axes.grid': True,
    'grid.alpha': 0.5,
    'grid.linestyle': ':',
    'savefig.dpi': 300,
    'savefig.format': 'pdf',
    'savefig.bbox': 'tight',
})
try:
    colors_cmap = matplotlib.colormaps.get_cmap('tab10')
except AttributeError:
    colors_cmap = plt.cm.get_cmap('tab10')
bar_plot_color = colors_cmap(0)

# =========================
# 单次实验（对一个模型的一次重复）
# =========================
def run_single_experiment(current_seed, model_obj, tokenizer_obj, device_obj, num_samples, num_total_layers):
    set_seeds(current_seed)
    print(f"\n--- Running Experiment with Seed: {current_seed} ---")

    # 1) 生成题目
    prompts, attempts = [], 0
    max_attempts = num_samples * 2
    while len(prompts) < num_samples and attempts < max_attempts:
        problem = generate_addition_problem()
        if problem:
            prompts.append(problem)
        attempts += 1
    if not prompts:
        print("Error: No prompts were generated for this run. Skipping.")
        return [-1] * num_samples, [], [], num_samples
    if len(prompts) < num_samples:
        print(f"Warning: Only generated {len(prompts)} prompts out of {num_samples} requested for seed {current_seed}.")

    # 2) 预测下一个 token（模型自己的下一步）
    generation_inputs = tokenizer_obj(prompts, return_tensors="pt", padding=True, truncation=True, max_length=MAX_INPUT_LENGTH).to(device_obj)
    with torch.no_grad():
        generated_outputs = model_obj.generate(
            input_ids=generation_inputs["input_ids"],
            attention_mask=generation_inputs["attention_mask"],
            max_new_tokens=MAX_NEW_TOKENS,
            do_sample=DO_SAMPLE,
            pad_token_id=tokenizer_obj.pad_token_id
        )

    original_input_lengths = generation_inputs["input_ids"].shape[1]
    target_token_ids = generated_outputs[:, original_input_lengths:].squeeze(-1)
    target_token_strs = [tokenizer_obj.decode(token_id, skip_special_tokens=True) for token_id in target_token_ids]

    # 3) 准备 Logit Lens 分析输入
    input_ids_for_analysis = generation_inputs["input_ids"]
    attention_mask_for_analysis = generation_inputs["attention_mask"]

    # 4) 逐层计算 hidden states，找出目标 token 首次成为 Top-1 的层
    with torch.no_grad():
        outputs = model_obj(
            input_ids=input_ids_for_analysis,
            attention_mask=attention_mask_for_analysis,
            output_hidden_states=True
        )

    hidden_states_all_layers = outputs.hidden_states  # len = num_total_layers (含 embedding 0)
    lm_head = model_obj.lm_head
    sequence_lengths = torch.sum(attention_mask_for_analysis, dim=1) - 1
    batch_size = input_ids_for_analysis.shape[0]
    first_top1_layer_per_sample = [-1] * batch_size

    for layer_idx, layer_hidden_states in enumerate(hidden_states_all_layers):
        if layer_idx >= num_total_layers:
            break
        batch_indices = torch.arange(batch_size, device=device_obj)
        last_token_hidden_states = layer_hidden_states[batch_indices, sequence_lengths, :]
        logits_at_layer = lm_head(last_token_hidden_states)
        top1_pred_ids = torch.argmax(logits_at_layer, dim=1)

        for sample_idx in range(batch_size):
            if (first_top1_layer_per_sample[sample_idx] == -1 and
                top1_pred_ids[sample_idx].item() == target_token_ids[sample_idx].item()):
                first_top1_layer_per_sample[sample_idx] = layer_idx

    num_never_top1_this_run = first_top1_layer_per_sample.count(-1)
    print(f"Seed {current_seed} done. 'Never Top-1' samples: {num_never_top1_this_run}")
    return first_top1_layer_per_sample, target_token_strs, prompts, num_never_top1_this_run

# =========================
# 单模型多次重复 + 汇总 + 输出
# =========================
def run_for_one_model(model_path: str):
    model_id = sanitize_name(model_path)
    out_dir = os.path.join(OUTPUT_ROOT_DIR, model_id)
    ensure_dir(out_dir)
    print(f"\n=== Loading model: {model_path} ===")

    tokenizer = AutoTokenizer.from_pretrained(model_path)
    model = AutoModelForCausalLM.from_pretrained(
        model_path,
        torch_dtype=torch.bfloat16
    )
    model.eval()
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model.to(device)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
        tokenizer.padding_side = "left"
        model.config.pad_token_id = tokenizer.eos_token_id

    num_layers = model.config.num_hidden_layers + 1
    print(f"Transformer blocks: {model.config.num_hidden_layers}; hidden_states count: {num_layers}")

    # ---- 重复实验 ----
    all_runs_first_top1_layers_data = []
    all_runs_num_never_top1_counts = []

    for i in range(NUM_REPETITIONS):
        current_run_seed = BASE_SEED + i
        first_top1_data, _, _, num_never_top1 = run_single_experiment(
            current_run_seed, model, tokenizer, device, NUM_DATASET_SAMPLES, num_layers
        )
        all_runs_first_top1_layers_data.append(first_top1_data)
        all_runs_num_never_top1_counts.append(num_never_top1)

        # 保存每次 run 的原始结果（CSV）
        csv_path = os.path.join(out_dir, f"first_top1_layers_run{i+1}.csv")
        with open(csv_path, "w", encoding="utf-8") as f:
            f.write("sample_idx,first_top1_layer\n")
            for idx, layer in enumerate(first_top1_data):
                f.write(f"{idx},{layer}\n")

    # ---- 汇总统计 ----
    aggregated_layer_counts_matrix = np.zeros((NUM_REPETITIONS, num_layers), dtype=float)
    for run_idx, single_run_results in enumerate(all_runs_first_top1_layers_data):
        if len(single_run_results) != NUM_DATASET_SAMPLES:
            print(f"Warning: Run {run_idx} returned {len(single_run_results)} results; expected {NUM_DATASET_SAMPLES}.")
        counts_for_this_run = Counter(l for l in single_run_results if l != -1)
        for layer_idx_val, count_val in counts_for_this_run.items():
            if 0 <= layer_idx_val < num_layers:
                aggregated_layer_counts_matrix[run_idx, layer_idx_val] = count_val

    mean_counts_per_layer = np.mean(aggregated_layer_counts_matrix, axis=0)
    std_dev_counts_per_layer = np.std(aggregated_layer_counts_matrix, axis=0)
    std_err_counts_per_layer = std_dev_counts_per_layer / np.sqrt(NUM_REPETITIONS)

    mean_num_never_top1_overall = float(np.mean(all_runs_num_never_top1_counts))
    std_dev_num_never_top1_overall = float(np.std(all_runs_num_never_top1_counts))
    std_err_num_never_top1_overall = std_dev_num_never_top1_overall / np.sqrt(NUM_REPETITIONS)

    # 保存汇总 CSV
    agg_csv = os.path.join(out_dir, "aggregated_counts_per_layer.csv")
    with open(agg_csv, "w", encoding="utf-8") as f:
        f.write("layer_idx,mean_count,std_err\n")
        for li in range(num_layers):
            f.write(f"{li},{mean_counts_per_layer[li]:.6f},{std_err_counts_per_layer[li]:.6f}\n")

    # 保存摘要 JSON（含超参、层数、统计）
    summary_json = os.path.join(out_dir, "summary.json")
    summary = {
        "model_id": model_id,
        "model_path": model_path,
        "timestamp": datetime.now().isoformat(),
        "params": {
            "NUM_DATASET_SAMPLES": NUM_DATASET_SAMPLES,
            "NUM_REPETITIONS": NUM_REPETITIONS,
            "BASE_SEED": BASE_SEED,
            "MAX_INPUT_LENGTH": MAX_INPUT_LENGTH,
            "MAX_NEW_TOKENS": MAX_NEW_TOKENS,
            "DO_SAMPLE": DO_SAMPLE,
        },
        "num_layers": num_layers,
        "display_last_n_layers": min(DISPLAY_LAST_N_LAYERS, num_layers),
        "mean_counts_per_layer": mean_counts_per_layer.tolist(),
        "std_err_counts_per_layer": std_err_counts_per_layer.tolist(),
        "never_top1_counts_per_run": all_runs_num_never_top1_counts,
        "mean_never_top1": mean_num_never_top1_overall,
        "stddev_never_top1": std_dev_num_never_top1_overall,
        "stderr_never_top1": std_err_num_never_top1_overall,
    }
    with open(summary_json, "w", encoding="utf-8") as f:
        json.dump(summary, f, ensure_ascii=False, indent=2)

    # ---- 画图并保存到该模型目录 ----
    print("\nVisualizing aggregated results (last layers)...")
    fig, ax = plt.subplots(figsize=FIGSIZE)

    start_layer_for_display = max(0, num_layers - DISPLAY_LAST_N_LAYERS)
    layers_to_plot = list(range(start_layer_for_display, num_layers))
    actual_displayed_n_layers = len(layers_to_plot)

    mean_counts_for_displayed_layers = [mean_counts_per_layer[idx] for idx in layers_to_plot]
    error_for_displayed_layers = [std_err_counts_per_layer[idx] for idx in layers_to_plot]

    if not layers_to_plot:
        ax.text(0.5, 0.5,
                f"Model has only {num_layers} layer(s).\nCannot display the last {DISPLAY_LAST_N_LAYERS} layers.",
                ha='center', va='center', transform=ax.transAxes,
                bbox=dict(boxstyle='round,pad=0.5', fc='lightyellow', alpha=0.8))
    else:
        bars = ax.bar(
            layers_to_plot,
            mean_counts_for_displayed_layers,
            yerr=error_for_displayed_layers,
            capsize=4,
            color=bar_plot_color,
            zorder=2,
            width=0.8,
            edgecolor='black',
            linewidth=0.7
        )
        ax.set_xticks(layers_to_plot)
        ax.set_xticklabels([str(l) for l in layers_to_plot])

        max_mean_count_in_display = max(mean_counts_for_displayed_layers, default=1) if mean_counts_for_displayed_layers else 1
        for bar_obj, mean_val in zip(bars, mean_counts_for_displayed_layers):
            height = bar_obj.get_height()
            if height > 0:
                ax.text(
                    bar_obj.get_x() + bar_obj.get_width() / 2.0,
                    height + 0.02 * max_mean_count_in_display,
                    f'{mean_val:.1f}',
                    ha='center', va='bottom',
                    fontsize=plt.rcParams['xtick.labelsize'] - 1
                )

        num_samples_reached_top1_overall_avg = float(np.sum(mean_counts_per_layer))
        if NUM_DATASET_SAMPLES > 0 and sum(mean_counts_for_displayed_layers) < 0.1:
            extra_info = ""
            if num_samples_reached_top1_overall_avg > 0.1:
                extra_info = (f"\n(Note: Avg. {num_samples_reached_top1_overall_avg:.1f} sample(s) reached Top-1 "
                              f"in earlier layers not shown.)")
            elif num_samples_reached_top1_overall_avg < 0.1:
                extra_info = (f"\n(Avg. {num_samples_reached_top1_overall_avg:.1f} samples reached Top-1 in any layer.)")
            ax.text(
                0.5, 0.5,
                (f"Avg. count of samples first becoming Top-1 is near zero\n"
                 f"in the displayed layers ({start_layer_for_display}-{num_layers-1}).{extra_info}"),
                ha='center', va='center', transform=ax.transAxes,
                fontsize=plt.rcParams['legend.fontsize'],
                bbox=dict(boxstyle='round,pad=0.3', fc='lightyellow', alpha=0.8)
            )

    title_line1 = (f"{model_id} | First layer where model's predicted token becomes Top-1 "
                   f"(Shown: {start_layer_for_display}-{num_layers-1})")
    title_line2 = (f"3-digit additions | Repetitions: {NUM_REPETITIONS}")
    title_line3 = (f"Never Top-1: {mean_num_never_top1_overall:.1f} ± "
                   f"{std_err_num_never_top1_overall:.1f} (SE)")

    ax.set_title(f"{title_line1}\n{title_line2}\n{title_line3}", wrap=True, fontsize=16)
    ax.set_xlabel(f"Layer Index (Range shown: {start_layer_for_display} to {num_layers-1})")
    ax.set_ylabel(f"Avg. # Samples (First Top-1 in this layer, ±SE over {NUM_REPETITIONS} runs)")

    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.tick_params(direction='in', top=False, right=False)
    ax.yaxis.grid(True, linestyle=plt.rcParams['grid.linestyle'], linewidth=0.6, alpha=plt.rcParams['grid.alpha'])
    ax.set_axisbelow(True)

    if mean_num_never_top1_overall >= 0:
        never_top1_text = (f"Avg. 'Never Top-1': {mean_num_never_top1_overall:.1f} "
                           f"± {std_err_num_never_top1_overall:.1f} (SE)\n"
                           f"over {NUM_REPETITIONS} runs.")
        ax.text(
            0.98, 0.05, never_top1_text,
            ha='right', va='bottom', transform=ax.transAxes,
            fontsize=plt.rcParams['legend.fontsize'] - 1, bbox=dict(boxstyle='round,pad=0.3', fc='lightcoral', alpha=0.7)
        )

    plt.tight_layout(pad=1.0)
    fig_path = os.path.join(out_dir, f"logit_lens_histogram_last{len(layers_to_plot)}L.pdf")
    plt.savefig(fig_path)
    plt.close(fig)
    print(f"Saved figure to → {fig_path}")

    # 释放显存/内存，便于下一个模型
    del model, tokenizer
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

    print(f"Finished model: {model_id}. Results in: {out_dir}")

# =========================
# 主入口：批量跑多个模型
# =========================
def main():
    ensure_dir(OUTPUT_ROOT_DIR)
    print(f"Results root: {OUTPUT_ROOT_DIR}")
    for mp in MODEL_PATHS:
        run_for_one_model(mp)
    print("\nAll models finished.")

if __name__ == "__main__":
    main()


Results root: ./logit_lens_results

=== Loading model: /root/autodl-tmp/llama ===


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Transformer blocks: 32; hidden_states count: 33

--- Running Experiment with Seed: 42 ---




Seed 42 done. 'Never Top-1' samples: 0

--- Running Experiment with Seed: 43 ---
Seed 43 done. 'Never Top-1' samples: 0

--- Running Experiment with Seed: 44 ---
Seed 44 done. 'Never Top-1' samples: 0

--- Running Experiment with Seed: 45 ---
Seed 45 done. 'Never Top-1' samples: 0

--- Running Experiment with Seed: 46 ---
Seed 46 done. 'Never Top-1' samples: 0

Visualizing aggregated results (last layers)...
Saved figure to → ./logit_lens_results/llama/logit_lens_histogram_last10L.pdf
Finished model: llama. Results in: ./logit_lens_results/llama

=== Loading model: /root/autodl-tmp/Mistral ===


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Transformer blocks: 32; hidden_states count: 33

--- Running Experiment with Seed: 42 ---
Seed 42 done. 'Never Top-1' samples: 0

--- Running Experiment with Seed: 43 ---
Seed 43 done. 'Never Top-1' samples: 0

--- Running Experiment with Seed: 44 ---
Seed 44 done. 'Never Top-1' samples: 0

--- Running Experiment with Seed: 45 ---
Seed 45 done. 'Never Top-1' samples: 0

--- Running Experiment with Seed: 46 ---
Seed 46 done. 'Never Top-1' samples: 0

Visualizing aggregated results (last layers)...
Saved figure to → ./logit_lens_results/Mistral/logit_lens_histogram_last10L.pdf
Finished model: Mistral. Results in: ./logit_lens_results/Mistral

=== Loading model: /root/autodl-tmp/AceMath ===


Sliding Window Attention is enabled but not implemented for `sdpa`; unexpected results may be encountered.


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Transformer blocks: 28; hidden_states count: 29

--- Running Experiment with Seed: 42 ---
Seed 42 done. 'Never Top-1' samples: 0

--- Running Experiment with Seed: 43 ---
Seed 43 done. 'Never Top-1' samples: 0

--- Running Experiment with Seed: 44 ---
Seed 44 done. 'Never Top-1' samples: 0

--- Running Experiment with Seed: 45 ---
Seed 45 done. 'Never Top-1' samples: 0

--- Running Experiment with Seed: 46 ---
Seed 46 done. 'Never Top-1' samples: 0

Visualizing aggregated results (last layers)...
Saved figure to → ./logit_lens_results/AceMath/logit_lens_histogram_last10L.pdf
Finished model: AceMath. Results in: ./logit_lens_results/AceMath

=== Loading model: /root/autodl-tmp/Qwen2.5-Math-7B ===


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Transformer blocks: 28; hidden_states count: 29

--- Running Experiment with Seed: 42 ---
Seed 42 done. 'Never Top-1' samples: 0

--- Running Experiment with Seed: 43 ---
Seed 43 done. 'Never Top-1' samples: 0

--- Running Experiment with Seed: 44 ---
Seed 44 done. 'Never Top-1' samples: 0

--- Running Experiment with Seed: 45 ---
Seed 45 done. 'Never Top-1' samples: 0

--- Running Experiment with Seed: 46 ---
Seed 46 done. 'Never Top-1' samples: 0

Visualizing aggregated results (last layers)...
Saved figure to → ./logit_lens_results/Qwen2.5-Math-7B/logit_lens_histogram_last10L.pdf
Finished model: Qwen2.5-Math-7B. Results in: ./logit_lens_results/Qwen2.5-Math-7B

=== Loading model: /root/autodl-tmp/Qwen2.5-7B-Instruct ===


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Transformer blocks: 28; hidden_states count: 29

--- Running Experiment with Seed: 42 ---




Seed 42 done. 'Never Top-1' samples: 237

--- Running Experiment with Seed: 43 ---
Seed 43 done. 'Never Top-1' samples: 216

--- Running Experiment with Seed: 44 ---
Seed 44 done. 'Never Top-1' samples: 218

--- Running Experiment with Seed: 45 ---
Seed 45 done. 'Never Top-1' samples: 252

--- Running Experiment with Seed: 46 ---
Seed 46 done. 'Never Top-1' samples: 221

Visualizing aggregated results (last layers)...
Saved figure to → ./logit_lens_results/Qwen2.5-7B-Instruct/logit_lens_histogram_last10L.pdf
Finished model: Qwen2.5-7B-Instruct. Results in: ./logit_lens_results/Qwen2.5-7B-Instruct

All models finished.
