# ==============================================
# Notebook for Logit Lens and Hidden Acts ======
# ==============================================

In [1]:
from pathlib import Path
from datasets import load_dataset, DownloadMode
import torch
import os
import glob
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

from old_lens.mi_utils.quant_configs.bnb_configs import load_bnb_in_4bit

# ==============================================
# Models & Dataset =============================
# ==============================================

In [2]:
filepath = r'D:\LogitLensData\nq'

destination_path = str(Path(filepath))
nq_dataset = load_dataset(
    'sentence-transformers/natural-questions',
    split={
        'train': 'train[:1000]'
    },
    cache_dir=destination_path,
    download_mode=DownloadMode.REUSE_DATASET_IF_EXISTS,
    keep_in_memory=True
)

nq_queries = nq_dataset['train']['query']
nq_answers = nq_dataset['train']['answer']

nq_queries_200 = nq_queries[:200]
"""nq_queries_400 = nq_queries[200:400]
nq_queries_600 = nq_queries[400:600]
nq_queries_800 = nq_queries[600:800]
nq_queries_1000 = nq_queries[800:1000]"""
nq_queries_200_1 = nq_queries[:1]
nq_1000 = nq_queries[:1000]

In [3]:
from enum import Enum

class Models(Enum):
    LAIN8B = "Models/LLaMA3Instruct"
    HF100B = "Models/HF1BitLLM100Btokens"


class Names(Enum):
    LAIN8B = "Meta-Llama-3-8B-Instruct-fp"
    HF100B = "Llama3-8B-1.58-100B-tokens"

In [4]:
def load_model_and_tok(
    model_name: str,
    low_cpu_mem_usage: bool = True,
    local_files_only: bool = True,
    device_map: str = "cpu",
    dtype: torch.dtype = torch.float32,
    load_in_8bit: bool = False
):
    tok = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        output_hidden_states=True,
        return_dict_in_generate=True,
        return_dict=True,
        output_attentions=True,
        low_cpu_mem_usage=low_cpu_mem_usage,
        local_files_only=local_files_only,
        device_map=device_map,
        load_in_8bit=load_in_8bit,
        torch_dtype=dtype,
        attn_implementation="eager",
    )
    return model, tok

In [None]:
model_orig, orig_tokenizer = load_model_and_tok(Models.LAIN8B.value) 

In [5]:
model_8bit, orig_tokenizer = load_model_and_tok(Models.LAIN8B.value, dtype=torch.float16, load_in_8bit=True) 

The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. Please, pass a `BitsAndBytesConfig` object in `quantization_config` argument instead.


Loading checkpoint shards:   0%|          | 0/7 [00:00<?, ?it/s]

In [None]:
model_4bit, orig_tokenizer = load_bnb_in_4bit(Models.LAIN8B.value, double_quant=False, dtype=torch.float16, device_map="cpu") 

In [None]:
model_quant, quant_tokenizer = load_model_and_tok(Models.HF100B.value) 

In [None]:
lengths = []
for q in nq_1000:
    ids = orig_tokenizer.encode(q, add_special_tokens=True)
    lengths.append(len(ids))

print("=== Token Length Statistics (LLaMA-3-8B-Instruct tokenizer) ===")
print(f"Samples analyzed: {len(lengths)}")
print(f"Mean length:       {np.mean(lengths):.2f}")
print(f"Median length:     {np.median(lengths):.0f}")
print(f"90th percentile:   {np.percentile(lengths, 90):.0f}")
print(f"95th percentile:   {np.percentile(lengths, 95):.0f}")
print(f"Max observed len:  {np.max(lengths)}")
print(f"Min observed len:  {np.min(lengths)}")

In [None]:
import inspect
sig = inspect.signature(model_orig.model.layers[0].forward)
print(sig)


In [1]:
import numpy as np
import torch

def load_npz_as_metrics(npz_path, norm_modes=("raw", "unit_rms", "norm_rms")):
    """
    Load a .npz file saved from the logit lens analysis
    and reconstruct the 'metrics' list expected by preprocess_metrics().
    """
    data = np.load(npz_path, allow_pickle=True)
    metrics = []

    num_rows = len(data["prompt_id"])
    for i in range(num_rows):
        row = {
            "prompt_id": int(data["prompt_id"][i]),
            "layer_index": int(data["layer_index"][i]),
            "layer_name": str(data["layer_name"][i]),
            "batch_index": int(data["batch_index"][i]),
            "vocab_size": int(data["vocab_size"][i]),
        }

        # Optional: if you saved prompt_texts, etc.
        if "prompt_text" in data:
            row["prompt_text"] = str(data["prompt_text"][i])

        # target_ids & attention_mask are stored per row
        if "target_ids" in data:
            row["target_ids"] = torch.from_numpy(data["target_ids"][i])
        if "attention_mask" in data:
            row["attention_mask"] = torch.from_numpy(data["attention_mask"][i])

        # Attach all normalization modes
        for mode in norm_modes:
            if f"logits_{mode}" in data:
                row[f"logits_{mode}"] = torch.from_numpy(data[f"logits_{mode}"][i])
            if f"hidden_{mode}" in data:
                row[f"hidden_{mode}"] = torch.from_numpy(data[f"hidden_{mode}"][i])

        metrics.append(row)

    print(f"[loaded] {num_rows} records from {npz_path}")
    return metrics


In [25]:
import os, numpy as np
data_path = "saved_data/lens_data/m_8bit/m_8bit_modes_batch000"
if not data_path.endswith(".npz"):
    data_path += ".npz"
data = np.load(data_path, allow_pickle=True)


In [27]:
data

NpzFile 'saved_data/lens_data/m_8bit/m_8bit_modes_batch000.npz' with keys: hidden_raw, logits_raw, hidden_unit_rms, logits_unit_rms, hidden_norm_rms...

# ==============================================
# Logit Lens with Normaliztion =================
# ==============================================

In [None]:
import torch
import torch.nn.functional as F
from pathlib import Path
from tqdm import tqdm
import numpy as np
import gc
import warnings
warnings.filterwarnings("ignore", message="MatMul8bitLt: inputs will be cast")


