In [1]:
from pathlib import Path
from datasets import load_dataset, DownloadMode
import torch
import os
import glob
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

from llama_wrapper import LlamaPromptLens, run_logit_lens_batched

In [2]:
from enum import Enum

class Models(Enum):
    LAIN8B = "Models/LLaMA3Instruct"
    HF100B = "Models/HF1BitLLM100Btokens"


class Names(Enum):
    LAIN8B = "Meta-Llama-3-8B-Instruct-fp"
    HF100B = "Llama3-8B-1.58-100B-tokens"

In [3]:
filepath = r'D:\LogitLensData\nq'

destination_path = str(Path(filepath))
nq_dataset = load_dataset(
    'sentence-transformers/natural-questions',
    split={
        'train': 'train[:1000]'
    },
    cache_dir=destination_path,
    download_mode=DownloadMode.REUSE_DATASET_IF_EXISTS,
    keep_in_memory=True
)

In [4]:
nq_queries = nq_dataset['train']['query']
nq_answers = nq_dataset['train']['answer']

In [5]:
nq_queries = nq_dataset['train']['query']

In [6]:
nq_queries_200 = nq_queries[:200]
nq_queries_400 = nq_queries[200:400]
nq_queries_600 = nq_queries[400:600]
nq_queries_800 = nq_queries[600:800]
nq_queries_1000 = nq_queries[800:1000]

### LLaMA FP

In [7]:
llama8b_fp = LlamaPromptLens(
    model_id=Models.LAIN8B.value,
    apply_per_layer_norm=True,
    include_subblocks=False,
    device="cpu"
)

Loading checkpoint shards:   0%|          | 0/7 [00:00<?, ?it/s]

Architecture detected: llama
Standard FP16 or FP32 model.


In [8]:
run_logit_lens_batched(
    lens=llama8b_fp,
    prompts=nq_queries_200,
    dataset_name="nq_query_200",
    model_name="llama8b_fp",
    save_dir="logs/lens_batches_norm/llama8b_fp/nq_200",
    proj_precision=None,
    batch_size=10,
)

[✓] Saved batch 0: logs/lens_batches_norm/llama8b_fp/nq_200/nq_query_200_llama8b_fp_batch0.pt
[✓] Saved batch 1: logs/lens_batches_norm/llama8b_fp/nq_200/nq_query_200_llama8b_fp_batch1.pt
[✓] Saved batch 2: logs/lens_batches_norm/llama8b_fp/nq_200/nq_query_200_llama8b_fp_batch2.pt
[✓] Saved batch 3: logs/lens_batches_norm/llama8b_fp/nq_200/nq_query_200_llama8b_fp_batch3.pt
[✓] Saved batch 4: logs/lens_batches_norm/llama8b_fp/nq_200/nq_query_200_llama8b_fp_batch4.pt
[✓] Saved batch 5: logs/lens_batches_norm/llama8b_fp/nq_200/nq_query_200_llama8b_fp_batch5.pt
[✓] Saved batch 6: logs/lens_batches_norm/llama8b_fp/nq_200/nq_query_200_llama8b_fp_batch6.pt
[✓] Saved batch 7: logs/lens_batches_norm/llama8b_fp/nq_200/nq_query_200_llama8b_fp_batch7.pt
[✓] Saved batch 8: logs/lens_batches_norm/llama8b_fp/nq_200/nq_query_200_llama8b_fp_batch8.pt
[✓] Saved batch 9: logs/lens_batches_norm/llama8b_fp/nq_200/nq_query_200_llama8b_fp_batch9.pt
[✓] Saved batch 10: logs/lens_batches_norm/llama8b_fp/nq_200

In [7]:
import torch

data = torch.load(
    "logs/lens_batches/nq_query_llama8b_fp_batch0.pt",
    weights_only=False 
)

### HF1BitLLM

In [10]:
llama8b_hf100b = LlamaPromptLens(
    model_id=Models.HF100B.value,
    apply_per_layer_norm=True,
    include_subblocks=False,
    device="cpu"
)

Architecture detected: bitnet
BitNet model (BitLinear layers).


