In [2]:
# Try inseq with llama2 series of models
# Extract top k important sentences from the input document based on attribution scores

import inseq
import torch
from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM
from huggingface_hub import login
from datasets import load_dataset
import evaluate

import json
import argparse
from pathlib import Path
from tqdm import tqdm

def is_sentence_ending(text):
    if text.endswith(("!", ".", "?")):
        return True
    if text.endswith((".\"", "?\"", "!\"")):
        return True
    
def get_token_length(text, tokenizer):
    encoded_text = tokenizer(text, 
                             return_tensors="pt", 
                             add_special_tokens=False).input_ids
    
    return encoded_text.shape[-1]

def clean_token(text):
    # TODO: different operation based on model name
    processed_token = text.replace("Ċ", "")
    processed_token = processed_token.replace("Ġ", "")

    return processed_token

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
login("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")

model_name = "meta-llama/Meta-Llama-3-8B-Instruct"
config = AutoConfig.from_pretrained(model_name)
context_window_length = getattr(config, 'max_position_embeddings', 
                                getattr(config, 'n_positions', None))

model = AutoModelForCausalLM.from_pretrained(model_name,
                                             torch_dtype=torch.bfloat16,
                                             device_map="auto",
                                             use_auth_token=True,
                                             cache_dir="/mnt/ssd/llms")
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.model_max_length = context_window_length  # 8192


The token has not been saved to the git credentials helper. Pass `add_to_git_credential=True` in this function directly or `--add-to-git-credential` if using via `huggingface-cli` if you want to set the git credential as well.
Token is valid (permission: read).
Your token has been saved to /root/.cache/huggingface/token
Login successful


Loading checkpoint shards: 100%|██████████| 4/4 [00:02<00:00,  1.34it/s]
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [4]:
test_data = load_dataset("xsum", split="test")

You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.


In [14]:
instruction = "Summarise the document below:"
doc = test_data[1528]['document']

prompt_message = f"{instruction}\n\n{doc}"
messages = [
    {"role": "user", "content": prompt_message},
]

prompt = tokenizer.apply_chat_template(messages, 
                                       return_tensors="pt", 
                                       add_generation_prompt=True).to(model.device)

prompt_text = tokenizer.apply_chat_template(messages,
                                            tokenize=False,
                                            add_generation_prompt=True)

terminators = [
    tokenizer.eos_token_id,
    tokenizer.convert_tokens_to_ids("<|eot_id|>")
]

inseq_model = inseq.load_model(model, "saliency", tokenizer=model_name)
output_ids = model.generate(prompt,
                            do_sample=False,
                            max_new_tokens=64,
                            temperature=0.0,
                            eos_token_id=terminators)

output_text = tokenizer.decode(output_ids[0, prompt.shape[-1]:], skip_special_tokens=False)
output_text = output_text.split('.')[0] + "."  # Note: only keep the first sentence for debugging; for summarisaiton task: keep until \n\n or the last complete sentence [TODO]
# output_text = tokenizer.decode(output_ids[0, prompt.shape[1]:], skip_special_tokens=True)

print(output_text)
out = inseq_model.attribute(
    input_texts=prompt_text,
    generated_texts=prompt_text + output_text,
)

# out.show()

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
The model is loaded with a device map. The device cannot be changed after loading.
The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.
Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


Outgoing US Vice President Joe Biden spoke at a state dinner in Washington, praising Canada and its Prime Minister Justin Trudeau.


Attributing with saliency...: 100%|██████████| 357/357 [00:06<00:00,  3.94it/s]


In [15]:
# Aggregate the attribution scores for each input sentence
# Process instructions and special tokens in chat template separately
start_marker = "<|begin_of_text|><|begin_of_text|><|start_header_id|>user<|end_header_id|> "
end_marker = "<|start_header_id|>assistant<|end_header_id|> "

# Calculate the token length for each part of the prompt
len_start_marker = get_token_length(start_marker, tokenizer)
len_end_marker = get_token_length(end_marker, tokenizer)
len_instruction = get_token_length(instruction, tokenizer)
len_prompt = get_token_length(prompt_message, tokenizer)
total_prompt_len = len_start_marker + len_prompt

