In [None]:
import os
import re
import json
import math  # for sqrt when computing SEM
import random
from dataclasses import dataclass, field
from pathlib import Path
from typing import List, Tuple, Dict, Any, Optional

import csv
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM
"""
model :
NousResearch/Meta-Llama-3-8B-Instruct
nvidia/AceMath-7B-Instruct
mistralai/Mistral-7B-Instruct-v0.3
Qwen/Qwen2.5-Math-7B-Instruct
"""

# ==============================
# Config (Hyperparameters & Multi-Model Support)
# ==============================
@dataclass
class Config:
    # Support running multiple models at once; each item can be an HF Hub name
    # (e.g., meta-llama/Meta-Llama-3-8B-Instruct) or a local path
    # (e.g., /data/global/model/llama3_instruct/)
    MODEL_NAMES: List[str] = field(default_factory=lambda: [
        "/root/autodl-tmp/llama",
        "/root/autodl-tmp/Mistral",
        "/root/autodl-tmp/AceMath",
        "/root/autodl-tmp/Qwen2.5-Math-7B",
        "/root/autodl-tmp/Qwen2.5-7B-Instruct"
    ])

    # Some community weights require trust_remote_code
    TRUST_REMOTE_CODE: bool = False

    # Model precision auto selection: auto/bf16/fp16/fp32
    DTYPE: str = "auto"

    # Task/data-related
    N_DIGITS: int = 3
    POSITIONS: Tuple[int, ...] = (0, 1, 2)  # positions to probe (0 = least significant)
    TASK_TO_PROBE: str = "carry"  # "carry" or "digits"
    SAMPLES_PER_CLASS: int = 1000  # per class for "carry"

    # Training
    BATCH_SIZE: int = 500
    N_REPETITIONS: int = 5
    BASE_SEED: int = 42
    EPOCHS: int = 20
    LR: float = 1e-3

    # Device
    DEVICE: str = "cuda:0" if torch.cuda.is_available() else "cpu"

    # Output root directory; each model will create its own subfolder
    # (using a sanitized version of the model name)
    OUTPUT_ROOT: Path = Path("outputs")


# ==============================
# Utilities
# ==============================
def set_seed(seed: int):
    """Set random seeds for reproducibility across Python, NumPy, and PyTorch."""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    # Optional: make behavior deterministic (may slow down)
    # torch.backends.cudnn.deterministic = True
    # torch.backends.cudnn.benchmark = False


def sanitize_model_name(name: str) -> str:
    """
    Convert an HF model name or local path into a safe folder name.
    - Take the last path component if there's a '/'
    - Replace non-alphanumeric characters with underscores
    """
    # Take the "last component"
    last = name.rstrip("/").split("/")[-1] if "/" in name.rstrip("/") else name
    # If it's a disk path but the last component is empty, fall back one level
    if last == "":
        last = Path(name).name
    # Replace illegal characters
    safe = re.sub(r"[^A-Za-z0-9_.-]+", "_", last)
    return safe if safe else "model"


def pick_dtype(device: str, pref: str = "auto"):
    """
    Choose torch dtype based on device and preference.
    Returns (torch.dtype, textual_description)
    """
    if pref.lower() == "bf16":
        return torch.bfloat16, "bf16"
    if pref.lower() == "fp16":
        return torch.float16, "fp16"
    if pref.lower() == "fp32":
        return torch.float32, "fp32"

    # auto
    if "cuda" in device:
        # Prefer bf16 (Ampere and newer), otherwise fallback to fp16
        # We won't probe hardware here; simply try bf16 -> fallback to fp16
        try:
            return torch.bfloat16, "bf16"
        except Exception:
            return torch.float16, "fp16"
    else:
        return torch.float32, "fp32"


