# Model Explainability: SHAP and LIME Analysis

**Project:** HEARTS Adaptation - Gender Bias Detection  
**Task:** Token-level explainability using SHAP and LIME

This notebook implements:
1. SHAP token-level importance calculations
2. LIME token-level importance calculations
3. Similarity metrics between SHAP and LIME:
   - Cosine Similarity
   - Pearson Correlation
   - Jensen-Shannon Divergence
4. Token rankings for correct and incorrect predictions
5. Visualization of explainability results

**Reference:** King, T., Wu, Z., Koshiyama, A., Kazim, E., & Treleaven, P. (2024). Hearts: A holistic framework for explainable, sustainable and robust text stereotype detection. arXiv preprint arXiv:2409.11579.


In [4]:
# Import required libraries
import pandas as pd
import numpy as np
import os
import re
import pickle
from pathlib import Path
from typing import List, Tuple, Dict, Optional
import warnings
warnings.filterwarnings('ignore')

# SHAP and LIME
import shap
from lime import lime_text
from lime.lime_text import LimeTextExplainer

# Transformers
from transformers import AutoModelForSequenceClassification, AutoTokenizer
import torch
import torch.nn.functional as F

# Visualization
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.stats import pearsonr
from scipy.spatial.distance import jensenshannon

# Set up paths
current_dir = Path.cwd()
if current_dir.name == 'notebooks':
    project_root = current_dir.parent
else:
    project_root = current_dir

data_dir = project_root / 'data'
models_dir = project_root / 'models'
results_dir = project_root / 'results'
explainability_dir = project_root / 'explainability'
paper_figures_dir = explainability_dir / 'paper_figures'
generated_figures_dir = explainability_dir / 'generated_figures'

# Create directories if they don't exist
os.makedirs(explainability_dir, exist_ok=True)
os.makedirs(paper_figures_dir, exist_ok=True)
os.makedirs(generated_figures_dir, exist_ok=True)

# Set random seed for reproducibility
np.random.seed(42)
torch.manual_seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(42)

# Check device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

print(f"\nProject root: {project_root}")
print(f"Data directory: {data_dir}")
print(f"Models directory: {models_dir}")
print(f"Explainability directory: {explainability_dir}")
print(f"Paper figures directory: {paper_figures_dir}")
print(f"Generated figures directory: {generated_figures_dir}")


ModuleNotFoundError: No module named 'shap'

## Load Test Data and Model

Load the test data and the fine-tuned ALBERT-V2 model (as per the paper):


In [5]:
# Load test data
def load_test_data(data_dir=None):
    """Load preprocessed test data"""
    if data_dir is None:
        data_dir = project_root / 'data'
    
    test_path = data_dir / 'splits' / 'test.csv'
    
    if not test_path.exists():
        raise FileNotFoundError(
            f"Test data not found. Please run 01_Data_Loading_Preprocessing.ipynb first.\n"
            f"Expected file: {test_path}"
        )
    
    test_data = pd.read_csv(test_path)
    print(f"Loaded test data: {len(test_data)} examples")
    print(f"\nTest label distribution:")
    print(test_data['label'].value_counts().sort_index())
    
    return test_data

# Load model
def load_model_for_explainability(model_dir, device='cpu'):
    """Load a fine-tuned model for explainability analysis"""
    print(f"\nLoading model from: {model_dir}")
    
    num_labels = 2  # Binary classification
    model = AutoModelForSequenceClassification.from_pretrained(
        model_dir,
        num_labels=num_labels,
        ignore_mismatched_sizes=True
    )
    tokenizer = AutoTokenizer.from_pretrained(model_dir)
    
    model = model.to(device)
    model.eval()
    
    print(f"Model loaded successfully on {device}")
    return model, tokenizer

# Load test data
test_data = load_test_data(data_dir)

# Load ALBERT-V2 model (as per paper)
albert_model_dir = models_dir / 'job_descriptions' / 'albert_albert-base-v2'
if not albert_model_dir.exists():
    raise FileNotFoundError(
        f"ALBERT-V2 model not found. Please run 02_Model_Training.ipynb first.\n"
        f"Expected directory: {albert_model_dir}"
    )

model, tokenizer = load_model_for_explainability(str(albert_model_dir), device=device)


NameError: name 'data_dir' is not defined

## Custom Regex Tokenizer

