# -------------------------------
# MI Logit Lens Notebook
# -------------------------------

In [1]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from quant_configs.bnb_configs import load_bnb_in_8bit, load_bnb_in_4bit
from mi_utils.util.logit_lens_utils.logit_lens_wrapper import LogitLensWrapper
from mi_utils.logit_lens.logit_lens_interpretability import (
    analyze_logit_lens_batch,
    analyze_UNSAFE_interpretability,
    analyze_SAFE_interpretability,
    interpretability_UNSAFE_score,
    save_results_to_csv
)

In [2]:
from enum import Enum

class Models(Enum):
    GPT2 = "Models/GPT2"
    INSTUCT8B = "Models/LLaMA3Instruct"
    HF100B = "Models/HF1BitLLM100Btokens"
    HF10BL = "Models/HF1BitLLMLinear10B"
    HF10BS = "Models/HF1BitLLMSigmoid10B"


LL_DIR = "logs/logit_lens_logs/"

In [3]:
def load_model_and_tok(
        model_name:str,
        output_hidden_states:bool=True,
        low_cpu_mem_usage:bool=True,
        local_files_only:bool=True,
        device_map:str="cuda",
        dtype=torch.bfloat16
) -> tuple[AutoModelForCausalLM, AutoTokenizer]:
    
    tok = AutoTokenizer.from_pretrained(model_name)

    model = AutoModelForCausalLM.from_pretrained(
    model_name,
    #output_hidden_states=output_hidden_states,
    low_cpu_mem_usage=low_cpu_mem_usage,
    local_files_only=local_files_only,
    device_map=device_map,
    torch_dtype=dtype
    )

    return model, tok

In [4]:
texts = ["Hello world", "This is a test"]

# -------------------------------
# 1. GPT2
# -------------------------------

## -------------------------------
## 1.a. GPT2: 4-bit BnB
## -------------------------------

In [5]:
gpt2_4bit, gpt2_4bit_tok = load_bnb_in_4bit(Models.GPT2.value, dtype=torch.float16)

In [6]:
wrapper = LogitLensWrapper(
    model=gpt2_4bit,
    tokenizer=gpt2_4bit_tok,
    block_step=1,
    include_input=True,
    force_include_output=True,
    include_subblocks=True,
    decoder_layer_names=["final_layernorm", "lm_head"],
    device="cuda"
)

In [7]:
gpt2_4bit_unsafe_results = analyze_UNSAFE_interpretability(wrapper, texts, top_k=5, decoder=None)

In [8]:
save_results_to_csv(gpt2_4bit_unsafe_results, filename=f"{LL_DIR}gpt2-4bit_unsafe_results.csv")

Results saved to logs/logit_lens_logs/gpt2-4bit_unsafe_results.csv


In [9]:
gpt2_4bit_section_scores, gpt2_4bit_overall_score = interpretability_UNSAFE_score(gpt2_4bit_unsafe_results)

print("Section scores:", gpt2_4bit_section_scores)
print("Overall interpretability score:", gpt2_4bit_overall_score)

Section scores: {'first': 0.0, 'early': 0.11111111111111116, 'mid': 0.25, 'late': 0.25, 'last': 0.0625}
Overall interpretability score: 0.13472222222222224


In [10]:
gpt2_4bit_safe_results = analyze_SAFE_interpretability(wrapper, texts, top_k=5, decoder=None)

In [11]:
save_results_to_csv(gpt2_4bit_safe_results, filename=f"{LL_DIR}gpt2-4bit_safe_results.csv")

Results saved to logs/logit_lens_logs/gpt2-4bit_safe_results.csv


In [12]:
gpt2_4bit_safe_scores, gpt2_4bit_safe_score = interpretability_UNSAFE_score(gpt2_4bit_safe_results)

print("Section scores:", gpt2_4bit_safe_scores)
print("Overall interpretability score:", gpt2_4bit_safe_score)

Section scores: {'first': 1.0, 'early': 1.0, 'mid': 1.0, 'late': 1.0, 'last': 1.0}
Overall interpretability score: 1.0


# -------------------------------
# 2. HF1BitLLM/Llama3-8B-1.58-100B-tokens
# -------------------------------

In [5]:
hf100b_m, hf100b_tok = load_model_and_tok(Models.HF100B.value, dtype=torch.bfloat16)

In [6]:
wrapper = LogitLensWrapper(
    model=hf100b_m,
    tokenizer=hf100b_tok,
    block_step=1,
    include_input=True,
    force_include_output=True,
    include_subblocks=True,
    decoder_layer_names=["norm", "lm_head"], 
    device="cuda"
)

In [7]:
hf100b_unsafe_results = analyze_UNSAFE_interpretability(wrapper, texts, top_k=5, decoder=None)

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.
Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)


In [9]:
save_results_to_csv(hf100b_unsafe_results, filename=f"{LL_DIR}hf100b_unsafe_results.csv")

Results saved to logs/logit_lens_logs/hf100b_unsafe_results.csv


In [10]:
hf100b_section_scores, hf100b_overall_score = interpretability_UNSAFE_score(hf100b_unsafe_results)

print("Section scores:", hf100b_section_scores)
print("Overall interpretability score:", hf100b_overall_score)

Section scores: {'first': 1.0, 'early': 1.0, 'mid': 1.0, 'late': 1.0, 'last': 1.0}
Overall interpretability score: 1.0


In [8]:
#model.save_pretrained("", safe_serialization=True)
#tokenizer.save_pretrained("")