In [None]:
import os, math, time, json, warnings, random
from dataclasses import dataclass, asdict
from typing import List, Dict, Tuple

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

import matplotlib.pyplot as plt
import seaborn as sns
from contextlib import nullcontext
from tqdm.auto import tqdm

# --- Pretty table printing & optional saving ---
PRINT_MAX_ROWS = 30

try:
    from rich.console import Console
    from rich.table import Table
    _RICH = True
    _console = Console()
except Exception:
    _RICH = False
    _console = None

def _fmt_cell(v, floatfmt=".4g"):
    try:
        if v is None or (isinstance(v, float) and (np.isnan(v) or np.isinf(v))):
            return str(v)
        if isinstance(v, (int, np.integer)):
            return str(int(v))
        if isinstance(v, (float, np.floating)):
            return format(float(v), floatfmt)
        return str(v)
    except Exception:
        return str(v)

def print_table(title: str, df: pd.DataFrame, max_rows: int = PRINT_MAX_ROWS, floatfmt: str = ".4g"):
    if df is None or df.empty:
        print(f"\n[TABLE] {title}: (empty)\n")
        return
    df_show = df.head(max_rows).copy()

    if _RICH:
        table = Table(title=title, show_lines=True)
        for col in df_show.columns:
            table.add_column(str(col))
        for _, row in df_show.iterrows():
            table.add_row(*[_fmt_cell(v, floatfmt) for v in row.tolist()])
        _console.print(table)
        if len(df) > max_rows:
            print(f"... ({len(df) - max_rows} more rows)")
    else:
        print(f"\n=== {title} ===")
        print(df_show.to_string(index=False, float_format=lambda x: format(x, floatfmt)))
        if len(df) > max_rows:
            print(f"... ({len(df) - max_rows} more rows)")

def maybe_save_csv(df: pd.DataFrame, path: str, save: bool):
    if save and (df is not None) and (not df.empty):
        ensure_dir(os.path.dirname(path))
        df.to_csv(path, index=False)

try:
    from einops import rearrange
    from einops.layers.torch import Rearrange
except Exception as e:
    raise RuntimeError("Please install einops: pip install einops") from e

EXP1_STEP_LOGS = []

# ------------------------------
# 0) Repro, device, I/O helpers
# ------------------------------
def set_seed(seed: int):
    random.seed(seed); np.random.seed(seed)
    torch.manual_seed(seed); torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"[INFO] Using device: {device}")
if device.type == "cuda":
    import torch.backends.cuda as cuda_backends
    cuda_backends.matmul.allow_tf32 = True
    torch.set_float32_matmul_precision("high")

def ensure_dir(p: str):
    os.makedirs(p, exist_ok=True); return p

OUT_DIR = ensure_dir("./aistats_release_outputs")

# ------------------------------
# 1) Small utilities
# ------------------------------
def pair(t): return t if isinstance(t, tuple) else (t, t)

def autocast_ctx(device, precision: str):
    if device.type == "cuda":
        dtype_map = {"fp16": torch.float16, "bf16": torch.bfloat16}
        return torch.amp.autocast("cuda", enabled=(precision in dtype_map),
                                    dtype=dtype_map.get(precision, torch.float32))
    return nullcontext()

def no_autocast_ctx(device):
    if device.type == "cuda":
        return torch.amp.autocast("cuda", enabled=False)
    return nullcontext()

def bootstrap_ci(x: np.ndarray, stat_fn, n=1000, alpha=0.05, rng=None) -> Tuple[float, float, float]:
    """Return (stat, low, high) for the statistic on x with basic bootstrap CI."""
    rng = np.random.default_rng(None if rng is None else rng)
    x = np.asarray(x); x = x[~np.isnan(x)]
    if len(x) == 0:
        return np.nan, np.nan, np.nan
    stat = stat_fn(x)
    boots = []
    for _ in range(n):
        sample = rng.choice(x, size=len(x), replace=True)
        boots.append(stat_fn(sample))
    boots = np.sort(np.asarray(boots))
    lo = np.percentile(boots, 100*alpha/2.0)
    hi = np.percentile(boots, 100*(1-alpha/2.0))
    return stat, lo, hi

def loglog_regression(x, y):
    """
    Simple log-log linear regression y ~ x^b -> log y = a + b log x.
    Returns dict with slope, intercept, R2, pval (t-test on slope).
    """
    x = np.asarray(x); y = np.asarray(y)
    m = (x > 0) & (y > 0) & np.isfinite(x) & np.isfinite(y)
    x = np.log10(x[m] + 1e-12); y = np.log10(y[m] + 1e-12)
    if len(x) < 2:
        return dict(n=len(x), slope=np.nan, intercept=np.nan, R2=np.nan, pval=np.nan)
    X = np.vstack([np.ones_like(x), x]).T
    beta, *_ = np.linalg.lstsq(X, y, rcond=None)      # [a, b]
    yhat = X @ beta
    resid = y - yhat
    ss_tot = ((y - y.mean())**2).sum()
    ss_res = (resid**2).sum()
    R2 = 1 - ss_res / max(ss_tot, 1e-12)
    # t-stat for slope
    dof = max(len(x) - 2, 1)
    s2 = ss_res / dof
    var_beta = s2 * np.linalg.inv(X.T @ X)
    se_b = np.sqrt(var_beta[1,1])
    t = beta[1] / max(se_b, 1e-12)
    from math import erf, sqrt
    p = 2 * (1 - 0.5*(1 + erf(abs(t)/np.sqrt(2))))
    return dict(n=len(x), slope=beta[1], intercept=beta[0], R2=R2, pval=p)

# ------------------------------
# 2) Tiny ViT
# ------------------------------
class PatchEmbedding(nn.Module):
    def __init__(self, image_size, patch_size, dim):
        super().__init__()
        H, W = pair(image_size); pH, pW = pair(patch_size)
        assert H % pH == 0 and W % pW == 0
        n_patches = (H // pH) * (W // pW)
        patch_dim = 3 * pH * pW
        self.to_patch_embedding = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=pH, p2=pW),
            nn.Linear(patch_dim, dim),
        )
        self.pos_embedding = nn.Parameter(torch.randn(1, n_patches + 1, dim))
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))

    def forward(self, img):
        x = self.to_patch_embedding(img)
        b, n, _ = x.shape
        x = torch.cat((self.cls_token.expand(b, -1, -1), x), dim=1)
        x = x + self.pos_embedding[:, : n + 1]
        return x

class Attention(nn.Module):
    def __init__(self, dim, heads=8, dim_head=64):
        super().__init__()
        inner = heads * dim_head
        self.heads = heads
        self.scale = dim_head ** -0.5
        self.to_qkv = nn.Linear(dim, inner * 3, bias=False)
        self.to_out = nn.Linear(inner, dim)

    def forward(self, x):
        qkv = self.to_qkv(x).chunk(3, dim=-1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=self.heads), qkv)
        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale  # (B,H,N,N)
        scores = dots
        attn_probs = scores.softmax(dim=-1)
        out = torch.matmul(attn_probs, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out), scores, attn_probs, q, k, v

