##### Import the Packages

In [None]:
import os
from dotenv import load_dotenv
import torch

from huggingface_hub import login

In [None]:
from AdversarialPromptGenerator import AdversarialPromptGenerator

from our_base import LocalModel, HuggingFaceEmbeddings
from our_token_shap import TokenizerSplitter, TokenSHAP

In [None]:
def get_device():
    if torch.cuda.is_available():
        return torch.device("cuda")
    # torch.backends.mps may not exist on all builds, guard with getattr
    if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available():
        return torch.device("mps")
    return torch.device("cpu")

DEVICE = get_device()

##### Retrieve the Hugging Face Key

In [None]:
load_dotenv()
hf_api_key = os.getenv("HUGGINGFACE_API_KEY")
if not hf_api_key:
    raise RuntimeError("Missing HUGGINGFACE_API_KEY. Set it in your environment or .env file.")
login(hf_api_key)

##### Instantiate TokenSHAP

In [None]:
model_path = "meta-llama/Llama-3.2-1B-Instruct"
local_model = LocalModel(model_name=model_path, max_new_tokens=1, temperature=None, device=DEVICE)
hf_embedding = HuggingFaceEmbeddings(device=DEVICE)
splitter = TokenizerSplitter(local_model.tokenizer)
token_shap_local = TokenSHAP(model=local_model, splitter=splitter, vectorizer=hf_embedding, debug=True)

In [None]:
local_model.device

##### Instantiate PromptGenerator

In [None]:
adv_prompt_generator = AdversarialPromptGenerator()
adversarial_suffix_path = "./adv_suffixes.pt" # tensor of all 100 suffixes
all_prompts = adv_prompt_generator.get_from(adversarial_suffix_path)

In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM

# tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B-Instruct")
tokenizer = splitter.tokenizer # it is literally the same, if not better (on device)
# or local_model.tokenizer???
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-1B-Instruct")
# model = local_model
for prompt in all_prompts:
	messages = [
		{"role": "user", "content": prompt},
	]
	inputs = tokenizer.apply_chat_template(
		messages,
		add_generation_prompt=True,
		tokenize=True,
		return_dict=True,
		return_tensors="pt",
	).to(model.device)

	# print(inputs)

	outputs = model.generate(
		**inputs,
		max_new_tokens=1,
		do_sample=False,
		temperature=None,
		top_p=None,
		pad_token_id=tokenizer.eos_token_id
	)
	print(tokenizer.decode(outputs[0][inputs["input_ids"].shape[-1]:]))

In [None]:
for prompt in all_prompts:
    df_local = token_shap_local.analyze(prompt, sampling_ratio=0.0)
    token_shap_local.print_colored_text()

In [None]:
local_model.generate()