# Token Attribution Analysis with Fair Forge

This notebook demonstrates how to use Fair Forge's explainability module to compute and visualize token attributions for language model responses.

Token attribution helps answer: **"Which parts of the input influenced the model's output the most?"**

## Prerequisites

Install Fair Forge with explainability support:
```bash
pip install alquimia-fair-forge[explainability]
```

In [4]:
# Install dependencies (uncomment if needed)
#%pip install alquimia-fair-forge[explainability] transformers torch -q
#!uv pip install --python {sys.executable} --force-reinstall "$(ls ../../../dist/*.whl)[explainability]" -q
#%pip install transformers torch -q


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.0[0m[39;49m -> [0m[32;49m26.0.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m
Note: you may need to restart the kernel to use updated packages.


## 1. Setup: Load Model and Tokenizer

We'll use a small model for demonstration. The explainability module works with any HuggingFace causal language model.

In [6]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

# Disable autocast for compatibility with attribution methods
torch.set_autocast_enabled(False)

# Load a small model (Qwen3-0.6B for this example)
repo_id = "Qwen/Qwen3-0.6B"

tokenizer = AutoTokenizer.from_pretrained(repo_id)
model = AutoModelForCausalLM.from_pretrained(
    repo_id,
    torch_dtype=torch.float16  # Important: use float16 for attribution methods
)

print(f"Loaded model: {repo_id}")

Exception: data did not match any variant of untagged enum ModelWrapper at line 757479 column 3

## 2. Import Fair Forge Explainability Module

In [None]:
from fair_forge.explainability import (
    AttributionExplainer,
    AttributionMethod,
    Granularity,
    compute_attributions,
)

# Create an explainer instance
explainer = AttributionExplainer(
    model=model,
    tokenizer=tokenizer,
    default_method=AttributionMethod.LIME,  # LIME is a good default
    default_granularity=Granularity.WORD,   # Word-level is most interpretable
    verbose=True,
)

print("Explainer created!")

## 3. Compute Attributions for a Single Response

Let's compute attributions for a simple Q&A interaction.

In [None]:
# Define the conversation
messages = [
    {
        "role": "system",
        "content": "Answer concisely with 3 bullet points."
    },
    {
        "role": "user",
        "content": "Explain Albert Einstein's theory of relativity"
    }
]

# The model's actual response (target text to explain)
target = """
- Einstein proposed that time and space are not absolute but relative, influencing how we perceive the universe.
- He introduced the concept of spacetime curvature, which explains gravity as a distortion of spacetime.
- His equations, like the general and special theories, have revolutionized our understanding of physics.
"""

# Compute attributions
result = explainer.explain(
    messages=messages,
    target=target,
    method=AttributionMethod.LIME,
)

print(f"Computed {len(result.attributions)} attributions")
print(f"Method: {result.method.value}")
print(f"Granularity: {result.granularity.value}")

## 4. Analyze Top Contributing Tokens/Words

In [None]:
# Get top 10 most important tokens
top_10 = result.get_top_k(10)

print("Top 10 Most Important Words/Tokens:")
print("-" * 40)
for i, attr in enumerate(top_10, 1):
    print(f"{i:2}. '{attr.text:15}' | score: {attr.score:+.4f} | normalized: {attr.normalized_score:.4f}")

## 5. Visualize Attributions

The explainer can display an interactive visualization of the attributions.

In [None]:
# Display the attribution visualization
explainer.visualize(result)

## 6. Compare Different Attribution Methods

Fair Forge supports multiple attribution methods. Let's compare a few.

In [None]:
# Define a simpler example for faster computation
simple_messages = [{"role": "user", "content": "What is the capital of France?"}]
simple_target = "The capital of France is Paris."

# Compare different methods
methods_to_compare = [
    AttributionMethod.LIME,
    AttributionMethod.OCCLUSION,
]

for method in methods_to_compare:
    print(f"\n{'='*50}")
    print(f"Method: {method.value.upper()}")
    print(f"{'='*50}")
    
    result = explainer.explain(
        messages=simple_messages,
        target=simple_target,
        method=method,
    )
    
    print(f"\nTop 5 contributing words:")
    for attr in result.get_top_k(5):
        print(f"  '{attr.text}': {attr.score:+.4f}")

## 7. Batch Processing

Process multiple Q&A pairs at once.

In [None]:
# Define multiple Q&A pairs
qa_pairs = [
    (
        [{"role": "user", "content": "What is machine learning?"}],
        "Machine learning is a subset of AI that enables systems to learn from data."
    ),
    (
        [{"role": "user", "content": "What is deep learning?"}],
        "Deep learning uses neural networks with many layers to process complex patterns."
    ),
]

# Compute attributions for all pairs
batch_results = explainer.explain_batch(qa_pairs)

print(f"Processed {len(batch_results)} items")
print(f"Total compute time: {batch_results.total_compute_time_seconds:.2f}s")

# Show top words for each
for i, result in enumerate(batch_results):
    print(f"\nQ&A #{i+1}:")
    print(f"  Top 3 words: {[attr.text for attr in result.get_top_k(3)]}")

## 8. Using the Convenience Function

For one-off attributions, use the `compute_attributions` function.

In [None]:
# Quick one-liner for attribution computation
quick_result = compute_attributions(
    model=model,
    tokenizer=tokenizer,
    messages=[{"role": "user", "content": "Why is the sky blue?"}],
    target="The sky appears blue due to Rayleigh scattering of sunlight.",
    method=AttributionMethod.LIME,
    granularity=Granularity.WORD,
)

print("Top 5 contributing words:")
for attr in quick_result.get_top_k(5):
    print(f"  '{attr.text}': {attr.score:+.4f}")

## 9. Export Results for Further Analysis

In [None]:
# Export to dictionary format
viz_data = quick_result.to_dict_for_visualization()
print("Visualization data structure:")
print(f"  tokens: {viz_data['tokens'][:5]}...")
print(f"  scores: {viz_data['scores'][:5]}...")

# Export full result as JSON-compatible dict
result_dict = quick_result.model_dump()
print(f"\nFull result keys: {list(result_dict.keys())}")

## Summary

This notebook demonstrated:

1. **Setup**: Loading a model and creating an `AttributionExplainer`
2. **Single attributions**: Using `explainer.explain()` for individual Q&A pairs
3. **Top-K analysis**: Finding the most important tokens with `get_top_k()`
4. **Visualization**: Interactive display with `explainer.visualize()`
5. **Method comparison**: Comparing LIME, Occlusion, and other methods
6. **Batch processing**: Processing multiple items with `explain_batch()`
7. **Convenience function**: Quick one-off attribution with `compute_attributions()`
8. **Export**: Getting data for further analysis

### Available Attribution Methods

**Gradient-based** (faster, require differentiable models):
- `SALIENCY`, `INTEGRATED_GRADIENTS`, `GRADIENT_SHAP`, `SMOOTH_GRAD`, `SQUARE_GRAD`, `VAR_GRAD`, `INPUT_X_GRADIENT`

**Perturbation-based** (model-agnostic, more robust):
- `LIME`, `KERNEL_SHAP`, `OCCLUSION`, `SOBOL`

### Granularity Options
- `TOKEN`: Individual tokens (finest granularity)
- `WORD`: Word-level (recommended for interpretability)
- `SENTENCE`: Sentence-level (coarsest granularity)