class TransformerEncoder(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim):
        super().__init__()
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                nn.LayerNorm(dim),
                Attention(dim, heads=heads, dim_head=dim_head),
                nn.LayerNorm(dim),
                nn.Sequential(nn.Linear(dim, mlp_dim), nn.GELU(), nn.Linear(mlp_dim, dim))
            ]))

    def forward(self, x):
        inter = {
            'scores': [], 'attn_probs': [], 'ln_pre_act': [], 'ln_modules': [],
            'mlp_w1': [], 'mlp_w2': [], 'attn_Wo': [],      # NEW
            'mlp_weights': [],                             # 기존 baseline 유지(여기선 W2를 넣자)
            'q': [], 'k': [], 'v': []
        }
        for ln1, attn, ln2, mlp in self.layers:
            x1 = ln1(x); inter['ln_pre_act'].append(x1); inter['ln_modules'].append(ln1)
            attn_out, scores, attn_probs, q, k, v = attn(x1)
            x = x + attn_out
            x2 = ln2(x); inter['ln_pre_act'].append(x2); inter['ln_modules'].append(ln2)
            x = x + mlp(x2)
            inter['scores'].append(scores); inter['attn_probs'].append(attn_probs)
            inter['q'].append(q.detach()); inter['k'].append(k.detach()); inter['v'].append(v.detach())
            w1 = w2 = None
            for m in mlp.modules():
                if isinstance(m, nn.Linear):
                    if w1 is None: w1 = m.weight.detach()
                    else: w2 = m.weight.detach(); break
            inter['mlp_w1'].append(w1); inter['mlp_w2'].append(w2)
            # baseline 용으론 W2를 사용(기존 코드 호환)
            inter['mlp_weights'].append(w2)
            # 어텐션 출력선형 W_O 기록
            inter['attn_Wo'].append(attn.to_out.weight.detach())
        return x, inter

class ViT(nn.Module):
    def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, dim_head=64):
        super().__init__()
        self.patch_embedding = PatchEmbedding(image_size, patch_size, dim)
        self.transformer = TransformerEncoder(dim, depth, heads, dim_head, mlp_dim)
        self.pool = "cls"
        self.mlp_head = nn.Sequential(nn.LayerNorm(dim), nn.Linear(dim, num_classes))

    def forward(self, img):
        x = self.patch_embedding(img)
        x, inter = self.transformer(x)
        x = x[:, 0] if self.pool == "cls" else x.mean(dim=1)
        return self.mlp_head(x), inter

# ------------------------------
# 3) Diagnostics (paper-aligned)
# ------------------------------
@torch.no_grad()
def spectral_norm_matrix_batch(mats: torch.Tensor, sample: int = None, iters: int = 7) -> torch.Tensor:
    """
    Spectral norm of possibly-rectangular matrices via power iteration on (M^T M).
    mats: (..., n, d)
    Returns: vector of spectral norms over the leading batch dims.
    """
    X = mats.float()
    if sample is not None and X.shape[0] > sample:
        idx = torch.randperm(X.shape[0], device=X.device)[:sample]
        X = X[idx]
    d = X.shape[-1]
    v = torch.randn(X.shape[:-2] + (d, 1), device=X.device)
    v = v / (v.norm(dim=(-2, -1), keepdim=True) + 1e-12)
    for _ in range(iters):
        u = X.transpose(-1, -2) @ (X @ v)   # (.., d, d) @ (.., d, 1)
        v = u / (u.norm(dim=(-2, -1), keepdim=True) + 1e-12)
    u = X.transpose(-1, -2) @ (X @ v)
    lam = (v.transpose(-1, -2) @ u) / (v.transpose(-1, -2) @ v + 1e-12)
    return torch.sqrt(torch.clamp(lam.squeeze(-1).squeeze(-1), min=0)).to(torch.float32)

@torch.no_grad()
def spectral_norm_power(S: torch.Tensor, iters: int = 7) -> torch.Tensor:
    """
    Batched spectral norm via power iteration on S^T S.
    S: (..., N, N)
    returns: tensor of shape (...) with ||S||_2 estimates
    """
    shape = S.shape[:-2]; N = S.shape[-1]
    v = torch.randn(*shape, N, 1, device=S.device, dtype=S.dtype)
    v = v / (torch.linalg.norm(v, dim=(-2, -1), keepdim=True) + 1e-12)
    for _ in range(iters):
        w = S @ v
        u = S.transpose(-1, -2) @ w
        v = u / (torch.linalg.norm(u, dim=(-2, -1), keepdim=True) + 1e-12)
    w = S @ v
    s2 = (v.transpose(-1, -2) @ (S.transpose(-1, -2) @ w)) / (v.transpose(-1, -2) @ v + 1e-12)
    return torch.sqrt(torch.clamp(s2.squeeze(-1).squeeze(-1), min=0)).to(torch.float32)

@torch.no_grad()
def kappa_score_pdf(q: torch.Tensor, k: torch.Tensor, S: torch.Tensor,
                    norm_mode: str = "fro",    # "fro" (paper body) or "spec" (protocols)
                    iters: int = 7, sample_heads: int = 64) -> float:
    B, H, N, d = q.shape
    Q = q.reshape(B * H, N, d)
    K = k.reshape(B * H, N, d)
    Sm = S.reshape(B * H, N, N)

    if norm_mode == "fro":
        Qn = torch.linalg.norm(Q, ord='fro', dim=(1, 2))
        Kn = torch.linalg.norm(K, ord='fro', dim=(1, 2))
        Sn = torch.linalg.norm(Sm, ord='fro', dim=(1, 2))
    else:
        total = Q.shape[0]  # B*H
        if sample_heads is not None and total > sample_heads:
            idx = torch.randperm(total, device=Q.device)[:sample_heads]
            Qs = Q[idx]; Ks = K[idx]; Ss = Sm[idx]
        else:
            Qs = Q; Ks = K; Ss = Sm
        Qn = spectral_norm_matrix_batch(Qs, sample=None, iters=iters)
        Kn = spectral_norm_matrix_batch(Ks, sample=None, iters=iters)
        Sn = spectral_norm_power(Ss, iters=iters)

    ratio = (Qn * Kn) / (Sn * math.sqrt(d) + 1e-12)
    return float(torch.nan_to_num(ratio, nan=0.0, posinf=0.0).max().item())

@torch.no_grad()
def kappa_cond_W(W: torch.Tensor, ridge: float = 1e-6) -> float:
    if W is None:
        return float('nan')
    s = torch.linalg.svdvals(W.float())
    if s.numel() < 2:
        return float('nan')
    return float((s.max() / (s.min() + ridge)).item())

@torch.no_grad()
def ln_C_proxy(ln_module: nn.LayerNorm, pre_act: torch.Tensor) -> float:
    var = torch.var(pre_act, dim=-1, unbiased=False)    # per-token variance
    sigma_med = torch.sqrt(torch.median(var)).item()
    var_med = torch.median(var).item()
    eps = float(max(getattr(ln_module, "eps", 1e-5), 1e-12))
    return float((sigma_med / math.sqrt(eps)) + (var_med / eps) + 1.0)

@torch.no_grad()
def kappa_softmax_pdf(attn_probs: torch.Tensor, scores: torch.Tensor, sample_rows: int = 256, iters: int = 7) -> float:
    P = attn_probs.detach().to(torch.float32)
    S = scores.detach().to(torch.float32)
    B, H, N, _ = P.shape
    P_rows = P.reshape(B*H*N, N)
    S_rows = S.reshape(B*H*N, N)
    M = P_rows.size(0)
    if M > sample_rows:
        idx = torch.randperm(M, device=P_rows.device)[:sample_rows]
        P_rows = P_rows[idx]; S_rows = S_rows[idx]
    v = torch.randn_like(P_rows)
    v = v / (v.norm(dim=-1, keepdim=True) + 1e-12)
    for _ in range(iters):
        pv = (P_rows * v).sum(dim=-1, keepdim=True)
        Jv = P_rows * v - pv * P_rows
        v = Jv / (Jv.norm(dim=-1, keepdim=True) + 1e-12)
    pv = (P_rows * v).sum(dim=-1, keepdim=True)
    Jv = P_rows * v - pv * P_rows
    lam = (v * Jv).sum(dim=-1) / (v * v).sum(dim=-1).clamp_min(1e-12)
    scale = (S_rows.norm(dim=-1) / (P_rows.norm(dim=-1) + 1e-12))
    val = (lam.clamp(min=0) * scale)
    return float(val.max().item())

