# ==============================================
# Notebook for Logit Lens and Hidden Acts ======
# ==============================================

In [None]:
from pathlib import Path
from datasets import load_dataset, DownloadMode
import torch
import os
import glob
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

from old_lens.mi_utils.quant_configs.bnb_configs import load_bnb_in_4bit

# ==============================================
# Models & Dataset =============================
# ==============================================

In [None]:
filepath = r'D:\LogitLensData\nq'

destination_path = str(Path(filepath))
nq_dataset = load_dataset(
    'sentence-transformers/natural-questions',
    split={
        'train': 'train[:1000]'
    },
    cache_dir=destination_path,
    download_mode=DownloadMode.REUSE_DATASET_IF_EXISTS,
    keep_in_memory=True
)

nq_queries = nq_dataset['train']['query']
nq_answers = nq_dataset['train']['answer']

In [None]:
nq_1000 = nq_queries[:1000]

In [None]:
nq_500 = nq_queries[:500]

In [None]:
from enum import Enum

class Models(Enum):
    LAIN8B = "Models/LLaMA3Instruct"
    HF100B = "Models/HF1BitLLM100Btokens"


class Names(Enum):
    LAIN8B = "Meta-Llama-3-8B-Instruct-fp"
    HF100B = "Llama3-8B-1.58-100B-tokens"

In [None]:
def load_model_and_tok(
    model_name: str,
    low_cpu_mem_usage: bool = True,
    local_files_only: bool = True,
    device_map: str = "cpu",
    dtype: torch.dtype = torch.float32,
    load_in_8bit: bool = False
):
    tok = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        output_hidden_states=True,
        return_dict_in_generate=True,
        return_dict=True,
        output_attentions=True,
        low_cpu_mem_usage=low_cpu_mem_usage,
        local_files_only=local_files_only,
        device_map=device_map,
        load_in_8bit=load_in_8bit,
        torch_dtype=dtype,
        attn_implementation="eager",
    )
    return model, tok

In [None]:
model_orig, orig_tokenizer = load_model_and_tok(Models.LAIN8B.value) 

In [None]:
eos_token_id = orig_tokenizer.eos_token_id
bos_token_id = orig_tokenizer.bos_token_id
print(f"eos: {eos_token_id}\nbos: {bos_token_id}")

In [None]:
model_8bit, orig_tokenizer = load_model_and_tok(Models.LAIN8B.value, dtype=torch.float16, load_in_8bit=True) 

In [None]:
model_4bit, orig_tokenizer = load_bnb_in_4bit(Models.LAIN8B.value, double_quant=False, dtype=torch.float16, device_map="cpu") 

In [None]:
model_quant, quant_tokenizer = load_model_and_tok(Models.HF100B.value) 

In [None]:
lengths = []
for q in nq_500:
    ids = orig_tokenizer.encode(q, add_special_tokens=True)
    lengths.append(len(ids))

print("=== Token Length Statistics (LLaMA-3-8B-Instruct tokenizer) ===")
print(f"Samples analyzed: {len(lengths)}")
print(f"Mean length:       {np.mean(lengths):.2f}")
print(f"Median length:     {np.median(lengths):.0f}")
print(f"90th percentile:   {np.percentile(lengths, 90):.0f}")
print(f"95th percentile:   {np.percentile(lengths, 95):.0f}")
print(f"Max observed len:  {np.max(lengths)}")
print(f"Min observed len:  {np.min(lengths)}")

In [None]:
import inspect
sig = inspect.signature(model_orig.model.layers[0].forward)
print(sig)

In [None]:
import os
import torch
from transformers import AutoModelForCausalLM, BitsAndBytesConfig

MODEL = ""
SAVE_DIR = "weights_extracted" 
os.makedirs(SAVE_DIR, exist_ok=True)


def extract_and_save_weights(model_name: str, model_cfg: dict):
    print(f"\nüîπ Extracting weights from {model_name} ...")

    model = AutoModelForCausalLM.from_pretrained(
        model_cfg["path"],
        quantization_config=model_cfg["quant"],
        device_map="auto"
    )

    out_dir = os.path.join(SAVE_DIR, model_name)
    os.makedirs(out_dir, exist_ok=True)

    for name, param in model.named_parameters():
        w = param.detach().clone().to(torch.float32).cpu()

        fname = name.replace(".", "_") + ".pt"
        torch.save(w, os.path.join(out_dir, fname))

    print(f"Saved float32 weights for {model_name} to {out_dir}")

    # small sanity check
    total_params = sum(p.numel() for p in model.parameters())
    print(f"   Total parameters: {total_params:,}")
    print(f"   Example tensor dtype: {w.dtype}, shape: {tuple(w.shape)}")



for model_name, model_cfg in MODEL.items():
    extract_and_save_weights(model_name, model_cfg)

In [None]:
import torch, glob

base_w = torch.load("weights_extracted/base_fp32/model_layers_0_self_attn_q_proj_weight.pt")
bnb_w = torch.load("weights_extracted/bnb_8bit/model_layers_0_self_attn_q_proj_weight.pt")

print(base_w.dtype, bnb_w.dtype)
print("Mean abs diff:", (base_w - bnb_w).abs().mean().item())
print("Cosine similarity:", torch.nn.functional.cosine_similarity(
    base_w.flatten(), bnb_w.flatten(), dim=0
).item())

# ==============================================
# Logit Lens with Normaliztion =================
# ==============================================

In [None]:
import torch
import torch.nn.functional as F
from pathlib import Path
from tqdm import tqdm
import numpy as np
import gc
import warnings
warnings.filterwarnings("ignore", message="MatMul8bitLt: inputs will be cast")


# ============================================================
# Normalization modes
# ============================================================
def apply_normalization(x, model, normalize_mode="raw", block=None, layer_index=None):
    x = x.to(torch.float32) 
    
    if normalize_mode == "raw":
        return x

    elif normalize_mode == "unit_rms":
        if layer_index == -1:
            return x
        elif block is not None:
            eps = 1e-5
            return x / (x.pow(2).mean(dim=-1, keepdim=True).add(eps).sqrt())
        else:
            return x

    elif normalize_mode == "rms_layer":
        if layer_index == -1:
            return x
        elif block is not None:
            return block.post_attention_layernorm(x.to(torch.float32))
        else:
            return x

    elif normalize_mode == "norm_rms":
        if layer_index == -1:
            return x
        elif block is not None:
            return model.model.norm(x).to(torch.float32)
        else:
            return x

    else:
        raise ValueError(f"Unknown normalization_mode: {normalize_mode}")


# ============================================================
# Helper: causal + padding mask for LLaMA blocks
# ============================================================
def build_full_attention_mask(input_ids, attention_mask, device, model=None):
    bsz, seq_len = input_ids.shape

    causal = torch.triu(
        torch.ones((seq_len, seq_len), device=device, dtype=torch.bool), diagonal=1
    ).unsqueeze(0).unsqueeze(0)

    if attention_mask is None:
        padding_mask = torch.zeros((bsz, 1, 1, seq_len), device=device, dtype=torch.bool)
    else:
        padding_mask = (attention_mask[:, None, None, :] == 0)

    full = causal | padding_mask

    attn_impl = getattr(getattr(model, "config", None), "attn_implementation", None)
    attn_impl = attn_impl or "eager" 

    if attn_impl in ["flash_attention_2", "sdpa"]:
        return full.to(torch.bool)
    else:
        return full.to(torch.float32) * -1e9


# ============================================================
# Collector with multiple normalization variants + attention weights (optional storage)
# ============================================================
@torch.no_grad()
def collect_logit_lens_full(
    model,
    tokenizer,
    prompts,
    batch_index=0,
    max_len=17,
    device=None,
    clamp_logits=False,         
    clamp_value=100.0,         
    save_path=None,
    collect_attn=True,
    save_attn=True,
    norm_modes=("raw", "unit_rms", "norm_rms"),
):

    device = device or ("cuda" if torch.cuda.is_available() else "cpu")

    bnb_layer_types = ("Linear4bit", "Linear8bitLt")
    is_quantized = any(any(name in type(m).__name__ for name in bnb_layer_types)
                       for m in model.modules())

    if is_quantized:
        try:
            first_param_device = next(model.parameters()).device
        except StopIteration:
            first_param_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        print(f"[info] Detected quantized model ‚Üí device {first_param_device}")
        model.eval()
    else:
        model = model.to(device).eval()

    Path(save_path).parent.mkdir(parents=True, exist_ok=True)

    # ============================================================
    # Tokenization
    # ============================================================
    encoded = []
    pad_id = tokenizer.pad_token_id or tokenizer.eos_token_id
    bos_id = tokenizer.bos_token_id
    eos_id = tokenizer.eos_token_id

    for p in prompts:
        ids = tokenizer.encode(p, add_special_tokens=False)
        content = ids[: max_len - 2]
        ids = torch.tensor([bos_id] + content + [eos_id], dtype=torch.long)
        if len(ids) < max_len:
            ids = F.pad(ids, (0, max_len - len(ids)), value=pad_id)
        encoded.append(ids)

    input_ids = torch.stack(encoded, dim=0).to(device)

    # ============================================================
    # Build attention mask (stop after first EOS)
    # ============================================================
    attention_mask = torch.ones_like(input_ids, dtype=torch.long)
    for i, ids in enumerate(input_ids):
        eos_positions = (ids == eos_id).nonzero(as_tuple=True)[0]
        if len(eos_positions) > 0:
            eos_pos = eos_positions[0].item()
            attention_mask[i, eos_pos + 1:] = 0

    batch_size, seq_len = input_ids.shape
    full_mask = build_full_attention_mask(input_ids, attention_mask, device)
    #print(f"[mask] dtype={full_mask.dtype}, shape={tuple(full_mask.shape)}, example={full_mask.flatten()[0].item()}")
    print(
        f"[mask] dtype={full_mask.dtype}, shape={tuple(full_mask.shape)}, "
        f"min={full_mask.min().item()}, max={full_mask.max().item()}, "
        f"unique={torch.unique(full_mask)}"
    )
    assert (full_mask == 0).sum() < full_mask.numel(), "Mask seems to contain only zeros ‚Äî check logic!"

    position_ids = torch.arange(seq_len, device=device).unsqueeze(0).expand(batch_size, -1)
    vocab_size = model.lm_head.out_features

    print(f"[info] Tokenized {batch_size} prompts | seq_len={seq_len}")
    print(f"[info] Collecting from {len(model.model.layers)} layers | quantized={is_quantized}")

    # ============================================================
    # Projection helper (with consistent upcasting and sanitization)
    # ============================================================
    def project(x):
        x_fp32 = x.to(torch.float32)

        head_dtype = next(model.lm_head.parameters()).dtype
        x_cast = x_fp32.to(head_dtype)

        logits = model.lm_head(x_cast)

        logits = torch.nan_to_num(logits, nan=0.0, posinf=0.0, neginf=0.0)

        if clamp_logits and is_quantized:
            logits = logits.clamp(-clamp_value, clamp_value)

        return logits.to(torch.float32)

    # ============================================================
    # Collection containers
    # ============================================================
    rows = []
    all_hidden, all_logits, all_attn = {}, {}, {}

    # ============================================================
    # Embedding layer
    # ============================================================
    x = model.model.embed_tokens(input_ids).to(torch.float32)
    hidden_variants = {mode: apply_normalization(x.clone(), model, mode, layer_index=-1)
                       for mode in norm_modes}
    logits_variants = {mode: project(hidden_variants[mode]) for mode in norm_modes}

    for mode in norm_modes:
        all_hidden[f"embed_tokens_{mode}"] = hidden_variants[mode].cpu()
        all_logits[f"embed_tokens_{mode}"] = logits_variants[mode].cpu()

    for i in range(batch_size):
        rows.append({
            "prompt_id": i,
            "prompt_text": prompts[i],
            "batch_index": batch_index,
            "vocab_size": vocab_size,
            "layer_index": -1,
            "layer_name": "embed_tokens",
            "input_ids": input_ids[i].cpu(),
            "target_ids": input_ids[i, 1:].cpu(),
            "attention_mask": attention_mask[i].cpu(),
            **{f"hidden_{m}": hidden_variants[m][i, :-1].cpu() for m in norm_modes},
            **{f"logits_{m}": logits_variants[m][i, :-1].cpu() for m in norm_modes},
        })

    # ============================================================
    # Transformer layers
    # ============================================================
    for li, block in enumerate(model.model.layers):
        out = block(
            x, position_ids=position_ids,
            attention_mask=full_mask,
            output_attentions=collect_attn,
        )
        x = out[0]
        attn = out[1] if collect_attn else None

        layer_output = x.detach().clone().to(torch.float32)
        hidden_variants = {
            mode: apply_normalization(layer_output.clone(), model, mode, block=block, layer_index=li)
            for mode in norm_modes
        }
        logits_variants = {mode: project(hidden_variants[mode]) for mode in norm_modes}

        for mode in norm_modes:
            hidden_variants[mode] = hidden_variants[mode][:, :-1, :]
            logits_variants[mode] = logits_variants[mode][:, :-1, :]

        for mode in norm_modes:
            all_hidden[f"layer.{li}_{mode}"] = hidden_variants[mode].cpu()
            all_logits[f"layer.{li}_{mode}"] = logits_variants[mode].cpu()

        for i in range(batch_size):
            record = {
                "prompt_id": i,
                "prompt_text": prompts[i],
                "batch_index": batch_index,
                "vocab_size": vocab_size,
                "layer_index": li,
                "layer_name": f"layer.{li}",
                "input_ids": input_ids[i].cpu(),
                "target_ids": input_ids[i, 1:].cpu(),
                "attention_mask": attention_mask[i].cpu(),
                **{f"hidden_{m}": hidden_variants[m][i].cpu() for m in norm_modes},
                **{f"logits_{m}": logits_variants[m][i].cpu() for m in norm_modes},
            }
            if save_attn and attn is not None:
                record["attn"] = attn[i].cpu()
            rows.append(record)

        torch.cuda.empty_cache()
        gc.collect()

    # ============================================================
    # Final RMSNorm output
    # ============================================================
    x = model.model.norm(x.to(torch.float32))
    h = x
    l = project(h)
    h, l = h[:, :-1, :], l[:, :-1, :]

    all_hidden["output_true"] = h.cpu()
    all_logits["output_true"] = l.cpu()

    for mode in norm_modes:
        all_hidden[f"output_{mode}"] = h.cpu()
        all_logits[f"output_{mode}"] = l.cpu()

    for i in range(batch_size):
        rows.append({
            "prompt_id": i,
            "prompt_text": prompts[i],
            "batch_index": batch_index,
            "vocab_size": vocab_size,
            "layer_index": len(model.model.layers),
            "layer_name": "output",
            "input_ids": input_ids[i].cpu(),
            "target_ids": input_ids[i, 1:].cpu(),
            "attention_mask": attention_mask[i].cpu(),
            **{f"hidden_{m}": h[i].cpu() for m in norm_modes},
            **{f"logits_{m}": l[i].cpu() for m in norm_modes},
        })


    # ============================================================
    # Save and finish 
    # ============================================================
    if save_path:
        torch.save(rows, save_path)
        print(f"[saved] Logit-lens data ‚Üí {save_path}")


    print(f"[info] Model has {len(model.model.layers)} transformer blocks (plus embedding + output).")

    return rows, all_hidden, all_logits, all_attn

In [None]:
import torch
import gc
from tqdm import tqdm

def run_logit_lens_in_batches(
    model,
    tokenizer,
    all_prompts,
    batch_size=20,
    save_prefix="logitlens_batch",
    max_len=17,
    normalize_mode=("raw", "unit_rms", "norm_rms"),
    device=None,
    clamp_logits=False,
    collect_attn=False,
    save_attn=False 
):

    num_batches = (len(all_prompts) + batch_size - 1) // batch_size
    print(f"[run] Processing {len(all_prompts)} prompts in {num_batches} batches of {batch_size}")

    for batch_idx in tqdm(range(num_batches), desc="Running logit lens batches"):
        start = batch_idx * batch_size
        end = min((batch_idx + 1) * batch_size, len(all_prompts))
        batch_prompts = all_prompts[start:end]

        save_path = f"{save_prefix}_batch{batch_idx:03d}.pt"

        print(f"\n[batch {batch_idx+1}/{num_batches}] {len(batch_prompts)} prompts ‚Üí {save_path}")

        try:
            rows, hidden_dict, logits_dict, all_attn = collect_logit_lens_full(
                model=model,
                tokenizer=tokenizer,
                prompts=batch_prompts,
                max_len=max_len,
                device=device,
                norm_modes=normalize_mode,
                save_path=save_path,
                clamp_logits=clamp_logits,
                collect_attn=collect_attn,
                save_attn=save_attn,
            )

        except RuntimeError as e:
            print(f"[error] Batch {batch_idx} failed: {e}")
            continue

        del rows, hidden_dict, logits_dict, all_attn, batch_prompts
        torch.cuda.empty_cache()
        gc.collect()

    print("\n[done] All batches processed and saved.")

In [None]:
run_logit_lens_in_batches(
    model=model_8bit,
    tokenizer=orig_tokenizer,
    all_prompts=nq_500,
    batch_size=10,
    max_len=17,
    normalize_mode=("raw", "unit_rms", "norm_rms"), 
    save_prefix="saved_data/lens_data/m_8bit/m_8bit_modes",
    device="cpu",
    clamp_logits=False,
    collect_attn=False,
    save_attn=False  
) 

