In [1]:
# --- Imports ---
import os
import re
import time
import json
import random
import logging
import datetime
from collections import defaultdict
from pathlib import Path
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm.auto import tqdm
from torch.utils.data import TensorDataset, DataLoader
from transformers import AutoTokenizer, AutoModelForCausalLM

from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, classification_report
from scipy import stats
import matplotlib
matplotlib.use("Agg")  # non-interactive backend (no GUI)
import matplotlib.pyplot as plt


# ========== Global logging ==========
logger = logging.getLogger("probe")
logger.propagate = False  # avoid duplicate output to root logger


def setup_logger(log_file_path: str, also_console: bool = False, console_level=logging.WARNING):
    """Configure the logger: write to a file and optionally also to the console.

    Args:
        log_file_path: Path to the log file to write.
        also_console: If True, also log to the console.
        console_level: Console log level if also_console is True.
    """
    logger.setLevel(logging.INFO)
    for h in list(logger.handlers):
        logger.removeHandler(h)

    os.makedirs(os.path.dirname(log_file_path), exist_ok=True)
    file_handler = logging.FileHandler(log_file_path, encoding="utf-8")
    file_handler.setLevel(logging.INFO)
    fmt = logging.Formatter("%(asctime)s [%(levelname)s] %(message)s")
    file_handler.setFormatter(fmt)
    logger.addHandler(file_handler)

    if also_console:
        console_handler = logging.StreamHandler()
        console_handler.setLevel(console_level)
        console_handler.setFormatter(fmt)
        logger.addHandler(console_handler)

    logger.info("======= Logger initialized =======")
    logger.info(f"Log file: {log_file_path}")


# --- Global Configuration (each model runs the main pipeline separately) ---
# ########################################################################
#                          ↓↓↓ Global config ↓↓↓                         #
# ########################################################################

# Put multiple HF repo IDs or local paths here (ensure they are available)
MODELS = [
    "/root/autodl-tmp/llama",
    "/root/autodl-tmp/AceMath",
    "/root/autodl-tmp/Mistral",
    "/root/autodl-tmp/Qwen2.5-7B-Instruct",
    "/root/autodl-tmp/Qwen2.5-Math-7B"
]

def sanitize_model_name(name: str) -> str:
    """
    Convert an HF model name or local path into a safe folder name.
    - Take the last path component if there's a '/'
    - Replace non-alphanumeric characters with underscores
    """
    # Take the "last component"
    last = name.rstrip("/").split("/")[-1] if "/" in name.rstrip("/") else name
    # If it's a disk path but the last component is empty, fall back one level
    if last == "":
        last = Path(name).name
    # Replace illegal characters
    safe = re.sub(r"[^A-Za-z0-9_.-]+", "_", last)
    return safe if safe else "model"

DEVICE_ID = 0

# === Automatic layer selection flags ===
AUTO_INCLUDE_EMBEDDING_STATE = False  # True includes hidden_states[0] (embeddings) for probing
AUTO_EVERY_K_LAYERS = 1               # Probe every k layers (use 2/3/4 for large models)
AUTO_LIMIT_MAX_LAYERS = None          # Limit to first N layers; None = no limit
USE_TQDM = False                      # Disable progress bar (key milestones logged instead)

# --- 3-Way Commutativity Probe Config ---
N_SAMPLES_PER_CLASS_3WAY = 500
TOTAL_SAMPLES_3WAY = N_SAMPLES_PER_CLASS_3WAY * 3
BATCH_SIZE_PROBE = 64
NUM_EPOCHS_PROBE = 10
LEARNING_RATE_PROBE = 0.001
TEST_SPLIT_SIZE = 0.2
OUTPUT_DIM_3WAY = 3

# Multiple replicates (each model will run with these seeds)
N_REPLICATES = 5
RANDOM_SEEDS_LIST = [42, 43, 44, 45, 46]
if len(RANDOM_SEEDS_LIST) != N_REPLICATES:
    raise ValueError("Length of RANDOM_SEEDS_LIST must be equal to N_REPLICATES")

# --- OOD Test Config ---
OOD_TEST_SAMPLES_PER_CLASS = {1: 100, 2: 100, 3: 100, 4: 100}

# --- Single root output directory (each model writes under its subfolder) ---
BASE_OUTPUT_DIR = "./outputs"  # final path: ./outputs/<model_name_short>/
# ########################################################################


# --- Helper Classes (Probe Definitions) ---
class PyTorchMLPProbe(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super().__init__()
        self.layer_1 = nn.Linear(input_dim, hidden_dim)
        self.layer_2 = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        x = self.layer_1(x)
        x = self.layer_2(x)
        return x


class LinearProbe(nn.Module):
    def __init__(self, in_dim: int, n_classes: int):
        super().__init__()
        self.linear = nn.Linear(in_dim, n_classes, bias=True)

    def forward(self, x):
        return self.linear(x)


# --- Function Definitions ---
def setup_environment(seed, device_id):
    """Set random seeds and choose computation device."""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
        device = torch.device(f"cuda:{device_id}")
    else:
        device = torch.device("cpu")
    logger.info(f"Using device: {device}")
    return device


def load_model_and_tokenizer(model_name, device):
    """Load pretrained model and tokenizer."""
    logger.info(f"Loading model: {model_name} ...")
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        output_hidden_states=True,
        torch_dtype=torch.bfloat16,  # adjust if needed
    )
    model.to(device)
    model.eval()
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side = "left"
    logger.info("Model and tokenizer loaded successfully.")
    return model, tokenizer


