# Task5_Model_Interpretability.ipynb

## Import dependencies

In [None]:
# Objective: Use model interpretability tools (LIME & SHAP) to explain NER model predictions.

# --- Mount Google Drive ---
from google.colab import drive
drive.mount('/content/drive')

# --- Step 1: Install Necessary Libraries ---
!pip install transformers datasets seqeval accelerate evaluate
!pip install lime shap

# IMPORTANT: After running this cell, if prompted, click "Restart runtime"
# and then "Run all cells" to ensure all libraries are correctly loaded.

# --- Step 2: Import Libraries ---
import os
import json
import numpy as np
import pandas as pd
import torch
from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline
from lime.lime_text import LimeTextExplainer
import shap

# --- Configuration for Model Loading ---

In [None]:
# IMPORTANT: Adjust 'colab_projects/EthioMart_NER' to your desired base path in Google Drive
DRIVE_PROJECT_BASE_PATH = "/content/drive/MyDrive/colab_projects/EthioMart_NER"

# Path to the best performing model saved in your Google Drive from Task 4
# This path should ideally be read from a file generated by Task 4, or hardcoded if Task 4 already ran.
# For direct execution, we hardcode it based on previous Task 4 output.
BEST_MODEL_PATH = os.path.join(DRIVE_PROJECT_BASE_PATH, "XLM-R-Amharic-NER_ner_output/final_model")
# Alternatively, read from file if it exists:
# try:
#     with open(os.path.join(DRIVE_PROJECT_BASE_PATH, "best_model_path.txt"), "r") as f:
#         BEST_MODEL_PATH = f.read().strip()
#     print(f"Loaded best model path from file: {BEST_MODEL_PATH}")
# except FileNotFoundError:
#     print("best_model_path.txt not found. Using default BEST_MODEL_PATH.")


# Define your entity types (must match what you used for training)
LABEL_NAMES = ["O", "B-PRODUCT", "I-PRODUCT", "B-LOC", "I-LOC", "B-PRICE", "I-PRICE"]

# --- Step 3: Load the Best Fine-Tuned Model and Tokenizer ---
print(f"Loading best model and tokenizer from: {BEST_MODEL_PATH}")
try:
    model = AutoModelForTokenClassification.from_pretrained(BEST_MODEL_PATH)
    tokenizer = AutoTokenizer.from_pretrained(BEST_MODEL_PATH)
    ner_pipeline = pipeline("ner", model=model, tokenizer=tokenizer, aggregation_strategy="simple")
    
    id2label = model.config.id2label
    label2id = {v: k for k, v in id2label.items()}
    print(f"Model's ID to Label mapping: {id2label}")
    print(f"Model's Label to ID mapping: {label2id}")
    model_loaded = True
except Exception as e:
    print(f"Error loading best model: {e}. Skipping interpretability tasks.")
    model_loaded = False


