# -------------------------------
# MI Logit Lens Notebook
# -------------------------------

In [None]:
from pathlib import Path
from datasets import load_dataset, DownloadMode
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

from quant_configs.bnb_configs import load_bnb_in_8bit, load_bnb_in_4bit
from mi_utils.util.logit_lens_utils.logit_lens_wrapper import LogitLensWrapper

from mi_utils.logit_lens.logit_lens_degradation_analysis import (
    analyze_UNSAFE_degradation,
    analyze_SAFE_degradation,
)

from mi_utils.logit_lens.logit_lens_analysis import run_logit_lens

from mi_utils.logit_lens.metric_utils.logit_lens_helpers import (
    save_results_to_csv,
    get_activation_tensor,
    extract_activations,
    load_results_from_pt
)

from mi_utils.logit_lens.metric_utils.interp_degradation_scores import (
    degradation_diff_score,
    interpretability_diff_score,
    degradation_score
)

from mi_utils.logit_lens.plotting_utils.lens_plotting import (
    plot_layer_metric_two_dfs,
    plot_layer_deviation,
    qq_plot_probs
)

from mi_utils.logit_lens.metric_utils.carry_over_safe_metrics import (
    filter_main_layers,
    prepare_layer_tensors,
    compute_carry_over_safe_scalar
)

# -------------------------------
# Models
# -------------------------------

In [None]:
from enum import Enum

class Models(Enum):
    GPT2 = "Models/GPT2"
    LAIN8B = "Models/LLaMA3Instruct"
    HF100B = "Models/HF1BitLLM100Btokens"
    HF10BL = "Models/HF1BitLLMLinear10B"
    HF10BS = "Models/HF1BitLLMSigmoid10B"
    OL1B = "Models/OLMo1B"
    OL7B = "Models/OLMo1B"
    DH3B = "DHLLaMA3B"
    DH8B = "DHLLaMA8B"

class Names(Enum):
    GPT2 = "GPT2"
    LAIN8B = "Meta-Llama-3-8B-Instruct"
    HF100B = "Models/Llama3-8B-1.58-100B-tokens"
    HF10BL = "Llama3-8B-1.58-Linear-10B-tokens"
    HF10BS = "Llama3-8B-1.58-Sigmoid-k100-10B-tokens"
    OL1B = "OLMo-1B-hf"
    OL7B = "OLMo-7B-hf"
    DH3B = "DeepHermes-3-Llama-3-3B-Preview"
    DH8B = "DeepHermes-3-Llama-3-8B-Preview"

In [None]:
def load_model_and_tok(
        model_name:str,
        output_hidden_states:bool=True,
        low_cpu_mem_usage:bool=True,
        local_files_only:bool=True,
        device_map:str="cuda",
        dtype=torch.bfloat16
) -> tuple[AutoModelForCausalLM, AutoTokenizer]:
    
    tok = AutoTokenizer.from_pretrained(model_name)

    model = AutoModelForCausalLM.from_pretrained(
    model_name,
    #output_hidden_states=output_hidden_states,
    low_cpu_mem_usage=low_cpu_mem_usage,
    local_files_only=local_files_only,
    device_map=device_map,
    torch_dtype=dtype
    )

    return model, tok

In [None]:
def save_fp_acts_to_pt(fp_acts, save_name:str) -> None:
    fp_acts_to_save = {}
    for lname, act in fp_acts.items():
        fp_acts_to_save[lname] = {
            "hidden": act["hidden"].detach().cpu(),
            "mask": act["mask"].detach().cpu()
        }

    torch.save(fp_acts_to_save, f"logs/cka_svcca_acts/{save_name}.pt")

def load_fp_acts(fp_acts_name:str):
    fp_acts = torch.load(f"logs/cka_svcca_acts/{fp_acts_name}.pt")
    return fp_acts

# -------------------------------
# Parameters & Datasets
# -------------------------------

In [None]:
LL_DIR = "logs/logit_lens_logs"

BATCH_DIR = f"{LL_DIR}/batch_analysis"
INTERP_DIR = f"{LL_DIR}/interp_analysis"

NQ_DIR = "natural_questions"
GS_DIR = "gsm8k"

In [None]:
EPS = 1e-12 
TOPK = 5

In [None]:
GPT_L = ["final_layernorm", "lm_head"]
LLAMA_L = ["norm", "lm_head"]

In [None]:
texts = ["Hello world", "This is a test"]

In [None]:
filepath = r'D:\LogitLensData\nq'