For consistency between SHAP and LIME, we use a custom regex tokenizer as mentioned in the paper:


In [None]:
def custom_regex_tokenizer(text: str) -> List[str]:
    """
    Custom regex tokenizer for consistent tokenization between SHAP and LIME.
    This tokenizer splits text on whitespace and punctuation while preserving tokens.
    """
    # Split on whitespace and punctuation, but keep the tokens
    tokens = re.findall(r'\b\w+\b|[^\w\s]', text)
    return tokens

# Test the tokenizer
sample_text = "We are looking for a strong leader who can manage teams effectively."
tokens = custom_regex_tokenizer(sample_text)
print(f"Sample text: {sample_text}")
print(f"Tokens: {tokens}")
print(f"Number of tokens: {len(tokens)}")


## Model Prediction Function

Wrapper function for model predictions that will be used by SHAP and LIME:


In [None]:
def model_predict(texts: List[str], model, tokenizer, device='cpu') -> np.ndarray:
    """
    Predict probability of positive class (biased) for a list of texts.
    
    Parameters:
    -----------
    texts : List[str]
        List of input texts
    model : AutoModelForSequenceClassification
        Fine-tuned model
    tokenizer : AutoTokenizer
        Tokenizer
    device : str
        Device to run inference on
    
    Returns:
    --------
    probs : np.ndarray
        Array of probabilities for positive class (shape: [len(texts)])
    """
    if isinstance(texts, str):
        texts = [texts]
    
    # Tokenize
    encoded = tokenizer(
        texts,
        padding=True,
        truncation=True,
        max_length=512,
        return_tensors='pt'
    )
    
    # Move to device
    encoded = {k: v.to(device) for k, v in encoded.items()}
    
    # Predict
    with torch.no_grad():
        outputs = model(**encoded)
        logits = outputs.logits
        probs = F.softmax(logits, dim=-1)
        # Return probability of positive class (label 1)
        positive_probs = probs[:, 1].cpu().numpy()
    
    return positive_probs


## Test Model Prediction Function

Test the prediction function after model is loaded:


In [None]:
# Test the prediction function (requires model and tokenizer from Cell 3)
if 'model' in globals() and 'tokenizer' in globals():
    sample_texts = ["We are looking for a strong leader.", "The team needs a collaborative member."]
    probs = model_predict(sample_texts, model, tokenizer, device=device)
    print(f"Sample predictions:")
    for text, prob in zip(sample_texts, probs):
        print(f"  Text: {text}")
        print(f"  Probability (biased): {prob:.4f}")
else:
    print("Note: Model and tokenizer not loaded yet. Please run Cell 3 first to load the model.")


## SHAP Token-Level Importance

Calculate SHAP values for token-level importance. According to the paper, SHAP values are calculated using:

$$\phi_{ij} = \sum_{S \subseteq N_i \setminus \{j\}} \frac{|S|!(|N_i| - |S| - 1)!}{|N_i|!} [f_i(S \cup \{j\}) - f_i(S)]$$

where $\phi_i = (\phi_{i1}, \phi_{i2}, \ldots, \phi_{iN})$ is the SHAP vector for instance $i$.


In [None]:
def calculate_shap_values(
    texts: List[str],
    model,
    tokenizer,
    device='cpu',
    max_evals: int = 100,
    batch_size: int = 10
) -> List[np.ndarray]:
    """
    Calculate SHAP values for a list of texts.
    
    Parameters:
    -----------
    texts : List[str]
        List of input texts
    model : AutoModelForSequenceClassification
        Fine-tuned model
    tokenizer : AutoTokenizer
        Tokenizer
    device : str
        Device to run inference on
    max_evals : int
        Maximum number of evaluations for SHAP (default: 100)
    batch_size : int
        Batch size for processing
    
    Returns:
    --------
    shap_values_list : List[np.ndarray]
        List of SHAP value arrays, one per text
    """
    # Create a wrapper function for SHAP
    def predict_wrapper(texts_input):
        """Wrapper for model prediction that SHAP can use"""
        if isinstance(texts_input, str):
            texts_input = [texts_input]
        return model_predict(texts_input, model, tokenizer, device)
    
    # Create SHAP explainer
    # Using TextExplainer with custom tokenizer
    explainer = shap.Explainer(predict_wrapper, tokenizer, output_names=['Non-Biased', 'Biased'])
    
    shap_values_list = []
    
    print(f"Calculating SHAP values for {len(texts)} texts...")
    for i, text in enumerate(texts):
        if (i + 1) % 10 == 0:
            print(f"  Processing text {i+1}/{len(texts)}")
        
        # Calculate SHAP values
        shap_values = explainer([text], max_evals=max_evals)
        
        # Extract values for positive class (biased)
        # SHAP returns values for all classes, we want the positive class
        if hasattr(shap_values, 'values'):
            # For text explainer, values might be structured differently
            values = shap_values.values[0] if len(shap_values.values.shape) > 1 else shap_values.values
            # If multi-class, take the positive class values
            if len(values.shape) > 1:
                values = values[:, 1]  # Positive class
            shap_values_list.append(values)
        else:
            # Fallback: use the explanation object directly
            shap_values_list.append(shap_values.values.flatten())
    
    print(f"Completed SHAP calculation for {len(texts)} texts")
    return shap_values_list