def generate_addition_sample(n_digits: int) -> Tuple[str, List[int], List[int]]:
    """
    Generate one a+b sample, returning:
    - prompt string,
    - carry flags list (length n_digits+1, LSB first),
    - sum digits list (length n_digits+1, LSB first).
    """
    lo = 10 ** (n_digits - 1) if n_digits > 0 else 0
    hi = (10 ** n_digits) - 1 if n_digits > 0 else 0
    a = random.randint(lo, hi)
    b = random.randint(lo, hi)
    c_val = a + b
    prompt = f"Calculate: {a}+{b} = "

    num_positions = n_digits + 1

    def to_digits(x: int, length: int) -> List[int]:
        s = f"{x:0{length}d}"[::-1]  # reverse string so LSB is first
        return [int(d) for d in s]

    A = to_digits(a, num_positions)
    B = to_digits(b, num_positions)
    C = to_digits(c_val, num_positions)

    carry_flags = []
    carry = 0
    for i in range(num_positions):
        s = A[i] + B[i] + carry
        if s >= 10:
            carry_flags.append(1)
            carry = 1
        else:
            carry_flags.append(0)
            carry = 0
    return prompt, carry_flags, C


class BalancedAdditionDataset(Dataset):
    """
    For task="carry": balance classes (carry=1/0) at a target position.
    For task="digits": do not enforce 10-class balance; just collect total samples.
    """
    def __init__(
        self,
        n_digits: int,
        target_pos: int,
        samples_per_class: int,
        tokenizer,
        task: str = "carry",
    ):
        yes, no = [], []
        max_attempts = samples_per_class * 100  # avoid infinite loops
        attempts = 0

        while (len(yes) < samples_per_class or len(no) < samples_per_class) and attempts < max_attempts:
            prompt, carry_flags, digits = generate_addition_sample(n_digits)

            if target_pos >= len(carry_flags):
                attempts += 1
                continue

            if task == "carry":
                flag_val = carry_flags[target_pos]
                if flag_val == 1 and len(yes) < samples_per_class:
                    yes.append((prompt, carry_flags, digits))
                elif flag_val == 0 and len(no) < samples_per_class:
                    no.append((prompt, carry_flags, digits))
            elif task == "digits":
                if len(yes) < samples_per_class * 2:
                    yes.append((prompt, carry_flags, digits))
                else:
                    if len(no) < samples_per_class:
                        no = [None] * samples_per_class  # sentinel to exit
            attempts += 1

        if task == "carry" and (len(yes) < samples_per_class or len(no) < samples_per_class):
            print(
                f"Warning: Could not generate enough balanced samples for 'carry' at position {target_pos}. "
                f"Yes: {len(yes)}, No: {len(no)}"
            )
        elif task == "digits" and len(yes) < samples_per_class * 2:
            print(
                f"Warning: Could not generate enough samples for 'digits' at position {target_pos}. "
                f"Collected: {len(yes)}, Expected: {samples_per_class * 2}"
            )

        self.samples = yes + (no if task == "carry" else [])
        if task == "digits":
            self.samples = [s for s in self.samples if s is not None]  # drop sentinel items

        self.task = task
        self.n_digits_plus_1 = n_digits + 1
        self.tokenizer = tokenizer
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        prompt, carry_flags, digits = self.samples[idx]
        expected_len = self.n_digits_plus_1

        # Ensure consistent length
        if len(carry_flags) < expected_len:
            carry_flags.extend([0] * (expected_len - len(carry_flags)))
        if len(digits) < expected_len:
            digits.extend([0] * (expected_len - len(digits)))

        enc = self.tokenizer(prompt, return_tensors="pt", padding=False, truncation=True)
        return {
            "input_ids": enc.input_ids[0],
            "attention_mask": enc.attention_mask[0],
            "carry_flags": torch.tensor(carry_flags, dtype=torch.long),
            "digits": torch.tensor(digits, dtype=torch.long),
            "prompt_text": prompt,
        }


def collate(batch: List[Dict], pad_id: int) -> Dict[str, Any]:
    """Pad tokenized inputs; stack labels."""
    from torch.nn.utils.rnn import pad_sequence
    coll: Dict[str, Any] = {}
    keys = batch[0].keys()
    for key in keys:
        if key in ("input_ids", "attention_mask"):
            coll[key] = pad_sequence([b[key] for b in batch], batch_first=True, padding_value=pad_id)
        elif key == "prompt_text":
            coll[key] = [b[key] for b in batch]
        else:  # carry_flags, digits
            coll[key] = torch.stack([b[key] for b in batch])
    return coll