@torch.no_grad()
def kappa_V_pdf(v: torch.Tensor, ridge: float = 1e-6, sample_heads: int = 32) -> float:
    B, H, N, d = v.shape
    total = B*H
    v_mat = v.reshape(total, N, d).float()
    if total > sample_heads:
        idx = torch.randperm(total, device=v_mat.device)[:sample_heads]
        v_mat = v_mat[idx]
    vals = []
    for Vm in v_mat:
        try:
            s = torch.linalg.svdvals(Vm)
            if s.numel() < 2: continue
            kappa = (s.max() / (s.min() + ridge)).item()
            vals.append(kappa)
        except Exception:
            continue
    return float(np.nanmax(vals) if vals else np.nan)

@torch.no_grad()
def attention_entropy(attn_probs, sample_rows=256):
    p = attn_probs.detach().to(torch.float32)
    B, H, T, _ = p.shape
    rows = p.reshape(B*H*T, T)
    if rows.size(0) > sample_rows:
        idx = torch.randperm(rows.size(0), device=rows.device)[:sample_rows]
        rows = rows[idx]
    eps = 1e-12
    ent = -(rows * (rows+eps).log()).sum(dim=-1)
    return ent.mean().item()

def _eps_mach_for(precision: str) -> float:
    try:
        if precision == "fp16":
            return float(torch.finfo(torch.float16).eps)
        if precision == "bf16":
            return float(torch.finfo(torch.bfloat16).eps)
        if precision == "fp32":
            return float(torch.finfo(torch.float32).eps)
    except Exception:
        pass
    # fallback
    return float(torch.finfo(torch.float32).eps)

@torch.no_grad()
def rho_LN(ln_module: nn.LayerNorm, pre_act: torch.Tensor, eps_mach: float = None) -> float:
    if eps_mach is None:
        eps_mach = float(torch.finfo(torch.float32).eps)
    var = torch.var(pre_act, dim=-1, unbiased=False)
    sigma2_med = torch.median(var).item()
    d_model = pre_act.shape[-1]
    return float((sigma2_med / max(getattr(ln_module, "eps", 1e-12), 1e-12)) * d_model * eps_mach)

@torch.no_grad()
def forward_error_same_weights(model, x, lp_precision="fp16"):
    with no_autocast_ctx(device):
        y_ref, _ = model(x)
    dtype_map = {"fp16": torch.float16, "bf16": torch.bfloat16}
    with torch.amp.autocast("cuda", enabled=(lp_precision in dtype_map),
                            dtype=dtype_map.get(lp_precision, torch.float16)):
        y_lp, _ = model(x)
    num = torch.norm(y_lp.float() - y_ref.float())
    den = torch.norm(y_ref.float()) + 1e-12
    return (num/den).item()

@torch.no_grad()
def weight_condition_number(weights: List[torch.Tensor]):
    if len(weights) == 0:
        return np.nan
    W = weights[-1].float()
    try:
        s = torch.linalg.svdvals(W)
        s = s[s>0]
        if s.numel() < 2: return np.nan
        return (s.max()/s.min()).item()
    except Exception:
        return np.nan

@torch.no_grad()
def spectral_norm_W(W: torch.Tensor) -> float:
    if W is None:
        return float('nan')
    s = torch.linalg.svdvals(W.float())
    return float(s.max().item()) if s.numel() else float('nan')

# ------------------------------
# 4) Data & training
# ------------------------------
def build_dataloader(batch_size=256, num_workers=None, generator=None):
    if num_workers is None: num_workers = os.cpu_count() or 4
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
    ])
    trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
    return DataLoader(
        trainset, batch_size=batch_size, shuffle=True, drop_last=True,
        num_workers=num_workers, pin_memory=True, persistent_workers=True, prefetch_factor=4,
        generator=generator
    )

def train_one_epoch(model, dataloader, optimizer, criterion,
                    precision="fp16", log_freq=20, max_steps=60,
                    ksoft_rows=256, kscore_mats=64, power_iters=7,
                    norm_mode="spec",
                    use_res_relax=False):
    model.train()
    scaler = torch.amp.GradScaler("cuda", enabled=(precision=="fp16" and device.type=="cuda"))
    logs = []

    max_steps = min(max_steps, len(dataloader))
    pbar = tqdm(dataloader, total=max_steps, desc=f"Train(prec={precision})")

    for step, (inputs, labels) in enumerate(pbar):
        if step >= max_steps:
            break
        inputs = inputs.to(device, non_blocking=True)
        labels = labels.to(device, non_blocking=True)

        optimizer.zero_grad(set_to_none=True)
        with autocast_ctx(device, precision):
            outputs, inter = model(inputs)
            loss = criterion(outputs, labels)

        if torch.isnan(loss):
            logs.append({'step': step, 'loss': float('nan'), 'precision': precision, 'failure': 1})
            break

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        if (step % log_freq) == 0:
            entry = {'step': step, 'loss': float(loss.detach().cpu()),
                     'precision': precision}

            L = len(inter['scores'])
            per_layer_terms = []
            rho_factors = []
            ksoft_all, kscore_all, kV_all = [], [], []

            attn_core_sum = 0.0
            attn_term_sum = 0.0
            k_eff_sum     = 0.0
            C_LN_sum      = 0.0
            for li in range(L):
                scores = inter['scores'][li]
                probs  = inter['attn_probs'][li]
                q      = inter['q'][li]
                k      = inter['k'][li]
                v      = inter['v'][li]
                Wo     = inter['attn_Wo'][li]
                W1     = inter['mlp_w1'][li]
                W2     = inter['mlp_w2'][li]

                # --- Paper-aligned diagnostics ---
                ksoft = kappa_softmax_pdf(probs, scores, sample_rows=ksoft_rows, iters=power_iters)
                kscore = kappa_score_pdf(q, k, scores, norm_mode=norm_mode, iters=power_iters, sample_heads=kscore_mats)
                kV = kappa_V_pdf(v, ridge=3e-6, sample_heads=kscore_mats)

                Wo_spec = spectral_norm_W(Wo)
                attn_core = ksoft * (1.0 + kscore) * kV
                attn_term = attn_core * Wo_spec

                k_eff = kappa_cond_W(W1) + kappa_cond_W(W2) + kappa_cond_W(Wo)

                ln_idx_a = 2*li
                ln_idx_b = 2*li + 1
                try:
                    C_LN_a = ln_C_proxy(inter['ln_modules'][ln_idx_a], inter['ln_pre_act'][ln_idx_a])
                except Exception:
                    C_LN_a = 0.0
                try:
                    C_LN_b = ln_C_proxy(inter['ln_modules'][ln_idx_b], inter['ln_pre_act'][ln_idx_b])
                except Exception:
                    C_LN_b = 0.0
                C_LN = max(C_LN_a, C_LN_b)

                term = attn_term + k_eff + C_LN
                per_layer_terms.append(term)

                rho_hat = min(0.9, Wo_spec + spectral_norm_W(W2) * spectral_norm_W(W1))
                rho_factors.append(1.0 + (rho_hat if use_res_relax else 0.0))

                ksoft_all.append(ksoft); kscore_all.append(kscore); kV_all.append(kV)

                attn_core_sum += float(attn_core)
                attn_term_sum += float(attn_term)
                k_eff_sum     += float(k_eff)
                C_LN_sum      += float(C_LN)

            if use_res_relax:
                pred_sum = 0.0
                for li, t in enumerate(per_layer_terms):
                    down = 1.0
                    if li + 1 < L:
                        for rr in rho_factors[li+1:]:
                            down *= rr
                    pred_sum += t * down
            else:
                pred_sum = float(np.sum(per_layer_terms))

            entry['k_softmax_pdf'] = float(np.max(ksoft_all))
            entry['k_score_pdf']   = float(np.max(kscore_all))
            entry['k_V_pdf']       = float(np.max(kV_all))
            entry['pred_rhs']      = float(pred_sum)

            entry['attn_core_sum'] = float(attn_core_sum)
            entry['attn_term_sum'] = float(attn_term_sum)
            entry['k_eff_sum']     = float(k_eff_sum)
            entry['C_LN_sum']      = float(C_LN_sum)

            # baselines
            entry['entropy'] = float(np.mean([
                attention_entropy(inter['attn_probs'][li], sample_rows=min(4*ksoft_rows, 2048))
                for li in range(L)
            ]))
            entry['w_cond']  = weight_condition_number(inter['mlp_weights'])

            try:
                ln_mod_last = inter['ln_modules'][-1]
                pre_last    = inter['ln_pre_act'][-1].detach()
                entry['rho_LN'] = rho_LN(ln_mod_last, pre_last)
            except Exception:
                entry['rho_LN'] = float('nan')

            entry['fwd_error'] = forward_error_same_weights(model, inputs, lp_precision=precision)
            entry['failure'] = 0
            logs.append(entry)
            pbar.set_postfix({'loss': entry['loss'], 'k_softmax_pdf': entry['k_softmax_pdf']})

    return pd.DataFrame(logs)

