# ==============================================
# Notebook for Logit Lens and Hidden Acts ======
# ==============================================

In [1]:
from pathlib import Path
from datasets import load_dataset, DownloadMode
import torch
import os
import glob
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

from old_lens.mi_utils.quant_configs.bnb_configs import load_bnb_in_4bit

# ==============================================
# Models & Dataset =============================
# ==============================================

In [None]:
filepath = r'D:\LogitLensData\nq'

destination_path = str(Path(filepath))
nq_dataset = load_dataset(
    'sentence-transformers/natural-questions',
    split={
        'train': 'train[:1000]'
    },
    cache_dir=destination_path,
    download_mode=DownloadMode.REUSE_DATASET_IF_EXISTS,
    keep_in_memory=True
)

nq_queries = nq_dataset['train']['query']
nq_answers = nq_dataset['train']['answer']

nq_queries_200 = nq_queries[:200]
"""nq_queries_400 = nq_queries[200:400]
nq_queries_600 = nq_queries[400:600]
nq_queries_800 = nq_queries[600:800]
nq_queries_1000 = nq_queries[800:1000]
nq_queries_200_1 = nq_queries[:1]"""
nq_1000 = nq_queries[:1000]

In [5]:
from enum import Enum

class Models(Enum):
    LAIN8B = "Models/LLaMA3Instruct"
    HF100B = "Models/HF1BitLLM100Btokens"


class Names(Enum):
    LAIN8B = "Meta-Llama-3-8B-Instruct-fp"
    HF100B = "Llama3-8B-1.58-100B-tokens"

In [3]:
def load_model_and_tok(
    model_name: str,
    low_cpu_mem_usage: bool = True,
    local_files_only: bool = True,
    device_map: str = "cpu",
    dtype: torch.dtype = torch.float32,
    load_in_8bit: bool = False
):
    tok = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        output_hidden_states=True,
        return_dict_in_generate=True,
        return_dict=True,
        output_attentions=True,
        low_cpu_mem_usage=low_cpu_mem_usage,
        local_files_only=local_files_only,
        device_map=device_map,
        load_in_8bit=load_in_8bit,
        torch_dtype=dtype,
        attn_implementation="eager",
    )
    return model, tok

In [6]:
model_orig, orig_tokenizer = load_model_and_tok(Models.LAIN8B.value) 

Loading checkpoint shards:   0%|          | 0/7 [00:00<?, ?it/s]

In [5]:
model_8bit, orig_tokenizer = load_model_and_tok(Models.LAIN8B.value, dtype=torch.float16, load_in_8bit=True) 

The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. Please, pass a `BitsAndBytesConfig` object in `quantization_config` argument instead.


Loading checkpoint shards:   0%|          | 0/7 [00:00<?, ?it/s]

In [6]:
model_4bit, orig_tokenizer = load_bnb_in_4bit(Models.LAIN8B.value, double_quant=False, dtype=torch.float16, device_map="cpu") 

Loading checkpoint shards:   0%|          | 0/7 [00:00<?, ?it/s]

In [4]:
model_quant, quant_tokenizer = load_model_and_tok(Models.HF100B.value) 

In [6]:
lengths = []
for q in nq_1000:
    ids = orig_tokenizer.encode(q, add_special_tokens=True)
    lengths.append(len(ids))

print("=== Token Length Statistics (LLaMA-3-8B-Instruct tokenizer) ===")
print(f"Samples analyzed: {len(lengths)}")
print(f"Mean length:       {np.mean(lengths):.2f}")
print(f"Median length:     {np.median(lengths):.0f}")
print(f"90th percentile:   {np.percentile(lengths, 90):.0f}")
print(f"95th percentile:   {np.percentile(lengths, 95):.0f}")
print(f"Max observed len:  {np.max(lengths)}")
print(f"Min observed len:  {np.min(lengths)}")

=== Token Length Statistics (LLaMA-3-8B-Instruct tokenizer) ===
Samples analyzed: 1000
Mean length:       11.21
Median length:     11
90th percentile:   14
95th percentile:   15
Max observed len:  21
Min observed len:  9


In [25]:
import inspect
sig = inspect.signature(model_orig.model.layers[0].forward)
print(sig)


(hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[transformers.cache_utils.Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, **kwargs) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]


# ==============================================
# Logit Lens with Normaliztion =================
# ==============================================

In [None]:
import torch
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
from pathlib import Path
from tqdm import tqdm
import gc
import warnings
warnings.filterwarnings("ignore", message="MatMul8bitLt: inputs will be cast")


# ============================================================
# Normalization modes (as before)
# ============================================================
def apply_normalization(x, model, normalize_mode="raw", block=None, layer_index=None):
    x = x.to(torch.float32)  # ensure consistent dtype
    
    if normalize_mode == "raw":
        return x

    elif normalize_mode == "unit_rms":

        eps = 1e-5
        return x / (x.pow(2).mean(dim=-1, keepdim=True).add(eps).sqrt())

    elif normalize_mode == "rms_layer":
        if layer_index == -1:
            return x
        elif block is not None:
            return block.post_attention_layernorm(x)
        else:
            #return model.model.norm(x)
            return model.model.norm(x).to(torch.float32)

    elif normalize_mode == "norm_rms":
        if layer_index == -1:
            return x
        else:
            #return model.model.norm(x)
            return model.model.norm(x).to(torch.float32)

    else:
        raise ValueError(f"Unknown normalization_mode: {normalize_mode}")

# ============================================================
# Helper: causal + padding mask for LLaMA blocks
# ============================================================
def build_full_attention_mask(input_ids, attention_mask, device):
    """
    Builds a [batch, 1, tgt_len, src_len] mask.
    Uses boolean mask when possible (preferred by Flash/Sdp kernels),
    otherwise additive mask in float32 with -1e9 fill.
    """
    bsz, seq_len = input_ids.shape

    # (1) causal mask
    causal = torch.triu(torch.ones((seq_len, seq_len), device=device, dtype=torch.bool), diagonal=1)
    causal = causal.unsqueeze(0).unsqueeze(0)  # [1, 1, T, T]

    # (2) padding mask
    if attention_mask is None:
        padding_mask = torch.zeros((bsz, 1, 1, seq_len), device=device, dtype=torch.bool)
    else:
        padding_mask = (attention_mask[:, None, None, :] == 0)

    # (3) combine: masked positions = True
    full = causal | padding_mask  # bool mask

    # For models expecting additive mask, convert to float
    full = full.to(torch.float32) * -1e9
    return full