# ============================================================
# Normalization modes
# ============================================================
def apply_normalization(x, model, normalize_mode="raw", block=None, layer_index=None):
    x = x.to(torch.float32) 
    
    if normalize_mode == "raw":
        return x

    elif normalize_mode == "unit_rms":

        eps = 1e-5
        return x / (x.pow(2).mean(dim=-1, keepdim=True).add(eps).sqrt())

    elif normalize_mode == "rms_layer":
        if layer_index == -1:
            return x
        elif block is not None:
            return block.post_attention_layernorm(x)
        else:
            #return model.model.norm(x)
            return model.model.norm(x).to(torch.float32)

    elif normalize_mode == "norm_rms":
        if layer_index == -1:
            return x
        else:
            #return model.model.norm(x)
            return model.model.norm(x).to(torch.float32)

    else:
        raise ValueError(f"Unknown normalization_mode: {normalize_mode}")


# ============================================================
# Helper: causal + padding mask for LLaMA blocks
# ============================================================
def build_full_attention_mask(input_ids, attention_mask, device, model=None):
    bsz, seq_len = input_ids.shape

    causal = torch.triu(torch.ones((seq_len, seq_len), device=device, dtype=torch.bool), diagonal=1)
    causal = causal.unsqueeze(0).unsqueeze(0)

    if attention_mask is None:
        padding_mask = torch.zeros((bsz, 1, 1, seq_len), device=device, dtype=torch.bool)
    else:
        padding_mask = (attention_mask[:, None, None, :] == 0)

    full = causal | padding_mask

    attn_impl = getattr(model.config, "attn_implementation", "eager") if model else "eager"

    if attn_impl in ["flash_attention_2", "sdpa"]:
        return full
    else:
        return full.to(torch.float32) * -1e9


# ============================================================
# Collector with multiple normalization variants + attention weights (optional storage)
# ============================================================
@torch.no_grad()
def collect_logit_lens_full(
    model,
    tokenizer,
    prompts,
    batch_index=0,
    max_len=17,
    device=None,
    clamp_logits=False,         
    clamp_value=100.0,         
    save_path=None,
    collect_attn=True,
    save_attn=True,
    norm_modes=("raw", "unit_rms", "norm_rms"),
):

    device = device or ("cuda" if torch.cuda.is_available() else "cpu")

    bnb_layer_types = ("Linear4bit", "Linear8bitLt")
    is_quantized = any(any(name in type(m).__name__ for name in bnb_layer_types)
                       for m in model.modules())

    if is_quantized:
        try:
            first_param_device = next(model.parameters()).device
        except StopIteration:
            first_param_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        print(f"[info] Detected quantized model → device {first_param_device}")
        model.eval()
    else:
        model = model.to(device).eval()

    Path(save_path).parent.mkdir(parents=True, exist_ok=True)

    # ============================================================
    # Tokenization
    # ============================================================
    encoded = []
    pad_id = tokenizer.pad_token_id or tokenizer.eos_token_id
    bos_id = tokenizer.bos_token_id
    eos_id = tokenizer.eos_token_id

    for p in prompts:
        ids = tokenizer.encode(p, add_special_tokens=False)
        content = ids[: max_len - 2]
        ids = torch.tensor([bos_id] + content + [eos_id], dtype=torch.long)
        if len(ids) < max_len:
            ids = F.pad(ids, (0, max_len - len(ids)), value=pad_id)
        encoded.append(ids)

    input_ids = torch.stack(encoded, dim=0).to(device)

    # ============================================================
    # Build attention mask (stop after first EOS)
    # ============================================================
    attention_mask = torch.ones_like(input_ids, dtype=torch.long)
    for i, ids in enumerate(input_ids):
        eos_positions = (ids == eos_id).nonzero(as_tuple=True)[0]
        if len(eos_positions) > 0:
            eos_pos = eos_positions[0].item()
            attention_mask[i, eos_pos + 1:] = 0

    batch_size, seq_len = input_ids.shape
    full_mask = build_full_attention_mask(input_ids, attention_mask, device)
    position_ids = torch.arange(seq_len, device=device).unsqueeze(0).expand(batch_size, -1)
    vocab_size = model.lm_head.out_features

    print(f"[info] Tokenized {batch_size} prompts | seq_len={seq_len}")
    print(f"[info] Collecting from {len(model.model.layers)} layers | quantized={is_quantized}")

    # ============================================================
    # Projection helper (with consistent upcasting and sanitization)
    # ============================================================
    def project(x):
        x_fp32 = x.to(torch.float32)

        head_dtype = next(model.lm_head.parameters()).dtype
        x_cast = x_fp32.to(head_dtype)

        logits = model.lm_head(x_cast)

        logits = torch.nan_to_num(logits, nan=0.0, posinf=0.0, neginf=0.0)

        if clamp_logits and is_quantized:
            logits = logits.clamp(-clamp_value, clamp_value)

        return logits.to(torch.float32)

    # ============================================================
    # Collection containers
    # ============================================================
    rows = []
    all_hidden, all_logits, all_attn = {}, {}, {}

    # ============================================================
    # Embedding layer
    # ============================================================
    x = model.model.embed_tokens(input_ids).to(torch.float32)
    hidden_variants = {mode: apply_normalization(x.clone(), model, mode, layer_index=-1)
                       for mode in norm_modes}
    logits_variants = {mode: project(hidden_variants[mode]) for mode in norm_modes}

    for mode in norm_modes:
        all_hidden[f"embed_tokens_{mode}"] = hidden_variants[mode].cpu()
        all_logits[f"embed_tokens_{mode}"] = logits_variants[mode].cpu()

    for i in range(batch_size):
        rows.append({
            "prompt_id": i,
            "prompt_text": prompts[i],
            "batch_index": batch_index,
            "vocab_size": vocab_size,
            "layer_index": -1,
            "layer_name": "embed_tokens",
            "input_ids": input_ids[i].cpu(),
            "target_ids": input_ids[i, 1:].cpu(),
            "attention_mask": attention_mask[i].cpu(),
            **{f"hidden_{m}": hidden_variants[m][i, :-1].cpu() for m in norm_modes},
            **{f"logits_{m}": logits_variants[m][i, :-1].cpu() for m in norm_modes},
        })

    # ============================================================
    # Transformer layers
    # ============================================================
    for li, block in enumerate(model.model.layers):
        out = block(
            x, position_ids=position_ids,
            attention_mask=full_mask,
            output_attentions=collect_attn,
        )
        x = out[0]
        attn = out[1] if collect_attn else None

        layer_output = x.detach().clone().to(torch.float32)
        hidden_variants = {
            mode: apply_normalization(layer_output.clone(), model, mode, block=block, layer_index=li)
            for mode in norm_modes
        }
        logits_variants = {mode: project(hidden_variants[mode]) for mode in norm_modes}

        for mode in norm_modes:
            hidden_variants[mode] = hidden_variants[mode][:, :-1, :]
            logits_variants[mode] = logits_variants[mode][:, :-1, :]

        for mode in norm_modes:
            all_hidden[f"layer.{li}_{mode}"] = hidden_variants[mode].cpu()
            all_logits[f"layer.{li}_{mode}"] = logits_variants[mode].cpu()

        for i in range(batch_size):
            record = {
                "prompt_id": i,
                "prompt_text": prompts[i],
                "batch_index": batch_index,
                "vocab_size": vocab_size,
                "layer_index": li,
                "layer_name": f"layer.{li}",
                "input_ids": input_ids[i].cpu(),
                "target_ids": input_ids[i, 1:].cpu(),
                "attention_mask": attention_mask[i].cpu(),
                **{f"hidden_{m}": hidden_variants[m][i].cpu() for m in norm_modes},
                **{f"logits_{m}": logits_variants[m][i].cpu() for m in norm_modes},
            }
            if save_attn and attn is not None:
                record["attn"] = attn[i].cpu()
            rows.append(record)

        torch.cuda.empty_cache()
        gc.collect()

    # ============================================================
    # Final RMSNorm output
    # ============================================================
    x = model.model.norm(x.to(torch.float32))
    h = x
    l = project(h)
    h, l = h[:, :-1, :], l[:, :-1, :]

    all_hidden["output_true"] = h.cpu()
    all_logits["output_true"] = l.cpu()

    for mode in norm_modes:
        all_hidden[f"output_{mode}"] = h.cpu()
        all_logits[f"output_{mode}"] = l.cpu()

    for i in range(batch_size):
        rows.append({
            "prompt_id": i,
            "prompt_text": prompts[i],
            "batch_index": batch_index,
            "vocab_size": vocab_size,
            "layer_index": len(model.model.layers),
            "layer_name": "output",
            "input_ids": input_ids[i].cpu(),
            "target_ids": input_ids[i, 1:].cpu(),
            "attention_mask": attention_mask[i].cpu(),
            **{f"hidden_{m}": h[i].cpu() for m in norm_modes},
            **{f"logits_{m}": l[i].cpu() for m in norm_modes},
        })

    # ============================================================
    # Save and finish
    # ============================================================
    if save_path:
        base = os.path.splitext(save_path)[0]
        np_data = {}

        for mode in norm_modes:
            np_data[f"hidden_{mode}"] = torch.stack([r[f"hidden_{mode}"] for r in rows]).cpu().numpy()
            np_data[f"logits_{mode}"] = torch.stack([r[f"logits_{mode}"] for r in rows]).cpu().numpy()

        np_data["prompt_id"]   = np.array([r["prompt_id"] for r in rows])
        np_data["batch_index"] = np.array([r["batch_index"] for r in rows])
        np_data["layer_index"] = np.array([r["layer_index"] for r in rows])
        np_data["layer_name"]  = np.array([r["layer_name"] for r in rows], dtype=object)
        np_data["vocab_size"]  = np.array([r["vocab_size"] for r in rows])

        out_path = f"{base}.npz"
        np.savez_compressed(out_path, **np_data)
        print(f"[saved] {out_path} ({os.path.getsize(out_path)/(1024*1024):.2f} MB)")

        try:
            test = np.load(out_path, allow_pickle=True)
            print(f"[check] loaded OK: {list(test.keys())[:5]} ...")
            test.close()
        except Exception as e:
            print(f"[error] Load check failed: {e}")


    print(f"[ok] Collected {len(rows)} entries from {len(model.model.layers)} layers.")
    return rows, all_hidden, all_logits, all_attn

