# 🧠 Explainability-Driven Fine-Tuning for Financial NLP Models

## Overview
This notebook demonstrates how to leverage explainability methods to guide the fine-tuning of financial NLP models. Rather than treating explainability as a post-training analysis tool, we use it as an integral part of the fine-tuning process to create more robust and interpretable models.

### Key Objectives
1. **Identify Model Weaknesses**: Use explainability to discover systematic errors and attention biases
2. **Design Targeted Fine-Tuning**: Create data augmentation and loss strategies based on explainability insights
3. **Optimize for Interpretability**: Balance performance improvements with explainable decision boundaries
4. **Quantify Explainability Improvements**: Track changes in both accuracy and interpretability metrics

### Methodology
This notebook builds on the comprehensive explainability analysis from notebook #5, focusing specifically on using those insights to drive fine-tuning decisions. We'll implement:

- **Feature Importance-Based Augmentation**: Targeted data augmentation based on SHAP/LIME insights
- **Attention-Guided Training**: Modified attention mechanisms based on attention visualization  
- **Counterfactual Fine-Tuning**: Training with explainability-generated counterfactual examples
- **Attribution Preservation**: Loss terms that encourage maintaining useful attribution patterns

### Academic Focus
This research-oriented approach provides:
- Systematic methodology for explainability-driven optimization
- Quantitative metrics for measuring explainability impact
- Comparative analysis of different fine-tuning strategies
- Visual documentation of improvement patterns

### Pipeline Integration
The notebook integrates with the existing model training pipeline and reuses explainability tools from previous notebooks to maintain consistency across the workflow.

In [1]:
# Import necessary libraries
import sys
import os
sys.path.append("../")

# Pipeline utilities - reuse existing infrastructure
from src.pipeline_utils import ConfigManager, StateManager, LoggingManager

# Core libraries
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import warnings
import pickle
import json
import time
from tqdm.auto import tqdm
from typing import Dict, List, Optional, Tuple, Any, Union
from collections import defaultdict, Counter
import random

# Suppress warnings
warnings.filterwarnings('ignore')

# Model and tokenizer for fine-tuning
from transformers import (
    AutoTokenizer, 
    AutoModelForSequenceClassification,
    TrainingArguments,
    Trainer,
    DataCollatorWithPadding,
    EarlyStoppingCallback,
    get_linear_schedule_with_warmup
)
from datasets import Dataset
from torch.utils.data import DataLoader
from torch.optim import AdamW

# Explainability libraries - only import what we need
print("🔍 Importing explainability libraries...")
try:
    import shap
    shap_available = True
    print("✅ SHAP available")
except ImportError:
    print("⚠️ SHAP not available. Install with: pip install shap")
    shap_available = False

try:
    from lime.lime_text import LimeTextExplainer
    lime_available = True
    print("✅ LIME available")
except ImportError:
    print("⚠️ LIME not available. Install with: pip install lime")
    lime_available = False

# Visualization and interactivity
import ipywidgets as widgets
from IPython.display import display, HTML, clear_output

# Initialize configuration managers
config = ConfigManager("../config/pipeline_config.json")
state = StateManager("../config/pipeline_state.json")
logger_manager = LoggingManager(config, 'explainability_fine_tuning')
logger = logger_manager.get_logger()

print("✅ All libraries imported successfully")
print(f"📂 Models directory: {config.get('models', {}).get('output_dir', 'models')}")
print(f"📊 Data directory: {config.get('data', {}).get('processed_data_dir', 'data/processed')}")

logger.info("🔍 Starting Explainability-Driven Fine-Tuning Pipeline")

🔍 Importing explainability libraries...


2025-08-11 14:50:10,637 - pipeline.explainability_fine_tuning - INFO - 🔍 Starting Explainability-Driven Fine-Tuning Pipeline


✅ SHAP available
✅ LIME available
✅ All libraries imported successfully
📂 Models directory: models
📊 Data directory: data/processed


## 1. 🔍 Load Models & Data from Previous Notebooks

We'll leverage the model discovery and data loading logic from the previous explainability notebook to avoid code duplication.

In [2]:
# Load models and data using existing pipeline infrastructure
print("🔍 Discovering available models...")

# Model discovery (reuse logic from notebook 5)
models_config = config.get('models', {})
models_dir = Path(f"../{models_config.get('output_dir', 'models')}")
print(f"📂 Models directory: {models_dir}")

