# Sentiment Classification Interpretability (Library-based)

This notebook demonstrates interpretability techniques using standard libraries:
1. LIME
2. SHAP
3. Integrated Gradients (IG) - Replaces LRP for better stability

In [None]:
!pip install transformers torch pandas matplotlib numpy lime shap -q

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F
import json
from transformers import RobertaTokenizer, AutoModelForSequenceClassification
import lime
from lime.lime_text import LimeTextExplainer
import shap
import warnings
warnings.filterwarnings("ignore")

## Load Model

In [None]:
model_path = './best_roberta_model'
model = AutoModelForSequenceClassification.from_pretrained(model_path)
tokenizer = RobertaTokenizer.from_pretrained(model_path)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
model.eval()

with open(f'{model_path}/label_mappings.json', 'r') as f:
    label_mappings = json.load(f)

label_list = label_mappings['label_list']
label2id = label_mappings['label2id']
id2label = {int(k): v for k, v in label_mappings['id2label'].items()}

def predict_proba(texts):
    if isinstance(texts, str):
        texts = [texts]
    elif isinstance(texts, np.ndarray):
        texts = texts.tolist()
    
    texts = [str(t) for t in texts]
    
    inputs = tokenizer(texts, return_tensors='pt', padding=True, truncation=True, max_length=64)
    inputs = {k: v.to(device) for k, v in inputs.items()}
    with torch.no_grad():
        outputs = model(**inputs)
        probs = F.softmax(outputs.logits, dim=-1)
    return probs.cpu().numpy()

## Test Samples

In [None]:
test_samples = [
    "This movie is absolutely fantastic! I loved every moment of it.",
    "Terrible experience, waste of time and money. Very disappointed.",
    "The product is perfect, it exceeded all my expectations.",
    "I'm extremely happy with this purchase! Highly recommend!",
    "This is the worst service I've ever encountered. Absolutely horrible.",
    "Bad staff. Never coming back.",
    "Outstanding quality! Best investment I ever made.",
    "The acting was pathetic and the plot was nonsense.",
    "I adore this place, it's magical and wonderful.",
    "Brilliant work! Truly exceptional and inspiring."
]

## LIME

In [None]:
lime_explainer = LimeTextExplainer(class_names=label_list)

def explain_with_lime(text, num_features=5):
    probs = predict_proba([text])[0]
    pred_class = int(np.argmax(probs))

    exp = lime_explainer.explain_instance(
        text,
        predict_proba,
        num_features=num_features,
        num_samples=1000,
        labels=[pred_class]
    )

    importance = exp.as_list(label=pred_class)
    importance = sorted(importance, key=lambda x: abs(x[1]), reverse=True)

    return importance, pred_class

## SHAP

In [None]:
shap_explainer = shap.Explainer(predict_proba, shap.maskers.Text(tokenizer))

def explain_with_shap(text, num_features=5):
    shap_values = shap_explainer([text])
    pred_class = np.argmax(predict_proba([text])[0])

    values = shap_values.values[0, :, pred_class]
    tokens = shap_values.data[0]

    aggregated_features = {}
    current_word = ""
    current_score = 0.0
    
    for token, score in zip(tokens, values):
        clean_tok = token.replace('Ġ', '')
        
        if token.startswith('Ġ') or not current_word:
            if current_word:
                aggregated_features[current_word] = aggregated_features.get(current_word, 0) + current_score
            current_word = clean_tok
            current_score = score
        else:
            current_word += clean_tok
            current_score += score
            
    if current_word:
        aggregated_features[current_word] = aggregated_features.get(current_word, 0) + current_score

    importance = list(aggregated_features.items())
    importance = sorted(importance, key=lambda x: abs(x[1]), reverse=True)

    return importance[:num_features], pred_class

## Integrated Gradients (IG)
Replacing simple LRP/Gradient-based heatmap with Integrated Gradients for better stability and accuracy.

