Run this notebook after get_most_toxic.py to get the attribution scores for all the generated output tokens (continuations) wrt the input tokens (prompts).

After this notebook run attr_aggregate_and_threshold.ipynb

In [None]:
!pip install accelerate
!pip install -i https://pypi.org/simple/ bitsandbytes
!pip install captum

from transformers import AutoModelForCausalLM, BitsAndBytesConfig, AutoTokenizer
from huggingface_hub import login
from captum.attr import FeatureAblation, LLMAttribution, TextTokenInput
import json
import torch

login("<HUGGINGFACE_API_TOKEN>")

In [None]:
quantization_config = BitsAndBytesConfig(load_in_8bit=False,load_in_4bit=True)

model = 'mistral'     # 'bloom' / 'llama' / 'mistral'

model_id = {
    'bloom': "bigscience/bloom-7b1",
    'llama': "meta-llama/Meta-Llama-3-8B",
    'mistral': "mistralai/Mistral-7B-v0.1"
}[model]

model_4bit = AutoModelForCausalLM.from_pretrained(
    model_id,
    device_map = "auto",
    quantization_config=quantization_config
)

tokenizer = AutoTokenizer.from_pretrained(model_id)

In [None]:
fa = FeatureAblation(model_4bit)
llm = LLMAttribution(fa, tokenizer)

In [None]:
file = f"results/{model}/most_toxic.jsonl"
outs = []

for line in open(file).readlines():
    message = json.loads(line)
    i = message["prompt"]
    o = message["generated"]
    inp = TextTokenInput(i, tokenizer)

    # Get attributions
    out = llm.attribute(inp, target=o, show_progress=True)
    outs.append(out.__dict__)

# Save attributions
torch.save(outs, f'attributions/{model}_attr_output.pt')