# -------------------------------
# MI Logit Lens Notebook
# -------------------------------

In [None]:
from pathlib import Path
from datasets import load_dataset, DownloadMode
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

from quant_configs.bnb_configs import load_bnb_in_8bit, load_bnb_in_4bit
from mi_utils.util.logit_lens_utils.logit_lens_wrapper import LogitLensWrapper

from mi_utils.logit_lens.logit_lens_degradation_analysis import (
    analyze_UNSAFE_degradation,
    analyze_SAFE_degradation,
)

from mi_utils.logit_lens.logit_lens_analysis import run_logit_lens

from mi_utils.logit_lens.metric_utils.logit_lens_helpers import (
    save_results_to_csv,
    get_activation_tensor,
    extract_activations,
    load_results_from_pt
)

from mi_utils.logit_lens.metric_utils.interp_degradation_scores import (
    degradation_diff_score,
    interpretability_diff_score,
    degradation_score
)

from mi_utils.logit_lens.plotting_utils.lens_plotting import (
    plot_layer_metric_two_dfs,
    plot_layer_deviation,
    qq_plot_probs
)

from mi_utils.logit_lens.metric_utils.carry_over_safe_metrics import get_carry_over_safe_with_embedding

# -------------------------------
# Models
# -------------------------------

In [None]:
from enum import Enum

class Models(Enum):
    GPT2 = "Models/GPT2"
    LAIN8B = "Models/LLaMA3Instruct"
    HF100B = "Models/HF1BitLLM100Btokens"
    HF10BL = "Models/HF1BitLLMLinear10B"
    HF10BS = "Models/HF1BitLLMSigmoid10B"
    OL1B = "Models/OLMo1B"
    OL7B = "Models/OLMo1B"
    DH3B = "DHLLaMA3B"
    DH8B = "DHLLaMA8B"

class Names(Enum):
    GPT2 = "GPT2"
    LAIN8B = "Meta-Llama-3-8B-Instruct"
    HF100B = "Models/Llama3-8B-1.58-100B-tokens"
    HF10BL = "Llama3-8B-1.58-Linear-10B-tokens"
    HF10BS = "Llama3-8B-1.58-Sigmoid-k100-10B-tokens"
    OL1B = "OLMo-1B-hf"
    OL7B = "OLMo-7B-hf"
    DH3B = "DeepHermes-3-Llama-3-3B-Preview"
    DH8B = "DeepHermes-3-Llama-3-8B-Preview"

In [None]:
def load_model_and_tok(
        model_name:str,
        output_hidden_states:bool=True,
        low_cpu_mem_usage:bool=True,
        local_files_only:bool=True,
        device_map:str="cuda",
        dtype=torch.bfloat16,
        load_in_8bit=False
) -> tuple[AutoModelForCausalLM, AutoTokenizer]:
    
    tok = AutoTokenizer.from_pretrained(model_name)

    model = AutoModelForCausalLM.from_pretrained(
    model_name,
    #output_hidden_states=output_hidden_states,
    low_cpu_mem_usage=low_cpu_mem_usage,
    local_files_only=local_files_only,
    device_map=device_map,
    torch_dtype=dtype,
    load_in_8bit=load_in_8bit
    )

    return model, tok

In [None]:
def save_fp_acts_to_pt(fp_acts, save_name:str) -> None:
    fp_acts_to_save = {}
    for lname, act in fp_acts.items():
        fp_acts_to_save[lname] = {
            "hidden": act["hidden"].detach().cpu(),
            "mask": act["mask"].detach().cpu()
        }

    torch.save(fp_acts_to_save, f"logs/cka_svcca_acts/{save_name}.pt")

def load_fp_acts(fp_acts_name:str):
    fp_acts = torch.load(f"logs/cka_svcca_acts/{fp_acts_name}.pt")
    return fp_acts

# -------------------------------
# Parameters & Datasets
# -------------------------------

In [None]:
LL_DIR = "logs/logit_lens_logs"

BATCH_DIR = f"{LL_DIR}/batch_analysis"
INTERP_DIR = f"{LL_DIR}/interp_analysis"

NQ_DIR = "natural_questions"
GS_DIR = "gsm8k"

In [None]:
EPS = 1e-8 
TOPK = 5

In [None]:
GPT_L = ["final_layernorm", "lm_head"]
LLAMA_L = ["norm", "lm_head"]

In [None]:
texts = ["Hello world", "This is a test"]

In [None]:
filepath = r'D:\LogitLensData\nq'

destination_path = str(Path(filepath))
nq_dataset = load_dataset(
    'sentence-transformers/natural-questions',
    split={
        'train': 'train[:20]'
    },
    cache_dir=destination_path,
    download_mode=DownloadMode.REUSE_DATASET_IF_EXISTS,
    keep_in_memory=True
)