def resolve_layers_to_probe(model, tokenizer, device,
                            include_embedding=False, every_k=1, limit=None,
                            sample_text="Hello"):
    """
    Run a minimal forward pass to read the length of hidden_states and return
    the indices to probe.
    - If include_embedding=False, start from index 1 (Transformer block outputs only)
    - every_k: sample every k-th layer
    - limit: keep only the first N indices
    """
    model.eval()
    with torch.no_grad():
        inputs = tokenizer(sample_text, return_tensors="pt").to(device)
        outputs = model(**inputs, output_hidden_states=True)
        n_h = len(outputs.hidden_states)  # typically num_hidden_layers + 1
    start = 0 if include_embedding else 1
    indices = list(range(start, n_h))
    if every_k and every_k > 1:
        indices = indices[::every_k]
    if limit is not None:
        indices = indices[:limit]
    logger.info(f"[Auto] hidden_states length = {n_h}; probing indices = {indices}")
    return indices


def generate_3way_commutativity_data(n_samples_per_class, seed):
    """Generate training data for 3-way commutativity probing."""
    logger.info(f"--- Generating 3-Way Commutativity Data ---")
    logger.info(f"Target: {n_samples_per_class} samples per class (a+b (a>b), b+a (a>b), a+a)")
    _random = random.Random(seed)
    samples_class_0, samples_class_1, samples_class_2 = [], [], []

    # Class 2 (a == b)
    max_attempts_c2 = n_samples_per_class * 5
    count_c2 = 0
    while len(samples_class_2) < n_samples_per_class and count_c2 < max_attempts_c2:
        a = _random.randint(10, 99)
        prompt = f"Calculate: {a}+{a} = "
        samples_class_2.append({'prompt': prompt, 'label': 2, 'a': a, 'b': a})
        count_c2 += 1
    if count_c2 >= max_attempts_c2:
        logger.warning("Max attempts reached generating Class 2 samples.")

    # Class 0/1 (a > b)
    generated_pairs_c01 = 0
    max_attempts_c01 = n_samples_per_class * 10
    while generated_pairs_c01 < n_samples_per_class and generated_pairs_c01 * 2 < max_attempts_c01:
        a = _random.randint(10, 99)
        b = _random.randint(10, 99)
        if a == b:
            continue
        if a < b:
            a, b = b, a
        prompt0 = f"Calculate: {a}+{b} = "
        samples_class_0.append({'prompt': prompt0, 'label': 0, 'a': a, 'b': b})
        prompt1 = f"Calculate: {b}+{a} = "
        samples_class_1.append({'prompt': prompt1, 'label': 1, 'a': a, 'b': b})
        generated_pairs_c01 += 1
    if generated_pairs_c01 * 2 >= max_attempts_c01:
        logger.warning("Max attempts reached generating Class 0/1 samples.")

    all_samples = samples_class_0 + samples_class_1 + samples_class_2
    _random.shuffle(all_samples)
    logger.info(f"Generated final dataset with {len(all_samples)} samples.")
    if all_samples:
        all_labels_list = [s['label'] for s in all_samples]
        final_counts_dict = {i: all_labels_list.count(i) for i in range(OUTPUT_DIM_3WAY)}
        logger.info(f"Final label distribution: {final_counts_dict}")
    else:
        raise ValueError("No 3-WAY COMMUTATIVITY samples generated.")
    return all_samples


def get_activations(model, tokenizer, prompts, layers_to_probe, device,
                    batch_size=32, desc_prefix=""):
    """Get hidden state activations (last token) for specified layers and inputs."""
    logger.info(f"[{desc_prefix}] Extracting activations for {len(prompts)} prompts, layers={layers_to_probe}")
    activations = {layer: [] for layer in layers_to_probe}
    model.eval()
    iterable = range(0, len(prompts), batch_size)
    iterator = tqdm(iterable, desc=f"{desc_prefix} Extract", disable=not USE_TQDM)
    with torch.no_grad():
        for i in iterator:
            batch_prompts = prompts[i: i + batch_size]
            try:
                inputs = tokenizer(batch_prompts, return_tensors="pt", padding=True,
                                   truncation=True, max_length=64).to(device)
                outputs = model(**inputs)
                hidden_states = outputs.hidden_states
                sequence_lengths = inputs.attention_mask.sum(dim=1)
                last_token_indices = sequence_lengths - 1
                for layer_idx in layers_to_probe:
                    if not (0 <= layer_idx < len(hidden_states)):
                        continue
                    layer_hidden_states = hidden_states[layer_idx]
                    last_token_activations = layer_hidden_states[
                        torch.arange(layer_hidden_states.size(0)), last_token_indices, :
                    ]
                    activations[layer_idx].append(
                        last_token_activations.to(torch.float32).cpu().numpy()
                    )
            except Exception as e:
                logger.exception(f"Error processing batch at index {i} ({desc_prefix}): {e}")

    final_activations = {}
    for layer in layers_to_probe:
        if activations[layer]:
            try:
                final_activations[layer] = np.concatenate(activations[layer], axis=0)
            except ValueError as e:
                logger.exception(f"Error concatenating for layer {layer} ({desc_prefix}): {e}")
                final_activations[layer] = np.array([])
        else:
            final_activations[layer] = np.array([])
    logger.info(f"[{desc_prefix}] Finished extracting activations.")
    return final_activations


