# 🧠 Financial Sentiment Model Explainability Dashboard - Generalized

## Overview
This notebook provides comprehensive explainability analysis for **any** trained financial sentiment model in your collection. It includes four complementary explanation methods accessible through an interactive dashboard.

### Explanation Methods
- **🎯 SHAP**: Game-theory based feature importance
- **🔍 LIME**: Local interpretable model-agnostic explanations 
- **👁️ Attention**: Model attention head visualization
- **🌡️ GradCAM**: Gradient-based visual attribution

### Dashboard Features
- **Model Selection**: Choose any model from your trained collection
- **Mistake Analysis**: Examine specific model errors
- **Custom Text Analysis**: Test any financial text
- **Interactive Interface**: Tabbed layout for easy comparison
- **On-demand Computation**: Optimized performance
- **Configuration-Driven**: All settings loaded from pipeline config

**Configuration-driven approach:** All settings loaded from `../config/pipeline_config.json`

In [1]:
# Import configuration system and explainability utilities
import sys
import os
sys.path.append("../")

from src.pipeline_utils import ConfigManager, StateManager, LoggingManager

# Core libraries
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import warnings
import pickle
import json
import time
from typing import Dict, List, Optional, Tuple, Any
warnings.filterwarnings('ignore')

# Model and tokenizer
from transformers import AutoTokenizer, AutoModelForSequenceClassification

# ONNX support for explainability
try:
    import onnxruntime as ort
    onnx_available = True
except ImportError:
    print("⚠️ ONNX Runtime not available. Only PyTorch models will be supported.")
    onnx_available = False

# Explainability libraries
try:
    import shap
    shap_available = True
except ImportError:
    print("⚠️ SHAP not available. Install with: pip install shap")
    shap_available = False

try:
    from lime.lime_text import LimeTextExplainer
    lime_available = True
except ImportError:
    print("⚠️ LIME not available. Install with: pip install lime")
    lime_available = False

try:
    from bertviz import head_view
    bertviz_available = True
except ImportError:
    print("⚠️ BertViz not available. Install with: pip install bertviz")
    bertviz_available = False

try:
    from captum.attr import LayerGradCam
    captum_available = True
except ImportError:
    print("⚠️ Captum not available. Install with: pip install captum")
    captum_available = False

# Dashboard components
import ipywidgets as widgets
from IPython.display import display, HTML, clear_output

# Initialize managers
config = ConfigManager("../config/pipeline_config.json")
state = StateManager("../config/pipeline_state.json")
logger_manager = LoggingManager(config, 'explainability')
logger = logger_manager.get_logger()

print("✅ All libraries imported successfully")
logger.info("📊 Starting Model Explainability Analysis - Generalized Pipeline")

2025-08-13 00:13:08,346 - pipeline.explainability - INFO - 📊 Starting Model Explainability Analysis - Generalized Pipeline


✅ All libraries imported successfully


## 2. 🔍 Model Discovery & Configuration

In [2]:
# Discover available models for explainability analysis
logger.info("🔍 Discovering available models...")

# Load configuration
explainability_config = config.get('explainability', {})
models_config = config.get('models', {})
data_config = config.get('data', {})

# Model discovery
models_dir = Path(f"../{models_config.get('output_dir', 'models')}")
print(f"📂 Models directory: {models_dir}")

# Discover available models
available_models = {}

if models_dir.exists():
    for model_path in models_dir.iterdir():
        if not model_path.is_dir() or model_path.name.startswith('.'):
            continue
            
        model_name = model_path.name
        
        # Check for required files
        config_file = model_path / "config.json"
        label_encoder_file = model_path / "label_encoder.pkl"
        pytorch_files = list(model_path.glob("*.safetensors")) + list(model_path.glob("pytorch_model.bin"))
        
        # Check for ONNX model files
        onnx_dir = model_path / "onnx"
        onnx_files = list(onnx_dir.glob("*.onnx")) if onnx_dir.exists() else []
        
        if config_file.exists() and label_encoder_file.exists() and (pytorch_files or onnx_files):
            model_info = {
                'name': model_name,
                'path': model_path,
                'config_file': config_file,
                'label_encoder_file': label_encoder_file,
                'has_pytorch': len(pytorch_files) > 0,
                'has_onnx': len(onnx_files) > 0,
                'pytorch_files': pytorch_files,
                'onnx_files': onnx_files
            }
            available_models[model_name] = model_info
            
            status = []
            if model_info['has_pytorch']:
                status.append("PyTorch")
            if model_info['has_onnx']:
                status.append("ONNX")
            
            print(f"   ✅ Found: {model_name} ({', '.join(status)})")
        else:
            print(f"   ⚠️ Invalid model directory: {model_name} (missing required files)")

print(f"\n📊 Discovery Summary:")
print(f"   🤖 Total models found: {len(available_models)}")
pytorch_count = sum(1 for m in available_models.values() if m['has_pytorch'])
onnx_count = sum(1 for m in available_models.values() if m['has_onnx'])
print(f"   🔥 PyTorch models: {pytorch_count}")
print(f"   ⚡ ONNX models: {onnx_count}")

if len(available_models) == 0:
    logger.error("No valid models found for explainability analysis")
    raise RuntimeError("No models found. Please ensure models have been trained and have label encoders.")

logger.info(f"Model discovery completed: {len(available_models)} models available")

2025-08-13 00:13:08,377 - pipeline.explainability - INFO - 🔍 Discovering available models...
2025-08-13 00:13:08,383 - pipeline.explainability - INFO - Model discovery completed: 7 models available
2025-08-13 00:13:08,383 - pipeline.explainability - INFO - Model discovery completed: 7 models available


📂 Models directory: ../models
   ✅ Found: tinybert-financial-classifier-fine-tuned (PyTorch, ONNX)
   ✅ Found: all-MiniLM-L6-v2-financial-sentiment (PyTorch, ONNX)
   ✅ Found: distilbert-financial-sentiment (PyTorch, ONNX)
   ✅ Found: finbert-tone-financial-sentiment (PyTorch, ONNX)
   ⚠️ Invalid model directory: tinybert-financial-classifier_explainability_fine_tuned (missing required files)
   ⚠️ Invalid model directory: SmolLM2-360M-Instruct-financial-sentiment (missing required files)
   ✅ Found: tinybert-financial-classifier (PyTorch, ONNX)
   ✅ Found: tinybert-financial-classifier-pruned (PyTorch, ONNX)
   ✅ Found: mobilebert-uncased-financial-sentiment (PyTorch, ONNX)

📊 Discovery Summary:
   🤖 Total models found: 7
   🔥 PyTorch models: 7
   ⚡ ONNX models: 7


## 3. 📊 Data Loading & Validation

In [3]:
# Load validation data from the config-specified path
data_config = config.get('data', {})
processed_data_dir = data_config.get('processed_data_dir', 'data/processed')
validation_path = os.path.join('..', processed_data_dir, 'validation.csv')

print(f"📂 Loading validation data from: {validation_path}")
print(f"📁 Absolute path: {os.path.abspath(validation_path)}")

try:
    validation_df = pd.read_csv(validation_path)
    print(f"✅ Loaded {len(validation_df)} validation samples")
    
    # Extract texts and labels
    validation_texts = validation_df['text'].tolist()
    validation_labels_text = validation_df['label'].tolist()
    
    # Get unique labels for encoding
    unique_labels = sorted(validation_df['label'].unique())
    label_to_id = {label: i for i, label in enumerate(unique_labels)}
    validation_labels = [label_to_id[label] for label in validation_labels_text]
    
    print(f"📊 Label distribution:")
    for label in unique_labels:
        count = validation_labels_text.count(label)
        print(f"  {label}: {count} samples ({count/len(validation_df)*100:.1f}%)")
        
    print(f"🎯 Available labels: {unique_labels}")
    print(f"📈 Sample texts preview:")
    for i, (text, label) in enumerate(zip(validation_texts[:3], validation_labels_text[:3])):
        print(f"  [{label}]: {text[:80]}...")
        
