# -------------------------------
# MI Logit Lens Notebook
# -------------------------------

In [1]:
from pathlib import Path
from datasets import load_dataset, DownloadMode
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

from quant_configs.bnb_configs import load_bnb_in_8bit, load_bnb_in_4bit
from mi_utils.util.logit_lens_utils.logit_lens_wrapper import LogitLensWrapper

from mi_utils.logit_lens.logit_lens_degradation_analysis import (
    analyze_UNSAFE_degradation,
    analyze_SAFE_degradation,
)

from mi_utils.logit_lens.logit_lens_analysis import run_logit_lens

from mi_utils.logit_lens.metric_utils.logit_lens_helpers import (
    save_results_to_csv,
    get_activation_tensor,
    extract_activations,
    load_results_from_pt
)

from mi_utils.logit_lens.metric_utils.interp_degradation_scores import (
    degradation_diff_score,
    interpretability_diff_score,
    degradation_score
)

from mi_utils.logit_lens.plotting_utils.lens_plotting import (
    plot_layer_metric_two_dfs,
    plot_layer_deviation,
    qq_plot_probs
)

from mi_utils.logit_lens.metric_utils.carry_over_safe_metrics import get_carry_over_safe_with_embedding

# -------------------------------
# Models
# -------------------------------

In [2]:
from enum import Enum

class Models(Enum):
    GPT2 = "Models/GPT2"
    LAIN8B = "Models/LLaMA3Instruct"
    HF100B = "Models/HF1BitLLM100Btokens"
    HF10BL = "Models/HF1BitLLMLinear10B"
    HF10BS = "Models/HF1BitLLMSigmoid10B"
    OL1B = "Models/OLMo1B"
    OL7B = "Models/OLMo1B"
    DH3B = "DHLLaMA3B"
    DH8B = "DHLLaMA8B"

class Names(Enum):
    GPT2 = "GPT2"
    LAIN8B = "Meta-Llama-3-8B-Instruct"
    HF100B = "Models/Llama3-8B-1.58-100B-tokens"
    HF10BL = "Llama3-8B-1.58-Linear-10B-tokens"
    HF10BS = "Llama3-8B-1.58-Sigmoid-k100-10B-tokens"
    OL1B = "OLMo-1B-hf"
    OL7B = "OLMo-7B-hf"
    DH3B = "DeepHermes-3-Llama-3-3B-Preview"
    DH8B = "DeepHermes-3-Llama-3-8B-Preview"

In [3]:
def load_model_and_tok(
        model_name:str,
        output_hidden_states:bool=True,
        low_cpu_mem_usage:bool=True,
        local_files_only:bool=True,
        device_map:str="cuda",
        dtype=torch.bfloat16
) -> tuple[AutoModelForCausalLM, AutoTokenizer]:
    
    tok = AutoTokenizer.from_pretrained(model_name)

    model = AutoModelForCausalLM.from_pretrained(
    model_name,
    #output_hidden_states=output_hidden_states,
    low_cpu_mem_usage=low_cpu_mem_usage,
    local_files_only=local_files_only,
    device_map=device_map,
    torch_dtype=dtype
    )

    return model, tok

In [4]:
def save_fp_acts_to_pt(fp_acts, save_name:str) -> None:
    fp_acts_to_save = {}
    for lname, act in fp_acts.items():
        fp_acts_to_save[lname] = {
            "hidden": act["hidden"].detach().cpu(),
            "mask": act["mask"].detach().cpu()
        }

    torch.save(fp_acts_to_save, f"logs/cka_svcca_acts/{save_name}.pt")

def load_fp_acts(fp_acts_name:str):
    fp_acts = torch.load(f"logs/cka_svcca_acts/{fp_acts_name}.pt")
    return fp_acts

# -------------------------------
# Parameters & Datasets
# -------------------------------

In [5]:
LL_DIR = "logs/logit_lens_logs"