In [None]:
ll_data = torch.load("saved_data/lens_data/m_orig/m_orig_modes_batch000.pt", weights_only=False, map_location="cpu")

In [None]:
ll_data

In [None]:
ll_data_df = pd.DataFrame(ll_data)

In [None]:
ll_data_df.head()

In [None]:
ll_data_df.tail()

In [None]:
ll_data_df["layer_name"].unique()

In [None]:
ll_data_df.columns

In [None]:
ll_data_df.isna().sum()

In [None]:
print(orig_tokenizer.decode([128000]))
print(orig_tokenizer.decode([128009]))

In [None]:
print(orig_tokenizer.decode([128000, 9906, 1917, 128009]))

In [None]:
print(orig_tokenizer.bos_token_id, orig_tokenizer.eos_token_id)
print(model_8bit.config.bos_token_id, model_8bit.config.eos_token_id)

In [None]:
ids = orig_tokenizer.encode("Hello world", add_special_tokens=True)
print(ids)

# ==============================================
# TopK Comparison ==============================
# ==============================================

In [None]:
import os, gc, torch
import torch.nn.functional as F
import pandas as pd


# ============================================================
# Cache BOS‚ÄìEOS valid indices per prompt
# ============================================================
_mask_cache = {}


def preprocess_metrics(metrics, lens_type="raw"):
    """Trim logits/hidden/targets to BOS‚ÄìEOS span, cached per prompt_id."""
    processed = []
    for row in metrics:
        pid = row.get("prompt_id")
        logits = row.get(f"logits_{lens_type}")
        hidden = row.get(f"hidden_{lens_type}")
        attn_mask = row.get("attention_mask")
        targets = row.get("target_ids")

        if logits is None or targets is None or attn_mask is None:
            continue

        # reuse cached BOS‚ÄìEOS mask
        if pid in _mask_cache:
            valid_pos = _mask_cache[pid]
        else:
            if not isinstance(attn_mask, torch.Tensor):
                attn_mask = torch.tensor(attn_mask)
            if attn_mask.ndim == 4:
                attn_mask = attn_mask[:, 0, 0, :]
            elif attn_mask.ndim == 1:
                attn_mask = attn_mask.unsqueeze(0)
            attn_mask = attn_mask.to(torch.bool)

            mask_1d = attn_mask[0]
            true_pos = mask_1d.nonzero(as_tuple=True)[0]
            if true_pos.numel() < 2:
                continue

            bos_idx, eos_idx = int(true_pos[0]), int(true_pos[-1])
            eval_mask = torch.zeros_like(mask_1d, dtype=torch.bool)
            if eos_idx > bos_idx + 1:
                eval_mask[bos_idx + 1:eos_idx] = True
            valid_pos = eval_mask.nonzero(as_tuple=True)[0]
            if valid_pos.numel() == 0:
                continue

            _mask_cache[pid] = valid_pos

        if logits.ndim == 2:
            logits = logits.unsqueeze(0)
        if targets.ndim == 1:
            targets = targets.unsqueeze(0)
        if hidden is not None and hidden.ndim == 2:
            hidden = hidden.unsqueeze(0)

        logits_trim = logits[:, valid_pos, :].contiguous()
        targets_trim = targets[:, valid_pos].contiguous()
        hidden_trim = hidden[:, valid_pos, :].contiguous() if hidden is not None else None

        row_out = dict(row)
        row_out[f"logits_{lens_type}"] = logits_trim
        row_out[f"hidden_{lens_type}"] = hidden_trim
        row_out["target_ids"] = targets_trim
        processed.append(row_out)

    return processed



# ============================================================
# Compute metrics and top-k similarities for A vs B
# ============================================================
@torch.no_grad()
def compute_topk(
    metrics_A,
    metrics_B,
    norm_modes=("raw", "unit_rms", "norm_rms"),
    topk=(1, 5, 10, 20),
    device="cpu",
    eps = 1e-12,
    output_dir="logs/new_summary",
    run_name=None,
    batch_idx=None,
    debug=False,
):
    os.makedirs(output_dir, exist_ok=True)
    device = device or ("cuda" if torch.cuda.is_available() else "cpu")

    # --- safe clean helper ---
    """def clean(x):
        if isinstance(x, torch.Tensor):
            x = x.detach().to("cpu", copy=False).float()
            x = torch.nan_to_num(x, nan=0.0, posinf=0.0, neginf=0.0)
            return x.tolist()
        return x"""
    def clean(x):
        if isinstance(x, torch.Tensor):
            x = x.detach().to("cpu", copy=False).float()
            x = torch.nan_to_num(x, nan=0.0, posinf=0.0, neginf=0.0)
            return x.view(-1).tolist()
        elif isinstance(x, (list, np.ndarray)):
            return [float(v) if np.isfinite(v) else 0.0 for v in x]
        else:
            return float(x) if np.isfinite(x) else 0.0

    # --- preprocess all normalization modes ---
    proc_modes = {
        m: (preprocess_metrics(metrics_A, m), preprocess_metrics(metrics_B, m))
        for m in norm_modes
    }

    if debug:
        print(f"[mask cache] built for {len(_mask_cache)} prompt_ids")
        for pid, mask in list(_mask_cache.items())[:10]: 
            print(f"  prompt_id={pid:<5} ‚Üí len={len(mask)}  positions={mask.tolist()}")
        lengths = [len(v) for v in _mask_cache.values()]
        print(f"  unique mask lengths: {sorted(set(lengths))}")

    rec_map = {}

    # --- main computation ---
    for mode in norm_modes:
        A_trim, B_trim = proc_modes[mode]
        if not A_trim or not B_trim:
            if debug:
                print(f"[skip] no valid rows for mode={mode}")
            continue

        for rA in A_trim:
            pid, lid = rA["prompt_id"], rA["layer_index"]
            rB = next((r for r in B_trim if r.get("prompt_id") == pid and r.get("layer_index") == lid), None)
            if rB is None:
                continue

            key = (pid, lid)
            record = rec_map.setdefault(
                key,
                dict(
                    prompt_id=pid,
                    batch_index=rA.get("batch_index", batch_idx),
                    layer_index=lid,
                    layer_name=rA.get("layer_name"),
                    prompt_text=rA.get("prompt_text"),
                ),
            )

            # --- tensor setup ---
            logits_A, logits_B = rA[f"logits_{mode}"], rB[f"logits_{mode}"]
            targets = rA["target_ids"]
            hidden_A, hidden_B = rA.get(f"hidden_{mode}"), rB.get(f"hidden_{mode}")

            if logits_A.ndim == 2: logits_A = logits_A.unsqueeze(0)
            if logits_B.ndim == 2: logits_B = logits_B.unsqueeze(0)
            if targets.ndim == 1: targets = targets.unsqueeze(0)

            logits_A, logits_B = logits_A.to(device).float(), logits_B.to(device).float()
            targets = targets.to(device)
            if hidden_A is not None: hidden_A = hidden_A.to(device).float()
            if hidden_B is not None: hidden_B = hidden_B.to(device).float()

            L = targets.size(1)
            logits_A, logits_B = logits_A[:, :L, :], logits_B[:, :L, :]
            if hidden_A is not None: hidden_A = hidden_A[:, :L, :]
            if hidden_B is not None: hidden_B = hidden_B[:, :L, :]

            vocab = max(logits_A.size(-1), logits_B.size(-1))
            if logits_A.size(-1) < vocab:
                logits_A = F.pad(logits_A, (0, vocab - logits_A.size(-1)))
            if logits_B.size(-1) < vocab:
                logits_B = F.pad(logits_B, (0, vocab - logits_B.size(-1)))

            # --- probability space ---
            logp_A = F.log_softmax(logits_A, dim=-1)
            logp_B = F.log_softmax(logits_B, dim=-1)

            # Compute probabilities cleanly and normalize to avoid drift
            pA = torch.exp(logp_A)
            pB = torch.exp(logp_B)

            # Explicit renormalization (prevents underflow/rounding issues)
            pA = pA / (pA.sum(-1, keepdim=True) + eps)
            pB = pB / (pB.sum(-1, keepdim=True) + eps)

            # unpack first dimension (batch=1)
            logp_A, logp_B = logp_A[0], logp_B[0]
            pA, pB, tgt = pA[0], pB[0], targets[0]

            # --- basic metrics ---
            kl_ab = torch.sum(pA * (logp_A - logp_B), dim=-1).clamp_min(0.0)
            kl_ba = torch.sum(pB * (logp_B - logp_A), dim=-1).clamp_min(0.0)
            js_div = 0.5 * (kl_ab + kl_ba)
            js_dist = torch.sqrt(torch.clamp(js_div, min=0.0) + eps)

            # move to cpu and clean
            kl_ab = clean(kl_ab)
            kl_ba = clean(kl_ba)
            js_div = clean(js_div)
            js_dist = clean(js_dist)

            # TVD and entropy
            tvd = clean(0.5 * torch.sum(torch.abs(pA - pB), dim=-1))
            entropy_A = clean(-torch.sum(pA * logp_A, dim=-1))
            entropy_B = clean(-torch.sum(pB * logp_B, dim=-1))


            # === per-position L2 ===
            if hidden_A is not None and hidden_B is not None:
                cosine = clean(F.cosine_similarity(hidden_A[0], hidden_B[0], dim=-1))
                l2_tensor = torch.sqrt(torch.sum((hidden_A[0] - hidden_B[0]) ** 2, dim=-1))
                l2 = clean(l2_tensor)
                #if debug:
                    #print(f"[debug] pid={pid} lid={lid} L2 shape={l2_tensor.shape} mean={l2_tensor.mean().item():.3g}")
            else:
                cosine, l2 = [0.0] * L, [0.0] * L

            # log-likelihood difference for ground-truth tokens
            logp_A_gt = torch.gather(logp_A, -1, tgt.unsqueeze(-1)).squeeze(-1)
            logp_B_gt = torch.gather(logp_B, -1, tgt.unsqueeze(-1)).squeeze(-1)
            logp_diff = logp_A_gt - logp_B_gt

            # probability assignments for the same ground-truth tokens
            p_A_gt = torch.exp(logp_A_gt)
            p_B_gt = torch.exp(logp_B_gt)
            p_diff = p_A_gt - p_B_gt

            # clean everything for storage
            logp_A_gt = clean(logp_A_gt)
            logp_B_gt = clean(logp_B_gt)
            logp_diff = clean(logp_diff)
            p_A_gt = clean(p_A_gt)
            p_B_gt = clean(p_B_gt)
            p_diff = clean(p_diff)

            # === per-position cross-entropy and perplexity ===
            ce_A_pos = F.cross_entropy(
                logits_A.view(-1, vocab), targets.view(-1), reduction="none"
            ).view(targets.shape)
            ce_B_pos = F.cross_entropy(
                logits_B.view(-1, vocab), targets.view(-1), reduction="none"
            ).view(targets.shape)

            # Compute per-position perplexity
            ppl_A_pos = torch.exp(ce_A_pos)
            ppl_B_pos = torch.exp(ce_B_pos)

            # Convert everything safely to Python lists
            ppl_A_pos = clean(ppl_A_pos)
            ppl_B_pos = clean(ppl_B_pos)

            # Compute per-position difference as list comprehension to avoid tensor ops
            ppl_diff = [a - b for a, b in zip(ppl_A_pos, ppl_B_pos)]

            record.update(
                {
                    f"kl_ab_{mode}": kl_ab,
                    f"kl_ba_{mode}": kl_ba,
                    f"js_div_{mode}": js_div,
                    f"js_dist_{mode}": js_dist,
                    f"tvd_{mode}": tvd,
                    f"entropy_A_{mode}": entropy_A,
                    f"entropy_B_{mode}": entropy_B,
                    f"cosine_sim_{mode}": cosine,
                    f"l2_dist_{mode}": l2,   
                    f"logp_diff_{mode}": logp_diff,
                    f"logp_A_gt_{mode}": logp_A_gt,
                    f"logp_B_gt_{mode}": logp_B_gt,
                    f"logp_diff_{mode}": logp_diff,
                    f"p_A_gt_{mode}": p_A_gt,
                    f"p_B_gt_{mode}": p_B_gt,
                    f"p_diff_{mode}": p_diff,
                    f"ppl_A_{mode}": ppl_A_pos,
                    f"ppl_B_{mode}": ppl_B_pos,
                    f"ppl_diff_{mode}": ppl_diff,
                }
            )

            # === top-k metrics ===
            fams = {key: {} for key in [
                f"acc_A_{mode}", f"acc_B_{mode}",
                f"jaccard_{mode}", f"disagree_set_{mode}", f"agree_set_{mode}",
                f"agree_correct_{mode}", f"disagree_correct_{mode}",
                f"agree_wrong_{mode}", f"prob_overlap_{mode}",
                f"top_pred_ids_A_{mode}", f"top_pred_vals_A_{mode}",
                f"top_pred_ids_B_{mode}", f"top_pred_vals_B_{mode}"
            ]}

            max_k = max(topk)
            top_vals_A, top_idx_A = torch.topk(pA, max_k, -1)
            top_vals_B, top_idx_B = torch.topk(pB, max_k, -1)

            top_vals_A, top_vals_B = top_vals_A.cpu(), top_vals_B.cpu()
            top_idx_A, top_idx_B = top_idx_A.cpu(), top_idx_B.cpu()
            tgt_cpu = tgt.cpu()

            for k in topk:
                tkA, tkB = top_idx_A[:, :k], top_idx_B[:, :k]
                tvA, tvB = top_vals_A[:, :k], top_vals_B[:, :k]

                acc_A = (tkA == tgt_cpu.unsqueeze(1)).any(1).float()
                acc_B = (tkB == tgt_cpu.unsqueeze(1)).any(1).float()

                # --- set overlap ---
                inter = torch.tensor(
                    [len(set(tkA[i].tolist()) & set(tkB[i].tolist())) for i in range(L)],
                    dtype=torch.float32,
                )
                jaccard = inter / (2 * k - inter + eps)
                disagree_set = 1.0 - jaccard
                agree_set = (inter > 0).float()

                # --- correctness relations ---
                agree_correct = acc_A * acc_B
                disagree_correct = ((acc_A + acc_B).round() == 1).float()  # XOR
                agree_wrong = ((1 - acc_A) * (1 - acc_B)).float()

                # --- probability mass overlap ---
                pmA, pmB = tvA.sum(1), tvB.sum(1)
                shared_mass = torch.zeros_like(pmA)
                for i in range(L):
                    shared = set(tkA[i].tolist()) & set(tkB[i].tolist())
                    if shared:
                        shared_mass[i] = 0.5 * (
                            pA[i, list(shared)].sum().cpu() + pB[i, list(shared)].sum().cpu()
                        )
                prob_overlap = shared_mass / (0.5 * (pmA + pmB) + eps)

                # --- save results ---
                fams[f"acc_A_{mode}"][f"@{k}"] = acc_A.tolist()
                fams[f"acc_B_{mode}"][f"@{k}"] = acc_B.tolist()
                fams[f"jaccard_{mode}"][f"@{k}"] = jaccard.tolist()
                fams[f"disagree_set_{mode}"][f"@{k}"] = disagree_set.tolist()
                fams[f"agree_set_{mode}"][f"@{k}"] = agree_set.tolist()
                fams[f"agree_correct_{mode}"][f"@{k}"] = agree_correct.tolist()
                fams[f"disagree_correct_{mode}"][f"@{k}"] = disagree_correct.tolist()
                fams[f"agree_wrong_{mode}"][f"@{k}"] = agree_wrong.tolist()
                fams[f"prob_overlap_{mode}"][f"@{k}"] = prob_overlap.tolist()

                # top predictions
                if k == 1:
                    fams[f"top_pred_ids_A_{mode}"][f"@{k}"] = [int(x) for x in tkA[:, 0].tolist()]
                    fams[f"top_pred_vals_A_{mode}"][f"@{k}"] = [float(x) for x in tvA[:, 0].tolist()]
                    fams[f"top_pred_ids_B_{mode}"][f"@{k}"] = [int(x) for x in tkB[:, 0].tolist()]
                    fams[f"top_pred_vals_B_{mode}"][f"@{k}"] = [float(x) for x in tvB[:, 0].tolist()]
                else:
                    fams[f"top_pred_ids_A_{mode}"][f"@{k}"] = [[int(x) for x in arr] for arr in tkA.tolist()]
                    fams[f"top_pred_vals_A_{mode}"][f"@{k}"] = [[float(x) for x in arr] for arr in tvA.tolist()]
                    fams[f"top_pred_ids_B_{mode}"][f"@{k}"] = [[int(x) for x in arr] for arr in tkB.tolist()]
                    fams[f"top_pred_vals_B_{mode}"][f"@{k}"] = [[float(x) for x in arr] for arr in tvB.tolist()]


            record.update(fams)

            # --- cleanup per iteration ---
            del logits_A, logits_B, logp_A, logp_B, pA, pB, hidden_A, hidden_B, targets
            gc.collect()
            if torch.cuda.is_available():
                torch.cuda.empty_cache()

    # --- finalize and save ---
    df = pd.DataFrame(list(rec_map.values()), dtype=object)
    out_path = os.path.join(output_dir, f"{run_name or 'run'}_batch{int(batch_idx or 0):03d}.parquet")
    df.to_parquet(out_path, index=False)

    if debug:
        print(f"[saved] {len(df)} rows ‚Üí {out_path}")
    return df


In [None]:
import os, gc, torch, psutil
import pandas as pd
from tqdm import tqdm


def mem_report(note=""):
    mem = psutil.virtual_memory()
    print(f"[mem] {note} used {mem.used/1e9:.1f} / {mem.total/1e9:.1f} GB")