# ------------------------------
# 5) Experiments
# ------------------------------
@dataclass
class ExpConfig:
    seeds: List[int]
    precision: List[str]
    dims: List[int]
    image_size: int = 32
    patch: int = 4
    heads: int = 4
    depth: int = 2
    iters_exp1: int = 7
    iters_exp2: int = 7
    iters_exp3: int = 7
    steps_exp1: int = 160
    steps_exp2: int = 800
    steps_exp3: int = 400
    ksoft_rows: int = 256
    kscore_mats: int = 64
    batch_size: int = 256

def run_exp1_decisive(dataloader, cfg: ExpConfig, save_csv: bool = False, print_tables: bool = True):
    rows = []
    for seed in cfg.seeds:
        set_seed(seed)
        for prec in cfg.precision:
            for dim in cfg.dims:
                print(f"[Exp1] seed={seed} prec={prec} dim={dim}")
                model = ViT(image_size=cfg.image_size, patch_size=cfg.patch, num_classes=10,
                            dim=dim, depth=cfg.depth, heads=cfg.heads, mlp_dim=dim*2).to(device)
                opt = optim.Adam(model.parameters(), lr=1e-3)
                crit = nn.CrossEntropyLoss()
                df = train_one_epoch(
                    model, dataloader, opt, crit,
                    precision=prec, max_steps=cfg.steps_exp1,
                    ksoft_rows=cfg.ksoft_rows, kscore_mats=cfg.kscore_mats,
                    power_iters=cfg.iters_exp1, norm_mode="spec"
                )
                if df.empty:
                    continue
                df['seed'] = seed
                df['precision'] = prec
                df['dim'] = dim
                EXP1_STEP_LOGS.append(df.copy())

                q95 = lambda s: float(np.nanquantile(s, 0.95))
                rows.append(dict(
                    seed=seed, precision=prec, dim=dim,
                    pred_combo   = q95(df['pred_rhs']),
                    pred_ksoft   = q95(df['k_softmax_pdf']),
                    pred_kscore  = q95(df['k_score_pdf']),
                    pred_kV      = q95(df['k_V_pdf']),
                    pred_entropy = q95(df['entropy']),
                    pred_wcond   = float(np.nanmedian(df['w_cond'])),
                    obs_error    = float(np.nanquantile(df['fwd_error'], 0.95)),
                    failure_rate = float(df['failure'].mean())
                ))

    res = pd.DataFrame(rows)
    if res.empty:
        print("[Exp1] No results.")
        return res, pd.DataFrame(), pd.DataFrame()

    agg_keys = ['precision', 'dim']
    agg = res.groupby(agg_keys).agg({
        'pred_combo':   ['mean', 'std'],
        'pred_ksoft':   ['mean', 'std'],
        'pred_kscore':  ['mean', 'std'],
        'pred_kV':      ['mean', 'std'],
        'pred_entropy': ['mean', 'std'],
        'pred_wcond':   ['mean', 'std'],
        'obs_error':    ['mean', 'std'],
        'failure_rate': ['mean']
    }).reset_index()

    flat_cols = []
    for c in agg.columns:
        if isinstance(c, tuple):
            base, stat = c
            flat_cols.append(f"{base}_{stat}" if stat else base)
        else:
            flat_cols.append(c)
    agg.columns = flat_cols

    maybe_save_csv(agg, os.path.join(OUT_DIR, "exp1_results_by_config.csv"), save_csv)
    maybe_save_csv(res, os.path.join(OUT_DIR, "exp1_results_raw.csv"), save_csv)
    if print_tables:
        print_table("Exp1 — results by config (aggregated)", agg)
        print_table("Exp1 — results raw (per seed/prec/dim, q95 summaries)", res)

    lp = agg[agg['precision'].isin(['fp16', 'bf16'])].copy()
    if lp.empty:
        lp = agg.copy()

    proxies = [
        ('pred_combo_mean',  'Combined (ours)'),
        ('pred_ksoft_mean',  'k_softmax'),
        ('pred_kscore_mean', 'k_score'),
        ('pred_kV_mean',     'k(V)'),
        ('pred_entropy_mean','Entropy'),
        ('pred_wcond_mean',  'WeightCond')
    ]

    stats_rows = []
    for col, label in proxies:
        if col not in lp.columns:
            continue
        x = lp[col].to_numpy()
        y = lp['obs_error_mean'].to_numpy()
        if np.sum(np.isfinite(x) & np.isfinite(y)) < 3:
            stats_rows.append(dict(proxy=label, pearson_log=np.nan, spearman=np.nan,
                                     n=len(x), slope=np.nan, intercept=np.nan, R2=np.nan, pval=np.nan))
            continue
        pearson = np.corrcoef(np.log10(x + 1e-12), np.log10(y + 1e-12))[0, 1]
        rx = np.argsort(np.argsort(x)); ry = np.argsort(np.argsort(y))
        spearman = np.corrcoef(rx, ry)[0, 1]
        reg = loglog_regression(x, y)
        stats_rows.append(dict(proxy=label, pearson_log=pearson, spearman=spearman, **reg))
    stats = pd.DataFrame(stats_rows)

    maybe_save_csv(stats, os.path.join(OUT_DIR, "exp1_proxy_stats.csv"), save_csv)
    if print_tables:
        print_table("Exp1 — proxy stats (LP mean vs observed mean)", stats)

    def eps_of(prec: str) -> float:
        try:
            if prec == 'fp16':   return float(torch.finfo(torch.float16).eps)
            if prec == 'bf16':   return float(torch.finfo(torch.bfloat16).eps)
            if prec == 'fp32':   return float(torch.finfo(torch.float32).eps)
        except Exception:
            pass
        return {'fp16': 2**-10, 'bf16': 2**-7}.get(prec, np.finfo(np.float32).eps)

    res['eps_mach'] = res['precision'].map(eps_of)
    res['pred_combo_eps'] = res['pred_combo'] * res['eps_mach']
    maybe_save_csv(res, os.path.join(OUT_DIR, "exp1_results_raw_with_eps.csv"), save_csv)

    def compute_stats(df, ycol):
        rows_s = []
        proxies2 = [
            ('pred_combo',     'Combined (ours)'),
            ('pred_combo_eps', 'Combined·εmach'),
            ('pred_ksoft',     'k_softmax'),
            ('pred_kscore',    'k_score'),
            ('pred_kV',        'k(V)'),
            ('pred_entropy',   'Entropy'),
            ('pred_wcond',     'WeightCond'),
        ]
        keep = [c for c,_ in proxies2 if c in df.columns] + [ycol]
        df = df.replace([np.inf, -np.inf], np.nan).dropna(subset=keep, how='any')
        npoints = len(df)
        for key, lab in proxies2:
            if key not in df.columns:
                continue
            x = df[key].to_numpy()
            y = df[ycol].to_numpy()
            if np.sum(np.isfinite(x) & np.isfinite(y)) < 3:
                rows_s.append(dict(proxy=lab, pearson_log=np.nan, spearman=np.nan,
                                     n=npoints, slope=np.nan, intercept=np.nan, R2=np.nan, pval=np.nan))
                continue
            pearson = np.corrcoef(np.log10(x + 1e-12), np.log10(y + 1e-12))[0, 1]
            rx = np.argsort(np.argsort(x)); ry = np.argsort(np.argsort(y))
            spearman = np.corrcoef(rx, ry)[0, 1]
            reg = loglog_regression(x, y)
            rows_s.append(dict(proxy=lab, pearson_log=pearson, spearman=spearman, **reg))
        return pd.DataFrame(rows_s)

    stats_by_prec = []
    for prec in ['fp16','bf16']:
        sub = res[res['precision'] == prec].copy()
        if len(sub) == 0:
            continue
        s = compute_stats(sub, ycol='obs_error')
        s.insert(0, 'precision', prec)
        stats_by_prec.append(s)
    stats_by_prec = pd.concat(stats_by_prec, ignore_index=True) if stats_by_prec else pd.DataFrame()
    maybe_save_csv(stats_by_prec, os.path.join(OUT_DIR, "exp1_proxy_stats_by_precision_raw.csv"), save_csv)

    stats_eps = compute_stats(res, ycol='obs_error')
    maybe_save_csv(stats_eps, os.path.join(OUT_DIR, "exp1_proxy_stats_mixed_with_eps.csv"), save_csv)

    if print_tables:
        print_table("Exp1 — stratified proxy stats (raw, by precision)", stats_by_prec)
        print_table("Exp1 — mixed-precision proxy stats (εmach scaled)", stats_eps)

    if EXP1_STEP_LOGS:
        step_df = pd.concat(EXP1_STEP_LOGS, ignore_index=True)
        maybe_save_csv(step_df, os.path.join(OUT_DIR, "exp1_step_logs.csv"), save_csv)

        cols = ["seed","precision","dim","step","pred_rhs","k_softmax_pdf","k_score_pdf","k_V_pdf","fwd_error"]
        step_df = step_df[cols].replace([np.inf, -np.inf], np.nan).dropna()

        def winsorize(s):
            lo, hi = np.nanpercentile(s, [1, 99])
            return np.clip(s, lo, hi)

        for c in ["pred_rhs","k_softmax_pdf","k_score_pdf","k_V_pdf","fwd_error"]:
            step_df[c] = winsorize(step_df[c])

        out_rows = []
        for prec in ["fp16","bf16","fp32"]:
            sub = step_df[step_df["precision"]==prec].copy()
            if len(sub) < 100:
                continue
            st = loglog_regression(sub["pred_rhs"].to_numpy(), sub["fwd_error"].to_numpy())
            sp = pd.Series(np.log10(sub["pred_rhs"]+1e-12)).corr(
                 pd.Series(np.log10(sub["fwd_error"]+1e-12)), method="spearman")
            kd = pd.Series(np.log10(sub["pred_rhs"]+1e-12)).corr(
                 pd.Series(np.log10(sub["fwd_error"]+1e-12)), method="kendall")
            out_rows.append(dict(precision=prec, n=len(sub), spearman=sp, kendall=kd, **st))
        pointwise_by_prec = pd.DataFrame(out_rows)
        maybe_save_csv(pointwise_by_prec, os.path.join(OUT_DIR,"exp1_pointwise_stats_by_precision.csv"), save_csv)
        if print_tables:
            print_table("Exp1 — pointwise stats by precision (winsorized 1/99%)", pointwise_by_prec)

    lp2 = agg[agg['precision'].isin(['fp16', 'bf16'])].copy()
    if lp2.empty:
        lp2 = agg.copy()

    if 'pred_combo_mean' in lp2.columns:
        plt.figure(figsize=(8, 6))
        sns.scatterplot(
            data=lp2, x='pred_combo_mean', y='obs_error_mean',
            hue='precision', style='dim', s=160, alpha=0.9
        )
        plt.xscale('log'); plt.yscale('log')
        plt.title('Exp1: Predicted vs Observed Instability (LP only, mean over seeds)')
        plt.xlabel('Predicted RHS: k_softmax · (1 + k_score) · k(V) · ||W_O||_2  +  κ_eff  +  C_LN')
        plt.ylabel('Observed FP32 vs LP error')
        plt.grid(True, which='both', ls='--', alpha=0.5)
        plt.tight_layout()
        plt.savefig(os.path.join(OUT_DIR, "fig_exp1.png"), dpi=240)
        plt.close()

    raw_lp = res[res['precision'].isin(['fp16','bf16'])].copy()
    if len(raw_lp) >= 3:
        plt.figure(figsize=(8,6))
        sns.scatterplot(data=raw_lp, x='pred_combo', y='obs_error',
                        hue='precision', style='dim', size='seed', sizes=(60,160), alpha=0.9)
        plt.xscale('log'); plt.yscale('log')
        plt.title('Exp1 Raw: Predicted vs Observed (per-seed)')
        plt.xlabel('Predicted RHS: k_softmax · (1 + k_score) · k(V) · ||W_O||_2  +  κ_eff  +  C_LN')
        plt.ylabel('Observed FP32 vs LP error')
        plt.grid(True, which='both', ls='--', alpha=0.5)
        plt.tight_layout()
        plt.savefig(os.path.join(OUT_DIR, "fig_exp1_raw_scatter.png"), dpi=240)
        plt.close()

    if len(res) >= 3:
        plt.figure(figsize=(8,6))
        sns.scatterplot(data=res, x='pred_combo_eps', y='obs_error',
                        hue='precision', style='dim', size='seed', sizes=(60,160), alpha=0.9)
        plt.xscale('log'); plt.yscale('log')
        plt.title('Exp1 Mixed: Predicted·εmach vs Observed (per-seed, mixed precision)')
        plt.xlabel('Predicted RHS: k_softmax · (1 + k_score) · k(V) · ||W_O||_2  +  κ_eff  +  C_LN')
        plt.ylabel('Observed FP32 vs LP error')
        plt.grid(True, which='both', ls='--', alpha=0.5)
        plt.tight_layout()
        plt.savefig(os.path.join(OUT_DIR, "fig_exp1_mixed_eps_scatter.png"), dpi=240)
        plt.close()

    return res, agg, stats