In [None]:
def explain_with_ig(text, num_features=5, n_steps=50):
    inputs = tokenizer(text, return_tensors='pt', truncation=True, max_length=64)
    input_ids = inputs['input_ids'].to(device)
    attention_mask = inputs['attention_mask'].to(device)

    embedding_layer = model.get_input_embeddings()
    embeddings = embedding_layer(input_ids)
    
    with torch.no_grad():
        outputs = model(inputs_embeds=embeddings, attention_mask=attention_mask)
        probs = F.softmax(outputs.logits, dim=-1)
        pred_class = torch.argmax(probs).item()

    baseline = torch.zeros_like(embeddings).to(device)
    scaled_embeddings = [baseline + (float(i) / n_steps) * (embeddings - baseline) for i in range(n_steps + 1)]
    
    total_gradients = torch.zeros_like(embeddings)
    
    for i, input_embed in enumerate(scaled_embeddings):
        input_embed.requires_grad_(True)
        outputs = model(inputs_embeds=input_embed, attention_mask=attention_mask)
        score = outputs.logits[0, pred_class]
        
        model.zero_grad()
        score.backward()
        
        if input_embed.grad is not None:
            total_gradients += input_embed.grad
            
    avg_gradients = total_gradients / (n_steps + 1)
    attributions = (embeddings - baseline) * avg_gradients
    attributions = attributions.sum(dim=-1).squeeze(0).cpu().detach().numpy()

    tokens = tokenizer.convert_ids_to_tokens(input_ids[0].cpu().numpy())
    
    aggregated_features = {}
    current_word = ""
    current_score = 0.0
    
    for token, score in zip(tokens, attributions):
        clean_tok = token.replace('Ġ', '').replace('</w>', '')
        if token in ['<s>', '</s>', '<pad>', '']: 
            continue
            
        if token.startswith('Ġ') or not current_word:
            if current_word:
                aggregated_features[current_word] = aggregated_features.get(current_word, 0) + current_score
            current_word = clean_tok
            current_score = score
        else:
            current_word += clean_tok
            current_score += score
            
    if current_word:
        aggregated_features[current_word] = aggregated_features.get(current_word, 0) + current_score

    importance = list(aggregated_features.items())
    importance = sorted(importance, key=lambda x: abs(x[1]), reverse=True)
    return importance[:num_features], pred_class

In [None]:
results = []
for idx, text in enumerate(test_samples):
    print(f"\nSample {idx+1}: {text}")
    probs = predict_proba(text)[0]
    pred_class = np.argmax(probs)
    pred_label = id2label[pred_class]
    print(f"Prediction: {pred_label} ({probs[pred_class]:.3f})")
    
    lime_features, _ = explain_with_lime(text)
    shap_features, _ = explain_with_shap(text)
    ig_features, _ = explain_with_ig(text)
    
    results.append({
        'sample_id': idx + 1,
        'text': text,
        'prediction': pred_label,
        'confidence': probs[pred_class],
        'lime_features': lime_features[:5],
        'shap_features': shap_features[:5],
        'ig_features': ig_features[:5]
    })

In [None]:
def visualize_comparison(sample_result):
    fig, axes = plt.subplots(1, 3, figsize=(18, 6))
    method_names = ['LIME', 'SHAP', 'Integrated Gradients']
    method_keys = ['lime_features', 'shap_features', 'ig_features']
    for idx, (method_key, method_name, ax) in enumerate(zip(method_keys, method_names, axes)):
        features = sample_result[method_key]
        features_sorted = sorted(features, key=lambda x: abs(x[1]), reverse=True)
        if features_sorted:
            words, scores = zip(*features_sorted)
        else:
            words, scores = [], []
        words = list(reversed(words))
        scores = list(reversed(scores))
        
        colors = ['#2ecc71' if s > 0 else '#e74c3c' for s in scores]
        
        ax.barh(range(len(words)), scores, color=colors, alpha=0.7, edgecolor='black', linewidth=0.5)
        ax.set_yticks(range(len(words)))
        ax.set_yticklabels(words, fontsize=12, fontweight='bold')
        ax.set_title(method_name, fontsize=14, fontweight='bold', pad=10)
        ax.axvline(x=0, color='black', linewidth=1.5, linestyle='--')
        ax.grid(axis='x', alpha=0.3, linestyle=':')
        ax.set_xlabel('Importance Score', fontsize=11)
    text_preview = sample_result['text'][:80] + '...' if len(sample_result['text']) > 80 else sample_result['text']
    title = f"Sample {sample_result['sample_id']}: \"{text_preview}\"\nPrediction: {sample_result['prediction']} (Confidence: {sample_result['confidence']:.1%})"
    plt.suptitle(title, fontsize=12, fontweight='bold', y=1.0)
    plt.tight_layout()
    plt.show()
for idx in range(len(results)):
    visualize_comparison(results[idx])

