In [1]:
# -*- coding: utf-8 -*-
# Multi-model + adaptive layer count + per-model subdirectory saving version

# --- Imports ---
import os, re, math, random, datetime, collections, gc
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader
from transformers import AutoTokenizer, AutoModelForCausalLM
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, classification_report
from sklearn.utils import shuffle
import matplotlib.pyplot as plt
import scipy.stats

# ============ 0. Global config (edit as needed) ============
# Support multiple HF repo IDs or local paths
MODELS = [
        #"/root/autodl-tmp/llama",
        #"/root/autodl-tmp/Mistral",
        #"/root/autodl-tmp/AceMath",
        "/root/autodl-tmp/Qwen2.5-Math-7B",
        "/root/autodl-tmp/Qwen2.5-7B-Instruct"
    ]

DEVICE_ID = 0
DEVICE = torch.device(f"cuda:{DEVICE_ID}" if torch.cuda.is_available() else "cpu")

# Experiment & training settings
N_SAMPLES_PER_DIGIT = 500
BATCH_SIZE_PROBE = 64
NUM_EPOCHS_PROBE = 5
LEARNING_RATE_PROBE = 0.001
HIDDEN_DIM_FACTOR = 1
TEST_SPLIT_SIZE = 0.25
RANDOM_SEED_BASE = 42
N_REPETITIONS = 5
CONFIDENCE_LEVEL = 0.95

# Data generation settings (target: hundreds digit)
ADD_MIN_SUM = 100
N_MULT_TEST_SAMPLES = 500
MULT_MIN_PRODUCT = 100
MULT_MAX_PRODUCT = 999
N_SUB_TEST_SAMPLES = 500
SUB_MIN_DIFFERENCE = 100
SUB_MAX_A = 999

# Auto layer probing strategy:
# - "auto": automatically use [0 .. num_hidden_layers] (including embedding layer 0)
# - You can also change to a specific list such as [1..num_hidden_layers] to probe only blocks:
#   modify one line inside run_experiment_for_model
GLOBAL_LAYERS_TO_PROBE = "auto"

# --- Main experiment directory ---
experiment_timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
MAIN_EXPERIMENT_DIR = f"./experiment_hundreds_CI_{experiment_timestamp}"
os.makedirs(MAIN_EXPERIMENT_DIR, exist_ok=True)
print(f"[INIT] Main experiment directory: {MAIN_EXPERIMENT_DIR}")
print(f"[INIT] Using device: {DEVICE}")

# ============ 1. Utility functions ============
def sanitize_model_id(model_id: str) -> str:
    """
    Make a safe folder name from a model id:
    - Strip leading/trailing slashes
    - Replace non-alphanumeric and . - _ with '_'
    """
    base = model_id.strip().rstrip("/")
    # Use basename (for "org/name" -> "name"; for local path -> last directory)
    base = os.path.basename(base)
    safe = re.sub(r"[^0-9A-Za-z._-]+", "_", base)
    return safe or "model"

def set_seed(seed_value: int):
    random.seed(seed_value)
    np.random.seed(seed_value)
    torch.manual_seed(seed_value)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed_value)

def format_add_prompt(a, b): return f"Calculate: {a}+{b} = "
def format_mult_prompt(a, b): return f"Calculate: {a}x{b} = "
def format_sub_prompt(a, b): return f"Calculate: {a}-{b} = "

def get_layers_to_probe(model, mode="auto"):
    """
    Return a list of layer indices for indexing hidden_states.
    In HF models, hidden_states[0] is the embedding; 1..N are block outputs.
    """
    if mode == "auto":
        num_layers = getattr(model.config, "num_hidden_layers", None)
        if isinstance(num_layers, int) and num_layers > 0:
            return list(range(0, num_layers + 1))  # include embedding layer
        # Fallback: run a dummy forward to check length
        tok = AutoTokenizer.from_pretrained("gpt2")  # fallback only; not actually used
        with torch.no_grad():
            outs = model(**tok("test", return_tensors="pt"))
        return list(range(0, len(outs.hidden_states)))
    elif isinstance(mode, (list, tuple)):
        return list(mode)
    else:
        raise ValueError("Unsupported GLOBAL_LAYERS_TO_PROBE setting.")