except FileNotFoundError:
    print(f"❌ Validation file not found at: {validation_path}")
    print("📝 Trying alternative validation data sources...")
    
    # Try to load from test data as fallback
    alternative_paths = [
        '../data/processed/full_processed.csv',
        '../data/FinancialClassification/test.csv',
        '../data/FinancialAuditor/test.csv',
        '../data/FinancialPhraseBank/all-data.csv'
    ]
    
    validation_df = None
    for alt_path in alternative_paths:
        try:
            print(f"🔍 Checking: {alt_path}")
            if os.path.exists(alt_path):
                validation_df = pd.read_csv(alt_path)
                print(f"✅ Using fallback data from: {alt_path}")
                
                # Handle different column naming conventions
                if 'text' in validation_df.columns:
                    text_col = 'text'
                elif 'sentence' in validation_df.columns:
                    text_col = 'sentence'
                elif len(validation_df.columns) >= 2:  # Handle FinancialPhraseBank format
                    # FinancialPhraseBank has no headers, assume first col is label, second is text
                    validation_df.columns = ['label', 'text']
                    text_col = 'text'
                else:
                    print(f"❌ Unknown data format in {alt_path}")
                    continue
                
                if 'label' in validation_df.columns:
                    label_col = 'label'
                elif 'sentiment' in validation_df.columns:
                    label_col = 'sentiment'
                else:
                    print(f"❌ No label column found in {alt_path}")
                    continue
                
                validation_texts = validation_df[text_col].tolist()
                validation_labels_text = validation_df[label_col].tolist()
                
                # Clean up any malformed data
                clean_data = []
                for text, label in zip(validation_texts, validation_labels_text):
                    if pd.notna(text) and pd.notna(label) and str(text).strip() and str(label).strip():
                        clean_data.append((str(text).strip(), str(label).strip()))
                
                if not clean_data:
                    print(f"❌ No valid data found in {alt_path}")
                    continue
                
                validation_texts = [item[0] for item in clean_data]
                validation_labels_text = [item[1] for item in clean_data]
                
                # Sample for performance if too large
                if len(validation_df) > 500:
                    sample_indices = np.random.choice(len(validation_df), 500, replace=False)
                    validation_texts = [validation_texts[i] for i in sample_indices]
                    validation_labels_text = [validation_labels_text[i] for i in sample_indices]
                    print(f"📋 Sampled {len(validation_texts)} examples for performance")
                
                unique_labels = sorted(set(validation_labels_text))
                label_to_id = {label: i for i, label in enumerate(unique_labels)}
                validation_labels = [label_to_id[label] for label in validation_labels_text]
                
                print(f"📊 Label distribution:")
                for label in unique_labels:
                    count = validation_labels_text.count(label)
                    print(f"  {label}: {count} samples ({count/len(validation_labels_text)*100:.1f}%)")
                
                break
        except Exception as e:
            continue
    
    if validation_df is None:
        print("❌ No validation data found. Creating sample data...")
        # Create sample validation data
        sample_texts = [
            "The company reported strong quarterly earnings with revenue growth exceeding expectations.",
            "Market volatility continues to pose challenges for the financial sector.",
            "The business maintained steady performance despite economic headwinds.",
            "Declining sales figures indicate potential market challenges ahead.",
            "The merger announcement boosted investor confidence significantly."
        ]
        sample_labels = ["positive", "negative", "neutral", "negative", "positive"]
        
        validation_texts = sample_texts
        validation_labels_text = sample_labels  
        unique_labels = ["positive", "negative", "neutral"]
        label_to_id = {label: i for i, label in enumerate(unique_labels)}
        validation_labels = [label_to_id[label] for label in validation_labels_text]
        
        print(f"✅ Using sample data with {len(validation_texts)} samples")

print("✅ Data loading complete")
print(f"📋 Final dataset: {len(validation_texts)} samples, {len(unique_labels)} classes")

# Configuration for explainability
EXPLAINABILITY_CONFIG = {
    'max_sequence_length': explainability_config.get('max_sequence_length', 512),
    'batch_size': explainability_config.get('batch_size', 8),
    'sample_size': explainability_config.get('sample_size', min(100, len(validation_texts))),
    'random_seed': explainability_config.get('random_seed', 42)
}

print(f"🔧 Explainability Configuration:")
for key, value in EXPLAINABILITY_CONFIG.items():
    print(f"   📋 {key}: {value}")

logger.info("Data loading completed successfully")

2025-08-13 00:13:08,411 - pipeline.explainability - INFO - Data loading completed successfully


📂 Loading validation data from: ../data/processed/validation.csv
📁 Absolute path: /Users/matthew/Documents/deepmind_internship/data/processed/validation.csv
✅ Loaded 485 validation samples
📊 Label distribution:
  negative: 61 samples (12.6%)
  neutral: 288 samples (59.4%)
  positive: 136 samples (28.0%)
🎯 Available labels: ['negative', 'neutral', 'positive']
📈 Sample texts preview:
  [neutral]: The solution will be installed in the USA to support the North American operatio...
  [negative]: Scanfil , a systems supplier and contract manufacturer to the communications sec...
  [neutral]: `` The sale of the oxygen measurement business strengthens our goal to focus on ...
✅ Data loading complete
📋 Final dataset: 485 samples, 3 classes
🔧 Explainability Configuration:
   📋 max_sequence_length: 512
   📋 batch_size: 8
   📋 sample_size: 100
   📋 random_seed: 42


## 4. 🔧 Universal Model Wrapper

In [4]:
class UniversalModelWrapper:
    """Universal wrapper that works with both PyTorch and ONNX models for explainability"""
    
    def __init__(self, model_info: Dict, model_type: str = 'auto'):
        """
        Initialize model wrapper
        
        Args:
            model_info: Model information dictionary from discovery
            model_type: 'pytorch', 'onnx', or 'auto' (auto-detect)
        """
        self.model_info = model_info
        self.model_path = model_info['path']
        self.model_name = model_info['name']
        
        # Auto-detect model type if not specified
        if model_type == 'auto':
            if model_info['has_pytorch']:
                self.model_type = 'pytorch'
            elif model_info['has_onnx']:
                self.model_type = 'onnx'
            else:
                raise ValueError(f"No supported model format found for {self.model_name}")
        else:
            self.model_type = model_type
        
        # Load tokenizer (same for both types)
        self.tokenizer = AutoTokenizer.from_pretrained(str(self.model_path))
        
        # Load label encoder
        with open(self.model_info['label_encoder_file'], 'rb') as f:
            self.label_encoder = pickle.load(f)
        
        # Load appropriate model
        if self.model_type == 'onnx':
            if not model_info['has_onnx']:
                raise ValueError(f"ONNX model not available for {self.model_name}")
            onnx_path = self.model_path / 'onnx' / 'model.onnx'
            self.session = ort.InferenceSession(str(onnx_path))
            self.model = None  # No PyTorch model for ONNX
            print(f"✅ Loaded ONNX model: {self.model_name}")
        else:
            if not model_info['has_pytorch']:
                raise ValueError(f"PyTorch model not available for {self.model_name}")
            
            # Load PyTorch model with eager attention implementation to support explainability
            try:
                # First attempt: Load with explicit eager attention implementation
                self.model = AutoModelForSequenceClassification.from_pretrained(
                    str(self.model_path),
                    attn_implementation="eager"  # Force eager attention for explainability compatibility
                )
                print(f"✅ Loaded PyTorch model with eager attention: {self.model_name}")
            except Exception as eager_error:
                print(f"⚠️ Could not load with eager attention: {eager_error}")
                print("🔄 Loading with default implementation...")
                
                # Fallback: Load normally and set implementation afterward
                self.model = AutoModelForSequenceClassification.from_pretrained(str(self.model_path))
                
                # Try to set eager attention implementation for explainability
                if hasattr(self.model.config, 'attn_implementation'):
                    self.model.config.attn_implementation = 'eager'
                if hasattr(self.model.config, '_attn_implementation'):
                    self.model.config._attn_implementation = 'eager'
                
                print(f"✅ Loaded PyTorch model (attention implementation set post-load): {self.model_name}")
            
            self.model.eval()
            self.session = None  # No ONNX session for PyTorch
        
        logger.info(f"Model wrapper initialized: {self.model_name} ({self.model_type})")
    
    def predict_class(self, texts):
        """Predict sentiment class for text(s)"""
        if isinstance(texts, str):
            texts = [texts]
        
        predictions = []
        
        if self.model_type == 'onnx':
            predictions = self._predict_onnx_class(texts)
        else:
            predictions = self._predict_pytorch_class(texts)
        
        return np.array(predictions)
    
    def predict_probs(self, texts):
        """Get prediction probabilities for text(s)"""
        if isinstance(texts, str):
            texts = [texts]
        
        if self.model_type == 'onnx':
            return self._predict_onnx_probs(texts)
        else:
            return self._predict_pytorch_probs(texts)
    
    def _predict_pytorch_class(self, texts):
        """PyTorch class prediction"""
        predictions = []
        self.model.eval()
        
        with torch.no_grad():
            for text in texts:
                encoding = self.tokenizer(text, return_tensors='pt', 
                                        max_length=EXPLAINABILITY_CONFIG['max_sequence_length'], 
                                        truncation=True, padding=True)
                outputs = self.model(**encoding)
                predicted_class = torch.argmax(outputs.logits, dim=-1).item()
                predictions.append(predicted_class)
        
        return predictions
    
    def _predict_pytorch_probs(self, texts):
        """PyTorch probability prediction"""
        all_probs = []
        self.model.eval()
        
        with torch.no_grad():
            for text in texts:
                encoding = self.tokenizer(text, return_tensors='pt', 
                                        max_length=EXPLAINABILITY_CONFIG['max_sequence_length'], 
                                        truncation=True, padding=True)
                outputs = self.model(**encoding)
                probs = torch.softmax(outputs.logits, dim=-1).squeeze().cpu().numpy()
                all_probs.append(probs)
        
        return np.array(all_probs)
    
    def _predict_onnx_class(self, texts):
        """ONNX class prediction"""
        probabilities = self._predict_onnx_probs(texts)
        return np.argmax(probabilities, axis=-1).tolist()
    
    def _predict_onnx_probs(self, texts):
        """ONNX probability prediction"""
        all_probs = []
        
        for text in texts:
            # Tokenize
            encoding = self.tokenizer(text, return_tensors='np', 
                                    max_length=EXPLAINABILITY_CONFIG['max_sequence_length'], 
                                    truncation=True, padding=True)
            
            # Prepare inputs
            inputs = {
                'input_ids': encoding['input_ids'].astype(np.int64),
                'attention_mask': encoding['attention_mask'].astype(np.int64)
            }
            
            # Run inference
            outputs = self.session.run(None, inputs)
            logits = outputs[0]
            
            # Convert to probabilities
            probs = self._softmax(logits[0])  # Take first (and only) sample
            all_probs.append(probs)
        
        return np.array(all_probs)
    
    def _softmax(self, x):
        """Numpy softmax implementation for ONNX"""
        exp_x = np.exp(x - np.max(x))
        return exp_x / np.sum(exp_x)
    
    def get_model_for_attention(self):
        """Get the PyTorch model for attention visualization (ONNX not supported)"""
        if self.model_type != 'pytorch':
            raise ValueError("Attention visualization only supported for PyTorch models")
        return self.model
    
    def get_tokenizer(self):
        """Get the tokenizer"""
        return self.tokenizer
    
    def get_label_encoder(self):
        """Get the label encoder"""
        return self.label_encoder
    
    def get_class_names(self):
        """Get class names"""
        return list(self.label_encoder.classes_)