available_models = {}
if models_dir.exists():
    for model_path in models_dir.iterdir():
        if not model_path.is_dir() or model_path.name.startswith('.'):
            continue
            
        model_name = model_path.name
        config_file = model_path / "config.json"
        label_encoder_file = model_path / "label_encoder.pkl"
        pytorch_files = list(model_path.glob("*.safetensors")) + list(model_path.glob("pytorch_model.bin"))
        
        if config_file.exists() and label_encoder_file.exists() and pytorch_files:
            available_models[model_name] = {
                'name': model_name,
                'path': model_path,
                'config_file': config_file,
                'label_encoder_file': label_encoder_file,
                'pytorch_files': pytorch_files
            }
            print(f"   ✅ Found: {model_name}")

print(f"📊 Total models available: {len(available_models)}")

# Load training data
data_config = config.get('data', {})
processed_data_dir = data_config.get('processed_data_dir', 'data/processed')

# Try to load training data
train_path = f"../{processed_data_dir}/train.csv"
val_path = f"../{processed_data_dir}/validation.csv"

print(f"📊 Loading training data from: {processed_data_dir}")

# Load training data with fallback
try:
    train_df = pd.read_csv(train_path)
    val_df = pd.read_csv(val_path)
    print(f"✅ Loaded {len(train_df)} training samples, {len(val_df)} validation samples")
except FileNotFoundError:
    print("⚠️ Standard data files not found, creating sample data...")
    # Create sample data for testing
    sample_data = {
        'text': [
            "The company reported strong quarterly earnings with revenue growth exceeding expectations.",
            "Market volatility continues to pose challenges for the financial sector.",
            "The business maintained steady performance despite economic headwinds.",
            "Declining sales figures indicate potential market challenges ahead.",
            "The merger announcement boosted investor confidence significantly.",
            "Regulatory changes may impact future profitability.",
            "Strong demand drove record sales this quarter.",
            "Economic uncertainty affects investor sentiment."
        ] * 20,  # Repeat for more samples
        'label': ["positive", "negative", "neutral", "negative", "positive", "negative", "positive", "negative"] * 20
    }
    
    train_df = pd.DataFrame(sample_data)
    val_df = train_df.sample(frac=0.3, random_state=42)  # Use 30% for validation
    train_df = train_df.drop(val_df.index)
    
    print(f"✅ Created sample data: {len(train_df)} training, {len(val_df)} validation samples")

# Extract features and labels
train_texts = train_df['text'].tolist()
val_texts = val_df['text'].tolist()

# Get unique labels and create label encoders
unique_labels = sorted(set(train_df['label'].unique()) | set(val_df['label'].unique()))
label_to_id = {label: i for i, label in enumerate(unique_labels)}
id_to_label = {i: label for label, i in label_to_id.items()}

train_labels = [label_to_id[label] for label in train_df['label']]
val_labels = [label_to_id[label] for label in val_df['label']]

print(f"🏷️ Labels: {', '.join(unique_labels)}")
print(f"📋 Data ready: {len(train_texts)} training, {len(val_texts)} validation samples")

logger.info("Model and data discovery completed")

2025-08-11 14:50:10,675 - pipeline.explainability_fine_tuning - INFO - Model and data discovery completed


🔍 Discovering available models...
📂 Models directory: ../models
   ✅ Found: tinybert-financial-classifier-fine-tuned
   ✅ Found: all-MiniLM-L6-v2-financial-sentiment
   ✅ Found: distilbert-financial-sentiment
   ✅ Found: finbert-tone-financial-sentiment
   ✅ Found: tinybert-financial-classifier
   ✅ Found: tinybert-financial-classifier-pruned
   ✅ Found: mobilebert-uncased-financial-sentiment
📊 Total models available: 7
📊 Loading training data from: data/processed
✅ Loaded 4361 training samples, 485 validation samples
🏷️ Labels: negative, neutral, positive
📋 Data ready: 4361 training, 485 validation samples


## 2. 🧠 Explainability-Driven Fine-Tuning Core

This section implements the core methodology for using explainability insights to guide fine-tuning decisions.