In [None]:
fig = plt.figure(figsize=(18, 12))
gs = fig.add_gridspec(3, 2, hspace=0.35, wspace=0.3)
ax1 = fig.add_subplot(gs[0, :])
agreement_matrix = []
for result in results:
    lime_words = set([w.lower() for w, _ in result['lime_features']])
    shap_words = set([w.lower() for w, _ in result['shap_features']])
    ig_words = set([w.lower() for w, _ in result['ig_features']])
    agreement_matrix.append([
        len(lime_words & shap_words),
        len(lime_words & ig_words),
        len(shap_words & ig_words),
        len(lime_words & shap_words & ig_words)
    ])
agreement_matrix = np.array(agreement_matrix)
im1 = ax1.imshow(agreement_matrix.T, cmap='YlGn', aspect='auto', vmin=0, vmax=5)
ax1.set_yticks(range(4))
ax1.set_yticklabels(['L-S', 'L-IG', 'S-IG', 'All 3'], fontsize=11, fontweight='bold')
ax1.set_xticks(range(len(results)))
ax1.set_xticklabels([f"S{r['sample_id']}" for r in results], fontsize=10)
ax1.set_xlabel('Sample', fontsize=12, fontweight='bold')
ax1.set_title('Feature Agreement Between Methods', fontsize=14, fontweight='bold', pad=15)
for i in range(4):
    for j in range(len(results)):
        ax1.text(j, i, int(agreement_matrix[j, i]), ha='center', va='center', color='black', fontsize=10, fontweight='bold')
plt.colorbar(im1, ax=ax1, label='Shared Features')
ax2 = fig.add_subplot(gs[1, 0])
confidences = [r['confidence'] for r in results]
predictions = [r['prediction'] for r in results]
colors = ['#2ecc71' if p == 'Positive' else '#e74c3c' if p == 'Negative' else '#f39c12' for p in predictions]
ax2.bar(range(len(results)), confidences, color=colors, alpha=0.7, edgecolor='black')
ax2.set_ylim(0, 1.1)
ax2.set_xlabel('Sample', fontsize=12, fontweight='bold')
ax2.set_ylabel('Confidence', fontsize=12, fontweight='bold')
ax2.set_title('Prediction Confidence', fontsize=14, fontweight='bold')
ax2.set_xticks(range(len(results)))
ax2.set_xticklabels([f"S{i+1}" for i in range(len(results))])
ax2.grid(axis='y', alpha=0.3)
ax3 = fig.add_subplot(gs[1, 1])
avg_agreements = agreement_matrix.mean(axis=0)
pairs = ['L-S', 'L-IG', 'S-IG', 'All 3']
ax3.bar(pairs, avg_agreements, color=['#3498db', '#9b59b6', '#e67e22', '#e74c3c'], alpha=0.7, edgecolor='black')
ax3.set_ylabel('Avg Shared Features', fontsize=12, fontweight='bold')
ax3.set_title('Average Agreement', fontsize=14, fontweight='bold')
ax3.set_ylim(0, 5)
ax3.grid(axis='y', alpha=0.3)
for i, val in enumerate(avg_agreements):
    ax3.text(i, val + 0.1, f'{val:.1f}', ha='center', fontsize=11, fontweight='bold')
ax4 = fig.add_subplot(gs[2, :])
all_features = {}
for result in results:
    for word, _ in (result['lime_features'][:3] + result['shap_features'][:3] + result['ig_features'][:3]):
        word = word.lower().strip('.,!?')
        if word:
            all_features[word] = all_features.get(word, 0) + 1
top_features = sorted(all_features.items(), key=lambda x: x[1], reverse=True)[:20]
if top_features:
    words, counts = zip(*top_features)
    colors_freq = plt.cm.viridis(np.linspace(0.3, 0.9, len(words)))
    ax4.barh(range(len(words)), counts, color=colors_freq, alpha=0.8, edgecolor='black')
    ax4.set_yticks(range(len(words)))
    ax4.set_yticklabels(words, fontsize=12, fontweight='bold')
    ax4.set_xlabel('Frequency', fontsize=12, fontweight='bold')
    ax4.set_title('Most Influential Features', fontsize=14, fontweight='bold')
    ax4.invert_yaxis()
    ax4.grid(axis='x', alpha=0.3)
    for i, count in enumerate(counts):
        ax4.text(count + 0.3, i, f'{int(count)}', va='center', fontsize=10, fontweight='bold')
plt.suptitle('Summary Analysis', fontsize=16, fontweight='bold', y=0.995)
plt.show()
print(f"Average confidence: {np.mean(confidences):.1%}")
print(f"Average agreement (all 3): {avg_agreements[3]:.2f} features")
if top_features:
    print(f"Most influential word: '{top_features[0][0]}' ({top_features[0][1]} appearances)")