destination_path = str(Path(filepath))
nq_dataset = load_dataset(
    'sentence-transformers/natural-questions',
    split={
        'train': 'train[:10]'
    },
    cache_dir=destination_path,
    download_mode=DownloadMode.REUSE_DATASET_IF_EXISTS,
    keep_in_memory=True
)

In [None]:
nq_queries = list(nq_dataset['train']['query'])
nq_answers = list(nq_dataset['train']['answer'])

In [None]:
nq_queries = nq_dataset['train']['query']
nq_answers = nq_dataset['train']['answer']

In [None]:
filepath = r'D:\LogitLensData\gsm8k'

destination_path = str(Path(filepath))
gsm8k_dataset = load_dataset(
    'gsm8k', 'main',
    split={
        'train': 'train[:20]'
    },
    cache_dir=destination_path,
    download_mode=DownloadMode.REUSE_DATASET_IF_EXISTS,
    keep_in_memory=True
)

In [None]:
gsm8k_questions = gsm8k_dataset['train']['question']
gsm8k_answers = gsm8k_dataset['train']['answer']

In [None]:
DATA = nq_answers

# -------------------------------
# Extract and save FP Activations for CKA/SVCCA
# -------------------------------

In [None]:
llama_fp, llama_tok = load_model_and_tok(Models.LAIN8B.value, device_map="cpu", dtype=torch.float32)

In [None]:
llama_fp_wrapper = LogitLensWrapper(
    model=llama_fp,
    tokenizer=llama_tok,
    block_step=1,
    include_input=True,
    force_include_output=True,
    include_subblocks=True,
    decoder_layer_names=LLAMA_L,
    device="cpu"
)

In [None]:
llama_fp_acts = extract_activations(wrapper=llama_fp_wrapper, prompts=nq_answers)

In [None]:
save_fp_acts_to_pt(llama_fp_acts, "llama_fp_acts_nq_answers")

# -------------------------------
# 2. HF1BitLLM/Llama3-8B-1.58-100B-tokens
# -------------------------------

In [None]:
#hf100b_m, hf100b_tok = load_model_and_tok(Models.HF100B.value, dtype=torch.bfloat16)
hf100b_m, hf100b_tok = load_model_and_tok(Models.HF100B.value, device_map="cpu", dtype=torch.float32)

In [None]:
hf100b_wrapper = LogitLensWrapper(
    model=hf100b_m,
    tokenizer=hf100b_tok,
    block_step=1,
    include_input=True,
    force_include_output=True,
    include_subblocks=True,
    decoder_layer_names=LLAMA_L, 
    device="cpu"
)

In [None]:
hf100b_unsafe_results = analyze_UNSAFE_degradation(hf100b_wrapper, texts, top_k=5, decoder=None)

In [None]:
save_results_to_csv(hf100b_unsafe_results, filename=f"{LL_DIR}hf100b_unsafe_results.csv")

In [None]:
hf100b_section_scores, hf100b_overall_score = degradation_score(hf100b_unsafe_results)

print("Section scores:", hf100b_section_scores)
print("Overall interpretability score:", hf100b_overall_score)

In [None]:
hf_acts = extract_activations(wrapper=hf100b_wrapper, prompts=texts)

In [None]:
save_fp_acts_to_pt(hf_acts, "hf_acts")

In [None]:
llama_fp_acts = load_fp_acts("llama_fp_acts_nq_answers")

In [None]:
hf_acts = extract_activations(wrapper=hf100b_wrapper, prompts=DATA)

In [None]:
save_fp_acts_to_pt(hf_acts, "hf100b_acts_nq_answers")

In [None]:
hf_acts = load_fp_acts("hf100b_acts_nq_answers")

In [None]:
run_logit_lens(
    wrapper=hf100b_wrapper,
    prompts=DATA,
    model_name="llama_hf100b",
    dataset_name="nq_answers",
    topk=TOPK,
    eps=EPS,
    skip_input_layer= False, # True to skip embedding
    include_final_norm=True,
    save_layer_probs=False
)

In [None]:
hf_df = pd.read_csv("logs/logit_lens_logs/batch_analysis/dataset_hf1bitllm_0.csv")

In [2]:
import torch

data = torch.load(
    "logs/logit_lens_logs/logit_lens_analysis/nq_answers_llama_hf100b.pt",
    weights_only=False  # allow full unpickling
)


In [3]:
import pandas as pd
hf_df = pd.DataFrame(data)

In [4]:

df_with_metrics = compute_scalar_metrics(hf_df, topk=5)