print("✅ Universal Model Wrapper defined")
logger.info("Universal model wrapper class ready")

2025-08-13 00:13:08,433 - pipeline.explainability - INFO - Universal model wrapper class ready


✅ Universal Model Wrapper defined


## 5. 🧩 Explainability Methods

### 5.1 SHAP Implementation (Universal)

In [5]:
# SHAP explainer cache (per model)
_shap_explainers = {}

def get_shap_explainer(model_wrapper):
    """Get SHAP explainer for model (lazy initialization)"""
    if not shap_available:
        raise ImportError("SHAP not available. Install with: pip install shap")
    
    model_key = f"{model_wrapper.model_name}_{model_wrapper.model_type}"
    
    if model_key not in _shap_explainers:
        print(f"🧮 Initializing SHAP explainer for {model_wrapper.model_name}...")
        
        # Create prediction function for SHAP
        def predict_for_shap(texts):
            return model_wrapper.predict_probs(texts)
        
        _shap_explainers[model_key] = shap.Explainer(predict_for_shap, model_wrapper.tokenizer)
    
    return _shap_explainers[model_key]

def explain_with_shap(model_wrapper, text, target_class=None):
    """Generate SHAP explanation for text"""
    if not shap_available:
        print("❌ SHAP not available. Install with: pip install shap")
        return
    
    # print("⏳ Computing SHAP values...")
    
    try:
        explainer = get_shap_explainer(model_wrapper)
        shap_values = explainer([text])
        
        if target_class is None:
            target_class = model_wrapper.predict_class(text)[0]
        
        # Display SHAP plot
        shap.plots.text(shap_values[0, :, target_class])
        
        pred_label = model_wrapper.label_encoder.inverse_transform([target_class])[0]
        print(f"📊 SHAP explanation for class: {pred_label}")
        print(f"🔧 Model: {model_wrapper.model_name} ({model_wrapper.model_type})")
        
    except Exception as e:
        print(f"❌ SHAP explanation failed: {str(e)}")
        print("💡 Common issues:")
        print("   - Model compatibility with SHAP")
        print("   - Text preprocessing differences")
        print("   - Try using LIME as alternative")

print("✅ SHAP implementation ready")

✅ SHAP implementation ready


### 5.2 LIME Implementation (Universal)

In [6]:
# LIME explainer cache (per model)
_lime_explainers = {}

def get_lime_explainer(model_wrapper):
    """Get LIME explainer for model (lazy initialization)"""
    if not lime_available:
        raise ImportError("LIME not available. Install with: pip install lime")
    
    model_key = f"{model_wrapper.model_name}_{model_wrapper.model_type}"
    
    if model_key not in _lime_explainers:
        _lime_explainers[model_key] = LimeTextExplainer(
            class_names=model_wrapper.get_class_names()
        )
    
    return _lime_explainers[model_key]

def explain_with_lime(model_wrapper, text):
    """Generate LIME explanation for text"""
    if not lime_available:
        print("❌ LIME not available. Install with: pip install lime")
        return
    
    print("⏳ Computing LIME explanation...")
    
    try:
        explainer = get_lime_explainer(model_wrapper)
        
        # Create prediction function for LIME
        def predict_for_lime(texts):
            """Prediction function for LIME (expects different format)"""
            if isinstance(texts, str):
                texts = [texts]
            elif isinstance(texts, list) and len(texts) == 1 and isinstance(texts[0], str):
                texts = texts
            elif isinstance(texts, (list, tuple)) and all(isinstance(t, str) for t in texts):
                texts = list(texts)
            else:
                texts = [str(t) for t in texts]
            
            try:
                return model_wrapper.predict_probs(texts)
            except Exception as e:
                print(f"Error in LIME prediction: {e}")
                # Return default probabilities if processing fails
                return np.array([[0.33, 0.33, 0.34]] * len(texts))
        
        explanation = explainer.explain_instance(
            text,
            predict_for_lime,
            num_features=20,
            labels=tuple(range(len(model_wrapper.get_class_names())))
        )
        
        display(HTML(explanation.as_html()))
        print(f"📊 LIME explanation generated")
        print(f"🔧 Model: {model_wrapper.model_name} ({model_wrapper.model_type})")
        
    except Exception as e:
        print(f"❌ LIME explanation failed: {str(e)}")
        print("💡 Common LIME issues:")
        print("   - Text preprocessing differences")
        print("   - Prediction function format mismatch")
        print("   - Try using SHAP instead")

print("✅ LIME implementation ready")

✅ LIME implementation ready


### 5.3 Attention Visualization (PyTorch Only)

In [7]:
def explain_with_attention(model_wrapper, text):
    """Generate attention visualization for text (PyTorch only)"""
    if model_wrapper.model_type != 'pytorch':
        print("❌ Attention visualization only supported for PyTorch models")
        print("💡 Switch to a PyTorch model to use this feature")
        return
    
    if not bertviz_available:
        print("⚠️ BertViz not available. Using custom attention visualization...")
        
    print("⏳ Generating attention visualization...")
    
    try:
        model = model_wrapper.get_model_for_attention()
        tokenizer = model_wrapper.get_tokenizer()
        
        # Tokenize with attention output
        inputs = tokenizer(text, return_tensors='pt', truncation=True, 
                          max_length=EXPLAINABILITY_CONFIG['max_sequence_length'])
        
        # Get model outputs with attention
        with torch.no_grad():
            outputs = model(**inputs, output_attentions=True)
            attentions = outputs.attentions
            tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])
        
        # Check if we have valid attention and tokens
        if attentions is None or len(attentions) == 0:
            print("❌ No attention weights available")
            return
            
        if len(tokens) == 0:
            print("❌ No tokens available")
            return
        
        pred_class = torch.argmax(outputs.logits, dim=-1).item()
        pred_label = model_wrapper.label_encoder.inverse_transform([pred_class])[0]
        
        # Try BertViz first if available
        if bertviz_available:
            try:
                print("🎯 Attempting interactive attention visualization...")
                head_view(attentions, tokens)
                print(f"👁️ Interactive attention visualization for prediction: {pred_label}")
                print(f"🔧 Model: {model_wrapper.model_name} ({model_wrapper.model_type})")
                return
                
            except Exception as viz_error:
                print(f"❌ BertViz interactive view failed: {viz_error}")
                print("💡 Using custom attention heatmap...")
        
        # Custom attention visualization fallback
        _visualize_attention_heatmap(attentions, tokens, pred_label, model_wrapper.model_name)
            
    except Exception as e:
        print(f"❌ Attention analysis failed: {str(e)}")
        print("💡 This might be due to model architecture compatibility issues")
        print("🔧 Try reloading the model or use other explainability methods")