class LinearProbe(nn.Module):
    def __init__(self, in_dim: int, n_classes: int):
        super().__init__()
        self.linear = nn.Linear(in_dim, n_classes)

    def forward(self, x):
        return self.linear(x)


class MLPProbe(nn.Module):
    """
    Two-layer MLP:
        in_dim → hidden_dim (ReLU) → n_classes
    Default hidden_dim = in_dim
    """
    def __init__(self, in_dim: int, n_classes: int, hidden_dim: Optional[int] = None):
        super().__init__()
        hidden_dim = hidden_dim or in_dim
        self.mlp = nn.Sequential(
            nn.Linear(in_dim, hidden_dim, bias=True),
            nn.ReLU(),
            nn.Linear(hidden_dim, n_classes, bias=True),
        )

    def forward(self, x):
        return self.mlp(x)


def save_probe_model(probe, layer_idx, pos_val, rep_idx, out_dir: Path):
    """
    Save probe weights to:
      {out_dir}/probes/pos{pos}/probe_layer{layer}_rep{rep}.pt
    """
    save_dir = out_dir / "probes" / f"pos{pos_val}"
    save_dir.mkdir(parents=True, exist_ok=True)
    model_save_path = save_dir / f"probe_layer{layer_idx}_rep{rep_idx}.pt"
    torch.save(probe.state_dict(), model_save_path)


# ==============================
# Helper: infer hidden-states layer count and dimension with a single minimal forward pass
# ==============================
@torch.no_grad()
def infer_hs_layers_and_dim(model, tokenizer, device: str) -> Tuple[int, int]:
    """
    Return (number_of_hidden_states_layers L, hidden_size).
    Note: L typically equals 1 (embedding) + num_hidden_layers.
    """
    tok = tokenizer("0+0=", return_tensors="pt", padding=False, truncation=True)
    tok = {k: v.to(device) for k, v in tok.items()}
    outs = model(**tok, output_hidden_states=True)
    hs = outs.hidden_states  # tuple/list of [B, T, H]
    L = len(hs)
    H = hs[-1].shape[-1]
    return L, H


# ==============================
# Training: one forward per batch; train probes for all layers
# ==============================

@torch.no_grad()
def cache_features_for_loader(
    model,
    dataloader,
    device: str,
    store_dtype: torch.dtype = torch.float16,
):
    """
    对 dataloader 里的所有样本只做一次 LLM 前向，缓存：
      feats: [L, N, H]（所有层的“最后 token”向量，CPU/fp16）
      carry: [N, P]    （carry_flags）
      digits:[N, P]    （sum digits）
    返回 (feats_LNH, carry, digits)
    """
    model.eval()
    # 关闭 KV cache，纯取 hidden_states 更快更省显存
    if hasattr(model, "config") and getattr(model.config, "use_cache", None) is not None:
        model.config.use_cache = False

    feats_list = []
    carry_list = []
    digits_list = []

    with torch.inference_mode():
        for batch in dataloader:
            input_ids = batch["input_ids"].to(device, non_blocking=True)
            attn = batch["attention_mask"].to(device, non_blocking=True)

            outs = model(
                input_ids=input_ids,
                attention_mask=attn,
                output_hidden_states=True,
                use_cache=False,
            )
            hs = outs.hidden_states  # list of [B, T, H]
            idxs = attn.sum(1) - 1  # 每个样本的最后有效 token 下标（[B]）

            # [L, B, H]：拼每层的“最后 token”向量
            last_per_layer = torch.stack(
                [h[torch.arange(h.size(0), device=h.device), idxs] for h in hs],
                dim=0,
            )

            feats_list.append(last_per_layer.to(dtype=store_dtype).cpu())
            carry_list.append(batch["carry_flags"].cpu())
            digits_list.append(batch["digits"].cpu())

    feats = torch.cat(feats_list, dim=1)      # [L, N, H]
    carry = torch.cat(carry_list, dim=0)      # [N, P]
    digits = torch.cat(digits_list, dim=0)    # [N, P]
    return feats, carry, digits