In [11]:
run_logit_lens_batched(
    lens=llama8b_hf100b,
    prompts=nq_queries_200,
    dataset_name="nq_query_200",
    model_name="llama8b_hf100b",
    save_dir="logs/lens_batches_norm/llama8b_hf100b/nq_200",
    proj_precision=None,
    batch_size=10,
)

[✓] Saved batch 0: logs/lens_batches_norm/llama8b_hf100b/nq_200/nq_query_200_llama8b_hf100b_batch0.pt
[✓] Saved batch 1: logs/lens_batches_norm/llama8b_hf100b/nq_200/nq_query_200_llama8b_hf100b_batch1.pt
[✓] Saved batch 2: logs/lens_batches_norm/llama8b_hf100b/nq_200/nq_query_200_llama8b_hf100b_batch2.pt
[✓] Saved batch 3: logs/lens_batches_norm/llama8b_hf100b/nq_200/nq_query_200_llama8b_hf100b_batch3.pt
[✓] Saved batch 4: logs/lens_batches_norm/llama8b_hf100b/nq_200/nq_query_200_llama8b_hf100b_batch4.pt
[✓] Saved batch 5: logs/lens_batches_norm/llama8b_hf100b/nq_200/nq_query_200_llama8b_hf100b_batch5.pt
[✓] Saved batch 6: logs/lens_batches_norm/llama8b_hf100b/nq_200/nq_query_200_llama8b_hf100b_batch6.pt
[✓] Saved batch 7: logs/lens_batches_norm/llama8b_hf100b/nq_200/nq_query_200_llama8b_hf100b_batch7.pt
[✓] Saved batch 8: logs/lens_batches_norm/llama8b_hf100b/nq_200/nq_query_200_llama8b_hf100b_batch8.pt
[✓] Saved batch 9: logs/lens_batches_norm/llama8b_hf100b/nq_200/nq_query_200_llama

In [11]:
import torch

data = torch.load(
    "logs/lens_batches/nq_query_norm_llama8b_hf100b_batch0.pt",
    weights_only=False 
)

In [12]:
data.tail()

Unnamed: 0,prompt_id,prompt_text,dataset,vocab_size,layer_index,layer_name,input_ids,target_ids,logits,position
325,9,when did fosters home for imaginary friends start,nq_query_norm,128000,28,layer_27,"[tensor(128000), tensor(9493), tensor(1550), t...","[tensor(9493), tensor(1550), tensor(37413), te...","[[tensor(0.6027), tensor(0.5580), tensor(0.581...","[tensor(0), tensor(1), tensor(2), tensor(3), t..."
326,9,when did fosters home for imaginary friends start,nq_query_norm,128000,29,layer_28,"[tensor(128000), tensor(9493), tensor(1550), t...","[tensor(9493), tensor(1550), tensor(37413), te...","[[tensor(0.3937), tensor(0.6215), tensor(0.407...","[tensor(0), tensor(1), tensor(2), tensor(3), t..."
327,9,when did fosters home for imaginary friends start,nq_query_norm,128000,30,layer_29,"[tensor(128000), tensor(9493), tensor(1550), t...","[tensor(9493), tensor(1550), tensor(37413), te...","[[tensor(0.3690), tensor(1.1385), tensor(0.828...","[tensor(0), tensor(1), tensor(2), tensor(3), t..."
328,9,when did fosters home for imaginary friends start,nq_query_norm,128000,31,layer_30,"[tensor(128000), tensor(9493), tensor(1550), t...","[tensor(9493), tensor(1550), tensor(37413), te...","[[tensor(1.3250), tensor(2.6986), tensor(1.801...","[tensor(0), tensor(1), tensor(2), tensor(3), t..."
329,9,when did fosters home for imaginary friends start,nq_query_norm,128000,32,layer_31,"[tensor(128000), tensor(9493), tensor(1550), t...","[tensor(9493), tensor(1550), tensor(37413), te...","[[tensor(0.7254), tensor(1.6154), tensor(0.954...","[tensor(0), tensor(1), tensor(2), tensor(3), t..."


# Softmax and Single Model Computations

In [None]:
import torch
import torch.nn.functional as F
import pandas as pd
from collections import defaultdict