In [None]:
class ExplainabilityDrivenFineTuner:
    """
    Core class for explainability-driven fine-tuning optimization.
    Uses insights from SHAP and LIME to guide fine-tuning strategies.
    """
    
    def __init__(self, base_model_path, tokenizer, label_encoder, config):
        self.base_model_path = base_model_path
        self.tokenizer = tokenizer
        self.label_encoder = label_encoder
        self.config = config
        
        # Load base model with eager attention for explainability
        self.model = AutoModelForSequenceClassification.from_pretrained(
            base_model_path,
            attn_implementation="eager",  # For explainability compatibility
            num_labels=len(label_encoder.classes_)
        )
        
        # Initialize explainability analyzers
        self.shap_analyzer = None
        self.lime_analyzer = None
        
        # Track insights and improvements
        self.explainability_insights = {
            'mistake_patterns': {},
            'token_importance': {},
            'decision_boundaries': {},
            'class_confusion': {}
        }
        
        self.fine_tuning_strategy = {
            'data_augmentation': [],
            'training_focus': [],
            'hyperparameters': {}
        }
        
        print(f"✅ Initialized ExplainabilityDrivenFineTuner for {Path(base_model_path).name}")
    
    def analyze_baseline_performance(self, val_texts, val_labels, max_samples=100):
        """
        Analyze baseline model performance to identify areas for improvement
        """
        print("🔍 Analyzing baseline model performance...")
        
        # Sample validation data for performance if needed
        if len(val_texts) > max_samples:
            indices = random.sample(range(len(val_texts)), max_samples)
            sampled_texts = [val_texts[i] for i in indices]
            sampled_labels = [val_labels[i] for i in indices]
        else:
            sampled_texts = val_texts
            sampled_labels = val_labels
        
        # Get predictions
        predictions = []
        confidences = []
        
        self.model.eval()
        with torch.no_grad():
            for text in tqdm(sampled_texts, desc="Getting predictions"):
                inputs = self.tokenizer(text, return_tensors='pt', truncation=True, max_length=512)
                outputs = self.model(**inputs)
                probs = torch.softmax(outputs.logits, dim=-1)
                pred_class = torch.argmax(probs, dim=-1).item()
                confidence = torch.max(probs).item()
                
                predictions.append(pred_class)
                confidences.append(confidence)
        
        # Calculate metrics
        from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
        
        accuracy = accuracy_score(sampled_labels, predictions)
        report = classification_report(sampled_labels, predictions, 
                                     target_names=self.label_encoder.classes_, 
                                     output_dict=True)
        conf_matrix = confusion_matrix(sampled_labels, predictions)
        
        # Identify mistakes and low-confidence predictions
        mistakes = []
        low_confidence = []
        
        for i, (text, true_label, pred_label, conf) in enumerate(zip(sampled_texts, sampled_labels, predictions, confidences)):
            if true_label != pred_label:
                mistakes.append({
                    'index': i,
                    'text': text,
                    'true_label': true_label,
                    'pred_label': pred_label,
                    'confidence': conf,
                    'true_class_name': self.label_encoder.classes_[true_label],
                    'pred_class_name': self.label_encoder.classes_[pred_label]
                })
            
            if conf < 0.7:  # Low confidence threshold
                low_confidence.append({
                    'index': i,
                    'text': text,
                    'true_label': true_label,
                    'pred_label': pred_label,
                    'confidence': conf
                })
        
        baseline_analysis = {
            'accuracy': accuracy,
            'classification_report': report,
            'confusion_matrix': conf_matrix,
            'mistakes': mistakes,
            'low_confidence': low_confidence,
            'avg_confidence': np.mean(confidences)
        }
        
        print(f"📊 Baseline Analysis:")
        print(f"   • Accuracy: {accuracy:.3f}")
        print(f"   • Mistakes: {len(mistakes)} out of {len(sampled_texts)}")
        print(f"   • Low confidence predictions: {len(low_confidence)}")
        print(f"   • Average confidence: {np.mean(confidences):.3f}")
        
        return baseline_analysis
    
    def extract_explainability_insights(self, mistakes, train_texts=None, train_labels=None, max_samples=20):
        """
        Extract insights from explainability analysis of mistakes
        """
        print("🧠 Extracting explainability insights from mistakes...")
        
        if len(mistakes) == 0:
            print("✅ No mistakes found - model performance is perfect!")
            return {}
        
        # Sample mistakes if too many
        if len(mistakes) > max_samples:
            sampled_mistakes = random.sample(mistakes, max_samples)
            print(f"📊 Analyzing {max_samples} mistakes out of {len(mistakes)}")
        else:
            sampled_mistakes = mistakes
        
        insights = {
            'mistake_patterns': {},
            'token_importance': {},
            'class_confusion': {}
        }
        
        # Pattern analysis - group mistakes by confusion type
        for mistake in sampled_mistakes:
            true_class = mistake['true_class_name']
            pred_class = mistake['pred_class_name']
            pattern = f"{true_class}_to_{pred_class}"
            
            if pattern not in insights['mistake_patterns']:
                insights['mistake_patterns'][pattern] = []
            insights['mistake_patterns'][pattern].append(mistake['text'])
        
        # SHAP analysis if available
        if shap_available:
            print("📊 Running SHAP analysis on mistakes...")
            try:
                # Initialize SHAP analyzer
                def predict_fn(texts):
                    predictions = []
                    self.model.eval()
                    with torch.no_grad():
                        for text in texts:
                            inputs = self.tokenizer(text, return_tensors='pt', truncation=True, max_length=512)
                            outputs = self.model(**inputs)
                            probs = torch.softmax(outputs.logits, dim=-1).numpy()
                            predictions.append(probs[0])
                    return np.array(predictions)
                
                self.shap_analyzer = shap.Explainer(predict_fn, self.tokenizer)
                
                # Analyze first few mistakes with SHAP
                mistake_texts = [m['text'] for m in sampled_mistakes[:5]]  # Limit for performance
                shap_values = self.shap_analyzer(mistake_texts)
                
                # Extract token importance patterns
                for i, text in enumerate(mistake_texts):
                    mistake = sampled_mistakes[i]
                    pred_class = mistake['pred_label']
                    
                    # Get SHAP values for the predicted class
                    values = shap_values[i, :, pred_class].values
                    tokens = shap_values[i].data
                    
                    # Store important tokens
                    for token, importance in zip(tokens, values):
                        if abs(importance) > 0.1 and token.strip():  # Significance threshold
                            if token not in insights['token_importance']:
                                insights['token_importance'][token] = []
                            insights['token_importance'][token].append(importance)
                
                print("✅ SHAP analysis completed")
                
            except Exception as e:
                print(f"⚠️ SHAP analysis failed: {e}")
        
        # LIME analysis if available
        if lime_available:
            print("🔍 Running LIME analysis on boundary cases...")
            try:
                self.lime_analyzer = LimeTextExplainer(class_names=list(self.label_encoder.classes_))
                
                def predict_fn_lime(texts):
                    predictions = []
                    self.model.eval()
                    with torch.no_grad():
                        for text in texts:
                            inputs = self.tokenizer(text, return_tensors='pt', truncation=True, max_length=512)
                            outputs = self.model(**inputs)
                            probs = torch.softmax(outputs.logits, dim=-1).numpy()
                            predictions.append(probs[0])
                    return np.array(predictions)
                
                # Analyze boundary cases (mistakes with moderate confidence)
                boundary_cases = [m for m in sampled_mistakes if 0.4 < m['confidence'] < 0.8][:3]
                if not boundary_cases:
                    print("⚠️ No boundary cases found for LIME analysis (0.4 < confidence < 0.8). Skipping LIME.")
                else:
                    for mistake in boundary_cases:
                        try:
                            explanation = self.lime_analyzer.explain_instance(
                                mistake['text'],
                                predict_fn_lime,
                                num_features=10
                            )
                            # Store boundary features
                            features = explanation.as_list(mistake['pred_label'])
                            for feature, importance in features:
                                if abs(importance) > 0.1:
                                    if feature not in insights['token_importance']:
                                        insights['token_importance'][feature] = []
                                    insights['token_importance'][feature].append(importance)
                        except Exception as lime_case_exc:
                            print(f"⚠️ LIME failed for text: '{mistake['text'][:60]}...' | Error: {lime_case_exc}")
                    print(f"✅ LIME analysis completed for {len(boundary_cases)} boundary case(s)")
            except Exception as e:
                print(f"⚠️ LIME analysis failed: {e}")
                import traceback
                print(f"🔍 LIME traceback: {traceback.format_exc()}")
        else:
            print("⚠️ LIME is not available. Please install with: pip install lime")
        
        # Store insights
        self.explainability_insights.update(insights)
        
        print(f"📋 Insights extracted:")
        print(f"   • Mistake patterns: {len(insights['mistake_patterns'])}")
        print(f"   • Important tokens identified: {len(insights['token_importance'])}")
        
        return insights
    
    def design_fine_tuning_strategy(self, insights):
        """
        Design comprehensive fine-tuning strategy based on explainability insights
        """
        print("🎯 Designing comprehensive fine-tuning strategy based on insights...")
        
        strategy = {
            'data_augmentation': [],
            'training_focus': [],
            'hyperparameters': {}
        }
        
        # 1. Address mistake patterns with prioritization
        if 'mistake_patterns' in insights:
            # Sort patterns by frequency (more mistakes = higher priority)
            pattern_priorities = sorted(insights['mistake_patterns'].items(), 
                                      key=lambda x: len(x[1]), reverse=True)
            
            print(f"📈 Prioritizing {len(pattern_priorities)} confusion patterns:")
            
            for pattern, examples in pattern_priorities:
                if len(examples) >= 1:  # Include all patterns, even single mistakes
                    priority = "HIGH" if len(examples) >= 5 else "MEDIUM" if len(examples) >= 3 else "LOW"
                    strategy['data_augmentation'].append({
                        'type': 'contrastive_examples',
                        'pattern': pattern,
                        'count': len(examples),
                        'priority': priority,
                        'recommendation': f"Generate contrastive examples for {pattern} confusion ({len(examples)} cases)"
                    })
                    print(f"   • {pattern}: {len(examples)} cases [{priority} priority]")
        
        # 2. Focus on problematic tokens with importance weighting
        if 'token_importance' in insights and insights['token_importance']:
            # Calculate average absolute importance for each token
            avg_importance = {token: np.mean(np.abs(values)) 
                            for token, values in insights['token_importance'].items()}
            
            # Get top tokens by importance
            top_tokens = sorted(avg_importance.items(), key=lambda x: x[1], reverse=True)[:15]
            
            if top_tokens:
                high_impact_tokens = [token for token, importance in top_tokens if importance > 0.1]
                medium_impact_tokens = [token for token, importance in top_tokens if 0.05 <= importance <= 0.1]
                
                if high_impact_tokens:
                    strategy['training_focus'].append({
                        'type': 'token_attention',
                        'tokens': high_impact_tokens,
                        'priority': 'HIGH',
                        'recommendation': f"Focus heavily on {len(high_impact_tokens)} high-impact tokens"
                    })
                
                if medium_impact_tokens:
                    strategy['training_focus'].append({
                        'type': 'token_attention',
                        'tokens': medium_impact_tokens,
                        'priority': 'MEDIUM',
                        'recommendation': f"Focus moderately on {len(medium_impact_tokens)} medium-impact tokens"
                    })
                
                print(f"🎯 Token focus strategy:")
                print(f"   • High-impact tokens: {len(high_impact_tokens)}")
                print(f"   • Medium-impact tokens: {len(medium_impact_tokens)}")
        
        # 3. Adaptive hyperparameters based on mistake complexity
        total_mistakes = sum(len(examples) for examples in insights.get('mistake_patterns', {}).values())
        mistake_complexity = len(insights.get('mistake_patterns', {}))
        
        # Adjust learning parameters based on complexity
        if mistake_complexity > 5:  # Many different confusion patterns
            learning_rate = 2e-5  # More conservative for complex patterns
            num_epochs = 6  # More epochs needed
        elif mistake_complexity > 3:  # Moderate complexity
            learning_rate = 3e-5  # Standard aggressive rate
            num_epochs = 5
        else:  # Simple patterns
            learning_rate = 4e-5  # More aggressive
            num_epochs = 4
        
        strategy['hyperparameters'] = {
            'learning_rate': learning_rate,
            'batch_size': 16,
            'num_epochs': num_epochs,
            'warmup_ratio': 0.15,
            'weight_decay': 0.01,
            'max_grad_norm': 1.0,
            'fp16': True,
            'dataloader_drop_last': False,
            'eval_accumulation_steps': 1,
            'logging_steps': 10,
            'save_steps': 200,
            'eval_steps': 50,
            'metric_for_best_model': 'eval_accuracy',
            'load_best_model_at_end': True,
            'greater_is_better': True,
        }
        
        self.fine_tuning_strategy.update(strategy)
        
        print("📋 Comprehensive fine-tuning strategy:")
        print(f"   • Data augmentation: {len(strategy['data_augmentation'])} pattern-based strategies")
        print(f"   • Training focus: {len(strategy['training_focus'])} token-based strategies")
        print(f"   • Adaptive learning rate: {learning_rate:.1e} (based on {mistake_complexity} patterns)")
        print(f"   • Training epochs: {num_epochs} (based on complexity)")
        
        return strategy
    
    def fine_tune_model(self, train_texts, train_labels, val_texts, val_labels, strategy, output_dir="../models/explainability_fine_tuned/"):
        """
        Fine-tune the model using explainability-guided recommendations.
        Applies hyperparameters and (optionally) data augmentation from strategy.
        """
        import os
        from datasets import Dataset
        from transformers import TrainingArguments, Trainer, DataCollatorWithPadding
        os.makedirs(output_dir, exist_ok=True)
        
        # Use recommended hyperparameters
        hyperparams = strategy.get('hyperparameters', {})
        
        # Optionally augment or reweight data based on strategy['data_augmentation'] (not implemented here)
        train_data = {
            'text': train_texts,
            'label': train_labels
        }
        val_data = {
            'text': val_texts,
            'label': val_labels
        }
        
        train_dataset = Dataset.from_dict(train_data)
        val_dataset = Dataset.from_dict(val_data)
        
        def preprocess(batch):
            return self.tokenizer(batch['text'], truncation=True, padding='max_length', max_length=128)
        
        train_dataset = train_dataset.map(preprocess, batched=True)
        val_dataset = val_dataset.map(preprocess, batched=True)
        
        training_args = TrainingArguments(
            output_dir=output_dir,
            learning_rate=hyperparams.get('learning_rate', 3e-5),
            num_train_epochs=hyperparams.get('num_epochs', 3),
            per_device_train_batch_size=hyperparams.get('batch_size', 16),
            warmup_ratio=hyperparams.get('warmup_ratio', 0.1),
            weight_decay=hyperparams.get('weight_decay', 0.01),
            logging_steps=hyperparams.get('logging_steps', 10),
            evaluation_strategy="steps",
            eval_steps=hyperparams.get('eval_steps', 50),
            save_steps=hyperparams.get('save_steps', 200),
            load_best_model_at_end=True,
            metric_for_best_model="eval_accuracy",
            greater_is_better=True,
            fp16=hyperparams.get('fp16', False),
        )
        
        data_collator = DataCollatorWithPadding(self.tokenizer)
        
        import numpy as np
        from sklearn.metrics import accuracy_score, f1_score
        def compute_metrics(eval_pred):
            logits, labels = eval_pred
            preds = np.argmax(logits, axis=1)
            acc = accuracy_score(labels, preds)
            f1 = f1_score(labels, preds, average='weighted')
            return {"accuracy": acc, "f1": f1}
        
        trainer = Trainer(
            model=self.model,
            args=training_args,
            train_dataset=train_dataset,
            eval_dataset=val_dataset,
            tokenizer=self.tokenizer,
            data_collator=data_collator,
            compute_metrics=compute_metrics,
        )
        
        print("\n🚀 Fine-tuning model...")
        trainer.train()
        print("✅ Fine-tuning complete!")
        
        # Save the fine-tuned model
        trainer.save_model(output_dir)
        print(f"📁 Model saved to {output_dir}")
        return trainer