def train_all_layers_for_one_position(
    model,                 # 已不再用于训练，只用于 cache，保留参数以最小化改动
    probes_per_layer: Dict[int, nn.Module],  # {layer_idx: probe}
    dataloader,           # 用它来做一次特征缓存
    position: int,
    device: str,
    task: str,            # "carry" or "digits"
    rep_idx: int,
    out_dir: Path,
    epochs: int,
    lr: float,
) -> Dict[int, float]:
    """
    一次前向缓存特征 -> 多个 epoch 循环只在缓存上训练/评估探针。
    """
    # 1) 只做一次 LLM 前向，缓存 [L, N, H] 特征 + 标签
    feats_LNH, labels_carry, labels_digits = cache_features_for_loader(
        model, dataloader, device, store_dtype=torch.float16
    )
    # 让探针训练时用 fp32；特征在搬到设备时再转 fp32
    probe_dtype = torch.float32

    # 2) 生成当前 position 的标签向量 y: [N]
    if position >= labels_carry.size(1):  # 安全防守
        return {l: 0.0 for l in probes_per_layer}

    y_all = labels_carry[:, position] if task == "carry" else labels_digits[:, position]
    y_all = y_all.to(torch.long)
    N = y_all.numel()

    # 3) 初始化优化器/损失
    for p in probes_per_layer.values():
        p.train()
        p.to(device=device, dtype=probe_dtype)
    opts = {l: torch.optim.AdamW(p.parameters(), lr=lr) for l, p in probes_per_layer.items()}
    crit = nn.CrossEntropyLoss()

    # 为了复用批大小，取自 dataloader（若取不到就回退 512）
    try:
        batch_size = dataloader.batch_size or 512
    except Exception:
        batch_size = 512

    # 4) 训练（仅在缓存特征上进行）
    for _ in range(epochs):
        perm = torch.randperm(N)
        for s in range(0, N, batch_size):
            idx = perm[s : s + batch_size]
            yb = y_all[idx].to(device)

            # 针对每一层的探针，取该层对应的特征切片 [b, H]
            for l, probe in probes_per_layer.items():
                xb = feats_LNH[l, idx].to(device=device, dtype=probe_dtype)  # fp16->fp32
                logits = probe(xb)
                loss = crit(logits.float(), yb)

                opts[l].zero_grad(set_to_none=True)
                loss.backward()
                opts[l].step()

    # 5) 评估（同样只用缓存特征）
    for p in probes_per_layer.values():
        p.eval()

    correct = {l: 0 for l in probes_per_layer}
    with torch.no_grad():
        for s in range(0, N, batch_size):
            idx = slice(s, min(s + batch_size, N))
            yb_cpu = y_all[idx]  # CPU long
            for l, probe in probes_per_layer.items():
                xb = feats_LNH[l, idx].to(device=device, dtype=probe_dtype)
                pred = probe(xb).argmax(-1).cpu()
                correct[l] += (pred == yb_cpu).sum().item()

    acc_by_layer = {l: (correct[l] / N if N else 0.0) for l in probes_per_layer}

    # 6) 保存探针
    for l, probe in probes_per_layer.items():
        save_probe_model(probe, l, position, rep_idx, out_dir)

    return acc_by_layer