def extract_metrics(
    data,
    topk=[1, 5, 10, 20],
    mask_ids=[128000, 128009],
    device=None,
):
    """
    Extract top-k accuracy, log-probs, and NLL metrics from logit lens outputs.
    Works with both a list of dicts and a pandas DataFrame (as produced by _run_logit_lens_batch).
    """

    if isinstance(data, pd.DataFrame):
        data = data.to_dict(orient="records")

    results = []
    all_k = sorted(set([1] + list(topk)))  # always include top-1
    max_k = max(all_k)

    device = device or ("cuda" if torch.cuda.is_available() else "cpu")
    mask_ids = torch.tensor(mask_ids, device=device)

    for row in data:
        logits = row["logits"].to(device)  # [seq_len, vocab]
        targets = row.get("target_ids", None)
        if targets is not None:
            targets = targets.to(device)

        # --- normalize shapes ---
        if logits.dim() == 3 and logits.size(0) == 1:
            logits = logits.squeeze(0)
        if targets is not None and targets.dim() == 2 and targets.size(0) == 1:
            targets = targets.squeeze(0)

        seq_len, vocab_size = logits.shape

        # --- stable softmax ---
        log_probs = F.log_softmax(logits, dim=-1)
        probs = log_probs.exp()

        # --- top-k predictions ---
        top_vals, top_idx = torch.topk(probs, max_k, dim=-1)
        topk_preds = {k: top_idx[:, :k] for k in all_k}
        topk_vals = {k: top_vals[:, :k] for k in all_k}

        # --- metrics if targets exist ---
        if targets is not None:
            # mask out special tokens
            mask = ~torch.isin(targets, mask_ids)

            # target log probs and nll
            target_logprobs = log_probs[torch.arange(seq_len, device=device), targets]
            target_probs = target_logprobs.exp()
            nll = -target_logprobs

            # accuracy metrics
            correct_topk, acc_topk = {}, {}
            for k in all_k:
                correct = (targets.unsqueeze(-1) == topk_preds[k]).any(dim=-1).int()
                correct_topk[k] = correct
                valid_correct = correct[mask]
                acc_mean = valid_correct.float().mean().item() if valid_correct.numel() > 0 else float("nan")
                acc_topk[k] = {"mean": acc_mean, "per_token": correct.tolist()}

            # monotonicity sanity check
            for k1, k2 in zip(all_k[:-1], all_k[1:]):
                assert (correct_topk[k1] <= correct_topk[k2]).all(), f"Top-{k1} not subset of top-{k2}"

        else:
            mask = target_probs = target_logprobs = nll = correct_topk = acc_topk = None

        results.append({
            "prompt_id": row.get("prompt_id"),
            "prompt_text": row.get("prompt_text"),
            "layer_index": row.get("layer_index"),
            "layer_name": row.get("layer_name"),
            "dataset": row.get("dataset"),
            "vocab_size": vocab_size,
            "input_ids": row.get("input_ids"),
            "position": row.get("position"),
            "targets": targets,
            "mask": mask,
            "logits": logits.cpu(),
            "log_probs": log_probs.cpu(),
            "probs": probs.cpu(),
            "nll": nll.cpu() if nll is not None else None,
            "target_probs": target_probs.cpu() if target_probs is not None else None,
            "target_logprobs": target_logprobs.cpu() if target_logprobs is not None else None,
            "topk_preds": {k: v.cpu() for k, v in topk_preds.items()},
            "topk_vals": {k: v.cpu() for k, v in topk_vals.items()},
            "correct_topk": {k: v.cpu() for k, v in correct_topk.items()} if correct_topk else None,
            "acc_topk": acc_topk,
        })

    return results


# Load logits:
"""norm_logits_data = torch.load(
    "logs/lens_batches/nq_query_norm_llama8b_hf100b_batch0.pt",
    weights_only=False 
)"""

raw_logits_data = torch.load(
    "logs/lens_batches_raw/llama8b_fp/nq_200/nq_query_200_llama8b_fp_batch0.pt",
    weights_only=False 
)

model_logits = raw_logits_data

# Extract topk metrics:
extracted_metrics = extract_metrics(model_logits)

# Saving as a torch file:
save_path = "logs/batch_probs_raw/llama8b_fp/nq_200/nq_query_200_llama8b_fp_batch0.pt"
torch.save(extracted_metrics, save_path)

TypeError: string indices must be integers