In [1]:
import inseq
import pandas as pd
from matplotlib import pyplot as plt
from transformers import AutoModelForCausalLM, AutoTokenizer
import seaborn as sns

In [2]:
model_path = "/media/cribin/4ff138a6-4ad8-4fcb-b932-97393d17c043/FaithfulSummarization/gemma_2b_span_absinth_finetuned/merged"
# model_path = "/media/cribin/4ff138a6-4ad8-4fcb-b932-97393d17c043/FaithfulSummarization/mistral-7b_span_absinth_finetuned/merged"
model = AutoModelForCausalLM.from_pretrained(model_path, load_in_8bit=True, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(model_path)

The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. Please, pass a `BitsAndBytesConfig` object in `quantization_config` argument instead.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [3]:
attrib_model = inseq.load_model(
    model=model,
    attribution_method="saliency"
)

The model is loaded in 8bit mode. The device cannot be changed after loading the model.


In [4]:
input_prompt = """<Anweisung>
Gegeben ist ein Artikel und ein spezifischer Satz, der aus diesem Artikel abgeleitet wurde. Analysiere den gesamten Satz, um seine Treue zum ursprünglichen Artikel zu bestimmen. Der Satz kann 'Faithful' (der Satz gibt die Information im Artikel wieder ohne jegliche Änderung), 'Intrinsic' (der Satz ändert oder widerspricht Informationen aus dem Artikel) oder 'Extrinsic' (der Satz führt neue, nicht verwandte Informationen ein, die nicht im Artikel enthalten sind) sein.
Falls der Satz 'Intrinsic' oder 'Extrinsic' ist, identifiziere und gib die entsprechenden Satzabschnitte an, wobei jeder Abschnitt ein eigenes Label erhält. Ein Satz kann verschiedene Arten von Halluzinationen enthalten.
</Anweisung>

<Artikel>
Thomas ist 6 Jahre alt.
</Artikel>
<Satz>
Thomas ist 7 Jahre alt.
</Satz>
<Antwort>
"""
faithful_completion = """{
    "Antwort": [
        {
            "Satzabschnitt": "-",
            "Label": "Faithful"
        }
    ]
}
</Antwort>"""
intrinsic_completion = """{
    "Antwort": [
        {
            "Satzabschnitt": "7",
            "Label": "Intrinsic"
        }
    ]
}
</Antwort>"""
output_faithful = input_prompt + faithful_completion
output_intrinsic = input_prompt + intrinsic_completion


In [5]:
input_tokens = tokenizer(input_prompt, return_tensors="pt").to(model.device)
model_completion = model.generate(input_tokens.input_ids, max_new_tokens=100)
output_text = tokenizer.batch_decode(model_completion, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
print(output_text)

<Anweisung>
Gegeben ist ein Artikel und ein spezifischer Satz, der aus diesem Artikel abgeleitet wurde. Analysiere den gesamten Satz, um seine Treue zum ursprünglichen Artikel zu bestimmen. Der Satz kann 'Faithful' (der Satz gibt die Information im Artikel wieder ohne jegliche Änderung), 'Intrinsic' (der Satz ändert oder widerspricht Informationen aus dem Artikel) oder 'Extrinsic' (der Satz führt neue, nicht verwandte Informationen ein, die nicht im Artikel enthalten sind) sein.
Falls der Satz 'Intrinsic' oder 'Extrinsic' ist, identifiziere und gib die entsprechenden Satzabschnitte an, wobei jeder Abschnitt ein eigenes Label erhält. Ein Satz kann verschiedene Arten von Halluzinationen enthalten.
</Anweisung>

<Artikel>
Thomas ist 6 Jahre alt.
</Artikel>
<Satz>
Thomas ist 7 Jahre alt.
</Satz>
<Antwort>
{
    "Antwort": [
        {
            "Satzabschnitt": "-",
            "Label": "Faithful"
        }
    ]
}
</Antwort>


In [17]:
out = attrib_model.attribute(
    input_texts=input_prompt,
    generated_texts=output_faithful,
    # contrast_targets=output_faithful,
    # attributed_fn="contrast_prob_diff",
    # step_scores=["probability", "contrast_prob_diff"],
)

# subw_sqa_agg = out.aggregate("subwords", special_symbol=("▁", "\n")).aggregate()
#subw_viz = subw_sqa_agg.show(return_html=True, do_aggregation=False)


Attributing with saliency...: 100%|██████████| 233/233 [00:09<00:00,  4.13it/s]


In [18]:
output_attribution = out[0].aggregate("subwords")
start_indices = [i for i, source_token in enumerate(list(output_attribution.source)) if "<Artikel>" in source_token.token]
end_indices = [i for i, source_token in enumerate(output_attribution.source) if "</Satz>" in source_token.token]
input_start_pos = start_indices[0] if start_indices else -1
input_end_pos =end_indices[0] + 1 if end_indices else -1
output_attribution[input_start_pos:input_end_pos].show()

Unnamed: 0_level_0,Unnamed: 1_level_0,0,1,2,3,4,5,6,7,8,9,10
Unnamed: 0_level_1,Unnamed: 1_level_1,▁alt. </Satz> <Antwort> {,"▁▁▁▁""Antwort"":",▁[,▁▁▁▁▁▁▁▁{,"▁▁▁▁▁▁▁▁▁▁▁▁""Satzabschnitt"":","▁""-"",","▁▁▁▁▁▁▁▁▁▁▁▁""Label"":","▁""Faithful""",▁▁▁▁▁▁▁▁},▁▁▁▁] } </Antwort,>
0,▁enthalten. </Anweisung> <Artikel> Thomas,0.151,0.114,0.095,0.09,0.047,0.058,0.037,0.032,0.059,0.037,0.0
1,▁ist,0.087,0.024,0.035,0.037,0.02,0.044,0.0,0.017,0.034,0.01,0.0
2,▁6,0.057,0.019,0.032,0.036,0.019,0.054,0.0,0.015,0.023,0.01,0.0
3,▁Jahre,0.052,0.022,0.046,0.047,0.022,0.05,0.0,0.028,0.023,0.013,0.0
4,▁alt. </Artikel> <Satz> Thomas,0.184,0.092,0.116,0.113,0.071,0.117,0.121,0.049,0.076,0.058,0.0
5,▁ist,0.056,0.019,0.033,0.036,0.016,0.036,0.0,0.009,0.023,0.01,0.0
6,▁7,0.064,0.02,0.049,0.039,0.022,0.076,0.0,0.012,0.032,0.012,0.0
7,▁Jahre,0.068,0.02,0.031,0.039,0.019,0.055,0.0,0.018,0.021,0.008,0.0
8,▁alt. </Satz> <Antwort> {,0.28,0.353,0.251,0.176,0.121,0.103,0.141,0.048,0.099,0.158,0.0
9,"▁▁▁▁""Antwort"":",Unnamed: 2_level_11,0.317,0.313,0.203,0.213,0.09,0.161,0.04,0.094,0.132,0.0
10,▁[,Unnamed: 2_level_12,Unnamed: 3_level_12,Unnamed: 4_level_12,0.184,0.098,0.069,0.116,0.024,0.095,0.108,0.0
11,▁▁▁▁▁▁▁▁{,Unnamed: 2_level_13,Unnamed: 3_level_13,Unnamed: 4_level_13,Unnamed: 5_level_13,0.104,0.082,0.134,0.029,0.099,0.062,0.0
12,"▁▁▁▁▁▁▁▁▁▁▁▁""Satzabschnitt"":",Unnamed: 2_level_14,Unnamed: 3_level_14,Unnamed: 4_level_14,Unnamed: 5_level_14,0.226,0.167,0.186,0.057,0.075,0.038,0.0
13,"▁""-"",",Unnamed: 2_level_15,Unnamed: 3_level_15,Unnamed: 4_level_15,Unnamed: 5_level_15,Unnamed: 6_level_15,Unnamed: 7_level_15,0.104,0.051,0.077,0.035,0.0
14,"▁▁▁▁▁▁▁▁▁▁▁▁""Label"":",Unnamed: 2_level_16,Unnamed: 3_level_16,Unnamed: 4_level_16,Unnamed: 5_level_16,Unnamed: 6_level_16,Unnamed: 7_level_16,Unnamed: 8_level_16,0.08,0.06,0.031,0.0
15,"▁""Faithful""",Unnamed: 2_level_17,Unnamed: 3_level_17,Unnamed: 4_level_17,Unnamed: 5_level_17,Unnamed: 6_level_17,Unnamed: 7_level_17,Unnamed: 8_level_17,0.492,0.11,0.076,0.0
16,▁▁▁▁▁▁▁▁},Unnamed: 2_level_18,Unnamed: 3_level_18,Unnamed: 4_level_18,Unnamed: 5_level_18,Unnamed: 6_level_18,Unnamed: 7_level_18,Unnamed: 8_level_18,Unnamed: 9_level_18,Unnamed: 10_level_18,0.075,0.0
17,▁▁▁▁] } </Antwort,Unnamed: 2_level_19,Unnamed: 3_level_19,Unnamed: 4_level_19,Unnamed: 5_level_19,Unnamed: 6_level_19,Unnamed: 7_level_19,Unnamed: 8_level_19,Unnamed: 9_level_19,Unnamed: 10_level_19,0.126,1.0
18,>,Unnamed: 2_level_20,Unnamed: 3_level_20,Unnamed: 4_level_20,Unnamed: 5_level_20,Unnamed: 6_level_20,Unnamed: 7_level_20,Unnamed: 8_level_20,Unnamed: 9_level_20,Unnamed: 10_level_20,Unnamed: 11_level_20,Unnamed: 12_level_20


In [34]:
# test = subw_sqa_agg.sequence_attributions[0]
# source_tokens = test.source
# attribution_start_pos = [i for i, target_token in enumerate(list(test.target)) if '"Satzabschnitt":' in target_token.token][0]
# attribution_end_pos = [i for i, target_token in enumerate(list(test.target)) if '"Label":' in target_token.token][0] + 2
# start_indices = [i for i, source_token in enumerate(list(test.source)) if "<Artikel>" in source_token.token]
# end_indices = [i for i, source_token in enumerate(test.source) if "</Satz>" in source_token.token]
# input_start_pos = start_indices[0] if start_indices else -1
# input_end_pos =end_indices[0] + 1 if end_indices else -1
# offset = len(test.target) - len(test.source)
# selected_target_attributions = test.target_attributions[input_start_pos:input_end_pos, attribution_start_pos - len(test.source) - offset:attribution_end_pos - len(test.source) - offset]
# attribution_column_labels = [ test.target[index].token for index in range(attribution_start_pos, attribution_end_pos)]
# input_row_labels = [ test.source[index].token for index in range(input_start_pos, input_end_pos) ]
# assert len(input_row_labels) == selected_target_attributions.shape[0]
# assert len(attribution_column_labels) == selected_target_attributions.shape[1]

In [None]:
# # Create DataFrame
# df = pd.DataFrame(selected_target_attributions.numpy(), index=input_row_labels, columns=attribution_column_labels)
# 
# # Create heatmap
# plt.figure(figsize=(20, 8))
# ax = sns.heatmap(df, annot=True, cmap="coolwarm",
#                  linewidths=.5, cbar=True)
# 
# # Rotate labels for better fit
# plt.yticks(rotation=0, fontsize=10)
# ax.xaxis.tick_top()
# ax.set_xticklabels(ax.get_xticklabels(), rotation=30)
# # Improve layout
# plt.tight_layout()
# 
# # Show plot
# plt.show()

In [None]:
# def is_sentence_ending(tok: str):
#     return tok.endswith(".") or tok.endswith(":")
# 
# 
# start_pos = out[0].attr_pos_start
# ends = [i + 1 for i, t in enumerate(out[0].target) if is_sentence_ending(t.token) and i < start_pos - 1] + [
#     start_pos - 1]
# starts = [0] + [i + 1 for i, t in enumerate(out[0].target) if is_sentence_ending(t.token) and i < start_pos - 1]
# spans = list(zip(starts, ends))
# res = out.aggregate("spans", target_spans=spans)
# res.show()