def run_one_model(cfg: Config, model_name: str):
    """Run the full pipeline for a single model: load → infer layers → train/eval → save summary."""
    outdir = cfg.OUTPUT_ROOT / sanitize_model_name(model_name)
    outdir.mkdir(parents=True, exist_ok=True)

    print(f"\n========== Running model: {model_name} ==========")
    print(f"Outputs will be saved under: {outdir.resolve()}")

    # Choose dtype
    model_dtype, dtype_name = pick_dtype(cfg.DEVICE, cfg.DTYPE)

    # --- Load tokenizer & model ---
    print(f"Loading tokenizer from {model_name}...")
    tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True, trust_remote_code=cfg.TRUST_REMOTE_CODE)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
        print(f"Tokenizer pad_token was None; set to eos_token: {tokenizer.eos_token}")
    pad_id = tokenizer.pad_token_id

    print(f"Loading model from {model_name} with dtype={dtype_name} on {cfg.DEVICE} ...")
    if "cuda" not in cfg.DEVICE and dtype_name != "fp32":
        print("Warning: CPU device detected; dtype forced to fp32 for stability.")
        model_dtype = torch.float32
        dtype_name = "fp32"

    try:
        model = AutoModelForCausalLM.from_pretrained(
            model_name,
            output_hidden_states=True,
            torch_dtype=model_dtype,
            trust_remote_code=cfg.TRUST_REMOTE_CODE
        ).to(cfg.DEVICE).eval()
    except Exception as e:
        print(f"Error loading model: {e}")
        print("Please ensure MODEL_NAMES are valid and you have permissions if it is private.")
        print("If VRAM is limited, consider a smaller model (e.g., 'gpt2') for testing.")
        return

    # --- Infer actual number of hidden-state layers via one minimal forward ---
    n_layers_hs, hidden_size = infer_hs_layers_and_dim(model, tokenizer, cfg.DEVICE)
    # Note: layer 0 is the embedding, and the last layer is the final layer
    print(f"Detected hidden_states layers (including embeddings): {n_layers_hs}")
    print(f"Detected hidden_size: {hidden_size}")

    # --- Pre-build DataLoaders (one per position; balanced for 'carry') ---
    loaders = {}
    print("Preparing datasets and dataloaders...")
    for pos_val in cfg.POSITIONS:
        ds = BalancedAdditionDataset(cfg.N_DIGITS, pos_val, cfg.SAMPLES_PER_CLASS, tokenizer, task=cfg.TASK_TO_PROBE)
        loaders[pos_val] = DataLoader(
            ds,
            batch_size=cfg.BATCH_SIZE,
            shuffle=True,
            collate_fn=lambda b, pid=pad_id: collate(b, pid),
        )
        print(f"  Pos {pos_val}: {len(ds)} samples -> {len(loaders[pos_val])} batches")
        expected_samples = cfg.SAMPLES_PER_CLASS * 2 if cfg.TASK_TO_PROBE == "carry" else cfg.SAMPLES_PER_CLASS * 2
        if len(ds) < expected_samples:
            print(
                f"  Warning: Dataset for Pos {pos_val} (Task: {cfg.TASK_TO_PROBE}) has fewer samples "
                f"({len(ds)}) than expected ({expected_samples})."
            )

    # --- Multiple repetitions: store accuracies with shape (R, L, P) ---
    all_runs_accuracies_list: List[List[List[float]]] = []

    if "cuda" in cfg.DEVICE:
        try:
            print(f"CUDA Device Name: {torch.cuda.get_device_name(torch.device(cfg.DEVICE))}")
        except Exception:
            pass

    for rep_idx in range(cfg.N_REPETITIONS):
        current_seed = cfg.BASE_SEED + rep_idx
        set_seed(current_seed)
        print(f"\n--- Repetition {rep_idx + 1}/{cfg.N_REPETITIONS} with Seed {current_seed} ---")

        # Rebuild probes each repetition
        accuracies_current_run: List[List[float]] = [[0.0] * len(cfg.POSITIONS) for _ in range(n_layers_hs)]

        # Iterate positions: one forward per batch to train all layer probes
        for pos_list_idx, pos_val in enumerate(tqdm(cfg.POSITIONS, desc=f"Rep {rep_idx + 1} Positions")):
            num_classes_probe = 2 if cfg.TASK_TO_PROBE == "carry" else 10
            probes_for_all_layers: Dict[int, nn.Module] = {
                l: LinearProbe(hidden_size, num_classes_probe) for l in range(n_layers_hs)
            }
            accs = train_all_layers_for_one_position(
                model,
                probes_for_all_layers,
                loaders[pos_val],
                position=pos_val,
                device=cfg.DEVICE,
                task=cfg.TASK_TO_PROBE,
                rep_idx=rep_idx,
                out_dir=outdir,
                epochs=cfg.EPOCHS,
                lr=cfg.LR,
            )
            for l in range(n_layers_hs):
                accuracies_current_run[l][pos_list_idx] = accs[l]

        all_runs_accuracies_list.append(accuracies_current_run)

    # --- Aggregate statistics and save CSV ---
    if not all_runs_accuracies_list:
        print("No accuracy data collected. Skipping save for this model.")
        # Free VRAM
        del model
        if "cuda" in cfg.DEVICE:
            torch.cuda.empty_cache()
        return

    all_runs_accuracies_np = np.array(all_runs_accuracies_list)  # (R, L, P)
    mean_accuracies = np.mean(all_runs_accuracies_np, axis=0)
    std_dev_accuracies = np.std(all_runs_accuracies_np, axis=0)

    # SEM & 95% CI
    sem_accuracies = std_dev_accuracies / math.sqrt(cfg.N_REPETITIONS) if cfg.N_REPETITIONS > 0 else np.zeros_like(std_dev_accuracies)
    z_score = 1.96
    ci_lower_bounds = np.clip(mean_accuracies - z_score * sem_accuracies, 0, 1)
    ci_upper_bounds = np.clip(mean_accuracies + z_score * sem_accuracies, 0, 1)

    csv_path = outdir / f"{cfg.TASK_TO_PROBE}_accuracies_mean_ci_nd{cfg.N_DIGITS}_nrep{cfg.N_REPETITIONS}.csv"
    with open(csv_path, "w", newline="") as f:
        w = csv.writer(f)
        header = ["layer"] + [f"pos_{p}_mean,pos_{p}_ci_lower,pos_{p}_ci_upper,pos_{p}_std_dev" for p in cfg.POSITIONS]
        # Flatten the comma-separated group in header
        flat_header = ["layer"]
        for p in cfg.POSITIONS:
            flat_header.extend([f"pos_{p}_mean", f"pos_{p}_ci_lower", f"pos_{p}_ci_upper", f"pos_{p}_std_dev"])
        w.writerow(flat_header)
        for i in range(n_layers_hs):
            row_data = [i]
            for j in range(len(cfg.POSITIONS)):
                row_data.extend([
                    float(mean_accuracies[i, j]),
                    float(ci_lower_bounds[i, j]),
                    float(ci_upper_bounds[i, j]),
                    float(std_dev_accuracies[i, j]),
                ])
            w.writerow(row_data)
    print(f"Saved mean accuracies with CI to {csv_path}")

    # --- Plot and save (kept consistent with the original) ---
    try:
        import matplotlib.pyplot as plt
        plt.figure(figsize=(12, 8))
        plt.rcParams.update({
            'font.size': 14,
            'font.family': 'serif',
            'font.sans-serif': ['Arial'],
            'axes.labelsize': 14,
            'axes.titlesize': 20,
            'xtick.labelsize': 12,
            'ytick.labelsize': 12,
            'legend.fontsize': 14,
            'lines.linewidth': 1.8,
            'lines.markersize': 6,
            'axes.grid': True,
            'grid.alpha': 0.5,
            'grid.linestyle': ':',
            'savefig.dpi': 600,
            'savefig.format': 'pdf',
            'savefig.bbox': 'tight',
        })

        position_descriptive_labels = {
            0: "ones place",
            1: "tens place",
            2: "hundreds place",
        }

        colors = plt.cm.winter(np.linspace(0, 1, len(cfg.POSITIONS)))

        for pos_list_idx, pos_val in enumerate(cfg.POSITIONS):
            means = mean_accuracies[:, pos_list_idx]
            ci_low = ci_lower_bounds[:, pos_list_idx]
            ci_high = ci_upper_bounds[:, pos_list_idx]

            label_text = position_descriptive_labels.get(pos_val, f"Pos {pos_val}")
            plot_legend_label = f"{label_text} (Mean Acc.)"
            # Leading underscore: do not display this item in legend
            fill_legend_label = f"_{label_text} (95% CI)"

            plt.plot(range(n_layers_hs), means, marker='o', linestyle='-', color=colors[pos_list_idx], label=plot_legend_label)
            plt.fill_between(range(n_layers_hs), ci_low, ci_high, color=colors[pos_list_idx], alpha=0.2, label=fill_legend_label)

        model_tag = sanitize_model_name(model_name)  
        plt.xlabel(f"Layer ({model_tag})")
        plt.ylabel("Mean Probe Accuracy")
        plt.title(f"{cfg.TASK_TO_PROBE.capitalize()} Probe: Mean Accuracy by Layer & Position\n"
                  f"({cfg.N_REPETITIONS} Runs, N_digits: {cfg.N_DIGITS})",
                  fontsize=14)
        plt.legend(loc='best', fontsize=14)
        plt.grid(True, linestyle='--', alpha=0.6)
        plt.ylim(0.45, 1.05)
        plt.xticks(range(n_layers_hs))
        plt.tight_layout()

        plot_path = outdir / f"{cfg.TASK_TO_PROBE}_plot_mean_ci_nd{cfg.N_DIGITS}_nrep{cfg.N_REPETITIONS}.pdf"
        plt.savefig(plot_path, format='pdf', dpi=600)
        plt.close()
        print(f"Saved plot to {plot_path}")
    except ImportError:
        print("matplotlib is not installed; skipping plot.")
    except Exception as e:
        print(f"An error occurred during plotting: {e}")

    # --- Record run metadata ---
    meta = {
        "model_name": model_name,
        "sanitized_dir": sanitize_model_name(model_name),
        "device": cfg.DEVICE,
        "dtype": dtype_name,
        "n_digits": cfg.N_DIGITS,
        "positions": list(cfg.POSITIONS),
        "task": cfg.TASK_TO_PROBE,
        "samples_per_class": cfg.SAMPLES_PER_CLASS,
        "batch_size": cfg.BATCH_SIZE,
        "repetitions": cfg.N_REPETITIONS,
        "epochs": cfg.EPOCHS,
        "lr": cfg.LR,
        "detected_layers_hidden_states": n_layers_hs,
        "hidden_size": hidden_size,
        "csv_path": str(csv_path),
    }
    with open(outdir / "run_meta.json", "w", encoding="utf-8") as f:
        json.dump(meta, f, ensure_ascii=False, indent=2)

    # Free VRAM
    del model
    if "cuda" in cfg.DEVICE:
        torch.cuda.empty_cache()