In [None]:
import torch
import gc
from tqdm import tqdm

def run_logit_lens_in_batches(
    model,
    tokenizer,
    all_prompts,
    batch_size=20,
    save_prefix="logitlens_batch",
    max_len=17,
    normalize_mode=("raw", "unit_rms", "norm_rms"),
    device=None,
    clamp_logits=False,
    collect_attn=False,
    save_attn=False 
):

    num_batches = (len(all_prompts) + batch_size - 1) // batch_size
    print(f"[run] Processing {len(all_prompts)} prompts in {num_batches} batches of {batch_size}")

    for batch_idx in tqdm(range(num_batches), desc="Running logit lens batches"):
        start = batch_idx * batch_size
        end = min((batch_idx + 1) * batch_size, len(all_prompts))
        batch_prompts = all_prompts[start:end]

        save_path = f"{save_prefix}_batch{batch_idx:03d}.pt"

        print(f"\n[batch {batch_idx+1}/{num_batches}] {len(batch_prompts)} prompts → {save_path}")

        try:
            rows, hidden_dict, logits_dict, all_attn = collect_logit_lens_full(
                model=model,
                tokenizer=tokenizer,
                prompts=batch_prompts,
                max_len=max_len,
                device=device,
                norm_modes=normalize_mode,
                save_path=save_path,
                clamp_logits=clamp_logits,
                collect_attn=collect_attn,
                save_attn=save_attn,
            )

        except RuntimeError as e:
            print(f"[error] Batch {batch_idx} failed: {e}")
            continue

        del rows, hidden_dict, logits_dict, all_attn, batch_prompts
        torch.cuda.empty_cache()
        gc.collect()

    print("\n[done] All batches processed and saved.")

In [20]:
run_logit_lens_in_batches(
    model=model_8bit,
    tokenizer=orig_tokenizer,
    all_prompts=nq_queries_200_1,
    batch_size=20,
    max_len=17,
    normalize_mode=("raw", "unit_rms", "norm_rms"),
    save_prefix="saved_data/lens_data/m_8bit/m_8bit_modes",
    device="cpu",
    clamp_logits=False,
    collect_attn=False,
    save_attn=False  
) 

[run] Processing 1 prompts in 1 batches of 20


Running logit lens batches:   0%|          | 0/1 [00:00<?, ?it/s]


[batch 1/1] 1 prompts → saved_data/lens_data/m_8bit/m_8bit_modes_batch000.pt
[info] Detected quantized model → device cpu
[info] Tokenized 1 prompts | seq_len=17
[info] Collecting from 32 layers | quantized=True
[saved] saved_data/lens_data/m_8bit/m_8bit_modes_batch000.npz (481.88 MB)
[check] loaded OK: ['hidden_raw', 'logits_raw', 'hidden_unit_rms', 'logits_unit_rms', 'hidden_norm_rms'] ...
[ok] Collected 34 entries from 32 layers.