# Test on a small sample
print("Testing SHAP calculation on a small sample...")
sample_texts = test_data['text'].head(3).tolist()
shap_values_sample = calculate_shap_values(sample_texts, model, tokenizer, device=device, max_evals=50)
print(f"\nSHAP values calculated for {len(shap_values_sample)} texts")
for i, (text, shap_vals) in enumerate(zip(sample_texts, shap_values_sample)):
    print(f"\nText {i+1}: {text[:50]}...")
    print(f"  SHAP values shape: {shap_vals.shape}")
    print(f"  Top 3 tokens by absolute SHAP value: {np.argsort(np.abs(shap_vals))[-3:][::-1]}")


In [None]:
def calculate_lime_values(
    texts: List[str],
    model,
    tokenizer,
    device='cpu',
    num_features: int = 20,
    num_samples: int = 5000
) -> List[np.ndarray]:
    """
    Calculate LIME values for a list of texts.
    
    Parameters:
    -----------
    texts : List[str]
        List of input texts
    model : AutoModelForSequenceClassification
        Fine-tuned model
    tokenizer : AutoTokenizer
        Tokenizer
    device : str
        Device to run inference on
    num_features : int
        Number of top features to explain (default: 20)
    num_samples : int
        Number of samples for LIME (default: 5000)
    
    Returns:
    --------
    lime_values_list : List[np.ndarray]
        List of LIME value arrays, one per text
    """
    # Create LIME explainer with custom tokenizer
    class CustomTokenizer:
        """Custom tokenizer wrapper for LIME"""
        def __call__(self, text):
            return custom_regex_tokenizer(text)
    
    explainer = LimeTextExplainer(class_names=['Non-Biased', 'Biased'], split_expression=CustomTokenizer())
    
    def predict_proba_wrapper(texts_input):
        """Wrapper for model prediction that LIME can use"""
        if isinstance(texts_input, str):
            texts_input = [texts_input]
        probs = model_predict(texts_input, model, tokenizer, device)
        # LIME expects probabilities for all classes
        # Return [prob_non_biased, prob_biased]
        return np.column_stack([1 - probs, probs])
    
    lime_values_list = []
    
    print(f"Calculating LIME values for {len(texts)} texts...")
    for i, text in enumerate(texts):
        if (i + 1) % 10 == 0:
            print(f"  Processing text {i+1}/{len(texts)}")
        
        # Get LIME explanation
        explanation = explainer.explain_instance(
            text,
            predict_proba_wrapper,
            num_features=num_features,
            num_samples=num_samples
        )
        
        # Extract LIME values
        # Get all features and their scores
        exp_list = explanation.as_list()
        
        # Create a mapping from token to LIME value
        tokens = custom_regex_tokenizer(text)
        lime_dict = {token: 0.0 for token in tokens}
        
        # Map LIME explanations to tokens
        for feature, score in exp_list:
            # Feature might be a token or a phrase
            feature_tokens = custom_regex_tokenizer(feature)
            if len(feature_tokens) == 1:
                # Single token
                if feature_tokens[0] in lime_dict:
                    lime_dict[feature_tokens[0]] = score
            else:
                # Multi-token feature - distribute score
                for token in feature_tokens:
                    if token in lime_dict:
                        lime_dict[token] += score / len(feature_tokens)
        
        # Convert to array matching token order
        lime_values = np.array([lime_dict.get(token, 0.0) for token in tokens])
        lime_values_list.append(lime_values)
    
    print(f"Completed LIME calculation for {len(texts)} texts")
    return lime_values_list