# Check results
print(df_with_metrics[['acc_top1_scalar', 'acc_topk_scalar', 'stab_top1_scalar', 'stab_topk_scalar']])


     acc_top1_scalar  acc_topk_scalar  stab_top1_scalar  stab_topk_scalar
0                0.0              0.0               0.0               0.0
1                0.0              0.0               0.0               0.0
2                0.0              0.0               0.0               0.0
3                0.0              0.0               0.0               0.0
4                0.0              0.0               0.0               0.0
..               ...              ...               ...               ...
965              0.0              0.0               0.0               0.0
966              0.0              0.0               0.0               0.0
967              0.0              0.0               0.0               0.0
968              0.0              0.0               0.0               0.0
969              0.0              0.0               0.0               0.0

[970 rows x 4 columns]


In [None]:
df = compute_scalar_metrics(hf_df, topk=5)

# Check scalar metrics
print(df[['acc_top1_scalar', 'acc_topk_scalar', 'stab_top1_scalar', 'stab_topk_scalar']])


In [None]:
# Suppose your DataFrame is hf_df
df_with_scalars = compute_scalar_metrics(hf_df, topk=5)

# See the results
print(df_with_scalars[['acc_top1_scalar', 'acc_topk_scalar', 'stab_top1_scalar', 'stab_topk_scalar']])


In [None]:
valid_layer_names = [lname for lname in hf_df['layer_name'].unique() if lname.startswith('layers.')]
print(valid_layer_names)


In [None]:
df = compute_scalar_metrics(hf_df, topk=5)

# Check the results
print(df[['acc_top1_scalar', 'acc_topk_scalar', 'stab_top1_scalar', 'stab_topk_scalar']])

In [None]:
TOPK = 5
df = hf_df
# Pre-allocate lists for results
acc_top1_list = []
acc_topk_list = []
stab_top1_list = []
stab_topk_list = []

for _, row in df.iterrows():
    # Prepare tensors, automatically filtering only 'layers.N'
    layer_logits, input_ids, target_ids = prepare_row_tensors_filtered(row)

    # Compute metrics
    acc_top1 = carry_over_safe_accuracy_top1(layer_logits, target_ids, input_ids)
    acc_topk = carry_over_safe_accuracy_topk(layer_logits, target_ids, input_ids, k=TOPK)
    stab_top1 = carry_over_safe_stability_top1(layer_logits, input_ids)
    stab_topk = carry_over_safe_stability_topk(layer_logits, input_ids, k=TOPK)

    # Append results as NumPy arrays
    acc_top1_list.append(acc_top1.numpy())
    acc_topk_list.append(acc_topk.numpy())
    stab_top1_list.append(stab_top1.numpy())
    stab_topk_list.append(stab_topk.numpy())

# Assign results back to the DataFrame
df['acc_top1'] = acc_top1_list
df['acc_topk'] = acc_topk_list
df['stab_top1'] = stab_top1_list
df['stab_topk'] = stab_topk_list


In [None]:
print(df[['acc_top1', 'acc_topk', 'stab_top1', 'stab_topk']].head())


In [None]:
import torch
import numpy as np

import torch
import numpy as np

def safe_mean(seq):
    if seq is None:
        return 0.0
    # convert to flat list of floats
    flat = []
    for x in seq:
        if isinstance(x, torch.Tensor):
            flat.append(float(x.item()))
        else:
            flat.append(float(x))
    if len(flat) == 0:
        return 0.0
    return float(torch.mean(torch.tensor(flat, dtype=torch.float)))

df['acc_top1_scalar'] = df['acc_top1'].apply(safe_mean)
df['acc_topk_scalar'] = df['acc_topk'].apply(safe_mean)
df['stab_top1_scalar'] = df['stab_top1'].apply(safe_mean)
df['stab_topk_scalar'] = df['stab_topk'].apply(safe_mean)



In [None]:
df[['acc_top1_scalar', 'acc_topk_scalar', 'stab_top1_scalar', 'stab_topk_scalar']].head()



In [None]:
hf_df.columns

In [None]:
df["layer_name"].unique()

In [None]:
TOPK = 5
df = hf_df
# Pre-allocate lists
acc_top1_list = []
acc_topk_list = []
stab_top1_list = []
stab_topk_list = []

for _, row in df.iterrows():
    layer_logits, input_ids, target_ids = prepare_row_tensors(row)
    
    acc_top1_list.append(carry_over_safe_accuracy_top1_vectorized(layer_logits, target_ids, input_ids).numpy())
    acc_topk_list.append(carry_over_safe_accuracy_topk_vectorized(layer_logits, target_ids, input_ids, k=TOPK).numpy())
    stab_top1_list.append(carry_over_safe_stability_top1_vectorized(layer_logits, input_ids).numpy())
    stab_topk_list.append(carry_over_safe_stability_topk_vectorized(layer_logits, input_ids, k=TOPK).numpy())