def posthoc_pointwise_analysis(step_csv_path: str):
    import os
    if not os.path.exists(step_csv_path):
        print(f"[Exp1] step csv not found: {step_csv_path}")
        return
    df = pd.read_csv(step_csv_path)

    cols = ["seed","precision","dim","step","pred_rhs","k_softmax_pdf","k_score_pdf","k_V_pdf","fwd_error"]
    df = df[cols].replace([np.inf, -np.inf], np.nan).dropna()

    def winsorize(s):
        lo, hi = np.nanpercentile(s, [1, 99])
        return np.clip(s, lo, hi)
    for c in ["pred_rhs","k_softmax_pdf","k_score_pdf","k_V_pdf","fwd_error"]:
        df[c] = winsorize(df[c])

    out_rows = []
    for prec in ["fp16","bf16","fp32"]:
        sub = df[df["precision"]==prec].copy()
        if len(sub) < 100:
            continue
        X = np.log10(sub["pred_rhs"] + 1e-12)
        Y = np.log10(sub["fwd_error"] + 1e-12)
        sp = pd.Series(X).corr(pd.Series(Y), method="spearman")
        kd = pd.Series(X).corr(pd.Series(Y), method="kendall")
        st = loglog_regression(sub["pred_rhs"].to_numpy(), sub["fwd_error"].to_numpy())
        out_rows.append(dict(precision=prec, n=len(sub), spearman=sp, kendall=kd, **st))
    pd.DataFrame(out_rows).to_csv(os.path.join(OUT_DIR,"exp1_pointwise_stats_by_precision.csv"), index=False)

    rows = []
    for (prec,dim,seed), g in df.groupby(["precision","dim","seed"]):
        if len(g) < 30:
            continue
        st = loglog_regression(g["pred_rhs"].to_numpy(), g["fwd_error"].to_numpy())
        rows.append(dict(precision=prec, dim=dim, seed=seed, **st))
    perrun = pd.DataFrame(rows)
    perrun.to_csv(os.path.join(OUT_DIR,"exp1_pointwise_slope_per_run.csv"), index=False)

    for prec in ["fp16","bf16"]:
        sub = df[df["precision"]==prec]
        if len(sub) < 200:
            continue
        plt.figure(figsize=(7.2,5.8))
        plt.hexbin(sub["pred_rhs"], sub["fwd_error"], gridsize=40, xscale="log", yscale="log", mincnt=3)
        plt.xlabel("Pred RHS (k_softmax · (1 + k_score) · k(V))")
        plt.ylabel("Observed LP vs FP32 error")
        plt.title(f"Exp1 Pointwise (precision={prec})")
        plt.colorbar(label="counts")
        plt.tight_layout()
        plt.savefig(os.path.join(OUT_DIR, f"fig_exp1_pointwise_hex_{prec}.png"), dpi=220)
        plt.close()