# Test on a small sample
print("Testing LIME calculation on a small sample...")
lime_values_sample = calculate_lime_values(sample_texts, model, tokenizer, device=device, num_samples=1000)
print(f"\nLIME values calculated for {len(lime_values_sample)} texts")
for i, (text, lime_vals) in enumerate(zip(sample_texts, lime_values_sample)):
    print(f"\nText {i+1}: {text[:50]}...")
    print(f"  LIME values shape: {lime_vals.shape}")
    print(f"  Top 3 tokens by absolute LIME value: {np.argsort(np.abs(lime_vals))[-3:][::-1]}")


In [None]:
def cosine_similarity(vec1: np.ndarray, vec2: np.ndarray) -> float:
    """Calculate cosine similarity between two vectors"""
    dot_product = np.dot(vec1, vec2)
    norm1 = np.linalg.norm(vec1)
    norm2 = np.linalg.norm(vec2)
    
    if norm1 == 0 or norm2 == 0:
        return 0.0
    
    return dot_product / (norm1 * norm2)


def pearson_correlation(vec1: np.ndarray, vec2: np.ndarray) -> float:
    """Calculate Pearson correlation between two vectors"""
    if len(vec1) != len(vec2) or len(vec1) < 2:
        return 0.0
    
    corr, _ = pearsonr(vec1, vec2)
    return corr if not np.isnan(corr) else 0.0


def jensen_shannon_divergence(vec1: np.ndarray, vec2: np.ndarray) -> float:
    """
    Calculate Jensen-Shannon Divergence between two vectors.
    
    As per the paper:
    P_j = (φ_{ij} + |Min(φ_i)|) / Σ(φ_{ij} + |Min(φ_i)|)
    Q_j = (β_{ij} + |Min(β_i)|) / Σ(β_{ij} + |Min(β_i)|)
    JSD = sqrt(0.5 * Σ P_j log(P_j / (P_j/2 + Q_j/2)) + 0.5 * Σ Q_j log(Q_j / (P_j/2 + Q_j/2)))
    """
    # Normalize to make them probability distributions
    vec1_min = np.min(vec1)
    vec1_shifted = vec1 + abs(vec1_min)
    vec1_sum = np.sum(vec1_shifted)
    if vec1_sum == 0:
        P = np.ones_like(vec1) / len(vec1)
    else:
        P = vec1_shifted / vec1_sum
    
    vec2_min = np.min(vec2)
    vec2_shifted = vec2 + abs(vec2_min)
    vec2_sum = np.sum(vec2_shifted)
    if vec2_sum == 0:
        Q = np.ones_like(vec2) / len(vec2)
    else:
        Q = vec2_shifted / vec2_sum
    
    # Calculate JSD
    M = (P + Q) / 2
    
    # Avoid log(0)
    epsilon = 1e-10
    P = np.clip(P, epsilon, 1)
    Q = np.clip(Q, epsilon, 1)
    M = np.clip(M, epsilon, 1)
    
    kl_pm = np.sum(P * np.log(P / M))
    kl_qm = np.sum(Q * np.log(Q / M))
    
    jsd = np.sqrt(0.5 * kl_pm + 0.5 * kl_qm)
    
    return jsd


def calculate_similarity_metrics(
    shap_values: np.ndarray,
    lime_values: np.ndarray
) -> Dict[str, float]:
    """
    Calculate all similarity metrics between SHAP and LIME vectors.
    
    Parameters:
    -----------
    shap_values : np.ndarray
        SHAP values for a text instance
    lime_values : np.ndarray
        LIME values for the same text instance
    
    Returns:
    --------
    metrics : Dict[str, float]
        Dictionary containing cosine similarity, Pearson correlation, and JSD
    """
    # Ensure same length (pad or truncate if necessary)
    min_len = min(len(shap_values), len(lime_values))
    shap_aligned = shap_values[:min_len]
    lime_aligned = lime_values[:min_len]
    
    metrics = {
        'cosine_similarity': cosine_similarity(shap_aligned, lime_aligned),
        'pearson_correlation': pearson_correlation(shap_aligned, lime_aligned),
        'jensen_shannon_divergence': jensen_shannon_divergence(shap_aligned, lime_aligned)
    }
    
    return metrics