def train_and_evaluate_probes(
        all_samples, extracted_activations, layers_to_probe, device,
        log_filename_base="probe_log", output_dir="./outputs",
        num_epochs=NUM_EPOCHS_PROBE, batch_size=BATCH_SIZE_PROBE, learning_rate=LEARNING_RATE_PROBE,
        test_split_size=TEST_SPLIT_SIZE, random_seed=42, output_dim=OUTPUT_DIM_3WAY
):
    """Train, evaluate, and save a probe per layer; return evaluation results."""
    y_all = np.array([s['label'] for s in all_samples])
    results = {"layer": [], "accuracy": [], "report_dict": [], "f1_class0": [], "f1_class1": [], "f1_class2": []}

    os.makedirs(output_dir, exist_ok=True)
    experiment_timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
    per_run_report_path = os.path.join(output_dir, f"{log_filename_base}_{experiment_timestamp}.txt")
    logger.info(f"Per-layer classification reports -> {per_run_report_path}")

    with open(per_run_report_path, "w", encoding="utf-8") as f:
        f.write(f"Output dir: {output_dir}\n")
        f.write(f"Layers: {layers_to_probe}\n")
        f.write(f"Num samples: {len(all_samples)}\n\n")

    logger.info(f"--- Starting Probe Training Loop for layers {layers_to_probe[0]}..{layers_to_probe[-1]} ---")

    for layer in tqdm(layers_to_probe, desc="Probe train/eval", disable=not USE_TQDM):
        logger.info(f"--- Processing Layer {layer} ---")
        if layer not in extracted_activations or extracted_activations[layer].shape[0] == 0:
            logger.warning(f"Skipping layer {layer}: No activation data.")
            results["layer"].append(layer)
            results["accuracy"].append(np.nan)
            results["report_dict"].append({})
            results["f1_class0"].append(np.nan)
            results["f1_class1"].append(np.nan)
            results["f1_class2"].append(np.nan)
            continue

        X_all_layer = extracted_activations[layer]
        if X_all_layer.shape[0] != len(y_all):
            logger.warning(f"Skipping layer {layer}: Mismatch activations ({X_all_layer.shape[0]}) / labels ({len(y_all)}).")
            results["layer"].append(layer)
            results["accuracy"].append(np.nan)
            results["report_dict"].append({})
            results["f1_class0"].append(np.nan)
            results["f1_class1"].append(np.nan)
            results["f1_class2"].append(np.nan)
            continue

        try:
            X_train, X_test, y_train, y_test = train_test_split(
                X_all_layer, y_all, test_size=test_split_size, random_state=random_seed, stratify=y_all
            )
        except ValueError as e:
            logger.exception(f"Skipping layer {layer}: Error during train/test split: {e}")
            results["layer"].append(layer)
            results["accuracy"].append(np.nan)
            results["report_dict"].append({})
            results["f1_class0"].append(np.nan)
            results["f1_class1"].append(np.nan)
            results["f1_class2"].append(np.nan)
            continue

        input_dim_probe = X_train.shape[1]
        probe = LinearProbe(input_dim_probe, output_dim).to(device)

        criterion = nn.CrossEntropyLoss()
        optimizer = optim.Adam(probe.parameters(), lr=learning_rate)

        X_train_tensor = torch.tensor(X_train, dtype=torch.float32).to(device)
        y_train_tensor = torch.tensor(y_train, dtype=torch.long).to(device)
        X_test_tensor = torch.tensor(X_test, dtype=torch.float32).to(device)
        y_test_tensor = torch.tensor(y_test, dtype=torch.long).to(device)

        train_dataset = TensorDataset(X_train_tensor, y_train_tensor)
        train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

        # Train
        probe.train()
        for epoch in range(num_epochs):
            for batch_X, batch_y in train_loader:
                optimizer.zero_grad()
                outputs = probe(batch_X)
                loss = criterion(outputs, batch_y)
                loss.backward()
                optimizer.step()
        logger.info(f"Layer {layer} - Training completed.")

        # Save probe state_dict
        layer_probe_dir = os.path.join(output_dir, f"layer_{layer}")
        os.makedirs(layer_probe_dir, exist_ok=True)
        state_dict_path = os.path.join(layer_probe_dir, "state_dict.pt")
        torch.save(probe.state_dict(), state_dict_path)
        logger.info(f"Saved probe state_dict -> {state_dict_path}")

        # Eval
        probe.eval()
        with torch.no_grad():
            test_outputs = probe(X_test_tensor)
            _, predicted_indices = torch.max(test_outputs, 1)
            y_pred_np = predicted_indices.cpu().numpy()
            y_true_np = y_test_tensor.cpu().numpy()

            acc_test = accuracy_score(y_true_np, y_pred_np)
            target_names = ['a+b (a>b)', 'b+a (a>b)', 'a+a (a=b)']

            report_dict_sklearn = classification_report(
                y_true_np, y_pred_np, labels=list(range(output_dim)),
                target_names=target_names, digits=3, zero_division=0, output_dict=True
            )
            report_str_sklearn = classification_report(
                y_true_np, y_pred_np, labels=list(range(output_dim)),
                target_names=target_names, digits=3, zero_division=0, output_dict=False
            )

            f1_c0 = report_dict_sklearn.get(target_names[0], {}).get('f1-score', np.nan)
            f1_c1 = report_dict_sklearn.get(target_names[1], {}).get('f1-score', np.nan)
            f1_c2 = report_dict_sklearn.get(target_names[2], {}).get('f1-score', np.nan)

        results["layer"].append(layer)
        results["accuracy"].append(acc_test)
        results["report_dict"].append(report_dict_sklearn)
        results["f1_class0"].append(f1_c0)
        results["f1_class1"].append(f1_c1)
        results["f1_class2"].append(f1_c2)

        logger.info(f"Layer {layer} - Test Acc: {acc_test:.4f} | F1(C0,C1,C2): {f1_c0:.3f}, {f1_c1:.3f}, {f1_c2:.3f}")

        with open(per_run_report_path, "a", encoding="utf-8") as f:
            f.write(f"--- Layer {layer} ---\n")
            f.write(f"Test Accuracy: {acc_test:.4f}\n")
            f.write(f"F1-Score Class 0: {f1_c0:.4f}\n")
            f.write(f"F1-Score Class 1: {f1_c1:.4f}\n")
            f.write(f"F1-Score Class 2: {f1_c2:.4f}\n")
            f.write(f"Classification Report:\n{report_str_sklearn}\n")
            f.write("-" * 20 + "\n")

    logger.info("--- Probe Training Loop Finished ---")
    return results