# --- Activation extraction ---
def get_activations(prompts, layers_to_probe, current_model, current_tokenizer, batch_size=8, desc_prefix=""):
    from tqdm.auto import tqdm
    activations = {layer: [] for layer in layers_to_probe}
    current_model.eval()
    with torch.no_grad():
        for i in tqdm(range(0, len(prompts), batch_size), desc=f"{desc_prefix} Extracting activations"):
            batch_prompts = prompts[i : i + batch_size]
            try:
                inputs = current_tokenizer(batch_prompts, return_tensors="pt", padding=True, truncation=True, max_length=64).to(DEVICE)
                outputs = current_model(**inputs)
                hidden_states = outputs.hidden_states
                sequence_lengths = inputs.attention_mask.sum(dim=1)
                last_token_indices = sequence_lengths - 1
                for layer_idx in layers_to_probe:
                    if layer_idx < 0 or layer_idx >= len(hidden_states):
                        continue
                    layer_hidden_states = hidden_states[layer_idx]
                    batch_indices = torch.arange(layer_hidden_states.size(0), device=DEVICE)
                    last_token_activations = layer_hidden_states[batch_indices, last_token_indices, :]
                    activations[layer_idx].append(last_token_activations.to(torch.float32).cpu().numpy())
            except Exception as e:
                print(f"[WARN] Batch {i} ({desc_prefix}) failed: {e}")
    final_activations = {}
    for layer in layers_to_probe:
        if activations[layer]:
            try:
                final_activations[layer] = np.concatenate(activations[layer], axis=0)
            except ValueError as e:
                print(f"[WARN] Layer {layer} concatenation failed: {e} | shapes: {[a.shape for a in activations[layer]]}")
                final_activations[layer] = np.array([])
        else:
            final_activations[layer] = np.array([])
    return final_activations

# --- Linear probe ---
class LinearProbe(nn.Module):
    def __init__(self, in_dim: int, n_classes: int):
        super().__init__()
        self.linear = nn.Linear(in_dim, n_classes)
    def forward(self, x):
        return self.linear(x)