@torch.no_grad()
def run_topk_streaming(
    dir_A,
    dir_B,
    output_dir="saved_data/topk",
    norm_modes=("raw","unit_rms","norm_rms"),
    topk=(1,5,10,20),
    device=None,
    run_name="run",
    debug=True
):
    os.makedirs(output_dir, exist_ok=True)
    device = device or ("cuda" if torch.cuda.is_available() else "cpu")

    files_A = sorted([f for f in os.listdir(dir_A) if f.endswith(".pt")])
    files_B = sorted([f for f in os.listdir(dir_B) if f.endswith(".pt")])
    assert len(files_A) == len(files_B), "Mismatch in number of files"

    print(f"[info] Found {len(files_A)} file pairs to process")

    for batch_idx, (fa, fb) in enumerate(tqdm(zip(files_A, files_B), total=len(files_A))):
        path_A = os.path.join(dir_A, fa)
        path_B = os.path.join(dir_B, fb)
        print(f"\n[batch {batch_idx}] {fa} vs {fb}")
        mem_report("before loading")

        metrics_A = torch.load(path_A, map_location="cpu")
        metrics_B = torch.load(path_B, map_location="cpu")

        print("  [compute] running compute_topk ...")
        df = compute_topk(
            metrics_A,
            metrics_B,
            norm_modes=norm_modes,
            topk=topk,
            device=device,
            output_dir=output_dir,
            run_name=run_name,
            batch_idx=batch_idx,
            debug=debug
        )

        print(f"  [saved] {run_name}_batch{batch_idx}.parquet")

        del df, metrics_A, metrics_B
        _mask_cache.clear()
        gc.collect()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        mem_report("after cleanup")


In [None]:
run_name = "m_orig_m_8bit"

run_topk_streaming(
    dir_A="saved_data/lens_data/m_orig",
    dir_B="saved_data/lens_data/m_8bit",
    output_dir="saved_data/topk/m_8bit",
    norm_modes=("raw", "unit_rms", "norm_rms"),
    topk=(1, 5, 10),
    #eos_token_id=128009,
    #bos_token_id=128000,
    run_name=run_name,
    device="cpu",
    debug=True
)

In [None]:
import pandas as pd

df = pd.read_parquet("saved_data/topk/m_8bit/m_orig_m_8bit_batch000.parquet")

print(df.shape)
print(df.columns.tolist()) 

In [None]:
df["ppl_diff_raw"][100]

In [None]:
df.isna().sum()

In [None]:
for m in ["raw", "unit_rms", "norm_rms"]:
    print(m, df.iloc[0][f"acc_A_{m}"])

In [None]:
df.head()

In [None]:
df.tail()

In [None]:
df["cosine_sim_norm_rms"][100]

In [None]:
df["cosine_sim_raw"][100]

In [None]:
df["l2_dist_raw"][0]

In [None]:
import numpy as np

cols = ["kl_ab_raw", "js_div_raw", "disagree_correct_raw"]
def arr_len(x):
    if isinstance(x, dict) and "@1" in x: 
        return len(x["@1"])
    if isinstance(x, (list, np.ndarray)):
        return len(x)
    return np.nan

for c in cols:
    df[f"len_{c}"] = df[c].apply(arr_len)

print(df[["prompt_id","layer_index","len_kl_ab_raw","len_js_div_raw","len_disagree_correct_raw"]].head(20))


# ==============================================
# Plot TopK Summaries ==========================
# ==============================================

In [None]:
import pandas as pd, numpy as np, matplotlib.pyplot as plt, seaborn as sns
from pathlib import Path

sns.set_theme(style="whitegrid", palette="deep")

BASE = Path("saved_data")
MODELS = ["m_4bit", "m_8bit", "m_quant"]
OUT_DIR = BASE / "figures_topk_modes_fixed_full"
OUT_DIR.mkdir(parents=True, exist_ok=True)

TOPK_METRICS = ["acc_A", "acc_B", "jaccard", "disagree_correct"]
CONT_METRICS = ["cosine_sim", "l2_dist", "tvd", "ppl_diff"]
MODES = ["raw", "unit_rms", "norm_rms"]
TOPK = [1, 5, 10]


def merge_parquet_files(input_dir: str) -> pd.DataFrame:
    files = sorted(Path(input_dir).glob("*.parquet"))
    if not files:
        raise FileNotFoundError(f"No parquet files found in {input_dir}")

    dfs = []
    for i, f in enumerate(files):
        d = pd.read_parquet(f)
        d["batch_index"] = i
        dfs.append(d)

    df = pd.concat(dfs, ignore_index=True)
    print(f"[merge] merged {len(files)} parquet files ‚Üí {len(df)} rows")

    n_batches = df["batch_index"].nunique()
    n_prompts = df["prompt_id"].nunique()
    n_pairs = df.groupby(["prompt_id", "batch_index"]).ngroups
    print(f"[diag] unique batch_index={n_batches} | unique prompt_id={n_prompts} | unique (prompt,batch)={n_pairs}")

    if n_pairs < len(files) * 10:
        print("[warn] fewer prompt‚Äìbatch pairs than expected ‚Äî possible ID overlap?")
    else:
        print("[ok] prompt‚Äìbatch pairs look good")

    return df


def extract_topk(df, metric, mode, k):
    base = f"{metric}_{mode}"
    if base not in df.columns:
        return np.full(len(df), np.nan)
    return df[base].apply(
        lambda d: np.mean(d.get(f"@{k}", [])) if isinstance(d, dict) and f"@{k}" in d else np.nan
    )


def extract_flat(df, metric, mode):
    col = f"{metric}_{mode}"
    if col not in df.columns:
        return np.full(len(df), np.nan)
    return df[col].apply(
        lambda v: np.mean(v) if isinstance(v, (list, np.ndarray)) and len(v) > 0 else np.nan
    )


dfs = []
for model in MODELS:
    topk_dir = BASE / "topk" / model
    df = merge_parquet_files(topk_dir)

    all_rows = []

    for metric in TOPK_METRICS:
        for mode in MODES:
            for k in TOPK:
                vals = extract_topk(df, metric, mode, k)
                dsub = pd.DataFrame({
                    "model": model,
                    "metric": metric,
                    "mode": mode,
                    "topk": k,
                    "layer_index": df.get("layer_index", pd.Series(np.zeros(len(df)))),
                    "value": vals
                })
                all_rows.append(dsub)

    for metric in CONT_METRICS:
        for mode in MODES:
            vals = extract_flat(df, metric, mode)
            dsub = pd.DataFrame({
                "model": model,
                "metric": metric,
                "mode": mode,
                "topk": 1,  
                "layer_index": df.get("layer_index", pd.Series(np.zeros(len(df)))),
                "value": vals
            })
            all_rows.append(dsub)

    dfs.append(pd.concat(all_rows, ignore_index=True))

df_long = pd.concat(dfs, ignore_index=True).dropna(subset=["value"])
print(f"[ok] merged all models ‚Üí {len(df_long)} rows total")
print(df_long.groupby(["metric", "mode"])["value"].count().unstack(fill_value=0))


for model in MODELS:
    dsub = df_long[df_long["model"] == model]
    if dsub.empty:
        print(f"[skip] no data for {model}")
        continue

    metrics_unique = dsub["metric"].unique().tolist()
    nrows = int(np.ceil(len(metrics_unique) / 2))
    fig, axes = plt.subplots(nrows, 2, figsize=(14, 4 * nrows))
    fig.suptitle(f"Top-K + Continuous Metrics ‚Äî {model}", fontsize=16, weight="bold")

    for ax, metric in zip(axes.flatten(), metrics_unique):
        d = dsub[dsub["metric"] == metric]
        sns.lineplot(
            data=d,
            x="layer_index",
            y="value",
            hue="mode",
            style="topk" if metric in TOPK_METRICS else None,
            markers=True,
            err_style="band",
            ax=ax
        )
        ax.set_title(metric.upper())
        ax.set_xlabel("Layer index")
        ax.set_ylabel("Mean value")
        ax.legend(fontsize=8)
        ax.tick_params(axis="x", rotation=0)

    plt.tight_layout()
    plt.subplots_adjust(top=0.9)
    out_path = OUT_DIR / f"overview_{model}.png"
    plt.savefig(out_path, dpi=300)
    plt.close()
    print(f"[saved] {out_path}")


# ==============================================
# TopK Correlations ============================
# ==============================================

### ==============================================
### Correlation with Pooling =====================
### ==============================================

In [None]:
import os
from pathlib import Path
import numpy as np
import pandas as pd
from scipy.stats import spearmanr, pearsonr, kendalltau, pointbiserialr, chi2_contingency
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns


# ============================================================
# Merge parquet files
# ============================================================
def merge_parquet_files(input_dir: str) -> pd.DataFrame:
    files = sorted(Path(input_dir).glob("*.parquet"))
    if not files:
        raise FileNotFoundError(f"No parquet files found in {input_dir}")

    dfs = []
    for i, f in enumerate(files):
        d = pd.read_parquet(f)
        d["batch_index"] = i
        dfs.append(d)

    df = pd.concat(dfs, ignore_index=True)
    print(f"[merge] merged {len(files)} parquet files ‚Üí {len(df)} rows")

    n_batches = df["batch_index"].nunique()
    n_prompts = df["prompt_id"].nunique()
    n_pairs = df.groupby(["prompt_id", "batch_index"]).ngroups
    print(f"[diag] unique batch_index={n_batches} | unique prompt_id={n_prompts} | unique (prompt,batch)={n_pairs}")

    if n_pairs < len(files) * 10:
        print("[warn] fewer prompt‚Äìbatch pairs than expected ‚Äî possible ID overlap?")
    else:
        print("[ok] prompt‚Äìbatch pairs look good")

    return df


# ============================================================
# Expand nested top-k metrics (safe names)
# ============================================================
def expand_topk_metrics(df, metrics, modes=("raw","unit_rms","norm_rms"), topk_levels=(1,5,10)):
    new_cols=[]
    for metric in metrics:
        for mode in modes:
            base=f"{metric}_{mode}"
            if base not in df.columns:
                continue
            for k in topk_levels:
                new=f"{base}_@{k}"
                df[new]=df[base].apply(
                    lambda d: np.array(d.get(f"@{k}",[]),float)
                    if isinstance(d,dict) and f"@{k}" in d else np.array([])
                )
                new_cols.append(new)
    print(f"[expand] added {len(new_cols)} flattened top-k columns")
    return df


# ============================================================
# Helper
# ============================================================
"""def safe_flatten(x):
    if isinstance(x,(list,np.ndarray)):
        return np.array(x,float).flatten()
    return np.array([],float)"""
def safe_flatten(v):
    """Ensure flattening of nested lists/arrays into 1D np.array."""
    if isinstance(v, (list, np.ndarray)):
        return np.asarray(v, dtype=float).flatten()
    try:
        return np.array([float(v)], dtype=float)
    except Exception:
        return np.array([], dtype=float)

# ============================================================
# Extract anchors 
# ============================================================
def preprocess_anchors(df, modes=("raw", "unit_rms", "norm_rms")):
    for mode in modes:
        col = f"disagree_correct_{mode}"
        if col not in df.columns:
            print(f"[warn] missing {col}")
            continue

        # find max sequence length (@1-level)
        max_len = int(df[col].apply(
            lambda d: len(d.get("@1", [])) if isinstance(d, dict) else 0
        ).max() or 0)

        def to_array(d):
            if isinstance(d, dict) and "@1" in d:
                arr = np.asarray(d["@1"], dtype=float)
            else:
                arr = np.full(max_len, np.nan, dtype=float)
            # pad to consistent length
            if len(arr) < max_len:
                arr = np.pad(arr, (0, max_len - len(arr)), constant_values=np.nan)
            return arr

        df[f"{col}_@1"] = df[col].apply(to_array)
        print(f"[ok] normalized {col}_@1 ‚Üí len={max_len}")

    for mode in modes:
        col = f"p_diff__{mode}"
        if col not in df.columns:
            print(f"[warn] missing {col}")
            continue

        # find max length across all prompts
        max_len = int(df[col].apply(
            lambda v: len(v) if isinstance(v, (list, np.ndarray)) else 0
        ).max() or 0)

        def to_array(v):
            if isinstance(v, (list, np.ndarray)):
                arr = np.asarray(v, dtype=float)
            else:
                try:
                    arr = np.array([float(v)], dtype=float)
                except Exception:
                    arr = np.array([np.nan], dtype=float)
            # pad to uniform length
            if len(arr) < max_len:
                arr = np.pad(arr, (0, max_len - len(arr)), constant_values=np.nan)
            return arr

        df[col] = df[col].apply(to_array)
        print(f"[ok] normalized {col} ‚Üí len={max_len}")

    print("[done] Anchors preprocessed (binary + continuous, unified padding)]")
    return df


# ============================================================
# Correlation 
# ============================================================
def _is_binary(arr, tol=1e-6):
    arr = np.asarray(arr, dtype=float)
    arr = arr[~np.isnan(arr)]
    if len(arr) == 0:
        return False
    u = np.unique(np.round(arr, 6))
    return np.all((np.abs(u - 0) < tol) | (np.abs(u - 1) < tol))


def _phi_coefficient(x, y):
    """Phi coefficient for two binary arrays (0/1)."""
    x, y = np.asarray(x).astype(int), np.asarray(y).astype(int)
    if len(x) == 0 or len(y) == 0:
        return np.nan
    try:
        table = pd.crosstab(x, y)
        if table.shape != (2, 2):
            return np.nan
        chi2, _, _, _ = chi2_contingency(table, correction=False)
        sign = np.sign((table.loc[1,1]*table.loc[0,0]) - (table.loc[1,0]*table.loc[0,1]))
        return float(sign * np.sqrt(chi2 / len(x)))
    except Exception:
        return np.nan


# --- metric-type mapping ---
"""METRIC_TYPES = {
    **{f"acc_A_@{k}": "binary" for k in [1, 5, 10]},
    **{f"acc_B_@{k}": "binary" for k in [1, 5, 10]},
    "disagree_correct": "binary",
    "agree_correct": "binary",
    "agree_wrong": "binary",
    "agree_set": "binary",
    "disagree_set": "binary",

    "jaccard_@1": "binary",
    "jaccard_@5": "continuous",
    "jaccard_@10": "continuous",
}"""
METRIC_TYPES = {
    **{f"acc_A_@{k}": "continuous" for k in [1, 5, 10]},
    **{f"acc_B_@{k}": "continuous" for k in [1, 5, 10]},
    "disagree_correct": "continuous",
    "agree_correct": "continuous",
    "agree_wrong": "continuous",
    "agree_set": "continuous",
    "disagree_set": "continuous",

    "jaccard_@1": "continuous",
    "jaccard_@5": "continuous",
    "jaccard_@10": "continuous",
}

"""def _choose_corr_func_fixed(anchor, metric):
    anchor_t = "binary" if "disagree_correct" in anchor else "continuous"
    metric_t = METRIC_TYPES.get(metric, "continuous")

    if anchor_t == "binary" and metric_t == "binary":
        return "phi", _phi_coefficient
    elif anchor_t == "binary" or metric_t == "binary":
        return "pointbiserial", pointbiserialr
    else:
        return "spearman", spearmanr"""
def _choose_corr_func_fixed(anchor, metric):
    """Deterministic correlation type based on known metric types."""
    anchor_t = METRIC_TYPES.get(anchor, "continuous")  
    metric_t = METRIC_TYPES.get(metric, "continuous")

    if anchor_t == "binary" and metric_t == "binary":
        return "phi", _phi_coefficient
    elif anchor_t == "binary" or metric_t == "binary":
        return "pointbiserial", pointbiserialr
    else:
        return "spearman", spearmanr