def _exp2_analyze_one(df, seed, target, out_tag, OUT_DIR):
    """
    동일한 조기경보 분석을 target(= 'loss' 또는 'fwd_error')에 대해 수행:
      - 타임시리즈 도표 저장
      - lead–lag 및 permutation test
      - Spike Precision@K
    out_tag는 파일명 구분용 접미사(예: 'loss', 'fwd').
    """
    import numpy as np
    import pandas as pd
    import matplotlib.pyplot as plt

    fig, ax1 = plt.subplots(figsize=(12,5))
    ax1.plot(df['step'], df[target], lw=1.8, label=target)
    ax1.set_xlabel('Step'); ax1.set_ylabel(target)
    ax2 = ax1.twinx()
    ax2.plot(df['step'], df['k_softmax_pdf'], lw=1.8, label='max k_softmax (paper)')
    ax2.set_ylabel('max k_softmax (paper)')
    fig.suptitle(f'Exp2: Early Warning (seed={seed}, target={target})')
    fig.tight_layout()
    fig.savefig(os.path.join(OUT_DIR, f"fig_exp2_seed{seed}_{out_tag}.png"), dpi=220)
    plt.close(fig)

    x = df['k_softmax_pdf'].to_numpy()
    y = df[target].to_numpy()
    x = (x - x.mean())/(x.std()+1e-12); y = (y - y.mean())/(y.std()+1e-12)

    max_lag = min(60, len(x)//3)
    lags = np.arange(-max_lag, max_lag+1)
    ccs=[]
    for L in lags:
        if L<0: ccs.append(np.corrcoef(x[-L:], y[:len(y)+L])[0,1])
        elif L>0: ccs.append(np.corrcoef(x[:len(x)-L], y[L:])[0,1])
        else: ccs.append(np.corrcoef(x,y)[0,1])
    ccs = np.array(ccs)
    best_idx = int(np.nanargmax(ccs))
    best_lag = int(lags[best_idx])
    best_cc  = float(ccs[best_idx])

    rng = np.random.default_rng(seed+123)
    perm = []
    for _ in range(1000):
        px = rng.permutation(x)
        ccs_p = []
        for L in lags:
            if L<0: ccs_p.append(np.corrcoef(px[-L:], y[:len(y)+L])[0,1])
            elif L>0: ccs_p.append(np.corrcoef(px[:len(px)-L], y[L:])[0,1])
            else:   ccs_p.append(np.corrcoef(px, y)[0,1])
        perm.append(np.nanmax(ccs_p))
    perm = np.array(perm)
    pval = (np.sum(perm >= best_cc) + 1) / (len(perm) + 1)

    plt.figure(figsize=(7,3.6))
    plt.plot(lags, ccs)
    plt.axvline(best_lag, ls='--')
    plt.title(f'Lead–Lag (seed={seed}, target={target})')
    plt.xlabel('Lag (steps)'); plt.ylabel('corr')
    plt.tight_layout()
    plt.savefig(os.path.join(OUT_DIR, f"fig_exp2_leadlag_seed{seed}_{out_tag}.png"), dpi=220)
    plt.close()

    H, Z, MA = 40, 1.5, 20
    series = df[target].to_numpy()
    mu = pd.Series(series).rolling(MA, min_periods=MA).mean().to_numpy()
    sd = pd.Series(series).rolling(MA, min_periods=MA).std(ddof=0).to_numpy()
    thr = (mu + Z*np.nan_to_num(sd, nan=0.0))
    future_max = np.array([series[i:min(i+H, len(series))].max() for i in range(len(series))])
    event = (future_max > np.maximum(series, np.nan_to_num(thr, nan=np.inf))).astype(int)

    score = (df['k_softmax_pdf'] - df['k_softmax_pdf'].rolling(MA, min_periods=MA).mean()).fillna(0).to_numpy()
    K = max(10, len(score)//10)
    top_idx = np.argsort(score)[-K:]
    prec_at_k = float(event[top_idx].mean())

    return dict(seed=seed, target=target, best_lag=best_lag, best_corr=best_cc,
                pval=pval, prec_at_k=prec_at_k)

def run_exp2_early_warning(dataloader, cfg: ExpConfig, precision='fp16',
                           save_csv: bool = False, print_tables: bool = True):
    rows_per_seed = []
    lag_rows = []

    for seed in cfg.seeds:
        set_seed(seed)
        print(f"[Exp2] seed={seed} prec={precision}")

        model = ViT(
            image_size=cfg.image_size, patch_size=cfg.patch, num_classes=10,
            dim=128, depth=cfg.depth, heads=cfg.heads, mlp_dim=256
        ).to(device)
        opt = optim.Adam(model.parameters(), lr=1e-3)
        crit = nn.CrossEntropyLoss()

        df = train_one_epoch(
            model, dataloader, opt, crit,
            precision=precision, log_freq=1, max_steps=cfg.steps_exp2,
            ksoft_rows=cfg.ksoft_rows, kscore_mats=cfg.kscore_mats,
            power_iters=cfg.iters_exp2
        )
        if df.empty:
            continue

        df['seed'] = seed
        rows_per_seed.append(df)

        for target, tag in [('loss', 'loss'), ('fwd_error', 'fwd')]:
            stats_one = _exp2_analyze_one(df, seed, target, tag, OUT_DIR)
            lag_rows.append(stats_one)
            print(
                f"[Exp2][seed={seed}][{target}] "
                f"lag={stats_one['best_lag']}, "
                f"corr={stats_one['best_corr']:.3f}, "
                f"p≈{stats_one['pval']:.4f}, "
                f"Prec@K≈{stats_one['prec_at_k']:.2f}"
            )

    if not rows_per_seed:
        return pd.DataFrame(), pd.DataFrame()

    df_all = pd.concat(rows_per_seed, ignore_index=True)
    lag_df = pd.DataFrame(lag_rows)

    maybe_save_csv(df_all, os.path.join(OUT_DIR, "exp2_logs_all.csv"), save_csv)
    maybe_save_csv(lag_df, os.path.join(OUT_DIR, "exp2_leadlag_both_targets.csv"), save_csv)
    for t, g in lag_df.groupby('target'):
        maybe_save_csv(g, os.path.join(OUT_DIR, f"exp2_leadlag_{t}.csv"), save_csv)

    if print_tables:
        print_table("Exp2 — lead–lag & Prec@K (both targets)", lag_df)

        fwd_only = lag_df[lag_df['target'] == 'fwd_error'].copy()
        if not fwd_only.empty:
            summary = pd.DataFrame([{
                'mean_lead_lag_steps': np.mean(fwd_only['best_lag']),
                'mean_corr': np.mean(fwd_only['best_corr']),
                'mean_Prec@K': np.mean(fwd_only['prec_at_k']),
                'n_seeds': len(fwd_only)
            }])
            print_table("Exp2 — fwd_error summary (quick)", summary)

    return df_all, lag_df

@torch.no_grad()
def dynamic_eps_from_rhoLN(pre_act: torch.Tensor, rho_star: float,
                           floor: float, cap: float, eps_mach: float) -> float:
    var_per_token = torch.var(pre_act, dim=-1, unbiased=False)
    sigma2_med = torch.median(var_per_token).item()
    d_model = pre_act.shape[-1]
    eps_new = (sigma2_med * d_model * eps_mach) / max(rho_star, 1e-8)
    return float(min(max(eps_new, floor), cap))

def train_with_intervention(model, dataloader, optimizer, criterion,
                            precision="fp16", intervention=False, rho_star=0.6,
                            floor=1e-6, cap=1e-2, check_every=5, sample_for_check=16,
                            max_steps=400, iters=7):
    model.train()
    original_eps = {m: m.eps for m in model.modules() if isinstance(m, nn.LayerNorm)}
    scaler = torch.amp.GradScaler("cuda", enabled=(precision=="fp16" and device.type=="cuda"))
    logs = []

    max_steps = min(max_steps, len(dataloader))
    pbar = tqdm(dataloader, total=max_steps, desc=f"Train(prec={precision})")

    eps_mach = _eps_mach_for(precision)

    for step, (inp, lab) in enumerate(pbar):
        if step >= max_steps:
            break
        inp = inp.to(device, non_blocking=True)
        lab = lab.to(device, non_blocking=True)

        if intervention and (step % check_every == 0):
            with torch.no_grad():
                sub = inp[:sample_for_check]
                _, inter = model(sub)
                for ln_mod, pre in zip(inter['ln_modules'], inter['ln_pre_act']):
                    rho_now = rho_LN(ln_mod, pre, eps_mach=eps_mach)
                    if rho_now < 1.0:
                        eps_new = dynamic_eps_from_rhoLN(
                            pre, rho_star=rho_star, floor=floor, cap=cap, eps_mach=eps_mach
                        )
                        ln_mod.eps = max(original_eps.get(ln_mod, ln_mod.eps), ln_mod.eps, eps_new)

        optimizer.zero_grad(set_to_none=True)
        with autocast_ctx(device, precision):
            out, _ = model(inp)
            loss = criterion(out, lab)
        if torch.isnan(loss):
            loss = torch.tensor(10.0, device=device)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        try:
            fe = forward_error_same_weights(model, inp, lp_precision=precision)
        except Exception:
            fe = float('nan')

        logs.append({
            'step': step,
            'loss': float(loss.detach().cpu()),
            'fwd_error': float(fe),
        })
        pbar.set_postfix({'loss': float(loss.detach().cpu()), 'fwd_error': float(fe)})

    for m, eps in original_eps.items():
        m.eps = eps

    return pd.DataFrame(logs)

def run_exp3_intervention(dl_control, dl_intervene, cfg: ExpConfig,
                          precision='fp16',
                          rho_grid=(0.5,0.6,0.7), cap_grid=(5e-3, 1e-2),
                          save_csv: bool = False, print_tables: bool = True):
    results = []
    series_pngs = []

    for seed in cfg.seeds:
        set_seed(seed)

        for rho_star in rho_grid:
            for cap in cap_grid:
                print(f"[Exp3] seed={seed} rho*={rho_star} cap={cap}")

                # control
                model_c = ViT(image_size=cfg.image_size, patch_size=cfg.patch, num_classes=10,
                                dim=128, depth=cfg.depth, heads=cfg.heads, mlp_dim=256).to(device)
                opt_c = optim.Adam(model_c.parameters(), lr=1e-3)
                crit = nn.CrossEntropyLoss()
                df_c = train_with_intervention(
                    model_c, dl_control, opt_c, crit,
                    precision=precision, intervention=False,
                    max_steps=cfg.steps_exp3, iters=cfg.iters_exp3
                )

                # intervention
                model_i = ViT(image_size=cfg.image_size, patch_size=cfg.patch, num_classes=10,
                                dim=128, depth=cfg.depth, heads=cfg.heads, mlp_dim=256).to(device)
                opt_i = optim.Adam(model_i.parameters(), lr=1e-3)
                df_i = train_with_intervention(
                    model_i, dl_intervene, opt_i, crit,
                    precision=precision, intervention=True,
                    rho_star=rho_star, cap=cap,
                    max_steps=cfg.steps_exp3, iters=cfg.iters_exp3
                )

                # loss curves (moving average)
                fig = plt.figure(figsize=(12,5))
                plt.plot(df_c['step'], df_c['loss'].rolling(10).mean(), label='Control', lw=2)
                plt.plot(df_i['step'], df_i['loss'].rolling(10).mean(),
                         label=f'Intervene (rho*={rho_star}, cap={cap})', lw=2)
                plt.title(f'Exp3 seed={seed}'); plt.xlabel('Step'); plt.ylabel('Loss (10-step MA)')
                plt.legend(); plt.tight_layout()
                png_path = os.path.join(OUT_DIR, f"fig_exp3_seed{seed}_rho{rho_star}_cap{cap}.png")
                fig.savefig(png_path, dpi=220); plt.close(fig)
                series_pngs.append(png_path)

                # tail metrics
                tail = 50
                def tail_mean(df, key):
                    s = df[key].tail(tail).replace([np.inf, -np.inf], np.nan).dropna()
                    return float(s.mean()) if len(s) else float('nan')
                m_loss_c = tail_mean(df_c, 'loss');     m_loss_i = tail_mean(df_i, 'loss')
                m_fwd_c  = tail_mean(df_c, 'fwd_error'); m_fwd_i  = tail_mean(df_i, 'fwd_error')

                results.append(dict(
                    seed=seed, rho_star=rho_star, cap=cap,
                    loss_tail_control=m_loss_c, loss_tail_intervene=m_loss_i,
                    fwd_tail_control=m_fwd_c,  fwd_tail_intervene=m_fwd_i,
                    improvement_loss=(m_loss_c - m_loss_i),
                    improvement_fwd=(m_fwd_c - m_fwd_i)
                ))

    res = pd.DataFrame(results)
    maybe_save_csv(res, os.path.join(OUT_DIR,"exp3_results_grid.csv"), save_csv)

    agg = res.groupby(['rho_star','cap']).agg({
        'improvement_loss':['mean','std'],
        'improvement_fwd':['mean','std'],
    }).reset_index()
    agg.columns = ['rho_star','cap','improv_loss_mean','improv_loss_std','improv_fwd_mean','improv_fwd_std']
    maybe_save_csv(agg, os.path.join(OUT_DIR,"exp3_results_agg.csv"), save_csv)

    if print_tables:
        print_table("Exp3 — grid (per-seed improvements)", res)
        print_table("Exp3 — aggregated improvements by (rho*, cap)", agg)

    # summary pointplots
    avg = res.groupby(['rho_star','seed']).agg({
        'improvement_loss':'mean',
        'improvement_fwd':'mean'
    }).reset_index()

    plt.figure(figsize=(7,4.2))
    sns.pointplot(data=avg, x='rho_star', y='improvement_loss', errorbar='sd')
    plt.title('Exp3: Improvement (control - intervene) by rho*')
    plt.ylabel('Loss improvement (↑ better)'); plt.tight_layout()
    plt.savefig(os.path.join(OUT_DIR,"fig_exp3_summary.png"), dpi=220); plt.close()

    plt.figure(figsize=(7,4.2))
    sns.pointplot(data=avg, x='rho_star', y='improvement_fwd', errorbar='sd')
    plt.title('Exp3: Forward-error Δ by rho*')
    plt.ylabel('Fwd-error change (≈0 is neutral)'); plt.tight_layout()
    plt.savefig(os.path.join(OUT_DIR,"fig_exp3_summary_fwd.png"), dpi=220); plt.close()

    return res, agg, series_pngs

# ------------------------------
# 6) Main orchestrator
# ------------------------------
def main(fast_mode=False, run1=True, run2=False, run3=False,
         save_csv: bool = False, print_tables: bool = True):
    """
    Orchestrator
    - save_csv: also write CSVs alongside terminal tables
    - print_tables: pretty-print tables to terminal
    """
    # ---- configs ----
    if fast_mode:
        cfg = ExpConfig(
            seeds=[0, 1, 2],
            precision=['fp32', 'fp16', 'bf16'],
            dims=[96],
            steps_exp1=60, steps_exp2=300, steps_exp3=200,
            ksoft_rows=128, kscore_mats=32,
            iters_exp1=5, iters_exp2=5, iters_exp3=5,
            batch_size=256
        )
    else:
        cfg = ExpConfig(
            seeds=[0, 1, 2, 3, 4],
            precision=['fp32', 'bf16', 'fp16'],
            dims=[96, 128],
            steps_exp1=160, steps_exp2=800, steps_exp3=500,
            ksoft_rows=256, kscore_mats=64,
            iters_exp1=7, iters_exp2=7, iters_exp3=7,
            batch_size=256 if device.type == "cuda" else 128
        )

    # ---- dump config ----
    ensure_dir(OUT_DIR)
    with open(os.path.join(OUT_DIR, "run_config.json"), "w") as f:
        json.dump(asdict(cfg), f, indent=2)

    # ---- data ----
    trainloader = build_dataloader(batch_size=cfg.batch_size)

    # ---- orchestrate ----
    results = {}

    if run1:
        res_raw, res_cfg, stats = run_exp1_decisive(
            trainloader, cfg, save_csv=save_csv, print_tables=print_tables
        )
        if save_csv:
            results['exp1_raw'] = os.path.join(OUT_DIR, "exp1_results_raw.csv")
            results['exp1_cfg'] = os.path.join(OUT_DIR, "exp1_results_by_config.csv")
            results['exp1_raw_eps'] = os.path.join(OUT_DIR, "exp1_results_raw_with_eps.csv")
            results['exp1_stats'] = os.path.join(OUT_DIR, "exp1_proxy_stats.csv")
            results['exp1_stats_by_prec_raw'] = os.path.join(OUT_DIR, "exp1_proxy_stats_by_precision_raw.csv")
            results['exp1_stats_mixed_with_eps'] = os.path.join(OUT_DIR, "exp1_proxy_stats_mixed_with_eps.csv")
            results['exp1_pointwise_stats_by_precision'] = os.path.join(OUT_DIR,"exp1_pointwise_stats_by_precision.csv")
            results['exp1_step_logs'] = os.path.join(OUT_DIR, "exp1_step_logs.csv")

    if run2:
        df_all, lag_df = run_exp2_early_warning(
            trainloader, cfg, precision='fp16',
            save_csv=save_csv, print_tables=print_tables
        )
        if save_csv:
            results['exp2_logs'] = os.path.join(OUT_DIR, "exp2_logs_all.csv")
            results['exp2_leadlag_both'] = os.path.join(OUT_DIR, "exp2_leadlag_both_targets.csv")
            results['exp2_leadlag_loss'] = os.path.join(OUT_DIR, "exp2_leadlag_loss.csv")
            results['exp2_leadlag_fwd'] = os.path.join(OUT_DIR, "exp2_leadlag_fwd.csv")

    if run3:
        print("[INFO] Creating dedicated dataloaders for Exp3...")
        g_base = torch.Generator(device='cpu'); g_base.manual_seed(42)
        g1 = torch.Generator(device='cpu'); g1.set_state(g_base.get_state())
        g2 = torch.Generator(device='cpu'); g2.set_state(g_base.get_state())
        dl_control   = build_dataloader(batch_size=cfg.batch_size, generator=g1)
        dl_intervene = build_dataloader(batch_size=cfg.batch_size, generator=g2)

        res_grid, res_agg, pngs = run_exp3_intervention(
            dl_control, dl_intervene, cfg, precision='fp16',
            rho_grid=(0.5, 0.6, 0.7), cap_grid=(5e-3, 1e-2),
            save_csv=save_csv, print_tables=print_tables
        )
        if save_csv:
            results['exp3_grid'] = os.path.join(OUT_DIR, "exp3_results_grid.csv")
            results['exp3_agg'] = os.path.join(OUT_DIR, "exp3_results_agg.csv")

    # ---- summary ----
    print("\nSaved PNGs/CSVs in:", OUT_DIR)
    for k, v in results.items():
        print(f"  - {k}: {v}")

    print("\nFigures:")
    figs = [
        "fig_exp1.png",
        "fig_exp1_raw_scatter.png",
        "fig_exp1_mixed_eps_scatter.png",
        "fig_exp2_seed*_{loss,fwd}.png",
        "fig_exp2_leadlag_seed*_{loss,fwd}.png",
        "fig_exp3_seed*_rho*_cap*.png",
        "fig_exp3_summary.png",
        "fig_exp3_summary_fwd.png",
    ]
    for f in figs:
        print("  -", f)

    print("\nDone.")

if __name__ == "__main__":
    # Terminal tables on, CSV saving off
    main(fast_mode=False, run1=True, run2=True, run3=True,
         save_csv=False, print_tables=True)

KeyboardInterrupt: 