In [None]:
#nq_queries = list(nq_dataset['train']['query'])
#nq_answers = list(nq_dataset['train']['answer'])

In [None]:
nq_queries = nq_dataset['train']['query']
nq_answers = nq_dataset['train']['answer']

In [None]:
filepath = r'D:\LogitLensData\gsm8k'

destination_path = str(Path(filepath))
gsm8k_dataset = load_dataset(
    'gsm8k', 'main',
    split={
        'train': 'train[:20]'
    },
    cache_dir=destination_path,
    download_mode=DownloadMode.REUSE_DATASET_IF_EXISTS,
    keep_in_memory=True
)

In [None]:
gsm8k_questions = gsm8k_dataset['train']['question']
gsm8k_answers = gsm8k_dataset['train']['answer']

In [None]:
DATA = nq_queries

# -------------------------------
# Extract and save FP Activations for CKA/SVCCA
# -------------------------------

In [None]:
llama_fp, llama_tok = load_model_and_tok(Models.LAIN8B.value, device_map="cpu", dtype=torch.float16)

In [None]:
llama_8bit, llama8bit_tok = load_model_and_tok(Models.LAIN8B.value, device_map="cpu", dtype=torch.float32, load_in_8bit=True)

In [None]:
llama_4bit, llama4bit_tok = load_bnb_in_4bit(Models.LAIN8B.value, device_map="cpu")

In [None]:
for name, param in llama_4bit.named_parameters():
    print(f"{name}: {param.dtype}")


In [None]:
llama_fp_wrapper = LogitLensWrapper(
    model=llama_fp,
    tokenizer=llama_tok,
    block_step=1,
    include_input=True,
    force_include_output=True,
    include_subblocks=True,
    decoder_layer_names=LLAMA_L,
    device="cpu",
    max_len=16,
)

In [None]:
llama_bnb_wrapper = LogitLensWrapper(
    model=llama_4bit,
    tokenizer=llama4bit_tok,
    block_step=1,
    include_input=True,
    force_include_output=True,
    include_subblocks=True,
    decoder_layer_names=LLAMA_L,
    device="cpu",
    max_len=16,
)

In [None]:
llama_fp_acts = extract_activations(wrapper=llama_fp_wrapper, prompts=nq_answers)

In [None]:
save_fp_acts_to_pt(llama_fp_acts, "llama_fp_acts_nq_answers")

# -------------------------------
# 2. HF1BitLLM/Llama3-8B-1.58-100B-tokens
# -------------------------------

In [None]:
#hf100b_m, hf100b_tok = load_model_and_tok(Models.HF100B.value, dtype=torch.bfloat16)
hf100b_m, hf100b_tok = load_model_and_tok(Models.HF100B.value, device_map="cpu", dtype=torch.float32)

In [None]:
hfbl_m, hfbl_tok = load_model_and_tok(Models.HF10BL.value, device_map="cpu", dtype=torch.float32)

In [None]:
hfbs_m, hfbs_tok = load_model_and_tok(Models.HF10BS.value, device_map="cpu", dtype=torch.float32)

In [None]:
hf_wrapper = LogitLensWrapper(
    model=hfbl_m,
    tokenizer=hfbl_tok,
    block_step=1,
    include_input=True,
    force_include_output=True,
    include_subblocks=True,
    decoder_layer_names=LLAMA_L, 
    device="cpu",
    max_len=16
)

In [None]:
run_logit_lens(
    wrapper=hf_wrapper,
    prompts=DATA,
    model_name="llama_hfbl_m",
    dataset_name="nq_queries",
    topk=TOPK,
    eps=EPS,
    skip_input_layer= False, # True to skip embedding
    include_final_norm=True,
    save_layer_probs=False,
    proj_precision="fp32",
    max_len=16
)

In [None]:
import torch

data = torch.load(
    "logs/logit_lens_logs/logit_lens_analysis/nq_queries_llama_hfbl_m.pt",
    weights_only=False 
)


In [None]:
import pandas as pd
df = pd.DataFrame(data)

In [None]:
get_carry_over_safe_with_embedding(df, topk=5, prefix='layers.')

In [None]:
df.columns

In [None]:
df["layer_name"].unique()

In [None]:
df.columns

In [None]:
df["entropy_seq"][2]

In [None]:
df.head(2)

In [None]:
df.isna().sum()

In [None]:
pos_inf_counts = df.apply(lambda col: (col == np.inf).sum())
neg_inf_counts = df.apply(lambda col: (col == -np.inf).sum())
print(f"{neg_inf_counts}\n{pos_inf_counts}")

In [None]:
#model.save_pretrained("Models/", safe_serialization=True)
#tokenizer.save_pretrained("Models/")