# ============================================================
# Correlation (pooled version)
# ============================================================
def correlate_layers_by_anchor(
    df,
    anchor,
    metrics,
    modes=("raw", "unit_rms", "norm_rms"),
    topk_levels=(1, 5, 10),
    n_boot=10,
    n_perm=10,
    seed=42,
    output_dir=None,
    model="m_8bit",
    min_valid=3,
):
    np.random.seed(seed)
    results = []
    df_out = df[df["layer_name"].str.lower() == "output"]

    # --- Build anchor map ---
    anchor_map = {}
    for _, row in df_out.iterrows():
        key = (int(row["prompt_id"]), int(row.get("batch_index", 0)))
        anchor_vecs = {}
        for mode in modes:
            col1 = f"{anchor}_{mode}_@1"
            col2 = f"{anchor}_{mode}"
            col = col1 if col1 in df.columns else col2
            val = row.get(col)
            if isinstance(val, (list, np.ndarray)) and len(val) > 0:
                anchor_vecs[mode] = np.array(val, float)
        if anchor_vecs:
            anchor_map[key] = anchor_vecs
    print(f"[anchor_map] built {len(anchor_map)} anchors")

    # --- Force raw for these metrics ---
    SHARED_METRICS = {"cosine_sim", "l2_dist"}

    # ============================================================
    # MAIN LOOP
    # ============================================================
    for mode in modes:
        for (lname, lidx), layer_df in tqdm(df.groupby(["layer_name", "layer_index"]), desc=f"{anchor}_{mode}"):
            for metric in metrics:
                is_topk_only = metric in {"acc_A", "acc_B", "jaccard"}
                suffixes = [f"_@{k}" for k in topk_levels] if is_topk_only else [""] + [f"_@{k}" for k in topk_levels]

                for suffix in suffixes:
                    src_mode = "raw" if metric in SHARED_METRICS else mode
                    mcol = f"{metric}_{src_mode}{suffix}"
                    if mcol not in layer_df.columns:
                        continue

                    anchor_vals, metric_vals = [], []
                    for (pid, bid), group in layer_df.groupby(["prompt_id", "batch_index"]):
                        key = (int(pid), int(bid))
                        if key not in anchor_map or mode not in anchor_map[key]:
                            continue
                        a = anchor_map[key][mode]
                        m = np.asarray(group[mcol].iloc[0], dtype=float).flatten()
                        n = min(len(a), len(m))
                        if n < min_valid:
                            continue
                        a, m = a[:n], m[:n]
                        mask = np.isfinite(a) & np.isfinite(m)
                        if mask.sum() < min_valid:
                            continue
                        anchor_vals.append(a[mask])
                        metric_vals.append(m[mask])

                    # --- NaN / missing cases ---
                    if not anchor_vals:
                        results.append({
                            "mode": mode,
                            "anchor": anchor,
                            "metric": f"{metric}{suffix}",
                            "corr_type": "undefined",
                            "layer_name": lname,
                            "layer_index": lidx,
                            "rho": np.nan,
                            "rho_boot_median": np.nan,
                            "ci_low": np.nan,
                            "ci_high": np.nan,
                            "p_val": np.nan,
                            "p_perm": np.nan,
                            "n": 0,
                            "pooling": "pooled",
                            "reason": "no_data"
                        })
                        continue

                    A = np.concatenate(anchor_vals)
                    M = np.concatenate(metric_vals)
                    mask = np.isfinite(A) & np.isfinite(M)
                    n_used = int(mask.sum())
                    if n_used < min_valid:
                        continue
                    A, M = A[mask], M[mask]

                    # --- No variance case ---
                    if np.std(A) == 0 or np.std(M) == 0:
                        results.append({
                            "mode": mode,
                            "anchor": anchor,
                            "metric": f"{metric}{suffix}",
                            "corr_type": "spearman",
                            "layer_name": lname,
                            "layer_index": lidx,
                            "rho": np.nan,
                            "rho_boot_median": np.nan,
                            "ci_low": np.nan,
                            "ci_high": np.nan,
                            "p_val": np.nan,
                            "p_perm": np.nan,
                            "n": n_used,
                            "pooling": "pooled",
                            "reason": "no_variance"
                        })
                        continue

                    # --- correlation computation ---
                    cname, func = _choose_corr_func_fixed(anchor, f"{metric}{suffix}")
                    try:
                        rho, pval = func(A, M)
                    except Exception:
                        rho, pval = np.nan, np.nan

                    if not np.isfinite(rho):
                        continue

                    # --- bootstrap ---
                    boot_rhos = []
                    for _ in range(n_boot):
                        idx = np.random.choice(n_used, n_used, replace=True)
                        try:
                            rho_b, _ = func(A[idx], M[idx])
                        except Exception:
                            rho_b = np.nan
                        if np.isfinite(rho_b):
                            boot_rhos.append(rho_b)
                    if len(boot_rhos) > 20:
                        ci_low, ci_high = np.percentile(boot_rhos, [2.5, 97.5])
                        rho_boot = np.median(boot_rhos)
                    else:
                        ci_low = ci_high = rho_boot = np.nan

                    # --- permutation test ---
                    perm_rhos = []
                    for _ in range(n_perm):
                        try:
                            rho_p, _ = func(A, np.random.permutation(M))
                        except Exception:
                            rho_p = np.nan
                        if np.isfinite(rho_p):
                            perm_rhos.append(rho_p)
                    if len(perm_rhos) > 20:
                        perm_rhos = np.array(perm_rhos)
                        p_perm = (np.sum(np.abs(perm_rhos) >= abs(rho)) + 1) / (len(perm_rhos) + 1)
                    else:
                        p_perm = np.nan

                    results.append({
                        "mode": mode,
                        "anchor": anchor,
                        "metric": f"{metric}{suffix}",
                        "corr_type": cname,
                        "layer_name": lname,
                        "layer_index": lidx,
                        "rho": rho,
                        "rho_boot_median": rho_boot,
                        "ci_low": ci_low,
                        "ci_high": ci_high,
                        "p_val": pval,
                        "p_perm": p_perm,
                        "n": n_used,
                        "pooling": "pooled",
                        "reason": "ok"
                    })

    # ============================================================
    # Wrap up
    # ============================================================
    df_corr = pd.DataFrame(results)
    print(f"[ok] {anchor} ‚Üí {len(df_corr)} pooled correlations")
    print(f"[nan check] {df_corr['rho'].isna().sum()} NaN correlations out of {len(df_corr)} total")

    if output_dir:
        out_path = Path(output_dir) / f"lw_{model}_{anchor}_corr_pooled.csv"
        df_corr.to_csv(out_path, index=False)
        print(f"[saved correlations] {out_path}")

    return df_corr


# ============================================================
# Summarize correlations + Global Spearman
# ============================================================
def summarize_correlations(df_corr, output_dir, model, ci_mode="both"):
    df = df_corr.copy()
    if "n" not in df.columns:
        df["n"] = 1

    df["z"] = np.nan
    mask_finite = np.isfinite(df["rho"])
    df.loc[mask_finite, "z"] = np.arctanh(np.clip(df.loc[mask_finite, "rho"], -0.999999, 0.999999))

    drop_cols = [c for c in ["prompt_id", "batch_index"] if c in df.columns]
    if drop_cols:
        df = df.drop(columns=drop_cols)

    group_cols = ["mode", "anchor", "metric", "corr_type", "layer_name", "layer_index"]

    def _weighted_stats(g):
        out = {}
        if not np.any(np.isfinite(g["rho"])) or g["n"].sum() == 0:
            for key in [
                "rho_mean", "rho_low", "rho_high",
                "rho_boot_mean", "ci_low_emp", "ci_high_emp", "n_total"
            ]:
                out[key] = np.nan
            out["n_total"] = 0
        else:
            z = g["z"].dropna()
            w = g.loc[z.index, "n"]
            z_mean = np.average(z, weights=w)
            z_std = np.sqrt(np.average((z - z_mean)**2, weights=w))
            out["rho_mean"] = np.tanh(z_mean)
            out["rho_low"] = np.tanh(z_mean - 1.96 * z_std)
            out["rho_high"] = np.tanh(z_mean + 1.96 * z_std)
            out["n_total"] = g["n"].sum()
            out["rho_boot_mean"] = np.average(g["rho_boot_median"].fillna(0), weights=w)
            out["ci_low_emp"] = np.average(g["ci_low"].fillna(0), weights=w)
            out["ci_high_emp"] = np.average(g["ci_high"].fillna(0), weights=w)

        for k in group_cols:
            out[k] = g[k].iloc[0]
        return pd.DataFrame([out])

    df_summary = pd.concat(
        [_weighted_stats(g) for _, g in df.groupby(group_cols, group_keys=False)],
        ignore_index=True
    )

    # choose CI mode for plotting
    if ci_mode == "empirical":
        df_summary["rho_low_plot"] = df_summary["ci_low_emp"]
        df_summary["rho_high_plot"] = df_summary["ci_high_emp"]
    else:
        df_summary["rho_low_plot"] = df_summary["rho_low"]
        df_summary["rho_high_plot"] = df_summary["rho_high"]

    # --- compute global Spearman (across all layers) ---
    valid = df_summary[np.isfinite(df_summary["rho_mean"])]
    if not valid.empty:
        z_vals = np.arctanh(valid["rho_mean"].clip(-0.999999, 0.999999))
        weights = valid["n_total"].fillna(1)
        z_mean = np.average(z_vals, weights=weights)
        rho_global = np.tanh(z_mean)
        n_valid_layers = len(valid)
        print(f"[global Spearman rho] {rho_global:.3f} (across {n_valid_layers} valid layers)")
    else:
        rho_global = np.nan
        n_valid_layers = 0
        print("[global Spearman rho] not computable (no valid layers)")

    # --- report missing / no-variance layers ---
    missing_info = df_corr[df_corr["reason"].isin(["no_variance", "no_data"])]
    if not missing_info.empty:
        for metric in missing_info["metric"].unique():
            n_missing = len(missing_info[missing_info["metric"] == metric])
            print(f"[info] {metric}: {n_missing} layers had no variance or missing data")

    # --- save summary ---
    out_summary = Path(output_dir) / f"lw_{model}_corr_pooled_summary_{ci_mode}.csv"
    df_summary.to_csv(out_summary, index=False)
    print(f"[saved summary] {out_summary}")
    print(f"[diag] rows={len(df_summary)} groups={df[group_cols].drop_duplicates().shape[0]}")
    print(f"[NaN summary rows: {df_summary['rho_mean'].isna().sum()}]")

    return df_summary, rho_global, n_valid_layers



# ============================================================
# Run pooled correlation pipeline 
# ============================================================
BASE = Path("saved_data")
model = "m_quant"
output_dir = BASE / "summary" / model
output_dir.mkdir(parents=True, exist_ok=True)

topk_dir = BASE / "topk" / model
df_topk = merge_parquet_files(topk_dir)
df_topk = expand_topk_metrics(
    df_topk,
    metrics=["jaccard", "acc_A", "acc_B", "disagree_correct"],
    modes=("raw", "unit_rms", "norm_rms")
)
df_topk = preprocess_anchors(df_topk)

for m in ["cosine_sim", "l2_dist"]:
    if f"{m}_raw" in df_topk.columns:
        for mode in ["unit_rms", "norm_rms"]:
            col_src = f"{m}_raw"
            col_dst = f"{m}_{mode}"
            if col_dst not in df_topk.columns:
                df_topk[col_dst] = df_topk[col_src]
                print(f"[copy] propagated {col_src} ‚Üí {col_dst}")

anchors = ["disagree_correct", "p_diff"]
metrics = [
    "kl_ab", "kl_ba", "js_div", "js_dist", "tvd",
    "entropy_A", "entropy_B", "cosine_sim", "l2_dist",
    "ppl_diff", "jaccard", "acc_A", "acc_B"
]

df_corr_all = []
for a in anchors:
    df_corr_all.append(correlate_layers_by_anchor(df_topk, a, metrics, model=model))
df_corr_all = pd.concat(df_corr_all, ignore_index=True)

pooled_corr_path = output_dir / f"lw_{model}_corr_pooled.csv"
pooled_summary_path = output_dir / f"lw_{model}_corr_pooled_summary.csv"

df_corr_all.to_csv(pooled_corr_path, index=False)
print(f"[saved pooled correlations] {pooled_corr_path}")

df_summary = summarize_correlations(df_corr_all, output_dir=output_dir, model=model)
print(f"[saved pooled summary] {pooled_summary_path}")


In [None]:
df_corr_all[df_corr_all["rho"].isna()].groupby("layer_name").size()

df_corr_all.groupby("metric")["rho"].apply(lambda s: s.isna().mean()).sort_values(ascending=False)

In [None]:
df_corr_all.query("metric.str.contains('cosine_sim')").groupby("mode")["rho"].mean()

In [None]:
merged = pd.merge(
    df_corr_all, df_summary,
    on=["mode","anchor","metric","corr_type","layer_name","layer_index"],
    how="inner",
    suffixes=("_corr","_sum")
)
print(merged[["rho","rho_mean"]].corr())


In [None]:
df_corr_all

In [None]:
df_summary

In [None]:
df_corr_all.isna().sum()

In [None]:
df_summary.isna().sum()

### ==============================================
### Correlation Per Prompt =======================
### ==============================================

In [None]:
import os
from pathlib import Path
import numpy as np
import pandas as pd
from scipy.stats import spearmanr, pearsonr, kendalltau, pointbiserialr, chi2_contingency
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns


# ============================================================
# Merge parquet files
# ============================================================
def merge_parquet_files(input_dir: str) -> pd.DataFrame:
    files = sorted(Path(input_dir).glob("*.parquet"))
    if not files:
        raise FileNotFoundError(f"No parquet files found in {input_dir}")

    dfs = []
    for i, f in enumerate(files):
        d = pd.read_parquet(f)
        d["batch_index"] = i
        dfs.append(d)

    df = pd.concat(dfs, ignore_index=True)
    print(f"[merge] merged {len(files)} parquet files ‚Üí {len(df)} rows")

    n_batches = df["batch_index"].nunique()
    n_prompts = df["prompt_id"].nunique()
    n_pairs = df.groupby(["prompt_id", "batch_index"]).ngroups
    print(f"[diag] unique batch_index={n_batches} | unique prompt_id={n_prompts} | unique (prompt,batch)={n_pairs}")

    if n_pairs < len(files) * 10:
        print("[warn] fewer prompt‚Äìbatch pairs than expected ‚Äî possible ID overlap?")
    else:
        print("[ok] prompt‚Äìbatch pairs look good")

    return df


# ============================================================
# Expand nested top-k metrics 
# ============================================================
def expand_topk_metrics(df, metrics, modes=("raw","unit_rms","norm_rms"), topk_levels=(1,5,10)):
    new_cols=[]
    for metric in metrics:
        for mode in modes:
            base=f"{metric}_{mode}"
            if base not in df.columns:
                continue
            for k in topk_levels:
                new=f"{base}_@{k}"
                df[new]=df[base].apply(
                    lambda d: np.array(d.get(f"@{k}",[]),float)
                    if isinstance(d,dict) and f"@{k}" in d else np.array([])
                )
                new_cols.append(new)
    print(f"[expand] added {len(new_cols)} flattened top-k columns")
    return df


# ============================================================
# Helper
# ============================================================
"""def safe_flatten(x):
    if isinstance(x,(list,np.ndarray)):
        return np.array(x,float).flatten()
    return np.array([],float)"""
def safe_flatten(v):
    """Ensure flattening of nested lists/arrays into 1D np.array."""
    if isinstance(v, (list, np.ndarray)):
        return np.asarray(v, dtype=float).flatten()
    try:
        return np.array([float(v)], dtype=float)
    except Exception:
        return np.array([], dtype=float)

# ============================================================
# Extract anchors 
# ============================================================
def extract_top1_disagreement_anchor(df, modes=("raw","unit_rms","norm_rms")):
    for mode in modes:
        col=f"disagree_correct_{mode}"
        if col not in df.columns:
            print(f"[warn] missing {col}")
            continue
        max_len=int(df[col].apply(
            lambda d: len(d.get("@1",[])) if isinstance(d,dict) else 0
        ).max() or 0)
        def to_array(d):
            if isinstance(d,dict) and "@1" in d:
                arr=np.array(d["@1"],float)
            else:
                arr=np.full(max_len,np.nan)
            return arr
        df[f"{col}_@1"]=df[col].apply(to_array)
        print(f"[ok] normalized {col}_@1 ‚Üí len={max_len}")
    return df


def preprocess_anchors(df, modes=("raw", "unit_rms", "norm_rms")):
    for mode in modes:
        col = f"disagree_correct_{mode}"
        if col not in df.columns:
            print(f"[warn] missing {col}")
            continue

        max_len = int(df[col].apply(
            lambda d: len(d.get("@1", [])) if isinstance(d, dict) else 0
        ).max() or 0)

        def to_array(d):
            if isinstance(d, dict) and "@1" in d:
                arr = np.asarray(d["@1"], dtype=float)
            else:
                arr = np.full(max_len, np.nan, dtype=float)
            # pad to consistent length
            if len(arr) < max_len:
                arr = np.pad(arr, (0, max_len - len(arr)), constant_values=np.nan)
            return arr

        df[f"{col}_@1"] = df[col].apply(to_array)
        print(f"[ok] normalized {col}_@1 ‚Üí len={max_len}")

    for mode in modes:
        col = f"p_diff_{mode}"
        if col not in df.columns:
            print(f"[warn] missing {col}")
            continue

        max_len = int(df[col].apply(
            lambda v: len(v) if isinstance(v, (list, np.ndarray)) else 0
        ).max() or 0)

        def to_array(v):
            if isinstance(v, (list, np.ndarray)):
                arr = np.asarray(v, dtype=float)
            else:
                try:
                    arr = np.array([float(v)], dtype=float)
                except Exception:
                    arr = np.array([np.nan], dtype=float)
            # pad to uniform length
            if len(arr) < max_len:
                arr = np.pad(arr, (0, max_len - len(arr)), constant_values=np.nan)
            return arr

        df[col] = df[col].apply(to_array)
        print(f"[ok] normalized {col} ‚Üí len={max_len}")

    print("[done] Anchors preprocessed (binary + continuous, unified padding)]")
    return df


# ============================================================
# Correlation 
# ============================================================
def _is_binary(arr, tol=1e-6):
    arr = np.asarray(arr, dtype=float)
    arr = arr[~np.isnan(arr)]
    if len(arr) == 0:
        return False
    u = np.unique(np.round(arr, 6))
    return np.all((np.abs(u - 0) < tol) | (np.abs(u - 1) < tol))


def _phi_coefficient(x, y):
    """Phi coefficient for two binary arrays (0/1)."""
    x, y = np.asarray(x).astype(int), np.asarray(y).astype(int)
    if len(x) == 0 or len(y) == 0:
        return np.nan
    try:
        table = pd.crosstab(x, y)
        if table.shape != (2, 2):
            return np.nan
        chi2, _, _, _ = chi2_contingency(table, correction=False)
        sign = np.sign((table.loc[1, 1] * table.loc[0, 0]) - (table.loc[1, 0] * table.loc[0, 1]))
        return float(sign * np.sqrt(chi2 / len(x)))
    except Exception:
        return np.nan


