# Sentiment Classification Interpretability (Library-based)

This notebook demonstrates interpretability techniques using standard libraries:
1. LIME
2. SHAP
3. Integrated Gradients (IG) - Replaces LRP for better stability

In [None]:
!pip install transformers torch pandas matplotlib numpy lime shap -q

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F
import json
from transformers import RobertaTokenizer, AutoModelForSequenceClassification
import lime
from lime.lime_text import LimeTextExplainer
import shap
import warnings
warnings.filterwarnings("ignore")

## Load Model

In [None]:
model_path = './best_roberta_model'
model = AutoModelForSequenceClassification.from_pretrained(model_path)
tokenizer = RobertaTokenizer.from_pretrained(model_path)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
model.eval()

with open(f'{model_path}/label_mappings.json', 'r') as f:
    label_mappings = json.load(f)

label_list = label_mappings['label_list']
label2id = label_mappings['label2id']
id2label = {int(k): v for k, v in label_mappings['id2label'].items()}

def predict_proba(texts):
    if isinstance(texts, str):
        texts = [texts]
    elif isinstance(texts, np.ndarray):
        texts = texts.tolist()
    
    texts = [str(t) for t in texts]
    
    inputs = tokenizer(texts, return_tensors='pt', padding=True, truncation=True, max_length=64)
    inputs = {k: v.to(device) for k, v in inputs.items()}
    with torch.no_grad():
        outputs = model(**inputs)
        probs = F.softmax(outputs.logits, dim=-1)
    return probs.cpu().numpy()

## Test Samples

In [None]:
test_samples = [
    "This movie is absolutely fantastic I loved every moment of it",
    "The state of corruption in our society is utterly disgusting",
    "The product is perfect it exceeded all my expectations",
    "I'm extremely happy with this purchase Highly recommend",
    "This is the worst service I have ever encountered Absolutely horrible",
    "Bad staff never coming back",
    "Outstanding quality Best investment I ever made",
    "The acting was pathetic and the plot was nonsense",
    "I adore this place it is magical and wonderful",
    "Envy poisons my thoughts coveting others success"
]

## LIME

In [None]:
lime_explainer = LimeTextExplainer(class_names=label_list)

def explain_with_lime(text, num_features=5):
    probs = predict_proba([text])[0]
    pred_class = int(np.argmax(probs))

    exp = lime_explainer.explain_instance(
        text,
        predict_proba,
        num_features=num_features,
        num_samples=1000,
        labels=[pred_class]
    )

    importance = exp.as_list(label=pred_class)
    importance = sorted(importance, key=lambda x: abs(x[1]), reverse=True)

    return importance, pred_class

## SHAP

In [None]:
shap_explainer = shap.Explainer(predict_proba, shap.maskers.Text(tokenizer))

def explain_with_shap(text, num_features=5):
    shap_values = shap_explainer([text])
    pred_class = np.argmax(predict_proba([text])[0])

    values = shap_values.values[0, :, pred_class]
    tokens = shap_values.data[0]

    word_scores = []
    for token, score in zip(tokens, values):
        clean_word = str(token).replace('Ġ', '').replace('\u0120', '').replace('▁', '').strip()
        if clean_word and clean_word not in ['<s>', '</s>', '<pad>']:
            word_scores.append((clean_word, score))

    word_scores.sort(key=lambda x: abs(x[1]), reverse=True)

    return word_scores[:num_features], pred_class

## Integrated Gradients (IG)