def _visualize_attention_heatmap(attentions, tokens, pred_label, model_name):
    """Create custom attention heatmap visualization"""
    # Get average attention across all layers and heads
    avg_attention = torch.stack(attentions).mean(dim=0)  # Average across layers
    avg_attention = avg_attention.mean(dim=1)  # Average across heads
    attention_matrix = avg_attention[0].detach().cpu().numpy()  # Get first batch
    
    # Clean tokens for display
    clean_tokens = []
    for token in tokens:
        if token.startswith('##'):
            clean_tokens.append(token[2:])
        elif token in ['[CLS]', '[SEP]', '[PAD]']:
            clean_tokens.append(token)
        else:
            clean_tokens.append(token)
    
    # Limit to reasonable size for visualization
    max_len = min(len(clean_tokens), 50)
    attention_matrix = attention_matrix[:max_len, :max_len]
    display_tokens = clean_tokens[:max_len]
    
    # Create visualization
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 8))
    
    # 1. Full attention heatmap
    sns.heatmap(attention_matrix, 
                xticklabels=display_tokens,
                yticklabels=display_tokens,
                cmap='Blues',
                ax=ax1,
                cbar_kws={'label': 'Attention Weight'})
    ax1.set_title(f'Attention Heatmap\nModel: {model_name}\nPrediction: {pred_label}', 
                 fontsize=14, weight='bold')
    ax1.set_xlabel('Attended Tokens')
    ax1.set_ylabel('Query Tokens')
    plt.setp(ax1.get_xticklabels(), rotation=45, ha='right')
    plt.setp(ax1.get_yticklabels(), rotation=0)
    
    # 2. CLS token attention (what the model focuses on for classification)
    cls_attention = attention_matrix[0, 1:]  # CLS token attention to other tokens
    tokens_for_cls = display_tokens[1:]  # Skip CLS token
    
    # Sort by attention weight
    token_attention_pairs = list(zip(tokens_for_cls, cls_attention))
    token_attention_pairs.sort(key=lambda x: x[1], reverse=True)
    
    # Take top 15 for readability
    top_tokens, top_weights = zip(*token_attention_pairs[:15])
    
    bars = ax2.barh(range(len(top_tokens)), top_weights, color='skyblue')
    ax2.set_yticks(range(len(top_tokens)))
    ax2.set_yticklabels(top_tokens)
    ax2.set_xlabel('Attention Weight')
    ax2.set_title(f'Top Attended Tokens\n(CLS token attention)', fontsize=14, weight='bold')
    ax2.invert_yaxis()
    
    # Add value labels on bars
    for i, (bar, weight) in enumerate(zip(bars, top_weights)):
        ax2.text(weight + 0.001, i, f'{weight:.3f}', 
                va='center', ha='left', fontsize=10)
    
    plt.tight_layout()
    plt.show()
    
    # Print statistics
    print(f"📊 Attention Statistics:")
    print(f"   • Model: {model_name}")
    print(f"   • Number of layers: {len(attentions)}")
    print(f"   • Number of heads per layer: {attentions[0].shape[2]}")
    print(f"   • Sequence length: {len(tokens)}")
    print(f"   • Max attention weight: {attention_matrix.max():.4f}")
    print(f"   • Average attention weight: {attention_matrix.mean():.4f}")
    
    print(f"\n🎯 Top 5 tokens by CLS attention:")
    for i, (token, weight) in enumerate(token_attention_pairs[:5]):
        if token not in ['[SEP]', '[PAD]']:
            print(f"   {i+1}. '{token}': {weight:.4f}")
    
    print(f"👁️ Custom attention visualization complete for: {pred_label}")

print("✅ Attention visualization ready")

✅ Attention visualization ready


### 5.4 GradCAM Implementation (PyTorch Only)

In [None]:
def explain_with_gradcam(model_wrapper, text, target_layer=None):
    """Generate GradCAM explanation for text (PyTorch only)"""
    if model_wrapper.model_type != 'pytorch':
        print("❌ GradCAM visualization only supported for PyTorch models")
        print("💡 Switch to a PyTorch model to use this feature")
        return
    
    if not captum_available:
        print("❌ Captum not available. Install with: pip install captum")
        return
        
    print("⏳ Generating GradCAM visualization...")
    
    try:
        from captum.attr import LayerGradCam, TokenReferenceBase
        from captum.attr import visualization as viz
        
        model = model_wrapper.get_model_for_attention()
        tokenizer = model_wrapper.get_tokenizer()
        
        # Tokenize input
        inputs = tokenizer(text, return_tensors='pt', truncation=True, 
                          max_length=EXPLAINABILITY_CONFIG['max_sequence_length'],
                          padding=True)
        
        input_ids = inputs['input_ids']
        attention_mask = inputs['attention_mask']
        
        # Get model prediction
        model.eval()
        with torch.no_grad():
            outputs = model(**inputs)
            pred_class = torch.argmax(outputs.logits, dim=-1).item()
            pred_label = model_wrapper.label_encoder.inverse_transform([pred_class])[0]
        
        print(f"🎯 Model prediction: {pred_label}")
        
        # Determine target layer for GradCAM
        if target_layer is None:
            # Try to find the last transformer layer
            if hasattr(model, 'bert'):  # BERT-based models
                target_layer = model.bert.encoder.layer[-1]
            elif hasattr(model, 'distilbert'):  # DistilBERT
                target_layer = model.distilbert.transformer.layer[-1]
            elif hasattr(model, 'roberta'):  # RoBERTa
                target_layer = model.roberta.encoder.layer[-1]
            elif hasattr(model, 'albert'):  # ALBERT
                target_layer = model.albert.encoder.albert_layer_groups[-1].albert_layers[-1]
            else:
                # Generic fallback - try to find encoder layers
                for name, module in model.named_modules():
                    if 'encoder' in name.lower() and 'layer' in name.lower():
                        target_layer = module
                        break
                
                if target_layer is None:
                    print("❌ Could not automatically detect target layer for GradCAM")
                    print("💡 Model architecture not supported for automatic layer detection")
                    return
        
        print(f"🔍 Using target layer: {target_layer.__class__.__name__}")
        
        # Create wrapper function that returns only logits
        def forward_func(input_ids, attention_mask=None):
            """Forward function that returns only logits tensor for Captum"""
            if attention_mask is not None:
                outputs = model(input_ids=input_ids, attention_mask=attention_mask)
            else:
                outputs = model(input_ids=input_ids)
            return outputs.logits
        
        # Create GradCAM with the forward function
        grad_cam = LayerGradCam(forward_func, target_layer)
        
        # Generate attributions
        try:
            # Try with attention mask first
            attributions = grad_cam.attribute(
                inputs=input_ids,
                target=pred_class,
                additional_forward_args=(attention_mask,)
            )
        except Exception as attr_error:
            print(f"❌ Attribution failed with attention mask: {attr_error}")
            print("🔄 Retrying without attention mask...")
            try:
                # Create new GradCAM without attention mask support
                def simple_forward_func(input_ids):
                    return model(input_ids=input_ids).logits
                
                grad_cam_simple = LayerGradCam(simple_forward_func, target_layer)
                attributions = grad_cam_simple.attribute(
                    inputs=input_ids,
                    target=pred_class
                )
            except Exception as attr_error2:
                print(f"❌ Attribution failed: {attr_error2}")
                print("💡 Trying alternative approach...")
                
                # Final fallback - use a wrapper class
                try:
                    class ModelWrapper(torch.nn.Module):
                        def __init__(self, model):
                            super().__init__()
                            self.model = model
                        
                        def forward(self, input_ids):
                            return self.model(input_ids=input_ids).logits
                    
                    wrapped_model = ModelWrapper(model)
                    grad_cam_wrapped = LayerGradCam(wrapped_model, target_layer)
                    attributions = grad_cam_wrapped.attribute(
                        inputs=input_ids,
                        target=pred_class
                    )
                except Exception as attr_error3:
                    print(f"❌ All attribution methods failed: {attr_error3}")
                    print("💡 This model architecture might not be compatible with GradCAM")
                    return
        
        # Convert to numpy and process
        attributions_np = attributions.squeeze().detach().cpu().numpy()
        
        # Get tokens for visualization
        tokens = tokenizer.convert_ids_to_tokens(input_ids.squeeze())
        
        # Create visualization
        _visualize_gradcam(attributions_np, tokens, text, pred_label, model_wrapper.model_name)
        
    except Exception as e:
        print(f"❌ GradCAM analysis failed: {str(e)}")
        print("💡 Common issues:")
        print("   - Model architecture not supported")
        print("   - Layer selection problems")
        print("   - Memory constraints with large models")