# --- metric-type mapping ---
"""METRIC_TYPES = {
    **{f"acc_A_@{k}": "binary" for k in [1, 5, 10]},
    **{f"acc_B_@{k}": "binary" for k in [1, 5, 10]},
    "disagree_correct": "binary",
    "agree_correct": "binary",
    "agree_wrong": "binary",
    "agree_set": "binary",
    "disagree_set": "binary",

    "jaccard_@1": "binary",
    "jaccard_@5": "continuous",
    "jaccard_@10": "continuous",
}"""
METRIC_TYPES = {
    **{f"acc_A_@{k}": "continuous" for k in [1, 5, 10]},
    **{f"acc_B_@{k}": "continuous" for k in [1, 5, 10]},
    "disagree_correct": "continuous",
    "agree_correct": "continuous",
    "agree_wrong": "continuous",
    "agree_set": "continuous",
    "disagree_set": "continuous",

    "jaccard_@1": "continuous",
    "jaccard_@5": "continuous",
    "jaccard_@10": "continuous",
}

"""def _choose_corr_func_fixed(anchor, metric):
    anchor_t = "binary" if "disagree_correct" in anchor else "continuous"
    metric_t = METRIC_TYPES.get(metric, "continuous")

    if anchor_t == "binary" and metric_t == "binary":
        return "phi", _phi_coefficient
    elif anchor_t == "binary" or metric_t == "binary":
        return "pointbiserial", pointbiserialr
    else:
        return "spearman", spearmanr"""
def _choose_corr_func_fixed(anchor, metric):
    """Deterministic correlation type based on known metric types."""
    anchor_t = METRIC_TYPES.get(anchor, "continuous") 
    metric_t = METRIC_TYPES.get(metric, "continuous")

    if anchor_t == "binary" and metric_t == "binary":
        return "phi", _phi_coefficient
    elif anchor_t == "binary" or metric_t == "binary":
        return "pointbiserial", pointbiserialr
    else:
        return "spearman", spearmanr



# ============================================================
# Correlation (per-prompt)
# ============================================================
def correlate_layers_by_anchor_perprompt(
    df,
    anchor,
    metrics,
    modes=("raw", "unit_rms", "norm_rms"),
    topk_levels=(1, 5, 10),
    n_boot=10,
    n_perm=10,
    seed=42,
    output_dir=None,
    model="m_8bit",
    min_valid=3,
):
    np.random.seed(seed)
    results = []
    df_out = df[df["layer_name"].str.lower() == "output"]

    # --- Build anchor map from output layer ---
    anchor_map = {}
    for _, row in df_out.iterrows():
        key = (int(row["prompt_id"]), int(row.get("batch_index", 0)))
        anchor_vecs = {}
        for mode in modes:
            col1 = f"{anchor}_{mode}_@1"
            col2 = f"{anchor}_{mode}"
            col = col1 if col1 in df.columns else col2
            val = row.get(col)
            if isinstance(val, (list, np.ndarray)) and len(val) > 0:
                anchor_vecs[mode] = np.array(val, float)
        if anchor_vecs:
            anchor_map[key] = anchor_vecs
    print(f"[anchor_map] built {len(anchor_map)} anchors")

    skip_const = 0
    SHARED_METRICS = {"cosine_sim", "l2_dist"}

    # --- Main correlation loop ---
    for mode in modes:
        for (lname, lidx), layer_df in tqdm(
            df.groupby(["layer_name", "layer_index"]),
            desc=f"{anchor}_{mode}"
        ):
            for metric in metrics:
                is_topk_only = metric in {"acc_A", "acc_B", "jaccard"}
                suffixes = [f"_@{k}" for k in topk_levels] if is_topk_only else [""] + [f"_@{k}" for k in topk_levels]

                for suffix in suffixes:
                    src_mode = "raw" if metric in SHARED_METRICS else mode
                    mcol = f"{metric}_{src_mode}{suffix}"
                    if mcol not in layer_df.columns:
                        continue

                    per_prompt_rhos, per_prompt_ns = [], []

                    for (pid, bid), group in layer_df.groupby(["prompt_id", "batch_index"]):
                        key = (int(pid), int(bid))
                        if key not in anchor_map or mode not in anchor_map[key]:
                            continue

                        a = anchor_map[key][mode]
                        m = np.asarray(group[mcol].iloc[0], dtype=float).flatten()
                        n = min(len(a), len(m))
                        if n < min_valid:
                            continue

                        a, m = a[:n], m[:n]
                        mask = np.isfinite(a) & np.isfinite(m)
                        n_used = int(mask.sum())
                        if n_used < min_valid:
                            continue
                        if np.std(a[mask]) == 0 or np.std(m[mask]) == 0:
                            skip_const += 1
                            continue

                        cname, corr_func = _choose_corr_func_fixed(anchor, f"{metric}{suffix}")
                        try:
                            rho, _ = corr_func(a[mask], m[mask])
                        except Exception:
                            rho = np.nan
                        if np.isfinite(rho):
                            per_prompt_rhos.append(rho)
                            per_prompt_ns.append(n_used)

                    # --- handle missing prompts ---
                    if not per_prompt_rhos:
                        results.append({
                            "mode": mode, "anchor": anchor, "metric": metric + suffix,
                            "corr_type": "undefined", "layer_name": lname, "layer_index": lidx,
                            "rho": np.nan, "rho_boot_median": np.nan, "ci_low": np.nan,
                            "ci_high": np.nan, "p_perm": np.nan, "n_prompts": 0,
                            "n_used_mean": np.nan, "n_used_median": np.nan,
                            "pooling": "per_prompt", "reason": "no_data"
                        })
                        continue

                    A, N = np.array(per_prompt_rhos), np.array(per_prompt_ns)
                    if len(N) == 0 or np.all(np.isnan(A)):
                        continue

                    # --- fisher-Z weighted aggregation ---
                    z = np.arctanh(np.clip(A, -0.999999, 0.999999))
                    z_mean = np.average(z, weights=N)
                    rho = np.tanh(z_mean)
                    n_mean, n_median = np.nanmean(N), np.nanmedian(N)

                    # --- bootstrap CI ---
                    if len(A) >= min_valid:
                        boot_rhos = []
                        for _ in range(n_boot):
                            idx = np.random.choice(len(A), len(A), replace=True)
                            z_b = np.arctanh(np.clip(A[idx], -0.999999, 0.999999))
                            boot_rhos.append(np.tanh(np.average(z_b, weights=N[idx])))
                        ci_low, ci_high = np.percentile(boot_rhos, [2.5, 97.5])
                        rho_boot = np.median(boot_rhos)
                    else:
                        ci_low = ci_high = rho_boot = np.nan

                    # --- permutation test ---
                    if len(A) >= min_valid:
                        perm_rhos = []
                        for _ in range(n_perm):
                            perm_z = np.arctanh(np.clip(np.random.permutation(A), -0.999999, 0.999999))
                            perm_rhos.append(np.tanh(np.average(perm_z, weights=N)))
                        perm_rhos = np.array(perm_rhos)
                        p_perm = (np.sum(np.abs(perm_rhos) >= abs(rho)) + 1) / (len(perm_rhos) + 1)
                    else:
                        p_perm = np.nan

                    results.append({
                        "mode": mode, "anchor": anchor, "metric": metric + suffix,
                        "corr_type": cname, "layer_name": lname, "layer_index": lidx,
                        "rho": rho, "rho_boot_median": rho_boot,
                        "ci_low": ci_low, "ci_high": ci_high, "p_perm": p_perm,
                        "n_prompts": len(A), "n_used_mean": n_mean, "n_used_median": n_median,
                        "pooling": "per_prompt", "reason": "ok"
                    })

    df_corr = pd.DataFrame(results)
    print(f"[ok] {anchor} ‚Üí {len(df_corr)} per-prompt correlations")
    print(f"[NaN correlations: {df_corr['rho'].isna().sum()}]")
    print(f"[skipped constant={skip_const}]")

    if output_dir:
        out_path = Path(output_dir) / f"lw_{model}_{anchor}_corr_perprompt.csv"
        df_corr.to_csv(out_path, index=False)
        print(f"[saved per-prompt correlations] {out_path}")

    # --- per-layer Fisher-Z summary ---
    valid = df_corr[np.isfinite(df_corr["rho"])]
    if not valid.empty:
        layer_stats = []
        for (lname, lidx), g in valid.groupby(["layer_name", "layer_index"]):
            z_vals = np.arctanh(np.clip(g["rho"], -0.999999, 0.999999))
            weights = g["n_prompts"].fillna(1)
            z_mean = np.average(z_vals, weights=weights)
            rho_layer = np.tanh(z_mean)
            layer_stats.append({"layer_name": lname, "layer_index": lidx, "rho_layer": rho_layer})
        df_layer = pd.DataFrame(layer_stats)
        rho_global = np.tanh(np.average(np.arctanh(df_layer["rho_layer"]), weights=None))
        print(f"[global per-prompt Fisher-Z rho] {rho_global:.3f} over {len(df_layer)} layers")
    else:
        df_layer = pd.DataFrame()
        rho_global = np.nan
        print("[global per-prompt Fisher-Z rho] not computable")

    return df_corr, df_layer, rho_global


# ============================================================
# summarize_correlations_perprompt (fix for single-row groups)
# ============================================================
def summarize_correlations_perprompt(df_corr, output_dir, model, ci_mode="both"):
    df = df_corr.copy()

    df = df.dropna(subset=["rho"]).copy()
    if "n_prompts" in df.columns:
        df["n"] = df["n_prompts"]
    elif "n" not in df.columns:
        df["n"] = 1

    # --- fisher-Z transform ---
    df["z"] = np.nan
    mask_finite = np.isfinite(df["rho"])
    df.loc[mask_finite, "z"] = np.arctanh(np.clip(df.loc[mask_finite, "rho"], -0.999999, 0.999999))

    drop_cols = [c for c in ["prompt_id", "batch_index"] if c in df.columns]
    if drop_cols:
        df = df.drop(columns=drop_cols)

    group_cols = ["mode", "anchor", "metric", "corr_type", "layer_name", "layer_index"]

    # --- weighted stats per layer ---
    def _weighted_stats(g):
        out = {}
        if not np.any(np.isfinite(g["rho"])) or g["n"].sum() == 0:
            for key in [
                "rho_mean", "rho_low", "rho_high",
                "rho_boot_mean", "ci_low_emp", "ci_high_emp", "n_total"
            ]:
                out[key] = np.nan
            out["n_total"] = 0
        else:
            w = g["n"].fillna(1)
            z = np.arctanh(np.clip(g["rho"], -0.999999, 0.999999))
            z_mean = np.average(z, weights=w)
            z_std = np.sqrt(np.average((z - z_mean)**2, weights=w))
            out["rho_mean"] = np.tanh(z_mean)
            out["rho_low"] = np.tanh(z_mean - 1.96 * z_std)
            out["rho_high"] = np.tanh(z_mean + 1.96 * z_std)
            out["n_total"] = g["n"].sum()
            out["rho_boot_mean"] = np.average(g["rho_boot_median"].fillna(0), weights=w)
            out["ci_low_emp"] = np.average(g["ci_low"].fillna(0), weights=w)
            out["ci_high_emp"] = np.average(g["ci_high"].fillna(0), weights=w)

        for k in group_cols:
            out[k] = g[k].iloc[0]
        return pd.DataFrame([out])

    df_summary = pd.concat(
        [_weighted_stats(g) for _, g in df.groupby(group_cols, group_keys=False)],
        ignore_index=True
    )

    # --- CI mode handling ---
    if ci_mode in ("empirical", "both"):
        df_summary["rho_low_plot"] = df_summary["ci_low_emp"]
        df_summary["rho_high_plot"] = df_summary["ci_high_emp"]
    else:
        df_summary["rho_low_plot"] = df_summary["rho_low"]
        df_summary["rho_high_plot"] = df_summary["rho_high"]

    # --- compute global Spearman across layers ---
    valid = df_summary[np.isfinite(df_summary["rho_mean"])]
    if not valid.empty:
        z_vals = np.arctanh(valid["rho_mean"].clip(-0.999999, 0.999999))
        weights = valid["n_total"].fillna(1)
        z_mean = np.average(z_vals, weights=weights)
        rho_global = np.tanh(z_mean)
        n_valid_layers = len(valid)
        print(f"[global per-prompt Fisher-Z rho] {rho_global:.3f} (across {n_valid_layers} valid layers)")
    else:
        rho_global = np.nan
        n_valid_layers = 0
        print("[global per-prompt Fisher-Z rho] not computable (no valid layers)")

    # --- report layers without variance or missing ---
    if "reason" in df_corr.columns:
        missing_info = df_corr[df_corr["reason"].isin(["no_variance", "no_data"])]
        if not missing_info.empty:
            for metric in missing_info["metric"].unique():
                n_missing = len(missing_info[missing_info["metric"] == metric])
                print(f"[info] {metric}: {n_missing} layers had no variance or missing data")

    # --- save ---
    out_summary = Path(output_dir) / f"lw_{model}_corr_perprompt_summary_{ci_mode}.csv"
    df_summary.to_csv(out_summary, index=False)
    print(f"[saved summary] {out_summary}")
    print(f"[diag] rows={len(df_summary)} groups={df[group_cols].drop_duplicates().shape[0]}")
    print(f"[NaN summary rows: {df_summary['rho_mean'].isna().sum()}]")

    return df_summary, rho_global, n_valid_layers


# ============================================================
# Run per-prompt correlation pipeline
# ============================================================
BASE = Path("saved_data")
model = "m_quant"
output_dir = BASE / "summary" / model
output_dir.mkdir(parents=True, exist_ok=True)

topk_dir = BASE / "topk" / model
df_topk = merge_parquet_files(topk_dir)
df_topk = expand_topk_metrics(
    df_topk,
    metrics=["jaccard", "acc_A", "acc_B", "disagree_correct"],
    modes=("raw", "unit_rms", "norm_rms")
)
df_topk = preprocess_anchors(df_topk)

for m in ["cosine_sim", "l2_dist"]:
    if f"{m}_raw" in df_topk.columns:
        for mode in ["unit_rms", "norm_rms"]:
            col_src = f"{m}_raw"
            col_dst = f"{m}_{mode}"
            if col_dst not in df_topk.columns:
                df_topk[col_dst] = df_topk[col_src]
                print(f"[copy] propagated {col_src} ‚Üí {col_dst}")

anchors = ["disagree_correct", "p_diff"]
metrics = [
    "kl_ab", "kl_ba", "js_div", "js_dist", "tvd",
    "entropy_A", "entropy_B", "cosine_sim", "l2_dist",
    "ppl_diff", "jaccard", "acc_A", "acc_B"
]

df_corr_perprompt_all = []
for a in anchors:
    df_corr_perprompt_all.append(
        correlate_layers_by_anchor_perprompt(df_topk, a, metrics, model=model)
    )
df_corr_perprompt_all = pd.concat(df_corr_perprompt_all, ignore_index=True)

perprompt_corr_path = output_dir / f"lw_{model}_corr_perprompt.csv"
perprompt_summary_path = output_dir / f"lw_{model}_corr_perprompt_summary.csv"

df_corr_perprompt_all.to_csv(perprompt_corr_path, index=False)
print(f"[saved per-prompt correlations] {perprompt_corr_path}")

df_summary_pp = summarize_correlations_perprompt(df_corr_perprompt_all, output_dir=output_dir, model=model)
print(f"[saved per-prompt summary] {perprompt_summary_path}")


In [None]:
df_corr_perprompt_all.query("metric.str.contains('cosine_sim')").groupby("mode")["rho"].mean()


In [None]:
df_corr_perprompt_all

In [None]:
df_summary_pp

In [None]:
df_corr_perprompt_all.isna().sum()

In [None]:
df_summary_pp.isna().sum()

In [None]:
df_raw = pd.read_csv("saved_data/summary/m_8bit/lw_m_8bit_corr_perprompt.csv")
df_sum = pd.read_csv("saved_data/summary/m_8bit/lw_m_8bit_corr_perprompt_summary_both.csv")

merged = df_sum.merge(df_raw, on=["mode","anchor","metric","corr_type","layer_index"], suffixes=("_sum","_raw"))
merged["diff"] = merged["rho_mean"] - merged["rho"]

print(merged["diff"].abs().mean())

In [None]:
df_raw = pd.read_csv("saved_data/summary/m_8bit/lw_m_8bit_corr_pooled.csv")
df_sum = pd.read_csv("saved_data/summary/m_8bit/lw_m_8bit_corr_pooled_summary_both.csv")

merged = df_sum.merge(df_raw, on=["mode","anchor","metric","corr_type","layer_index"], suffixes=("_sum","_raw"))
merged["diff"] = merged["rho_mean"] - merged["rho"]

print(merged["diff"].abs().mean())

In [None]:
import pandas as pd
from pathlib import Path

BASE_DIR = Path("saved_data/summary")
MODELS = ["m_4bit", "m_8bit", "m_quant"]

dfs = []
for model in MODELS:
    for f in (BASE_DIR / model).glob("*_summary*.csv"):
        d = pd.read_csv(f)
        d["model"] = model
        if "pooled" in f.stem:
            d["pooling"] = "pooled"
        elif "perprompt" in f.stem or "per_prompt" in f.stem:
            d["pooling"] = "per_prompt"
        else:
            d["pooling"] = "unknown"
        dfs.append(d)

df = pd.concat(dfs, ignore_index=True)