print("✅ ExplainabilityDrivenFineTuner class defined (with fine-tuning)")

✅ ExplainabilityDrivenFineTuner class defined


## 3. 🎮 Interactive Fine-Tuning Dashboard

This section provides an interactive interface to run the explainability-driven fine-tuning process.

In [None]:
class ExplainabilityFineTuningDashboard:
    """
    Interactive dashboard for explainability-driven fine-tuning
    """
    
    def __init__(self, available_models, train_data, val_data):
        self.available_models = available_models
        self.train_texts, self.train_labels = train_data
        self.val_texts, self.val_labels = val_data
        self.fine_tuner = None
        self.last_strategy = None
        
        self.create_interface()
    
    def create_interface(self):
        """Create the dashboard interface"""
        
        # Model selector
        model_options = [(name, name) for name in self.available_models.keys()]
        self.model_selector = widgets.Dropdown(
            options=model_options,
            description='Base Model:',
            style={'description_width': '120px'},
            layout=widgets.Layout(width='400px')
        )
        
        # Control buttons
        self.analyze_button = widgets.Button(
            description='🔍 Analyze Model',
            button_style='info',
            layout=widgets.Layout(width='150px')
        )
        
        self.fine_tune_button = widgets.Button(
            description='🚀 Fine-Tune',
            button_style='success',
            layout=widgets.Layout(width='150px'),
            disabled=True
        )
        
        self.benchmark_button = widgets.Button(
            description='📊 Run Benchmarks',
            button_style='warning',
            layout=widgets.Layout(width='150px'),
            disabled=True
        )
        
        # Progress and status
        self.status_output = widgets.Output()
        
        # Event handlers
        self.analyze_button.on_click(self.on_analyze)
        self.fine_tune_button.on_click(self.on_fine_tune)
        self.benchmark_button.on_click(self.on_benchmark)
    
    def on_analyze(self, button):
        """Analyze selected model for fine-tuning opportunities"""
        with self.status_output:
            clear_output(wait=True)
            
            if not self.model_selector.value:
                print("❌ Please select a model first!")
                return
            
            model_info = self.available_models[self.model_selector.value]
            
            try:
                print(f"🔄 Loading model: {model_info['name']}")
                
                # Load tokenizer and label encoder
                tokenizer = AutoTokenizer.from_pretrained(str(model_info['path']))
                
                with open(model_info['label_encoder_file'], 'rb') as f:
                    label_encoder = pickle.load(f)
                
                # Initialize fine-tuner
                self.fine_tuner = ExplainabilityDrivenFineTuner(
                    str(model_info['path']),
                    tokenizer,
                    label_encoder,
                    config
                )
                
                print("🔍 Analyzing baseline performance...")
                baseline_analysis = self.fine_tuner.analyze_baseline_performance(
                    self.val_texts, self.val_labels
                )
                
                print("🧠 Extracting explainability insights...")
                insights = self.fine_tuner.extract_explainability_insights(
                    baseline_analysis['mistakes'],
                    train_texts=self.train_texts,
                    train_labels=self.train_labels
                )
                
                print("🎯 Designing fine-tuning strategy...")
                strategy = self.fine_tuner.design_fine_tuning_strategy(insights)
                self.last_strategy = strategy
                
                print("\n✅ Analysis complete! Ready for fine-tuning.")
                self.fine_tune_button.disabled = False
                
            except Exception as e:
                print(f"❌ Analysis failed: {str(e)}")
                import traceback
                print(f"🔍 Details: {traceback.format_exc()}")
    
    def on_fine_tune(self, button):
        """Execute fine-tuning based on explainability insights"""
        with self.status_output:
            clear_output(wait=True)
            if self.fine_tuner is None or self.last_strategy is None:
                print("❌ Please analyze a model first!")
                return
            
            try:
                print("🚀 Starting explainability-guided fine-tuning...")
                trainer = self.fine_tuner.fine_tune_model(
                    self.train_texts, self.train_labels, self.val_texts, self.val_labels, self.last_strategy
                )
                print("✅ Fine-tuning complete!")
                print("📁 Model saved to ../models/explainability_fine_tuned/")
                self.benchmark_button.disabled = False
            except Exception as e:
                print(f"❌ Fine-tuning failed: {str(e)}")
                import traceback
                print(f"🔍 Details: {traceback.format_exc()}")
    
    def on_benchmark(self, button):
        """Run benchmarking script to compare performance"""
        with self.status_output:
            if self.fine_tuner is None:
                print("❌ Please analyze a model first!")
                return
            
            try:
                print("📊 Running benchmarking analysis...")
                print("🔄 This will compare current model performance...")
                
                print("\n🎯 Next Steps:")
                print("1. Open notebook #7 (7_benchmarks_generalized.ipynb)")
                print("2. Run all cells to benchmark your model")
                print("3. Analyze results and insights from explainability analysis")
                
            except Exception as e:
                print(f"❌ Benchmarking setup failed: {str(e)}")
                print("💡 Please manually run notebook #7 for benchmarking results")
    
    def display(self):
        """Display the dashboard"""
        title = widgets.HTML(
            value="""
            <div style='text-align: center; background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
                        color: white; padding: 20px; border-radius: 10px; margin-bottom: 20px;'>
                <h2 style='margin: 0; font-size: 24px;'>🧠 Explainability-Driven Fine-Tuning Dashboard</h2>
                <p style='margin: 10px 0 0 0; opacity: 0.9;'>Optimize models using explainability insights</p>
            </div>
            """
        )
        
        controls = widgets.VBox([
            widgets.HTML("<h3>🔧 Model Selection</h3>"),
            self.model_selector,
            widgets.HTML("<h3>⚡ Actions</h3>"),
            widgets.HBox([self.analyze_button, self.fine_tune_button, self.benchmark_button]),
            widgets.HTML("<h3>📊 Status & Progress</h3>"),
            self.status_output
        ])
        
        return widgets.VBox([title, controls])