def generate_ood_comm_samples(n_digits, n_per_class, seed, exclude_prompts=None):
    """Generate OOD samples; optionally exclude specific prompts."""
    _local_rng = random.Random(seed)
    samples_0, samples_1, samples_2 = [], [], []
    low, high = (10 ** (n_digits - 1), 10 ** n_digits - 1) if n_digits > 1 else (1, 9)

    # Class 2 (a==b)
    attempts_c2 = 0
    max_attempts_c2_ood = n_per_class * 50
    while len(samples_2) < n_per_class and attempts_c2 < max_attempts_c2_ood:
        a = _local_rng.randint(low, high)
        p = f"Calculate: {a}+{a} = "
        if exclude_prompts and p in exclude_prompts:
            attempts_c2 += 1
            continue
        samples_2.append(dict(prompt=p, label=2, a=a, b=a))
        if exclude_prompts:
            exclude_prompts.add(p)
        attempts_c2 = 0
    if len(samples_2) < n_per_class:
        logger.warning(f"[WARN] OOD Class-2 ({n_digits}-digit): Got {len(samples_2)}/{n_per_class}")

    # Class 0/1
    attempts_c01 = 0
    max_attempts_c01_ood = n_per_class * 50
    while len(samples_0) < n_per_class and attempts_c01 < max_attempts_c01_ood:
        a, b = _local_rng.randint(low, high), _local_rng.randint(low, high)
        if a == b:
            attempts_c01 += 1
            continue
        if a < b:
            a, b = b, a
        p0, p1 = f"Calculate: {a}+{b} = ", f"Calculate: {b}+{a} = "
        if exclude_prompts and (p0 in exclude_prompts or p1 in exclude_prompts):
            attempts_c01 += 1
            continue
        samples_0.append(dict(prompt=p0, label=0, a=a, b=b))
        samples_1.append(dict(prompt=p1, label=1, a=a, b=b))
        if exclude_prompts:
            exclude_prompts.update([p0, p1])
        attempts_c01 = 0
    if len(samples_0) < n_per_class:
        logger.warning(f"[WARN] OOD Class-0/1 ({n_digits}-digit): Got {len(samples_0)}/{n_per_class}")

    return samples_0 + samples_1 + samples_2


def evaluate_probe_on_ood_layer(model_llm, tokenizer_llm, layer_idx, probe_model,
                                prompts, labels, device, batch_size_activations):
    """Evaluate a specific layer's probe on OOD data."""
    activations_ood_layer = get_activations(
        model_llm, tokenizer_llm, prompts, [layer_idx], device,
        batch_size_activations, desc_prefix=f"OOD_L{layer_idx}"
    )
    if layer_idx not in activations_ood_layer or activations_ood_layer[layer_idx].shape[0] == 0:
        logger.warning(f"Could not get OOD activations for layer {layer_idx}")
        return np.nan, "No OOD activations"

    X_ood = activations_ood_layer[layer_idx]
    if X_ood.shape[0] != len(labels):
        logger.warning(f"Mismatch OOD activations ({X_ood.shape[0]}) and labels ({len(labels)}) for layer {layer_idx}")
        return np.nan, "Mismatch OOD activations/labels"

    X_ood_t = torch.tensor(X_ood, dtype=torch.float32).to(device)
    probe_model.eval()
    with torch.no_grad():
        y_pred_ood = probe_model(X_ood_t).argmax(dim=1).cpu().numpy()

    acc = accuracy_score(labels, y_pred_ood)
    report = classification_report(labels, y_pred_ood, labels=[0, 1, 2],
                                   target_names=['a+b', 'b+a', 'a+a'],
                                   zero_division=0, digits=3, output_dict=False)
    return acc, report


def evaluate_all_probes_on_ood(model_llm, tokenizer_llm, layers_to_probe, probes_dir_for_run,
                               ood_datasets_map, device, input_dim_probes,
                               output_dim_probes=OUTPUT_DIM_3WAY,
                               batch_size_activations=BATCH_SIZE_PROBE):
    """Evaluate all saved probes across all OOD datasets (one pass per dataset)."""
    logger.info("--- Evaluating Probes on OOD Datasets (one pass per dataset) ---")

    # 1) Preload all layer probes to reduce repeated I/O
    probes = {}
    for layer in layers_to_probe:
        probe_path = os.path.join(probes_dir_for_run, f"layer_{layer}", "state_dict.pt")
        if not os.path.exists(probe_path):
            logger.warning(f"Probe for layer {layer} missing at {probe_path}")
            continue
        probe = LinearProbe(input_dim_probes, output_dim_probes).to(device)
        probe.load_state_dict(torch.load(probe_path, map_location=device))
        probe.eval()
        probes[layer] = probe

    ood_summary = defaultdict(list)
    ood_summary['layer'] = layers_to_probe

    # 2) For each OOD dataset: extract activations for all layers in a single pass
    for digits, ds_ood in ood_datasets_map.items():
        prompts_ood = [s['prompt'] for s in ds_ood]
        labels_ood = np.array([s['label'] for s in ds_ood])

        acts_by_layer = get_activations(
            model_llm, tokenizer_llm, prompts_ood, layers_to_probe, device,
            batch_size=batch_size_activations, desc_prefix=f"OOD_{digits}d_ALL"
        )  # key: one-time extraction for all layers

        acc_key = f"acc_{digits}d"
        for layer in layers_to_probe:
            if layer not in probes or layer not in acts_by_layer or acts_by_layer[layer].size == 0:
                ood_summary[acc_key].append(float('nan'))
                continue
            X = torch.tensor(acts_by_layer[layer], dtype=torch.float32, device=device)
            with torch.no_grad():
                pred = probes[layer](X).argmax(dim=1).cpu().numpy()
            acc = (pred == labels_ood).mean()
            ood_summary[acc_key].append(float(acc))
            logger.info(f"[OOD {digits}d] Layer {layer}: acc={acc:.4f}")

    return ood_summary