for pooling in ["pooled", "per_prompt"]:
    print(f"\n=== {pooling.upper()} ===")
    d = df[df["pooling"] == pooling]
    all_layers = sorted(df["layer_index"].unique())
    print(f"Alle mulige lag: {all_layers[:10]}... ({len(all_layers)} total)")
    for model in MODELS:
        layers = sorted(d[d["model"] == model]["layer_index"].unique())
        missing = sorted(set(all_layers) - set(layers))
        print(f"{model}: {len(layers)} lag fundet, mangler {len(missing)} ‚Üí {missing[:10] if missing else 'ingen mangler'}")


In [None]:
subset = df[
    (df["anchor"] == "disagree_correct") &
    (df["corr_type"] == "pointbiserial")
]

print("Unique layers per model (per_prompt):")
for model in df["model"].unique():
    d = subset[(subset["model"] == model) & (df["pooling"] == "per_prompt")]
    print(f"{model}: {sorted(d['layer_index'].unique())[:10]} ... total {d['layer_index'].nunique()}")


# ==============================================
# Plot TopK Correlations =======================
# ==============================================

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path

# === STYLE ===
sns.set_theme(style="whitegrid", context="talk", palette="deep")
plt.rcParams.update({
    "axes.titlesize": 12,
    "axes.labelsize": 11,
    "legend.fontsize": 8,
    "xtick.labelsize": 9,
    "ytick.labelsize": 9,
})

# === CONFIG ===
BASE_DIR = Path("saved_data/summary")
MODELS = ["m_4bit", "m_8bit", "m_quant"]
ANCHORS_CONTINUOUS = ["logp_diff"]
ANCHORS_BINARY = ["disagree_correct"]
OUT_ROOT = Path("saved_data/figures_rawcorr_all")
OUT_ROOT.mkdir(parents=True, exist_ok=True)

LAYOUT_MODE = "metric"   # "mode" | "metric" | "group"

SPEARMAN_METRICS = [
    "tvd","kl_ab","kl_ba","js_div","js_dist","cosine_sim",
    "l2_dist","ppl_diff","jaccard_@1","jaccard_@5","jaccard_@10"
]
BISERIAL_METRICS = [
    "acc_A_@1","acc_A_@5","acc_A_@10",
    "acc_B_@1","acc_B_@5","acc_B_@10",
    "disagree_correct","jaccard_@1","jaccard_@5","jaccard_@10"
]
GROUPS = {
    "divergence": ["tvd","kl_ab","kl_ba","js_div","js_dist","ppl_diff"],
    "representation": ["cosine_sim","l2_dist","jaccard_@1","jaccard_@5","jaccard_@10"],
    "accuracy": ["acc_A_@1","acc_B_@1","acc_A_@5","acc_B_@5","disagree_correct"]
}

# === LOAD RAW CORR FILES ===
dfs = []
for model in MODELS:
    model_dir = BASE_DIR / model
    if not model_dir.exists():
        continue
    for f in model_dir.glob("*corr_*.csv"):
        if "summary" in f.name:
            continue
        d = pd.read_csv(f)
        d["model"] = model
        if "pooled" in f.stem:
            d["pooling"] = "pooled"
        elif "perprompt" in f.stem or "per_prompt" in f.stem:
            d["pooling"] = "per_prompt"
        else:
            d["pooling"] = "unknown"
        dfs.append(d)

if not dfs:
    raise RuntimeError("No raw correlation files found!")

df = pd.concat(dfs, ignore_index=True)
for c in ["rho"]:
    df[c] = pd.to_numeric(df[c], errors="coerce")
df["layer_index"] = pd.to_numeric(df["layer_index"], errors="coerce").fillna(0).astype(int)
df["rho_smooth"] = df["rho"]  # raw values only

print(f"[ok] merged {len(df)} rows from {df['model'].nunique()} models.")
print(df.groupby(["anchor","corr_type","pooling"])["rho"].describe().round(3))

# === HELPERS ===
def _auto_subplots(n_items, n_cols=3):
    n_rows = int(np.ceil(n_items / n_cols))
    return n_rows, n_cols

def _finalize_grid(fig, axes, title, save_path=None):
    for ax in axes.flat:
        if not ax.has_data():
            ax.axis("off")
    fig.suptitle(title, fontsize=15, weight="bold")
    fig.tight_layout(rect=[0, 0, 1, 0.95])
    if save_path:
        fig.savefig(save_path, dpi=250, bbox_inches="tight")
        print(f"[saved] {save_path}")
    plt.close(fig)

def _mark_nans(ax, dsub):
    d_nan = dsub[dsub["rho"].isna()]
    if not d_nan.empty:
        for i, model in enumerate(sorted(d_nan["model"].unique())):
            d_m = d_nan[d_nan["model"] == model]
            ax.scatter(
                d_m["layer_index"],
                [-0.95 + 0.05*i] * len(d_m),
                color="red", marker="x", s=50, label=f"{model} NaN"
            )

# === PLOT FUNCTIONS ===
def plot_by_mode(df_sub, corr_type, anchor, pooling, out_dir):
    fig, axes = plt.subplots(1, 3, figsize=(18, 6), sharey=True)
    for ax, mode in zip(axes, ["raw", "unit_rms", "norm_rms"]):
        dmode = df_sub[df_sub["mode"] == mode]
        if dmode.empty:
            ax.text(0.5, 0.5, "no data", ha="center", va="center", fontsize=11)
            continue
        sns.lineplot(
            data=dmode, x="layer_index", y="rho_smooth",
            hue="model", style="metric", lw=2, ax=ax
        )
        _mark_nans(ax, dmode)
        ax.axhline(0, color="black", linestyle=":")
        ax.axhline(0.3, color="gray", linestyle="--", lw=0.6, alpha=0.4)
        ax.axhline(-0.3, color="gray", linestyle="--", lw=0.6, alpha=0.4)
        ax.set_ylim(-1, 1)
        ax.set_title(mode.upper(), fontsize=12, weight="bold")
        ax.set_xlabel("Layer index")
        ax.set_ylabel("œÅ" if mode == "raw" else "")
        ax.legend(fontsize=8)
    _finalize_grid(fig, axes, f"{corr_type.capitalize()} ‚Äî {anchor} ({pooling})",
                   out_dir / f"{corr_type}_{anchor}_{pooling}_modes.png")

def plot_by_metric(df_sub, corr_type, anchor, pooling, out_dir, n_cols=3):
    metrics = sorted(df_sub["metric"].unique())
    n_rows, n_cols = _auto_subplots(len(metrics), n_cols)
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(6 * n_cols, 3.5 * n_rows), sharey=True)
    axes = np.array(axes).reshape(-1)
    for ax, metric in zip(axes, metrics):
        dmet = df_sub[df_sub["metric"] == metric]
        if dmet.empty:
            continue
        sns.lineplot(
            data=dmet, x="layer_index", y="rho_smooth",
            hue="model", style="mode", lw=2.2, ax=ax
        )
        _mark_nans(ax, dmet)
        ax.axhline(0, color="black", linestyle=":")
        ax.axhline(0.3, color="gray", linestyle="--", lw=0.6, alpha=0.4)
        ax.axhline(-0.3, color="gray", linestyle="--", lw=0.6, alpha=0.4)
        ax.set_ylim(-1, 1)
        ax.set_title(metric, fontsize=10)
        ax.set_xlabel("Layer index")
        ax.set_ylabel("œÅ")
        ax.legend(fontsize=8, frameon=True)
    _finalize_grid(fig, axes, f"{corr_type.capitalize()} ‚Äî {anchor} ({pooling})",
                   out_dir / f"{corr_type}_{anchor}_{pooling}_metrics.png")

def plot_by_group(df_sub, corr_type, anchor, pooling, out_dir):
    for gname, gmetrics in GROUPS.items():
        dgroup = df_sub[df_sub["metric"].isin(gmetrics)]
        if dgroup.empty:
            continue
        fig, axes = plt.subplots(1, 3, figsize=(18, 6), sharey=True)
        for ax, mode in zip(axes, ["raw", "unit_rms", "norm_rms"]):
            dmode = dgroup[dgroup["mode"] == mode]
            if dmode.empty:
                ax.text(0.5, 0.5, "no data", ha="center", va="center", fontsize=11)
                continue
            sns.lineplot(
                data=dmode, x="layer_index", y="rho_smooth",
                hue="model", style="metric", lw=2, ax=ax
            )
            _mark_nans(ax, dmode)
            ax.axhline(0, color="black", linestyle=":")
            ax.axhline(0.3, color="gray", linestyle="--", lw=0.6, alpha=0.4)
            ax.axhline(-0.3, color="gray", linestyle="--", lw=0.6, alpha=0.4)
            ax.set_ylim(-1, 1)
            ax.set_title(f"{gname.title()} ({mode})", fontsize=11)
            ax.set_xlabel("Layer index")
            ax.set_ylabel("œÅ" if mode == "raw" else "")
            ax.legend(fontsize=8, frameon=True)
        _finalize_grid(fig, axes,
                       f"{gname.capitalize()} ‚Äî {corr_type.capitalize()} ({anchor}/{pooling})",
                       out_dir / f"{corr_type}_{anchor}_{pooling}_{gname}.png")

# === MAIN LOOP ===
for pooling in ["pooled", "per_prompt"]:
    out_dir = OUT_ROOT / pooling
    out_dir.mkdir(parents=True, exist_ok=True)

    # Continuous (Spearman)
    for anchor in ANCHORS_CONTINUOUS:
        df_s = df[
            (df["anchor"] == anchor)
            & (df["corr_type"] == "spearman")
            & (df["pooling"] == pooling)
            & (df["metric"].isin(SPEARMAN_METRICS))
        ]
        if df_s.empty:
            continue
        if LAYOUT_MODE == "mode":
            plot_by_mode(df_s, "spearman", anchor, pooling, out_dir)
        elif LAYOUT_MODE == "metric":
            plot_by_metric(df_s, "spearman", anchor, pooling, out_dir)
        elif LAYOUT_MODE == "group":
            plot_by_group(df_s, "spearman", anchor, pooling, out_dir)

    # Binary (Point-biserial & Phi)
    for corr_type in ["pointbiserial", "phi"]:
        for anchor in ANCHORS_BINARY:
            df_b = df[
                (df["anchor"] == anchor)
                & (df["corr_type"] == corr_type)
                & (df["pooling"] == pooling)
                & (df["metric"].isin(BISERIAL_METRICS))
            ]
            if df_b.empty:
                continue
            if LAYOUT_MODE == "mode":
                plot_by_mode(df_b, corr_type, anchor, pooling, out_dir)
            elif LAYOUT_MODE == "metric":
                plot_by_metric(df_b, corr_type, anchor, pooling, out_dir)
            elif LAYOUT_MODE == "group":
                plot_by_group(df_b, corr_type, anchor, pooling, out_dir)

# === MIXED ALL-COMBINATION PLOTS ===
print("\n[mixed plotting] Generating full cross-anchor √ó metric plots...")

ALL_CORR_TYPES = ["phi", "pointbiserial", "spearman"]

for pooling in ["pooled", "per_prompt"]:
    out_dir = OUT_ROOT / pooling
    out_dir.mkdir(parents=True, exist_ok=True)

    for anchor in ANCHORS_BINARY + ANCHORS_CONTINUOUS:
        df_all = df[
            (df["anchor"] == anchor)
            & (df["pooling"] == pooling)
            & (df["corr_type"].isin(ALL_CORR_TYPES))
        ]
        if df_all.empty:
            print(f"[skip] no data for {anchor} ({pooling})")
            continue

        print(f"[mixed] {anchor} ({pooling}) ‚Üí {df_all['metric'].nunique()} metrics plotted")
        plot_by_metric(df_all, "mixed_all", anchor, pooling, out_dir)
        plot_by_group(df_all, "mixed_all", anchor, pooling, out_dir)

# === QUICK TOPLIST ===
top = (
    df.groupby(["anchor", "metric", "corr_type", "model"])["rho"]
      .mean().reset_index()
      .sort_values("rho", ascending=False)
)
print("\nTop correlations:")
print(top.head(20).to_string(index=False))


In [None]:
df.query("metric == 'ppl_diff'")[["rho","corr_type","anchor","pooling"]].describe()


In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path

# === STYLE ===
sns.set_theme(style="whitegrid", context="talk", palette="deep")
plt.rcParams.update({
    "axes.titlesize": 12,
    "axes.labelsize": 11,
    "legend.fontsize": 8,
    "xtick.labelsize": 9,
    "ytick.labelsize": 9,
})

# === CONFIG ===
BASE_DIR = Path("saved_data/summary")
MODELS = ["m_4bit", "m_8bit", "m_quant"]
ANCHORS_CONTINUOUS = ["logp_diff"]
ANCHORS_BINARY = ["disagree_correct"]
OUT_ROOT = Path("saved_data/figures_rawcorr")
OUT_ROOT.mkdir(parents=True, exist_ok=True)

LAYOUT_MODE = "metric"   # "mode" | "metric" | "group"

SPEARMAN_METRICS = [
    "tvd","kl_ab","kl_ba","js_div","js_dist","cosine_sim",
    "l2_dist","ppl_diff","jaccard_@1","jaccard_@5","jaccard_@10"
]
BISERIAL_METRICS = [
    "acc_A_@1","acc_A_@5","acc_A_@10",
    "acc_B_@1","acc_B_@5","acc_B_@10",
    "disagree_correct","jaccard_@1","jaccard_@5","jaccard_@10"
]
GROUPS = {
    "divergence": ["tvd","kl_ab","kl_ba","js_div","js_dist","ppl_diff"],
    "representation": ["cosine_sim","l2_dist","jaccard_@1","jaccard_@5","jaccard_@10"],
    "accuracy": ["acc_A_@1","acc_B_@1","acc_A_@5","acc_B_@5","disagree_correct"]
}

# === LOAD RAW CORR FILES ===
dfs = []
for model in MODELS:
    model_dir = BASE_DIR / model
    if not model_dir.exists():
        continue
    for f in model_dir.glob("*corr_*.csv"):
        if "summary" in f.name:
            continue
        d = pd.read_csv(f)
        d["model"] = model
        if "pooled" in f.stem:
            d["pooling"] = "pooled"
        elif "perprompt" in f.stem or "per_prompt" in f.stem:
            d["pooling"] = "per_prompt"
        else:
            d["pooling"] = "unknown"
        dfs.append(d)

if not dfs:
    raise RuntimeError("No raw correlation files found!")

df = pd.concat(dfs, ignore_index=True)
for c in ["rho"]:
    df[c] = pd.to_numeric(df[c], errors="coerce")
df["layer_index"] = pd.to_numeric(df["layer_index"], errors="coerce").fillna(0).astype(int)
df["rho_smooth"] = df["rho"]  # raw values only

print(f"[ok] merged {len(df)} rows from {df['model'].nunique()} models.")
print(df.groupby(["anchor","corr_type","pooling"])["rho"].describe().round(3))

# === HELPERS ===
def _auto_subplots(n_items, n_cols=3):
    n_rows = int(np.ceil(n_items / n_cols))
    return n_rows, n_cols

def _finalize_grid(fig, axes, title, save_path=None):
    for ax in axes.flat:
        if not ax.has_data():
            ax.axis("off")
    fig.suptitle(title, fontsize=15, weight="bold")
    fig.tight_layout(rect=[0, 0, 1, 0.95])
    if save_path:
        fig.savefig(save_path, dpi=250, bbox_inches="tight")
        print(f"[saved] {save_path}")
    plt.close(fig)

def _mark_nans(ax, dsub):
    d_nan = dsub[dsub["rho"].isna()]
    if not d_nan.empty:
        for i, model in enumerate(sorted(d_nan["model"].unique())):
            d_m = d_nan[d_nan["model"] == model]
            ax.scatter(
                d_m["layer_index"],
                [-0.95 + 0.05*i] * len(d_m),  # lidt forskudt pr. model
                color="red", marker="x", s=50, label=f"{model} NaN"
            )

# === PLOT FUNCTIONS ===
def plot_by_mode(df_sub, corr_type, anchor, pooling, out_dir):
    fig, axes = plt.subplots(1, 3, figsize=(18, 6), sharey=True)
    for ax, mode in zip(axes, ["raw", "unit_rms", "norm_rms"]):
        dmode = df_sub[df_sub["mode"] == mode]
        if dmode.empty:
            ax.text(0.5, 0.5, "no data", ha="center", va="center", fontsize=11)
            continue
        sns.lineplot(
            data=dmode, x="layer_index", y="rho_smooth",
            hue="model", style="metric", lw=2, ax=ax
        )
        _mark_nans(ax, dmode)
        ax.axhline(0, color="black", linestyle=":")
        ax.axhline(0.3, color="gray", linestyle="--", lw=0.6, alpha=0.4)
        ax.axhline(-0.3, color="gray", linestyle="--", lw=0.6, alpha=0.4)
        ax.set_ylim(-1, 1)
        ax.set_title(mode.upper(), fontsize=12, weight="bold")
        ax.set_xlabel("Layer index")
        ax.set_ylabel("œÅ" if mode == "raw" else "")
        ax.legend(fontsize=8)
    _finalize_grid(fig, axes, f"{corr_type.capitalize()} ‚Äî {anchor} ({pooling})",
                   out_dir / f"{corr_type}_{anchor}_{pooling}_modes.png")