BATCH_DIR = f"{LL_DIR}/batch_analysis"
INTERP_DIR = f"{LL_DIR}/interp_analysis"

NQ_DIR = "natural_questions"
GS_DIR = "gsm8k"

In [6]:
EPS = 1e-8 
TOPK = 5

In [7]:
GPT_L = ["final_layernorm", "lm_head"]
LLAMA_L = ["norm", "lm_head"]

In [None]:
texts = ["Hello world", "This is a test"]

In [8]:
filepath = r'D:\LogitLensData\nq'

destination_path = str(Path(filepath))
nq_dataset = load_dataset(
    'sentence-transformers/natural-questions',
    split={
        'train': 'train[:20]'
    },
    cache_dir=destination_path,
    download_mode=DownloadMode.REUSE_DATASET_IF_EXISTS,
    keep_in_memory=True
)

Generating train split:   0%|          | 0/100231 [00:00<?, ? examples/s]

In [9]:
nq_queries = list(nq_dataset['train']['query'])
nq_answers = list(nq_dataset['train']['answer'])

In [9]:
nq_queries = nq_dataset['train']['query']
nq_answers = nq_dataset['train']['answer']

In [None]:
filepath = r'D:\LogitLensData\gsm8k'

destination_path = str(Path(filepath))
gsm8k_dataset = load_dataset(
    'gsm8k', 'main',
    split={
        'train': 'train[:20]'
    },
    cache_dir=destination_path,
    download_mode=DownloadMode.REUSE_DATASET_IF_EXISTS,
    keep_in_memory=True
)

In [None]:
gsm8k_questions = gsm8k_dataset['train']['question']
gsm8k_answers = gsm8k_dataset['train']['answer']

In [13]:
DATA = nq_queries

# -------------------------------
# Extract and save FP Activations for CKA/SVCCA
# -------------------------------

In [None]:
llama_fp, llama_tok = load_model_and_tok(Models.LAIN8B.value, device_map="cpu", dtype=torch.float32)

In [None]:
llama_fp_wrapper = LogitLensWrapper(
    model=llama_fp,
    tokenizer=llama_tok,
    block_step=1,
    include_input=True,
    force_include_output=True,
    include_subblocks=True,
    decoder_layer_names=LLAMA_L,
    device="cpu"
)

In [None]:
llama_fp_acts = extract_activations(wrapper=llama_fp_wrapper, prompts=nq_answers)

In [None]:
save_fp_acts_to_pt(llama_fp_acts, "llama_fp_acts_nq_answers")

# -------------------------------
# 2. HF1BitLLM/Llama3-8B-1.58-100B-tokens
# -------------------------------

In [10]:
#hf100b_m, hf100b_tok = load_model_and_tok(Models.HF100B.value, dtype=torch.bfloat16)
hf100b_m, hf100b_tok = load_model_and_tok(Models.HF100B.value, device_map="cpu", dtype=torch.float32)

In [11]:
hf100b_wrapper = LogitLensWrapper(
    model=hf100b_m,
    tokenizer=hf100b_tok,
    block_step=1,
    include_input=True,
    force_include_output=True,
    include_subblocks=True,
    decoder_layer_names=LLAMA_L, 
    device="cpu",
    max_len=16
)

In [None]:
hf100b_unsafe_results = analyze_UNSAFE_degradation(hf100b_wrapper, texts, top_k=5, decoder=None)

In [None]:
save_results_to_csv(hf100b_unsafe_results, filename=f"{LL_DIR}hf100b_unsafe_results.csv")

In [None]:
hf100b_section_scores, hf100b_overall_score = degradation_score(hf100b_unsafe_results)

print("Section scores:", hf100b_section_scores)
print("Overall interpretability score:", hf100b_overall_score)

In [None]:
hf_acts = extract_activations(wrapper=hf100b_wrapper, prompts=texts)

In [None]:
save_fp_acts_to_pt(hf_acts, "hf_acts")