Running logit lens batches: 100%|██████████| 1/1 [02:06<00:00, 126.80s/it]


[done] All batches processed and saved.





In [None]:
import numpy as np

data = np.load("saved_data/lens_data/m_8bit/m_8bit_modes_batch000.npz", allow_pickle=True)

print(list(data.keys()))


['hidden_raw', 'logits_raw', 'hidden_unit_rms', 'logits_unit_rms', 'hidden_norm_rms', 'logits_norm_rms', 'prompt_id', 'batch_index', 'layer_index', 'layer_name', 'vocab_size']


In [29]:
for k in data.keys():
    arr = data[k]
    print(f"{k:20s} shape={arr.shape}, dtype={arr.dtype}")



hidden_raw           shape=(34, 16, 4096), dtype=float32
logits_raw           shape=(34, 16, 128256), dtype=float32
hidden_unit_rms      shape=(34, 16, 4096), dtype=float32
logits_unit_rms      shape=(34, 16, 128256), dtype=float32
hidden_norm_rms      shape=(34, 16, 4096), dtype=float32
logits_norm_rms      shape=(34, 16, 128256), dtype=float32
prompt_id            shape=(34,), dtype=int64
batch_index          shape=(34,), dtype=int64
layer_index          shape=(34,), dtype=int64
layer_name           shape=(34,), dtype=object
vocab_size           shape=(34,), dtype=int64


In [31]:
arrays = dict(data)
data.close()  

In [None]:
import numpy as np
import torch

path = "saved_data/lens_data/m_8bit/m_8bit_modes_batch000.npz"
data = np.load(path, allow_pickle=True)

for k in data.keys():
    arr = data[k]
    if not np.issubdtype(arr.dtype, np.floating):
        continue 
    n_nan = np.isnan(arr).sum()
    n_inf = np.isinf(arr).sum()
    print(f"{k:20s}: NaNs={n_nan:,}, Infs={n_inf:,}")


hidden_raw          : NaNs=0, Infs=0
logits_raw          : NaNs=0, Infs=0
hidden_unit_rms     : NaNs=0, Infs=0
logits_unit_rms     : NaNs=0, Infs=0
hidden_norm_rms     : NaNs=0, Infs=0
logits_norm_rms     : NaNs=0, Infs=0


In [34]:
import pandas as pd

summary = []
for k in data.keys():
    if "logits" in k:
        arr = data[k]
        summary.append({
            "name": k,
            "mean": float(np.mean(arr)),
            "std": float(np.std(arr)),
            "max": float(np.max(arr)),
            "min": float(np.min(arr))
        })
pd.DataFrame(summary)


Unnamed: 0,name,mean,std,max,min
0,logits_raw,-0.068328,1.736763,37.6875,-66.75
1,logits_unit_rms,-0.56426,1.427832,22.0,-15.554688
2,logits_norm_rms,-0.349259,2.006538,26.0,-13.992188


In [None]:
print(orig_tokenizer.decode([128000]))
print(orig_tokenizer.decode([128009]))

<|begin_of_text|>
<|eot_id|>


In [None]:
print(orig_tokenizer.decode([128000, 9906, 1917, 128009]))

<|begin_of_text|>Hello world<|eot_id|>


In [11]:
print(orig_tokenizer.bos_token_id, orig_tokenizer.eos_token_id)
print(model_8bit.config.bos_token_id, model_8bit.config.eos_token_id)

128000 128009
128000 128009


In [21]:
ids = orig_tokenizer.encode("Hello world", add_special_tokens=True)
print(ids)

[128000, 9906, 1917]


In [None]:
run_logit_lens_in_batches(
    model=model_quant,
    tokenizer=quant_tokenizer,
    all_prompts=nq_queries_200,
    batch_size=20,
    max_len=18,
    normalize_mode="raw",
    save_prefix="logs/new_model_lens/raw/m_quant_raw",
    device="cpu",
) 

# ==============================================
# TopK Comparison ==============================
# ==============================================

In [None]:
import torch
import torch.nn.functional as F
import numpy as np
import pandas as pd
import os


def safe_mean(x):
    if x is None:
        return 0.0
    if torch.is_tensor(x):
        if x.numel() == 0:
            return 0.0
        x = torch.nan_to_num(x.detach().float(), nan=0.0, posinf=0.0, neginf=0.0).flatten()
        if x.numel() == 0:
            return 0.0
        return float(x.mean().item())
    if isinstance(x, (list, tuple, np.ndarray)):
        arr = np.asarray(x, dtype=np.float64).ravel()
        if arr.size == 0:
            return 0.0
        arr = np.nan_to_num(arr, nan=0.0, posinf=0.0, neginf=0.0)
        return float(np.mean(arr))
    try:
        val = float(x)
        if not np.isfinite(val):
            return 0.0
        return val
    except Exception:
        return 0.0


def preprocess_metrics(metrics, lens_type="raw"):
    processed = []
    for row in metrics:
        logits = row.get(f"logits_{lens_type}", row.get("logits"))
        attn_mask = row.get("attention_mask")
        targets = row.get("target_ids")

        if logits is None or targets is None:
            continue

        # Convert to tensors
        logits = logits.unsqueeze(0) if logits.ndim == 2 else logits
        targets = targets.unsqueeze(0) if targets.ndim == 1 else targets

        # --- Build eval mask (ignore BOS, keep up to first EOS)
        if attn_mask is not None:
            eval_mask = attn_mask.clone()
            eval_mask[:, 0] = 0  # ignore BOS
            eval_mask = eval_mask.bool()
        else:
            seq_len = logits.size(1)
            eval_mask = torch.ones((1, seq_len), dtype=torch.bool)
            eval_mask[:, 0] = 0

        # Apply mask to logits + targets
        logits = logits[:, eval_mask[0], :]
        targets = targets[:, eval_mask[0]]

        # Store
        row_proc = dict(row)
        row_proc["logits"] = logits
        row_proc["target_ids"] = targets
        row_proc["eval_mask"] = eval_mask
        processed.append(row_proc)

    return processed


def summarize(results, mode="raw"):
    df = pd.DataFrame(results[mode])
    summary = df.groupby("layer_index").agg({
        "kl_ab": "mean",
        "jsd": "mean",
        "cosine_sim": "mean",
        "l2_dist": "mean",
        "ppl_A": "mean",
        "ppl_B": "mean"
    })
    return summary