def plot_by_metric(df_sub, corr_type, anchor, pooling, out_dir, n_cols=3):
    metrics = sorted(df_sub["metric"].unique())
    n_rows, n_cols = _auto_subplots(len(metrics), n_cols)
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(6 * n_cols, 3.5 * n_rows), sharey=True)
    axes = np.array(axes).reshape(-1)
    for ax, metric in zip(axes, metrics):
        dmet = df_sub[df_sub["metric"] == metric]
        if dmet.empty:
            continue
        sns.lineplot(
            data=dmet, x="layer_index", y="rho_smooth",
            hue="model", style="mode", lw=2.2, ax=ax
        )
        _mark_nans(ax, dmet)
        ax.axhline(0, color="black", linestyle=":")
        ax.axhline(0.3, color="gray", linestyle="--", lw=0.6, alpha=0.4)
        ax.axhline(-0.3, color="gray", linestyle="--", lw=0.6, alpha=0.4)
        ax.set_ylim(-1, 1)
        ax.set_title(metric, fontsize=10)
        ax.set_xlabel("Layer index")
        ax.set_ylabel("œÅ")
        ax.legend(fontsize=8, frameon=True)
    _finalize_grid(fig, axes, f"{corr_type.capitalize()} ‚Äî {anchor} ({pooling})",
                   out_dir / f"{corr_type}_{anchor}_{pooling}_metrics.png")

def plot_by_group(df_sub, corr_type, anchor, pooling, out_dir):
    for gname, gmetrics in GROUPS.items():
        dgroup = df_sub[df_sub["metric"].isin(gmetrics)]
        if dgroup.empty:
            continue
        fig, axes = plt.subplots(1, 3, figsize=(18, 6), sharey=True)
        for ax, mode in zip(axes, ["raw", "unit_rms", "norm_rms"]):
            dmode = dgroup[dgroup["mode"] == mode]
            if dmode.empty:
                ax.text(0.5, 0.5, "no data", ha="center", va="center", fontsize=11)
                continue
            sns.lineplot(
                data=dmode, x="layer_index", y="rho_smooth",
                hue="model", style="metric", lw=2, ax=ax
            )
            _mark_nans(ax, dmode)
            ax.axhline(0, color="black", linestyle=":")
            ax.axhline(0.3, color="gray", linestyle="--", lw=0.6, alpha=0.4)
            ax.axhline(-0.3, color="gray", linestyle="--", lw=0.6, alpha=0.4)
            ax.set_ylim(-1, 1)
            ax.set_title(f"{gname.title()} ({mode})", fontsize=11)
            ax.set_xlabel("Layer index")
            ax.set_ylabel("œÅ" if mode == "raw" else "")
            ax.legend(fontsize=8, frameon=True)
        _finalize_grid(fig, axes,
                       f"{gname.capitalize()} ‚Äî {corr_type.capitalize()} ({anchor}/{pooling})",
                       out_dir / f"{corr_type}_{anchor}_{pooling}_{gname}.png")

# === MAIN LOOP ===
for pooling in ["pooled", "per_prompt"]:
    out_dir = OUT_ROOT / pooling
    out_dir.mkdir(parents=True, exist_ok=True)

    # Continuous (Spearman)
    for anchor in ANCHORS_CONTINUOUS:
        df_s = df[
            (df["anchor"] == anchor)
            & (df["corr_type"] == "spearman")
            & (df["pooling"] == pooling)
            & (df["metric"].isin(SPEARMAN_METRICS))
        ]
        if df_s.empty:
            continue
        if LAYOUT_MODE == "mode":
            plot_by_mode(df_s, "spearman", anchor, pooling, out_dir)
        elif LAYOUT_MODE == "metric":
            plot_by_metric(df_s, "spearman", anchor, pooling, out_dir)
        elif LAYOUT_MODE == "group":
            plot_by_group(df_s, "spearman", anchor, pooling, out_dir)

    # Binary (Point-biserial & Phi)
    for corr_type in ["pointbiserial", "phi"]:
        for anchor in ANCHORS_BINARY:
            df_b = df[
                (df["anchor"] == anchor)
                & (df["corr_type"] == corr_type)
                & (df["pooling"] == pooling)
                & (df["metric"].isin(BISERIAL_METRICS))
            ]
            if df_b.empty:
                continue
            if LAYOUT_MODE == "mode":
                plot_by_mode(df_b, corr_type, anchor, pooling, out_dir)
            elif LAYOUT_MODE == "metric":
                plot_by_metric(df_b, corr_type, anchor, pooling, out_dir)
            elif LAYOUT_MODE == "group":
                plot_by_group(df_b, corr_type, anchor, pooling, out_dir)

# === QUICK TOPLIST ===
top = (
    df.groupby(["anchor", "metric", "corr_type", "model"])["rho"]
      .mean().reset_index()
      .sort_values("rho", ascending=False)
)
print("\nTop correlations:")
print(top.head(20).to_string(index=False))


In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path

# === STYLE ===
sns.set_theme(style="whitegrid", context="talk", palette="deep")
plt.rcParams.update({
    "axes.titlesize": 12,
    "axes.labelsize": 11,
    "legend.fontsize": 8,
    "xtick.labelsize": 9,
    "ytick.labelsize": 9,
})

# === CONFIG ===
BASE_DIR = Path("saved_data/summary")
MODELS = ["m_4bit", "m_8bit", "m_quant"]
ANCHORS_CONTINUOUS = ["logp_diff"]
ANCHORS_BINARY = ["disagree_correct"]
OUT_ROOT = Path("saved_data/figures_flexible_rawcorr")
OUT_ROOT.mkdir(parents=True, exist_ok=True)

LAYOUT_MODE = "metric"   # "mode" | "metric" | "group"

SPEARMAN_METRICS = [
    "tvd","kl_ab","kl_ba","js_div","js_dist","cosine_sim",
    "l2_dist","ppl_diff","jaccard_@1","jaccard_@5","jaccard_@10"
]
BISERIAL_METRICS = [
    "acc_A_@1","acc_A_@5","acc_A_@10",
    "acc_B_@1","acc_B_@5","acc_B_@10",
    "disagree_correct","jaccard_@1","jaccard_@5","jaccard_@10"
]
GROUPS = {
    "divergence": ["tvd","kl_ab","kl_ba","js_div","js_dist","ppl_diff"],
    "representation": ["cosine_sim","l2_dist","jaccard_@1","jaccard_@5","jaccard_@10"],
    "accuracy": ["acc_A_@1","acc_B_@1","acc_A_@5","acc_B_@5","disagree_correct"]
}

# === LOAD RAW CORR FILES ===
dfs = []
for model in MODELS:
    model_dir = BASE_DIR / model
    if not model_dir.exists():
        continue
    for f in model_dir.glob("*corr_*.csv"):
        if "summary" in f.name:
            continue
        d = pd.read_csv(f)
        d["model"] = model
        if "pooled" in f.stem:
            d["pooling"] = "pooled"
        elif "perprompt" in f.stem or "per_prompt" in f.stem:
            d["pooling"] = "per_prompt"
        else:
            d["pooling"] = "unknown"
        dfs.append(d)

if not dfs:
    raise RuntimeError("No raw correlation files found!")

df = pd.concat(dfs, ignore_index=True)
for c in ["rho"]:
    df[c] = pd.to_numeric(df[c], errors="coerce")
df["layer_index"] = pd.to_numeric(df["layer_index"], errors="coerce").fillna(0).astype(int)

# === RAW VALUES (no smoothing) ===
df["rho_smooth"] = df["rho"]

print(f"[ok] merged {len(df)} rows from {df['model'].nunique()} models.")
print(df.groupby(["anchor","corr_type","pooling"])["rho"].describe().round(3))

# === HELPERS ===
def _auto_subplots(n_items, n_cols=3):
    n_rows = int(np.ceil(n_items / n_cols))
    return n_rows, n_cols

def _finalize_grid(fig, axes, title, save_path=None):
    for ax in axes.flat:
        if not ax.has_data():
            ax.axis("off")
    fig.suptitle(title, fontsize=15, weight="bold")
    fig.tight_layout(rect=[0, 0, 1, 0.95])
    if save_path:
        fig.savefig(save_path, dpi=250, bbox_inches="tight")
        print(f"[saved] {save_path}")
    plt.close(fig)

# === ADD NAN MARKERS ===
def _mark_nans(ax, dsub):
    d_nan = dsub[dsub["rho"].isna()]
    if not d_nan.empty:
        ax.scatter(
            d_nan["layer_index"],
            [0] * len(d_nan),
            color="red", marker="x", s=60, label="NaN / no corr"
        )

# === PLOT FUNCTIONS ===
def plot_by_mode(df_sub, corr_type, anchor, pooling, out_dir):
    fig, axes = plt.subplots(1, 3, figsize=(18, 6), sharey=True)
    for ax, mode in zip(axes, ["raw", "unit_rms", "norm_rms"]):
        dmode = df_sub[df_sub["mode"] == mode]
        if dmode.empty:
            ax.text(0.5, 0.5, "no data", ha="center", va="center", fontsize=11)
            continue
        sns.lineplot(
            data=dmode, x="layer_index", y="rho_smooth",
            hue="model", style="metric", lw=2, ax=ax
        )
        _mark_nans(ax, dmode)
        ax.axhline(0, color="black", linestyle=":")
        ax.axhline(0.3, color="gray", linestyle="--", lw=0.6, alpha=0.4)
        ax.axhline(-0.3, color="gray", linestyle="--", lw=0.6, alpha=0.4)
        ax.set_ylim(-1, 1)
        ax.set_title(mode.upper(), fontsize=12, weight="bold")
        ax.set_xlabel("Layer index")
        ax.set_ylabel("œÅ" if mode == "raw" else "")
        ax.legend(fontsize=8)
    _finalize_grid(fig, axes, f"{corr_type.capitalize()} ‚Äî {anchor} ({pooling})",
                   out_dir / f"{corr_type}_{anchor}_{pooling}_modes.png")

def plot_by_metric(df_sub, corr_type, anchor, pooling, out_dir, n_cols=3):
    metrics = sorted(df_sub["metric"].unique())
    n_rows, n_cols = _auto_subplots(len(metrics), n_cols)
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(6 * n_cols, 3.5 * n_rows), sharey=True)
    axes = np.array(axes).reshape(-1)
    for ax, metric in zip(axes, metrics):
        dmet = df_sub[df_sub["metric"] == metric]
        if dmet.empty:
            continue
        sns.lineplot(
            data=dmet, x="layer_index", y="rho_smooth",
            hue="model", style="mode", lw=2.2, ax=ax
        )
        _mark_nans(ax, dmet)
        ax.axhline(0, color="black", linestyle=":")
        ax.axhline(0.3, color="gray", linestyle="--", lw=0.6, alpha=0.4)
        ax.axhline(-0.3, color="gray", linestyle="--", lw=0.6, alpha=0.4)
        ax.set_ylim(-1, 1)
        ax.set_title(metric, fontsize=10)
        ax.set_xlabel("Layer index")
        ax.set_ylabel("œÅ")
        ax.legend(fontsize=8, frameon=True)
    _finalize_grid(fig, axes, f"{corr_type.capitalize()} ‚Äî {anchor} ({pooling})",
                   out_dir / f"{corr_type}_{anchor}_{pooling}_metrics.png")

def plot_by_group(df_sub, corr_type, anchor, pooling, out_dir):
    for gname, gmetrics in GROUPS.items():
        dgroup = df_sub[df_sub["metric"].isin(gmetrics)]
        if dgroup.empty:
            continue
        fig, axes = plt.subplots(1, 3, figsize=(18, 6), sharey=True)
        for ax, mode in zip(axes, ["raw", "unit_rms", "norm_rms"]):
            dmode = dgroup[dgroup["mode"] == mode]
            if dmode.empty:
                ax.text(0.5, 0.5, "no data", ha="center", va="center", fontsize=11)
                continue
            sns.lineplot(
                data=dmode, x="layer_index", y="rho_smooth",
                hue="model", style="metric", lw=2, ax=ax
            )
            _mark_nans(ax, dmode)
            ax.axhline(0, color="black", linestyle=":")
            ax.axhline(0.3, color="gray", linestyle="--", lw=0.6, alpha=0.4)
            ax.axhline(-0.3, color="gray", linestyle="--", lw=0.6, alpha=0.4)
            ax.set_ylim(-1, 1)
            ax.set_title(f"{gname.title()} ({mode})", fontsize=11)
            ax.set_xlabel("Layer index")
            ax.set_ylabel("œÅ" if mode == "raw" else "")
            ax.legend(fontsize=8, frameon=True)
        _finalize_grid(fig, axes,
                       f"{gname.capitalize()} ‚Äî {corr_type.capitalize()} ({anchor}/{pooling})",
                       out_dir / f"{corr_type}_{anchor}_{pooling}_{gname}.png")

# === MAIN LOOP ===
for pooling in ["pooled", "per_prompt"]:
    out_dir = OUT_ROOT / pooling
    out_dir.mkdir(parents=True, exist_ok=True)

    # Continuous (Spearman)
    for anchor in ANCHORS_CONTINUOUS:
        df_s = df[
            (df["anchor"] == anchor)
            & (df["corr_type"] == "spearman")
            & (df["pooling"] == pooling)
            & (df["metric"].isin(SPEARMAN_METRICS))
        ]
        if df_s.empty:
            continue
        if LAYOUT_MODE == "mode":
            plot_by_mode(df_s, "spearman", anchor, pooling, out_dir)
        elif LAYOUT_MODE == "metric":
            plot_by_metric(df_s, "spearman", anchor, pooling, out_dir)
        elif LAYOUT_MODE == "group":
            plot_by_group(df_s, "spearman", anchor, pooling, out_dir)

    # Binary (Point-biserial & Phi)
    for corr_type in ["pointbiserial", "phi"]:
        for anchor in ANCHORS_BINARY:
            df_b = df[
                (df["anchor"] == anchor)
                & (df["corr_type"] == corr_type)
                & (df["pooling"] == pooling)
                & (df["metric"].isin(BISERIAL_METRICS))
            ]
            if df_b.empty:
                continue
            if LAYOUT_MODE == "mode":
                plot_by_mode(df_b, corr_type, anchor, pooling, out_dir)
            elif LAYOUT_MODE == "metric":
                plot_by_metric(df_b, corr_type, anchor, pooling, out_dir)
            elif LAYOUT_MODE == "group":
                plot_by_group(df_b, corr_type, anchor, pooling, out_dir)

# === QUICK TOPLIST ===
top = (
    df.groupby(["anchor", "metric", "corr_type", "model"])["rho"]
      .mean().reset_index()
      .sort_values("rho", ascending=False)
)
print("\nTop correlations:")
print(top.head(20).to_string(index=False))



In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path


sns.set_theme(style="whitegrid", context="talk", palette="deep")


BASE_DIR = Path("saved_data/summary")
MODELS = ["m_4bit", "m_8bit", "m_quant"]
ANCHORS_CONTINUOUS = ["logp_diff"]
ANCHORS_BINARY = ["disagree_correct"]
OUT_ROOT = Path("saved_data/figures_flexible_clean")
OUT_ROOT.mkdir(parents=True, exist_ok=True)

LAYOUT_MODE = "metric"   # "mode" | "metric" | "group"

SPEARMAN_METRICS = [
    "tvd","kl_ab","kl_ba","js_div","js_dist","cosine_sim",
    "l2_dist","ppl_diff","jaccard_@1","jaccard_@5","jaccard_@10"
]
BISERIAL_METRICS = [
    "acc_A_@1","acc_A_@5","acc_A_@10",
    "acc_B_@1","acc_B_@5","acc_B_@10",
    "disagree_correct","jaccard_@1","jaccard_@5","jaccard_@10"
]

GROUPS = {
    "divergence": ["tvd","kl_ab","kl_ba","js_div","js_dist","ppl_diff"],
    "representation": ["cosine_sim","l2_dist","jaccard_@1","jaccard_@5","jaccard_@10"],
    "accuracy": ["acc_A_@1","acc_B_@1","acc_A_@5","acc_B_@5","disagree_correct"]
}


dfs = []
for model in MODELS:
    model_dir = BASE_DIR / model
    if not model_dir.exists():
        continue
    for f in model_dir.glob("*_summary*.csv"): 
        d = pd.read_csv(f)
        d["model"] = model
        if "pooled" in f.stem:
            d["pooling"] = "pooled"
        elif "perprompt" in f.stem or "per_prompt" in f.stem:
            d["pooling"] = "per_prompt"
        else:
            d["pooling"] = "unknown"
        dfs.append(d)

if not dfs:
    raise RuntimeError("No summary files found!")

df = pd.concat(dfs, ignore_index=True)
for c in ["rho_mean","rho_low","rho_high"]:
    df[c] = pd.to_numeric(df[c], errors="coerce")
df["layer_index"] = pd.to_numeric(df["layer_index"], errors="coerce").fillna(0).astype(int)
print(f"[ok] merged {len(df)} rows from {df['model'].nunique()} models.")


def _auto_subplots(n_items, n_cols=3):
    """Return dynamic rows/cols for a clean subplot grid."""
    n_rows = int(np.ceil(n_items / n_cols))
    return n_rows, n_cols


def _finalize_grid(fig, axes, title, save_path=None):
    """Uniform cleanup for figure layout."""
    for ax in axes.flat:
        if not ax.has_data():
            ax.axis("off")
    fig.suptitle(title, fontsize=15, weight="bold")
    fig.tight_layout(rect=[0, 0, 1, 0.95])
    if save_path:
        fig.savefig(save_path, dpi=250, bbox_inches="tight")
        print(f"[saved] {save_path}")
    plt.close(fig)


