Run this notebook after attribute.ipynb to aggregate the tokens into full words and filter out only words with an attribution score above a certain threshold.


In [None]:
quantization_config = BitsAndBytesConfig(load_in_8bit=False,load_in_4bit=True)

model = 'llama'     # 'bloom' / 'llama' / 'mistral'

model_id = {
    'bloom': "bigscience/bloom-7b1",
    'llama': "meta-llama/Meta-Llama-3-8B",
    'mistral': "mistralai/Mistral-7B-v0.1"
}[model]

model_4bit = AutoModelForCausalLM.from_pretrained(
    model_id,
    device_map = "auto",
    quantization_config=quantization_config
)

tokenizer = AutoTokenizer.from_pretrained(model_id)

In [None]:
fa = FeatureAblation(model_4bit)
llm = LLMAttribution(fa, tokenizer)

In [None]:
# Load attributions
file = f"attributions/{model}_attr_output.pt"

data = torch.load(file)

In [None]:
attributions = []

c = 1
for d in data:
    # Make dummy attribution object and overwrite with loaded attribution
    attr = llm.attribute(TextTokenInput("ab cg", tokenizer), "bc cg")
    attr.__dict__ = d
    attributions.append(attr)
    
    # Counter
    print(f"{c}/{len(data)}")
    c += 1

In [None]:
# Method for aggregating subwords into full words. 
def aggregate_attr_cols(attribution_dict):
    # Get rid of first word, which is "<|begin_of_text|>"
    inp = attribution_dict['input_tokens'][1:]
    out = attribution_dict['output_tokens']
    token = attribution_dict['token_attr'][:, 1:]
    seq = attribution_dict['seq_attr'][1:]
    
    # Indexes of the start of every word. Every new word starts with 'Ġ', 'Ċ' or '▁'
    merge_cols = [0] + [i for i, v in enumerate(inp) if (('Ġ' in v) or ('Ċ' in v)) or ('▁' in v)]   

    new_inps = []
    tcols = []
    scols = []
    for i, c in enumerate(merge_cols):
        # Last word goes until the end of the list, all others go until the next 'Ġ' i.e. start of the next word
        if i == len(merge_cols)-1:
            to = len(inp)
        else:
            to = merge_cols[i+1]

        # Initialise the new input word, and the token and seq attributions for the word
        new_inp = inp[c]
        tcol = token[:, c].detach().clone()
        scol = seq[c].detach().clone()

        # For all subwords until the start of the next word, add the subword and attributions.
        for j in range(c+1, to):
            new_inp += inp[j]
            tcol += token[:, j]
            scol += seq[j]

        # Remove special characters
        remove_chars = 'âĢľĠĻĿĊ▁'
        new_inp = ''.join([c for c in new_inp if not c in remove_chars])

        new_inps.append(new_inp)
        tcols.append(tcol)
        scols.append(scol)

    new_dict = {}
    new_dict['input_tokens'] = new_inps
    new_dict['output_tokens'] = out
    new_dict['token_attr'] = torch.transpose(torch.stack(tcols), 0, 1)
    new_dict['seq_attr'] = torch.stack(scols)
    
    return new_dict
    
def aggregate_attr_rows(attribution_dict):
    inp = attribution_dict['input_tokens']
    out = attribution_dict['output_tokens']
    token = attribution_dict['token_attr']
    seq = attribution_dict['seq_attr']
    
    # Indexes of the start of every word. Every new word starts with 'Ġ', 'Ċ' or '▁'
    merge_rows = [i for i, v in enumerate(out) if (('Ġ' in v) or ('Ċ' in v)) or ('▁' in v)]  

    new_outs = []
    trows = []
    for i, c in enumerate(merge_rows):
        # Last word goes until the end of the list, all others go until the next 'Ġ' i.e. start of the next word
        if i == len(merge_rows)-1:
            to = len(out)
        else:
            to = merge_rows[i+1]
            
        # Initialise the new input word, and the token attributions for the word
        new_out = out[c]
        trow = token[c].detach().clone()
        
        # For all subwords until the start of the next word, add the subword and attributions.
        for j in range(c+1, to):
            new_out += out[j]
            trow += token[j]
            
        # Remove special characters
        remove_chars = 'âĢľĠĻĿĊ▁'
        new_out = ''.join([c for c in new_out if not c in remove_chars])
        new_outs.append(new_out)
        trows.append(trow)
    
    new_dict = {}
    new_dict['input_tokens'] = inp
    new_dict['output_tokens'] = new_outs
    new_dict['token_attr'] = torch.stack(trows)
    new_dict['seq_attr'] = seq
    
    return new_dict

# Make dummy attribution object and overwrite with aggregated attribution
# We do it this way with dummies to copy the data so we dont lose the original unaggregated data
attr = llm.attribute(TextTokenInput("ab cg", tokenizer), "bc cg")
attr.__dict__ = aggregate_attr_rows(aggregate_attr_cols(attributions[0].__dict__))
attributions[0].plot_token_attr(show=True)
attr.plot_token_attr(show=True)

attributions[0].plot_seq_attr(show=True)
attr.plot_seq_attr(show=True)

In [None]:
# Load the prompts and generated continuation
sentences = []
for x in open(f"results/{model}/most_toxic.jsonl","r").readlines():
    j = json.loads(x)
    sentences.append(j["prompt"] + j["generated"])

# Aggregate attributions
aggregated = [aggregate_attr_rows(aggregate_attr_cols(i.__dict__)) for i in attributions]

attributed_to = []

for d in aggregated:
    att = {}

    # Iterate over all attributions and store the indices where the attribution is >= 2
    for i, row in enumerate(d["token_attr"].detach()):
        for j, elem in enumerate(row):
            # For each input token store the output tokens that have an attribution >= 2
            inp = d["input_tokens"][j]
            if not inp in att:
                att[inp] = []

            if elem >= 2:
                att[inp].append(d["output_tokens"][i])
    
    attributed_to.append(att)

# Write to file
write_dir = f"aggregated_attr/{model}_attr_threshold.txt"

with open(write_dir,"w") as f:
    for i, a in enumerate(attributed_to):
        f.write(sentences[i] + "\n" + json.dumps(a) + "\n\n")