In [1]:
%load_ext autoreload
%autoreload 2

In [1]:
import sys

import matplotlib.colors as mcolors
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

sys.path.append("../..")
from interpreto.attributions.base import AttributionOutput
from interpreto.attributions.methods import OcclusionExplainer
from interpreto.commons.granularity import GranularityLevel
from interpreto.visualizations.attributions.classification_highlight import GenerationAttributionVisualization

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
model = AutoModelForCausalLM.from_pretrained("gpt2")
tokenizer = AutoTokenizer.from_pretrained("gpt2")

explainer = OcclusionExplainer(
    model=model, batch_size=4, tokenizer=tokenizer, granularity_level=GranularityLevel.ALL_TOKENS
)

attribution_outputs = explainer.explain(
    model_inputs="Hi there, how are you?", mode="softmax", generation_kwargs={"max_length": 10}
)

# mode in {"logits", "softmax", "log_softmax"}

Setting `pad_token_id` to `eos_token_id`:None for open-end generation.


In [5]:
attribution_outputs

[AttributionOutput(attributions=tensor([[-4.7965e-02,  3.3612e-03,  1.0128e-01,  1.0324e-01,  1.6909e-02,
           5.6971e-03,  2.3930e-01,  0.0000e+00, -8.9407e-08,  0.0000e+00,
          -8.9407e-08,  0.0000e+00, -8.9407e-08,  0.0000e+00, -8.9407e-08,
           0.0000e+00, -8.9407e-08,  0.0000e+00, -8.9407e-08, -5.3942e-06],
         [-3.1621e-03, -1.2726e-04, -2.3127e-05,  1.6425e-03, -6.6048e-04,
           7.0919e-03, -3.6035e-03,  9.6907e-01, -1.1921e-07,  0.0000e+00,
          -1.1921e-07,  0.0000e+00, -1.1921e-07,  0.0000e+00, -1.1921e-07,
           0.0000e+00, -1.1921e-07,  0.0000e+00, -1.1921e-07, -8.3447e-07],
         [ 1.3944e-02,  8.4414e-03,  1.6453e-02,  2.0449e-02,  2.2180e-02,
           2.3977e-02,  3.8142e-02,  7.0758e-02,  6.7674e-02,  0.0000e+00,
          -7.4506e-09,  0.0000e+00, -7.4506e-09,  0.0000e+00, -7.4506e-09,
           0.0000e+00, -7.4506e-09,  0.0000e+00, -7.4506e-09, -1.1474e-06],
         [ 1.6756e-01, -4.8457e-02,  9.2429e-04,  8.9635e-02,  1.8

In [4]:
attribution_outputs[0].attributions.shape, len(attribution_outputs[0].elements)

(torch.Size([13, 20]), 20)

In [10]:
test

NameError: name 'test' is not defined

In [6]:
viz = GenerationAttributionVisualization(
    attribution_output=attribution_outputs[0],
    color=mcolors.to_rgb("orange"),
    highlight_border=False,
    normalize=False,
    css=".common-word-style {font-size: 1.5em}",
)
viz.display()

AssertionError: The attribution shape (13) does not match the number of elements (20)

In [3]:
inputs_sentence = ["This", " ", "is an", " example of ", "a", " ", "sentence"]
outputs_sentence = [
    "A",
    "B",
    "C",
    " ",
    "look",
    " ",
    "behind",
    " ",
    "you",
    " : ",
    "there",
    " ",
    "is",
    " ",
    "a",
    " ",
    "monster",
    " ",
    "under",
    " ",
    "the",
    " ",
    "bed",
    "!!!",
]
nb_concepts = 1


def make_attributions_outputs(inputs, outputs):
    attributions = torch.rand(len(inputs) + len(outputs), len(outputs))  # (l, l_g)
    return AttributionOutput(elements=inputs + outputs, attributions=attributions)


generation_output = make_attributions_outputs(inputs_sentence, outputs_sentence)
print(f"{generation_output.attributions.shape = }")

generation_output.attributions[0, 0] = (
    10.0  # normalization check: when selecting output 0, the intput 0 should be the most important
)

generation_output.attributions.shape = torch.Size([31, 24])


In [4]:
viz = GenerationAttributionVisualization(
    attribution_output=generation_output,
    color=mcolors.to_rgb("orange"),
    highlight_border=False,
    normalize=False,
    css=".common-word-style {font-size: 1.5em}",
)
viz.display()

In [None]:
# viz.save("attributions_generation.html")