In [None]:
llama_fp_acts = load_fp_acts("llama_fp_acts_nq_answers")

In [None]:
hf_acts = extract_activations(wrapper=hf100b_wrapper, prompts=DATA)

In [None]:
save_fp_acts_to_pt(hf_acts, "hf100b_acts_nq_answers")

In [None]:
hf_acts = load_fp_acts("hf100b_acts_nq_answers")

In [14]:
run_logit_lens(
    wrapper=hf100b_wrapper,
    prompts=DATA,
    model_name="llama_hf100b",
    dataset_name="nq_queries",
    topk=TOPK,
    eps=EPS,
    skip_input_layer= False, # True to skip embedding
    include_final_norm=True,
    save_layer_probs=False,
    proj_precision="fp32",
    max_len=16
)

Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)


In [15]:
import torch

data = torch.load(
    "logs/logit_lens_logs/logit_lens_analysis/nq_queries_llama_hf100b.pt",
    weights_only=False  # allow full unpickling
)


In [16]:
import pandas as pd
hf_df = pd.DataFrame(data)

In [17]:
get_carry_over_safe_with_embedding(hf_df, topk=5, prefix='layers.')

  return layers_tensor, torch.tensor(input_ids, dtype=torch.long), torch.tensor(target_ids, dtype=torch.long)