# Test similarity metrics
print("Testing similarity metrics...")
if len(shap_values_sample) > 0 and len(lime_values_sample) > 0:
    test_metrics = calculate_similarity_metrics(shap_values_sample[0], lime_values_sample[0])
    print(f"\nSimilarity metrics for first sample:")
    for metric_name, value in test_metrics.items():
        print(f"  {metric_name}: {value:.4f}")


In [None]:
def run_explainability_analysis(
    test_data: pd.DataFrame,
    model,
    tokenizer,
    device='cpu',
    sample_size: Optional[int] = None,
    max_evals_shap: int = 100,
    num_samples_lime: int = 5000
) -> pd.DataFrame:
    """
    Run complete explainability analysis on test data.
    
    Parameters:
    -----------
    test_data : pd.DataFrame
        Test data with 'text' and 'label' columns
    model : AutoModelForSequenceClassification
        Fine-tuned model
    tokenizer : AutoTokenizer
        Tokenizer
    device : str
        Device to run inference on
    sample_size : Optional[int]
        Number of samples to analyze (None for all)
    max_evals_shap : int
        Maximum evaluations for SHAP
    num_samples_lime : int
        Number of samples for LIME
    
    Returns:
    --------
    results_df : pd.DataFrame
        DataFrame with SHAP values, LIME values, predictions, and similarity metrics
    """
    # Sample data if specified
    if sample_size is not None and sample_size < len(test_data):
        test_data = test_data.sample(n=sample_size, random_state=42).reset_index(drop=True)
        print(f"Sampled {sample_size} examples for analysis")
    
    texts = test_data['text'].tolist()
    true_labels = test_data['label'].values
    
    # Get predictions
    print("\n" + "="*60)
    print("Step 1: Getting model predictions...")
    print("="*60)
    predictions = model_predict(texts, model, tokenizer, device)
    predicted_labels = (predictions > 0.5).astype(int)
    
    # Identify correct and incorrect predictions
    correct_mask = (predicted_labels == true_labels)
    incorrect_mask = ~correct_mask
    
    print(f"\nPredictions summary:")
    print(f"  Total examples: {len(texts)}")
    print(f"  Correct predictions: {correct_mask.sum()}")
    print(f"  Incorrect predictions: {incorrect_mask.sum()}")
    
    # Calculate SHAP values
    print("\n" + "="*60)
    print("Step 2: Calculating SHAP values...")
    print("="*60)
    shap_values_list = calculate_shap_values(
        texts, model, tokenizer, device, max_evals=max_evals_shap
    )
    
    # Calculate LIME values
    print("\n" + "="*60)
    print("Step 3: Calculating LIME values...")
    print("="*60)
    lime_values_list = calculate_lime_values(
        texts, model, tokenizer, device, num_samples=num_samples_lime
    )
    
    # Calculate similarity metrics
    print("\n" + "="*60)
    print("Step 4: Calculating similarity metrics...")
    print("="*60)
    similarity_metrics_list = []
    for i, (shap_vals, lime_vals) in enumerate(zip(shap_values_list, lime_values_list)):
        if (i + 1) % 100 == 0:
            print(f"  Processing {i+1}/{len(texts)}")
        metrics = calculate_similarity_metrics(shap_vals, lime_vals)
        similarity_metrics_list.append(metrics)
    
    # Create results DataFrame
    results_df = pd.DataFrame({
        'text': texts,
        'true_label': true_labels,
        'predicted_label': predicted_labels,
        'predicted_probability': predictions,
        'correct_prediction': correct_mask,
        'shap_values': shap_values_list,
        'lime_values': lime_values_list,
        'cosine_similarity': [m['cosine_similarity'] for m in similarity_metrics_list],
        'pearson_correlation': [m['pearson_correlation'] for m in similarity_metrics_list],
        'jensen_shannon_divergence': [m['jensen_shannon_divergence'] for m in similarity_metrics_list]
    })
    
    print("\n" + "="*60)
    print("Analysis Complete!")
    print("="*60)
    print(f"\nSimilarity metrics summary:")
    print(f"  Cosine Similarity - Mean: {results_df['cosine_similarity'].mean():.4f}, Std: {results_df['cosine_similarity'].std():.4f}")
    print(f"  Pearson Correlation - Mean: {results_df['pearson_correlation'].mean():.4f}, Std: {results_df['pearson_correlation'].std():.4f}")
    print(f"  Jensen-Shannon Divergence - Mean: {results_df['jensen_shannon_divergence'].mean():.4f}, Std: {results_df['jensen_shannon_divergence'].std():.4f}")
    
    return results_df