def _visualize_gradcam(attributions, tokens, original_text, pred_label, model_name):
    """Create GradCAM visualization with text highlighting"""
    
    # Process attributions
    if len(attributions.shape) > 1:
        # If multi-dimensional, take mean across dimensions
        attr_scores = np.mean(attributions, axis=tuple(range(1, len(attributions.shape))))
    else:
        attr_scores = attributions
    
    # Ensure we have the right number of scores for tokens
    min_len = min(len(attr_scores), len(tokens))
    attr_scores = attr_scores[:min_len]
    tokens = tokens[:min_len]
    
    # Normalize attributions to [0, 1]
    attr_scores = attr_scores - attr_scores.min()
    if attr_scores.max() > 0:
        attr_scores = attr_scores / attr_scores.max()
    
    # Clean tokens and merge subwords for better text reconstruction
    display_tokens = []
    display_scores = []
    reconstructed_words = []
    word_scores = []
    current_word = ""
    current_score = 0
    word_token_count = 0
    
    for token, score in zip(tokens, attr_scores):
        if token in ['[CLS]', '[SEP]', '[PAD]']:
            continue
            
        if token.startswith('##'):
            # Subword continuation
            current_word += token[2:]
            current_score += score
            word_token_count += 1
        else:
            # New word - finalize previous word if exists
            if current_word:
                reconstructed_words.append(current_word)
                word_scores.append(current_score / word_token_count)  # Average score for word
            
            # Start new word
            current_word = token
            current_score = score
            word_token_count = 1
            
        # Keep individual tokens for detailed analysis
        if token not in ['[CLS]', '[SEP]', '[PAD]']:
            display_tokens.append(token[2:] if token.startswith('##') else token)
            display_scores.append(score)
    
    # Don't forget the last word
    if current_word:
        reconstructed_words.append(current_word)
        word_scores.append(current_score / word_token_count)
    
    # Create clean, focused visualization
    fig = plt.figure(figsize=(16, 9))
    # Simplified layout: Large text panel + compact side panels
    gs = fig.add_gridspec(2, 2, height_ratios=[3, 1], width_ratios=[2.5, 1], 
                         hspace=0.2, wspace=0.3, 
                         left=0.05, right=0.95, top=0.92, bottom=0.08)
    
    # 1. Enhanced text highlighting (main focus - spans top row)
    ax_text = fig.add_subplot(gs[0, :])
    _create_enhanced_text_highlight(ax_text, reconstructed_words, word_scores, original_text, pred_label, model_name)
    
    # 2. Top words bar chart (bottom left)
    ax_bars = fig.add_subplot(gs[1, 0])
    _create_word_importance_chart(ax_bars, reconstructed_words, word_scores)
    
    # 3. Statistics summary (bottom right)
    ax_stats = fig.add_subplot(gs[1, 1])
    _create_stats_summary(ax_stats, word_scores, reconstructed_words, pred_label, model_name)
    
    plt.suptitle('GradCAM Text Explainability Analysis', fontsize=18, fontweight='bold', y=0.97)
    plt.show()
    
    # Print enhanced statistics without emojis
    print(f"GradCAM Analysis Summary:")
    print(f"   Model: {model_name}")
    print(f"   Prediction: {pred_label}")
    print(f"   Words analyzed: {len(reconstructed_words)}")
    print(f"   Tokens analyzed: {len(tokens)}")
    print(f"   Max word importance: {max(word_scores):.4f}")
    print(f"   Average word importance: {np.mean(word_scores):.4f}")
    
    # Show top important words with better formatting
    word_importance = list(zip(reconstructed_words, word_scores))
    word_importance.sort(key=lambda x: x[1], reverse=True)
    
    print(f"\nMost Important Words for '{pred_label}' prediction:")
    for i, (word, score) in enumerate(word_importance[:8]):
        if score > 0.8:
            intensity = "[CRITICAL]"
        elif score > 0.65:
            intensity = "[HIGH]"
        elif score > 0.45:
            intensity = "[MEDIUM]"
        elif score > 0.25:
            intensity = "[LOW]"
        else:
            intensity = "[MINIMAL]"
        print(f"   {i+1}. {intensity} '{word}': {score:.3f}")
    
    print(f"\nGradCAM text analysis complete!")

def _create_enhanced_text_highlight(ax, words, scores, original_text, pred_label, model_name):
    """Create enhanced text highlighting with more focus and better information display"""
    ax.set_xlim(0, 12)
    ax.set_ylim(0, 7)
    ax.axis('off')
    
    # Clean title without emojis
    title_box = dict(boxstyle="round,pad=0.6", facecolor='#2C3E50', edgecolor='#1A252F', linewidth=2)
    ax.text(6, 6.6, f"GradCAM Analysis: {pred_label.upper()}", 
            ha='center', va='center', fontsize=18, weight='bold', color='white',
            bbox=title_box)
    
    # Model info
    ax.text(6, 6.1, f"Model: {model_name}", ha='center', va='center', 
            fontsize=12, style='italic', color='#34495E')
    
    # Text highlighting with enhanced visual hierarchy
    y_pos = 5.0
    x_pos = 0.4
    max_width = 11.2
    line_height = 0.5
    
    # Better score normalization for clearer visual differences
    if max(scores) > 0:
        norm_scores = np.array(scores) / max(scores)
        # Apply sigmoid-like transformation for better visual separation
        norm_scores = 1 / (1 + np.exp(-5 * (norm_scores - 0.5)))
    else:
        norm_scores = np.array(scores)
    
    # Add importance scale indicator
    ax.text(0.4, 5.4, "Importance Scale:", fontsize=12, weight='bold', color='#2C3E50')
    
    for i, (word, score) in enumerate(zip(words, norm_scores)):
        # Refined color scheme for better contrast and readability
        if score > 0.80:
            bg_color = '#E74C3C'  # Strong red
            text_color = 'white'
            weight = 'bold'
            border_color = '#C0392B'
            importance = 'CRITICAL'
        elif score > 0.65:
            bg_color = '#E67E22'  # Orange
            text_color = 'white'
            weight = 'bold'
            border_color = '#D35400'
            importance = 'HIGH'
        elif score > 0.45:
            bg_color = '#F39C12'  # Yellow-orange
            text_color = 'black'
            weight = 'bold'
            border_color = '#E67E22'
            importance = 'MEDIUM'
        elif score > 0.25:
            bg_color = '#27AE60'  # Green
            text_color = 'white'
            weight = 'normal'
            border_color = '#229954'
            importance = 'LOW'
        else:
            bg_color = '#BDC3C7'  # Light gray
            text_color = '#2C3E50'
            weight = 'normal'
            border_color = '#95A5A6'
            importance = 'MINIMAL'
        
        # Calculate word width
        word_width = len(word) * 0.09 + 0.5
        
        # Line wrapping
        if x_pos + word_width > max_width:
            y_pos -= line_height
            x_pos = 0.4
        
        # Enhanced word styling with subtle shadow effect
        word_box = dict(boxstyle="round,pad=0.3", facecolor=bg_color, 
                       edgecolor=border_color, linewidth=1.5, alpha=0.95)
        
        word_text = ax.text(x_pos, y_pos, word, fontsize=14, weight=weight, color=text_color,
                           bbox=word_box, ha='left', va='center')
        
        # Add importance score as small text above critical words
        if score > 0.65:
            ax.text(x_pos + word_width/2, y_pos + 0.2, f'{score:.2f}', 
                   ha='center', va='center', fontsize=9, weight='bold', 
                   color=border_color, alpha=0.8)
        
        x_pos += word_width + 0.15
    
    # Enhanced legend with cleaner design
    legend_y = 1.8
    legend_spacing = 2.2
    
    # Legend background
    ax.add_patch(plt.Rectangle((0.2, 1.2), 11.6, 1.4, 
                              facecolor='#ECF0F1', edgecolor='#BDC3C7', 
                              linewidth=1.5, alpha=0.9))
    
    ax.text(6, 2.4, "Word Importance Legend", ha='center', va='center',
            fontsize=14, weight='bold', color='#2C3E50')
    
    legend_items = [
        ("CRITICAL", "#E74C3C", "white"),
        ("HIGH", "#E67E22", "white"), 
        ("MEDIUM", "#F39C12", "black"),
        ("LOW", "#27AE60", "white"),
        ("MINIMAL", "#BDC3C7", "#2C3E50")
    ]
    
    for i, (label, color, text_color) in enumerate(legend_items):
        x_pos = 0.8 + i * legend_spacing
        legend_box = dict(boxstyle="round,pad=0.3", facecolor=color, 
                         edgecolor='#2C3E50', linewidth=1.2, alpha=0.95)
        ax.text(x_pos, legend_y, label, fontsize=11, weight='bold',
                bbox=legend_box, ha='center', va='center', color=text_color)
    
    # Add key statistics in the legend area
    stats_text = f"Total words: {len(words)} | Max score: {max(scores):.3f} | Avg score: {np.mean(scores):.3f}"
    ax.text(6, 1.4, stats_text, ha='center', va='center',
            fontsize=11, color='#7F8C8D', style='italic')

