# -------------------------------
# MI Logit Lens Notebook
# -------------------------------

In [None]:
from pathlib import Path
from datasets import load_dataset, DownloadMode
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

from quant_configs.bnb_configs import load_bnb_in_8bit, load_bnb_in_4bit
from mi_utils.util.logit_lens_utils.logit_lens_wrapper import LogitLensWrapper

from mi_utils.logit_lens.logit_lens_degradation_analysis import (
    analyze_UNSAFE_degradation,
    analyze_SAFE_degradation,
)

from mi_utils.logit_lens.logit_lens_unsafe_analysis import run_logit_lens_unsafe
from mi_utils.logit_lens.logit_lens_analysis import run_logit_lens

from mi_utils.logit_lens.metric_utils.logit_lens_helpers import (
    save_results_to_csv,
    get_activation_tensor,
    extract_activations
)

from mi_utils.logit_lens.metric_utils.interp_degradation_scores import (
    degradation_diff_score,
    interpretability_diff_score,
    degradation_score
)

from mi_utils.logit_lens.plotting_utils.lens_plotting import (
    plot_layer_metric_two_dfs,
    plot_layer_deviation,
    qq_plot_probs
)

# -------------------------------
# Models
# -------------------------------

In [None]:
from enum import Enum

class Models(Enum):
    GPT2 = "Models/GPT2"
    LAIN8B = "Models/LLaMA3Instruct"
    HF100B = "Models/HF1BitLLM100Btokens"
    HF10BL = "Models/HF1BitLLMLinear10B"
    HF10BS = "Models/HF1BitLLMSigmoid10B"
    OL1B = "Models/OLMo1B"
    OL7B = "Models/OLMo1B"
    DH3B = "DHLLaMA3B"
    DH8B = "DHLLaMA8B"

class Names(Enum):
    GPT2 = "GPT2"
    LAIN8B = "Meta-Llama-3-8B-Instruct"
    HF100B = "Models/Llama3-8B-1.58-100B-tokens"
    HF10BL = "Llama3-8B-1.58-Linear-10B-tokens"
    HF10BS = "Llama3-8B-1.58-Sigmoid-k100-10B-tokens"
    OL1B = "OLMo-1B-hf"
    OL7B = "OLMo-7B-hf"
    DH3B = "DeepHermes-3-Llama-3-3B-Preview"
    DH8B = "DeepHermes-3-Llama-3-8B-Preview"

In [None]:
def load_model_and_tok(
        model_name:str,
        output_hidden_states:bool=True,
        low_cpu_mem_usage:bool=True,
        local_files_only:bool=True,
        device_map:str="cuda",
        dtype=torch.bfloat16
) -> tuple[AutoModelForCausalLM, AutoTokenizer]:
    
    tok = AutoTokenizer.from_pretrained(model_name)

    model = AutoModelForCausalLM.from_pretrained(
    model_name,
    #output_hidden_states=output_hidden_states,
    low_cpu_mem_usage=low_cpu_mem_usage,
    local_files_only=local_files_only,
    device_map=device_map,
    torch_dtype=dtype
    )

    return model, tok

# -------------------------------
# Parameters & Datasets
# -------------------------------

In [None]:
LL_DIR = "logs/logit_lens_logs"

BATCH_DIR = f"{LL_DIR}/batch_analysis"
INTERP_DIR = f"{LL_DIR}/interp_analysis"

NQ_DIR = "natural_questions"
GS_DIR = "gsm8k"

In [None]:
EPS = 1e-12 
TOPK = 5

In [None]:
GPT_L = ["final_layernorm", "lm_head"]
LLAMA_L = ["norm", "lm_head"]

In [None]:
texts = ["Hello world", "This is a test"]

In [None]:
filepath = r'D:\LogitLensData\nq'

destination_path = str(Path(filepath))
nq_dataset = load_dataset(
    'sentence-transformers/natural-questions',
    split={
        'train': 'train[:20]'
    },
    cache_dir=destination_path,
    download_mode=DownloadMode.REUSE_DATASET_IF_EXISTS,
    keep_in_memory=True
)

In [None]:
nq_queries= nq_dataset['train']['query']
nq_answers = nq_dataset['train']['answer']

In [None]:
filepath = r'D:\LogitLensData\gsm8k'

destination_path = str(Path(filepath))
gsm8k_dataset = load_dataset(
    'gsm8k', 'main',
    split={
        'train': 'train[:20]'
    },
    cache_dir=destination_path,
    download_mode=DownloadMode.REUSE_DATASET_IF_EXISTS,
    keep_in_memory=True
)

In [None]:
gsm8k_questions = gsm8k_dataset['train']['question']
gsm8k_answers = gsm8k_dataset['train']['answer']

In [None]:
DATA = nq_queries

# -------------------------------
# 1. GPT2
# -------------------------------

## -------------------------------
## 1.a. GPT2: 4-bit BnB
## -------------------------------

In [None]:
gpt2, gpt2_tok = load_model_and_tok(Models.GPT2.value, dtype=torch.float32)

In [None]:
gpt2_wrapper = LogitLensWrapper(
    model=gpt2,
    tokenizer=gpt2_tok,
    block_step=1,
    include_input=True,
    force_include_output=True,
    include_subblocks=True,
    decoder_layer_names=["final_layernorm", "lm_head"],
    device="cuda"
)

In [None]:
fp_acts = extract_activations(wrapper=gpt2_wrapper, prompts=texts)

In [None]:
fp_acts