# Run analysis on a sample (use 1005 as mentioned in paper, or smaller for testing)
print("Starting explainability analysis...")
print("Note: This may take a while. Using a small sample for initial testing.")
results_df = run_explainability_analysis(
    test_data,
    model,
    tokenizer,
    device=device,
    sample_size=50,  # Start with 50 for testing, increase to 1005 for full analysis
    max_evals_shap=50,  # Reduced for faster testing
    num_samples_lime=1000  # Reduced for faster testing
)


## Token Rankings for Correct and Incorrect Predictions

Generate token rankings based on SHAP and LIME values for correct and incorrect predictions:


In [None]:
def get_token_rankings(
    results_df: pd.DataFrame,
    top_k: int = 10
) -> Dict[str, pd.DataFrame]:
    """
    Get token rankings for correct and incorrect predictions.
    
    Parameters:
    -----------
    results_df : pd.DataFrame
        Results from explainability analysis
    top_k : int
        Number of top tokens to extract
    
    Returns:
    --------
    rankings : Dict[str, pd.DataFrame]
        Dictionary with rankings for correct/incorrect predictions
    """
    rankings = {}
    
    # Separate correct and incorrect predictions
    correct_df = results_df[results_df['correct_prediction'] == True].copy()
    incorrect_df = results_df[results_df['correct_prediction'] == False].copy()
    
    def extract_top_tokens(row, method='shap', top_k=top_k):
        """Extract top K tokens by importance"""
        text = row['text']
        tokens = custom_regex_tokenizer(text)
        
        if method == 'shap':
            values = row['shap_values']
        else:  # lime
            values = row['lime_values']
        
        # Ensure same length
        min_len = min(len(tokens), len(values))
        tokens = tokens[:min_len]
        values = values[:min_len]
        
        # Get top K by absolute value
        top_indices = np.argsort(np.abs(values))[-top_k:][::-1]
        top_tokens = [tokens[i] for i in top_indices]
        top_values = [values[i] for i in top_indices]
        
        return {
            'tokens': top_tokens,
            'values': top_values,
            'text': text
        }
    
    # Extract top tokens for correct predictions
    if len(correct_df) > 0:
        correct_shap = correct_df.apply(
            lambda row: extract_top_tokens(row, 'shap', top_k), axis=1
        )
        correct_lime = correct_df.apply(
            lambda row: extract_top_tokens(row, 'lime', top_k), axis=1
        )
        
        rankings['correct_shap'] = pd.DataFrame({
            'text': correct_df['text'].values,
            'true_label': correct_df['true_label'].values,
            'predicted_label': correct_df['predicted_label'].values,
            'top_tokens': [x['tokens'] for x in correct_shap],
            'top_values': [x['values'] for x in correct_shap]
        })
        
        rankings['correct_lime'] = pd.DataFrame({
            'text': correct_df['text'].values,
            'true_label': correct_df['true_label'].values,
            'predicted_label': correct_df['predicted_label'].values,
            'top_tokens': [x['tokens'] for x in correct_lime],
            'top_values': [x['values'] for x in correct_lime]
        })
    
    # Extract top tokens for incorrect predictions
    if len(incorrect_df) > 0:
        incorrect_shap = incorrect_df.apply(
            lambda row: extract_top_tokens(row, 'shap', top_k), axis=1
        )
        incorrect_lime = incorrect_df.apply(
            lambda row: extract_top_tokens(row, 'lime', top_k), axis=1
        )
        
        rankings['incorrect_shap'] = pd.DataFrame({
            'text': incorrect_df['text'].values,
            'true_label': incorrect_df['true_label'].values,
            'predicted_label': incorrect_df['predicted_label'].values,
            'top_tokens': [x['tokens'] for x in incorrect_shap],
            'top_values': [x['values'] for x in incorrect_shap]
        })
        
        rankings['incorrect_lime'] = pd.DataFrame({
            'text': incorrect_df['text'].values,
            'true_label': incorrect_df['true_label'].values,
            'predicted_label': incorrect_df['predicted_label'].values,
            'top_tokens': [x['tokens'] for x in incorrect_lime],
            'top_values': [x['values'] for x in incorrect_lime]
        })
    
    return rankings