def generate_final_comparison_plots_with_ci(
        all_runs_results_in_domain,   # list of results_3way_in_domain dicts
        all_runs_ood_summaries,       # list of ood_evaluation_summary dicts
        layers_plotted_config,        # e.g., per-model LAYERS_TO_PROBE
        output_dir,
        plot_timestamp,
        model_name_short,
        n_replicates
):
    """
    Generate final in-domain vs OOD performance comparison plots (line + bar)
    with approximate 95% CIs, and save to a unified folder per model.
    Note: layers_plotted_config is the model's own layer index list; each model
    is plotted and saved separately.
    """
    logger.info("--- Generating Final Aggregated Plots (with CI) ---")
    os.makedirs(output_dir, exist_ok=True)
    num_layers = len(layers_plotted_config)

    # Initialize aggregation containers
    acc_2d_in_domain_runs = np.full((n_replicates, num_layers), np.nan)
    acc_1d_ood_runs = np.full((n_replicates, num_layers), np.nan)
    acc_2d_ood_runs = np.full((n_replicates, num_layers), np.nan)
    acc_3d_ood_runs = np.full((n_replicates, num_layers), np.nan)
    acc_4d_ood_runs = np.full((n_replicates, num_layers), np.nan)

    mean_acc_2d_in_domain_bar_runs = np.full(n_replicates, np.nan)
    mean_acc_1d_ood_bar_runs = np.full(n_replicates, np.nan)
    mean_acc_2d_ood_bar_runs = np.full(n_replicates, np.nan)
    mean_acc_3d_ood_bar_runs = np.full(n_replicates, np.nan)
    mean_acc_4d_ood_bar_runs = np.full(n_replicates, np.nan)

    # Fill data from each run
    for i_run in range(n_replicates):
        res_in_domain_run = all_runs_results_in_domain[i_run]
        layer_to_idx_map_in_domain = {l: idx for idx, l in enumerate(res_in_domain_run['layer'])}

        current_run_in_domain_accs = []
        for layer_config_idx, layer_val in enumerate(layers_plotted_config):
            if layer_val in layer_to_idx_map_in_domain:
                source_idx = layer_to_idx_map_in_domain[layer_val]
                acc = res_in_domain_run['accuracy'][source_idx]
                acc_2d_in_domain_runs[i_run, layer_config_idx] = acc
                if not np.isnan(acc):
                    current_run_in_domain_accs.append(acc)
        if current_run_in_domain_accs:
            mean_acc_2d_in_domain_bar_runs[i_run] = np.nanmean(current_run_in_domain_accs)

        # OOD
        ood_summary_run = all_runs_ood_summaries[i_run]
        layer_to_idx_map_ood = {l: idx for idx, l in enumerate(ood_summary_run['layer'])}
        current_run_ood_accs = {'1d': [], '2d': [], '3d': [], '4d': []}
        for layer_config_idx, layer_val in enumerate(layers_plotted_config):
            if layer_val in layer_to_idx_map_ood:
                source_idx = layer_to_idx_map_ood[layer_val]
                if source_idx < len(ood_summary_run.get('acc_1d', [])):
                    acc_1d = ood_summary_run['acc_1d'][source_idx]
                    acc_1d_ood_runs[i_run, layer_config_idx] = acc_1d
                    if not np.isnan(acc_1d): current_run_ood_accs['1d'].append(acc_1d)
                if source_idx < len(ood_summary_run.get('acc_2d', [])):
                    acc_2d = ood_summary_run['acc_2d'][source_idx]
                    acc_2d_ood_runs[i_run, layer_config_idx] = acc_2d
                    if not np.isnan(acc_2d): current_run_ood_accs['2d'].append(acc_2d)
                if source_idx < len(ood_summary_run.get('acc_3d', [])):
                    acc_3d = ood_summary_run['acc_3d'][source_idx]
                    acc_3d_ood_runs[i_run, layer_config_idx] = acc_3d
                    if not np.isnan(acc_3d): current_run_ood_accs['3d'].append(acc_3d)
                if source_idx < len(ood_summary_run.get('acc_4d', [])):
                    acc_4d = ood_summary_run['acc_4d'][source_idx]
                    acc_4d_ood_runs[i_run, layer_config_idx] = acc_4d
                    if not np.isnan(acc_4d): current_run_ood_accs['4d'].append(acc_4d)

        if current_run_ood_accs['1d']: mean_acc_1d_ood_bar_runs[i_run] = np.nanmean(current_run_ood_accs['1d'])
        if current_run_ood_accs['2d']: mean_acc_2d_ood_bar_runs[i_run] = np.nanmean(current_run_ood_accs['2d'])
        if current_run_ood_accs['3d']: mean_acc_3d_ood_bar_runs[i_run] = np.nanmean(current_run_ood_accs['3d'])
        if current_run_ood_accs['4d']: mean_acc_4d_ood_bar_runs[i_run] = np.nanmean(current_run_ood_accs['4d'])

    # Stats: mean and approx 95% CI
    def get_mean_and_ci(data_runs_per_layer, confidence=0.95):
        mean_per_layer = np.nanmean(data_runs_per_layer, axis=0)
        sem_per_layer = stats.sem(data_runs_per_layer, axis=0, nan_policy='omit')
        sem_per_layer = np.nan_to_num(sem_per_layer)
        h_per_layer = 1.96 * sem_per_layer
        return mean_per_layer, np.clip(mean_per_layer - h_per_layer, 0, 1), np.clip(mean_per_layer + h_per_layer, 0, 1)

    mean_2d_in, ci_low_2d_in, ci_high_2d_in = get_mean_and_ci(acc_2d_in_domain_runs)
    mean_1d_ood, ci_low_1d_ood, ci_high_1d_ood = get_mean_and_ci(acc_1d_ood_runs)
    mean_2d_ood, ci_low_2d_ood, ci_high_2d_ood = get_mean_and_ci(acc_2d_ood_runs)
    mean_3d_ood, ci_low_3d_ood, ci_high_3d_ood = get_mean_and_ci(acc_3d_ood_runs)
    mean_4d_ood, ci_low_4d_ood, ci_high_4d_ood = get_mean_and_ci(acc_4d_ood_runs)

    layers_for_plot = np.array(layers_plotted_config)

    # Matplotlib style
    plt.rcParams.update({
        'font.size': 14,
        'font.family': 'serif',
        'font.sans-serif': ['Arial'],
        'axes.labelsize': 14,
        'axes.titlesize': 20,
        'xtick.labelsize': 12,
        'ytick.labelsize': 12,
        'legend.fontsize': 14,
        'lines.linewidth': 1.8,
        'lines.markersize': 6,
        'axes.grid': True,
        'grid.alpha': 0.5,
        'grid.linestyle': ':',
        'savefig.dpi': 600,
        'savefig.format': 'pdf',
        'savefig.bbox': 'tight',
    })
    colors = plt.cm.get_cmap('tab10', 5)
    color_map = {
        '2-digit_in': colors(0), '1-digit_ood': colors(1), '2-digit_ood': colors(2),
        '3-digit_ood': colors(3), '4-digit_ood': colors(4), 'chance': 'grey'
    }
    marker_map = {'2-digit_in': 'o', '1-digit_ood': '^', '2-digit_ood': 'd',
                  '3-digit_ood': 's', '4-digit_ood': 'v'}

    # Line plot
    fig1, ax1 = plt.subplots(figsize=(10, 5.5))

    def plot_line_with_ci(ax, x, mean, ci_low, ci_high, color, marker, label):
        ax.plot(x, mean, marker=marker, label=label, color=color)
        ax.fill_between(x, ci_low, ci_high, color=color, alpha=0.2)

    plot_line_with_ci(ax1, layers_for_plot, mean_2d_in, ci_low_2d_in, ci_high_2d_in,
                      color_map['2-digit_in'], marker_map['2-digit_in'], '2-digit (In-Domain Train)')
    ax1.plot(layers_for_plot, mean_2d_ood, marker=marker_map['2-digit_ood'],
             label='2-digit (OOD Test)', color=color_map['2-digit_ood'], linestyle='--')
    ax1.fill_between(layers_for_plot, ci_low_2d_ood, ci_high_2d_ood, color=color_map['2-digit_ood'], alpha=0.2)
    plot_line_with_ci(ax1, layers_for_plot, mean_1d_ood, ci_low_1d_ood, ci_high_1d_ood,
                      color_map['1-digit_ood'], marker_map['1-digit_ood'], '1-digit (OOD Test)')
    plot_line_with_ci(ax1, layers_for_plot, mean_3d_ood, ci_low_3d_ood, ci_high_3d_ood,
                      color_map['3-digit_ood'], marker_map['3-digit_ood'], '3-digit (OOD Test)')
    plot_line_with_ci(ax1, layers_for_plot, mean_4d_ood, ci_low_4d_ood, ci_high_4d_ood,
                      color_map['4-digit_ood'], marker_map['4-digit_ood'], '4-digit (OOD Test)')

    ax1.axhline(y=1.0 / 3, linestyle='--', linewidth=1.0, label='Chance (0.333)',
                color=color_map['chance'], zorder=0)
    model_tag = sanitize_model_name(model_id)  
    ax1.set_xlabel(f"Layer Index({model_tag})")
    ax1.set_ylabel("Mean Accuracy")
    ax1.set_title(f"Probe Generalisation Across Digit Lengths ({n_replicates} Runs)")
    ax1.set_xticks(layers_for_plot)
    ax1.set_ylim(0.0, 1.05)
    ax1.legend(loc='lower right', fontsize=12, frameon=True,
               facecolor='white', framealpha=0.8, borderpad=0.2,
               labelspacing=0.2, handlelength=1.0, handletextpad=0.4)
    ax1.spines['top'].set_visible(False)
    ax1.spines['right'].set_visible(False)
    ax1.tick_params(direction='in', top=False, right=False)
    ax1.minorticks_on()
    ax1.grid(True, linestyle=':', linewidth=0.4, which='minor', alpha=0.4, axis='y')
    ax1.grid(True, linestyle=':', linewidth=0.6, which='major', alpha=0.7, axis='y')
    plt.tight_layout(pad=0.5)
    pdf_line_path = os.path.join(output_dir, f"AGG_probe_comm_generalisation_line_{model_name_short}_{plot_timestamp}.pdf")
    plt.savefig(pdf_line_path)
    plt.close(fig1)
    logger.info(f"Saved aggregated per-layer line chart (PDF) -> {pdf_line_path}")

    # Bar chart
    bar_means = {
        "2-digit (Train)": np.nanmean(np.nanmean(acc_2d_in_domain_runs, axis=1)),
        "1-digit (OOD)": np.nanmean(np.nanmean(acc_1d_ood_runs, axis=1)),
        "2-digit (OOD)": np.nanmean(np.nanmean(acc_2d_ood_runs, axis=1)),
        "3-digit (OOD)": np.nanmean(np.nanmean(acc_3d_ood_runs, axis=1)),
        "4-digit (OOD)": np.nanmean(np.nanmean(acc_4d_ood_runs, axis=1)),
    }
    bar_sems = {
        "2-digit (Train)": stats.sem(np.nanmean(acc_2d_in_domain_runs, axis=1), nan_policy='omit'),
        "1-digit (OOD)": stats.sem(np.nanmean(acc_1d_ood_runs, axis=1), nan_policy='omit'),
        "2-digit (OOD)": stats.sem(np.nanmean(acc_2d_ood_runs, axis=1), nan_policy='omit'),
        "3-digit (OOD)": stats.sem(np.nanmean(acc_3d_ood_runs, axis=1), nan_policy='omit'),
        "4-digit (OOD)": stats.sem(np.nanmean(acc_4d_ood_runs, axis=1), nan_policy='omit'),
    }
    for k in bar_sems:
        bar_sems[k] = np.nan_to_num(bar_sems[k])

    labels_bar = list(bar_means.keys())
    means_bar_values = [bar_means[k] for k in labels_bar]
    sems_bar_values = [1.96 * bar_sems[k] for k in labels_bar]

    colors_map = plt.cm.get_cmap('tab10', 5)
    bar_colors_map_list = [colors_map(0), colors_map(1), colors_map(2), colors_map(3), colors_map(4)]

    fig2, ax2 = plt.subplots(figsize=(7, 5.5))
    bars = ax2.bar(range(len(labels_bar)), means_bar_values, yerr=sems_bar_values,
                   color=bar_colors_map_list, width=0.6, edgecolor='black', linewidth=0.7, capsize=5)
    ax2.set_xticks(range(len(labels_bar)))
    ax2.set_xticklabels(labels_bar, rotation=25, ha="right")
    ax2.axhline(y=1.0 / 3, linestyle='--', linewidth=1.2, color='grey', zorder=0)
    ax2.set_ylabel("Mean Accuracy (Averaged Across Layers & Runs)")
    ymax = max([v for v in means_bar_values if not np.isnan(v)] + [0.0])
    ax2.set_ylim(0, ymax * 1.15 if ymax > 0 else 1.0)
    for bar in bars:
        yval = bar.get_height()
        ax2.text(bar.get_x() + bar.get_width() / 2.0, yval + 0.005, f'{yval:.3f}',
                 va='bottom', ha='center', fontsize=12)
    ax2.spines['top'].set_visible(False)
    ax2.spines['right'].set_visible(False)
    ax2.tick_params(direction='in', top=False, right=False, bottom=False)
    ax2.yaxis.grid(True, linestyle=':', linewidth=0.6, alpha=0.7)
    ax2.set_axisbelow(True)
    plt.tight_layout(pad=0.5)
    pdf_bar_path = os.path.join(output_dir, f"AGG_probe_comm_generalisation_bar_{model_name_short}_{plot_timestamp}.pdf")
    plt.savefig(pdf_bar_path)
    plt.close(fig2)
    logger.info(f"Saved aggregated mean-accuracy bar chart (PDF) -> {pdf_bar_path}")