def combine_results(output_dir="logs/new_summary", run_name="run"):
    import glob, pandas as pd
    files = glob.glob(f"{output_dir}/{run_name}_*.parquet")
    dfs = []
    for f in files:
        mode = os.path.basename(f).replace(f"{run_name}_", "").replace(".parquet", "")
        df = pd.read_parquet(f)
        df["mode"] = mode
        dfs.append(df)
    combined = pd.concat(dfs, ignore_index=True)
    print(f"[merged] {len(combined)} total rows from {len(files)} files.")
    return combined


@torch.no_grad()
def compute_all_metrics_multi(
    metrics_A,
    metrics_B,
    norm_modes=("raw", "unit_rms", "norm_rms"),
    topk=(1, 5, 10, 20),
    device=None,
    eps=1e-9,
    output_dir="logs/new_summary",
    run_name=None,
):
    """
    Compute divergence, similarity, and behavioral alignment metrics
    across multiple normalization modes — in *wide-format*
    (columns per normalization mode), including all top-k accuracy
    and agreement/disagreement measures.
    """
    import os, pandas as pd

    device = device or ("cuda" if torch.cuda.is_available() else "cpu")
    topk = sorted(set([1] + list(topk)))
    os.makedirs(output_dir, exist_ok=True)

    # ------------------------------------------------------------
    # Helpers
    # ------------------------------------------------------------
    def safe(x):
        return torch.nan_to_num(x, nan=0.0, posinf=30.0, neginf=-30.0)

    def pad_vocab_dim(logits, target_vocab):
        pad = target_vocab - logits.size(-1)
        if pad > 0:
            logits = F.pad(logits, (0, pad))
        return logits

    preprocess = lambda m, t: preprocess_metrics(m, lens_type=t)
    metrics_proc_A = {mode: preprocess(metrics_A, mode) for mode in norm_modes}
    metrics_proc_B = {mode: preprocess(metrics_B, mode) for mode in norm_modes}

    base_rows = metrics_proc_A[norm_modes[0]]

    # Base metadata template
    meta_rows = [
        {
            "prompt_id": rA.get("prompt_id"),
            "prompt_text": rA.get("prompt_text"),
            "batch_index": rA.get("batch_index"),
            "layer_index": rA.get("layer_index"),
            "layer_name": rA.get("layer_name", f"layer.{rA.get('layer_index', -1)}"),
            "vocab_size": rA.get("vocab_size"),
        }
        for rA in base_rows
    ]

    results = []

    # ============================================================
    # MAIN LOOP — per (prompt × layer)
    # ============================================================
    for meta in meta_rows:
        pid, lid = meta["prompt_id"], meta["layer_index"]
        merged = meta.copy()

        for mode in norm_modes:
            row_A = [r for r in metrics_proc_A[mode] if r["prompt_id"] == pid and r["layer_index"] == lid][0]
            row_B = [r for r in metrics_proc_B[mode] if r["prompt_id"] == pid and r["layer_index"] == lid][0]

            logits_A = safe(row_A["logits"].to(device).float())
            logits_B = safe(row_B["logits"].to(device).float())
            targets = row_A["target_ids"].to(device)

            vocab_target = max(logits_A.size(-1), logits_B.size(-1))
            logits_A, logits_B = pad_vocab_dim(logits_A, vocab_target), pad_vocab_dim(logits_B, vocab_target)

            # --- Probabilities ---
            logp_A = F.log_softmax(logits_A, dim=-1)
            logp_B = F.log_softmax(logits_B, dim=-1)
            probs_A, probs_B = logp_A.exp(), logp_B.exp()

            # ========================================================
            # === Divergence and Similarity Metrics
            # ========================================================
            kl_ab = torch.sum(probs_A * (logp_A - logp_B), dim=-1).mean().item()
            kl_ba = torch.sum(probs_B * (logp_B - logp_A), dim=-1).mean().item()
            jsd = 0.5 * (kl_ab + kl_ba)
            tvd = 0.5 * torch.sum(torch.abs(probs_A - probs_B), dim=-1).mean().item()

            entropy_A = (-torch.sum(probs_A * logp_A, dim=-1)).mean().item()
            entropy_B = (-torch.sum(probs_B * logp_B, dim=-1)).mean().item()

            cosine_sim = F.cosine_similarity(logits_A, logits_B, dim=-1).mean().item()
            l2_dist = torch.norm(logits_A - logits_B, dim=-1).mean().item()

            # ========================================================
            # === Ground-truth log-probabilities
            # ========================================================
            gather_idx = targets.unsqueeze(-1).clamp(max=logp_A.size(-1) - 1)
            logp_gt_A = logp_A.gather(-1, gather_idx).squeeze(-1)
            logp_gt_B = logp_B.gather(-1, gather_idx).squeeze(-1)
            logp_diff = (logp_gt_A - logp_gt_B).mean().item()

            # ========================================================
            # === Dispersion + Perplexity
            # ========================================================
            sorted_A, _ = torch.sort(probs_A, dim=-1)
            sorted_B, _ = torch.sort(probs_B, dim=-1)
            iqr_A = (torch.quantile(sorted_A, 0.75, dim=-1) - torch.quantile(sorted_A, 0.25, dim=-1)).mean().item()
            iqr_B = (torch.quantile(sorted_B, 0.75, dim=-1) - torch.quantile(sorted_B, 0.25, dim=-1)).mean().item()

            ce_A = F.cross_entropy(logits_A.view(-1, vocab_target), targets.view(-1), reduction="mean")
            ce_B = F.cross_entropy(logits_B.view(-1, vocab_target), targets.view(-1), reduction="mean")
            ppl_A, ppl_B = torch.exp(ce_A).item(), torch.exp(ce_B).item()

            # ========================================================
            # === Top-k Accuracy + Agreement/Disagreement
            # ========================================================
            max_k = max(topk)
            top_vals_A, top_idx_A = torch.topk(probs_A, max_k, dim=-1)
            top_vals_B, top_idx_B = torch.topk(probs_B, max_k, dim=-1)

            for k in topk:
                tkA, tkB = top_idx_A[:, :, :k], top_idx_B[:, :, :k]
                tvA, tvB = top_vals_A[:, :, :k], top_vals_B[:, :, :k]

                # --- Basic overlaps ---
                inter_mask = (tkA.unsqueeze(-1) == tkB.unsqueeze(-2))
                inter = inter_mask.any(dim=-1).sum(dim=2).float()
                jaccard = (inter / (2 * k - inter + eps)).mean().item()
                disagree = (1 - jaccard)

                # --- Accuracy ---
                tgt_expand = targets.unsqueeze(0) if targets.ndim == 1 else targets
                acc_A = (tkA == tgt_expand.unsqueeze(-1)).any(dim=-1).float().mean().item()
                acc_B = (tkB == tgt_expand.unsqueeze(-1)).any(dim=-1).float().mean().item()

                # --- Correctness-based ---
                acc_A_mat = (tkA == tgt_expand.unsqueeze(-1)).any(dim=-1).float()
                acc_B_mat = (tkB == tgt_expand.unsqueeze(-1)).any(dim=-1).float()
                agree_correct = (acc_A_mat * acc_B_mat).mean().item()
                disagree_correct = ((acc_A_mat + acc_B_mat == 1).float()).mean().item()
                agree_wrong = (((1 - acc_A_mat) * (1 - acc_B_mat)).float()).mean().item()

                # --- Shared probability mass ---
                prob_mass_A, prob_mass_B = tvA.sum(dim=-1), tvB.sum(dim=-1)
                shared_mass = torch.zeros_like(prob_mass_A)
                for i in range(tkA.size(0)):
                    shared_tokens = set(tkA[i].reshape(-1).tolist()) & set(tkB[i].reshape(-1).tolist())
                    if shared_tokens:
                        st = torch.tensor(list(shared_tokens), device=device)
                        shared_mass[i] = 0.5 * (probs_A[i, :, st].sum() + probs_B[i, :, st].sum())
                prob_overlap = (shared_mass / (0.5 * (prob_mass_A + prob_mass_B) + eps)).mean().item()

                # --- Add as flattened columns ---
                merged.update({
                    f"acc_A_top{k}_{mode}": acc_A,
                    f"acc_B_top{k}_{mode}": acc_B,
                    f"jaccard_top{k}_{mode}": jaccard,
                    f"disagree_top{k}_{mode}": disagree,
                    f"agree_correct_top{k}_{mode}": agree_correct,
                    f"disagree_correct_top{k}_{mode}": disagree_correct,
                    f"agree_wrong_top{k}_{mode}": agree_wrong,
                    f"prob_overlap_top{k}_{mode}": prob_overlap,
                })

            # ========================================================
            # === Add scalar metrics for this mode
            # ========================================================
            merged.update({
                f"kl_ab_{mode}": kl_ab,
                f"kl_ba_{mode}": kl_ba,
                f"jsd_{mode}": jsd,
                f"tvd_{mode}": tvd,
                f"entropy_A_{mode}": entropy_A,
                f"entropy_B_{mode}": entropy_B,
                f"cosine_sim_{mode}": cosine_sim,
                f"l2_dist_{mode}": l2_dist,
                f"logp_diff_{mode}": logp_diff,
                f"ppl_A_{mode}": ppl_A,
                f"ppl_B_{mode}": ppl_B,
                f"ppl_diff_{mode}": ppl_A - ppl_B,
                f"iqr_A_{mode}": iqr_A,
                f"iqr_B_{mode}": iqr_B,
            })

        results.append(merged)

    # ============================================================
    # SAVE WIDE-FORMAT RESULTS
    # ============================================================
    df = pd.DataFrame(results)
    out_path = os.path.join(output_dir, f"{run_name or 'run'}_metrics_wide.parquet")
    df.to_parquet(out_path, index=False)
    print(f"[saved] {len(df)} rows with full multi-mode metrics → {out_path}")
    return df