if model_loaded:
    # --- Step 4: Prepare Custom Prediction Functions for LIME/SHAP ---

    def predict_proba_for_token_classification(texts: list[str], target_token_idx: int = 0) -> np.ndarray:
        """
        Custom prediction function for LIME.
        It takes a list of texts, tokenizes them, gets model predictions,
        and returns probabilities for the TARGET_TOKEN_IDX.
        This is a simplification for LIME on a single token for demonstration.
        """
        tokenized_inputs = tokenizer(
            texts,
            padding=True,
            truncation=True,
            return_tensors="pt"
        ).to(model.device)

        with torch.no_grad():
            logits = model(**tokenized_inputs).logits

        probabilities = torch.softmax(logits, dim=-1).cpu().detach().numpy()

        output_probs = []
        for i in range(len(texts)):
            if 1 < probabilities.shape[1]:
                output_probs.append(probabilities[i, 1, :])
            else:
                output_probs.append(np.zeros(len(LABEL_NAMES)))
        return np.array(output_probs)

    # --- Step 5: Implement LIME (Local Interpretable Model-agnostic Explanations) ---

    print("\n--- LIME Explanations ---")
    explainer_lime = LimeTextExplainer(class_names=LABEL_NAMES)

    example_sentence_lime = "አዲስ ስልክ iPhone 15 Pro Max ዋጋው 70000 ብር ሲሆን በአዲስ አበባ ይገኛል"
    pipe_result_lime = ner_pipeline(example_sentence_lime)
    print(f"Pipeline prediction for LIME example: {pipe_result_lime}")

    word_to_explain_lime = "iPhone"
    if word_to_explain_lime in example_sentence_lime:
        target_label_id_for_lime = label2id.get("B-PRODUCT", 0)
        explanation_lime = explainer_lime.explain_instance(
            text_instance=example_sentence_lime,
            classifier_fn=lambda texts: predict_proba_for_token_classification(texts, target_token_idx=1),
            labels=[target_label_id_for_lime],
            num_features=5,
            num_samples=1000
        )
        print(f"\nLIME Explanation for '{word_to_explain_lime}' as '{id2label[target_label_id_for_lime]}' in: '{example_sentence_lime}'")
        for feature, weight in explanation_lime.as_list(label=target_label_id_for_lime):
            print(f"  - '{feature}': {weight:.4f}")
    else:
        print(f"'{word_to_explain_lime}' not found in the example sentence for LIME explanation.")

    # --- SHAP Explanations ---
    print("\n--- SHAP Explanations ---")

    def ner_shap_predict_fn(texts: list[str]) -> np.ndarray:
        """Prediction function for SHAP: returns flattened probabilities."""
        all_flat_probs = []
        for text in texts:
            inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True).to(model.device)
            with torch.no_grad():
                logits = model(**inputs).logits
            probs = torch.softmax(logits, dim=-1).cpu().numpy()
            flat_probs = probs.reshape(-1)
            all_flat_probs.append(flat_probs)
        max_flat_len = max(len(p) for p in all_flat_probs)
        padded_flat_probs = []
        for p in all_flat_probs:
            padded_flat_probs.append(np.pad(p, (0, max_flat_len - len(p)), 'constant'))
        return np.array(padded_flat_probs)

    # Define a small background dataset for SHAP KernelExplainer
    shap_background_text = "ስልክ ዋጋ ብር አዲስ አበባ"
    background_data_for_shap = [shap_background_text]

    # Initialize the SHAP Text masker with the tokenizer and background data.
    masker_shap = shap.maskers.Text(tokenizer, background_data_for_shap)

    # Create a SHAP explainer using the high-level API
    explainer_shap = shap.Explainer(
        model=ner_shap_predict_fn,
        masker=masker_shap, # Pass the initialized masker object
        link="identity" # Use "identity" as predict_fn returns probabilities
    )

    example_sentence_shap = "አዲስ ስልክ iPhone 15 Pro Max ዋጋው 70000 ብር ሲሆን በአዲስ አበባ ይገኛል"

    try:
        shap_values = explainer_shap([example_sentence_shap])
        
        print(f"\nSHAP Explanation for: '{example_sentence_shap}'")
        print("For detailed visualization, run `shap.plots.text(shap_values)` in a Jupyter environment.")
        print("Alternatively, inspect `shap_values.data` (tokens) and `shap_values.values` (importance scores).")
        
        if hasattr(shap_values, 'data') and hasattr(shap_values, 'values') and len(shap_values.values) > 0:
            loc_entity_from_pipeline = None
            for entity in ner_pipeline(example_sentence_shap):
                if entity['entity_group'] == 'LOC':
                    loc_entity_from_pipeline = entity
                    break

            if loc_entity_from_pipeline:
                print(f"\nFocusing on: {loc_entity_from_pipeline['word']} (Predicted as {loc_entity_from_pipeline['entity_group']})")
                text_tokens = shap_values.data[0]
                token_shap_values_for_labels = shap_values.values[0]
                
                print("Approximate Feature importances (word: SHAP value summed across all labels):")
                if len(text_tokens) == token_shap_values_for_labels.shape[0]:
                    for i in range(len(text_tokens)):
                        total_contrib = np.sum(np.abs(token_shap_values_for_labels[i, :]))
                        print(f"  '{text_tokens[i]}': {total_contrib:.4f}")
                else:
                    print("Cannot precisely map SHAP values to text tokens for simple print. Use `shap.plots.text`.")
            else:
                print("No LOC entity found in pipeline prediction for SHAP example. Cannot provide targeted SHAP explanation.")

    except Exception as e:
        print(f"Error during SHAP explanation: {e}")
        print("SHAP for token classification is complex. Ensure prediction function output matches SHAP's expectations.")
        print("If error persists, consider simpler examples or using `shap.plots.text()` in Jupyter.")
else:
    print("\nSkipping interpretability tasks as the model could not be loaded.")