{'per_prompt': {0: {'acc_top1_scalar': 0.4117647111415863,
   'acc_topk_scalar': 0.5882353186607361,
   'persistency_top1_scalar': 0.4800758957862854,
   'persistency_topk_scalar': 0.08349145948886871,
   'consistency_top1_scalar': 0.4117647111415863,
   'consistency_topk_scalar': 0.5424836874008179,
   'earliness_top1_scalar': 0.0,
   'earliness_topk_scalar': 0.018975334241986275},
  1: {'acc_top1_scalar': 0.47058823704719543,
   'acc_topk_scalar': 0.6470588445663452,
   'persistency_top1_scalar': 0.4326375424861908,
   'persistency_topk_scalar': 0.09487665444612503,
   'consistency_top1_scalar': 0.47058823704719543,
   'consistency_topk_scalar': 0.5882353186607361,
   'earliness_top1_scalar': 0.0,
   'earliness_topk_scalar': 0.009487674571573734},
  2: {'acc_top1_scalar': 0.23529411852359772,
   'acc_topk_scalar': 0.4117647111415863,
   'persistency_top1_scalar': 0.464895635843277,
   'persistency_topk_scalar': 0.0759013295173645,
   'consistency_top1_scalar': 0.1871657818555832,
   

In [18]:
hf_df.columns

Index(['prompt_id', 'prompt_text', 'layer_index', 'layer_name', 'seq_len',
       'vocab_size', 'topk_pred_tokens_seq', 'preds_seq', 'top1_mean_prob',
       'topk_mean_prob', 'logit_mean', 'logit_std_mean', 'logit_var_mean',
       'prob_mean', 'prob_std_mean', 'prob_var_mean', 'entropy_seq',
       'normalized_entropy_seq', 'correct_1_seq', 'correct_topk_seq', 'ece',
       'ngram_correct_2_seq', 'ngram_correct_3_seq', 'repetition_ratio',
       'logits', 'input_ids', 'target_ids'],
      dtype='object')

In [None]:
hf_df["layer_name"].unique()

In [20]:
hf_df.head()

Unnamed: 0,prompt_id,prompt_text,layer_index,layer_name,seq_len,vocab_size,topk_pred_tokens_seq,preds_seq,top1_mean_prob,topk_mean_prob,...,normalized_entropy_seq,correct_1_seq,correct_topk_seq,ece,ngram_correct_2_seq,ngram_correct_3_seq,repetition_ratio,logits,input_ids,target_ids
0,0,when did richmond last play in a preliminary f...,0,embed_tokens,17,128000,"[[8157, 1601, 75248, 52105, 27014], [78142, 36...","[8157, 78142, 539, 304, 26577, 478, 13, 68025,...",0.018687,0.006148,...,"[0.8536966443061829, 0.8598365783691406, 0.690...","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",0.018687,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]",0.3125,"[[tensor(-2.6243), tensor(-1.9217), tensor(-1....","[tensor(0), tensor(1), tensor(2), tensor(3), t...","[tensor(279), tensor(279), tensor(1274), tenso..."
1,0,when did richmond last play in a preliminary f...,1,layers.0.self_attn,17,128000,"[[8157, 52105, 1601, 30885, 75248], [13852, 30...","[8157, 13852, 539, 35517, 90646, 365, 25, 2746...",0.013611,0.006437,...,"[0.8585053086280823, 0.8094022274017334, 0.808...","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",0.013611,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]",0.3125,"[[tensor(-1.9366), tensor(-1.4553), tensor(-1....","[tensor(0), tensor(1), tensor(2), tensor(3), t...","[tensor(279), tensor(279), tensor(1274), tenso..."
2,0,when did richmond last play in a preliminary f...,2,layers.0.mlp,17,128000,"[[78321, 75248, 11881, 38596, 81913], [20693, ...","[78321, 20693, 539, 15638, 90646, 48300, 33583...",0.060587,0.017295,...,"[0.8441224694252014, 0.8093522191047668, 0.261...","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",0.060587,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]",0.1875,"[[tensor(-0.2932), tensor(-0.8723), tensor(-2....","[tensor(0), tensor(1), tensor(2), tensor(3), t...","[tensor(279), tensor(279), tensor(1274), tenso..."
3,0,when did richmond last play in a preliminary f...,3,layers.0,17,128000,"[[78321, 75248, 11881, 38596, 81913], [20693, ...","[78321, 20693, 539, 15638, 90646, 48300, 33583...",0.060587,0.017295,...,"[0.8441224694252014, 0.8093522191047668, 0.261...","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",0.060587,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]",0.1875,"[[tensor(-0.2932), tensor(-0.8723), tensor(-2....","[tensor(0), tensor(1), tensor(2), tensor(3), t...","[tensor(279), tensor(279), tensor(1274), tenso..."
4,0,when did richmond last play in a preliminary f...,4,layers.1.self_attn,17,128000,"[[117054, 30952, 126459, 91871, 63643], [43577...","[117054, 43577, 539, 893, 90646, 48300, 1797, ...",0.028503,0.008836,...,"[0.9160845279693604, 0.8508396148681641, 0.568...","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",0.028503,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]",0.125,"[[tensor(1.7907), tensor(0.8550), tensor(1.499...","[tensor(0), tensor(1), tensor(2), tensor(3), t...","[tensor(279), tensor(279), tensor(1274), tenso..."


In [19]:
hf_df.isna().sum()

prompt_id                 0
prompt_text               0
layer_index               0
layer_name                0
seq_len                   0
vocab_size                0
topk_pred_tokens_seq      0
preds_seq                 0
top1_mean_prob            0
topk_mean_prob            0
logit_mean                0
logit_std_mean            0
logit_var_mean            0
prob_mean                 0
prob_std_mean             0
prob_var_mean             0
entropy_seq               0
normalized_entropy_seq    0
correct_1_seq             0
correct_topk_seq          0
ece                       0
ngram_correct_2_seq       0
ngram_correct_3_seq       0
repetition_ratio          0
logits                    0
input_ids                 0
target_ids                0
dtype: int64

In [None]:
pos_inf_counts = hf_df.apply(lambda col: (col == np.inf).sum())
neg_inf_counts = hf_df.apply(lambda col: (col == -np.inf).sum())
print(f"{neg_inf_counts}\n{pos_inf_counts}")

In [None]:
#model.save_pretrained("Models/", safe_serialization=True)
#tokenizer.save_pretrained("Models/")