def plot_by_mode(df_sub, corr_type, anchor, pooling, out_dir):
    fig, axes = plt.subplots(1, 3, figsize=(18, 6), sharey=True)
    modes = ["raw", "unit_rms", "norm_rms"]

    for ax, mode in zip(axes, modes):
        dmode = df_sub[df_sub["mode"] == mode]
        if dmode.empty:
            ax.text(0.5, 0.5, "no data", ha="center", va="center", fontsize=11)
            continue
        sns.lineplot(
            data=dmode, x="layer_index", y="rho_mean",
            hue="model", style="metric", markers=False,
            lw=2, ax=ax
        )
        ax.set_title(mode.upper(), fontsize=12, weight="bold")
        ax.set_xlabel("Layer index")
        ax.set_ylabel("Mean œÅ" if mode == "raw" else "")
        ax.axhline(0, color="black", linestyle=":")
        ax.legend(fontsize=8)

    _finalize_grid(fig, axes, f"{corr_type.capitalize()} ‚Äî {anchor} ({pooling})",
                   out_dir / f"{corr_type}_{anchor}_{pooling}_modes.png")


def plot_by_metric(df_sub, corr_type, anchor, pooling, out_dir, n_cols=3):
    metrics = sorted(df_sub["metric"].unique())
    n_rows, n_cols = _auto_subplots(len(metrics), n_cols)
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(6 * n_cols, 3.5 * n_rows), sharey=True)
    axes = np.array(axes).reshape(-1)

    for ax, metric in zip(axes, metrics):
        dmet = df_sub[df_sub["metric"] == metric]
        if dmet.empty:
            continue
        sns.lineplot(
            data=dmet, x="layer_index", y="rho_mean",
            hue="model", style="mode", lw=2.2,
            ax=ax
        )
        ax.axhline(0, color="black", linestyle=":")
        ax.set_title(metric, fontsize=10)
        ax.set_xlabel("Layer index")
        ax.set_ylabel("œÅ")
        ax.legend(fontsize=8, frameon=True)

    _finalize_grid(fig, axes, f"{corr_type.capitalize()} ‚Äî {anchor} ({pooling})",
                   out_dir / f"{corr_type}_{anchor}_{pooling}_metrics.png")


def plot_by_group(df_sub, corr_type, anchor, pooling, out_dir):
    for gname, gmetrics in GROUPS.items():
        dgroup = df_sub[df_sub["metric"].isin(gmetrics)]
        if dgroup.empty:
            continue
        fig, axes = plt.subplots(1, 3, figsize=(18, 6), sharey=True)
        for ax, mode in zip(axes, ["raw", "unit_rms", "norm_rms"]):
            dmode = dgroup[dgroup["mode"] == mode]
            if dmode.empty:
                ax.text(0.5, 0.5, "no data", ha="center", va="center", fontsize=11)
                continue
            sns.lineplot(
                data=dmode, x="layer_index", y="rho_mean",
                hue="model", style="metric", lw=2,
                ax=ax
            )
            ax.axhline(0, color="black", linestyle=":")
            ax.set_title(f"{gname.title()} ({mode})", fontsize=11)
            ax.set_xlabel("Layer index")
            ax.set_ylabel("œÅ" if mode == "raw" else "")
            ax.legend(fontsize=8, frameon=True)

        _finalize_grid(fig, axes,
                       f"{gname.capitalize()} ‚Äî {corr_type.capitalize()} ({anchor}/{pooling})",
                       out_dir / f"{corr_type}_{anchor}_{pooling}_{gname}.png")


for pooling in ["pooled", "per_prompt"]:
    out_dir = OUT_ROOT / pooling
    out_dir.mkdir(parents=True, exist_ok=True)

    # Continuous anchors Spearman
    for anchor in ANCHORS_CONTINUOUS:
        df_s = df[
            (df["anchor"] == anchor)
            & (df["corr_type"] == "spearman")
            & (df["pooling"] == pooling)
            & (df["metric"].isin(SPEARMAN_METRICS))
        ]
        if df_s.empty:
            continue
        if LAYOUT_MODE == "mode":
            plot_by_mode(df_s, "spearman", anchor, pooling, out_dir)
        elif LAYOUT_MODE == "metric":
            plot_by_metric(df_s, "spearman", anchor, pooling, out_dir)
        elif LAYOUT_MODE == "group":
            plot_by_group(df_s, "spearman", anchor, pooling, out_dir)

    # Binary anchors Point-biserial
    for anchor in ANCHORS_BINARY:
        df_b = df[
            (df["anchor"] == anchor)
            & (df["corr_type"] == "pointbiserial")
            & (df["pooling"] == pooling)
            & (df["metric"].isin(BISERIAL_METRICS))
        ]
        if df_b.empty:
            continue
        if LAYOUT_MODE == "mode":
            plot_by_mode(df_b, "pointbiserial", anchor, pooling, out_dir)
        elif LAYOUT_MODE == "metric":
            plot_by_metric(df_b, "pointbiserial", anchor, pooling, out_dir)
        elif LAYOUT_MODE == "group":
            plot_by_group(df_b, "pointbiserial", anchor, pooling, out_dir)


In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path

sns.set_theme(style="whitegrid", context="talk", palette="deep")

# === CONFIG ===
BASE_DIR = Path("saved_data/summary")
MODELS = ["m_4bit", "m_8bit", "m_quant"]
ANCHORS_CONTINUOUS = ["logp_diff"]
ANCHORS_BINARY = ["disagree_correct"]
OUT_ROOT = Path("saved_data/figures_flexible_rawcorr")
OUT_ROOT.mkdir(parents=True, exist_ok=True)

LAYOUT_MODE = "metric"   # "mode" | "metric" | "group"

SPEARMAN_METRICS = [
    "tvd","kl_ab","kl_ba","js_div","js_dist","cosine_sim",
    "l2_dist","ppl_diff","jaccard_@1","jaccard_@5","jaccard_@10"
]
BISERIAL_METRICS = [
    "acc_A_@1","acc_A_@5","acc_A_@10",
    "acc_B_@1","acc_B_@5","acc_B_@10",
    "disagree_correct","jaccard_@1","jaccard_@5","jaccard_@10"
]
GROUPS = {
    "divergence": ["tvd","kl_ab","kl_ba","js_div","js_dist","ppl_diff"],
    "representation": ["cosine_sim","l2_dist","jaccard_@1","jaccard_@5","jaccard_@10"],
    "accuracy": ["acc_A_@1","acc_B_@1","acc_A_@5","acc_B_@5","disagree_correct"]
}


# === LOAD RAW CORR FILES (not summaries) ===
dfs = []
for model in MODELS:
    model_dir = BASE_DIR / model
    if not model_dir.exists():
        continue
    for f in model_dir.glob("*corr_*.csv"):
        if "summary" in f.name:
            continue  # skip summaries
        d = pd.read_csv(f)
        d["model"] = model
        if "pooled" in f.stem:
            d["pooling"] = "pooled"
        elif "perprompt" in f.stem or "per_prompt" in f.stem:
            d["pooling"] = "per_prompt"
        else:
            d["pooling"] = "unknown"
        dfs.append(d)

if not dfs:
    raise RuntimeError("No raw correlation files found!")

df = pd.concat(dfs, ignore_index=True)
for c in ["rho"]:
    df[c] = pd.to_numeric(df[c], errors="coerce")
df["layer_index"] = pd.to_numeric(df["layer_index"], errors="coerce").fillna(0).astype(int)
print(f"[ok] merged {len(df)} raw correlation rows from {df['model'].nunique()} models.")


# === PLOTTING HELPERS ===
def _auto_subplots(n_items, n_cols=3):
    n_rows = int(np.ceil(n_items / n_cols))
    return n_rows, n_cols

def _finalize_grid(fig, axes, title, save_path=None):
    for ax in axes.flat:
        if not ax.has_data():
            ax.axis("off")
    fig.suptitle(title, fontsize=15, weight="bold")
    fig.tight_layout(rect=[0, 0, 1, 0.95])
    if save_path:
        fig.savefig(save_path, dpi=250, bbox_inches="tight")
        print(f"[saved] {save_path}")
    plt.close(fig)


# === PLOT FUNCTIONS (identiske men y='rho') ===
def plot_by_mode(df_sub, corr_type, anchor, pooling, out_dir):
    fig, axes = plt.subplots(1, 3, figsize=(18, 6), sharey=True)
    modes = ["raw", "unit_rms", "norm_rms"]
    for ax, mode in zip(axes, modes):
        dmode = df_sub[df_sub["mode"] == mode]
        if dmode.empty:
            ax.text(0.5, 0.5, "no data", ha="center", va="center", fontsize=11)
            continue
        sns.lineplot(
            data=dmode, x="layer_index", y="rho",
            hue="model", style="metric", lw=2, ax=ax
        )
        ax.set_title(mode.upper(), fontsize=12, weight="bold")
        ax.set_xlabel("Layer index")
        ax.set_ylabel("œÅ" if mode == "raw" else "")
        ax.axhline(0, color="black", linestyle=":")
        ax.legend(fontsize=8)
    _finalize_grid(fig, axes, f"{corr_type.capitalize()} ‚Äî {anchor} ({pooling})",
                   out_dir / f"{corr_type}_{anchor}_{pooling}_modes_raw.png")


def plot_by_metric(df_sub, corr_type, anchor, pooling, out_dir, n_cols=3):
    metrics = sorted(df_sub["metric"].unique())
    n_rows, n_cols = _auto_subplots(len(metrics), n_cols)
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(6 * n_cols, 3.5 * n_rows), sharey=True)
    axes = np.array(axes).reshape(-1)
    for ax, metric in zip(axes, metrics):
        dmet = df_sub[df_sub["metric"] == metric]
        if dmet.empty:
            continue
        sns.lineplot(
            data=dmet, x="layer_index", y="rho",
            hue="model", style="mode", lw=2.2, ax=ax
        )
        ax.axhline(0, color="black", linestyle=":")
        ax.set_title(metric, fontsize=10)
        ax.set_xlabel("Layer index")
        ax.set_ylabel("œÅ")
        ax.legend(fontsize=8, frameon=True)
    _finalize_grid(fig, axes, f"{corr_type.capitalize()} ‚Äî {anchor} ({pooling})",
                   out_dir / f"{corr_type}_{anchor}_{pooling}_metrics_raw.png")


def plot_by_group(df_sub, corr_type, anchor, pooling, out_dir):
    for gname, gmetrics in GROUPS.items():
        dgroup = df_sub[df_sub["metric"].isin(gmetrics)]
        if dgroup.empty:
            continue
        fig, axes = plt.subplots(1, 3, figsize=(18, 6), sharey=True)
        for ax, mode in zip(axes, ["raw", "unit_rms", "norm_rms"]):
            dmode = dgroup[dgroup["mode"] == mode]
            if dmode.empty:
                ax.text(0.5, 0.5, "no data", ha="center", va="center", fontsize=11)
                continue
            sns.lineplot(
                data=dmode, x="layer_index", y="rho",
                hue="model", style="metric", lw=2, ax=ax
            )
            ax.axhline(0, color="black", linestyle=":")
            ax.set_title(f"{gname.title()} ({mode})", fontsize=11)
            ax.set_xlabel("Layer index")
            ax.set_ylabel("œÅ" if mode == "raw" else "")
            ax.legend(fontsize=8, frameon=True)
        _finalize_grid(fig, axes,
                       f"{gname.capitalize()} ‚Äî {corr_type.capitalize()} ({anchor}/{pooling})",
                       out_dir / f"{corr_type}_{anchor}_{pooling}_{gname}_raw.png")


# === MAIN LOOP ===
for pooling in ["pooled", "per_prompt"]:
    out_dir = OUT_ROOT / pooling
    out_dir.mkdir(parents=True, exist_ok=True)

    # Continuous (Spearman)
    for anchor in ANCHORS_CONTINUOUS:
        df_s = df[
            (df["anchor"] == anchor)
            & (df["corr_type"] == "spearman")
            & (df["pooling"] == pooling)
            & (df["metric"].isin(SPEARMAN_METRICS))
        ]
        if df_s.empty:
            continue
        if LAYOUT_MODE == "mode":
            plot_by_mode(df_s, "spearman", anchor, pooling, out_dir)
        elif LAYOUT_MODE == "metric":
            plot_by_metric(df_s, "spearman", anchor, pooling, out_dir)
        elif LAYOUT_MODE == "group":
            plot_by_group(df_s, "spearman", anchor, pooling, out_dir)

    # Binary (Point-biserial)
    for anchor in ANCHORS_BINARY:
        df_b = df[
            (df["anchor"] == anchor)
            & (df["corr_type"] == "pointbiserial")
            & (df["pooling"] == pooling)
            & (df["metric"].isin(BISERIAL_METRICS))
        ]
        if df_b.empty:
            continue
        if LAYOUT_MODE == "mode":
            plot_by_mode(df_b, "pointbiserial", anchor, pooling, out_dir)
        elif LAYOUT_MODE == "metric":
            plot_by_metric(df_b, "pointbiserial", anchor, pooling, out_dir)
        elif LAYOUT_MODE == "group":
            plot_by_group(df_b, "pointbiserial", anchor, pooling, out_dir)


In [None]:
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from pathlib import Path

sns.set_theme(style="whitegrid", context="talk", palette="deep")

# === CONFIG ===
BASE_DIR = Path("saved_data/summary")
MODELS = ["m_4bit", "m_8bit", "m_quant"]
ANCHORS_CONTINUOUS = ["logp_diff"]
ANCHORS_BINARY = ["disagree_correct"]
OUT_ROOT = Path("saved_data/figures_heatmaps_rawcorr")
OUT_ROOT.mkdir(parents=True, exist_ok=True)

SPEARMAN_METRICS = [
    "tvd","kl_ab","kl_ba","js_div","js_dist","cosine_sim",
    "l2_dist","ppl_diff","jaccard_@1","jaccard_@5","jaccard_@10"
]
BISERIAL_METRICS = [
    "acc_A_@1","acc_A_@5","acc_A_@10",
    "acc_B_@1","acc_B_@5","acc_B_@10",
    "disagree_correct","jaccard_@1","jaccard_@5","jaccard_@10"
]

# === LOAD RAW CORR FILES ===
dfs = []
for model in MODELS:
    model_dir = BASE_DIR / model
    if not model_dir.exists():
        continue
    for f in model_dir.glob("*corr_*.csv"):
        if "summary" in f.name:
            continue
        d = pd.read_csv(f)
        d["model"] = model
        if "pooled" in f.stem:
            d["pooling"] = "pooled"
        elif "perprompt" in f.stem or "per_prompt" in f.stem:
            d["pooling"] = "per_prompt"
        else:
            d["pooling"] = "unknown"
        dfs.append(d)

if not dfs:
    raise RuntimeError("No raw correlation files found!")

df = pd.concat(dfs, ignore_index=True)
df["rho"] = pd.to_numeric(df["rho"], errors="coerce")
df["layer_index"] = pd.to_numeric(df["layer_index"], errors="coerce").fillna(0).astype(int)
print(f"[ok] merged {len(df)} rows from {df['model'].nunique()} models.")


# === HEATMAP FUNCTION ===
def plot_heatmap(df_sub, corr_type, anchor, pooling, out_dir):
    """
    Plots rho values as a heatmap (metric x layer) for each model/mode combo.
    """
    modes = ["raw", "unit_rms", "norm_rms"]
    for mode in modes:
        dmode = df_sub[df_sub["mode"] == mode]
        if dmode.empty:
            continue

        for model in dmode["model"].unique():
            dmodel = dmode[dmode["model"] == model]
            if dmodel.empty:
                continue

            # Pivot so that rows = metric, columns = layer_index
            heat = dmodel.pivot_table(
                index="metric", columns="layer_index", values="rho", aggfunc="mean"
            ).sort_index()

            plt.figure(figsize=(12, max(6, 0.4 * len(heat))))
            sns.heatmap(
                heat, cmap="coolwarm", center=0, annot=True, fmt=".2f",
                cbar_kws={"label": "œÅ (correlation)"}, linewidths=0.4
            )
            plt.title(f"{model} ‚Äî {anchor} ({pooling}, {corr_type}, {mode})", fontsize=13, weight="bold")
            plt.xlabel("Layer index")
            plt.ylabel("Metric")
            plt.tight_layout()

            out_path = out_dir / f"{model}_{corr_type}_{anchor}_{pooling}_{mode}_heatmap.png"
            plt.savefig(out_path, dpi=300, bbox_inches="tight")
            plt.close()
            print(f"[saved] {out_path}")


# === MAIN LOOP ===
for pooling in ["pooled", "per_prompt"]:
    out_dir = OUT_ROOT / pooling
    out_dir.mkdir(parents=True, exist_ok=True)

    # Continuous anchors (Spearman)
    for anchor in ANCHORS_CONTINUOUS:
        df_s = df[
            (df["anchor"] == anchor)
            & (df["corr_type"] == "spearman")
            & (df["pooling"] == pooling)
            & (df["metric"].isin(SPEARMAN_METRICS))
        ]
        if not df_s.empty:
            plot_heatmap(df_s, "spearman", anchor, pooling, out_dir)

    # Binary anchors (Point-biserial)
    for anchor in ANCHORS_BINARY:
        df_b = df[
            (df["anchor"] == anchor)
            & (df["corr_type"] == "pointbiserial")
            & (df["pooling"] == pooling)
            & (df["metric"].isin(BISERIAL_METRICS))
        ]
        if not df_b.empty:
            plot_heatmap(df_b, "pointbiserial", anchor, pooling, out_dir)