def _create_word_importance_chart(ax, words, scores):
    """Create a polished word importance bar chart for poster presentation"""
    # Get top words
    word_scores = list(zip(words, scores))
    word_scores.sort(key=lambda x: x[1], reverse=True)
    
    top_n = min(8, len(word_scores))  # Reduced for cleaner look
    top_words, top_scores = zip(*word_scores[:top_n])
    
    # Enhanced color scheme
    colors = ['#FF4444', '#FF6644', '#FF8844', '#FFAA44', 
              '#FFCC44', '#DDCC44', '#BBAA44', '#99AA44'][:top_n]
    
    bars = ax.barh(range(len(top_words)), top_scores, color=colors, alpha=0.9, 
                   edgecolor='#333', linewidth=1.2)
    
    # Clean word labels (truncate if too long)
    clean_words = [w[:12] + '...' if len(w) > 12 else w for w in top_words]
    ax.set_yticks(range(len(clean_words)))
    ax.set_yticklabels(clean_words, fontsize=11, weight='bold')
    ax.set_xlabel('Importance Score', fontsize=11, weight='bold', color='#333')
    ax.set_title('Top Important Words', fontsize=13, weight='bold', pad=12, color='#333')
    ax.invert_yaxis()
    
    # Enhanced value labels
    for i, (bar, score) in enumerate(zip(bars, top_scores)):
        ax.text(score + 0.005, i, f'{score:.3f}', 
                va='center', ha='left', fontsize=10, weight='bold', color='#333')
    
    # Styling improvements
    ax.grid(True, axis='x', alpha=0.3, linestyle='--')
    ax.set_axisbelow(True)
    for spine in ax.spines.values():
        spine.set_edgecolor('#333')
        spine.set_linewidth(1.2)
    ax.tick_params(colors='#333', labelsize=10)

def _create_stats_summary(ax, scores, words, pred_label, model_name):
    """Create a clean statistics summary panel for poster presentation"""
    ax.set_xlim(0, 1)
    ax.set_ylim(0, 1)
    ax.axis('off')
    
    # Title
    ax.text(0.5, 0.95, 'Analysis Statistics', ha='center', va='top', 
            fontsize=14, weight='bold', color='#2C3E50',
            bbox=dict(boxstyle="round,pad=0.3", facecolor='#ECF0F1', 
                     edgecolor='#34495E', linewidth=2))
    
    # Key statistics with cleaner formatting
    stats_text = f"""Model: {model_name}
Prediction: {pred_label}

STATISTICS:
Words analyzed: {len(words)}
Max importance: {max(scores):.3f}
Mean importance: {np.mean(scores):.3f}
Std deviation: {np.std(scores):.3f}

DISTRIBUTION:
Critical (>0.80): {sum(1 for s in scores if s > 0.80)}
High (0.65-0.80): {sum(1 for s in scores if 0.65 <= s <= 0.80)}
Medium (0.45-0.65): {sum(1 for s in scores if 0.45 <= s <= 0.65)}
Low (0.25-0.45): {sum(1 for s in scores if 0.25 <= s <= 0.45)}
Minimal (<0.25): {sum(1 for s in scores if s < 0.25)}"""
    
    ax.text(0.05, 0.85, stats_text.strip(), ha='left', va='top', 
            fontsize=10, color='#2C3E50', linespacing=1.5,
            bbox=dict(boxstyle="round,pad=0.4", facecolor='#F8F9FA', 
                     edgecolor='#BDC3C7', linewidth=1, alpha=0.95))

def _create_attribution_distribution(ax, scores, words):
    """Create attribution score distribution"""
    ax.hist(scores, bins=20, color='skyblue', alpha=0.7, edgecolor='black')
    ax.set_xlabel('Attribution Score', fontsize=11)
    ax.set_ylabel('Number of Words', fontsize=11)
    ax.set_title('Distribution of Word Importance Scores', fontsize=12, weight='bold')
    ax.grid(True, alpha=0.3)
    
    # Add statistics
    mean_score = np.mean(scores)
    std_score = np.std(scores)
    ax.axvline(mean_score, color='red', linestyle='--', linewidth=2, label=f'Mean: {mean_score:.3f}')
    ax.axvline(mean_score + std_score, color='orange', linestyle='--', alpha=0.7, label=f'+1σ: {mean_score + std_score:.3f}')
    ax.legend(fontsize=10)

print("✅ GradCAM implementation ready")

✅ GradCAM implementation ready


## 6. 🎛️ Interactive Dashboard