def main():
    cfg = Config()

    # Create root output directory
    cfg.OUTPUT_ROOT.mkdir(exist_ok=True, parents=True)

    print(f"Using device: {cfg.DEVICE}")
    if "cuda" not in cfg.DEVICE and cfg.DTYPE.lower() != "fp32":
        print("Info: On CPU, dtype will be fp32 regardless of config.DTYPE.")

    # Run over models
    for model_name in cfg.MODEL_NAMES:
        run_one_model(cfg, model_name)


if __name__ == "__main__":
    main()


Using device: cuda:0

Outputs will be saved under: /root/EMNLP1 copy/Carry_Signal_Experiment/outputs/AceMath
Loading tokenizer from /root/autodl-tmp/AceMath...
Loading model from /root/autodl-tmp/AceMath with dtype=bf16 on cuda:0 ...


Sliding Window Attention is enabled but not implemented for `sdpa`; unexpected results may be encountered.


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Detected hidden_states layers (including embeddings): 29
Detected hidden_size: 3584
Preparing datasets and dataloaders...
  Pos 0: 2000 samples -> 4 batches
  Pos 1: 2000 samples -> 4 batches
  Pos 2: 2000 samples -> 4 batches
CUDA Device Name: Tesla V100S-PCIE-32GB

--- Repetition 1/5 with Seed 42 ---