# Assign as new columns
df['acc_top1'] = acc_top1_list
df['acc_topk'] = acc_topk_list
df['stab_top1'] = stab_top1_list
df['stab_topk'] = stab_topk_list


In [None]:
hf_df.columns

In [None]:
print(hf_df.head(2))

In [None]:
import torch

df = hf_df
# Suppose df is your DataFrame
TOPK = 5

for _, row in df.iterrows():
    layer_logits, input_ids, target_ids = prepare_row_tensors(row)

    acc_top1 = carry_over_safe_accuracy_top1(layer_logits, target_ids, input_ids)
    #acc_topk = carry_over_safe_accuracy_topk(layer_logits, target_ids, input_ids, k=5)
    #stab_top1 = carry_over_safe_stability_top1(layer_logits, input_ids)
    #stab_topk = carry_over_safe_stability_topk(layer_logits, input_ids, k=5)


In [None]:
hf_df["ece"][:10]

In [None]:
carry_tokens = set(input_ids[:t+1].tolist())

In [None]:
hf_df.isna().sum()

In [None]:
import pandas as pd
import numpy as np

# Assuming your logit lens data is loaded into a DataFrame `df`
# df = pd.DataFrame(data)  # if not done yet
df = hf_df

cols_to_check = ["kl_next_layer_mean", "kl_next_layer_seq", "nwd"]

def check_nan_inf(x):
    """Return a tuple of flags: (has_nan, has_pos_inf, has_neg_inf)"""
    if isinstance(x, list):
        results = [check_nan_inf(v) for v in x]
        has_nan = any(r[0] for r in results)
        has_pos_inf = any(r[1] for r in results)
        has_neg_inf = any(r[2] for r in results)
        return has_nan, has_pos_inf, has_neg_inf
    try:
        x = float(x)
        return np.isnan(x), np.isposinf(x), np.isneginf(x)
    except:
        return False, False, False

# Apply to each column
flags = {}
for col in cols_to_check:
    col_flags = df[col].map(check_nan_inf)
    flags[col + "_nan"] = col_flags.map(lambda t: t[0])
    flags[col + "_pos_inf"] = col_flags.map(lambda t: t[1])
    flags[col + "_neg_inf"] = col_flags.map(lambda t: t[2])

# Combine into a DataFrame
flags_df = pd.DataFrame(flags)

# Get rows with any issue
problem_rows = df[flags_df.any(axis=1)]

print("Rows with NaN or +/- Inf:")
print(problem_rows)


In [None]:
hf_df.isna()

In [None]:
cols_to_check = df.columns

def has_problem(x):
    try:
        val = float(x)
        return np.isnan(val) or np.isinf(val)
    except (ValueError, TypeError):
        return False

# Boolean mask: True if any column in the row has NaN or ±Inf
problem_mask = df[cols_to_check].applymap(has_problem).any(axis=1)

# Extract rows with problems
problem_rows = df[problem_mask]

print(f"Rows with NaN or ±Inf ({len(problem_rows)} rows):")
print(problem_rows)


In [None]:
hf_df.isna().sum()

In [None]:
for k,v in hf_df["kl_next_layer_mean"].values().isna():

    print(k, v)

In [None]:
hf_df.info()

In [None]:
def has_problem_extended(x):
    # Check numeric issues
    try:
        val = float(x)
        if np.isnan(val) or np.isinf(val):
            return True
    except (ValueError, TypeError):
        pass
    # Check for empty sequences/lists
    if isinstance(x, (list, tuple, np.ndarray)) and len(x) == 0:
        return True
    return False

problem_mask = df.applymap(has_problem_extended).any(axis=1)
problem_rows = df[problem_mask]

print(f"Rows with NaN, ±Inf, or empty sequences ({len(problem_rows)} rows):")
print(problem_rows)


In [None]:
hf_df.iloc[hf_df["kl_next_layer_mean"] == hf_df["kl_next_layer_mean"].isna()]

In [None]:
pos_inf_counts = df.apply(lambda col: (col == np.inf).sum())
neg_inf_counts = df.apply(lambda col: (col == -np.inf).sum())
print(f"{neg_inf_counts}\n{pos_inf_counts}")

In [None]:
#model.save_pretrained("Models/", safe_serialization=True)
#tokenizer.save_pretrained("Models/")