# ============================================================
# Collector with multiple normalization variants + attention weights (optional storage)
# ============================================================
@torch.no_grad()
def collect_logit_lens_full(
    model,
    tokenizer,
    prompts,
    batch_index=0,
    max_len=17,
    device=None,
    clamp_logits=True,
    save_path=None,
    collect_attn=True,
    save_attn=True,
    norm_modes=("raw", "unit_rms", "norm_rms"),
):

    device = device or ("cuda" if torch.cuda.is_available() else "cpu")

    # --- Detect BitsAndBytes quantized models ---
    bnb_layer_types = ("Linear4bit", "Linear8bitLt")
    is_bnb_quantized = any(
        any(name in type(m).__name__ for name in bnb_layer_types)
        for m in model.modules()
    )

    if is_bnb_quantized:
        try:
            first_param_device = next(model.parameters()).device
        except StopIteration:
            first_param_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        print(f"[info] Detected BitsAndBytes quantized model → already on device {first_param_device}")
        model.eval()
    else:
        model = model.to(device).eval()

    Path(save_path).parent.mkdir(parents=True, exist_ok=True)

    # ============================================================
    # Tokenization
    # ============================================================
    encoded = []
    pad_id = tokenizer.pad_token_id or tokenizer.eos_token_id
    bos_id = tokenizer.bos_token_id
    eos_id = tokenizer.eos_token_id

    for p in prompts:
        ids = tokenizer.encode(p, add_special_tokens=False)
        content = ids[: max_len - 2]
        ids = torch.tensor([bos_id] + content + [eos_id], dtype=torch.long)
        if len(ids) < max_len:
            ids = F.pad(ids, (0, max_len - len(ids)), value=pad_id)
        encoded.append(ids)

    input_ids = torch.stack(encoded, dim=0).to(device)

    # ============================================================
    # Build correct attention mask (BOS + content + first EOS)
    # ============================================================
    attention_mask = torch.ones_like(input_ids, dtype=torch.long)
    for i, ids in enumerate(input_ids):
        eos_positions = (ids == eos_id).nonzero(as_tuple=True)[0]
        if len(eos_positions) > 0:
            eos_pos = eos_positions[0].item()
            attention_mask[i, eos_pos + 1 :] = 0

    # ============================================================
    # Attention + positional setup
    # ============================================================
    batch_size, seq_len = input_ids.shape
    full_mask = build_full_attention_mask(input_ids, attention_mask, device)
    position_ids = torch.arange(seq_len, device=device).unsqueeze(0).expand(batch_size, -1)
    vocab_size = model.lm_head.out_features

    print(f"[info] Tokenized {batch_size} prompts | seq_len={seq_len} | pad_id={pad_id} | eos_id={eos_id}")
    print(f"[info] Collecting from {len(model.model.layers)} layers | attn compute={collect_attn}, save={save_attn}")

    def project(x):
        # ensure consistent dtype for analysis
        x_fp32 = x.to(torch.float32)

        # For quantized models, cast back to lm_head weight dtype before projection
        head_dtype = next(model.lm_head.parameters()).dtype if any(p.dtype != torch.float32 for p in model.lm_head.parameters()) else torch.float32
        x_cast = x_fp32.to(head_dtype)

        logits = model.lm_head(x_cast)

        if clamp_logits:
            logits = torch.nan_to_num(logits, nan=0.0, posinf=80, neginf=-80)

        # always return FP32 results for consistency
        return logits.to(torch.float32)

    rows = []
    all_hidden, all_logits, all_attn = {}, {}, {}

    # ============================================================
    # Embedding layer
    # ============================================================
    #x = model.model.embed_tokens(input_ids)
    x = model.model.embed_tokens(input_ids).to(torch.float32)
    hidden_variants = {mode: apply_normalization(x.clone(), model, mode, layer_index=-1) for mode in norm_modes}
    logits_variants = {mode: project(hidden_variants[mode]) for mode in norm_modes}

    for mode in norm_modes:
        all_hidden[f"embed_tokens_{mode}"] = hidden_variants[mode].cpu()
        all_logits[f"embed_tokens_{mode}"] = logits_variants[mode].cpu()

    for i in range(batch_size):
        rows.append({
            "prompt_id": i,
            "prompt_text": prompts[i],
            "batch_index": batch_index,
            "vocab_size": vocab_size,
            "layer_index": -1,
            "layer_name": "embed_tokens",
            "input_ids": input_ids[i].cpu(),
            "target_ids": input_ids[i, 1:].cpu(),
            "attention_mask": attention_mask[i].cpu(),
            **{f"hidden_{m}": hidden_variants[m][i, :-1].cpu() for m in norm_modes},
            **{f"logits_{m}": logits_variants[m][i, :-1].cpu() for m in norm_modes},
        })

    # ============================================================
    # Transformer layers
    # ============================================================
    for li, block in enumerate(model.model.layers):
        out = block(
            x,
            position_ids=position_ids,
            attention_mask=full_mask,
            output_attentions=collect_attn,
        )
        x = out[0]
        attn = out[1] if collect_attn else None

        #layer_output = x.detach().clone()
        layer_output = x.detach().clone().to(torch.float32)
        hidden_variants = {
            mode: apply_normalization(layer_output.clone(), model, mode, block=block, layer_index=li)
            for mode in norm_modes
        }
        logits_variants = {mode: project(hidden_variants[mode]) for mode in norm_modes}

        for mode in norm_modes:
            hidden_variants[mode] = hidden_variants[mode][:, :-1, :]
            logits_variants[mode] = logits_variants[mode][:, :-1, :]

        for mode in norm_modes:
            all_hidden[f"layer.{li}_{mode}"] = hidden_variants[mode].cpu()
            all_logits[f"layer.{li}_{mode}"] = logits_variants[mode].cpu()

        for i in range(batch_size):
            record = {
                "prompt_id": i,
                "prompt_text": prompts[i],
                "batch_index": batch_index,
                "vocab_size": vocab_size,
                "layer_index": li,
                "layer_name": f"layer.{li}",
                "input_ids": input_ids[i].cpu(),
                "target_ids": input_ids[i, 1:].cpu(),
                "attention_mask": attention_mask[i].cpu(),
                **{f"hidden_{m}": hidden_variants[m][i].cpu() for m in norm_modes},
                **{f"logits_{m}": logits_variants[m][i].cpu() for m in norm_modes},
            }
            if save_attn and attn is not None:
                record["attn"] = attn[i].cpu()
            rows.append(record)

        torch.cuda.empty_cache()
        gc.collect()

    # ============================================================
    # Final model RMSNorm output (apply model.norm for ALL modes)
    # ============================================================
    # Apply the model’s true RMSNorm — no custom normalization
    #x = model.model.norm(x)
    x = model.model.norm(x.to(torch.float32))

    # Compute true hidden and logits using the same projection
    h = x
    l = project(h) 

    # Trim for next-token prediction
    h, l = h[:, :-1, :], l[:, :-1, :]

    # Store once for reference
    all_hidden["output_true"] = h.cpu()
    all_logits["output_true"] = l.cpu()

    # --- Apply this SAME normalized output to all norm_modes ---
    for mode in norm_modes:
        all_hidden[f"output_{mode}"] = h.cpu()
        all_logits[f"output_{mode}"] = l.cpu()

    # --- Append a per-prompt metadata record identical to other layers ---
    for i in range(batch_size):
        rows.append({
            "prompt_id": i,
            "prompt_text": prompts[i],
            "batch_index": batch_index,
            "vocab_size": vocab_size,
            "layer_index": len(model.model.layers),
            "layer_name": "output",
            "input_ids": input_ids[i].cpu(),
            "target_ids": input_ids[i, 1:].cpu(),
            "attention_mask": attention_mask[i].cpu(),
            **{f"hidden_{m}": h[i].cpu() for m in norm_modes},
            **{f"logits_{m}": l[i].cpu() for m in norm_modes},
        })


    # ============================================================
    # Save and finish
    # ============================================================
    if save_path:
        torch.save(rows, save_path)
        print(f"[saved] Logit-lens data → {save_path}")

    print(f"[ok] Collected {len(rows)} entries from {len(model.model.layers)} layers (no redundant tensors)")
    return rows, all_hidden, all_logits, all_attn


In [None]:
import torch
import gc
from tqdm import tqdm

def run_logit_lens_in_batches(
    model,
    tokenizer,
    all_prompts,
    batch_size=20,
    save_prefix="logitlens_batch",
    max_len=17,
    normalize_mode=("raw", "unit_rms", "norm_rms"),
    device=None,
    clamp_logits=True,
    collect_attn=False,
    save_attn=False 
):
    """
    Runs collect_logit_lens_full() in batches to avoid OOM.
    Each batch is saved as a separate .pt file.
    """

    num_batches = (len(all_prompts) + batch_size - 1) // batch_size
    print(f"[run] Processing {len(all_prompts)} prompts in {num_batches} batches of {batch_size}")

    for batch_idx in tqdm(range(num_batches), desc="Running logit lens batches"):
        start = batch_idx * batch_size
        end = min((batch_idx + 1) * batch_size, len(all_prompts))
        batch_prompts = all_prompts[start:end]

        save_path = f"{save_prefix}_batch{batch_idx:03d}.pt"

        print(f"\n[batch {batch_idx+1}/{num_batches}] {len(batch_prompts)} prompts → {save_path}")

        try:
            rows, hidden_dict, logits_dict, all_attn = collect_logit_lens_full(
                model=model,
                tokenizer=tokenizer,
                prompts=batch_prompts,
                max_len=max_len,
                device=device,
                norm_modes=normalize_mode,
                save_path=save_path,
                clamp_logits=clamp_logits,
                collect_attn=collect_attn,
                save_attn=save_attn,
            )

        except RuntimeError as e:
            print(f"[error] Batch {batch_idx} failed: {e}")
            continue

        # -------------------------------
        # Free GPU + CPU memory
        # -------------------------------
        del rows, hidden_dict, logits_dict, all_attn, batch_prompts
        torch.cuda.empty_cache()
        gc.collect()

    print("\n[done] All batches processed and saved.")

In [None]:
run_logit_lens_in_batches(
    model=model_4bit,
    tokenizer=orig_tokenizer,
    all_prompts=nq_1000,
    batch_size=20,
    max_len=17,
    normalize_mode=("raw", "unit_rms", "norm_rms"),
    save_prefix="saved_data/lens_data/m_4bit/m_4bit_modes",
    device="cpu",
    clamp_logits=True,
    collect_attn=False,
    save_attn=False  
) 

[run] Processing 1000 prompts in 50 batches of 20


Running logit lens batches:   0%|          | 0/50 [00:00<?, ?it/s]


[batch 1/50] 20 prompts → saved_data/lens_data/m_4bit/m_4bit_modes_batch000.pt
[info] Detected BitsAndBytes quantized model → already on device cpu
[info] Tokenized 20 prompts | seq_len=17 | pad_id=128009 | eos_id=128009
[info] Collecting from 32 layers | attn compute=False, save=False