Rep 1 Positions: 100%|██████████| 3/3 [02:06<00:00, 42.14s/it]



--- Repetition 2/5 with Seed 43 ---


Rep 2 Positions: 100%|██████████| 3/3 [02:06<00:00, 42.09s/it]



--- Repetition 3/5 with Seed 44 ---


Rep 3 Positions: 100%|██████████| 3/3 [02:05<00:00, 41.98s/it]



--- Repetition 4/5 with Seed 45 ---


Rep 4 Positions: 100%|██████████| 3/3 [02:05<00:00, 41.79s/it]



--- Repetition 5/5 with Seed 46 ---


Rep 5 Positions: 100%|██████████| 3/3 [02:05<00:00, 41.94s/it]


Saved mean accuracies with CI to outputs/AceMath/carry_accuracies_mean_ci_nd3_nrep5.csv
Saved plot to outputs/AceMath/carry_plot_mean_ci_nd3_nrep5.pdf

Outputs will be saved under: /root/EMNLP1 copy/Carry_Signal_Experiment/outputs/Qwen2.5-Math-7B
Loading tokenizer from /root/autodl-tmp/Qwen2.5-Math-7B...
Loading model from /root/autodl-tmp/Qwen2.5-Math-7B with dtype=bf16 on cuda:0 ...




Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Detected hidden_states layers (including embeddings): 29
Detected hidden_size: 3584
Preparing datasets and dataloaders...
  Pos 0: 2000 samples -> 4 batches
  Pos 1: 2000 samples -> 4 batches
  Pos 2: 2000 samples -> 4 batches
