In [None]:
# Task 5: Model Interpretability for NER

import pandas as pd
import torch
from datasets import Dataset
from transformers import (XLMRobertaTokenizerFast, XLMRobertaForTokenClassification, 
                          DistilBertTokenizerFast, DistilBertForTokenClassification,
                          BertTokenizerFast, BertForTokenClassification)
from shap import Explainer, summary_plot  # Import SHAP Explainer and summary_plot
from lime.lime_text import LimeTextExplainer

# Step 1: Load the labeled dataset in CoNLL format
def load_conll_data(file_path):
    """Loads CoNLL formatted data into a pandas DataFrame."""
    data = []
    with open(file_path, 'r', encoding='utf-8') as f:
        sentence = []
        labels = []
        for line in f:
            if line.strip():
                token, label = line.strip().split()
                sentence.append(token)
                labels.append(label)
            else:
                if sentence:  # Check if sentence is not empty before appending
                    data.append((" ".join(sentence), labels))
                    sentence = []
                    labels = []
    if sentence:  # For the last sentence if there is no newline
        data.append((" ".join(sentence), labels))

    print(f"Loaded {len(data)} sentences from the CoNLL file.")  # Debugging info
    return pd.DataFrame(data, columns=["text", "labels"])

# Load data
conll_file_path = '../output/labeled_telegram_data.conll'  
df = load_conll_data(conll_file_path)

# Load the fine-tuned model and tokenizer
model_name = "xlm-roberta-base"  # change the models name to when needed
tokenizer = XLMRobertaTokenizerFast.from_pretrained(model_name)
model = XLMRobertaForTokenClassification.from_pretrained(f'../models/{model_name}')

# Step 2: Prepare data for interpretation
texts = df["text"].tolist()[:2]  # Use the first two examples for demonstration

# Tokenization
inputs = tokenizer(texts, return_tensors="pt", truncation=True, padding=True)

# Step 3: Make predictions with the fine-tuned model
with torch.no_grad():
    outputs = model(**inputs)

# Extract logits and predicted labels
logits = outputs.logits
predicted_labels = torch.argmax(logits, dim=2)

# Prepare the predictions for LIME and SHAP
def get_predictions(text):
    """Get predictions from the model for LIME and SHAP."""
    inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
    with torch.no_grad():
        outputs = model(**inputs)
    return torch.argmax(outputs.logits, dim=2).tolist()

# Step 4: Interpret model predictions using LIME
lime_explainer = LimeTextExplainer(class_names=model.config.id2label.values())

# Function to explain a single prediction
def explain_with_lime(text):
    explanation = lime_explainer.explain_instance(
        text,
        get_predictions,
        top_labels=1,
        num_features=10
    )
    return explanation

# Analyze LIME explanations
for i in range(len(texts)):
    explanation = explain_with_lime(texts[i])
    print(f"LIME Explanation for: {texts[i]}")
    explanation.as_pyplot_figure()  # Display the explanation plot

# Step 5: Interpret model predictions using SHAP
shap_explainer = Explainer(model)

# Generate SHAP values for the first two texts
shap_values = shap_explainer(inputs['input_ids'])

# Plot SHAP values
summary_plot(shap_values, inputs['input_ids'], feature_names=tokenizer.convert_ids_to_tokens(inputs['input_ids'][0]))

print("Interpretation completed.")