print("✅ ExplainabilityFineTuningDashboard class defined")

# Initialize and display the dashboard
try:
    if len(available_models) > 0:
        print("🔄 Setting up explainability-driven fine-tuning environment...")
        
        # Create the fine-tuning dashboard
        dashboard = ExplainabilityFineTuningDashboard(
            available_models,
            (train_texts, train_labels),
            (val_texts, val_labels)
        )
        
        print("🎉 Dashboard initialized!")
        print("\n📋 Instructions:")
        print("1. Select a base model from the dropdown")
        print("2. Click 'Analyze Model' to identify fine-tuning opportunities")
        print("3. Click 'Fine-Tune' to see explainability-guided recommendations")
        print("4. Click 'Run Benchmarks' to measure current model performance")
        print("\n💡 This provides comprehensive explainability analysis for research")
        
        # Display the dashboard
        display(dashboard.display())
        
    else:
        print("❌ No models found.")
        print("💡 Please ensure you have trained models available in the models directory")
        
except Exception as e:
    print(f"❌ Error setting up dashboard: {str(e)}")
    print("\n🔧 Please ensure:")
    print("   1. Models are available in the models directory")
    print("   2. Training data is available") 
    print("   3. All dependencies are installed")

✅ ExplainabilityFineTuningDashboard class defined
🔄 Setting up explainability-driven fine-tuning environment...
🎉 Dashboard initialized!

