# 1. Settings

Let's start by importing all the needed packages and setting the function to get the `device`:

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
from dotenv import load_dotenv
import torch

from huggingface_hub import login

In [None]:
from AdversarialPromptGenerator import AdversarialPromptGenerator

from our_base import LocalModel, HuggingFaceEmbeddings
from our_token_shap import TokenizerSplitter, TokenSHAP, get_text_before_last_underscore

In [None]:
def get_device():
    if torch.cuda.is_available():
        return torch.device("cuda")
    # torch.backends.mps may not exist on all builds, guard with getattr
    if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available():
        return torch.device("mps")
    return torch.device("cpu")

DEVICE = get_device()

# 2. Hugging Face

First, retrieve the Hugging Face Key:

In [None]:
load_dotenv()
hf_api_key = os.getenv("HUGGINGFACE_API_KEY")
if not hf_api_key:
    raise RuntimeError("Missing HUGGINGFACE_API_KEY. Set it in your environment or .env file.")
login(hf_api_key)

# 3. TokenSHAP

Then, instantiate TokenSHAP using HuggingFace, specifically using the `meta-llama/Llama-3.2-1B-Instruct` model:

In [None]:
model_path = "meta-llama/Llama-3.2-1B-Instruct"
local_model = LocalModel(model_name=model_path, max_new_tokens=1, temperature=None, device=DEVICE, dtype="float16")
hf_embedding = HuggingFaceEmbeddings(device=DEVICE)
splitter = TokenizerSplitter(local_model.tokenizer)
token_shap = TokenSHAP(model=local_model, splitter=splitter, vectorizer=hf_embedding, debug=True)

In [None]:
local_model.device

Instantiate the `PromptGenerator` to retrieve the adversarial prompts:

In [None]:
adv_prompt_generator = AdversarialPromptGenerator()
adversarial_suffix_path = "./adv_suffixes.pt" # tensor of all 100 suffixes
all_prompts = adv_prompt_generator.get_from(adversarial_suffix_path)

##### Test of Specific Functions of TokenSHAP

In [None]:
token_shap._calculate_baseline(all_prompts[0])

In [None]:
token_shap._get_result_per_combination(all_prompts[0], 0.0)

In [None]:
df_local = token_shap.analyze(all_prompts[0], sampling_ratio=0.0)
token_shap.print_colored_text()

In [None]:
for token, value in token_shap.shapley_values.items():
    print(token, value)
    token = get_text_before_last_underscore(token)

    # Convert token string → token id → decoded string
    token_id = token_shap.model.tokenizer.convert_tokens_to_ids(token)
    print(token)

##### Full Loop to Analyse All 100 Prompts

In [None]:
for i, prompt in enumerate(all_prompts):
    df_local = token_shap.analyze(prompt, sampling_ratio=0.0)
    # token_shap.print_colored_text()
    token_shap.save_results("./results", "all_shapley.json", run_id=f"run_{i:03d}", prompt=prompt)