In [None]:
import numpy as np
import torch

def load_npz_as_metrics(npz_path, norm_modes=("raw", "unit_rms", "norm_rms")):
    """
    Load a .npz file saved from the logit lens analysis
    and reconstruct the 'metrics' list expected by preprocess_metrics().
    """
    data = np.load(npz_path, allow_pickle=True)
    metrics = []

    num_rows = len(data["prompt_id"])
    for i in range(num_rows):
        row = {
            "prompt_id": int(data["prompt_id"][i]),
            "layer_index": int(data["layer_index"][i]),
            "layer_name": str(data["layer_name"][i]),
            "batch_index": int(data["batch_index"][i]),
            "vocab_size": int(data["vocab_size"][i]),
        }

        # Optional: if you saved prompt_texts, etc.
        if "prompt_text" in data:
            row["prompt_text"] = str(data["prompt_text"][i])

        # target_ids & attention_mask are stored per row
        if "target_ids" in data:
            row["target_ids"] = torch.from_numpy(data["target_ids"][i])
        if "attention_mask" in data:
            row["attention_mask"] = torch.from_numpy(data["attention_mask"][i])

        # Attach all normalization modes
        for mode in norm_modes:
            if f"logits_{mode}" in data:
                row[f"logits_{mode}"] = torch.from_numpy(data[f"logits_{mode}"][i])
            if f"hidden_{mode}" in data:
                row[f"hidden_{mode}"] = torch.from_numpy(data[f"hidden_{mode}"][i])

        metrics.append(row)

    print(f"[loaded] {num_rows} records from {npz_path}")
    return metrics

metrics_A = load_npz_as_metrics("model_A_batch_03.npz")
metrics_B = load_npz_as_metrics("model_B_batch_03.npz")

results = compute_all_metrics_multi(metrics_A, metrics_B)

In [None]:
import pandas as pd
df_raw = pd.read_parquet("logs/new_summary/run_raw.parquet")

In [None]:
"""
valid_positions = record["attention_mask"][:-1].bool()
logits = record["logits"][valid_positions]
targets = record["target_ids"][valid_positions]

bos_id, eos_id = tokenizer.bos_token_id, tokenizer.eos_token_id
mask = (record["attention_mask"][:-1].bool()) '\'
        & (record["input_ids"][:-1] != bos_id) '\'
        & (record["input_ids"][:-1] != eos_id)

"""

# ==============================================
# Hidden Acts Similarity & PPL =================
# ==============================================

In [None]:
import os
import glob
import math
import torch
import pandas as pd
import torch.nn.functional as F
from tqdm import tqdm

def load_hidden_batches(folder, pattern="*.pt"):
    """Load all saved hidden-state batches collected with collect_hidden_states_full()."""
    files = sorted(glob.glob(os.path.join(folder, pattern)))
    all_rows = []
    for f in files:
        print(f"[load] {f}")
        rows = torch.load(f, map_location="cpu")
        all_rows.extend(rows)
    print(f"[ok] Loaded {len(all_rows)} rows from {len(files)} files.")
    return all_rows