In [9]:
class GeneralizedExplainabilityDashboard:
    """Interactive dashboard for model explainability analysis with model selection"""
    
    def __init__(self, available_models, validation_texts, validation_labels, validation_labels_text):
        self.available_models = available_models
        self.validation_texts = validation_texts
        self.validation_labels = validation_labels
        self.validation_labels_text = validation_labels_text
        self.current_model_wrapper = None
        self.setup_data()
        self.create_interface()
    
    def setup_data(self):
        """Setup data for mistake analysis (will be updated when model changes)"""
        self.incorrect_indices = []
        print("📊 Dashboard initialized - select a model to analyze mistakes")
    
    def update_mistake_data(self):
        """Update mistake analysis data when model changes"""
        if self.current_model_wrapper is None:
            return
        
        print(f"🔄 Analyzing mistakes for {self.current_model_wrapper.model_name}...")
        predictions = self.current_model_wrapper.predict_class(self.validation_texts)
        self.incorrect_indices = np.where(predictions != np.array(self.validation_labels))[0]
        
        # Update mistake selector options
        mistake_options = [(f"Mistake {i+1}: {self.validation_texts[idx][:50]}...", i) 
                          for i, idx in enumerate(self.incorrect_indices[:20])]  # Limit for performance
        
        if len(mistake_options) == 0:
            mistake_options = [("No mistakes found!", 0)]
        
        self.mistake_selector.options = mistake_options
        print(f"📊 Found {len(self.incorrect_indices)} mistakes out of {len(self.validation_texts)} samples")
    
    def create_interface(self):
        """Create the dashboard interface"""
        # Model selector
        model_options = [(name, name) for name in self.available_models.keys()]
        self.model_selector = widgets.Dropdown(
            options=model_options,
            value=None,  # No default selection
            description='Select Model:',
            style={'description_width': '120px'},
            layout=widgets.Layout(width='400px')
        )
        
        # Model type selector
        self.model_type_selector = widgets.ToggleButtons(
            options=[('PyTorch', 'pytorch'), ('ONNX', 'onnx')],
            value='pytorch',  # Default to PyTorch
            description='Model Type:',
            style={'description_width': '120px'},
            disabled=True  # Enable after model selection
        )
        
        # Input mode selector
        self.input_mode = widgets.ToggleButtons(
            options=[('Custom Text', 'custom'), ('Analyze Mistakes', 'mistakes')],
            value='custom',  # Start with custom text since no model is selected yet
            description='Analysis Mode:',
            style={'description_width': '120px'}
        )
        
        # Mistake selector
        self.mistake_selector = widgets.Dropdown(
            options=[("Select a model first", 0)],
            description='Select Mistake:',
            style={'description_width': '120px'},
            layout=widgets.Layout(width='600px'),
            disabled=True  # Enable after model loading
        )
        
        # Custom text input
        self.text_input = widgets.Textarea(
            value='The company reported strong quarterly earnings with revenue growth exceeding expectations.',
            placeholder='Enter financial text to analyze...',
            description='Text:',
            layout=widgets.Layout(width='100%', height='80px'),
            style={'description_width': '60px'}
        )
        
        # Control buttons
        self.load_model_button = widgets.Button(
            description='🔄 Load Model',
            button_style='info',
            layout=widgets.Layout(width='150px'),
            disabled=True  # Enable after model selection
        )
        
        self.analyze_button = widgets.Button(
            description='🚀 Analyze',
            button_style='primary',
            layout=widgets.Layout(width='120px'),
            disabled=True  # Enable after model loading
        )
        
        self.clear_button = widgets.Button(
            description='🧹 Clear',
            button_style='warning',
            layout=widgets.Layout(width='120px')
        )
        
        # Method selection
        available_methods = []
        if shap_available:
            available_methods.append('SHAP')
        if lime_available:
            available_methods.append('LIME')
        available_methods.append('Attention')  # Will check PyTorch requirement later
        if captum_available:
            available_methods.append('GradCAM')  # Will check PyTorch requirement later
        
        self.method_selector = widgets.SelectMultiple(
            options=available_methods,
            value=[available_methods[0]] if available_methods else [],
            description='Methods:',
            style={'description_width': '80px'},
            layout=widgets.Layout(width='200px', height='120px')
        )
        
        # Output tabs
        self.output_tabs = widgets.Tab()
        self.method_outputs = {}
        
        # Status output
        self.status_output = widgets.Output()
        
        # Event handlers
        self.model_selector.observe(self.on_model_change, names='value')
        self.model_type_selector.observe(self.on_model_type_change, names='value')
        self.input_mode.observe(self.on_mode_change, names='value')
        self.load_model_button.on_click(self.on_load_model)
        self.analyze_button.on_click(self.on_analyze)
        self.clear_button.on_click(self.on_clear)
        
        # Initialize with first model if available
        if model_options:
            # Don't auto-select, let user choose
            pass
    
    def on_model_change(self, change):
        """Handle model selection change"""
        if change['new'] is None:
            self.model_type_selector.disabled = True
            self.load_model_button.disabled = True
            return
            
        model_info = self.available_models[change['new']]
        available_types = []
        
        if model_info['has_pytorch']:
            available_types.append(('PyTorch', 'pytorch'))
        if model_info['has_onnx']:
            available_types.append(('ONNX', 'onnx'))
        
        self.model_type_selector.options = available_types
        self.model_type_selector.disabled = False
        self.load_model_button.disabled = False
        
        # Set default selection
        if available_types:
            if model_info['has_pytorch']:
                self.model_type_selector.value = 'pytorch'
            else:
                self.model_type_selector.value = 'onnx'
        
        with self.status_output:
            clear_output(wait=True)
            print(f"📋 Selected: {change['new']}")
            print(f"✅ Available formats: {', '.join([t[0] for t in available_types])}")
            print("👆 Click 'Load Model' to initialize")
    
    def on_model_type_change(self, change):
        """Handle model type change"""
        if self.model_selector.value and change['new']:
            with self.status_output:
                clear_output(wait=True)
                print(f"📋 Model: {self.model_selector.value}")
                print(f"🔧 Type: {change['new']}")
                print("👆 Click 'Load Model' to initialize")
    
    def on_mode_change(self, change):
        """Handle input mode change"""
        # Interface will be updated in display method
        pass
    
    def on_load_model(self, button):
        """Load the selected model"""
        try:
            if not self.model_selector.value:
                with self.status_output:
                    clear_output(wait=True)
                    print("❌ Please select a model first!")
                return
            
            model_info = self.available_models[self.model_selector.value]
            model_type = self.model_type_selector.value
            
            with self.status_output:
                clear_output(wait=True)
                print(f"🔄 Loading {model_info['name']} ({model_type})...")
                print("⏳ This may take a moment...")
            
            # Load the model
            self.current_model_wrapper = UniversalModelWrapper(model_info, model_type)
            
            # Update mistake data
            self.update_mistake_data()
            
            # Enable analysis controls
            self.analyze_button.disabled = False
            self.mistake_selector.disabled = False
            
            # Update method availability based on model type
            available_methods = []
            if shap_available:
                available_methods.append('SHAP')
            if lime_available:
                available_methods.append('LIME')
            if model_type == 'pytorch':  # Attention and GradCAM only work with PyTorch
                available_methods.append('Attention')
                if captum_available:
                    available_methods.append('GradCAM')
            
            self.method_selector.options = available_methods
            if available_methods:
                self.method_selector.value = [available_methods[0]]
            
            with self.status_output:
                clear_output(wait=True)
                print("✅ Model loaded successfully!")
                print(f"📊 Model: {model_info['name']} ({model_type})")
                print(f"🎯 Classes: {', '.join(self.current_model_wrapper.get_class_names())}")
                print(f"📋 Available methods: {', '.join(available_methods)}")
                print(f"🔍 Found {len(self.incorrect_indices)} mistakes for analysis")
                print("\n🚀 Ready for analysis! Select methods and click 'Analyze'")
                
        except Exception as e:
            with self.status_output:
                clear_output(wait=True)
                print(f"❌ Error loading model: {str(e)}")
                print("💡 Try selecting a different model or format")
    
    def on_analyze(self, button):
        """Handle analyze button click"""
        try:
            if self.current_model_wrapper is None:
                with self.status_output:
                    clear_output(wait=True)
                    print("❌ Please load a model first!")
                return
            
            # Get text and prediction info
            if self.input_mode.value == 'mistakes':
                if not self.incorrect_indices.size:
                    with self.status_output:
                        clear_output(wait=True)
                        print("❌ No mistakes found for this model!")
                    return
                
                mistake_idx = self.mistake_selector.value
                if mistake_idx >= len(self.incorrect_indices):
                    with self.status_output:
                        clear_output(wait=True)
                        print("❌ Invalid mistake selection!")
                    return
                
                sample_idx = self.incorrect_indices[mistake_idx]
                text = self.validation_texts[sample_idx]
                true_label = self.validation_labels_text[sample_idx]
                pred_class = int(self.current_model_wrapper.predict_class(text)[0])
                pred_label = self.current_model_wrapper.label_encoder.inverse_transform([pred_class])[0]
            else:
                text = self.text_input.value.strip()
                if not text:
                    with self.status_output:
                        clear_output(wait=True)
                        print("❌ Please enter some text to analyze!")
                    return
                pred_class = int(self.current_model_wrapper.predict_class(text)[0])
                pred_label = self.current_model_wrapper.label_encoder.inverse_transform([pred_class])[0]
                true_label = "Unknown"
            
            # Generate explanations for selected methods
            self.generate_explanations(text, pred_label, true_label, pred_class)
            
        except Exception as e:
            with self.status_output:
                clear_output(wait=True)
                print(f"❌ Error during analysis: {str(e)}")
                import traceback
                print(f"🔍 Details: {traceback.format_exc()}")
    
    def generate_explanations(self, text, pred_label, true_label, pred_class):
        """Generate selected explanations for the text"""
        selected_methods = list(self.method_selector.value)
        
        if not selected_methods:
            with self.status_output:
                clear_output(wait=True)
                print("❌ Please select at least one explanation method!")
            return
        
        # Create method outputs
        self.method_outputs = {}
        for method in selected_methods:
            self.method_outputs[method] = widgets.Output()
        
        self.output_tabs.children = list(self.method_outputs.values())
        for i, method in enumerate(self.method_outputs.keys()):
            self.output_tabs.set_title(i, f'{method}')
        
        # Create header
        header_html = f"""
        <div style='background: #f8f9fa; padding: 15px; margin: 10px 0; border-radius: 8px; 
                    border-left: 4px solid #007bff; box-shadow: 0 2px 8px rgba(0,0,0,0.1);'>
            <h4 style='margin: 0 0 10px 0; color: #007bff;'>📝 Analysis Summary</h4>
            <p style='margin: 5px 0;'><strong>Text:</strong> <em>"{text[:200]}{'...' if len(text) > 200 else ''}"</em></p>
            <p style='margin: 5px 0;'><strong>Model:</strong> {self.current_model_wrapper.model_name} ({self.current_model_wrapper.model_type})</p>
            <p style='margin: 5px 0;'><strong>Prediction:</strong> 
               <span style='color: #28a745; font-weight: bold;'>{pred_label}</span></p>
            {f'<p style="margin: 5px 0;"><strong>True Label:</strong> <span style="color: #dc3545; font-weight: bold;">{true_label}</span></p>' if true_label != "Unknown" else ''}
        </div>
        """
        
        with self.status_output:
            clear_output(wait=True)
            print("🧠 Generating explanations...")
            print(f"📋 Methods: {', '.join(selected_methods)}")
            print("⏳ This may take a moment...")
        
        # Generate explanations for each selected method
        for i, method in enumerate(selected_methods):
            with self.method_outputs[method]:
                display(HTML(header_html))
                
                try:
                    # print(f"🔍 Running {method} analysis...")
                    if method == 'SHAP':
                        explain_with_shap(self.current_model_wrapper, text, pred_class)
                    elif method == 'LIME':
                        explain_with_lime(self.current_model_wrapper, text)
                    elif method == 'Attention':
                        explain_with_attention(self.current_model_wrapper, text)
                    elif method == 'GradCAM':
                        explain_with_gradcam(self.current_model_wrapper, text)
                    print(f"✅ {method} analysis complete!")
                except Exception as e:
                    print(f"❌ {method} failed: {str(e)}")
                    import traceback
                    print(f"🔍 Error details: {traceback.format_exc()}")
        
        with self.status_output:
            clear_output(wait=True)
            print("✅ Analysis complete!")
            print("📊 Check the tabs above for detailed explanations")
            print("🔄 Modify text or methods and click 'Analyze' again for new results")
    
    def on_clear(self, button):
        """Clear all outputs"""
        for output in self.method_outputs.values():
            with output:
                clear_output()
        
        with self.status_output:
            clear_output(wait=True)
            print("🧹 All results cleared! Ready for new analysis.")
    
    def display(self):
        """Display the dashboard"""
        # Title
        title = widgets.HTML(
            value="""
            <div style='text-align: center; background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
                        color: white; padding: 20px; border-radius: 10px; margin-bottom: 20px;'>
                <h2 style='margin: 0; font-size: 24px;'>🧠 Generalised Financial Sentiment Explainability Dashboard</h2>
                <p style='margin: 10px 0 0 0; opacity: 0.9;'>Universal AI model explanation and analysis</p>
            </div>
            """
        )
        
        # Model selection section
        model_section = widgets.VBox([
            widgets.HTML("<h3>🔧 Model Configuration</h3>"),
            widgets.HBox([self.model_selector, self.model_type_selector, self.load_model_button],
                        layout=widgets.Layout(align_items='center'))
        ])
        
        # Create dynamic input container
        def get_input_widget():
            if self.input_mode.value == 'mistakes':
                return self.mistake_selector
            else:
                return self.text_input
        
        # Input section
        input_container = widgets.VBox([
            widgets.HTML("<h3>📝 Analysis Input</h3>"),
            self.input_mode,
            get_input_widget()
        ])
        
        # Update input container based on mode
        def update_input_display(*args):
            input_container.children = [
                widgets.HTML("<h3>📝 Analysis Input</h3>"),
                self.input_mode, 
                get_input_widget()
            ]
        
        self.input_mode.observe(update_input_display, names='value')
        
        # Controls section
        controls_section = widgets.VBox([
            widgets.HTML("<h3>⚡ Analysis Controls</h3>"),
            widgets.HBox([
                self.method_selector,
                widgets.VBox([self.analyze_button, self.clear_button],
                           layout=widgets.Layout(align_items='center'))
            ])
        ])
        
        # Status section
        status_section = widgets.VBox([
            widgets.HTML("<h3>📊 Status</h3>"),
            self.status_output
        ])
        
        # Results section
        results_section = widgets.VBox([
            widgets.HTML("<h3>📈 Results</h3>"),
            self.output_tabs
        ])
        
        # Main dashboard
        dashboard = widgets.VBox([
            title,
            model_section,
            input_container,
            controls_section,
            status_section,
            results_section
        ], layout=widgets.Layout(padding='10px'))
        
        return dashboard