# Get token rankings
print("Extracting token rankings...")
rankings = get_token_rankings(results_df, top_k=10)

print(f"\nRankings extracted:")
for key in rankings.keys():
    print(f"  {key}: {len(rankings[key])} examples")

# Display sample rankings
if 'correct_shap' in rankings and len(rankings['correct_shap']) > 0:
    print("\n" + "="*60)
    print("Sample: Top tokens for a correct prediction (SHAP)")
    print("="*60)
    sample_idx = 0
    sample = rankings['correct_shap'].iloc[sample_idx]
    print(f"Text: {sample['text'][:100]}...")
    print(f"True label: {sample['true_label']}, Predicted: {sample['predicted_label']}")
    print(f"Top tokens: {sample['top_tokens']}")
    print(f"Top values: {[f'{v:.4f}' for v in sample['top_values']]}")

if 'incorrect_shap' in rankings and len(rankings['incorrect_shap']) > 0:
    print("\n" + "="*60)
    print("Sample: Top tokens for an incorrect prediction (SHAP)")
    print("="*60)
    sample_idx = 0
    sample = rankings['incorrect_shap'].iloc[sample_idx]
    print(f"Text: {sample['text'][:100]}...")
    print(f"True label: {sample['true_label']}, Predicted: {sample['predicted_label']}")
    print(f"Top tokens: {sample['top_tokens']}")
    print(f"Top values: {[f'{v:.4f}' for v in sample['top_values']]}")


In [None]:
def plot_similarity_metrics(results_df: pd.DataFrame, save_path: Optional[Path] = None):
    """Plot similarity metrics distribution"""
    fig, axes = plt.subplots(1, 3, figsize=(15, 4))
    
    metrics = ['cosine_similarity', 'pearson_correlation', 'jensen_shannon_divergence']
    titles = ['Cosine Similarity', 'Pearson Correlation', 'Jensen-Shannon Divergence']
    
    for ax, metric, title in zip(axes, metrics, titles):
        ax.hist(results_df[metric], bins=30, edgecolor='black', alpha=0.7)
        ax.axvline(results_df[metric].mean(), color='red', linestyle='--', 
                   label=f'Mean: {results_df[metric].mean():.4f}')
        ax.set_xlabel(title)
        ax.set_ylabel('Frequency')
        ax.set_title(f'{title} Distribution')
        ax.legend()
        ax.grid(True, alpha=0.3)
    
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        print(f"Saved figure to: {save_path}")
    
    plt.show()


def plot_correct_vs_incorrect_similarity(results_df: pd.DataFrame, save_path: Optional[Path] = None):
    """Plot similarity metrics comparison between correct and incorrect predictions"""
    fig, axes = plt.subplots(1, 3, figsize=(15, 4))
    
    correct_df = results_df[results_df['correct_prediction'] == True]
    incorrect_df = results_df[results_df['correct_prediction'] == False]
    
    metrics = ['cosine_similarity', 'pearson_correlation', 'jensen_shannon_divergence']
    titles = ['Cosine Similarity', 'Pearson Correlation', 'Jensen-Shannon Divergence']
    
    for ax, metric, title in zip(axes, metrics, titles):
        data_to_plot = [correct_df[metric].dropna(), incorrect_df[metric].dropna()]
        ax.boxplot(data_to_plot, labels=['Correct', 'Incorrect'])
        ax.set_ylabel(title)
        ax.set_title(f'{title}: Correct vs Incorrect')
        ax.grid(True, alpha=0.3)
    
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        print(f"Saved figure to: {save_path}")
    
    plt.show()


