In [1]:
# Try inseq with llama2 series of models
# Extract top k important sentences from the input document based on attribution scores

import inseq
import torch
from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM
from huggingface_hub import login
from datasets import load_dataset
import evaluate

import json
import argparse
from pathlib import Path
from tqdm import tqdm

def is_sentence_ending(text):
    if text.endswith(("!", ".", "?")):
        return True
    if text.endswith((".\"", "?\"", "!\"")):
        return True
    
def get_token_length(text, tokenizer):
    encoded_text = tokenizer(text, 
                             return_tensors="pt", 
                             add_special_tokens=False).input_ids
    
    return encoded_text.shape[-1]

def clean_token(text):
    # TODO: different operation based on model name
    processed_token = text.replace("Ċ", "")
    processed_token = processed_token.replace("Ġ", "")

    return processed_token

In [2]:
login("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")

model_name = "meta-llama/Meta-Llama-3-8B-Instruct"
config = AutoConfig.from_pretrained(model_name)
context_window_length = getattr(config, 'max_position_embeddings', 
                                getattr(config, 'n_positions', None))

model = AutoModelForCausalLM.from_pretrained(model_name,
                                             torch_dtype=torch.bfloat16,
                                             device_map="auto",
                                             use_auth_token=True,
                                             cache_dir="/mnt/ceph_rbd/llms")
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.model_max_length = context_window_length  # 8192


The token has not been saved to the git credentials helper. Pass `add_to_git_credential=True` in this function directly or `--add-to-git-credential` if using via `huggingface-cli` if you want to set the git credential as well.
Token is valid (permission: read).
Your token has been saved to /root/.cache/huggingface/token
Login successful


Loading checkpoint shards: 100%|██████████| 4/4 [00:02<00:00,  1.64it/s]
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [10]:
# test_data = load_dataset("xsum", split="test")

# Load CCSum test data
# Indices of longest sampels in CCSum: [603, 5220, 7200, 11376]
ccsum_dataset = load_dataset("/mnt/ceph_rbd/datasets/ccsum")
dataset_abstractive = ccsum_dataset.filter(lambda x: x["abstractiveness_bin"] == "high")
test_data = dataset_abstractive['test']
print(len(test_data))

4947


In [11]:
instruction = "Summarise the document below:"
doc = test_data[1125]['article']

prompt_message = f"{instruction}\n\n{doc}"
messages = [
    {"role": "user", "content": prompt_message},
]

prompt = tokenizer.apply_chat_template(messages, 
                                       return_tensors="pt", 
                                       add_generation_prompt=True).to(model.device)

prompt_text = tokenizer.apply_chat_template(messages,
                                            tokenize=False,
                                            add_generation_prompt=True)

terminators = [
    tokenizer.eos_token_id,
    tokenizer.convert_tokens_to_ids("<|eot_id|>")
]

inseq_model = inseq.load_model(model, "saliency", tokenizer=model_name)
output_ids = model.generate(prompt,
                            do_sample=False,
                            max_new_tokens=64,
                            temperature=0.0,
                            eos_token_id=terminators)

output_text = tokenizer.decode(output_ids[0, prompt.shape[-1]:], skip_special_tokens=False)
output_text = output_text.split('.')[0] + "."  # Note: only keep the first sentence for debugging; for summarisaiton task: keep until \n\n or the last complete sentence [TODO]
# output_text = tokenizer.decode(output_ids[0, prompt.shape[1]:], skip_special_tokens=True)

print(output_text)
out = inseq_model.attribute(
    input_texts=prompt_text,
    generated_texts=prompt_text + output_text,
)

# out.show()

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
The model is loaded with a device map. The device cannot be changed after loading.
The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.
Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


The document is a summary of various baseball games played on June 15, 2022.


Attributing with saliency...: 100%|██████████| 2700/2700 [00:23<00:00,  1.22s/it]


In [5]:
# Aggregate the attribution scores for each input sentence
# Process instructions and special tokens in chat template separately
start_marker = "<|begin_of_text|><|begin_of_text|><|start_header_id|>user<|end_header_id|> "
end_marker = "<|start_header_id|>assistant<|end_header_id|> "

# Calculate the token length for each part of the prompt
len_start_marker = get_token_length(start_marker, tokenizer)
len_end_marker = get_token_length(end_marker, tokenizer)
len_instruction = get_token_length(instruction, tokenizer)
len_prompt = get_token_length(prompt_message, tokenizer)
total_prompt_len = len_start_marker + len_prompt

doc_start_pos = len_start_marker + len_instruction
start_span = (0, len_start_marker)
instr_span = (len_start_marker, len_start_marker + len_instruction)
end_span = (total_prompt_len, total_prompt_len + len_end_marker)

ends = [i + 1 for i, t in enumerate(out[0].target) if is_sentence_ending(clean_token(t.token)) and i < total_prompt_len] + [total_prompt_len]
starts = [doc_start_pos] + [i + 1 for i, t in enumerate(out[0].target) if is_sentence_ending(clean_token(t.token)) and i < total_prompt_len]
spans = [start_span, instr_span] + list(zip(starts, ends)) + [end_span]

# Remove empty spans 
processed_spans = []
for span in spans:
    if span[0] + 1 < span[1]:
        processed_spans.append(span)

print(processed_spans)
res = out.aggregate("spans", target_spans=processed_spans)
res.show()