The attention layers in this model are transitioning from computing the RoPE embeddings internally through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed `position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be removed and `position_embeddings` will be mandatory.


[saved] Logit-lens data → saved_data/lens_data/m_4bit/m_4bit_modes_batch000.pt
[ok] Collected 680 entries from 32 layers (no redundant tensors)


Running logit lens batches:   2%|▏         | 1/50 [06:21<5:11:11, 381.05s/it]


[batch 2/50] 20 prompts → saved_data/lens_data/m_4bit/m_4bit_modes_batch001.pt
[info] Detected BitsAndBytes quantized model → already on device cpu
[info] Tokenized 20 prompts | seq_len=17 | pad_id=128009 | eos_id=128009
[info] Collecting from 32 layers | attn compute=False, save=False
[saved] Logit-lens data → saved_data/lens_data/m_4bit/m_4bit_modes_batch001.pt
[ok] Collected 680 entries from 32 layers (no redundant tensors)


Running logit lens batches:   4%|▍         | 2/50 [12:24<4:56:42, 370.89s/it]


[batch 3/50] 20 prompts → saved_data/lens_data/m_4bit/m_4bit_modes_batch002.pt
[info] Detected BitsAndBytes quantized model → already on device cpu
[info] Tokenized 20 prompts | seq_len=17 | pad_id=128009 | eos_id=128009
[info] Collecting from 32 layers | attn compute=False, save=False
[saved] Logit-lens data → saved_data/lens_data/m_4bit/m_4bit_modes_batch002.pt
[ok] Collected 680 entries from 32 layers (no redundant tensors)


Running logit lens batches:   6%|▌         | 3/50 [19:09<5:02:33, 386.25s/it]


[batch 4/50] 20 prompts → saved_data/lens_data/m_4bit/m_4bit_modes_batch003.pt
[info] Detected BitsAndBytes quantized model → already on device cpu
[info] Tokenized 20 prompts | seq_len=17 | pad_id=128009 | eos_id=128009
[info] Collecting from 32 layers | attn compute=False, save=False
[saved] Logit-lens data → saved_data/lens_data/m_4bit/m_4bit_modes_batch003.pt
[ok] Collected 680 entries from 32 layers (no redundant tensors)


Running logit lens batches:   8%|▊         | 4/50 [25:24<4:52:53, 382.04s/it]


[batch 5/50] 20 prompts → saved_data/lens_data/m_4bit/m_4bit_modes_batch004.pt
[info] Detected BitsAndBytes quantized model → already on device cpu
[info] Tokenized 20 prompts | seq_len=17 | pad_id=128009 | eos_id=128009
[info] Collecting from 32 layers | attn compute=False, save=False
[saved] Logit-lens data → saved_data/lens_data/m_4bit/m_4bit_modes_batch004.pt
[ok] Collected 680 entries from 32 layers (no redundant tensors)


Running logit lens batches:  10%|█         | 5/50 [32:18<4:55:05, 393.46s/it]


[batch 6/50] 20 prompts → saved_data/lens_data/m_4bit/m_4bit_modes_batch005.pt
[info] Detected BitsAndBytes quantized model → already on device cpu
[info] Tokenized 20 prompts | seq_len=17 | pad_id=128009 | eos_id=128009
[info] Collecting from 32 layers | attn compute=False, save=False
[saved] Logit-lens data → saved_data/lens_data/m_4bit/m_4bit_modes_batch005.pt
[ok] Collected 680 entries from 32 layers (no redundant tensors)


Running logit lens batches:  12%|█▏        | 6/50 [38:52<4:48:37, 393.59s/it]


[batch 7/50] 20 prompts → saved_data/lens_data/m_4bit/m_4bit_modes_batch006.pt
[info] Detected BitsAndBytes quantized model → already on device cpu
[info] Tokenized 20 prompts | seq_len=17 | pad_id=128009 | eos_id=128009
[info] Collecting from 32 layers | attn compute=False, save=False
[saved] Logit-lens data → saved_data/lens_data/m_4bit/m_4bit_modes_batch006.pt
[ok] Collected 680 entries from 32 layers (no redundant tensors)


Running logit lens batches:  14%|█▍        | 7/50 [45:22<4:41:11, 392.35s/it]


[batch 8/50] 20 prompts → saved_data/lens_data/m_4bit/m_4bit_modes_batch007.pt
[info] Detected BitsAndBytes quantized model → already on device cpu
[info] Tokenized 20 prompts | seq_len=17 | pad_id=128009 | eos_id=128009
[info] Collecting from 32 layers | attn compute=False, save=False
[saved] Logit-lens data → saved_data/lens_data/m_4bit/m_4bit_modes_batch007.pt
[ok] Collected 680 entries from 32 layers (no redundant tensors)


Running logit lens batches:  16%|█▌        | 8/50 [51:51<4:33:58, 391.40s/it]


[batch 9/50] 20 prompts → saved_data/lens_data/m_4bit/m_4bit_modes_batch008.pt
[info] Detected BitsAndBytes quantized model → already on device cpu
[info] Tokenized 20 prompts | seq_len=17 | pad_id=128009 | eos_id=128009
[info] Collecting from 32 layers | attn compute=False, save=False
[saved] Logit-lens data → saved_data/lens_data/m_4bit/m_4bit_modes_batch008.pt
[ok] Collected 680 entries from 32 layers (no redundant tensors)


Running logit lens batches:  18%|█▊        | 9/50 [58:39<4:30:57, 396.51s/it]


[batch 10/50] 20 prompts → saved_data/lens_data/m_4bit/m_4bit_modes_batch009.pt
[info] Detected BitsAndBytes quantized model → already on device cpu
[info] Tokenized 20 prompts | seq_len=17 | pad_id=128009 | eos_id=128009
[info] Collecting from 32 layers | attn compute=False, save=False
[saved] Logit-lens data → saved_data/lens_data/m_4bit/m_4bit_modes_batch009.pt
[ok] Collected 680 entries from 32 layers (no redundant tensors)


Running logit lens batches:  20%|██        | 10/50 [1:04:57<4:20:29, 390.73s/it]


[batch 11/50] 20 prompts → saved_data/lens_data/m_4bit/m_4bit_modes_batch010.pt
[info] Detected BitsAndBytes quantized model → already on device cpu
[info] Tokenized 20 prompts | seq_len=17 | pad_id=128009 | eos_id=128009
[info] Collecting from 32 layers | attn compute=False, save=False
[saved] Logit-lens data → saved_data/lens_data/m_4bit/m_4bit_modes_batch010.pt
[ok] Collected 680 entries from 32 layers (no redundant tensors)


Running logit lens batches:  22%|██▏       | 11/50 [1:11:07<4:09:49, 384.34s/it]


[batch 12/50] 20 prompts → saved_data/lens_data/m_4bit/m_4bit_modes_batch011.pt
[info] Detected BitsAndBytes quantized model → already on device cpu
[info] Tokenized 20 prompts | seq_len=17 | pad_id=128009 | eos_id=128009
[info] Collecting from 32 layers | attn compute=False, save=False
[saved] Logit-lens data → saved_data/lens_data/m_4bit/m_4bit_modes_batch011.pt
[ok] Collected 680 entries from 32 layers (no redundant tensors)


Running logit lens batches:  24%|██▍       | 12/50 [1:17:09<3:59:11, 377.67s/it]


[batch 13/50] 20 prompts → saved_data/lens_data/m_4bit/m_4bit_modes_batch012.pt
[info] Detected BitsAndBytes quantized model → already on device cpu
[info] Tokenized 20 prompts | seq_len=17 | pad_id=128009 | eos_id=128009
[info] Collecting from 32 layers | attn compute=False, save=False
[saved] Logit-lens data → saved_data/lens_data/m_4bit/m_4bit_modes_batch012.pt
[ok] Collected 680 entries from 32 layers (no redundant tensors)


Running logit lens batches:  26%|██▌       | 13/50 [1:24:29<4:04:28, 396.45s/it]


[batch 14/50] 20 prompts → saved_data/lens_data/m_4bit/m_4bit_modes_batch013.pt
[info] Detected BitsAndBytes quantized model → already on device cpu
[info] Tokenized 20 prompts | seq_len=17 | pad_id=128009 | eos_id=128009
[info] Collecting from 32 layers | attn compute=False, save=False
[saved] Logit-lens data → saved_data/lens_data/m_4bit/m_4bit_modes_batch013.pt
[ok] Collected 680 entries from 32 layers (no redundant tensors)


Running logit lens batches:  28%|██▊       | 14/50 [1:33:15<4:21:20, 435.56s/it]


[batch 15/50] 20 prompts → saved_data/lens_data/m_4bit/m_4bit_modes_batch014.pt
[info] Detected BitsAndBytes quantized model → already on device cpu
[info] Tokenized 20 prompts | seq_len=17 | pad_id=128009 | eos_id=128009
[info] Collecting from 32 layers | attn compute=False, save=False
[saved] Logit-lens data → saved_data/lens_data/m_4bit/m_4bit_modes_batch014.pt
[ok] Collected 680 entries from 32 layers (no redundant tensors)


Running logit lens batches:  30%|███       | 15/50 [1:40:14<4:11:18, 430.81s/it]


[batch 16/50] 20 prompts → saved_data/lens_data/m_4bit/m_4bit_modes_batch015.pt
[info] Detected BitsAndBytes quantized model → already on device cpu
[info] Tokenized 20 prompts | seq_len=17 | pad_id=128009 | eos_id=128009
[info] Collecting from 32 layers | attn compute=False, save=False
[saved] Logit-lens data → saved_data/lens_data/m_4bit/m_4bit_modes_batch015.pt
[ok] Collected 680 entries from 32 layers (no redundant tensors)


Running logit lens batches:  32%|███▏      | 16/50 [1:48:00<4:10:05, 441.33s/it]


[batch 17/50] 20 prompts → saved_data/lens_data/m_4bit/m_4bit_modes_batch016.pt
[info] Detected BitsAndBytes quantized model → already on device cpu
[info] Tokenized 20 prompts | seq_len=17 | pad_id=128009 | eos_id=128009
[info] Collecting from 32 layers | attn compute=False, save=False
[saved] Logit-lens data → saved_data/lens_data/m_4bit/m_4bit_modes_batch016.pt
[ok] Collected 680 entries from 32 layers (no redundant tensors)


Running logit lens batches:  34%|███▍      | 17/50 [1:54:31<3:54:21, 426.11s/it]


[batch 18/50] 20 prompts → saved_data/lens_data/m_4bit/m_4bit_modes_batch017.pt
[info] Detected BitsAndBytes quantized model → already on device cpu
[info] Tokenized 20 prompts | seq_len=17 | pad_id=128009 | eos_id=128009
[info] Collecting from 32 layers | attn compute=False, save=False
[saved] Logit-lens data → saved_data/lens_data/m_4bit/m_4bit_modes_batch017.pt
[ok] Collected 680 entries from 32 layers (no redundant tensors)


Running logit lens batches:  36%|███▌      | 18/50 [2:03:14<4:02:49, 455.28s/it]


[batch 19/50] 20 prompts → saved_data/lens_data/m_4bit/m_4bit_modes_batch018.pt
[info] Detected BitsAndBytes quantized model → already on device cpu
[info] Tokenized 20 prompts | seq_len=17 | pad_id=128009 | eos_id=128009
[info] Collecting from 32 layers | attn compute=False, save=False
[saved] Logit-lens data → saved_data/lens_data/m_4bit/m_4bit_modes_batch018.pt
[ok] Collected 680 entries from 32 layers (no redundant tensors)


Running logit lens batches:  38%|███▊      | 19/50 [2:11:54<4:05:14, 474.66s/it]


[batch 20/50] 20 prompts → saved_data/lens_data/m_4bit/m_4bit_modes_batch019.pt
[info] Detected BitsAndBytes quantized model → already on device cpu
[info] Tokenized 20 prompts | seq_len=17 | pad_id=128009 | eos_id=128009
[info] Collecting from 32 layers | attn compute=False, save=False
[saved] Logit-lens data → saved_data/lens_data/m_4bit/m_4bit_modes_batch019.pt
[ok] Collected 680 entries from 32 layers (no redundant tensors)


Running logit lens batches:  40%|████      | 20/50 [2:18:31<3:45:44, 451.50s/it]


[batch 21/50] 20 prompts → saved_data/lens_data/m_4bit/m_4bit_modes_batch020.pt
[info] Detected BitsAndBytes quantized model → already on device cpu
[info] Tokenized 20 prompts | seq_len=17 | pad_id=128009 | eos_id=128009
[info] Collecting from 32 layers | attn compute=False, save=False
[saved] Logit-lens data → saved_data/lens_data/m_4bit/m_4bit_modes_batch020.pt
[ok] Collected 680 entries from 32 layers (no redundant tensors)


Running logit lens batches:  42%|████▏     | 21/50 [2:24:32<3:25:04, 424.29s/it]


[batch 22/50] 20 prompts → saved_data/lens_data/m_4bit/m_4bit_modes_batch021.pt
[info] Detected BitsAndBytes quantized model → already on device cpu
[info] Tokenized 20 prompts | seq_len=17 | pad_id=128009 | eos_id=128009
[info] Collecting from 32 layers | attn compute=False, save=False
[saved] Logit-lens data → saved_data/lens_data/m_4bit/m_4bit_modes_batch021.pt
[ok] Collected 680 entries from 32 layers (no redundant tensors)


Running logit lens batches:  44%|████▍     | 22/50 [2:30:46<3:10:55, 409.13s/it]


[batch 23/50] 20 prompts → saved_data/lens_data/m_4bit/m_4bit_modes_batch022.pt
[info] Detected BitsAndBytes quantized model → already on device cpu
[info] Tokenized 20 prompts | seq_len=17 | pad_id=128009 | eos_id=128009
[info] Collecting from 32 layers | attn compute=False, save=False
[saved] Logit-lens data → saved_data/lens_data/m_4bit/m_4bit_modes_batch022.pt
[ok] Collected 680 entries from 32 layers (no redundant tensors)


Running logit lens batches:  46%|████▌     | 23/50 [2:36:39<2:56:30, 392.25s/it]


[batch 24/50] 20 prompts → saved_data/lens_data/m_4bit/m_4bit_modes_batch023.pt
[info] Detected BitsAndBytes quantized model → already on device cpu
[info] Tokenized 20 prompts | seq_len=17 | pad_id=128009 | eos_id=128009
[info] Collecting from 32 layers | attn compute=False, save=False
[saved] Logit-lens data → saved_data/lens_data/m_4bit/m_4bit_modes_batch023.pt
[ok] Collected 680 entries from 32 layers (no redundant tensors)


Running logit lens batches:  48%|████▊     | 24/50 [2:43:42<2:54:01, 401.58s/it]


[batch 25/50] 20 prompts → saved_data/lens_data/m_4bit/m_4bit_modes_batch024.pt
[info] Detected BitsAndBytes quantized model → already on device cpu
[info] Tokenized 20 prompts | seq_len=17 | pad_id=128009 | eos_id=128009
[info] Collecting from 32 layers | attn compute=False, save=False
[saved] Logit-lens data → saved_data/lens_data/m_4bit/m_4bit_modes_batch024.pt
[ok] Collected 680 entries from 32 layers (no redundant tensors)


Running logit lens batches:  50%|█████     | 25/50 [2:51:05<2:52:28, 413.95s/it]


[batch 26/50] 20 prompts → saved_data/lens_data/m_4bit/m_4bit_modes_batch025.pt
[info] Detected BitsAndBytes quantized model → already on device cpu
[info] Tokenized 20 prompts | seq_len=17 | pad_id=128009 | eos_id=128009
[info] Collecting from 32 layers | attn compute=False, save=False
[saved] Logit-lens data → saved_data/lens_data/m_4bit/m_4bit_modes_batch025.pt
[ok] Collected 680 entries from 32 layers (no redundant tensors)


Running logit lens batches:  52%|█████▏    | 26/50 [2:57:50<2:44:31, 411.29s/it]


[batch 27/50] 20 prompts → saved_data/lens_data/m_4bit/m_4bit_modes_batch026.pt
[info] Detected BitsAndBytes quantized model → already on device cpu
[info] Tokenized 20 prompts | seq_len=17 | pad_id=128009 | eos_id=128009
[info] Collecting from 32 layers | attn compute=False, save=False
[saved] Logit-lens data → saved_data/lens_data/m_4bit/m_4bit_modes_batch026.pt
[ok] Collected 680 entries from 32 layers (no redundant tensors)


Running logit lens batches:  54%|█████▍    | 27/50 [3:04:13<2:34:23, 402.74s/it]


[batch 28/50] 20 prompts → saved_data/lens_data/m_4bit/m_4bit_modes_batch027.pt
[info] Detected BitsAndBytes quantized model → already on device cpu
[info] Tokenized 20 prompts | seq_len=17 | pad_id=128009 | eos_id=128009
[info] Collecting from 32 layers | attn compute=False, save=False
[saved] Logit-lens data → saved_data/lens_data/m_4bit/m_4bit_modes_batch027.pt
[ok] Collected 680 entries from 32 layers (no redundant tensors)


Running logit lens batches:  56%|█████▌    | 28/50 [3:11:26<2:31:02, 411.91s/it]


[batch 29/50] 20 prompts → saved_data/lens_data/m_4bit/m_4bit_modes_batch028.pt
[info] Detected BitsAndBytes quantized model → already on device cpu
[info] Tokenized 20 prompts | seq_len=17 | pad_id=128009 | eos_id=128009
[info] Collecting from 32 layers | attn compute=False, save=False
[saved] Logit-lens data → saved_data/lens_data/m_4bit/m_4bit_modes_batch028.pt
[ok] Collected 680 entries from 32 layers (no redundant tensors)


Running logit lens batches:  58%|█████▊    | 29/50 [3:18:17<2:24:02, 411.54s/it]


[batch 30/50] 20 prompts → saved_data/lens_data/m_4bit/m_4bit_modes_batch029.pt
[info] Detected BitsAndBytes quantized model → already on device cpu
[info] Tokenized 20 prompts | seq_len=17 | pad_id=128009 | eos_id=128009
[info] Collecting from 32 layers | attn compute=False, save=False
[saved] Logit-lens data → saved_data/lens_data/m_4bit/m_4bit_modes_batch029.pt
[ok] Collected 680 entries from 32 layers (no redundant tensors)


Running logit lens batches:  60%|██████    | 30/50 [3:24:53<2:15:37, 406.86s/it]


[batch 31/50] 20 prompts → saved_data/lens_data/m_4bit/m_4bit_modes_batch030.pt
[info] Detected BitsAndBytes quantized model → already on device cpu
[info] Tokenized 20 prompts | seq_len=17 | pad_id=128009 | eos_id=128009
[info] Collecting from 32 layers | attn compute=False, save=False
[saved] Logit-lens data → saved_data/lens_data/m_4bit/m_4bit_modes_batch030.pt
[ok] Collected 680 entries from 32 layers (no redundant tensors)


Running logit lens batches:  62%|██████▏   | 31/50 [3:31:02<2:05:14, 395.51s/it]


[batch 32/50] 20 prompts → saved_data/lens_data/m_4bit/m_4bit_modes_batch031.pt
[info] Detected BitsAndBytes quantized model → already on device cpu
[info] Tokenized 20 prompts | seq_len=17 | pad_id=128009 | eos_id=128009
[info] Collecting from 32 layers | attn compute=False, save=False
[saved] Logit-lens data → saved_data/lens_data/m_4bit/m_4bit_modes_batch031.pt
[ok] Collected 680 entries from 32 layers (no redundant tensors)


Running logit lens batches:  64%|██████▍   | 32/50 [3:37:13<1:56:30, 388.35s/it]


[batch 33/50] 20 prompts → saved_data/lens_data/m_4bit/m_4bit_modes_batch032.pt
[info] Detected BitsAndBytes quantized model → already on device cpu
[info] Tokenized 20 prompts | seq_len=17 | pad_id=128009 | eos_id=128009
[info] Collecting from 32 layers | attn compute=False, save=False


In [None]:
run_logit_lens_in_batches(
    model=model_quant,
    tokenizer=quant_tokenizer,
    all_prompts=nq_queries_200,
    batch_size=20,
    max_len=18,
    normalize_mode="raw",
    save_prefix="logs/new_model_lens/raw/m_quant_raw",
    device="cpu",
) 

In [15]:
import torch
import pandas as pd 

data = torch.load("saved_data/lens_data/m_orig_norm_modes_batch000.pt", weights_only=False, map_location="cpu")

In [16]:
data_df = pd.DataFrame(data)

In [46]:
data_df.head()

Unnamed: 0,prompt_id,prompt_text,batch_index,vocab_size,layer_index,layer_name,input_ids,target_ids,attention_mask,hidden_raw,hidden_unit_rms,hidden_norm_rms,logits_raw,logits_unit_rms,logits_norm_rms,hidden,logits
0,0,when did richmond last play in a preliminary f...,0,128256,-1,embed_tokens,"[tensor(128000), tensor(9493), tensor(1550), t...","[tensor(9493), tensor(1550), tensor(9257), ten...","[tensor(1), tensor(1), tensor(1), tensor(1), t...","[[tensor(-8.2970e-05), tensor(0.0003), tensor(...","[[tensor(-8.2970e-05), tensor(0.0003), tensor(...","[[tensor(-8.2970e-05), tensor(0.0003), tensor(...","[[tensor(-0.0105), tensor(-0.0092), tensor(-0....","[[tensor(-0.0105), tensor(-0.0092), tensor(-0....","[[tensor(-0.0105), tensor(-0.0092), tensor(-0....",,
1,0,when did richmond last play in a preliminary f...,0,128256,0,layer.0,"[tensor(128000), tensor(9493), tensor(1550), t...","[tensor(9493), tensor(1550), tensor(9257), ten...","[tensor(1), tensor(1), tensor(1), tensor(1), t...","[[tensor(0.0013), tensor(0.0040), tensor(-0.00...","[[tensor(0.0102), tensor(0.0307), tensor(-0.03...","[[tensor(0.0271), tensor(0.0790), tensor(-0.10...","[[tensor(-0.0531), tensor(-0.1477), tensor(-0....","[[tensor(-0.4077), tensor(-1.1344), tensor(-0....","[[tensor(-0.7749), tensor(-2.9273), tensor(-1....",,
2,0,when did richmond last play in a preliminary f...,0,128256,1,layer.1,"[tensor(128000), tensor(9493), tensor(1550), t...","[tensor(9493), tensor(1550), tensor(9257), ten...","[tensor(1), tensor(1), tensor(1), tensor(1), t...","[[tensor(-0.0454), tensor(0.0954), tensor(-0.0...","[[tensor(-0.0075), tensor(0.0157), tensor(-0.0...","[[tensor(-0.0199), tensor(0.0405), tensor(-0.0...","[[tensor(-7.2014), tensor(-13.3586), tensor(-8...","[[tensor(-1.1870), tensor(-2.2018), tensor(-1....","[[tensor(-1.3217), tensor(-2.3338), tensor(-1....",,
3,0,when did richmond last play in a preliminary f...,0,128256,2,layer.2,"[tensor(128000), tensor(9493), tensor(1550), t...","[tensor(9493), tensor(1550), tensor(9257), ten...","[tensor(1), tensor(1), tensor(1), tensor(1), t...","[[tensor(-0.0479), tensor(0.0981), tensor(-0.0...","[[tensor(-0.0079), tensor(0.0162), tensor(-0.0...","[[tensor(-0.0210), tensor(0.0417), tensor(-0.0...","[[tensor(-7.2171), tensor(-13.3865), tensor(-8...","[[tensor(-1.1894), tensor(-2.2061), tensor(-1....","[[tensor(-1.3264), tensor(-2.3425), tensor(-1....",,
4,0,when did richmond last play in a preliminary f...,0,128256,3,layer.3,"[tensor(128000), tensor(9493), tensor(1550), t...","[tensor(9493), tensor(1550), tensor(9257), ten...","[tensor(1), tensor(1), tensor(1), tensor(1), t...","[[tensor(-0.0325), tensor(0.1043), tensor(-0.0...","[[tensor(-0.0054), tensor(0.0172), tensor(-0.0...","[[tensor(-0.0142), tensor(0.0443), tensor(-0.0...","[[tensor(-7.2138), tensor(-13.4237), tensor(-8...","[[tensor(-1.1886), tensor(-2.2117), tensor(-1....","[[tensor(-1.3254), tensor(-2.3557), tensor(-1....",,


In [19]:
data_df.tail()

Unnamed: 0,prompt_id,prompt_text,batch_index,vocab_size,layer_index,layer_name,input_ids,target_ids,attention_mask,hidden_raw,hidden_unit_rms,hidden_norm_rms,logits_raw,logits_unit_rms,logits_norm_rms
29,0,when did richmond last play in a preliminary f...,0,128256,28,layer.28,"[tensor(128000), tensor(9493), tensor(1550), t...","[tensor(9493), tensor(1550), tensor(9257), ten...","[tensor(1), tensor(1), tensor(1), tensor(1), t...","[[tensor(0.0686), tensor(0.1682), tensor(0.047...","[[tensor(0.0114), tensor(0.0280), tensor(0.007...","[[tensor(0.0303), tensor(0.0721), tensor(0.020...","[[tensor(-6.2383), tensor(-12.7422), tensor(-8...","[[tensor(-1.0371), tensor(-2.1172), tensor(-1....","[[tensor(-1.1064), tensor(-2.1855), tensor(-1...."
30,0,when did richmond last play in a preliminary f...,0,128256,29,layer.29,"[tensor(128000), tensor(9493), tensor(1550), t...","[tensor(9493), tensor(1550), tensor(9257), ten...","[tensor(1), tensor(1), tensor(1), tensor(1), t...","[[tensor(0.1700), tensor(0.1889), tensor(0.050...","[[tensor(0.0283), tensor(0.0314), tensor(0.008...","[[tensor(0.0751), tensor(0.0811), tensor(0.021...","[[tensor(-5.3750), tensor(-12.2188), tensor(-7...","[[tensor(-0.8945), tensor(-2.0332), tensor(-1....","[[tensor(-0.8179), tensor(-2.0156), tensor(-1...."
31,0,when did richmond last play in a preliminary f...,0,128256,30,layer.30,"[tensor(128000), tensor(9493), tensor(1550), t...","[tensor(9493), tensor(1550), tensor(9257), ten...","[tensor(1), tensor(1), tensor(1), tensor(1), t...","[[tensor(0.1513), tensor(0.2639), tensor(0.072...","[[tensor(0.0254), tensor(0.0443), tensor(0.012...","[[tensor(0.0675), tensor(0.1142), tensor(0.031...","[[tensor(-2.8730), tensor(-10.8047), tensor(-6...","[[tensor(-0.4824), tensor(-1.8154), tensor(-1....","[[tensor(0.0625), tensor(-1.5752), tensor(-1.0..."
32,0,when did richmond last play in a preliminary f...,0,128256,31,layer.31,"[tensor(128000), tensor(9493), tensor(1550), t...","[tensor(9493), tensor(1550), tensor(9257), ten...","[tensor(1), tensor(1), tensor(1), tensor(1), t...","[[tensor(1.0524), tensor(-0.0069), tensor(-0.2...","[[tensor(1.6513), tensor(-0.0108), tensor(-0.3...","[[tensor(4.3862), tensor(-0.0278), tensor(-1.0...","[[tensor(1.6826), tensor(2.3926), tensor(3.320...","[[tensor(2.6406), tensor(3.7539), tensor(5.210...","[[tensor(4.7617), tensor(6.1094), tensor(10.72..."
33,0,when did richmond last play in a preliminary f...,0,128256,32,output,"[tensor(128000), tensor(9493), tensor(1550), t...","[tensor(9493), tensor(1550), tensor(9257), ten...","[tensor(1), tensor(1), tensor(1), tensor(1), t...","[[tensor(4.3862), tensor(-0.0278), tensor(-1.0...","[[tensor(4.3862), tensor(-0.0278), tensor(-1.0...","[[tensor(4.3862), tensor(-0.0278), tensor(-1.0...","[[tensor(4.7617), tensor(6.1094), tensor(10.72...","[[tensor(4.7617), tensor(6.1094), tensor(10.72...","[[tensor(4.7617), tensor(6.1094), tensor(10.72..."


In [34]:
data_df.columns

Index(['prompt_id', 'prompt_text', 'batch_index', 'vocab_size', 'layer_index',
       'layer_name', 'input_ids', 'target_ids', 'attention_mask', 'hidden_raw',
       'hidden_unit_rms', 'hidden_norm_rms', 'logits_raw', 'logits_unit_rms',
       'logits_norm_rms'],
      dtype='object')

In [17]:
data_df.isna().sum()

prompt_id          0
prompt_text        0
batch_index        0
vocab_size         0
layer_index        0
layer_name         0
input_ids          0
target_ids         0
attention_mask     0
hidden_raw         0
hidden_unit_rms    0
hidden_norm_rms    0
logits_raw         0
logits_unit_rms    0
logits_norm_rms    0
dtype: int64

# ==============================================
# TopK Comparison ==============================
# ==============================================

In [None]:
import torch
import torch.nn.functional as F
import numpy as np
from scipy.special import kl_div
from scipy.spatial.distance import jensenshannon, jaccard
from scipy.stats import entropy
import pandas as pd
import os


def full_jaccard_over_positions(tkA, tkB, eps=1e-9, device="cpu"):
    """
    Compute Jaccard similarity across *all positions* for each sequence.
    Treats all top-k tokens across positions as one set per model.
    Returns tensor of shape [batch_size].
    Always finite (no NaNs/Infs).
    """
    batch_scores = []
    for i in range(tkA.size(0)):
        setA = set(tkA[i].flatten().tolist())
        setB = set(tkB[i].flatten().tolist())

        if len(setA) == 0 and len(setB) == 0:
            score = torch.tensor(1.0, device=device)
        elif len(setA | setB) == 0:
            score = torch.tensor(0.0, device=device)
        else:
            inter = len(setA & setB)
            union = len(setA | setB)
            val = inter / (union + eps)
            score = torch.tensor(val, device=device)
            if not torch.isfinite(score):
                score = torch.tensor(0.0, device=device)
            score = torch.clamp(score, 0.0, 1.0)
        batch_scores.append(score)

    out = torch.stack(batch_scores)
    out = torch.nan_to_num(out, nan=0.0, posinf=0.0, neginf=0.0)
    return out

def safe_mean(x):
    """Robust mean: always returns a finite float (never NaN or None)."""
    if x is None:
        return 0.0
    if torch.is_tensor(x):
        if x.numel() == 0:
            return 0.0
        x = torch.nan_to_num(x.detach().float(), nan=0.0, posinf=0.0, neginf=0.0).flatten()
        if x.numel() == 0:
            return 0.0
        return float(x.mean().item())
    if isinstance(x, (list, tuple, np.ndarray)):
        arr = np.asarray(x, dtype=np.float64).ravel()
        if arr.size == 0:
            return 0.0
        arr = np.nan_to_num(arr, nan=0.0, posinf=0.0, neginf=0.0)
        return float(np.mean(arr))
    try:
        val = float(x)
        if not np.isfinite(val):
            return 0.0
        return val
    except Exception:
        return 0.0


@torch.no_grad()
def compute_divergence(
    metrics_A,
    metrics_B,
    topk=(1, 5, 10, 20),
    eps=1e-9,
    mode="aligned",
    batch_num=0,
    lens_type="raw",
    device=None,
    tokenizer=None,
    debug_checks=False,
):
    """
    Compare metrics between two model outputs with full numerical safety.
    Masked probabilities are renormalized; NaNs/Infs are handled everywhere.
    """
    device = device or ("cuda" if torch.cuda.is_available() else "cpu")
    topk = sorted(set([1] + list(topk)))
    max_k = max(topk)
    detailed = []

    def safe(x, default=None, eps=1e-9):
        """
        Numerically safe utility:
        - replaces NaN with 0
        - clamps extreme values to avoid destroying magnitude
        - keeps log-like tensors in [-30, 30] range
        """
        if x is None:
            return torch.tensor(0.0)

        x = torch.nan_to_num(x, nan=0.0, posinf=30.0, neginf=-30.0)

        if torch.is_floating_point(x):
            mean_val = x.mean().item() if x.numel() > 0 else 0.0
            if mean_val < 1.0:  # looks like log-probs
                x = torch.clamp(x, min=-30.0, max=30.0)
            else:
                x = torch.clamp(x, min=0.0, max=1e6)

        return x


    def pad_vocab_dim(logits, target_vocab):
        pad = target_vocab - logits.size(-1)
        if pad > 0:
            logits = F.pad(logits, (0, pad))
        return logits

    # Pair rows depending on mode
    if mode == "aligned":
        #Layer-by-layer (#layers × #prompts) Deep interpretability, divergence over depth
        pairs = list(zip(metrics_A, metrics_B))
    elif mode == "position_final":
        "position_final"
        #Final-layer only (#prompts) Final behavior agreement (accuracy, top-k, disagreement)
        def last_by_prompt(rows):
            best = {}
            for r in rows:
                pid = r["prompt_id"]
                if pid not in best or r["layer_index"] > best[pid]["layer_index"]:
                    best[pid] = r
            return best
        last_A = last_by_prompt(metrics_A)
        last_B = last_by_prompt(metrics_B)
        common_pids = sorted(set(last_A) & set(last_B))
        pairs = [(last_A[p], last_B[p]) for p in common_pids]
    else:
        raise ValueError(f"Unknown mode: {mode}")


    # ============================================================
    # === MAIN LOOP ==============================================
    # ============================================================
    for row_A, row_B in pairs:
        assert row_A["prompt_id"] == row_B["prompt_id"]
        if mode == "aligned":
            assert row_A["layer_index"] == row_B["layer_index"]

        logits_A = safe(row_A["logits"].clone().float().to(device))
        logits_B = safe(row_B["logits"].clone().float().to(device))
        if logits_A.ndim == 2: logits_A = logits_A.unsqueeze(0)
        if logits_B.ndim == 2: logits_B = logits_B.unsqueeze(0)
        vocab_target = max(logits_A.size(-1), logits_B.size(-1))
        logits_A, logits_B = pad_vocab_dim(logits_A, vocab_target), pad_vocab_dim(logits_B, vocab_target)

        # --- Stable log-softmax ---
        logp_A_clean = safe(F.log_softmax(logits_A, dim=-1), -30.0)
        logp_B_clean = safe(F.log_softmax(logits_B, dim=-1), -30.0)
        probs_A_clean, probs_B_clean = logp_A_clean.exp(), logp_B_clean.exp()

        # ============================================================
        # === FULL-VOCAB DIVERGENCES (before masking) ================
        # ============================================================
        kl_ab_full = safe(torch.sum(probs_A_clean * (logp_A_clean - logp_B_clean), dim=-1))
        kl_ba_full = safe(torch.sum(probs_B_clean * (logp_B_clean - logp_A_clean), dim=-1))
        kl_ab_token_full, kl_ba_token_full = float(kl_ab_full.mean()), float(kl_ba_full.mean())

        try:
            jsd_dist_full = torch.tensor(
                jensenshannon(
                    probs_A_clean.detach().cpu().numpy(),
                    probs_B_clean.detach().cpu().numpy(),
                    axis=-1,
                ),
                device=device,
            )
        except Exception:
            jsd_dist_full = torch.zeros_like(kl_ab_full)


        jsd_div_full = safe(jsd_dist_full.pow(2))
        vocab_tvd_full = safe(0.5 * torch.sum(torch.abs(probs_A_clean - probs_B_clean), dim=-1))
        heat_A_full = safe(-torch.sum(probs_A_clean * logp_A_clean, dim=-1))
        heat_B_full = safe(-torch.sum(probs_B_clean * logp_B_clean, dim=-1))

        # ============================================================
        # === GROUND-TRUTH LOG-PROBABILITIES (FULL VOCAB, UNMASKED) ===
        # ============================================================
        targets = row_A.get("target_ids")
        if targets is not None:
            if targets.ndim == 1:
                targets = targets.unsqueeze(0)

            # Align sequence lengths (decoder shift handling)
            if logp_A_clean.size(1) == targets.size(1) + 1:
                logp_A_clean = logp_A_clean[:, :-1, :]
                logp_B_clean = logp_B_clean[:, :-1, :]
            elif logp_A_clean.size(1) + 1 == targets.size(1):
                targets = targets[:, 1:]

            gather_idx = targets.unsqueeze(-1).clamp(max=logp_A_clean.size(-1) - 1)
            logp_gt_A_full = logp_A_clean.gather(-1, gather_idx).squeeze(-1)
            logp_gt_B_full = logp_B_clean.gather(-1, gather_idx).squeeze(-1)
            logp_gt_diff_full = logp_gt_A_full - logp_gt_B_full

            mean_logp_gt_A_full = float(logp_gt_A_full.mean().item())
            mean_logp_gt_B_full = float(logp_gt_B_full.mean().item())
            mean_logp_gt_diff_full = float(logp_gt_diff_full.mean().item())
        else:
            logp_gt_A_full = logp_gt_B_full = logp_gt_diff_full = torch.empty(0)
            mean_logp_gt_A_full = mean_logp_gt_B_full = mean_logp_gt_diff_full = float("nan")

        # ============================================================
        # === DISPERSION + QUANTILES ================
        # ============================================================
        sorted_A, _ = torch.sort(probs_A_clean, dim=-1)
        sorted_B, _ = torch.sort(probs_B_clean, dim=-1)
        q25_A, q75_A = torch.quantile(sorted_A, 0.25, dim=-1), torch.quantile(sorted_A, 0.75, dim=-1)
        q25_B, q75_B = torch.quantile(sorted_B, 0.25, dim=-1), torch.quantile(sorted_B, 0.75, dim=-1)
        iqr_A, iqr_B = q75_A - q25_A, q75_B - q25_B


        # ============================================================
        # === FINAL MODEL PERPLEXITY =================================
        # ============================================================
        def compute_final_perplexity(final_logits, labels, ignore_index=-100, eps=1e-9):
            """Compute overall model perplexity = exp(mean cross-entropy loss)"""
            if final_logits.ndim == 2:
                final_logits = final_logits.unsqueeze(0)
            if labels.ndim == 1:
                labels = labels.unsqueeze(0)

            # Align logits and labels if off by one (common for decoder models)
            if final_logits.size(1) == labels.size(1) + 1:
                final_logits = final_logits[:, :-1, :]
            elif final_logits.size(1) + 1 == labels.size(1):
                labels = labels[:, 1:]

            logits_flat = final_logits.reshape(-1, final_logits.size(-1))
            labels_flat = labels.reshape(-1)

            ce_loss = F.cross_entropy(
                logits_flat, labels_flat,
                ignore_index=ignore_index, reduction="mean"
            )

            # --- Numeric safety but *faithful* - hopefully :) ---
            if torch.isnan(ce_loss) or torch.isinf(ce_loss):
                safe_ce = torch.tensor(30.0, device=ce_loss.device)
            else:
                safe_ce = ce_loss.clamp(min=0.0, max=30.0)

            # Perplexity is *unsigned*
            ppl = float(torch.exp(safe_ce))
            return ppl, float(ce_loss)
            
        if targets is not None:
            final_logits_A = row_A["logits"].to(device)
            final_logits_B = row_B["logits"].to(device)
            final_ppl_A, ce_A = compute_final_perplexity(final_logits_A, targets)
            final_ppl_B, ce_B = compute_final_perplexity(final_logits_B, targets)

            # Signed difference (directional drift)
            final_ppl_diff = float(final_ppl_A - final_ppl_B)
        else:
            final_ppl_A = final_ppl_B = final_ppl_diff = float("nan")
            ce_A = ce_B = float("nan")

        # ============================================================
        # === TOP-K METRICS (non-masked divergence version) ==========
        # ============================================================
        top_vals_A, top_idx_A = torch.topk(probs_A_clean, max_k, dim=-1)
        top_vals_B, top_idx_B = torch.topk(probs_B_clean, max_k, dim=-1)

        # --- initialize all containers ---
        acc_A_topk, acc_B_topk = {}, {}
        agree_topk, jaccard_topk = {}, {}
        disagree_topk = {}
        disagree_correct_topk, agree_correct_topk, agree_wrong_topk = {}, {}, {}
        full_jaccard_topk = {}
        prob_mass_A_topk, prob_mass_B_topk = {}, {}
        shared_mass_topk, tail_mass_A_topk, tail_mass_B_topk = {}, {}, {}
        prob_overlap_topk, prob_mass_overlap_topk = {}, {}

        for k in topk:
            tkA, tkB = top_idx_A[:, :, :k], top_idx_B[:, :, :k]
            tvA, tvB = top_vals_A[:, :, :k], top_vals_B[:, :, :k]

            # --- probability & overlap metrics ---
            prob_mass_A_topk[k], prob_mass_B_topk[k] = tvA.sum(dim=-1), tvB.sum(dim=-1)
            tail_mass_A_topk[k], tail_mass_B_topk[k] = 1 - prob_mass_A_topk[k], 1 - prob_mass_B_topk[k]

            inter_mask = (tkA.unsqueeze(-1) == tkB.unsqueeze(-2))
            inter_counts = inter_mask.any(dim=-1).sum(dim=2).float()
            jaccard_topk[k] = inter_counts / (2 * k - inter_counts + eps)
            agree_topk[k] = (inter_counts == k).float()
            full_jaccard_topk[k] = full_jaccard_over_positions(tkA, tkB, eps, device)

            # --- shared probability mass ---
            shared_mass = torch.zeros_like(prob_mass_A_topk[k])
            for i in range(tkA.size(0)):
                shared_tokens = set(tkA[i].reshape(-1).tolist()) & set(tkB[i].reshape(-1).tolist())
                if shared_tokens:
                    st = torch.tensor(list(shared_tokens), device=device)
                    shared_mass[i] = 0.5 * (
                        probs_A_clean[i, :, st].sum() + probs_B_clean[i, :, st].sum()
                    )
            shared_mass_topk[k] = shared_mass
            prob_mass_overlap_topk[k] = prob_mass_A_topk[k] * prob_mass_B_topk[k]
            prob_overlap_topk[k] = shared_mass / (
                0.5 * (prob_mass_A_topk[k] + prob_mass_B_topk[k]) + eps
            )

            # --- general set disagreement (always defined) ---
            disagree_topk[k] = 1.0 - jaccard_topk[k]

            # --- accuracy & correctness metrics ---
            if targets is not None:
                tgt_expand = targets.unsqueeze(0) if targets.ndim == 1 else targets
                acc_A = (tkA == tgt_expand.unsqueeze(-1)).any(dim=-1).float()
                acc_B = (tkB == tgt_expand.unsqueeze(-1)).any(dim=-1).float()

                acc_A_topk[k], acc_B_topk[k] = acc_A, acc_B

                # --- correctness-based disagreement metrics ---
                disagree_correct_topk[k] = (acc_A + acc_B == 1).float()   # exactly one correct
                agree_correct_topk[k]   = (acc_A * acc_B).float()         # both correct
                agree_wrong_topk[k]     = ((1 - acc_A) * (1 - acc_B)).float()  # both wrong
            else:
                acc_A_topk[k] = acc_B_topk[k] = None
                disagree_correct_topk[k] = None
                agree_correct_topk[k] = None
                agree_wrong_topk[k] = None


        # ============================================================
        # === PREDICTED TOKENS (optional for qualitative analysis) ===
        # ============================================================
        # Save top-1 predictions (token IDs)
        pred_top1_A = top_idx_A[:, :, 0].detach().cpu()
        pred_top1_B = top_idx_B[:, :, 0].detach().cpu()

        # Optionally, decode them if tokenizer is available
        if tokenizer is not None:
            decoded_preds_A = [tokenizer.decode(seq, skip_special_tokens=False) for seq in pred_top1_A]
            decoded_preds_B = [tokenizer.decode(seq, skip_special_tokens=False) for seq in pred_top1_B]
        else:
            decoded_preds_A = decoded_preds_B = None


        # ============================================================
        # === DETAILED ROW ===========================================
        # ============================================================
        detailed_row = {
            "batch": batch_num,
            "prompt_id": row_A["prompt_id"],
            "prompt_text": row_A.get("prompt_text"),
            "layer_index": row_A["layer_index"],
            "layer_name": row_A["layer_name"],
            "vocab_size": vocab_target,
            "input_ids": row_A.get("input_ids"),
            "targets": row_A.get("target_ids"),
            "pred_top1_A": pred_top1_A,
            "pred_top1_B": pred_top1_B,
            "decoded_top1_A": decoded_preds_A,
            "decoded_top1_B": decoded_preds_B,
            "kl_ab_full": kl_ab_full.cpu(), 
            "kl_ba_full": kl_ba_full.cpu(), 
            "kl_ab_token_full": kl_ab_token_full,
            "kl_ba_token_full": kl_ba_token_full,
            "jsd_dist_full": jsd_dist_full.cpu(), 
            "jsd_div_full": jsd_div_full.cpu(), 
            "vocab_tvd_full": vocab_tvd_full.cpu(), 
            "heat_A_full": heat_A_full.cpu(),
            "heat_B_full": heat_B_full.cpu(),
            "iqr_A": iqr_A.cpu(),
            "iqr_B": iqr_B.cpu(),
            "final_ce_A": ce_A,
            "final_ce_B": ce_B,
            "final_ppl_A": final_ppl_A, 
            "final_ppl_B": final_ppl_B, 
            "final_ppl_diff": final_ppl_diff, 
            #"final_ppl_absdiff": final_ppl_absdiff,
            "acc_A_topk": acc_A_topk,
            "acc_B_topk": acc_B_topk,
            "agree_topk": agree_topk,
            "jaccard_topk": jaccard_topk,
            "disagree_topk": disagree_topk,
            "disagree_correct_topk": disagree_correct_topk,
            "agree_correct_topk": agree_correct_topk,
            "agree_wrong_topk": agree_wrong_topk,
            "full_jaccard_topk": full_jaccard_topk,
            "prob_mass_A_topk": prob_mass_A_topk,
            "prob_mass_B_topk": prob_mass_B_topk,
            "tail_mass_A_topk": tail_mass_A_topk,
            "tail_mass_B_topk": tail_mass_B_topk,
            "shared_mass_topk": shared_mass_topk,
            "prob_overlap_topk": prob_overlap_topk,
            "prob_mass_overlap_topk": prob_mass_overlap_topk,
            "logp_gt_A_full": logp_gt_A_full.cpu(),
            "logp_gt_B_full": logp_gt_B_full.cpu(),
            "logp_gt_diff_full": logp_gt_diff_full.cpu(),
            "mean_logp_gt_A_full": mean_logp_gt_A_full,
            "mean_logp_gt_B_full": mean_logp_gt_B_full,
            "mean_logp_gt_diff_full": mean_logp_gt_diff_full,

        }
        detailed.append(detailed_row)

    print(f"[ok] Compared {len(detailed)} pairs (vocab padded safely).")

    return detailed

In [None]:
"""
valid_positions = record["attention_mask"][:-1].bool()
logits = record["logits"][valid_positions]
targets = record["target_ids"][valid_positions]

bos_id, eos_id = tokenizer.bos_token_id, tokenizer.eos_token_id
mask = (record["attention_mask"][:-1].bool()) '\'
        & (record["input_ids"][:-1] != bos_id) '\'
        & (record["input_ids"][:-1] != eos_id)

"""

# ==============================================
# Hidden Acts Similarity & PPL =================
# ==============================================

In [None]:
import os
import glob
import math
import torch
import pandas as pd
import torch.nn.functional as F
from tqdm import tqdm

def load_hidden_batches(folder, pattern="*.pt"):
    """Load all saved hidden-state batches collected with collect_hidden_states_full()."""
    files = sorted(glob.glob(os.path.join(folder, pattern)))
    all_rows = []
    for f in files:
        print(f"[load] {f}")
        rows = torch.load(f, map_location="cpu")
        all_rows.extend(rows)
    print(f"[ok] Loaded {len(all_rows)} rows from {len(files)} files.")
    return all_rows


def analyze_hidden_similarity_batched(rows_A, rows_B, compute_perplexity=True):
    """
    Compare hidden activations from two models (A vs B).
    Computes:
      - L2 distance (positional, sequence, layer)
      - Cosine similarity (positional, sequence, layer)
      - Optional: perplexity per model if logits available

    Returns:
        df_detailed : detailed per-prompt/layer/position metrics
    """

    def safe(x, default=None, eps=1e-9):
        """Clamp and sanitize tensor values."""
        if x is None:
            return torch.tensor(default if default is not None else 0.0)
        x = torch.nan_to_num(x, nan=0.0, posinf=30.0, neginf=-30.0)
        if torch.is_floating_point(x):
            mean_val = x.mean().item() if x.numel() > 0 else 0.0
            if mean_val < 1.0:  # looks like log-probs
                x = torch.clamp(x, min=-30.0, max=30.0)
            else:
                x = torch.clamp(x, min=0.0, max=1e6)
        return x

    detailed = []
    index_B = {(r["prompt_id"], r["layer_name"]): r for r in rows_B if "hidden" in r}

    for rA in tqdm(rows_A, desc="Comparing hidden states"):
        key = (rA["prompt_id"], rA["layer_name"])
        rB = index_B.get(key)
        if not rB or "hidden" not in rA or "hidden" not in rB:
            continue

        hA, hB = rA["hidden"], rB["hidden"]
        if not isinstance(hA, torch.Tensor) or not isinstance(hB, torch.Tensor):
            continue
        if hA.shape != hB.shape:
            continue

        # === Hidden-state metrics ===
        diff = hA - hB
        l2_pos = torch.norm(diff, dim=-1)
        cos_pos = F.cosine_similarity(hA, hB, dim=-1)
        l2_seq = l2_pos.mean().item()
        cos_seq = cos_pos.mean().item()

        # === Optional perplexity ===
        pplA = pplB = ppl_diff_signed = None
        if compute_perplexity and "logits" in rA and "logits" in rB:
            for label, logits in zip(("A", "B"), (rA["logits"], rB["logits"])):
                if logits is not None and "target_ids" in rA:
                    log_probs = safe(F.log_softmax(logits, dim=-1))
                    target = rA["target_ids"]
                    nll = -log_probs.gather(-1, target.unsqueeze(-1)).squeeze(-1)
                    mask = (target != 0).float()
                    mean_nll = (nll * mask).sum() / mask.sum()
                    ppl = math.exp(mean_nll.item())
                    if label == "A":
                        pplA = ppl
                    else:
                        pplB = ppl

            # signed PPL difference (log-scale for stability)
            if pplA is not None and pplB is not None:
                ppl_diff_signed = math.log(pplB + 1e-9) - math.log(pplA + 1e-9)

        detailed.append({
            "prompt_id": rA["prompt_id"],
            "batch_index": rA.get("batch_index"),
            "layer_index": rA["layer_index"],
            "layer_name": rA["layer_name"],
            "mean_l2_seq": l2_seq,
            "mean_cos_seq": cos_seq,
            "l2_pos": l2_pos.cpu(),
            "cos_pos": cos_pos.cpu(),
            "ppl_A": pplA,
            "ppl_B": pplB,
            "ppl_diff_signed": ppl_diff_signed,
        })

    if not detailed:
        raise ValueError("No valid pairs found. Check layer alignment and hidden data.")

    df_detailed = pd.DataFrame(detailed)
    print(f"[ok] Computed similarity metrics for {df_detailed['layer_name'].nunique()} layers.")
    return df_detailed


# ==============================================
# Head Semantics ===============================
# ==============================================

In [None]:
import torch
import gc
from tqdm import tqdm
from pathlib import Path

# orig→quant: orig_quant
# quant→orig quant_orig
@torch.no_grad()
def project_hidden_through_other_head(
    projector_model,
    hidden_file,
    batch_num=0,
    save_prefix="saved_data/projections",
    direction_tag="orig_quant",
    device="cpu",
):
    """
    Projects pre-saved hidden activations (from collect_hidden_states_full)
    through ANOTHER model's LM head (cross-model projection).

    """

    # --- Ensure model + device setup ---
    projector_model = projector_model.to(device).eval()
    Path(save_prefix).mkdir(parents=True, exist_ok=True)

    # --- Handle batch number formatting ---
    if isinstance(batch_num, str) and not batch_num.isdigit():
        batch_str = batch_num
    else:
        try:
            batch_str = f"{int(batch_num):03d}"
        except Exception:
            batch_str = str(batch_num)

    # --- Define save path ---
    save_path = f"{save_prefix}/proj_{direction_tag}_batch{batch_str}.pt"
    print(f"[run] Projecting hidden activations {direction_tag} → {save_path}")

    # --- Load pre-saved activations ---
    rows = torch.load(hidden_file, map_location="cpu", weights_only=False)
    results = []

    # --- Main projection loop ---
    for r in tqdm(rows, desc=f"[batch {batch_str}] {direction_tag} projection"):
        if "hidden" not in r:
            continue

        h = r["hidden"].to(device)
        logits_cross = projector_model.lm_head(h).cpu()

        # Preserve all meta info
        results.append({
            "prompt_id": r.get("prompt_id"),
            "prompt_text": r.get("prompt_text"),
            "batch_index": r.get("batch_index", batch_num),
            "layer_index": r.get("layer_index"),
            "layer_name": r.get("layer_name"),
            "input_ids": r.get("input_ids"),
            "target_ids": r.get("target_ids"),
            "logits_crossproj": logits_cross,
        })

        # Free memory
        del h, logits_cross
        torch.cuda.empty_cache()
        gc.collect()

    # --- Save the result ---
    torch.save(results, save_path)
    print(f"[saved] → {save_path} ({len(results)} entries, batch_index={batch_num})")
    return save_path

batch_num = "009"

# (1) Project original model’s hidden states through quantized model’s head (quant→orig)
project_hidden_through_other_head(
    projector_model=model_quant,
    hidden_file=f"saved_data/hidden_acts/m_orig_batch{batch_num}.pt",
    batch_num=batch_num,
    save_prefix="saved_data/projections",
    direction_tag="orig_quant",
    device="cpu",
)

batch_num = "009"

# (2) Project quantized model’s hidden states through original model’s head
project_hidden_through_other_head(
    projector_model=model_orig,
    hidden_file=f"saved_data/hidden_acts/m_quant_batch{batch_num}.pt",
    batch_num=batch_num,
    save_prefix="saved_data/projections",
    direction_tag="quant_orig",
    device="cpu",
)

In [None]:
import os
import gc
import torch
import torch.nn.functional as F
from tqdm import tqdm
from pathlib import Path

@torch.no_grad()
def run_projection_comparison_single_batch(
    file_A,
    file_B,
    tokenizer,
    batch_num=0,
    save_prefix="saved_data/head_semantics/semantics_",
    topk_levels=(1, 5, 10, 20),
    device="cpu",
):
    """
    Compare *already projected* logits from two models (e.g., orig→quant vs quant→orig).
    Computes:
      • cosine & L2 distances between logits (sequence-level + per-position)
      • Gini coefficients (sequence-level + per-position)
      • top-k Jaccard overlaps (per-position + sequence-level)
      • decoded top-k tokens for interpretability
    """

    Path(save_prefix).parent.mkdir(parents=True, exist_ok=True)
    save_path = f"{save_prefix}batch{int(batch_num):03d}.pt"

    rows_A = torch.load(file_A, map_location=device, weights_only=False)
    rows_B = torch.load(file_B, map_location=device, weights_only=False)

    # --- Helper for Gini coefficient ---
    def gini_coefficient(x):
        """Compute Gini per position (for logits → probs)."""
        probs = torch.softmax(x, dim=-1)
        n = probs.shape[-1]
        sorted_p, _ = torch.sort(probs, dim=-1)
        cum_p = torch.cumsum(sorted_p, dim=-1)
        gini = 1 - 2 * cum_p.sum(dim=-1) / (n * cum_p[..., -1]) + 1 / n
        return gini

    index_B = {(r["prompt_id"], r["layer_name"]): r for r in rows_B}
    results = []

    for rA in tqdm(rows_A, desc=f"[batch {batch_num}] Comparing projected logits"):
        key = (rA["prompt_id"], rA["layer_name"])
        rB = index_B.get(key)
        if not rB or "logits_crossproj" not in rA or "logits_crossproj" not in rB:
            continue

        logits_A = rA["logits_crossproj"].to(device)   # [T, V]
        logits_B = rB["logits_crossproj"].to(device)   # [T, V]
        vocab_dim = logits_A.shape[-1]
        seq_len = logits_A.shape[0]

        # --- Per-position metrics ---
        cos_pos = F.cosine_similarity(logits_A, logits_B, dim=-1)           # [T]
        l2_pos = torch.norm(logits_A - logits_B, dim=-1)                    # [T]
        gini_A_pos = gini_coefficient(logits_A)                             # [T]
        gini_B_pos = gini_coefficient(logits_B)                             # [T]
        delta_gini_pos = gini_B_pos - gini_A_pos

        # --- Sequence-level summaries ---
        cos_seq = cos_pos.mean().item()
        l2_seq = l2_pos.mean().item()
        gini_A_seq = gini_A_pos.mean().item()
        gini_B_seq = gini_B_pos.mean().item()
        delta_gini_seq = gini_B_seq - gini_A_seq

        # --- Top-k Jaccard overlaps ---
        jaccard_pos_dict = {}
        jaccard_seq_dict = {}
        decoded_shift = {}

        for k in topk_levels:
            k_safe = min(k, vocab_dim)
            jaccard_pos = []
            for t in range(seq_len):
                topA_t = torch.topk(logits_A[t], k_safe).indices.tolist()
                topB_t = torch.topk(logits_B[t], k_safe).indices.tolist()
                inter = len(set(topA_t) & set(topB_t))
                union = len(set(topA_t) | set(topB_t))
                jaccard_pos.append(inter / union if union > 0 else 0.0)
            jaccard_pos_tensor = torch.tensor(jaccard_pos)
            jaccard_pos_dict[f"jaccard_top{k}_pos"] = jaccard_pos_tensor.cpu()
            jaccard_seq_dict[f"jaccard_top{k}_seq"] = jaccard_pos_tensor.mean().item()

            # Decode top-k tokens for mean logits (for readability)
            mean_A = logits_A.mean(dim=0)
            mean_B = logits_B.mean(dim=0)
            topA_k = torch.topk(mean_A, k_safe).indices.tolist()
            topB_k = torch.topk(mean_B, k_safe).indices.tolist()
            decoded_shift[f"top{k}_A"] = [tokenizer.decode([i]) for i in topA_k]
            decoded_shift[f"top{k}_B"] = [tokenizer.decode([i]) for i in topB_k]

        # --- Aggregate everything ---
        result_entry = {
            "prompt_id": rA["prompt_id"],
            "batch_index": rA.get("batch_index", batch_num),
            "layer_index": rA["layer_index"],
            "layer_name": rA["layer_name"],
            "input_ids": rA.get("input_ids"),
            "target_ids": rA.get("target_ids"),

            # --- sequence-level summary ---
            "cosine_seq": cos_seq,
            "l2_seq": l2_seq,
            "gini_A_seq": gini_A_seq,
            "gini_B_seq": gini_B_seq,
            "delta_gini_seq": delta_gini_seq,
            **jaccard_seq_dict,

            # --- detailed per-position tensors ---
            "cosine_pos": cos_pos.cpu(),
            "l2_pos": l2_pos.cpu(),
            "gini_A_pos": gini_A_pos.cpu(),
            "gini_B_pos": gini_B_pos.cpu(),
            "delta_gini_pos": delta_gini_pos.cpu(),
            **jaccard_pos_dict,

            "decoded_shift": decoded_shift,
        }

        results.append(result_entry)
        del logits_A, logits_B
        torch.cuda.empty_cache()
        gc.collect()

    torch.save(results, save_path)
    print(f"[saved] → {save_path} ({len(results)} entries)")
    return results

from transformers import AutoTokenizer
orig_tokenizer = AutoTokenizer.from_pretrained("Models/LLaMA3Instruct")

batch_num = "002"

run_projection_comparison_single_batch(
    file_A=f"saved_data/projections/proj_orig_quant_batch{batch_num}.pt",
    file_B=f"saved_data/projections/proj_quant_orig_batch{batch_num}.pt",
    tokenizer=orig_tokenizer,
    batch_num=batch_num,
    save_prefix="saved_data/projection_comparisons/semantics_",
)