# SHAP Explanations for DistilBERT

This notebook reuses the helper utilities in `src.analysis.shap_utils` to load text samples, compute SHAP values, and visualize the most influential tokens for the fine-tuned classifier.

In [None]:
from pathlib import Path
import sys

PROJECT_ROOT = Path.cwd()
if not (PROJECT_ROOT / "src").exists():
    PROJECT_ROOT = PROJECT_ROOT.parent
if str(PROJECT_ROOT) not in sys.path:
    sys.path.append(str(PROJECT_ROOT))
PROJECT_ROOT

WindowsPath('c:/Users/Bruker/Documents/Notater/NTNU/TDT13_NLP/TDT13_AI_text_prediction')

In [None]:
import matplotlib.pyplot as plt
import shap

from src.analysis.shap_utils import (
    ShapConfig,
    aggregate_token_importance,
    build_text_classifier,
    compute_shap_values,
    load_test_texts,
    summarize_examples,
)

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
cfg = ShapConfig(
    checkpoint_path=Path("models/distilbert-debug"),
    num_samples=32,
    data_limit=2000,
    sample_seed=42,
    test_ratio=0.2,
)
cfg

ShapConfig(checkpoint_path=WindowsPath('models/distilbert-debug'), num_samples=32, data_limit=2000, sample_seed=42, test_ratio=0.2, algorithm='partition')

In [None]:
texts = load_test_texts(cfg)
clf = build_text_classifier(cfg.checkpoint_path)
shap_values = compute_shap_values(clf, texts, algorithm=cfg.algorithm)
len(texts)

HFValidationError: Repo id must be in the form 'repo_name' or 'namespace/repo_name': 'C:/Users/Bruker/Documents/Notater/NTNU/TDT13_NLP/TDT13_AI_text_prediction/notebooks/models/distilbert-debug'. Use `repo_type` argument if needed.

In [None]:
token_importance = aggregate_token_importance(shap_values, top_k=20)
token_importance

In [None]:
example_summary = summarize_examples(clf, shap_values, texts, top_k=5)
example_summary.head()

In [None]:
if not token_importance.empty:
    plt.figure(figsize=(8, 4))
    plt.barh(token_importance["token"][::-1], token_importance["mean_abs_score"][::-1])
    plt.xlabel("Mean |SHAP| contribution")
    plt.title("Top token drivers across samples")
    plt.tight_layout()
    plt.show()
else:
    print("No aggregated token importance available.")

In [None]:
shap.plots.text(shap_values[0])

In [None]:
if len(texts) > 1:
    shap.plots.text(shap_values[1])