def analyze_hidden_similarity_batched(rows_A, rows_B, compute_perplexity=True):
    """
    Compare hidden activations from two models (A vs B).
    Computes:
      - L2 distance (positional, sequence, layer)
      - Cosine similarity (positional, sequence, layer)
      - Optional: perplexity per model if logits available

    Returns:
        df_detailed : detailed per-prompt/layer/position metrics
    """

    def safe(x, default=None, eps=1e-9):
        """Clamp and sanitize tensor values."""
        if x is None:
            return torch.tensor(default if default is not None else 0.0)
        x = torch.nan_to_num(x, nan=0.0, posinf=30.0, neginf=-30.0)
        if torch.is_floating_point(x):
            mean_val = x.mean().item() if x.numel() > 0 else 0.0
            if mean_val < 1.0:  # looks like log-probs
                x = torch.clamp(x, min=-30.0, max=30.0)
            else:
                x = torch.clamp(x, min=0.0, max=1e6)
        return x

    detailed = []
    index_B = {(r["prompt_id"], r["layer_name"]): r for r in rows_B if "hidden" in r}

    for rA in tqdm(rows_A, desc="Comparing hidden states"):
        key = (rA["prompt_id"], rA["layer_name"])
        rB = index_B.get(key)
        if not rB or "hidden" not in rA or "hidden" not in rB:
            continue

        hA, hB = rA["hidden"], rB["hidden"]
        if not isinstance(hA, torch.Tensor) or not isinstance(hB, torch.Tensor):
            continue
        if hA.shape != hB.shape:
            continue

        # === Hidden-state metrics ===
        diff = hA - hB
        l2_pos = torch.norm(diff, dim=-1)
        cos_pos = F.cosine_similarity(hA, hB, dim=-1)
        l2_seq = l2_pos.mean().item()
        cos_seq = cos_pos.mean().item()

        # === Optional perplexity ===
        pplA = pplB = ppl_diff_signed = None
        if compute_perplexity and "logits" in rA and "logits" in rB:
            for label, logits in zip(("A", "B"), (rA["logits"], rB["logits"])):
                if logits is not None and "target_ids" in rA:
                    log_probs = safe(F.log_softmax(logits, dim=-1))
                    target = rA["target_ids"]
                    nll = -log_probs.gather(-1, target.unsqueeze(-1)).squeeze(-1)
                    mask = (target != 0).float()
                    mean_nll = (nll * mask).sum() / mask.sum()
                    ppl = math.exp(mean_nll.item())
                    if label == "A":
                        pplA = ppl
                    else:
                        pplB = ppl

            # signed PPL difference (log-scale for stability)
            if pplA is not None and pplB is not None:
                ppl_diff_signed = math.log(pplB + 1e-9) - math.log(pplA + 1e-9)

        detailed.append({
            "prompt_id": rA["prompt_id"],
            "batch_index": rA.get("batch_index"),
            "layer_index": rA["layer_index"],
            "layer_name": rA["layer_name"],
            "mean_l2_seq": l2_seq,
            "mean_cos_seq": cos_seq,
            "l2_pos": l2_pos.cpu(),
            "cos_pos": cos_pos.cpu(),
            "ppl_A": pplA,
            "ppl_B": pplB,
            "ppl_diff_signed": ppl_diff_signed,
        })

    if not detailed:
        raise ValueError("No valid pairs found. Check layer alignment and hidden data.")

    df_detailed = pd.DataFrame(detailed)
    print(f"[ok] Computed similarity metrics for {df_detailed['layer_name'].nunique()} layers.")
    return df_detailed


# ==============================================
# Head Semantics ===============================
# ==============================================

In [None]:
import torch
import gc
from tqdm import tqdm
from pathlib import Path

# orig→quant: orig_quant
# quant→orig quant_orig
@torch.no_grad()
def project_hidden_through_other_head(
    projector_model,
    hidden_file,
    batch_num=0,
    save_prefix="saved_data/projections",
    direction_tag="orig_quant",
    device="cpu",
):
    """
    Projects pre-saved hidden activations (from collect_hidden_states_full)
    through ANOTHER model's LM head (cross-model projection).

    """

    # --- Ensure model + device setup ---
    projector_model = projector_model.to(device).eval()
    Path(save_prefix).mkdir(parents=True, exist_ok=True)

    # --- Handle batch number formatting ---
    if isinstance(batch_num, str) and not batch_num.isdigit():
        batch_str = batch_num
    else:
        try:
            batch_str = f"{int(batch_num):03d}"
        except Exception:
            batch_str = str(batch_num)

    # --- Define save path ---
    save_path = f"{save_prefix}/proj_{direction_tag}_batch{batch_str}.pt"
    print(f"[run] Projecting hidden activations {direction_tag} → {save_path}")

    # --- Load pre-saved activations ---
    rows = torch.load(hidden_file, map_location="cpu", weights_only=False)
    results = []

    # --- Main projection loop ---
    for r in tqdm(rows, desc=f"[batch {batch_str}] {direction_tag} projection"):
        if "hidden" not in r:
            continue

        h = r["hidden"].to(device)
        logits_cross = projector_model.lm_head(h).cpu()

        # Preserve all meta info
        results.append({
            "prompt_id": r.get("prompt_id"),
            "prompt_text": r.get("prompt_text"),
            "batch_index": r.get("batch_index", batch_num),
            "layer_index": r.get("layer_index"),
            "layer_name": r.get("layer_name"),
            "input_ids": r.get("input_ids"),
            "target_ids": r.get("target_ids"),
            "logits_crossproj": logits_cross,
        })

        # Free memory
        del h, logits_cross
        torch.cuda.empty_cache()
        gc.collect()

    # --- Save the result ---
    torch.save(results, save_path)
    print(f"[saved] → {save_path} ({len(results)} entries, batch_index={batch_num})")
    return save_path

batch_num = "009"

# (1) Project original model’s hidden states through quantized model’s head (quant→orig)
project_hidden_through_other_head(
    projector_model=model_quant,
    hidden_file=f"saved_data/hidden_acts/m_orig_batch{batch_num}.pt",
    batch_num=batch_num,
    save_prefix="saved_data/projections",
    direction_tag="orig_quant",
    device="cpu",
)

batch_num = "009"