📋 Instructions:
1. Select a base model from the dropdown
2. Click 'Analyze Model' to identify fine-tuning opportunities
3. Click 'Fine-Tune' to see explainability-guided recommendations
4. Click 'Run Benchmarks' to measure current model performance

💡 This provides comprehensive explainability analysis for research


VBox(children=(HTML(value="\n            <div style='text-align: center; background: linear-gradient(135deg, #…

## 4. 📈 Next Steps: Benchmarking & Research

After running the explainability-driven fine-tuning, here's how to proceed with your research:

### 🔬 Academic Research Workflow
1. **📊 Quantitative Analysis**: Run your benchmarking script to measure performance improvements
2. **🔍 Qualitative Analysis**: Compare explainability patterns before and after fine-tuning
3. **📝 Documentation**: Record insights and methodology for your research paper
4. **🔄 Iterative Refinement**: Refine the process based on results

### 📊 Key Metrics to Track
- **Accuracy Improvement**: Overall performance gain from baseline
- **Confidence Stability**: Reduction in low-confidence predictions
- **Mistake Pattern Changes**: Shift in types of errors made
- **Explainability Consistency**: More coherent attribution patterns
- **Training Efficiency**: Convergence speed and stability

### 🎯 Research Contributions
This methodology demonstrates:
- **Novel Approach**: Using explainability to guide fine-tuning rather than just analyze
- **Quantifiable Impact**: Measurable improvements in both performance and interpretability  
- **Systematic Framework**: Reproducible methodology for explainability-driven optimization
- **Domain Application**: Validation in financial NLP where interpretability is critical

### 📁 Outputs
- **Fine-tuned Model**: Saved to `../models/{model_name}-explainability-fine-tuned/`
- **Training Logs**: Available in `../logs/explainability_fine_tuning/`
- **Insights Data**: Stored in the ExplainabilityDrivenFineTuner instance
- **Benchmark Results**: Automatically generated via integrated benchmarking

### 🚀 Paper Sections This Supports
- **Methodology**: Systematic explainability-driven optimization process
- **Experiments**: Controlled comparison of fine-tuning strategies
- **Results**: Performance and interpretability improvements
- **Analysis**: Ablation studies and insight validation
- **Discussion**: Trade-offs between performance and interpretability

**🎉 Ready to run your benchmarking script and measure the improvements!**

The fine-tuned model will be available alongside your existing models for comparative analysis.