# Model Interpretability for Amharic NER
This notebook demonstrates how to use SHAP and LIME to interpret NER model predictions.

In [ ]:
import os
import numpy as np
import torch
import shap
from lime.lime_text import LimeTextExplainer
import matplotlib.pyplot as plt
from transformers import AutoTokenizer, AutoModelForTokenClassification
from src.utils.ner_data_utils import parse_conll, build_label_maps


In [ ]:
# Load model and tokenizer
MODEL_DIR = 'notebooks/amharicnermodel'  # Adjust as needed
model = AutoModelForTokenClassification.from_pretrained(MODEL_DIR)
tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR)
model.eval()


In [ ]:
# Load and inspect example data
conll_path = '../data/raw/labeled_cnll_manual.txt'
sentences, ner_tags = parse_conll(conll_path)
label2id, id2label = build_label_maps(ner_tags)
examples = [' '.join(sent) for sent in sentences[:5]]
print('Examples:', examples)


In [ ]:
# Prediction wrapper for SHAP/LIME
def predict_proba(texts):
    if isinstance(texts, np.ndarray):
        texts = texts.tolist()
    if isinstance(texts, str):
        texts = [texts]
    if not isinstance(texts, list):
        texts = list(texts)
    inputs = tokenizer(texts, return_tensors='pt', padding=True, truncation=True, is_split_into_words=False)
    with torch.no_grad():
        outputs = model(**inputs).logits
        probs = torch.softmax(outputs, dim=-1).cpu().numpy()
    results = []
    for i, text in enumerate(texts):
        token_ids = inputs['input_ids'][i]
        tokens = tokenizer.convert_ids_to_tokens(token_ids)
        word_ids = inputs.word_ids(batch_index=i)
        mask = []
        seen = set()
        for idx, wid in enumerate(word_ids):
            if wid is not None and wid not in seen:
                mask.append(idx)
                seen.add(wid)
        results.append(probs[i][mask])
    avg_probs = [np.mean(r, axis=0) for r in results]
    return np.array(avg_probs)


In [ ]:
# SHAP Explanation
os.makedirs('./results', exist_ok=True)
explainer = shap.Explainer(predict_proba, tokenizer)
shap_values = explainer(examples)
for i, example in enumerate(examples):
    shap.plots.text(shap_values[i], display=False)
    plt.savefig(f'./results/shap_text_{i}.png')
    plt.close()
print('SHAP text plots saved for each example in ./results/')


In [ ]:
# LIME Explanation
class_names = list(label2id.keys())
lime_explainer = LimeTextExplainer(class_names=class_names)
for i, example in enumerate(examples):
    exp = lime_explainer.explain_instance(
        example,
        predict_proba,
        num_features=10,
        num_samples=100
    )
    exp.save_to_file(f'./results/lime_explanation_{i}.html')
    print(f'LIME explanation for example {i} saved to ./results/lime_explanation_{i}.html')