print("✅ Enhanced Dashboard class defined")

✅ Enhanced Dashboard class defined


In [10]:
# Initialize and display the dashboard
try:
    # Clear any existing dashboard
    from IPython.display import clear_output
    clear_output(wait=True)
    
    dashboard = GeneralizedExplainabilityDashboard(
        available_models=available_models,
        validation_texts=validation_texts,
        validation_labels=validation_labels,
        validation_labels_text=validation_labels_text
    )
    
    print("🚀 Enhanced Dashboard initialized successfully!")
    print(f" Found {len(available_models)} models ready for analysis")
    print("\n  Quick Start Instructions:")
    print("1. 🎯 Select a model from the dropdown")
    print("2. 🔧 Choose PyTorch or ONNX format") 
    print("3. 🔄 Click 'Load Model' to initialize")
    print("4. 📝 Enter text or select 'Analyze Mistakes'")
    print("5. ✅ Pick explanation methods and click 'Analyze'")
    print("\n" + "="*60)
    
    # Display the dashboard with proper widget rendering
    dashboard_widget = dashboard.display()
    display(dashboard_widget)
    
    print("\n🎉 Dashboard ready! Start by selecting a model above.")
    
except Exception as e:
    print(f"❌ Error initializing dashboard: {str(e)}")
    import traceback
    print(f"🔍 Full error details:\n{traceback.format_exc()}")
    print("\n🔧 Fallback: Individual analysis functions available:")
    print("- explain_with_shap(model_wrapper, text, pred_class)")
    print("- explain_with_lime(model_wrapper, text)")
    print("- explain_with_attention(model_wrapper, text)")
    print("- explain_with_gradcam(model_wrapper, text)")

📊 Dashboard initialized - select a model to analyze mistakes
🚀 Enhanced Dashboard initialized successfully!
 Found 7 models ready for analysis

  Quick Start Instructions:
1. 🎯 Select a model from the dropdown
2. 🔧 Choose PyTorch or ONNX format
3. 🔄 Click 'Load Model' to initialize
4. 📝 Enter text or select 'Analyze Mistakes'
5. ✅ Pick explanation methods and click 'Analyze'



VBox(children=(HTML(value="\n            <div style='text-align: center; background: linear-gradient(135deg, #…


🎉 Dashboard ready! Start by selecting a model above.


## 7. 📚 Example Usage & Next Steps

This generalized explainability notebook provides comprehensive model interpretation capabilities across your entire model collection. Here's how to make the most of it:

### 🎯 Key Features
- **Universal Model Support**: Works with both PyTorch and ONNX models from your collection
- **Four Explainability Methods**: SHAP, LIME, Attention Visualization, and GradCAM
- **Interactive Dashboard**: Point-and-click interface for quick analysis
- **Mistake Analysis**: Automatically finds and analyzes model misclassifications
- **Custom Text Analysis**: Test any financial text with multiple explanation methods

### 🔧 Configuration-Driven Approach
All settings are loaded from `../config/pipeline_config.json`, ensuring consistency with your training and evaluation pipeline:
- Model paths and metadata
- Validation datasets
- Analysis parameters
- Output configurations

### 🚀 For Fine-Tuning Integration
This notebook is designed to support your fine-tuning workflow:

1. **Pre-Fine-Tuning Analysis**: Understand what your base models focus on
2. **Mistake Identification**: Find systematic errors to address in fine-tuning
3. **Post-Fine-Tuning Comparison**: Compare explanations before and after fine-tuning
4. **Method Selection**: Choose the best explanation methods for your specific use case

### 🎨 Visualization Options
- **SHAP**: Feature importance with bidirectional influence
- **LIME**: Local interpretable model-agnostic explanations
- **Attention**: Attention weight heatmaps (PyTorch models only)
- **GradCAM**: Gradient-based class activation mapping (PyTorch models only)
- **Interactive Interface**: Compare multiple methods side-by-side

### 📈 Performance Considerations
- Models are loaded on-demand to manage memory
- ONNX models typically provide faster inference
- Explanation methods have different computational requirements
- Results are cached for repeated analysis

### 🔄 Next Steps
1. Run the dashboard above to start exploring your models
2. Analyze systematic mistakes across different model architectures
3. Use insights to guide fine-tuning data augmentation
4. Compare explanation consistency across model types
5. Document findings for model selection decisions

Happy exploring! 🧠✨