def plot_token_importance_example(
    text: str,
    shap_values: np.ndarray,
    lime_values: np.ndarray,
    top_k: int = 10,
    save_path: Optional[Path] = None
):
    """Plot token importance comparison for a single example"""
    tokens = custom_regex_tokenizer(text)
    
    # Align values
    min_len = min(len(tokens), len(shap_values), len(lime_values))
    tokens = tokens[:min_len]
    shap_aligned = shap_values[:min_len]
    lime_aligned = lime_values[:min_len]
    
    # Get top K tokens
    top_indices = np.argsort(np.abs(shap_aligned))[-top_k:][::-1]
    
    top_tokens = [tokens[i] for i in top_indices]
    top_shap = [shap_aligned[i] for i in top_indices]
    top_lime = [lime_aligned[i] for i in top_indices]
    
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 6))
    
    # SHAP plot
    colors_shap = ['red' if v < 0 else 'green' for v in top_shap]
    ax1.barh(range(len(top_tokens)), top_shap, color=colors_shap, alpha=0.7)
    ax1.set_yticks(range(len(top_tokens)))
    ax1.set_yticklabels(top_tokens)
    ax1.set_xlabel('SHAP Value')
    ax1.set_title(f'Top {top_k} Tokens by SHAP Importance')
    ax1.grid(True, alpha=0.3, axis='x')
    ax1.axvline(0, color='black', linestyle='-', linewidth=0.5)
    
    # LIME plot
    colors_lime = ['red' if v < 0 else 'green' for v in top_lime]
    ax2.barh(range(len(top_tokens)), top_lime, color=colors_lime, alpha=0.7)
    ax2.set_yticks(range(len(top_tokens)))
    ax2.set_yticklabels(top_tokens)
    ax2.set_xlabel('LIME Value')
    ax2.set_title(f'Top {top_k} Tokens by LIME Importance')
    ax2.grid(True, alpha=0.3, axis='x')
    ax2.axvline(0, color='black', linestyle='-', linewidth=0.5)
    
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        print(f"Saved figure to: {save_path}")
    
    plt.show()

# Create visualizations
print("Creating visualizations...")

# 1. Similarity metrics distribution
plot_similarity_metrics(
    results_df,
    save_path=generated_figures_dir / 'similarity_metrics_distribution.png'
)

# 2. Correct vs Incorrect comparison
plot_correct_vs_incorrect_similarity(
    results_df,
    save_path=generated_figures_dir / 'similarity_correct_vs_incorrect.png'
)

# 3. Example token importance
if len(results_df) > 0:
    sample_row = results_df.iloc[0]
    plot_token_importance_example(
        sample_row['text'],
        sample_row['shap_values'],
        sample_row['lime_values'],
        top_k=10,
        save_path=generated_figures_dir / 'token_importance_example.png'
    )


## Save Results

Save the explainability analysis results:


In [None]:
# Save results
results_output_dir = results_dir / 'job_descriptions' / 'explainability'
os.makedirs(results_output_dir, exist_ok=True)

# Save main results (convert arrays to lists for JSON serialization)
results_to_save = results_df.copy()
results_to_save['shap_values'] = results_to_save['shap_values'].apply(lambda x: x.tolist())
results_to_save['lime_values'] = results_to_save['lime_values'].apply(lambda x: x.tolist())

# Save as CSV (without array columns for readability)
results_csv = results_df.copy()
results_csv['shap_values'] = results_csv['shap_values'].apply(lambda x: str(x.tolist()))
results_csv['lime_values'] = results_csv['lime_values'].apply(lambda x: str(x.tolist()))
results_csv.to_csv(results_output_dir / 'explainability_results.csv', index=False)

# Save full results with arrays as pickle
with open(results_output_dir / 'explainability_results_full.pkl', 'wb') as f:
    pickle.dump({
        'results_df': results_df,
        'rankings': rankings,
        'similarity_summary': {
            'cosine_similarity': {
                'mean': float(results_df['cosine_similarity'].mean()),
                'std': float(results_df['cosine_similarity'].std())
            },
            'pearson_correlation': {
                'mean': float(results_df['pearson_correlation'].mean()),
                'std': float(results_df['pearson_correlation'].std())
            },
            'jensen_shannon_divergence': {
                'mean': float(results_df['jensen_shannon_divergence'].mean()),
                'std': float(results_df['jensen_shannon_divergence'].std())
            }
        }
    }, f)

# Save rankings
for key, ranking_df in rankings.items():
    ranking_df.to_csv(results_output_dir / f'token_rankings_{key}.csv', index=False)

print(f"\nResults saved to: {results_output_dir}")
print(f"  - explainability_results.csv")
print(f"  - explainability_results_full.pkl")
for key in rankings.keys():
    print(f"  - token_rankings_{key}.csv")

print(f"\nFigures saved to: {generated_figures_dir}")
print(f"  - similarity_metrics_distribution.png")
print(f"  - similarity_correct_vs_incorrect.png")
print(f"  - token_importance_example.png")