[(0, 6), (6, 13), (13, 48), (48, 70), (70, 113), (113, 141), (141, 158), (158, 182), (182, 208), (208, 250), (250, 266), (266, 294), (294, 332), (332, 346), (346, 355), (355, 405), (405, 438), (438, 470), (470, 487), (487, 526), (526, 540), (540, 564), (564, 590), (590, 617), (617, 639), (639, 660), (660, 671), (671, 693), (693, 714), (714, 718)]


Unnamed: 0_level_0,The,ĠPalestinian,ĠAuthority,Ġhas,Ġofficially,Ġbecome,Ġthe,Ġ,123,rd,Ġmember,Ġof,Ġthe,ĠInternational,ĠCriminal,ĠCourt,Ġ(,ICC,"),",Ġgiving,Ġthe,Ġcourt,Ġjurisdiction,Ġover,Ġalleged,Ġcrimes,Ġin,ĠPalestinian,Ġterritories,.
<|begin_of_text|><|begin_of_text|><|start_header_id|>user<|end_header_id|>ĊĊ,0.223,0.164,0.174,0.166,0.171,0.165,0.149,0.114,0.269,0.087,0.1,0.113,0.136,0.15,0.133,0.075,0.086,0.136,0.087,0.133,0.101,0.085,0.111,0.082,0.116,0.087,0.099,0.148,0.109,0.089
SummariseĠtheĠdocumentĠbelow:ĊĊ,0.097,0.076,0.052,0.064,0.061,0.049,0.047,0.041,0.028,0.018,0.026,0.029,0.028,0.036,0.028,0.017,0.029,0.016,0.035,0.039,0.023,0.027,0.046,0.026,0.042,0.038,0.036,0.023,0.027,0.03
"(CNN)TheĠPalestinianĠAuthorityĠofficiallyĠbecameĠtheĠ123rdĠmemberĠofĠtheĠInternationalĠCriminalĠCourtĠonĠWednesday,ĠaĠstepĠthatĠgivesĠtheĠcourtĠjurisdictionĠoverĠallegedĠcrimesĠinĠPalestinianĠterritories.",0.03,0.11,0.292,0.08,0.173,0.144,0.141,0.132,0.25,0.143,0.107,0.065,0.104,0.21,0.202,0.086,0.05,0.171,0.044,0.119,0.136,0.055,0.122,0.081,0.13,0.094,0.076,0.36,0.235,0.088
"ĠTheĠformalĠaccessionĠwasĠmarkedĠwithĠaĠceremonyĠatĠTheĠHague,ĠinĠtheĠNetherlands,ĠwhereĠtheĠcourtĠisĠbased.",0.009,0.013,0.011,0.025,0.074,0.05,0.019,0.022,0.008,0.005,0.017,0.015,0.012,0.025,0.01,0.005,0.007,0.008,0.03,0.019,0.01,0.008,0.009,0.007,0.006,0.007,0.007,0.006,0.007,0.014
"ĠTheĠPalestiniansĠsignedĠtheĠICC'sĠfoundingĠRomeĠStatuteĠinĠJanuary,ĠwhenĠtheyĠalsoĠacceptedĠitsĠjurisdictionĠoverĠallegedĠcrimesĠcommittedĠ""inĠtheĠoccupiedĠPalestinianĠterritory,ĠincludingĠEastĠJerusalem,ĠsinceĠJuneĠ13,Ġ2014.""",0.045,0.045,0.027,0.029,0.024,0.02,0.018,0.029,0.016,0.011,0.041,0.019,0.026,0.048,0.037,0.02,0.025,0.063,0.018,0.032,0.016,0.016,0.029,0.024,0.023,0.031,0.037,0.062,0.067,0.12
"ĠLaterĠthatĠmonth,ĠtheĠICCĠopenedĠaĠpreliminaryĠexaminationĠintoĠtheĠsituationĠinĠPalestinianĠterritories,ĠpavingĠtheĠwayĠforĠpossibleĠwarĠcrimesĠinvestigationsĠagainstĠIsraelis.",0.018,0.029,0.008,0.008,0.013,0.007,0.007,0.007,0.008,0.004,0.009,0.006,0.005,0.006,0.006,0.007,0.005,0.007,0.007,0.012,0.008,0.007,0.012,0.01,0.011,0.04,0.01,0.063,0.029,0.02
"ĠAsĠmembersĠofĠtheĠcourt,ĠPalestiniansĠmayĠbeĠsubjectĠtoĠcounter-chargesĠasĠwell.",0.009,0.008,0.01,0.007,0.017,0.009,0.015,0.011,0.004,0.003,0.019,0.012,0.01,0.006,0.006,0.004,0.003,0.003,0.005,0.009,0.016,0.005,0.007,0.004,0.005,0.007,0.004,0.002,0.003,0.007
"ĠIsraelĠandĠtheĠUnitedĠStates,ĠneitherĠofĠwhichĠisĠanĠICCĠmember,ĠopposedĠtheĠPalestinians'ĠeffortsĠtoĠjoinĠtheĠbody.",0.012,0.011,0.012,0.016,0.018,0.023,0.013,0.015,0.007,0.004,0.018,0.019,0.01,0.014,0.011,0.005,0.008,0.011,0.006,0.012,0.006,0.005,0.006,0.004,0.005,0.006,0.005,0.003,0.003,0.006
"ĠButĠPalestinianĠForeignĠMinisterĠRiadĠal-Malki,ĠspeakingĠatĠWednesday'sĠceremony,ĠsaidĠitĠwasĠaĠmoveĠtowardĠgreaterĠjustice.",0.027,0.088,0.026,0.047,0.009,0.009,0.009,0.007,0.004,0.003,0.009,0.006,0.007,0.013,0.009,0.006,0.005,0.003,0.01,0.011,0.005,0.003,0.004,0.003,0.005,0.005,0.005,0.003,0.006,0.006
"Ġ""AsĠPalestineĠformallyĠbecomesĠaĠStateĠPartyĠtoĠtheĠRomeĠStatuteĠtoday,ĠtheĠworldĠisĠalsoĠaĠstepĠcloserĠtoĠendingĠaĠlongĠeraĠofĠimpunityĠandĠinjustice,""ĠheĠsaid,ĠaccordingĠtoĠanĠICCĠnewsĠrelease.",0.029,0.052,0.018,0.038,0.012,0.012,0.017,0.018,0.019,0.006,0.046,0.023,0.014,0.011,0.007,0.013,0.005,0.008,0.013,0.011,0.007,0.004,0.005,0.004,0.008,0.008,0.005,0.004,0.007,0.007
"Ġ""Indeed,ĠtodayĠbringsĠusĠcloserĠtoĠourĠsharedĠgoalsĠofĠjusticeĠandĠpeace.""",0.003,0.002,0.003,0.007,0.003,0.003,0.003,0.003,0.005,0.002,0.002,0.002,0.003,0.002,0.002,0.003,0.002,0.003,0.003,0.003,0.002,0.002,0.002,0.002,0.002,0.002,0.002,0.001,0.001,0.002
"ĠJudgeĠKunikoĠOzaki,ĠaĠviceĠpresidentĠofĠtheĠICC,ĠsaidĠaccedingĠtoĠtheĠtreatyĠwasĠjustĠtheĠfirstĠstepĠforĠtheĠPalestinians.",0.005,0.005,0.009,0.016,0.014,0.011,0.009,0.007,0.011,0.004,0.008,0.01,0.008,0.007,0.007,0.007,0.006,0.01,0.008,0.013,0.005,0.005,0.004,0.004,0.003,0.004,0.004,0.002,0.003,0.004
"Ġ""AsĠtheĠRomeĠStatuteĠtodayĠentersĠintoĠforceĠforĠtheĠStateĠofĠPalestine,ĠPalestineĠacquiresĠallĠtheĠrightsĠasĠwellĠasĠresponsibilitiesĠthatĠcomeĠwithĠbeingĠaĠStateĠPartyĠtoĠtheĠStatute.",0.008,0.01,0.018,0.051,0.02,0.016,0.014,0.015,0.04,0.011,0.013,0.012,0.021,0.013,0.008,0.04,0.008,0.018,0.021,0.016,0.012,0.007,0.01,0.006,0.007,0.019,0.005,0.004,0.005,0.007
"ĠTheseĠareĠsubstantiveĠcommitments,ĠwhichĠcannotĠbeĠtakenĠlightly,""ĠsheĠsaid.",0.009,0.005,0.004,0.003,0.003,0.003,0.003,0.013,0.006,0.003,0.003,0.003,0.004,0.004,0.004,0.002,0.002,0.002,0.002,0.004,0.002,0.003,0.003,0.002,0.002,0.003,0.005,0.001,0.002,0.004
ĠRightsĠgroupĠHumanĠRightsĠWatchĠwelcomedĠtheĠdevelopment.,0.003,0.003,0.003,0.003,0.003,0.003,0.003,0.003,0.003,0.002,0.003,0.002,0.002,0.003,0.003,0.003,0.002,0.001,0.003,0.003,0.002,0.002,0.002,0.002,0.002,0.003,0.002,0.001,0.001,0.002
"Ġ""GovernmentsĠseekingĠtoĠpenalizeĠPalestineĠforĠjoiningĠtheĠICCĠshouldĠimmediatelyĠendĠtheirĠpressure,ĠandĠcountriesĠthatĠsupportĠuniversalĠacceptanceĠofĠtheĠcourt'sĠtreatyĠshouldĠspeakĠoutĠtoĠwelcomeĠitsĠmembership,""ĠsaidĠBalkeesĠJarrah,ĠinternationalĠjusticeĠcounselĠforĠtheĠgroup.",0.005,0.003,0.005,0.003,0.004,0.004,0.004,0.007,0.004,0.003,0.007,0.006,0.003,0.006,0.006,0.008,0.005,0.003,0.004,0.005,0.004,0.004,0.004,0.004,0.003,0.006,0.004,0.002,0.002,0.004
"Ġ""What'sĠobjectionableĠisĠtheĠattemptsĠtoĠundermineĠinternationalĠjustice,ĠnotĠPalestine'sĠdecisionĠtoĠjoinĠaĠtreatyĠtoĠwhichĠoverĠ100ĠcountriesĠaroundĠtheĠworldĠareĠmembers.""",0.01,0.003,0.003,0.004,0.005,0.004,0.003,0.006,0.009,0.006,0.01,0.005,0.003,0.003,0.003,0.003,0.003,0.002,0.002,0.004,0.002,0.003,0.003,0.002,0.004,0.003,0.003,0.001,0.002,0.002
"ĠInĠJanuary,ĠwhenĠtheĠpreliminaryĠICCĠexaminationĠwasĠopened,ĠIsraeliĠPrimeĠMinisterĠBenjaminĠNetanyahuĠdescribedĠitĠasĠanĠoutrage,ĠsayingĠtheĠcourtĠwasĠoversteppingĠitsĠboundaries.",0.008,0.004,0.005,0.013,0.011,0.007,0.006,0.012,0.003,0.006,0.003,0.004,0.005,0.007,0.005,0.009,0.01,0.007,0.01,0.006,0.004,0.006,0.004,0.005,0.007,0.009,0.004,0.002,0.007,0.004
"ĠTheĠUnitedĠStatesĠalsoĠsaidĠitĠ""strongly""ĠdisagreedĠwithĠtheĠcourt'sĠdecision.",0.004,0.004,0.005,0.031,0.024,0.012,0.018,0.052,0.007,0.012,0.006,0.003,0.014,0.013,0.005,0.015,0.004,0.01,0.002,0.012,0.008,0.009,0.004,0.006,0.013,0.007,0.003,0.002,0.002,0.004
"Ġ""AsĠweĠhaveĠsaidĠrepeatedly,ĠweĠdoĠnotĠbelieveĠthatĠPalestineĠisĠaĠstateĠandĠthereforeĠweĠdoĠnotĠbelieveĠthatĠitĠisĠeligibleĠtoĠjoinĠtheĠICC,""ĠtheĠStateĠDepartmentĠsaidĠinĠaĠstatement.",0.006,0.004,0.004,0.012,0.014,0.003,0.005,0.017,0.002,0.004,0.007,0.007,0.007,0.005,0.006,0.009,0.006,0.004,0.007,0.006,0.004,0.005,0.003,0.003,0.004,0.006,0.003,0.002,0.003,0.003
ĠItĠurgedĠtheĠwarringĠsidesĠtoĠresolveĠtheirĠdifferencesĠthroughĠdirectĠnegotiations.,0.003,0.003,0.006,0.002,0.003,0.004,0.002,0.004,0.002,0.002,0.004,0.003,0.002,0.002,0.003,0.023,0.004,0.003,0.002,0.009,0.003,0.004,0.003,0.005,0.015,0.025,0.007,0.003,0.009,0.003
"Ġ""WeĠwillĠcontinueĠtoĠopposeĠactionsĠagainstĠIsraelĠatĠtheĠICCĠasĠcounterproductiveĠtoĠtheĠcauseĠofĠpeace,""ĠitĠsaid.",0.004,0.003,0.003,0.002,0.002,0.002,0.002,0.003,0.001,0.002,0.002,0.003,0.003,0.004,0.004,0.003,0.003,0.002,0.004,0.003,0.002,0.002,0.003,0.002,0.002,0.003,0.002,0.002,0.002,0.003
"ĠButĠtheĠICCĠbegsĠtoĠdifferĠwithĠtheĠdefinitionĠofĠaĠstateĠforĠitsĠpurposesĠandĠrefersĠtoĠtheĠterritoriesĠasĠ""Palestine.""",0.005,0.004,0.007,0.006,0.012,0.003,0.004,0.005,0.002,0.002,0.006,0.004,0.008,0.006,0.008,0.007,0.005,0.005,0.003,0.004,0.004,0.006,0.004,0.003,0.005,0.007,0.007,0.017,0.015,0.011
"ĠWhileĠaĠpreliminaryĠexaminationĠisĠnotĠaĠformalĠinvestigation,ĠitĠallowsĠtheĠcourtĠtoĠreviewĠevidenceĠandĠdetermineĠwhetherĠtoĠinvestigateĠsuspectsĠonĠbothĠsides.",0.007,0.004,0.003,0.003,0.002,0.003,0.003,0.003,0.002,0.004,0.002,0.003,0.003,0.004,0.005,0.006,0.003,0.002,0.003,0.003,0.004,0.008,0.005,0.004,0.004,0.007,0.005,0.001,0.002,0.004
"ĠProsecutorĠFatouĠBensoudaĠsaidĠherĠofficeĠwouldĠ""conductĠitsĠanalysisĠinĠfullĠindependenceĠandĠimpartiality.""",0.016,0.005,0.007,0.006,0.005,0.008,0.007,0.005,0.003,0.003,0.003,0.004,0.006,0.008,0.011,0.058,0.005,0.008,0.006,0.004,0.005,0.004,0.01,0.013,0.005,0.008,0.006,0.003,0.004,0.01
"ĠTheĠwarĠbetweenĠIsraelĠandĠHamasĠmilitantsĠinĠGazaĠlastĠsummerĠleftĠmoreĠthanĠ2,000ĠpeopleĠdead.",0.119,0.054,0.026,0.03,0.009,0.021,0.013,0.02,0.011,0.012,0.016,0.018,0.023,0.025,0.044,0.013,0.021,0.016,0.01,0.02,0.005,0.009,0.016,0.006,0.024,0.01,0.022,0.004,0.007,0.02
ĠTheĠinquiryĠwillĠincludeĠallegedĠwarĠcrimesĠcommittedĠsinceĠJune.,0.003,0.004,0.004,0.003,0.003,0.004,0.004,0.006,0.002,0.004,0.003,0.004,0.009,0.013,0.016,0.015,0.006,0.005,0.004,0.004,0.003,0.004,0.003,0.005,0.012,0.024,0.013,0.003,0.003,0.005
"ĠTheĠInternationalĠCriminalĠCourtĠwasĠsetĠupĠinĠ2002ĠtoĠprosecuteĠgenocide,ĠcrimesĠagainstĠhumanityĠandĠwarĠcrimes.",0.006,0.007,0.008,0.007,0.007,0.011,0.011,0.023,0.004,0.007,0.011,0.012,0.043,0.065,0.084,0.088,0.018,0.026,0.017,0.01,0.01,0.015,0.009,0.01,0.008,0.016,0.008,0.004,0.005,0.006
"ĠCNN'sĠVascoĠCotovio,ĠKareemĠKhadderĠandĠFaithĠKarimiĠcontributedĠtoĠthisĠreport.",0.02,0.022,0.017,0.017,0.016,0.014,0.017,0.014,0.008,0.016,0.012,0.015,0.012,0.009,0.011,0.01,0.014,0.008,0.014,0.011,0.01,0.012,0.013,0.01,0.014,0.012,0.013,0.005,0.008,0.011
<|start_header_id|>assistant<|end_header_id|>ĊĊ,0.172,0.09,0.05,0.076,0.08,0.067,0.068,0.066,0.063,0.032,0.042,0.034,0.059,0.054,0.058,0.046,0.04,0.022,0.047,0.054,0.034,0.039,0.074,0.033,0.075,0.039,0.049,0.031,0.04,0.039
The,0.081,0.111,0.037,0.064,0.046,0.042,0.072,0.041,0.026,0.035,0.029,0.046,0.03,0.021,0.036,0.027,0.036,0.019,0.05,0.03,0.031,0.033,0.05,0.037,0.05,0.048,0.047,0.015,0.024,0.037
ĠPalestinian,Unnamed: 1_level_32,0.055,0.031,0.031,0.019,0.024,0.023,0.018,0.007,0.015,0.015,0.017,0.015,0.011,0.009,0.009,0.015,0.008,0.018,0.012,0.01,0.015,0.025,0.012,0.016,0.013,0.012,0.007,0.009,0.016
ĠAuthority,Unnamed: 1_level_33,Unnamed: 2_level_33,0.109,0.054,0.031,0.037,0.035,0.028,0.011,0.02,0.024,0.027,0.019,0.016,0.013,0.018,0.043,0.021,0.027,0.02,0.023,0.029,0.015,0.02,0.013,0.015,0.016,0.013,0.017,0.015
Ġhas,Unnamed: 1_level_34,Unnamed: 2_level_34,Unnamed: 3_level_34,0.076,0.044,0.052,0.045,0.036,0.014,0.015,0.028,0.033,0.026,0.02,0.013,0.015,0.02,0.009,0.026,0.021,0.025,0.02,0.017,0.014,0.013,0.018,0.017,0.01,0.016,0.012
Ġofficially,Unnamed: 1_level_35,Unnamed: 2_level_35,Unnamed: 3_level_35,Unnamed: 4_level_35,0.049,0.039,0.04,0.025,0.008,0.012,0.017,0.017,0.012,0.008,0.007,0.007,0.01,0.006,0.023,0.012,0.012,0.014,0.012,0.008,0.007,0.008,0.009,0.004,0.006,0.009
Ġbecome,Unnamed: 1_level_36,Unnamed: 2_level_36,Unnamed: 3_level_36,Unnamed: 4_level_36,Unnamed: 5_level_36,0.115,0.061,0.041,0.009,0.011,0.026,0.019,0.013,0.009,0.005,0.006,0.007,0.004,0.019,0.021,0.011,0.01,0.011,0.007,0.008,0.008,0.008,0.004,0.005,0.008
Ġthe,Unnamed: 1_level_37,Unnamed: 2_level_37,Unnamed: 3_level_37,Unnamed: 4_level_37,Unnamed: 5_level_37,Unnamed: 6_level_37,0.09,0.05,0.016,0.015,0.038,0.022,0.014,0.007,0.006,0.004,0.007,0.005,0.016,0.014,0.014,0.011,0.012,0.007,0.008,0.007,0.008,0.003,0.005,0.006
Ġ,Unnamed: 1_level_38,Unnamed: 2_level_38,Unnamed: 3_level_38,Unnamed: 4_level_38,Unnamed: 5_level_38,Unnamed: 6_level_38,Unnamed: 7_level_38,0.082,0.035,0.054,0.038,0.037,0.017,0.009,0.007,0.008,0.01,0.008,0.009,0.01,0.011,0.01,0.008,0.008,0.006,0.006,0.007,0.003,0.004,0.006
123,Unnamed: 1_level_39,Unnamed: 2_level_39,Unnamed: 3_level_39,Unnamed: 4_level_39,Unnamed: 5_level_39,Unnamed: 6_level_39,Unnamed: 7_level_39,Unnamed: 8_level_39,0.073,0.106,0.037,0.029,0.015,0.007,0.007,0.007,0.01,0.006,0.009,0.01,0.009,0.008,0.007,0.006,0.005,0.005,0.006,0.003,0.004,0.005
rd,Unnamed: 1_level_40,Unnamed: 2_level_40,Unnamed: 3_level_40,Unnamed: 4_level_40,Unnamed: 5_level_40,Unnamed: 6_level_40,Unnamed: 7_level_40,Unnamed: 8_level_40,Unnamed: 9_level_40,0.301,0.086,0.049,0.022,0.01,0.009,0.006,0.011,0.007,0.013,0.016,0.01,0.008,0.009,0.007,0.007,0.007,0.009,0.004,0.005,0.008
Ġmember,Unnamed: 1_level_41,Unnamed: 2_level_41,Unnamed: 3_level_41,Unnamed: 4_level_41,Unnamed: 5_level_41,Unnamed: 6_level_41,Unnamed: 7_level_41,Unnamed: 8_level_41,Unnamed: 9_level_41,Unnamed: 10_level_41,0.103,0.075,0.03,0.012,0.01,0.006,0.01,0.008,0.014,0.017,0.012,0.01,0.01,0.008,0.006,0.007,0.008,0.003,0.004,0.007
Ġof,Unnamed: 1_level_42,Unnamed: 2_level_42,Unnamed: 3_level_42,Unnamed: 4_level_42,Unnamed: 5_level_42,Unnamed: 6_level_42,Unnamed: 7_level_42,Unnamed: 8_level_42,Unnamed: 9_level_42,Unnamed: 10_level_42,Unnamed: 11_level_42,0.167,0.085,0.027,0.016,0.011,0.015,0.01,0.02,0.022,0.018,0.015,0.014,0.01,0.009,0.009,0.012,0.004,0.006,0.01
Ġthe,Unnamed: 1_level_43,Unnamed: 2_level_43,Unnamed: 3_level_43,Unnamed: 4_level_43,Unnamed: 5_level_43,Unnamed: 6_level_43,Unnamed: 7_level_43,Unnamed: 8_level_43,Unnamed: 9_level_43,Unnamed: 10_level_43,Unnamed: 11_level_43,Unnamed: 12_level_43,0.11,0.034,0.018,0.018,0.022,0.016,0.017,0.016,0.016,0.017,0.011,0.012,0.008,0.009,0.011,0.003,0.005,0.007
ĠInternational,Unnamed: 1_level_44,Unnamed: 2_level_44,Unnamed: 3_level_44,Unnamed: 4_level_44,Unnamed: 5_level_44,Unnamed: 6_level_44,Unnamed: 7_level_44,Unnamed: 8_level_44,Unnamed: 9_level_44,Unnamed: 10_level_44,Unnamed: 11_level_44,Unnamed: 12_level_44,Unnamed: 13_level_44,0.035,0.028,0.022,0.029,0.022,0.016,0.009,0.013,0.019,0.009,0.011,0.007,0.007,0.009,0.003,0.005,0.006
ĠCriminal,Unnamed: 1_level_45,Unnamed: 2_level_45,Unnamed: 3_level_45,Unnamed: 4_level_45,Unnamed: 5_level_45,Unnamed: 6_level_45,Unnamed: 7_level_45,Unnamed: 8_level_45,Unnamed: 9_level_45,Unnamed: 10_level_45,Unnamed: 11_level_45,Unnamed: 12_level_45,Unnamed: 13_level_45,Unnamed: 14_level_45,0.074,0.079,0.068,0.052,0.031,0.019,0.025,0.027,0.016,0.016,0.014,0.015,0.014,0.005,0.009,0.009
ĠCourt,Unnamed: 1_level_46,Unnamed: 2_level_46,Unnamed: 3_level_46,Unnamed: 4_level_46,Unnamed: 5_level_46,Unnamed: 6_level_46,Unnamed: 7_level_46,Unnamed: 8_level_46,Unnamed: 9_level_46,Unnamed: 10_level_46,Unnamed: 11_level_46,Unnamed: 12_level_46,Unnamed: 13_level_46,Unnamed: 14_level_46,Unnamed: 15_level_46,0.15,0.119,0.084,0.053,0.033,0.044,0.047,0.025,0.026,0.021,0.024,0.024,0.008,0.013,0.015
Ġ(,Unnamed: 1_level_47,Unnamed: 2_level_47,Unnamed: 3_level_47,Unnamed: 4_level_47,Unnamed: 5_level_47,Unnamed: 6_level_47,Unnamed: 7_level_47,Unnamed: 8_level_47,Unnamed: 9_level_47,Unnamed: 10_level_47,Unnamed: 11_level_47,Unnamed: 12_level_47,Unnamed: 13_level_47,Unnamed: 14_level_47,Unnamed: 15_level_47,Unnamed: 16_level_47,0.18,0.1,0.079,0.044,0.069,0.068,0.03,0.041,0.022,0.024,0.025,0.011,0.017,0.016
ICC,Unnamed: 1_level_48,Unnamed: 2_level_48,Unnamed: 3_level_48,Unnamed: 4_level_48,Unnamed: 5_level_48,Unnamed: 6_level_48,Unnamed: 7_level_48,Unnamed: 8_level_48,Unnamed: 9_level_48,Unnamed: 10_level_48,Unnamed: 11_level_48,Unnamed: 12_level_48,Unnamed: 13_level_48,Unnamed: 14_level_48,Unnamed: 15_level_48,Unnamed: 16_level_48,Unnamed: 17_level_48,0.031,0.052,0.019,0.042,0.053,0.021,0.022,0.016,0.016,0.017,0.007,0.012,0.013
"),",Unnamed: 1_level_49,Unnamed: 2_level_49,Unnamed: 3_level_49,Unnamed: 4_level_49,Unnamed: 5_level_49,Unnamed: 6_level_49,Unnamed: 7_level_49,Unnamed: 8_level_49,Unnamed: 9_level_49,Unnamed: 10_level_49,Unnamed: 11_level_49,Unnamed: 12_level_49,Unnamed: 13_level_49,Unnamed: 14_level_49,Unnamed: 15_level_49,Unnamed: 16_level_49,Unnamed: 17_level_49,Unnamed: 18_level_49,0.078,0.022,0.06,0.086,0.027,0.024,0.015,0.021,0.019,0.006,0.012,0.012
Ġgiving,Unnamed: 1_level_50,Unnamed: 2_level_50,Unnamed: 3_level_50,Unnamed: 4_level_50,Unnamed: 5_level_50,Unnamed: 6_level_50,Unnamed: 7_level_50,Unnamed: 8_level_50,Unnamed: 9_level_50,Unnamed: 10_level_50,Unnamed: 11_level_50,Unnamed: 12_level_50,Unnamed: 13_level_50,Unnamed: 14_level_50,Unnamed: 15_level_50,Unnamed: 16_level_50,Unnamed: 17_level_50,Unnamed: 18_level_50,Unnamed: 19_level_50,0.045,0.031,0.024,0.023,0.022,0.015,0.013,0.018,0.008,0.01,0.017
Ġthe,Unnamed: 1_level_51,Unnamed: 2_level_51,Unnamed: 3_level_51,Unnamed: 4_level_51,Unnamed: 5_level_51,Unnamed: 6_level_51,Unnamed: 7_level_51,Unnamed: 8_level_51,Unnamed: 9_level_51,Unnamed: 10_level_51,Unnamed: 11_level_51,Unnamed: 12_level_51,Unnamed: 13_level_51,Unnamed: 14_level_51,Unnamed: 15_level_51,Unnamed: 16_level_51,Unnamed: 17_level_51,Unnamed: 18_level_51,Unnamed: 19_level_51,Unnamed: 20_level_51,0.062,0.043,0.036,0.036,0.022,0.015,0.018,0.009,0.01,0.015
Ġcourt,Unnamed: 1_level_52,Unnamed: 2_level_52,Unnamed: 3_level_52,Unnamed: 4_level_52,Unnamed: 5_level_52,Unnamed: 6_level_52,Unnamed: 7_level_52,Unnamed: 8_level_52,Unnamed: 9_level_52,Unnamed: 10_level_52,Unnamed: 11_level_52,Unnamed: 12_level_52,Unnamed: 13_level_52,Unnamed: 14_level_52,Unnamed: 15_level_52,Unnamed: 16_level_52,Unnamed: 17_level_52,Unnamed: 18_level_52,Unnamed: 19_level_52,Unnamed: 20_level_52,Unnamed: 21_level_52,0.066,0.034,0.042,0.02,0.015,0.017,0.008,0.01,0.009
Ġjurisdiction,Unnamed: 1_level_53,Unnamed: 2_level_53,Unnamed: 3_level_53,Unnamed: 4_level_53,Unnamed: 5_level_53,Unnamed: 6_level_53,Unnamed: 7_level_53,Unnamed: 8_level_53,Unnamed: 9_level_53,Unnamed: 10_level_53,Unnamed: 11_level_53,Unnamed: 12_level_53,Unnamed: 13_level_53,Unnamed: 14_level_53,Unnamed: 15_level_53,Unnamed: 16_level_53,Unnamed: 17_level_53,Unnamed: 18_level_53,Unnamed: 19_level_53,Unnamed: 20_level_53,Unnamed: 21_level_53,Unnamed: 22_level_53,0.046,0.064,0.021,0.016,0.015,0.007,0.011,0.011
Ġover,Unnamed: 1_level_54,Unnamed: 2_level_54,Unnamed: 3_level_54,Unnamed: 4_level_54,Unnamed: 5_level_54,Unnamed: 6_level_54,Unnamed: 7_level_54,Unnamed: 8_level_54,Unnamed: 9_level_54,Unnamed: 10_level_54,Unnamed: 11_level_54,Unnamed: 12_level_54,Unnamed: 13_level_54,Unnamed: 14_level_54,Unnamed: 15_level_54,Unnamed: 16_level_54,Unnamed: 17_level_54,Unnamed: 18_level_54,Unnamed: 19_level_54,Unnamed: 20_level_54,Unnamed: 21_level_54,Unnamed: 22_level_54,Unnamed: 23_level_54,0.167,0.055,0.041,0.033,0.014,0.018,0.019
Ġalleged,Unnamed: 1_level_55,Unnamed: 2_level_55,Unnamed: 3_level_55,Unnamed: 4_level_55,Unnamed: 5_level_55,Unnamed: 6_level_55,Unnamed: 7_level_55,Unnamed: 8_level_55,Unnamed: 9_level_55,Unnamed: 10_level_55,Unnamed: 11_level_55,Unnamed: 12_level_55,Unnamed: 13_level_55,Unnamed: 14_level_55,Unnamed: 15_level_55,Unnamed: 16_level_55,Unnamed: 17_level_55,Unnamed: 18_level_55,Unnamed: 19_level_55,Unnamed: 20_level_55,Unnamed: 21_level_55,Unnamed: 22_level_55,Unnamed: 23_level_55,Unnamed: 24_level_55,0.052,0.039,0.035,0.011,0.01,0.014
Ġcrimes,Unnamed: 1_level_56,Unnamed: 2_level_56,Unnamed: 3_level_56,Unnamed: 4_level_56,Unnamed: 5_level_56,Unnamed: 6_level_56,Unnamed: 7_level_56,Unnamed: 8_level_56,Unnamed: 9_level_56,Unnamed: 10_level_56,Unnamed: 11_level_56,Unnamed: 12_level_56,Unnamed: 13_level_56,Unnamed: 14_level_56,Unnamed: 15_level_56,Unnamed: 16_level_56,Unnamed: 17_level_56,Unnamed: 18_level_56,Unnamed: 19_level_56,Unnamed: 20_level_56,Unnamed: 21_level_56,Unnamed: 22_level_56,Unnamed: 23_level_56,Unnamed: 24_level_56,Unnamed: 25_level_56,0.058,0.063,0.012,0.013,0.009
Ġin,Unnamed: 1_level_57,Unnamed: 2_level_57,Unnamed: 3_level_57,Unnamed: 4_level_57,Unnamed: 5_level_57,Unnamed: 6_level_57,Unnamed: 7_level_57,Unnamed: 8_level_57,Unnamed: 9_level_57,Unnamed: 10_level_57,Unnamed: 11_level_57,Unnamed: 12_level_57,Unnamed: 13_level_57,Unnamed: 14_level_57,Unnamed: 15_level_57,Unnamed: 16_level_57,Unnamed: 17_level_57,Unnamed: 18_level_57,Unnamed: 19_level_57,Unnamed: 20_level_57,Unnamed: 21_level_57,Unnamed: 22_level_57,Unnamed: 23_level_57,Unnamed: 24_level_57,Unnamed: 25_level_57,Unnamed: 26_level_57,0.072,0.015,0.017,0.016
ĠPalestinian,Unnamed: 1_level_58,Unnamed: 2_level_58,Unnamed: 3_level_58,Unnamed: 4_level_58,Unnamed: 5_level_58,Unnamed: 6_level_58,Unnamed: 7_level_58,Unnamed: 8_level_58,Unnamed: 9_level_58,Unnamed: 10_level_58,Unnamed: 11_level_58,Unnamed: 12_level_58,Unnamed: 13_level_58,Unnamed: 14_level_58,Unnamed: 15_level_58,Unnamed: 16_level_58,Unnamed: 17_level_58,Unnamed: 18_level_58,Unnamed: 19_level_58,Unnamed: 20_level_58,Unnamed: 21_level_58,Unnamed: 22_level_58,Unnamed: 23_level_58,Unnamed: 24_level_58,Unnamed: 25_level_58,Unnamed: 26_level_58,Unnamed: 27_level_58,0.03,0.029,0.026
Ġterritories,Unnamed: 1_level_59,Unnamed: 2_level_59,Unnamed: 3_level_59,Unnamed: 4_level_59,Unnamed: 5_level_59,Unnamed: 6_level_59,Unnamed: 7_level_59,Unnamed: 8_level_59,Unnamed: 9_level_59,Unnamed: 10_level_59,Unnamed: 11_level_59,Unnamed: 12_level_59,Unnamed: 13_level_59,Unnamed: 14_level_59,Unnamed: 15_level_59,Unnamed: 16_level_59,Unnamed: 17_level_59,Unnamed: 18_level_59,Unnamed: 19_level_59,Unnamed: 20_level_59,Unnamed: 21_level_59,Unnamed: 22_level_59,Unnamed: 23_level_59,Unnamed: 24_level_59,Unnamed: 25_level_59,Unnamed: 26_level_59,Unnamed: 27_level_59,Unnamed: 28_level_59,0.078,0.05
.,Unnamed: 1_level_60,Unnamed: 2_level_60,Unnamed: 3_level_60,Unnamed: 4_level_60,Unnamed: 5_level_60,Unnamed: 6_level_60,Unnamed: 7_level_60,Unnamed: 8_level_60,Unnamed: 9_level_60,Unnamed: 10_level_60,Unnamed: 11_level_60,Unnamed: 12_level_60,Unnamed: 13_level_60,Unnamed: 14_level_60,Unnamed: 15_level_60,Unnamed: 16_level_60,Unnamed: 17_level_60,Unnamed: 18_level_60,Unnamed: 19_level_60,Unnamed: 20_level_60,Unnamed: 21_level_60,Unnamed: 22_level_60,Unnamed: 23_level_60,Unnamed: 24_level_60,Unnamed: 25_level_60,Unnamed: 26_level_60,Unnamed: 27_level_60,Unnamed: 28_level_60,Unnamed: 29_level_60,0.06


In [6]:
def clean_token(text):
    # TODO: different operation based on model name
    processed_token = text.replace("Ċ", "")
    processed_token = processed_token.replace("Ġ", " ")

    return processed_token

# Aggregate the attribution scores for each input sentence
tok_out = res.aggregate()
prompt_last_index = tok_out[0].attr_pos_start
input_sequences = [clean_token(t.token) for t in tok_out[0].target[2:prompt_last_index-1]]
print(input_sequences)

# Note: no need to use cleaned sequences
attr_scores = tok_out[0].target_attributions[2:prompt_last_index-1].tolist()
assert(len(input_sequences) == len(attr_scores))

sent_scores = dict()
for seq_ix, seq in enumerate(input_sequences):
    sent_scores[seq] = max(attr_scores[seq_ix])

# Extract top K important sentences
sorted_sent_scores = dict(sorted(sent_scores.items(), key=lambda x: x[1], reverse=True))
top_k_sents = list(sorted_sent_scores.keys())[:3]

print(top_k_sents)
for sent in top_k_sents:
    print(sent_scores[sent])

# Store both the attributed sentences and their aggregated scores
attributed_sents = []
for sent in top_k_sents:
    attributed_sents.append(
        {
            "input_sequence": sent,
            "score": sent_scores[sent]
        }
    )

print(attributed_sents)

['(CNN)The Palestinian Authority officially became the 123rd member of the International Criminal Court on Wednesday, a step that gives the court jurisdiction over alleged crimes in Palestinian territories.', ' The formal accession was marked with a ceremony at The Hague, in the Netherlands, where the court is based.', ' The Palestinians signed the ICC\'s founding Rome Statute in January, when they also accepted its jurisdiction over alleged crimes committed "in the occupied Palestinian territory, including East Jerusalem, since June 13, 2014."', ' Later that month, the ICC opened a preliminary examination into the situation in Palestinian territories, paving the way for possible war crimes investigations against Israelis.', ' As members of the court, Palestinians may be subject to counter-charges as well.', " Israel and the United States, neither of which is an ICC member, opposed the Palestinians' efforts to join the body.", " But Palestinian Foreign Minister Riad al-Malki, speaking 