In [None]:
def explain_with_ig(text, num_features=5, n_steps=50):
    inputs = tokenizer(text, return_tensors='pt', truncation=True, max_length=64)
    input_ids = inputs['input_ids'].to(device)
    attention_mask = inputs['attention_mask'].to(device)

    embedding_layer = model.get_input_embeddings()
    embeddings = embedding_layer(input_ids).detach()
    
    with torch.no_grad():
        outputs = model(inputs_embeds=embeddings, attention_mask=attention_mask)
        probs = F.softmax(outputs.logits, dim=-1)
        pred_class = torch.argmax(probs).item()

    baseline = torch.zeros_like(embeddings).to(device)
    
    total_gradients = torch.zeros_like(embeddings).to(device)
    
    for step in range(n_steps + 1):
        alpha = float(step) / n_steps
        interpolated = baseline + alpha * (embeddings - baseline)
        interpolated = interpolated.requires_grad_(True)
        
        outputs = model(inputs_embeds=interpolated, attention_mask=attention_mask)
        score = outputs.logits[0, pred_class]
        
        score.backward()
        
        if interpolated.grad is not None:
            total_gradients += interpolated.grad.detach()
            
    avg_gradients = total_gradients / (n_steps + 1)
    attributions = (embeddings - baseline) * avg_gradients
    attributions = attributions.sum(dim=-1).squeeze(0).cpu().numpy()

    tokens = tokenizer.convert_ids_to_tokens(input_ids[0].cpu().numpy())
    
    aggregated_features = {}
    current_word = ""
    current_score = 0.0
    
    for token, score in zip(tokens, attributions):
        clean_tok = token.replace('Ġ', '').replace('\u0120', '').replace('▁', '').replace('</w>', '')
        if token in ['<s>', '</s>', '<pad>', '']:
            continue
            
        if token.startswith('Ġ') or token.startswith('\u0120') or not current_word:
            if current_word:
                aggregated_features[current_word] = aggregated_features.get(current_word, 0) + float(current_score)
            current_word = clean_tok
            current_score = float(score)
        else:
            current_word += clean_tok
            current_score += float(score)
            
    if current_word:
        aggregated_features[current_word] = aggregated_features.get(current_word, 0) + float(current_score)

    importance = list(aggregated_features.items())
    importance = sorted(importance, key=lambda x: abs(x[1]), reverse=True)
    
    return importance[:num_features], pred_class

In [None]:
results = []
for idx, text in enumerate(test_samples):
    print(f"\nSample {idx+1}: {text}")
    probs = predict_proba(text)[0]
    pred_class = np.argmax(probs)
    pred_label = id2label[pred_class]
    print(f"Prediction: {pred_label} ({probs[pred_class]:.3f})")
    
    lime_features, _ = explain_with_lime(text)
    shap_features, _ = explain_with_shap(text)
    ig_features, _ = explain_with_ig(text)
    
    results.append({
        'sample_id': idx + 1,
        'text': text,
        'prediction': pred_label,
        'confidence': probs[pred_class],
        'lime_features': lime_features[:5],
        'shap_features': shap_features[:5],
        'ig_features': ig_features[:5]
    })

In [None]:
def visualize_comparison(sample_result):
    fig, axes = plt.subplots(1, 3, figsize=(18, 6))
    method_names = ['LIME', 'SHAP', 'Integrated Gradients']
    method_keys = ['lime_features', 'shap_features', 'ig_features']
    
    for idx, (method_key, method_name, ax) in enumerate(zip(method_keys, method_names, axes)):
        features = sample_result[method_key]
        features_sorted = sorted(features, key=lambda x: abs(x[1]), reverse=True)
        
        if features_sorted:
            words, scores = zip(*features_sorted)
            
            sorted_pairs = sorted(zip(words, scores), key=lambda x: x[1])
            words, scores = zip(*sorted_pairs)
            words = list(words)
            scores = list(scores)
        else:
            words, scores = [], []
        
        colors = ['#2ecc71' if s > 0 else '#e74c3c' for s in scores]
        
        ax.barh(range(len(words)), scores, color=colors, alpha=0.7, edgecolor='black', linewidth=0.5)
        ax.set_yticks(range(len(words)))
        ax.set_yticklabels(words, fontsize=12, fontweight='bold')
        ax.set_title(method_name, fontsize=14, fontweight='bold', pad=10)
        ax.axvline(x=0, color='black', linewidth=1.5, linestyle='--')
        ax.grid(axis='x', alpha=0.3, linestyle=':')
        ax.set_xlabel('Importance Score', fontsize=11)
    
    text_preview = sample_result['text'][:80] + '...' if len(sample_result['text']) > 80 else sample_result['text']
    title = f"Sample {sample_result['sample_id']}: \"{text_preview}\"\nPrediction: {sample_result['prediction']} (Confidence: {sample_result['confidence']:.1%})"
    plt.suptitle(title, fontsize=12, fontweight='bold', y=1.0)
    plt.tight_layout()
    plt.show()

for idx in range(len(results)):
    visualize_comparison(results[idx])