# Token Attribution Analysis with Fair Forge

This notebook demonstrates how to use Fair Forge's explainability module to compute and visualize token attributions for language model responses.

Token attribution helps answer: **"Which parts of the input influenced the model's output the most?"**

## Prerequisites

Install Fair Forge with explainability support:
```bash
pip install alquimia-fair-forge[explainability]
```

In [None]:
# Install dependencies (uncomment if needed)
# %pip install alquimia-fair-forge[explainability] transformers torch -q

## 1. Setup: Load Model and Tokenizer

We'll use a small model for demonstration. The explainability module works with any HuggingFace causal language model.

In [None]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

# Disable autocast for compatibility with attribution methods
torch.set_autocast_enabled(False)

# Load a small model (Qwen3-0.6B for this example)
repo_id = "Qwen/Qwen3-0.6B"

tokenizer = AutoTokenizer.from_pretrained(repo_id)
model = AutoModelForCausalLM.from_pretrained(
    repo_id,
    torch_dtype=torch.float16  # Important: use float16 for attribution methods
)

print(f"Loaded model: {repo_id}")

## 2. Import Fair Forge Explainability Module

The module uses class-based attribution methods for extensibility. You pass method classes (like `Lime`, `Saliency`) directly to the explainer.

In [None]:
from fair_forge.explainability import (
    AttributionExplainer,
    Granularity,
    # Method classes
    Lime,
    Occlusion,
    Saliency,
    IntegratedGradients,
    compute_attributions,
)

# Create an explainer instance
explainer = AttributionExplainer(
    model=model,
    tokenizer=tokenizer,
    default_method=Lime,            # Pass the class directly
    default_granularity=Granularity.WORD,
    verbose=True,
)

print("Explainer created!")

## 3. Format Your Prompt

**Important:** You are responsible for formatting prompts according to your model's requirements. This keeps the explainability module focused on attribution computation and avoids coupling with specific LLM prompt formats.

In [None]:
# Define the conversation
messages = [
    {
        "role": "system",
        "content": "Answer concisely with 3 bullet points."
    },
    {
        "role": "user",
        "content": "Explain Albert Einstein's theory of relativity"
    }
]

# Format the prompt using your tokenizer (model-specific)
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
print("Formatted prompt:")
print(prompt[:500] + "..." if len(prompt) > 500 else prompt)

## 4. Compute Attributions

Now compute attributions for a target response.

In [None]:
# The model's actual response (target text to explain)
target = """
- Einstein proposed that time and space are not absolute but relative.
- He introduced spacetime curvature, explaining gravity as distortion.
- His equations revolutionized physics and led to E=mc^2.
"""

# Compute attributions
result = explainer.explain(
    prompt=prompt,
    target=target,
    method=Lime,  # Pass the class directly
)

print(f"Computed {len(result.attributions)} attributions")
print(f"Method: {result.method.value}")
print(f"Granularity: {result.granularity.value}")

## 5. Analyze Top Contributing Tokens/Words

In [None]:
# Get top 10 most important tokens
top_10 = result.get_top_k(10)

print("Top 10 Most Important Words/Tokens:")
print("-" * 40)
for i, attr in enumerate(top_10, 1):
    norm = f"{attr.normalized_score:.4f}" if attr.normalized_score else "N/A"
    print(f"{i:2}. '{attr.text:15}' | score: {attr.score:+.4f} | normalized: {norm}")

## 6. Visualize Attributions

The explainer can display an interactive visualization of the attributions.

In [None]:
# Display the attribution visualization
explainer.visualize(result)

## 7. Compare Different Attribution Methods

Fair Forge supports multiple attribution methods as classes. Let's compare a few.

In [None]:
# Define a simpler example for faster computation
simple_messages = [{"role": "user", "content": "What is the capital of France?"}]
simple_prompt = tokenizer.apply_chat_template(simple_messages, tokenize=False, add_generation_prompt=True)
simple_target = "The capital of France is Paris."

# Compare different methods (pass the class directly)
methods_to_compare = [Lime, Occlusion]