# ========== Single-model main routine ==========

def run_experiment_for_model(model_id: str):
    """Run the full pipeline for a single model: load, auto-select layers, train N runs,
    OOD evaluation, and aggregated plotting.
    """
    # Timestamp & short model name
    RUN_TIMESTAMP = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
    model_name_short = os.path.basename(model_id.strip('/')) or re.sub(r'[^A-Za-z0-9_.-]+', '_', model_id)

    # Unified output directory (per model)
    OUTPUT_DIR = os.path.join(BASE_OUTPUT_DIR, model_name_short)
    os.makedirs(OUTPUT_DIR, exist_ok=True)

    # Main log goes under the unified output directory
    LOG_PATH = os.path.join(OUTPUT_DIR, f"run_{RUN_TIMESTAMP}.log")
    setup_logger(LOG_PATH, also_console=False)

    # 1) Environment & model (init with the first seed)
    temp_device = setup_environment(RANDOM_SEEDS_LIST[0], DEVICE_ID)
    try:
        model_llm, tokenizer_llm = load_model_and_tokenizer(model_id, temp_device)
    except Exception as e:
        logger.exception(f"[FATAL] Failed to load model {model_id}: {e}")
        return  # do not stop other models

    # 2) Dynamically determine layers_to_probe (per model)
    layers_to_probe = resolve_layers_to_probe(
        model_llm, tokenizer_llm, temp_device,
        include_embedding=AUTO_INCLUDE_EMBEDDING_STATE,
        every_k=AUTO_EVERY_K_LAYERS,
        limit=AUTO_LIMIT_MAX_LAYERS,
        sample_text="Hello"
    )
    if not isinstance(layers_to_probe, (list, tuple)) or len(layers_to_probe) == 0:
        logger.error("LAYERS_TO_PROBE is empty; skipping this model.")
        try:
            del model_llm; del tokenizer_llm
        except Exception:
            pass
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        return
    logger.info(f"Final LAYERS_TO_PROBE for {model_name_short}: {layers_to_probe}")

    # 3) Determine probe input dimension (with a small sample)
    logger.info("Determining probe input dimension...")
    _placeholder_samples = generate_3way_commutativity_data(10, RANDOM_SEEDS_LIST[0])
    _placeholder_prompts = [s['prompt'] for s in _placeholder_samples]
    _placeholder_activations = get_activations(
        model_llm, tokenizer_llm, _placeholder_prompts, [layers_to_probe[0]],
        temp_device, batch_size=10, desc_prefix="PLACEHOLDER"
    )
    if layers_to_probe[0] not in _placeholder_activations or _placeholder_activations[layers_to_probe[0]].size == 0:
        logger.error("Could not get placeholder activations to determine probe input dimension. Skip model.")
        try:
            del model_llm; del tokenizer_llm
        except Exception:
            pass
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        return
    probe_input_dimension_fixed = _placeholder_activations[layers_to_probe[0]].shape[1]
    logger.info(f"Determined fixed probe input dimension: {probe_input_dimension_fixed}")
    del _placeholder_samples, _placeholder_prompts, _placeholder_activations

    # 4) Multiple replicates per model
    all_runs_results_in_domain_list = []
    all_runs_ood_summaries_list = []

    for i_run, current_seed in enumerate(RANDOM_SEEDS_LIST, start=1):
        logger.info("=" * 20 + f" RUN {i_run}/{N_REPLICATES} (Seed: {current_seed}) " + "=" * 20)

        current_device = setup_environment(current_seed, DEVICE_ID)
        # Ensure the model is on the current device (usually the same; move if needed)
        try:
            model_llm.to(current_device)
        except Exception as e:
            logger.warning(f"Model .to({current_device}) raised: {e}")

        # Probes & reports for each run are written in this model's subdirectory
        run_identifier = f"seed_{current_seed}"
        probes_dir_for_run = os.path.join(OUTPUT_DIR, f"run_{run_identifier}")
        os.makedirs(probes_dir_for_run, exist_ok=True)

        # 4.1 Generate training data
        all_3way_samples_train = generate_3way_commutativity_data(N_SAMPLES_PER_CLASS_3WAY, seed=current_seed)
        prompts_3way_train = [s['prompt'] for s in all_3way_samples_train]

        # 4.2 Extract activations
        activations_3way_train = get_activations(
            model_llm, tokenizer_llm, prompts_3way_train, layers_to_probe, current_device,
            batch_size=BATCH_SIZE_PROBE, desc_prefix=f"TRAIN_COMM_3WAY_Run{i_run}"
        )

        # 4.3 Train & evaluate probes (save per layer)
        results_3way_this_run = train_and_evaluate_probes(
            all_3way_samples_train, activations_3way_train, layers_to_probe, current_device,
            log_filename_base=f"probe_3way_comm_{model_name_short}_{run_identifier}",
            output_dir=probes_dir_for_run,  # write each layer's state_dict and this run's report here
            random_seed=current_seed,
        )
        all_runs_results_in_domain_list.append(results_3way_this_run)

        # 4.4 Generate OOD data (2-digit excludes any training prompts from this run)
        logger.info(f"--- Generating OOD Datasets for Run {i_run} ---")
        train_prompt_set_this_run = {s['prompt'] for s in all_3way_samples_train}
        ood_datasets_this_run = {
            1: generate_ood_comm_samples(1, OOD_TEST_SAMPLES_PER_CLASS[1], seed=current_seed + 1001),
            2: generate_ood_comm_samples(2, OOD_TEST_SAMPLES_PER_CLASS[2], seed=current_seed + 1002,
                                         exclude_prompts=train_prompt_set_this_run.copy()),
            3: generate_ood_comm_samples(3, OOD_TEST_SAMPLES_PER_CLASS[3], seed=current_seed + 1003),
            4: generate_ood_comm_samples(4, OOD_TEST_SAMPLES_PER_CLASS[4], seed=current_seed + 1004),
        }

        # 4.5 OOD evaluation (load each layer's probe; results logged)
        ood_summary_this_run = evaluate_all_probes_on_ood(
            model_llm, tokenizer_llm, layers_to_probe,
            probes_dir_for_run=probes_dir_for_run,  # load layer probes from this run dir
            ood_datasets_map=ood_datasets_this_run,
            device=current_device,
            input_dim_probes=probe_input_dimension_fixed,
        )
        all_runs_ood_summaries_list.append(ood_summary_this_run)

    # 5) Aggregated plots (with CI) — write under this model's output directory
    plot_timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
    generate_final_comparison_plots_with_ci(
        all_runs_results_in_domain_list,
        all_runs_ood_summaries_list,
        layers_to_probe,
        OUTPUT_DIR,
        plot_timestamp,
        model_name_short,
        N_REPLICATES
    )

    logger.info("--- All Runs and Aggregated Plotting Finished (this model) ---")
    logger.info(f"Unified output dir (this model): {OUTPUT_DIR}")

    # 6) Memory cleanup
    try:
        del model_llm; del tokenizer_llm
    except Exception:
        pass
    if torch.cuda.is_available():
        torch.cuda.empty_cache()