CUDA Device Name: Tesla V100S-PCIE-32GB

--- Repetition 1/5 with Seed 42 ---


Rep 1 Positions: 100%|██████████| 3/3 [02:05<00:00, 41.98s/it]



--- Repetition 2/5 with Seed 43 ---


Rep 2 Positions: 100%|██████████| 3/3 [02:06<00:00, 42.21s/it]



--- Repetition 3/5 with Seed 44 ---


Rep 3 Positions: 100%|██████████| 3/3 [02:06<00:00, 42.15s/it]



--- Repetition 4/5 with Seed 45 ---


Rep 4 Positions: 100%|██████████| 3/3 [02:06<00:00, 42.19s/it]



--- Repetition 5/5 with Seed 46 ---


Rep 5 Positions: 100%|██████████| 3/3 [02:05<00:00, 41.97s/it]


Saved mean accuracies with CI to outputs/Qwen2.5-Math-7B/carry_accuracies_mean_ci_nd3_nrep5.csv
Saved plot to outputs/Qwen2.5-Math-7B/carry_plot_mean_ci_nd3_nrep5.pdf

Outputs will be saved under: /root/EMNLP1 copy/Carry_Signal_Experiment/outputs/Qwen2.5-7B-Instruct
Loading tokenizer from /root/autodl-tmp/Qwen2.5-7B-Instruct...
Loading model from /root/autodl-tmp/Qwen2.5-7B-Instruct with dtype=bf16 on cuda:0 ...




Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Detected hidden_states layers (including embeddings): 29
Detected hidden_size: 3584
Preparing datasets and dataloaders...
  Pos 0: 2000 samples -> 4 batches
  Pos 1: 2000 samples -> 4 batches
  Pos 2: 2000 samples -> 4 batches
CUDA Device Name: Tesla V100S-PCIE-32GB

--- Repetition 1/5 with Seed 42 ---


Rep 1 Positions: 100%|██████████| 3/3 [02:06<00:00, 42.06s/it]



--- Repetition 2/5 with Seed 43 ---


Rep 2 Positions: 100%|██████████| 3/3 [02:06<00:00, 42.00s/it]



--- Repetition 3/5 with Seed 44 ---


Rep 3 Positions: 100%|██████████| 3/3 [02:06<00:00, 42.19s/it]



--- Repetition 4/5 with Seed 45 ---


Rep 4 Positions: 100%|██████████| 3/3 [02:07<00:00, 42.38s/it]



--- Repetition 5/5 with Seed 46 ---


Rep 5 Positions: 100%|██████████| 3/3 [02:06<00:00, 42.32s/it]


Saved mean accuracies with CI to outputs/Qwen2.5-7B-Instruct/carry_accuracies_mean_ci_nd3_nrep5.csv
Saved plot to outputs/Qwen2.5-7B-Instruct/carry_plot_mean_ci_nd3_nrep5.pdf