for method_class in methods_to_compare:
    print(f"\n{'='*50}")
    print(f"Method: {method_class.name.upper()}")
    print(f"{'='*50}")
    
    result = explainer.explain(
        prompt=simple_prompt,
        target=simple_target,
        method=method_class,
    )
    
    print(f"\nTop 5 contributing words:")
    for attr in result.get_top_k(5):
        print(f"  '{attr.text}': {attr.score:+.4f}")

## 8. Batch Processing

Process multiple prompt/target pairs at once.

In [None]:
# Define multiple prompt/target pairs (prompts should be pre-formatted)
qa_pairs = [
    (
        tokenizer.apply_chat_template(
            [{"role": "user", "content": "What is machine learning?"}],
            tokenize=False, add_generation_prompt=True
        ),
        "Machine learning is a subset of AI that enables systems to learn from data."
    ),
    (
        tokenizer.apply_chat_template(
            [{"role": "user", "content": "What is deep learning?"}],
            tokenize=False, add_generation_prompt=True
        ),
        "Deep learning uses neural networks with many layers to process complex patterns."
    ),
]

# Compute attributions for all pairs
batch_results = explainer.explain_batch(qa_pairs)

print(f"Processed {len(batch_results)} items")
print(f"Total compute time: {batch_results.total_compute_time_seconds:.2f}s")

# Show top words for each
for i, result in enumerate(batch_results):
    print(f"\nQ&A #{i+1}:")
    print(f"  Top 3 words: {[attr.text for attr in result.get_top_k(3)]}")

## 9. Using the Convenience Function

For one-off attributions, use the `compute_attributions` function.

In [None]:
# Format prompt
prompt = tokenizer.apply_chat_template(
    [{"role": "user", "content": "Why is the sky blue?"}],
    tokenize=False, add_generation_prompt=True
)

# Quick one-liner for attribution computation
quick_result = compute_attributions(
    model=model,
    tokenizer=tokenizer,
    prompt=prompt,
    target="The sky appears blue due to Rayleigh scattering of sunlight.",
    method=Lime,
    granularity=Granularity.WORD,
)

print("Top 5 contributing words:")
for attr in quick_result.get_top_k(5):
    print(f"  '{attr.text}': {attr.score:+.4f}")

## 10. Export Results for Further Analysis

In [None]:
# Export to dictionary format
viz_data = quick_result.to_dict_for_visualization()
print("Visualization data structure:")
print(f"  tokens: {viz_data['tokens'][:5]}...")
print(f"  scores: {viz_data['scores'][:5]}...")

# Export full result as JSON-compatible dict
result_dict = quick_result.model_dump()
print(f"\nFull result keys: {list(result_dict.keys())}")

## Summary

This notebook demonstrated:

1. **Setup**: Loading a model and creating an `AttributionExplainer`
2. **Prompt formatting**: Using `tokenizer.apply_chat_template()` (user responsibility)
3. **Single attributions**: Using `explainer.explain()` with pre-formatted prompts
4. **Top-K analysis**: Finding the most important tokens with `get_top_k()`
5. **Visualization**: Interactive display with `explainer.visualize()`
6. **Method comparison**: Comparing `Lime`, `Occlusion`, and other methods (class-based)
7. **Batch processing**: Processing multiple items with `explain_batch()`
8. **Convenience function**: Quick one-off attribution with `compute_attributions()`
9. **Export**: Getting data for further analysis

### Available Attribution Method Classes

**Gradient-based** (faster, require differentiable models):
- `Saliency`, `IntegratedGradients`, `GradientShap`, `SmoothGrad`, `SquareGrad`, `VarGrad`, `InputXGradient`

**Perturbation-based** (model-agnostic, more robust):
- `Lime`, `KernelShap`, `Occlusion`, `Sobol`

### Granularity Options
- `Granularity.TOKEN`: Individual tokens (finest granularity)
- `Granularity.WORD`: Word-level (recommended for interpretability)
- `Granularity.SENTENCE`: Sentence-level (coarsest granularity)

### Key Design Decisions
- **Pre-formatted prompts**: Users format prompts according to their model's requirements
- **Class-based methods**: Attribution methods are classes for extensibility
- **Parser interface**: Custom parsers can be implemented for different attribution libraries