In [None]:
gpt2_4bit, gpt2_4bit_tok = load_bnb_in_4bit(Models.GPT2.value, dtype=torch.float16)

In [None]:
gpt2_4bit_wrapper = LogitLensWrapper(
    model=gpt2_4bit,
    tokenizer=gpt2_4bit_tok,
    block_step=1,
    include_input=True,
    force_include_output=True,
    include_subblocks=True,
    decoder_layer_names=["final_layernorm", "lm_head"],
    device="cuda"
)

In [None]:
q_acts = extract_activations(wrapper=gpt2_4bit_wrapper, prompts=texts)

In [None]:
run_logit_lens_unsafe(
    wrapper=gpt2_4bit_wrapper,
    prompts=texts,
    model_name="gpt2_4bit",
    dataset_name="test",
    A_acts=fp_acts,
    B_acts=q_acts,
    topk=TOPK,
    skip_input_layer= True,
    include_final_norm=True,
    save_layer_probs=False
)

In [None]:
df = pd.read_csv("logs/logit_lens_logs/batch_analysis/test_gpt2_4bit.csv")

In [None]:
print(df.head(2))

In [None]:
print(df.head(2))

In [None]:
df.columns

In [None]:
df["entropy_seq"]

In [None]:
pos_inf_counts = df.apply(lambda col: (col == np.inf).sum())
neg_inf_counts = df.apply(lambda col: (col == -np.inf).sum())
print(f"{neg_inf_counts}\n{pos_inf_counts}")

In [None]:
for k,v in df.items():
    print(v[0])

In [None]:
df["layer_name"].unique()

In [None]:
df.isna().sum()

In [None]:
run_logit_lens_unsafe(
    wrapper=gpt2_wrapper,
    prompts=texts,
    model_name="gpt2_fp",
    dataset_name="test",
    A_acts=fp_acts,
    B_acts=fp_acts,
    topk=TOPK,
    skip_input_layer= True,
    include_final_norm=True,
    save_layer_probs=False
)

In [None]:
df_fp = pd.read_csv("logs/logit_lens_logs/batch_analysis/test_gpt2_fp.csv")

In [None]:
print(df_fp.head(2))

In [None]:
df_fp["layer_name"].unique()

In [None]:
df_fp.columns

In [None]:
pos_inf_counts = df_fp.apply(lambda col: (col == np.inf).sum())
neg_inf_counts = df_fp.apply(lambda col: (col == -np.inf).sum())
print(f"{neg_inf_counts}\n{pos_inf_counts}")

In [None]:
df_fp.isna().sum()

In [None]:
print(df.head(2))

In [None]:
print(df_fp.head(2))

In [None]:
gpt2_4bit_unsafe_results = analyze_UNSAFE_degradation(gpt2_4bit_wrapper, texts=texts, top_k=TOPK, decoder=None)

In [None]:
save_results_to_csv(gpt2_4bit_unsafe_results, filename=f"{LL_DIR}gpt2-4bit_unsafe_results.csv")

In [None]:
gpt2_4bit_section_scores, gpt2_4bit_overall_score = degradation_score(gpt2_4bit_unsafe_results)

print("Section scores:", gpt2_4bit_section_scores)
print("Overall interpretability score:", gpt2_4bit_overall_score)

In [None]:
gpt2_4bit_safe_results = analyze_SAFE_degradation(gpt2_4bit_wrapper, texts, top_k=5, decoder=None)

In [None]:
save_results_to_csv(gpt2_4bit_safe_results, filename=f"{LL_DIR}gpt2-4bit_safe_results.csv")

# -------------------------------
# 2. HF1BitLLM/Llama3-8B-1.58-100B-tokens
# -------------------------------

In [None]:
#hf100b_m, hf100b_tok = load_model_and_tok(Models.HF100B.value, dtype=torch.bfloat16)
hf100b_m, hf100b_tok = load_model_and_tok(Models.HF100B.value, dtype=torch.float32)

In [None]:
wrapper = LogitLensWrapper(
    model=hf100b_m,
    tokenizer=hf100b_tok,
    block_step=1,
    include_input=True,
    force_include_output=True,
    include_subblocks=True,
    decoder_layer_names=["norm", "lm_head"], 
    device="cuda"
)

In [None]:
hf100b_unsafe_results = analyze_UNSAFE_degradation(wrapper, texts, top_k=5, decoder=None)

In [None]:
save_results_to_csv(hf100b_unsafe_results, filename=f"{LL_DIR}hf100b_unsafe_results.csv")

In [None]:
hf100b_section_scores, hf100b_overall_score = degradation_score(hf100b_unsafe_results)

print("Section scores:", hf100b_section_scores)
print("Overall interpretability score:", hf100b_overall_score)

In [None]:
hf_acts = extract_activations(wrapper=wrapper, prompts=texts)

In [None]:
run_logit_lens_unsafe(
    wrapper=wrapper,
    prompts=texts,
    model_name="hf1bitllm_0",
    A_acts=hf_acts,
    B_acts=hf_acts,
    topk=TOPK,
    skip_input_layer= True,
    include_final_norm=True,
    save_layer_probs=True
)

In [None]:
hf_df = pd.read_csv("logs/logit_lens_logs/batch_analysis/dataset_hf1bitllm_0.csv")

In [None]:
hf_df.columns

In [None]:
print(hf_df.head(2))

In [None]:
df.columns

In [None]:
#model.save_pretrained("Models/", safe_serialization=True)
#tokenizer.save_pretrained("Models/")