# --- Main (multi-model entry) ---
if __name__ == "__main__":
    if not MODELS:
        print("[WARN] MODELS is empty. Please add at least one model path or HF repo id to the MODELS list at the top of the script.")
    else:
        for midx, model_id in enumerate(MODELS, start=1):
            try:
                print(f"\n===== MODEL {midx}/{len(MODELS)}: {model_id} =====")
                run_experiment_for_model(model_id)
            except Exception as e:
                # Catch per-model unhandled exceptions and continue to the next model
                try:
                    logger.exception(f"[FATAL] Uncaught error while running model {model_id}: {e}")
                except Exception:
                    print(f"[FATAL] Uncaught error while running model {model_id}: {e}")
                if torch.cuda.is_available():
                    torch.cuda.empty_cache()
        print("\n>>> All models finished (see per-model logs under ./outputs/<model_name_short>/).\n")



===== MODEL 1/5: /root/autodl-tmp/llama =====




Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

  colors = plt.cm.get_cmap('tab10', 5)
  colors_map = plt.cm.get_cmap('tab10', 5)



===== MODEL 2/5: /root/autodl-tmp/AceMath =====


Sliding Window Attention is enabled but not implemented for `sdpa`; unexpected results may be encountered.


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

  colors = plt.cm.get_cmap('tab10', 5)
  colors_map = plt.cm.get_cmap('tab10', 5)



===== MODEL 3/5: /root/autodl-tmp/Mistral =====




Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

  colors = plt.cm.get_cmap('tab10', 5)
  colors_map = plt.cm.get_cmap('tab10', 5)



===== MODEL 4/5: /root/autodl-tmp/Qwen2.5-7B-Instruct =====




Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

  colors = plt.cm.get_cmap('tab10', 5)
  colors_map = plt.cm.get_cmap('tab10', 5)



===== MODEL 5/5: /root/autodl-tmp/Qwen2.5-Math-7B =====




Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

  colors = plt.cm.get_cmap('tab10', 5)
  colors_map = plt.cm.get_cmap('tab10', 5)



>>> All models finished (see per-model logs under ./outputs/<model_name_short>/).