# ============ 2. Full pipeline for a single model ============
def run_experiment_for_model(model_id: str):
    safe_name = sanitize_model_id(model_id)
    model_dir = os.path.join(MAIN_EXPERIMENT_DIR, safe_name)
    os.makedirs(model_dir, exist_ok=True)

    aggregated_log_filename = os.path.join(model_dir, f"AGGREGATED_probe_log_{safe_name}_{experiment_timestamp}.txt")
    if os.path.exists(aggregated_log_filename):
        os.remove(aggregated_log_filename)

    # Load model & tokenizer
    print(f"\n===== Loading model: {model_id} =====")
    tokenizer = AutoTokenizer.from_pretrained(model_id)
    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        output_hidden_states=True,
        torch_dtype=torch.bfloat16,
        device_map=None  # we manually .to(DEVICE)
    )
    model.to(DEVICE)
    model.eval()
    if tokenizer.pad_token is None:
        print("[INFO] Tokenizer has no pad_token; using eos_token as pad.")
        tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side = 'left'
    print("[OK] Model loaded.")

    # Adaptive layers
    layers_to_probe = get_layers_to_probe(model, GLOBAL_LAYERS_TO_PROBE)
    print(f"[INFO] Layers to probe (first 5 shown): {layers_to_probe[:5]}... total {len(layers_to_probe)} layers (including embedding layer 0).")

    # Store results across repetitions
    overall_aggregated_results = collections.defaultdict(lambda: collections.defaultdict(list))

    # ========= Repeated experiments =========
    for rep_idx in range(N_REPETITIONS):
        current_seed = RANDOM_SEED_BASE + rep_idx
        set_seed(current_seed)
        print(f"\n--- Repetition {rep_idx + 1}/{N_REPETITIONS} | Seed = {current_seed} ---")

        # Run directory for this repetition
        run_dir = os.path.join(model_dir, f"run_{rep_idx}_seed_{current_seed}")
        os.makedirs(run_dir, exist_ok=True)
        PROBE_SAVE_DIR_RUN = os.path.join(run_dir, "probe_models_hundreds")
        os.makedirs(PROBE_SAVE_DIR_RUN, exist_ok=True)
        run_log_filename = os.path.join(run_dir, f"run_log_{experiment_timestamp}.txt")
        if os.path.exists(run_log_filename):
            os.remove(run_log_filename)

        with open(aggregated_log_filename, "a") as agg_f:
            agg_f.write(f"\n--- Repetition {rep_idx + 1}/{N_REPETITIONS} with Seed {current_seed} ---\n")

        # --- Generate addition data (target: hundreds digit) ---
        n_add_total_target = N_SAMPLES_PER_DIGIT * 10
        n_add_to_generate_initial = n_add_total_target * 30
        potential_add_samples = []
        attempts_add = 0
        max_attempts_add = n_add_to_generate_initial * 3
        while len(potential_add_samples) < n_add_to_generate_initial and attempts_add < max_attempts_add:
            a = random.randint(1, 899); b = random.randint(1, 899)
            s = a + b
            if s >= ADD_MIN_SUM:
                potential_add_samples.append({'a': a, 'b': b, 'sum': s})
            attempts_add += 1
        samples_by_hundreds_add = collections.defaultdict(list)
        for item in potential_add_samples:
            hd = (item['sum'] // 100) % 10
            samples_by_hundreds_add[hd].append(item)

        all_add_samples_final_run = []
        for digit_val in range(10):
            available = samples_by_hundreds_add[digit_val]
            n_to_take = min(N_SAMPLES_PER_DIGIT, len(available))
            if len(available) < N_SAMPLES_PER_DIGIT:
                print(f"[WARN][Seed {current_seed}] Addition hundreds digit {digit_val} has only {len(available)} samples.")
            if not available:
                print(f"[ERR ][Seed {current_seed}] Addition hundreds digit {digit_val} has no samples.")
                continue
            random.shuffle(available)
            selected = available[:n_to_take]
            for sample in selected:
                sample['label'] = digit_val
                sample['prompt'] = format_add_prompt(sample['a'], sample['b'])
            all_add_samples_final_run.extend(selected)
        random.shuffle(all_add_samples_final_run)
        print(f"[OK] Total addition train+test samples: {len(all_add_samples_final_run)}")
        if not all_add_samples_final_run:
            with open(aggregated_log_filename, "a") as agg_f:
                agg_f.write("CRITICAL ERROR: No ADDITION samples. Repetition skipped.\n")
            with open(run_log_filename, "a") as run_f:
                run_f.write("CRITICAL ERROR: No ADDITION samples. Repetition skipped.\n")
            continue

        # Multiplication test set
        mult_samples_test_run = []
        attempts_mult = 0; max_attempts_mult = N_MULT_TEST_SAMPLES * 50
        min_factor = 2; max_factor = 99
        while len(mult_samples_test_run) < N_MULT_TEST_SAMPLES and attempts_mult < max_attempts_mult:
            attempts_mult += 1
            a = random.randint(min_factor, max_factor); b = random.randint(min_factor, max_factor)
            p = a * b
            if MULT_MIN_PRODUCT <= p <= MULT_MAX_PRODUCT:
                hd = (p // 100) % 10
                mult_samples_test_run.append({'a':a,'b':b,'product':p,'label':hd,'prompt':format_mult_prompt(a,b)})
        if len(mult_samples_test_run) > N_MULT_TEST_SAMPLES:
            random.shuffle(mult_samples_test_run)
            mult_samples_test_run = mult_samples_test_run[:N_MULT_TEST_SAMPLES]
        print(f"[OK] Multiplication test samples: {len(mult_samples_test_run)}")

        # Subtraction test set
        sub_samples_test_run = []
        attempts_sub = 0; max_attempts_sub = N_SUB_TEST_SAMPLES * 100
        while len(sub_samples_test_run) < N_SUB_TEST_SAMPLES and attempts_sub < max_attempts_sub:
            attempts_sub += 1
            a = random.randint(SUB_MIN_DIFFERENCE + 1, SUB_MAX_A)
            max_b = a - SUB_MIN_DIFFERENCE
            if max_b < 1: continue
            b = random.randint(1, max_b)
            d = a - b
            if d >= SUB_MIN_DIFFERENCE:
                hd = (d // 100) % 10
                sub_samples_test_run.append({'a':a,'b':b,'difference':d,'label':hd,'prompt':format_sub_prompt(a,b)})
        if len(sub_samples_test_run) > N_SUB_TEST_SAMPLES:
            random.shuffle(sub_samples_test_run)
            sub_samples_test_run = sub_samples_test_run[:N_SUB_TEST_SAMPLES]
        print(f"[OK] Subtraction test samples: {len(sub_samples_test_run)}")

        # --- Extract activations ---
        from tqdm.auto import tqdm
        print("[ACT] Extracting activations for ADDITION...")
        prompts_add_all_run = [s['prompt'] for s in all_add_samples_final_run]
        activations_add_all_run = get_activations(prompts_add_all_run, layers_to_probe, model, tokenizer, batch_size=BATCH_SIZE_PROBE, desc_prefix=f"ADD S{current_seed}")

        activations_mult_test_run = {}; y_mult_test_labels_run = np.array([])
        if mult_samples_test_run:
            prompts_mult_test_run = [s['prompt'] for s in mult_samples_test_run]
            activations_mult_test_run = get_activations(prompts_mult_test_run, layers_to_probe, model, tokenizer, batch_size=BATCH_SIZE_PROBE, desc_prefix=f"MULT S{current_seed}")
            y_mult_test_labels_run = np.array([s['label'] for s in mult_samples_test_run])

        activations_sub_test_run = {}; y_sub_test_labels_run = np.array([])
        if sub_samples_test_run:
            prompts_sub_test_run = [s['prompt'] for s in sub_samples_test_run]
            activations_sub_test_run = get_activations(prompts_sub_test_run, layers_to_probe, model, tokenizer, batch_size=BATCH_SIZE_PROBE, desc_prefix=f"SUB S{current_seed}")
            y_sub_test_labels_run = np.array([s['label'] for s in sub_samples_test_run])

        # --- Train & evaluate ---
        y_add_all_run = np.array([s['label'] for s in all_add_samples_final_run])
        print(f"[PROBE] Start layer-wise probing ({min(layers_to_probe)}..{max(layers_to_probe)})")

        for layer in tqdm(layers_to_probe, desc=f"Probing Layers S{current_seed}"):
            layer_results_run = {"layer": layer, "accuracy_add": np.nan, "cross_accuracy_mult": np.nan, "cross_accuracy_sub": np.nan, "report_add": "Skipped"}
            if layer not in activations_add_all_run or activations_add_all_run[layer].shape[0] == 0:
                with open(run_log_filename, "a") as f:
                    f.write(f"--- Layer {layer} ---\nSkipped: No ADDITION activation data\n{'-'*20}\n")
                with open(aggregated_log_filename, "a") as agg_f:
                    agg_f.write(f"L{layer}: ADD No Act\n")
                overall_aggregated_results[layer]["accuracy_add"].append(np.nan)
                overall_aggregated_results[layer]["cross_accuracy_mult"].append(np.nan)
                overall_aggregated_results[layer]["cross_accuracy_sub"].append(np.nan)
                continue

            X_add_all_layer_run = activations_add_all_run[layer]
            if X_add_all_layer_run.shape[0] != len(y_add_all_run):
                with open(run_log_filename, "a") as f:
                    f.write(f"--- Layer {layer} ---\nSkipped: Mismatch activations/labels\n{'-'*20}\n")
                with open(aggregated_log_filename, "a") as agg_f:
                    agg_f.write(f"L{layer}: ADD Mismatch\n")
                overall_aggregated_results[layer]["accuracy_add"].append(np.nan)
                overall_aggregated_results[layer]["cross_accuracy_mult"].append(np.nan)
                overall_aggregated_results[layer]["cross_accuracy_sub"].append(np.nan)
                continue

            unique_labels_run = np.unique(y_add_all_run)
            if len(unique_labels_run) < 2:
                with open(run_log_filename, "a") as f:
                    f.write(f"--- Layer {layer} ---\nSkipped: Insufficient unique classes ({len(unique_labels_run)})\n{'-'*20}\n")
                with open(aggregated_log_filename, "a") as agg_f:
                    agg_f.write(f"L{layer}: ADD FewClasses\n")
                overall_aggregated_results[layer]["accuracy_add"].append(np.nan)
                overall_aggregated_results[layer]["cross_accuracy_mult"].append(np.nan)
                overall_aggregated_results[layer]["cross_accuracy_sub"].append(np.nan)
                continue

            OUTPUT_DIM = 10
            try:
                X_add_train_run, X_add_test_run, y_add_train_run, y_add_test_run = train_test_split(
                    X_add_all_layer_run, y_add_all_run, test_size=TEST_SPLIT_SIZE, random_state=current_seed, stratify=y_add_all_run
                )
            except ValueError as e_split:
                with open(run_log_filename, "a") as f:
                    f.write(f"--- Layer {layer} ---\nSkipped: Train/Test split error: {e_split}\n{'-'*20}\n")
                with open(aggregated_log_filename, "a") as agg_f:
                    agg_f.write(f"L{layer}: ADD SplitErr\n")
                overall_aggregated_results[layer]["accuracy_add"].append(np.nan)
                overall_aggregated_results[layer]["cross_accuracy_mult"].append(np.nan)
                overall_aggregated_results[layer]["cross_accuracy_sub"].append(np.nan)
                continue

            X_add_train_run, y_add_train_run = shuffle(X_add_train_run, y_add_train_run, random_state=current_seed)
            INPUT_DIM_run = X_add_train_run.shape[1]
            HIDDEN_DIM_run = int(INPUT_DIM_run * HIDDEN_DIM_FACTOR)

            probe = LinearProbe(INPUT_DIM_run, OUTPUT_DIM).to(DEVICE)
            criterion = nn.CrossEntropyLoss()
            optimizer = optim.Adam(probe.parameters(), lr=LEARNING_RATE_PROBE)

            train_dataset = TensorDataset(torch.tensor(X_add_train_run, dtype=torch.float32).to(DEVICE),
                                          torch.tensor(y_add_train_run, dtype=torch.long).to(DEVICE))
            train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE_PROBE, shuffle=True)

            probe.train()
            for _ in range(NUM_EPOCHS_PROBE):
                for batch_X, batch_y in train_loader:
                    optimizer.zero_grad()
                    logits = probe(batch_X)
                    loss = criterion(logits, batch_y)
                    loss.backward()
                    optimizer.step()

            # Save the probe for this layer
            layer_probe_dir = os.path.join(PROBE_SAVE_DIR_RUN, f"layer_{layer}")
            os.makedirs(layer_probe_dir, exist_ok=True)
            torch.save(probe.state_dict(), os.path.join(layer_probe_dir, "state_dict.pt"))

            # ADD test
            probe.eval()
            with torch.no_grad():
                logits = probe(torch.tensor(X_add_test_run, dtype=torch.float32).to(DEVICE))
                _, pred_idx = torch.max(logits, 1)
                y_pred = pred_idx.cpu().numpy()
                y_true = torch.tensor(y_add_test_run, dtype=torch.long).cpu().numpy()
                acc_add = accuracy_score(y_true, y_pred)
                try:
                    report_add = classification_report(
                        y_true, y_pred, labels=list(range(OUTPUT_DIM)), digits=3, zero_division=0,
                        target_names=[f"Digit {i}" for i in range(OUTPUT_DIM)]
                    )
                except ValueError as e:
                    report_add = f"Acc: {acc_add:.4f}, Report failed: {e}"

            layer_results_run["accuracy_add"] = acc_add
            layer_results_run["report_add"] = report_add
            overall_aggregated_results[layer]["accuracy_add"].append(acc_add)

            # MULT transfer
            acc_cross_mult = np.nan
            if mult_samples_test_run and layer in activations_mult_test_run and activations_mult_test_run[layer].shape[0] > 0:
                Xm = activations_mult_test_run[layer]
                if Xm.shape[0] == len(y_mult_test_labels_run):
                    with torch.no_grad():
                        logits_m = probe(torch.tensor(Xm, dtype=torch.float32).to(DEVICE))
                        _, pred_m = torch.max(logits_m, 1)
                        acc_cross_mult = accuracy_score(y_mult_test_labels_run, pred_m.cpu().numpy())
            layer_results_run["cross_accuracy_mult"] = acc_cross_mult
            overall_aggregated_results[layer]["cross_accuracy_mult"].append(acc_cross_mult)

            # SUB transfer
            acc_cross_sub = np.nan
            if sub_samples_test_run and layer in activations_sub_test_run and activations_sub_test_run[layer].shape[0] > 0:
                Xs = activations_sub_test_run[layer]
                if Xs.shape[0] == len(y_sub_test_labels_run):
                    with torch.no_grad():
                        logits_s = probe(torch.tensor(Xs, dtype=torch.float32).to(DEVICE))
                        _, pred_s = torch.max(logits_s, 1)
                        acc_cross_sub = accuracy_score(y_sub_test_labels_run, pred_s.cpu().numpy())
            layer_results_run["cross_accuracy_sub"] = acc_cross_sub
            overall_aggregated_results[layer]["cross_accuracy_sub"].append(acc_cross_sub)

            # Logging
            with open(run_log_filename, "a") as f_run:
                f_run.write(f"--- Layer {layer} ---\n")
                f_run.write(f"ADDITION Test Accuracy: {acc_add:.4f}\n")
                f_run.write(f"Cross-Task Accuracy (Mult): {acc_cross_mult:.4f}\n")
                f_run.write(f"Cross-Task Accuracy (Sub): {acc_cross_sub:.4f}\n")
                f_run.write(f"ADDITION Classification Report:\n{report_add}\n{'-'*20}\n")
            with open(aggregated_log_filename, "a") as f_agg:
                f_agg.write(f"L{layer}: AddAcc={acc_add:.4f}, MultAcc={acc_cross_mult:.4f}, SubAcc={acc_cross_sub:.4f}\n")

        print(f"[DONE] Layer-wise probing finished for Seed {current_seed}.")

    # ========= Aggregate statistics + plot (per model) =========
    print("\n[AGG] Computing final statistics...")
    alpha = 1.0 - CONFIDENCE_LEVEL
    with open(aggregated_log_filename, "a") as f_agg:
        f_agg.write("\n\n--- Aggregated Statistics Across Runs ---\n")
        f_agg.write(f"Number of Repetitions: {N_REPETITIONS}, Confidence Level: {CONFIDENCE_LEVEL*100}%\n")

    sorted_layers = sorted(list(overall_aggregated_results.keys()))
    plot_data = {
        "layers": sorted_layers,
        "add": {"mean": [], "ci_margin": []},
        "mult": {"mean": [], "ci_margin": []},
        "sub": {"mean": [], "ci_margin": []}
    }

    for layer in sorted_layers:
        with open(aggregated_log_filename, "a") as f_agg:
            f_agg.write(f"\n--- Layer {layer} ---\n")
        for metric_key, display_name, plot_key in [
            ("accuracy_add", "ADDITION Test Accuracy", "add"),
            ("cross_accuracy_mult", "Cross-Task Accuracy (MULT)", "mult"),
            ("cross_accuracy_sub", "Cross-Task Accuracy (SUB)", "sub")
        ]:
            vals = np.array([v for v in overall_aggregated_results[layer][metric_key] if not np.isnan(v)])
            mean_acc = np.nan
            ci_lower = ci_upper = np.nan
            ci_margin = 0.0
            if len(vals) >= 2:
                mean_acc = np.mean(vals)
                se = scipy.stats.sem(vals)
                if se > 0:
                    lo, hi = scipy.stats.t.interval(CONFIDENCE_LEVEL, len(vals)-1, loc=mean_acc, scale=se)
                    ci_lower, ci_upper = lo, hi
                    ci_margin = (hi - lo) / 2
                else:
                    ci_lower = ci_upper = mean_acc
                    ci_margin = 0.0
            elif len(vals) == 1:
                mean_acc = vals[0]
            plot_data[plot_key]["mean"].append(mean_acc)
            plot_data[plot_key]["ci_margin"].append(ci_margin)
            line = f"{display_name}: Mean={mean_acc:.4f}, {CONFIDENCE_LEVEL*100}% CI=({ci_lower:.4f}, {ci_upper:.4f}), N_valid_runs={len(vals)}"
            print(line)
            with open(aggregated_log_filename, "a") as f_agg:
                f_agg.write(line + "\n")

    # Plotting
    plt.figure(figsize=(12, 8))
    plt.rcParams.update({
        'font.size': 14, 'font.family': 'serif', 'font.sans-serif': ['Arial'],
        'axes.labelsize': 14, 'axes.titlesize': 20,
        'xtick.labelsize': 12, 'ytick.labelsize': 12, 'legend.fontsize': 14,
        'lines.linewidth': 1.8, 'lines.markersize': 6,
        'axes.grid': True, 'grid.alpha': 0.5, 'grid.linestyle': ':',
        'savefig.dpi': 600, 'savefig.format': 'pdf', 'savefig.bbox': 'tight',
    })

    # ADD
    add_means = np.array(plot_data["add"]["mean"])
    add_cis = np.array(plot_data["add"]["ci_margin"])
    plt.plot(plot_data["layers"], add_means, marker='o', linestyle='-', label='Mean Probe Accuracy on ADDITION')
    plt.fill_between(plot_data["layers"], add_means - add_cis, add_means + add_cis, alpha=0.2, label=f'{CONFIDENCE_LEVEL*100}% CI (ADD)')

    # MULT
    mult_means = np.array(plot_data["mult"]["mean"])
    mult_cis = np.array(plot_data["mult"]["ci_margin"])
    plt.plot(plot_data["layers"], mult_means, marker='x', linestyle=':', label='Mean Probe Accuracy on MULTIPLICATION')
    plt.fill_between(plot_data["layers"], mult_means - mult_cis, mult_means + mult_cis, alpha=0.2, label=f'{CONFIDENCE_LEVEL*100}% CI (MULT)')

    # SUB
    sub_means = np.array(plot_data["sub"]["mean"])
    sub_cis = np.array(plot_data["sub"]["ci_margin"])
    plt.plot(plot_data["layers"], sub_means, marker='^', linestyle='--', label='Mean Probe Accuracy on SUBTRACTION')
    plt.fill_between(plot_data["layers"], sub_means - sub_cis, sub_means + sub_cis, alpha=0.2, label=f'{CONFIDENCE_LEVEL*100}% CI (SUB)')

    # Chance level
    chance = 1.0 / 10
    plt.axhline(y=chance, linestyle='-.', label=f'Chance Level ({chance:.2f}, 10 classes)')

    plt.xlabel(f"Model Layer Index ({safe_name})")
    plt.ylabel("Mean Probe Accuracy Score")
    plt.title(f"[{safe_name}] Mean Probe Performance & {int(CONFIDENCE_LEVEL*100)}% CI (N={N_REPETITIONS})")
    if plot_data["layers"]:
        plt.xticks(plot_data["layers"])
    plt.ylim(0.0, 1.05)
    plt.legend(loc='best', fontsize='small')
    plt.tight_layout()

    plot_pdf = os.path.join(model_dir, f"{safe_name}_mean_probe_accuracy_CI_{experiment_timestamp}.pdf")
    plot_png = os.path.join(model_dir, f"{safe_name}_mean_probe_accuracy_CI_{experiment_timestamp}.png")
    plt.savefig(plot_pdf, format='pdf', bbox_inches='tight')
    plt.savefig(plot_png, bbox_inches='tight')
    print(f"[SAVE] Plot saved: {plot_pdf}\n[SAVE] Plot saved: {plot_png}")
    plt.close()

    print(f"[LOG] Aggregated log: {aggregated_log_filename}")
    print(f"[OK ] Model {model_id} finished; results at: {model_dir}")

    # Free GPU memory
    del model
    try:
        torch.cuda.empty_cache()
    except Exception:
        pass
    gc.collect()

# ============ 3. Entry point: run multiple models in sequence ============
if __name__ == "__main__":
    if not MODELS:
        print("Please provide at least one model path or HF repo id in MODELS.")
    for mid in MODELS:
        try:
            run_experiment_for_model(mid)
        except Exception as e:
            print(f"[ERROR] Model {mid} failed to run: {e}")


[INIT] Main experiment directory: ./experiment_hundreds_CI_20250909_164549
[INIT] Using device: cuda:0

===== Loading model: /root/autodl-tmp/Qwen2.5-Math-7B =====


Sliding Window Attention is enabled but not implemented for `sdpa`; unexpected results may be encountered.


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

[OK] Model loaded.
[INFO] Layers to probe (first 5 shown): [0, 1, 2, 3, 4]... total 29 layers (including embedding layer 0).

--- Repetition 1/5 | Seed = 42 ---
[OK] Total addition train+test samples: 5000
[OK] Multiplication test samples: 500
[OK] Subtraction test samples: 500
[ACT] Extracting activations for ADDITION...


ADD S42 Extracting activations:   0%|          | 0/79 [00:00<?, ?it/s]

MULT S42 Extracting activations:   0%|          | 0/8 [00:00<?, ?it/s]

SUB S42 Extracting activations:   0%|          | 0/8 [00:00<?, ?it/s]

[PROBE] Start layer-wise probing (0..28)


Probing Layers S42:   0%|          | 0/29 [00:00<?, ?it/s]

[DONE] Layer-wise probing finished for Seed 42.

--- Repetition 2/5 | Seed = 43 ---
[OK] Total addition train+test samples: 5000
[OK] Multiplication test samples: 500
[OK] Subtraction test samples: 500
[ACT] Extracting activations for ADDITION...


ADD S43 Extracting activations:   0%|          | 0/79 [00:00<?, ?it/s]

MULT S43 Extracting activations:   0%|          | 0/8 [00:00<?, ?it/s]

SUB S43 Extracting activations:   0%|          | 0/8 [00:00<?, ?it/s]

[PROBE] Start layer-wise probing (0..28)


Probing Layers S43:   0%|          | 0/29 [00:00<?, ?it/s]

[DONE] Layer-wise probing finished for Seed 43.

--- Repetition 3/5 | Seed = 44 ---
[OK] Total addition train+test samples: 5000
[OK] Multiplication test samples: 500
[OK] Subtraction test samples: 500
[ACT] Extracting activations for ADDITION...


ADD S44 Extracting activations:   0%|          | 0/79 [00:00<?, ?it/s]

MULT S44 Extracting activations:   0%|          | 0/8 [00:00<?, ?it/s]

SUB S44 Extracting activations:   0%|          | 0/8 [00:00<?, ?it/s]

[PROBE] Start layer-wise probing (0..28)


Probing Layers S44:   0%|          | 0/29 [00:00<?, ?it/s]

[DONE] Layer-wise probing finished for Seed 44.

--- Repetition 4/5 | Seed = 45 ---
[OK] Total addition train+test samples: 5000
[OK] Multiplication test samples: 500
[OK] Subtraction test samples: 500
[ACT] Extracting activations for ADDITION...


ADD S45 Extracting activations:   0%|          | 0/79 [00:00<?, ?it/s]

MULT S45 Extracting activations:   0%|          | 0/8 [00:00<?, ?it/s]

SUB S45 Extracting activations:   0%|          | 0/8 [00:00<?, ?it/s]

[PROBE] Start layer-wise probing (0..28)


Probing Layers S45:   0%|          | 0/29 [00:00<?, ?it/s]

[DONE] Layer-wise probing finished for Seed 45.

--- Repetition 5/5 | Seed = 46 ---
[OK] Total addition train+test samples: 5000
[OK] Multiplication test samples: 500
[OK] Subtraction test samples: 500
[ACT] Extracting activations for ADDITION...


ADD S46 Extracting activations:   0%|          | 0/79 [00:00<?, ?it/s]

MULT S46 Extracting activations:   0%|          | 0/8 [00:00<?, ?it/s]

SUB S46 Extracting activations:   0%|          | 0/8 [00:00<?, ?it/s]

[PROBE] Start layer-wise probing (0..28)


Probing Layers S46:   0%|          | 0/29 [00:00<?, ?it/s]

[DONE] Layer-wise probing finished for Seed 46.

[AGG] Computing final statistics...
ADDITION Test Accuracy: Mean=0.1277, 95.0% CI=(0.1223, 0.1331), N_valid_runs=5
Cross-Task Accuracy (MULT): Mean=0.0664, 95.0% CI=(0.0152, 0.1176), N_valid_runs=5
Cross-Task Accuracy (SUB): Mean=0.0760, 95.0% CI=(0.0611, 0.0909), N_valid_runs=5
ADDITION Test Accuracy: Mean=0.1301, 95.0% CI=(0.1185, 0.1416), N_valid_runs=5
Cross-Task Accuracy (MULT): Mean=0.0832, 95.0% CI=(0.0428, 0.1236), N_valid_runs=5
Cross-Task Accuracy (SUB): Mean=0.0764, 95.0% CI=(0.0074, 0.1454), N_valid_runs=5
ADDITION Test Accuracy: Mean=0.1389, 95.0% CI=(0.1240, 0.1537), N_valid_runs=5
Cross-Task Accuracy (MULT): Mean=0.0700, 95.0% CI=(0.0466, 0.0934), N_valid_runs=5
Cross-Task Accuracy (SUB): Mean=0.1324, 95.0% CI=(0.0185, 0.2463), N_valid_runs=5
ADDITION Test Accuracy: Mean=0.1437, 95.0% CI=(0.1303, 0.1571), N_valid_runs=5
Cross-Task Accuracy (MULT): Mean=0.1212, 95.0% CI=(0.0600, 0.1824), N_valid_runs=5
Cross-Task Accuracy (



Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

[OK] Model loaded.
[INFO] Layers to probe (first 5 shown): [0, 1, 2, 3, 4]... total 29 layers (including embedding layer 0).

--- Repetition 1/5 | Seed = 42 ---
[OK] Total addition train+test samples: 5000
[OK] Multiplication test samples: 500
[OK] Subtraction test samples: 500
[ACT] Extracting activations for ADDITION...


ADD S42 Extracting activations:   0%|          | 0/79 [00:00<?, ?it/s]

MULT S42 Extracting activations:   0%|          | 0/8 [00:00<?, ?it/s]

SUB S42 Extracting activations:   0%|          | 0/8 [00:00<?, ?it/s]

[PROBE] Start layer-wise probing (0..28)


Probing Layers S42:   0%|          | 0/29 [00:00<?, ?it/s]

[DONE] Layer-wise probing finished for Seed 42.

--- Repetition 2/5 | Seed = 43 ---
[OK] Total addition train+test samples: 5000
[OK] Multiplication test samples: 500
[OK] Subtraction test samples: 500
[ACT] Extracting activations for ADDITION...


ADD S43 Extracting activations:   0%|          | 0/79 [00:00<?, ?it/s]

MULT S43 Extracting activations:   0%|          | 0/8 [00:00<?, ?it/s]

SUB S43 Extracting activations:   0%|          | 0/8 [00:00<?, ?it/s]

[PROBE] Start layer-wise probing (0..28)


Probing Layers S43:   0%|          | 0/29 [00:00<?, ?it/s]

[DONE] Layer-wise probing finished for Seed 43.

--- Repetition 3/5 | Seed = 44 ---
[OK] Total addition train+test samples: 5000
[OK] Multiplication test samples: 500
[OK] Subtraction test samples: 500
[ACT] Extracting activations for ADDITION...


ADD S44 Extracting activations:   0%|          | 0/79 [00:00<?, ?it/s]

MULT S44 Extracting activations:   0%|          | 0/8 [00:00<?, ?it/s]

SUB S44 Extracting activations:   0%|          | 0/8 [00:00<?, ?it/s]

[PROBE] Start layer-wise probing (0..28)


Probing Layers S44:   0%|          | 0/29 [00:00<?, ?it/s]

[DONE] Layer-wise probing finished for Seed 44.

--- Repetition 4/5 | Seed = 45 ---
[OK] Total addition train+test samples: 5000
[OK] Multiplication test samples: 500
[OK] Subtraction test samples: 500
[ACT] Extracting activations for ADDITION...


ADD S45 Extracting activations:   0%|          | 0/79 [00:00<?, ?it/s]

MULT S45 Extracting activations:   0%|          | 0/8 [00:00<?, ?it/s]

SUB S45 Extracting activations:   0%|          | 0/8 [00:00<?, ?it/s]

[PROBE] Start layer-wise probing (0..28)


Probing Layers S45:   0%|          | 0/29 [00:00<?, ?it/s]

[DONE] Layer-wise probing finished for Seed 45.

--- Repetition 5/5 | Seed = 46 ---
[OK] Total addition train+test samples: 5000
[OK] Multiplication test samples: 500
[OK] Subtraction test samples: 500
[ACT] Extracting activations for ADDITION...


ADD S46 Extracting activations:   0%|          | 0/79 [00:00<?, ?it/s]

MULT S46 Extracting activations:   0%|          | 0/8 [00:00<?, ?it/s]

SUB S46 Extracting activations:   0%|          | 0/8 [00:00<?, ?it/s]

[PROBE] Start layer-wise probing (0..28)


Probing Layers S46:   0%|          | 0/29 [00:00<?, ?it/s]

[DONE] Layer-wise probing finished for Seed 46.

[AGG] Computing final statistics...
ADDITION Test Accuracy: Mean=0.1290, 95.0% CI=(0.1246, 0.1333), N_valid_runs=5
Cross-Task Accuracy (MULT): Mean=0.0596, 95.0% CI=(0.0079, 0.1113), N_valid_runs=5
Cross-Task Accuracy (SUB): Mean=0.0728, 95.0% CI=(0.0515, 0.0941), N_valid_runs=5
ADDITION Test Accuracy: Mean=0.1226, 95.0% CI=(0.1064, 0.1387), N_valid_runs=5
Cross-Task Accuracy (MULT): Mean=0.0604, 95.0% CI=(0.0185, 0.1023), N_valid_runs=5
Cross-Task Accuracy (SUB): Mean=0.1096, 95.0% CI=(0.0272, 0.1920), N_valid_runs=5
ADDITION Test Accuracy: Mean=0.1378, 95.0% CI=(0.1139, 0.1616), N_valid_runs=5
Cross-Task Accuracy (MULT): Mean=0.1052, 95.0% CI=(0.0819, 0.1285), N_valid_runs=5
Cross-Task Accuracy (SUB): Mean=0.0616, 95.0% CI=(0.0001, 0.1231), N_valid_runs=5
ADDITION Test Accuracy: Mean=0.1370, 95.0% CI=(0.1276, 0.1463), N_valid_runs=5
Cross-Task Accuracy (MULT): Mean=0.0912, 95.0% CI=(0.0153, 0.1671), N_valid_runs=5
Cross-Task Accuracy (