# -------------------------------
# MI Logit Lens Notebook
# -------------------------------

In [None]:
from pathlib import Path
from datasets import load_dataset, DownloadMode
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

from quant_configs.bnb_configs import load_bnb_in_4bit
from mi_utils.util.logit_lens_utils.logit_lens_wrapper import LogitLensWrapper
from mi_utils.logit_lens.logit_lens_analysis import run_logit_lens
from mi_utils.logit_lens.metric_utils.carry_over_safe_metrics import get_carry_over_safe_with_embedding, compute_carry_over_safe_partitioned

# -------------------------------
# Models
# -------------------------------

In [None]:
from enum import Enum

class Models(Enum):
    #GPT2 = "Models/GPT2"
    LAIN8B = "Models/LLaMA3Instruct"
    HF100B = "Models/HF1BitLLM100Btokens"
    HF10BL = "Models/HF1BitLLMLinear10B"
    HF10BS = "Models/HF1BitLLMSigmoid10B"
    LAIN3B = "Models/LLaMA3Instruct3B"
    DH3B = "Models/DHLLaMA3B"
    DH8B = "Models/DHLLaMA8B"

class Names(Enum):
    #GPT2 = "GPT2"
    LAIN8B = "Meta-Llama-3-8B-Instruct"
    HF100B = "Models/Llama3-8B-1.58-100B-tokens"
    HF10BL = "Llama3-8B-1.58-Linear-10B-tokens"
    HF10BS = "Llama3-8B-1.58-Sigmoid-k100-10B-tokens"
    LAIN3B = "meta-llama/Llama-3.2-3B-Instruct"
    DH3B = "DeepHermes-3-Llama-3-3B-Preview"
    DH8B = "DeepHermes-3-Llama-3-8B-Preview"

In [None]:
def load_model_and_tok(
        model_name:str,
        output_hidden_states:bool=True,
        low_cpu_mem_usage:bool=True,
        local_files_only:bool=True,
        device_map:str="cuda",
        dtype=torch.bfloat16,
        load_in_8bit=False
) -> tuple[AutoModelForCausalLM, AutoTokenizer]:
    
    tok = AutoTokenizer.from_pretrained(model_name)

    model = AutoModelForCausalLM.from_pretrained(
    model_name,
    #output_hidden_states=output_hidden_states,
    low_cpu_mem_usage=low_cpu_mem_usage,
    local_files_only=local_files_only,
    device_map=device_map,
    torch_dtype=dtype,
    load_in_8bit=load_in_8bit
    )

    return model, tok

In [None]:
EPS = 1e-8 
TOPK = 5

#GPT_L = ["final_layernorm", "lm_head"]
LLAMA_L = ["norm", "lm_head"]

In [None]:
filepath = r'D:\LogitLensData\nq'

destination_path = str(Path(filepath))
nq_dataset = load_dataset(
    'sentence-transformers/natural-questions',
    split={
        'train': 'train[:20]'
    },
    cache_dir=destination_path,
    download_mode=DownloadMode.REUSE_DATASET_IF_EXISTS,
    keep_in_memory=True
)

In [None]:
#nq_queries = list(nq_dataset['train']['query'])
#nq_answers = list(nq_dataset['train']['answer'])

nq_queries = nq_dataset['train']['query']
nq_answers = nq_dataset['train']['answer']

In [None]:
filepath = r'D:\LogitLensData\gsm8k'

destination_path = str(Path(filepath))
gsm8k_dataset = load_dataset(
    'gsm8k', 'main',
    split={
        'train': 'train[:20]'
    },
    cache_dir=destination_path,
    download_mode=DownloadMode.REUSE_DATASET_IF_EXISTS,
    keep_in_memory=True
)

In [None]:
gsm8k_questions = gsm8k_dataset['train']['question']
gsm8k_answers = gsm8k_dataset['train']['answer']

In [None]:
DATA_NQ = nq_queries
DATA_GS = gsm8k_questions

In [None]:
dh3b_8bit, dh3b_8bit_tok = load_model_and_tok(Models.DH3B.value, device_map="cpu", dtype=torch.float32, load_in_8bit=True)

In [None]:
dh3b_4bit, dh3b_4bit_tok = load_bnb_in_4bit(Models.DH3B.value, device_map="cpu")

In [None]:
"""for name, param in model.named_parameters():
    print(f"{name}: {param.dtype}")"""

In [None]:
bnb_wrapper = LogitLensWrapper(
    model=dh3b_4bit,
    tokenizer=dh3b_4bit_tok,
    block_step=1,
    include_input=True,
    force_include_output=True,
    include_subblocks=True,
    decoder_layer_names=LLAMA_L,
    device="cpu",
    max_len=16,
)

In [None]:
run_logit_lens(
    wrapper=bnb_wrapper,
    prompts=DATA_NQ,
    model_name="dh3b_4bit",
    dataset_name="nq_queries", # gsm8k_questions, nq_queries
    topk=TOPK,
    eps=EPS,
    skip_input_layer= False, # True to skip embedding
    include_final_norm=True,
    save_layer_probs=False,
    proj_precision="fp32",
    max_len=16
)

In [None]:
import torch

data = torch.load(
    "logs/logit_lens_logs/logit_lens_analysis/gsm8k_questions_llama3b_8bit.pt",
    weights_only=False 
)


In [None]:
import pandas as pd
df = pd.DataFrame(data)

In [None]:
get_carry_over_safe_with_embedding(df, topk=5, prefix='layers.')

In [None]:
df.columns

In [None]:
df["layer_name"].unique()

In [None]:
df["entropy_seq"][2]

In [None]:
df.head(2)

In [None]:
df.isna().sum()

In [None]:
pos_inf_counts = df.apply(lambda col: (col == np.inf).sum())
neg_inf_counts = df.apply(lambda col: (col == -np.inf).sum())
print(f"{neg_inf_counts}\n{pos_inf_counts}")

In [None]:
#model.save_pretrained("Models/", safe_serialization=True)
#tokenizer.save_pretrained("Models/")