# (2) Project quantized model’s hidden states through original model’s head
project_hidden_through_other_head(
    projector_model=model_orig,
    hidden_file=f"saved_data/hidden_acts/m_quant_batch{batch_num}.pt",
    batch_num=batch_num,
    save_prefix="saved_data/projections",
    direction_tag="quant_orig",
    device="cpu",
)

In [None]:
import os
import gc
import torch
import torch.nn.functional as F
from tqdm import tqdm
from pathlib import Path

@torch.no_grad()
def run_projection_comparison_single_batch(
    file_A,
    file_B,
    tokenizer,
    batch_num=0,
    save_prefix="saved_data/head_semantics/semantics_",
    topk_levels=(1, 5, 10, 20),
    device="cpu",
):
    """
    Compare *already projected* logits from two models (e.g., orig→quant vs quant→orig).
    Computes:
      • cosine & L2 distances between logits (sequence-level + per-position)
      • Gini coefficients (sequence-level + per-position)
      • top-k Jaccard overlaps (per-position + sequence-level)
      • decoded top-k tokens for interpretability
    """

    Path(save_prefix).parent.mkdir(parents=True, exist_ok=True)
    save_path = f"{save_prefix}batch{int(batch_num):03d}.pt"

    rows_A = torch.load(file_A, map_location=device, weights_only=False)
    rows_B = torch.load(file_B, map_location=device, weights_only=False)

    # --- Helper for Gini coefficient ---
    def gini_coefficient(x):
        """Compute Gini per position (for logits → probs)."""
        probs = torch.softmax(x, dim=-1)
        n = probs.shape[-1]
        sorted_p, _ = torch.sort(probs, dim=-1)
        cum_p = torch.cumsum(sorted_p, dim=-1)
        gini = 1 - 2 * cum_p.sum(dim=-1) / (n * cum_p[..., -1]) + 1 / n
        return gini

    index_B = {(r["prompt_id"], r["layer_name"]): r for r in rows_B}
    results = []

    for rA in tqdm(rows_A, desc=f"[batch {batch_num}] Comparing projected logits"):
        key = (rA["prompt_id"], rA["layer_name"])
        rB = index_B.get(key)
        if not rB or "logits_crossproj" not in rA or "logits_crossproj" not in rB:
            continue

        logits_A = rA["logits_crossproj"].to(device)   # [T, V]
        logits_B = rB["logits_crossproj"].to(device)   # [T, V]
        vocab_dim = logits_A.shape[-1]
        seq_len = logits_A.shape[0]

        # --- Per-position metrics ---
        cos_pos = F.cosine_similarity(logits_A, logits_B, dim=-1)           # [T]
        l2_pos = torch.norm(logits_A - logits_B, dim=-1)                    # [T]
        gini_A_pos = gini_coefficient(logits_A)                             # [T]
        gini_B_pos = gini_coefficient(logits_B)                             # [T]
        delta_gini_pos = gini_B_pos - gini_A_pos

        # --- Sequence-level summaries ---
        cos_seq = cos_pos.mean().item()
        l2_seq = l2_pos.mean().item()
        gini_A_seq = gini_A_pos.mean().item()
        gini_B_seq = gini_B_pos.mean().item()
        delta_gini_seq = gini_B_seq - gini_A_seq

        # --- Top-k Jaccard overlaps ---
        jaccard_pos_dict = {}
        jaccard_seq_dict = {}
        decoded_shift = {}

        for k in topk_levels:
            k_safe = min(k, vocab_dim)
            jaccard_pos = []
            for t in range(seq_len):
                topA_t = torch.topk(logits_A[t], k_safe).indices.tolist()
                topB_t = torch.topk(logits_B[t], k_safe).indices.tolist()
                inter = len(set(topA_t) & set(topB_t))
                union = len(set(topA_t) | set(topB_t))
                jaccard_pos.append(inter / union if union > 0 else 0.0)
            jaccard_pos_tensor = torch.tensor(jaccard_pos)
            jaccard_pos_dict[f"jaccard_top{k}_pos"] = jaccard_pos_tensor.cpu()
            jaccard_seq_dict[f"jaccard_top{k}_seq"] = jaccard_pos_tensor.mean().item()

            # Decode top-k tokens for mean logits (for readability)
            mean_A = logits_A.mean(dim=0)
            mean_B = logits_B.mean(dim=0)
            topA_k = torch.topk(mean_A, k_safe).indices.tolist()
            topB_k = torch.topk(mean_B, k_safe).indices.tolist()
            decoded_shift[f"top{k}_A"] = [tokenizer.decode([i]) for i in topA_k]
            decoded_shift[f"top{k}_B"] = [tokenizer.decode([i]) for i in topB_k]

        # --- Aggregate everything ---
        result_entry = {
            "prompt_id": rA["prompt_id"],
            "batch_index": rA.get("batch_index", batch_num),
            "layer_index": rA["layer_index"],
            "layer_name": rA["layer_name"],
            "input_ids": rA.get("input_ids"),
            "target_ids": rA.get("target_ids"),

            # --- sequence-level summary ---
            "cosine_seq": cos_seq,
            "l2_seq": l2_seq,
            "gini_A_seq": gini_A_seq,
            "gini_B_seq": gini_B_seq,
            "delta_gini_seq": delta_gini_seq,
            **jaccard_seq_dict,

            # --- detailed per-position tensors ---
            "cosine_pos": cos_pos.cpu(),
            "l2_pos": l2_pos.cpu(),
            "gini_A_pos": gini_A_pos.cpu(),
            "gini_B_pos": gini_B_pos.cpu(),
            "delta_gini_pos": delta_gini_pos.cpu(),
            **jaccard_pos_dict,

            "decoded_shift": decoded_shift,
        }

        results.append(result_entry)
        del logits_A, logits_B
        torch.cuda.empty_cache()
        gc.collect()

    torch.save(results, save_path)
    print(f"[saved] → {save_path} ({len(results)} entries)")
    return results

from transformers import AutoTokenizer
orig_tokenizer = AutoTokenizer.from_pretrained("Models/LLaMA3Instruct")

batch_num = "002"

run_projection_comparison_single_batch(
    file_A=f"saved_data/projections/proj_orig_quant_batch{batch_num}.pt",
    file_B=f"saved_data/projections/proj_quant_orig_batch{batch_num}.pt",
    tokenizer=orig_tokenizer,
    batch_num=batch_num,
    save_prefix="saved_data/projection_comparisons/semantics_",
)