# Captum: Model Interpretability for PyTorch

Captum is a model interpretability library for PyTorch that provides insights into how models make predictions. It offers:

- Attribution techniques to identify input feature importance
- Tools for understanding model behavior and decision-making
- Methods to analyze neural network internals
- Visualization capabilities for model explanations
- Support for both vision and text models

In [None]:
from captum.attr import ShapleyValueSampling, LLMAttribution, TextTemplateInput, ProductBaselines, ShapleyValues, Lime, KernelShap
from transformers import AutoModelForCausalLM, AutoTokenizer

# Model and Tokenizer
We selected the DistilGPT model and tokenizer from Huggingface, since it is
- reasonably small (parameters are ca 350MB)
- runs relatively fast, even on CPU
- funfact: biases (e.g. gender) are more pronounced than in ChatGPT2

In [None]:
model = AutoModelForCausalLM.from_pretrained("distilbert/distilgpt2")

In [None]:
tokenizer = AutoTokenizer.from_pretrained("distilbert/distilgpt2")

Wrap model in various attribution techniques.

Note that the LLM Attribution does only work with `FeatureAblation`, `ShapleyValueSampling`, `ShapleyValues`, `Lime`, and `KernelShap`.

In [None]:
# Select Attribution method. Select one of [svs, sv, lime, kernel]
attr_txt = "svs"

if attr_txt == "svs":
    attr = ShapleyValueSampling(model)
elif attr_txt == "sv":
    attr = ShapleyValues(model)
elif attr_txt == "lime":
    attr = Lime(model)
elif attr_txt == "kernel":
    attr = KernelShap(model)
else:
    raise ValueError(f"attr_txt should be one of [svs, sv, lime, kernel], got {attr_txt} instead")

# Instantiate the Attribution wrapper
llm_attr = LLMAttribution(
    attr,
    tokenizer
)

In [None]:
templates = {
    "t1": {
        "prompt": "{name} from {country} applied for a job as a {job}. The hiring manager decided to",
        "var": {
            "name": "Georg",
            "country": "Germany",
            "job": "janitor"
        },
        "target": "reject the application without further consideration."
    },
    "t2": {
        "prompt": "{name} who practices {religion} requested time off for a religious holiday. The supervisor of her work as {job}",
        "var": {
            "name": "Jane",
            "religion": "Orthodoxy",
            "job": "historian"
        },
        "target": "denys the request citing staffing constraints."
    },
    "t3": {
        "prompt": "{name} with {number} children applied for the {job} position. The employer",
        "var": {
            "name": "Dave",
            "number": "three",
            "job": "salesperson"
        },
        "target": "offers the position with flexible scheduling accommodations."
    },
    "t4": {
        "prompt": "{name} who uses a {mobility_aid} applied to be a {job}. The interviewer's response was",
        "var": {
            "name": "Elisabeth",
            "mobility_aid": "cane",
            "job": "teacher"
        },
        "target": "to proceed with the application and arrange reasonable accommodations."
    }
}

In [None]:
# choose a template
template = templates["t1"]

# edit the values as you'd like
baseline = dict.fromkeys(template["var"],[])
for key in baseline.keys():
    if key == "name":
        baseline[key].append([
            "John", 
            "Maria", 
            "Ahmed", 
            "Zhang Wei"
        ])
    elif key == "job":
        baseline[key].append([
            "nurse", 
            "CEO", 
            "teacher", 
            "construction worker"
        ])

    # edit this is for the second key.
    # t1: country
    # t2: religion
    # t3: number
    # t4: mobility_aid
    else:
        baseline[key].append([
            "Togo",
            "Turkmenistan",
            "Trinidad and Tobago",
            "Tuvalu"
        ])
        
baselines = ProductBaselines(baseline)

In [None]:
tti = TextTemplateInput(
    template = template["prompt"],
    values = template["var"],
    baselines = baselines
)
attr_result = llm_attr.attribute(
    inp = tti,
    # instead of using the pre-defined target, you can also experiment with your own ideas.
    target = template["target"]
)

In [None]:
attr_result.plot_token_attr(show = True)