# Load Model, Tokenizer, and Sample Data

In [None]:
!pip install shap lime transformers datasets torch --quiet
from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline
import torch

model_path = "../models/xlm-roberta-ner"  
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForTokenClassification.from_pretrained(model_path)

ner_pipeline = pipeline("ner", model=model, tokenizer=tokenizer, aggregation_strategy="simple")


# Run NER on Sample Text

In [None]:
sample_text = "3pcs silicon brush spatulas ዋጋ-550ብር ቦሌ አዲስ አበባ"

ner_results = ner_pipeline(sample_text)
for entity in ner_results:
    print(f"{entity['word']} → {entity['entity_group']} (score: {round(entity['score'], 2)})")


# Explain with LIME

In [None]:
from lime.lime_text import LimeTextExplainer

explainer = LimeTextExplainer(class_names=model.config.id2label.values())

def predict_proba(texts):
    all_probs = []
    for text in texts:
        tokens = tokenizer(text, return_tensors="pt", truncation=True)
        with torch.no_grad():
            outputs = model(**tokens).logits
        probs = torch.nn.functional.softmax(outputs, dim=-1)
        mean_probs = probs[0].mean(dim=0).numpy()
        all_probs.append(mean_probs)
    return all_probs

# Explain sample
exp = explainer.explain_instance(sample_text, predict_proba, num_features=10)
exp.show_in_notebook(text=sample_text)


# Use SHAP for Global Feature Insight

In [None]:
import shap

# Explainer setup
def tokenize_for_shap(texts):
    return tokenizer(texts, padding=True, truncation=True, return_tensors="pt")

shap.initjs()

# Wrap model
class WrappedModel:
    def __call__(self, texts):
        inputs = tokenize_for_shap(texts)
        with torch.no_grad():
            logits = model(**inputs).logits
            return logits.softmax(dim=-1).mean(dim=1).numpy()

wrapped_model = WrappedModel()
explainer = shap.Explainer(wrapped_model, tokenizer)

shap_values = explainer([sample_text])
shap.plots.text(shap_values[0])


# Analyze Ambiguous or Incorrect Examples

In [None]:
wrong_preds = [
    (tokens, predicted_tags, true_tags)
    for tokens, predicted_tags, true_tags in zip(..., ..., ...)
    if predicted_tags != true_tags
]