doc_start_pos = len_start_marker + len_instruction
start_span = (0, len_start_marker)
instr_span = (len_start_marker, len_start_marker + len_instruction)
end_span = (total_prompt_len, total_prompt_len + len_end_marker)

ends = [i + 1 for i, t in enumerate(out[0].target) if is_sentence_ending(clean_token(t.token)) and i < total_prompt_len] + [total_prompt_len]
starts = [doc_start_pos] + [i + 1 for i, t in enumerate(out[0].target) if is_sentence_ending(clean_token(t.token)) and i < total_prompt_len]
spans = [start_span, instr_span] + list(zip(starts, ends)) + [end_span]

# Remove empty spans 
processed_spans = []
for span in spans:
    if span[0] + 1 < span[1]:
        processed_spans.append(span)

print(processed_spans)
res = out.aggregate("spans", target_spans=processed_spans)
res.show()

[(0, 6), (6, 13), (13, 33), (33, 57), (57, 70), (70, 90), (90, 114), (114, 139), (139, 164), (164, 182), (182, 198), (198, 225), (225, 261), (261, 282), (282, 308), (308, 328), (328, 332)]


Unnamed: 0_level_0,Out,going,ĠUS,ĠVice,ĠPresident,ĠJoe,ĠBiden,Ġspoke,Ġat,Ġa,Ġstate,Ġdinner,Ġin,ĠWashington,",",Ġpraising,ĠCanada,Ġand,Ġits,ĠPrime,ĠMinister,ĠJustin,ĠTrudeau,.
<|begin_of_text|><|begin_of_text|><|start_header_id|>user<|end_header_id|>ĊĊ,0.184,0.177,0.152,0.169,0.137,0.178,0.129,0.158,0.145,0.145,0.145,0.149,0.151,0.145,0.107,0.11,0.118,0.126,0.101,0.114,0.126,0.091,0.095,0.139
SummariseĠtheĠdocumentĠbelow:ĊĊ,0.099,0.082,0.06,0.082,0.059,0.064,0.039,0.072,0.047,0.045,0.057,0.039,0.039,0.034,0.033,0.05,0.04,0.048,0.034,0.026,0.027,0.03,0.017,0.049
TheĠoutgoingĠvice-presidentĠspokeĠduringĠaĠstateĠdinnerĠandĠtookĠtheĠopportunityĠtoĠpraiseĠAmerica'sĠnorthernĠneighbour.Ċ,0.175,0.198,0.124,0.153,0.161,0.091,0.062,0.101,0.164,0.222,0.14,0.213,0.099,0.118,0.046,0.043,0.087,0.078,0.055,0.053,0.033,0.035,0.027,0.029
"""TheĠworldĠisĠgoingĠtoĠspendĠaĠlotĠofĠtimeĠlookingĠtoĠyou,ĠMrĠPrimeĠMinister"",ĠheĠtoldĠtheĠCanadianĠleader.Ċ",0.019,0.02,0.064,0.022,0.038,0.046,0.03,0.03,0.026,0.044,0.02,0.014,0.034,0.049,0.016,0.016,0.076,0.048,0.04,0.074,0.075,0.031,0.023,0.015
MrĠBidenĠhasĠbeenĠhighlyĠcriticalĠofĠUSĠPresident-electĠDonaldĠTrump.Ċ,0.021,0.024,0.076,0.058,0.034,0.089,0.091,0.026,0.021,0.02,0.017,0.014,0.023,0.025,0.013,0.018,0.03,0.022,0.027,0.019,0.017,0.03,0.02,0.014
"""ViveĠleĠCanadaĠbecauseĠweĠneedĠyouĠvery,ĠveryĠbadly,""ĠheĠtoldĠtheĠdinnerĠguests.Ċ",0.098,0.018,0.019,0.023,0.011,0.014,0.017,0.032,0.023,0.03,0.025,0.022,0.025,0.059,0.014,0.031,0.072,0.067,0.061,0.042,0.016,0.014,0.014,0.022
HeĠwentĠonĠtoĠdescribeĠtheĠself-doubtĠthatĠliberalĠleadersĠacrossĠtheĠworldĠareĠcurrentlyĠexperiencingĠafterĠseveralĠpoliticalĠdefeats.Ċ,0.012,0.006,0.008,0.008,0.005,0.008,0.007,0.012,0.008,0.007,0.008,0.006,0.008,0.01,0.009,0.008,0.018,0.011,0.016,0.016,0.006,0.006,0.008,0.01
"ButĠheĠpraisedĠ""genuineĠleaders""ĠincludingĠGermanĠChancellorĠAngelaĠMerkel,ĠsayingĠsuchĠstatesmenĠandĠwomenĠareĠinĠshortĠsupply.Ċ",0.012,0.013,0.01,0.01,0.008,0.012,0.009,0.03,0.011,0.016,0.033,0.014,0.01,0.017,0.01,0.021,0.052,0.028,0.017,0.021,0.015,0.019,0.03,0.017
"MrĠTrudeauĠreportedlyĠbecameĠemotionalĠduringĠMrĠBiden'sĠremarksĠwhenĠtheĠAmericanĠspokeĠofĠhisĠlateĠfather,ĠformerĠPrimeĠMinisterĠPierreĠTrudeau.Ċ",0.021,0.022,0.025,0.023,0.018,0.044,0.029,0.021,0.036,0.017,0.016,0.012,0.022,0.029,0.01,0.011,0.039,0.026,0.024,0.038,0.037,0.046,0.12,0.015
"""You'reĠaĠsuccessfulĠfatherĠwhenĠyourĠchildrenĠturnĠoutĠbetterĠthanĠyou,""ĠMrĠBidenĠsaid.Ċ",0.01,0.006,0.008,0.009,0.007,0.009,0.009,0.011,0.009,0.015,0.015,0.012,0.013,0.019,0.007,0.008,0.009,0.006,0.008,0.009,0.005,0.005,0.004,0.006
ThisĠisĠtheĠsecondĠstateĠdinnerĠsharedĠbyĠtheĠtwoĠNorthĠAmericanĠnationsĠthisĠyear.Ċ,0.022,0.011,0.021,0.023,0.009,0.016,0.007,0.032,0.037,0.09,0.085,0.065,0.062,0.104,0.029,0.017,0.018,0.017,0.012,0.009,0.01,0.01,0.006,0.007
PresidentĠBarackĠObamaĠhostedĠPrimeĠMinisterĠTrudeauĠatĠtheĠWhiteĠHouseĠinĠMarchĠandĠlaterĠinĠtheĠsummerĠvisitedĠCanadaĠtoĠgiveĠaĠspeechĠinĠparliament.Ċ,0.167,0.013,0.034,0.047,0.016,0.03,0.019,0.032,0.024,0.034,0.029,0.024,0.04,0.099,0.022,0.04,0.016,0.015,0.013,0.011,0.028,0.018,0.018,0.009
"CanadianĠofficialsĠsayĠtheĠvisitĠisĠnotĠaboutĠ""specificĠpolicy"",ĠbutĠratherĠ""anĠopportunityĠtoĠshowĠtheĠdepthĠofĠtheĠrelationship"",ĠsaidĠKateĠPurchase,ĠMrĠTrudeau'sĠdirectorĠofĠcommunications.Ċ",0.005,0.008,0.009,0.009,0.009,0.01,0.008,0.011,0.009,0.009,0.009,0.007,0.009,0.008,0.009,0.008,0.008,0.006,0.008,0.008,0.015,0.014,0.016,0.008
"TheĠdinnerĠensuresĠ""thatĠthereĠisĠcontinuityĠinĠtheĠrelationship""ĠwithĠtheĠnewĠAmericanĠadministration,ĠsheĠadded.Ċ",0.005,0.008,0.01,0.009,0.012,0.018,0.008,0.012,0.008,0.01,0.011,0.011,0.01,0.007,0.008,0.008,0.007,0.007,0.008,0.006,0.006,0.006,0.006,0.007
ExpertsĠsayĠMrĠBidenĠwillĠseekĠtoĠassureĠCanadiansĠthatĠtheĠUS-CanadaĠrelationshipĠwillĠremainĠstrongĠduringĠPresident-electĠDonaldĠTrump'sĠpresidency.Ċ,0.008,0.013,0.021,0.02,0.027,0.037,0.029,0.015,0.012,0.012,0.014,0.017,0.015,0.008,0.014,0.011,0.02,0.015,0.037,0.01,0.029,0.012,0.01,0.019
OnĠFridayĠMrĠBidenĠisĠmeetingĠwithĠCanada'sĠprovincialĠpremiersĠandĠindigenousĠleadersĠtoĠdiscussĠclimateĠchange.,0.011,0.019,0.019,0.02,0.025,0.027,0.02,0.018,0.019,0.01,0.016,0.013,0.014,0.008,0.017,0.015,0.01,0.016,0.014,0.014,0.014,0.011,0.009,0.016
<|start_header_id|>assistant<|end_header_id|>ĊĊ,0.097,0.077,0.106,0.106,0.057,0.042,0.065,0.086,0.059,0.074,0.08,0.092,0.037,0.028,0.041,0.084,0.029,0.042,0.047,0.029,0.079,0.031,0.038,0.05
Out,0.034,0.107,0.048,0.058,0.046,0.033,0.035,0.048,0.051,0.023,0.048,0.034,0.027,0.022,0.037,0.038,0.026,0.033,0.032,0.027,0.028,0.026,0.022,0.044
going,Unnamed: 1_level_19,0.18,0.073,0.049,0.038,0.031,0.025,0.041,0.027,0.019,0.023,0.018,0.019,0.013,0.019,0.018,0.015,0.013,0.015,0.012,0.018,0.015,0.009,0.021
ĠUS,Unnamed: 1_level_20,Unnamed: 2_level_20,0.113,0.066,0.056,0.049,0.038,0.032,0.041,0.016,0.024,0.017,0.02,0.017,0.02,0.02,0.022,0.018,0.017,0.013,0.015,0.015,0.008,0.015
ĠVice,Unnamed: 1_level_21,Unnamed: 2_level_21,Unnamed: 3_level_21,0.037,0.053,0.022,0.021,0.014,0.013,0.008,0.015,0.011,0.011,0.007,0.018,0.014,0.008,0.012,0.014,0.008,0.012,0.01,0.006,0.008
ĠPresident,Unnamed: 1_level_22,Unnamed: 2_level_22,Unnamed: 3_level_22,Unnamed: 4_level_22,0.175,0.07,0.065,0.039,0.02,0.013,0.018,0.015,0.015,0.011,0.014,0.014,0.014,0.016,0.011,0.01,0.021,0.018,0.012,0.01
ĠJoe,Unnamed: 1_level_23,Unnamed: 2_level_23,Unnamed: 3_level_23,Unnamed: 4_level_23,Unnamed: 5_level_23,0.058,0.096,0.025,0.022,0.013,0.019,0.016,0.016,0.011,0.02,0.018,0.015,0.015,0.013,0.014,0.029,0.023,0.015,0.011
ĠBiden,Unnamed: 1_level_24,Unnamed: 2_level_24,Unnamed: 3_level_24,Unnamed: 4_level_24,Unnamed: 5_level_24,Unnamed: 6_level_24,0.142,0.029,0.018,0.01,0.014,0.012,0.017,0.008,0.013,0.01,0.014,0.013,0.01,0.01,0.013,0.026,0.017,0.009
Ġspoke,Unnamed: 1_level_25,Unnamed: 2_level_25,Unnamed: 3_level_25,Unnamed: 4_level_25,Unnamed: 5_level_25,Unnamed: 6_level_25,Unnamed: 7_level_25,0.073,0.03,0.01,0.011,0.011,0.013,0.009,0.021,0.015,0.013,0.016,0.016,0.01,0.011,0.01,0.009,0.013
Ġat,Unnamed: 1_level_26,Unnamed: 2_level_26,Unnamed: 3_level_26,Unnamed: 4_level_26,Unnamed: 5_level_26,Unnamed: 6_level_26,Unnamed: 7_level_26,Unnamed: 8_level_26,0.121,0.033,0.026,0.019,0.024,0.015,0.027,0.034,0.015,0.017,0.02,0.014,0.01,0.012,0.007,0.026
Ġa,Unnamed: 1_level_27,Unnamed: 2_level_27,Unnamed: 3_level_27,Unnamed: 4_level_27,Unnamed: 5_level_27,Unnamed: 6_level_27,Unnamed: 7_level_27,Unnamed: 8_level_27,Unnamed: 9_level_27,0.055,0.035,0.025,0.033,0.019,0.031,0.029,0.016,0.018,0.016,0.013,0.008,0.011,0.007,0.022
Ġstate,Unnamed: 1_level_28,Unnamed: 2_level_28,Unnamed: 3_level_28,Unnamed: 4_level_28,Unnamed: 5_level_28,Unnamed: 6_level_28,Unnamed: 7_level_28,Unnamed: 8_level_28,Unnamed: 9_level_28,Unnamed: 10_level_28,0.047,0.031,0.031,0.014,0.024,0.017,0.009,0.01,0.01,0.009,0.007,0.009,0.007,0.012
Ġdinner,Unnamed: 1_level_29,Unnamed: 2_level_29,Unnamed: 3_level_29,Unnamed: 4_level_29,Unnamed: 5_level_29,Unnamed: 6_level_29,Unnamed: 7_level_29,Unnamed: 8_level_29,Unnamed: 9_level_29,Unnamed: 10_level_29,Unnamed: 11_level_29,0.066,0.06,0.017,0.035,0.018,0.012,0.011,0.01,0.009,0.006,0.008,0.003,0.012
Ġin,Unnamed: 1_level_30,Unnamed: 2_level_30,Unnamed: 3_level_30,Unnamed: 4_level_30,Unnamed: 5_level_30,Unnamed: 6_level_30,Unnamed: 7_level_30,Unnamed: 8_level_30,Unnamed: 9_level_30,Unnamed: 10_level_30,Unnamed: 11_level_30,Unnamed: 12_level_30,0.104,0.034,0.052,0.029,0.021,0.02,0.015,0.012,0.009,0.009,0.004,0.018
ĠWashington,Unnamed: 1_level_31,Unnamed: 2_level_31,Unnamed: 3_level_31,Unnamed: 4_level_31,Unnamed: 5_level_31,Unnamed: 6_level_31,Unnamed: 7_level_31,Unnamed: 8_level_31,Unnamed: 9_level_31,Unnamed: 10_level_31,Unnamed: 11_level_31,Unnamed: 12_level_31,Unnamed: 13_level_31,0.035,0.069,0.043,0.013,0.013,0.015,0.013,0.009,0.013,0.009,0.024
",",Unnamed: 1_level_32,Unnamed: 2_level_32,Unnamed: 3_level_32,Unnamed: 4_level_32,Unnamed: 5_level_32,Unnamed: 6_level_32,Unnamed: 7_level_32,Unnamed: 8_level_32,Unnamed: 9_level_32,Unnamed: 10_level_32,Unnamed: 11_level_32,Unnamed: 12_level_32,Unnamed: 13_level_32,Unnamed: 14_level_32,0.194,0.072,0.019,0.017,0.016,0.018,0.015,0.018,0.005,0.049
Ġpraising,Unnamed: 1_level_33,Unnamed: 2_level_33,Unnamed: 3_level_33,Unnamed: 4_level_33,Unnamed: 5_level_33,Unnamed: 6_level_33,Unnamed: 7_level_33,Unnamed: 8_level_33,Unnamed: 9_level_33,Unnamed: 10_level_33,Unnamed: 11_level_33,Unnamed: 12_level_33,Unnamed: 13_level_33,Unnamed: 14_level_33,Unnamed: 15_level_33,0.112,0.037,0.036,0.04,0.035,0.021,0.031,0.019,0.081
ĠCanada,Unnamed: 1_level_34,Unnamed: 2_level_34,Unnamed: 3_level_34,Unnamed: 4_level_34,Unnamed: 5_level_34,Unnamed: 6_level_34,Unnamed: 7_level_34,Unnamed: 8_level_34,Unnamed: 9_level_34,Unnamed: 10_level_34,Unnamed: 11_level_34,Unnamed: 12_level_34,Unnamed: 13_level_34,Unnamed: 14_level_34,Unnamed: 15_level_34,Unnamed: 16_level_34,0.082,0.067,0.064,0.061,0.017,0.025,0.011,0.048
Ġand,Unnamed: 1_level_35,Unnamed: 2_level_35,Unnamed: 3_level_35,Unnamed: 4_level_35,Unnamed: 5_level_35,Unnamed: 6_level_35,Unnamed: 7_level_35,Unnamed: 8_level_35,Unnamed: 9_level_35,Unnamed: 10_level_35,Unnamed: 11_level_35,Unnamed: 12_level_35,Unnamed: 13_level_35,Unnamed: 14_level_35,Unnamed: 15_level_35,Unnamed: 16_level_35,Unnamed: 17_level_35,0.079,0.078,0.047,0.025,0.024,0.012,0.019
Ġits,Unnamed: 1_level_36,Unnamed: 2_level_36,Unnamed: 3_level_36,Unnamed: 4_level_36,Unnamed: 5_level_36,Unnamed: 6_level_36,Unnamed: 7_level_36,Unnamed: 8_level_36,Unnamed: 9_level_36,Unnamed: 10_level_36,Unnamed: 11_level_36,Unnamed: 12_level_36,Unnamed: 13_level_36,Unnamed: 14_level_36,Unnamed: 15_level_36,Unnamed: 16_level_36,Unnamed: 17_level_36,Unnamed: 18_level_36,0.066,0.064,0.03,0.047,0.024,0.021
ĠPrime,Unnamed: 1_level_37,Unnamed: 2_level_37,Unnamed: 3_level_37,Unnamed: 4_level_37,Unnamed: 5_level_37,Unnamed: 6_level_37,Unnamed: 7_level_37,Unnamed: 8_level_37,Unnamed: 9_level_37,Unnamed: 10_level_37,Unnamed: 11_level_37,Unnamed: 12_level_37,Unnamed: 13_level_37,Unnamed: 14_level_37,Unnamed: 15_level_37,Unnamed: 16_level_37,Unnamed: 17_level_37,Unnamed: 18_level_37,Unnamed: 19_level_37,0.099,0.036,0.067,0.024,0.017
ĠMinister,Unnamed: 1_level_38,Unnamed: 2_level_38,Unnamed: 3_level_38,Unnamed: 4_level_38,Unnamed: 5_level_38,Unnamed: 6_level_38,Unnamed: 7_level_38,Unnamed: 8_level_38,Unnamed: 9_level_38,Unnamed: 10_level_38,Unnamed: 11_level_38,Unnamed: 12_level_38,Unnamed: 13_level_38,Unnamed: 14_level_38,Unnamed: 15_level_38,Unnamed: 16_level_38,Unnamed: 17_level_38,Unnamed: 18_level_38,Unnamed: 19_level_38,Unnamed: 20_level_38,0.123,0.073,0.033,0.011
ĠJustin,Unnamed: 1_level_39,Unnamed: 2_level_39,Unnamed: 3_level_39,Unnamed: 4_level_39,Unnamed: 5_level_39,Unnamed: 6_level_39,Unnamed: 7_level_39,Unnamed: 8_level_39,Unnamed: 9_level_39,Unnamed: 10_level_39,Unnamed: 11_level_39,Unnamed: 12_level_39,Unnamed: 13_level_39,Unnamed: 14_level_39,Unnamed: 15_level_39,Unnamed: 16_level_39,Unnamed: 17_level_39,Unnamed: 18_level_39,Unnamed: 19_level_39,Unnamed: 20_level_39,Unnamed: 21_level_39,0.099,0.092,0.02
ĠTrudeau,Unnamed: 1_level_40,Unnamed: 2_level_40,Unnamed: 3_level_40,Unnamed: 4_level_40,Unnamed: 5_level_40,Unnamed: 6_level_40,Unnamed: 7_level_40,Unnamed: 8_level_40,Unnamed: 9_level_40,Unnamed: 10_level_40,Unnamed: 11_level_40,Unnamed: 12_level_40,Unnamed: 13_level_40,Unnamed: 14_level_40,Unnamed: 15_level_40,Unnamed: 16_level_40,Unnamed: 17_level_40,Unnamed: 18_level_40,Unnamed: 19_level_40,Unnamed: 20_level_40,Unnamed: 21_level_40,Unnamed: 22_level_40,0.183,0.02
.,Unnamed: 1_level_41,Unnamed: 2_level_41,Unnamed: 3_level_41,Unnamed: 4_level_41,Unnamed: 5_level_41,Unnamed: 6_level_41,Unnamed: 7_level_41,Unnamed: 8_level_41,Unnamed: 9_level_41,Unnamed: 10_level_41,Unnamed: 11_level_41,Unnamed: 12_level_41,Unnamed: 13_level_41,Unnamed: 14_level_41,Unnamed: 15_level_41,Unnamed: 16_level_41,Unnamed: 17_level_41,Unnamed: 18_level_41,Unnamed: 19_level_41,Unnamed: 20_level_41,Unnamed: 21_level_41,Unnamed: 22_level_41,Unnamed: 23_level_41,0.038


In [16]:
def clean_token(text):
    # TODO: different operation based on model name
    processed_token = text.replace("Ċ", "")
    processed_token = processed_token.replace("Ġ", " ")

    return processed_token

# Aggregate the attribution scores for each input sentence
tok_out = res.aggregate()
prompt_last_index = tok_out[0].attr_pos_start
input_sequences = [clean_token(t.token) for t in tok_out[0].target[2:prompt_last_index-1]]
print(input_sequences)

# Note: no need to use cleaned sequences
attr_scores = tok_out[0].target_attributions[2:prompt_last_index-1].tolist()
assert(len(input_sequences) == len(attr_scores))

sent_scores = dict()
for seq_ix, seq in enumerate(input_sequences):
    sent_scores[seq] = max(attr_scores[seq_ix])

# Extract top K important sentences
sorted_sent_scores = dict(sorted(sent_scores.items(), key=lambda x: x[1], reverse=True))
top_k_sents = list(sorted_sent_scores.keys())[:3]

print(top_k_sents)
for sent in top_k_sents:
    print(sent_scores[sent])

# Store both the attributed sentences and their aggregated scores
attributed_sents = []
for sent in top_k_sents:
    attributed_sents.append(
        {
            "input_sequence": sent,
            "score": sent_scores[sent]
        }
    )

print(attributed_sents)

["The outgoing vice-president spoke during a state dinner and took the opportunity to praise America's northern neighbour.", '"The world is going to spend a lot of time looking to you, Mr Prime Minister", he told the Canadian leader.', 'Mr Biden has been highly critical of US President-elect Donald Trump.', '"Vive le Canada because we need you very, very badly," he told the dinner guests.', 'He went on to describe the self-doubt that liberal leaders across the world are currently experiencing after several political defeats.', 'But he praised "genuine leaders" including German Chancellor Angela Merkel, saying such statesmen and women are in short supply.', "Mr Trudeau reportedly became emotional during Mr Biden's remarks when the American spoke of his late father, former Prime Minister Pierre Trudeau.", '"You\'re a successful father when your children turn out better than you," Mr Biden said.', 'This is the second state dinner shared by the two North American nations this year.', 'Pres