# üß† Financial Sentiment Model Explainability Dashboard

## Overview
This notebook provides comprehensive explainability analysis for the fine-tuned TinyBERT financial sentiment classification model. It includes four complementary explanation methods accessible through an interactive dashboard.

### Explanation Methods
- **üéØ SHAP**: Game-theory based feature importance
- **üîç LIME**: Local interpretable model-agnostic explanations 
- **üëÅÔ∏è Attention**: Model attention head visualization
- **üå°Ô∏è GradCAM**: Gradient-based visual attribution

### Dashboard Features
- **Mistake Analysis**: Examine specific model errors
- **Custom Text Analysis**: Test any financial text
- **Interactive Interface**: Tabbed layout for easy comparison
- **On-demand Computation**: Optimized performance

## 1. üì¶ Setup & Imports

In [1]:
!ls

0_setup.ipynb                   4_benchmarks.ipynb
1_data_processing.ipynb         5_explainability.ipynb
2_train_models.ipynb            5_explainability_original.ipynb
3_convert_to_onnx.ipynb         6_fine_tune.ipynb


In [2]:
%cd ..

/Users/matthew/Documents/deepmind_internship


In [3]:
# Core libraries
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import warnings
warnings.filterwarnings('ignore')

# Model and tokenizer
from transformers import BertTokenizerFast, BertForSequenceClassification
from sklearn.preprocessing import LabelEncoder

# Explainability libraries
import shap
from lime.lime_text import LimeTextExplainer
from bertviz import head_view
from captum.attr import LayerGradCam

# Dashboard components
import ipywidgets as widgets
from IPython.display import display, HTML, clear_output

print("‚úÖ All libraries imported successfully")

‚úÖ All libraries imported successfully


## 2. üóÇÔ∏è Data & Model Loading

In [4]:
# Configuration
MODEL_DIR = Path('models/tinybert-financial-classifier')
DATA_FILE = 'data/FinancialPhraseBank/all-data.csv'
RANDOM_SEED = 42
TEST_SIZE = 0.25

# Load full dataset and create train-test split
from sklearn.model_selection import train_test_split

# Load data with correct encoding and column names (matching training notebook)
df = pd.read_csv(DATA_FILE, header=None, names=["label", "sentence"], encoding="latin-1")
df["sentence"] = df["sentence"].str.strip('"')  # Remove extra quotes

# Create train-test split with same parameters as training
train_df, test_df = train_test_split(
    df, 
    test_size=TEST_SIZE, 
    random_state=RANDOM_SEED, 
    stratify=df['label']
)

# Extract test data
test_texts = test_df['sentence'].tolist()  # Note: column is 'sentence' not 'text'

# Load label encoder
import pickle
with open(MODEL_DIR / 'label_encoder.pkl', 'rb') as f:
    label_encoder = pickle.load(f)

true_labels_encoded = label_encoder.transform(test_df['label'])

print(f"üìä Loaded full dataset: {len(df)} samples")
print(f"üìä Test set: {len(test_texts)} samples (25% split)")
print(f"üìã Label classes: {list(label_encoder.classes_)}")
print(f"üé≤ Random seed: {RANDOM_SEED}")
print(f"‚úÖ Data loaded successfully with correct encoding")

üìä Loaded full dataset: 4846 samples
üìä Test set: 1212 samples (25% split)
üìã Label classes: ['negative', 'neutral', 'positive']
üé≤ Random seed: 42
‚úÖ Data loaded successfully with correct encoding


In [5]:
# Load model and tokenizer
print("üîÑ Loading model and tokenizer...")

tokenizer = BertTokenizerFast.from_pretrained(MODEL_DIR)
pt_model = BertForSequenceClassification.from_pretrained(MODEL_DIR)
pt_model.eval()

print("‚úÖ Model and tokenizer loaded successfully")
print(f"üì± Model type: {type(pt_model).__name__}")
print(f"üéØ Number of classes: {pt_model.config.num_labels}")

üîÑ Loading model and tokenizer...
‚úÖ Model and tokenizer loaded successfully
üì± Model type: BertForSequenceClassification
üéØ Number of classes: 3


## 3. üîß Core Prediction Functions

In [6]:
def predict_class(texts):
    """Predict sentiment class for text(s)"""
    if isinstance(texts, str):
        texts = [texts]
    
    predictions = []
    pt_model.eval()
    
    with torch.no_grad():
        for text in texts:
            encoding = tokenizer(text, return_tensors='pt', max_length=512, truncation=True, padding=True)
            outputs = pt_model(**encoding)
            predicted_class = torch.argmax(outputs.logits, dim=-1).item()
            predictions.append(predicted_class)
    
    return np.array(predictions)

def predict_probs_for_shap(texts):
    """Get prediction probabilities for SHAP"""
    if isinstance(texts, str):
        texts = [texts]
    
    all_probs = []
    pt_model.eval()
    
    with torch.no_grad():
        for text in texts:
            encoding = tokenizer(text, return_tensors='pt', max_length=512, truncation=True, padding=True)
            outputs = pt_model(**encoding)
            probs = torch.softmax(outputs.logits, dim=-1).squeeze().numpy()
            all_probs.append(probs)
    
    return np.array(all_probs)

def predict_probs_for_lime(texts):
    """Get prediction probabilities for LIME (expects different format)"""
    if isinstance(texts, str):
        texts = [texts]
    elif isinstance(texts, list) and len(texts) == 1 and isinstance(texts[0], str):
        # LIME sometimes passes single text as list
        texts = texts
    elif isinstance(texts, (list, tuple)) and all(isinstance(t, str) for t in texts):
        # LIME passes list of texts
        texts = list(texts)
    else:
        # Convert to list if needed
        texts = [str(t) for t in texts]
    
    all_probs = []
    pt_model.eval()
    
    with torch.no_grad():
        for text in texts:
            try:
                encoding = tokenizer(text, return_tensors='pt', max_length=512, truncation=True, padding=True)
                outputs = pt_model(**encoding)
                probs = torch.softmax(outputs.logits, dim=-1).squeeze().cpu().numpy()
                all_probs.append(probs)
            except Exception as e:
                print(f"Error processing text: {text[:50]}... Error: {e}")
                # Return default probabilities if processing fails
                all_probs.append(np.array([0.33, 0.33, 0.34]))
    
    return np.array(all_probs)

print("‚úÖ Core prediction functions defined")

‚úÖ Core prediction functions defined


## 4. üß© Explainability Methods

### 4.1 SHAP Implementation

In [7]:
# SHAP explainer (lazy initialization for performance)
_shap_explainer = None

def get_shap_explainer():
    """Get SHAP explainer (lazy initialization)"""
    global _shap_explainer
    if _shap_explainer is None:
        print("üßÆ Initializing SHAP explainer...")
        _shap_explainer = shap.Explainer(predict_probs_for_shap, tokenizer)
    return _shap_explainer

def explain_with_shap(text, target_class=None):
    """Generate SHAP explanation for text"""
    print("‚è≥ Computing SHAP values...")
    
    explainer = get_shap_explainer()
    shap_values = explainer([text])
    
    if target_class is None:
        target_class = predict_class(text)[0]
    
    # Display SHAP plot
    shap.plots.text(shap_values[0, :, target_class])
    
    pred_label = label_encoder.inverse_transform([target_class])[0]
    print(f"üìä SHAP explanation for class: {pred_label}")

print("‚úÖ SHAP implementation ready")

‚úÖ SHAP implementation ready


### 4.2 LIME Implementation

In [8]:
# LIME explainer (lazy initialization)
_lime_explainer = None

def get_lime_explainer():
    """Get LIME explainer (lazy initialization)"""
    global _lime_explainer
    if _lime_explainer is None:
        _lime_explainer = LimeTextExplainer(
            class_names=label_encoder.classes_
            # Removed 'mode' parameter as it's not valid for LimeTextExplainer
        )
    return _lime_explainer

def explain_with_lime(text):
    """Generate LIME explanation for text"""
    print("‚è≥ Computing LIME explanation...")
    
    explainer = get_lime_explainer()
    explanation = explainer.explain_instance(
        text,
        predict_probs_for_lime,
        num_features=20,
        labels=(0, 1, 2)
    )
    
    display(HTML(explanation.as_html()))
    print("üìä LIME explanation generated")

print("‚úÖ LIME implementation ready")

‚úÖ LIME implementation ready


### 4.3 Attention Visualization

In [9]:
def explain_with_attention(text):
    """Generate attention visualization for text"""
    print("‚è≥ Generating attention visualization...")
    
    try:
        # Tokenize with attention output
        inputs = tokenizer(text, return_tensors='pt', truncation=True, max_length=512)
        
        # Get model outputs with attention
        with torch.no_grad():
            # Force eager attention for BertViz compatibility
            original_impl = getattr(pt_model.config, '_attn_implementation', None)
            pt_model.config._attn_implementation = 'eager'
            
            outputs = pt_model(**inputs, output_attentions=True)
            attentions = outputs.attentions
            tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])
            
            # Restore original implementation
            if original_impl is not None:
                pt_model.config._attn_implementation = original_impl
        
        # Check if we have valid attention and tokens
        if attentions is None or len(attentions) == 0:
            print("‚ùå No attention weights available")
            return
            
        if len(tokens) == 0:
            print("‚ùå No tokens available")
            return
        
        pred_class = torch.argmax(outputs.logits, dim=-1).item()
        pred_label = label_encoder.inverse_transform([pred_class])[0]
        
        # Try BertViz first
        try:
            # Enable widget display
            from IPython.display import Javascript
            display(Javascript("""
                require.config({
                    paths: {
                        d3: 'https://d3js.org/d3.v5.min'
                    }
                });
            """))
            
            print("üéØ Attempting interactive attention visualization...")
            head_view(attentions, tokens)
            print(f"üëÅÔ∏è Interactive attention visualization for prediction: {pred_label}")
            
        except Exception as viz_error:
            print(f"‚ùå BertViz interactive view failed: {viz_error}")
            print("üí° Using custom attention heatmap...")
            
            # Custom attention visualization
            _visualize_attention_heatmap(attentions, tokens, pred_label)
            
    except Exception as e:
        print(f"‚ùå Attention analysis failed: {str(e)}")
        print("üí° This might be due to model architecture or BertViz compatibility issues")

def _visualize_attention_heatmap(attentions, tokens, pred_label):
    """Create custom attention heatmap visualization"""
    import matplotlib.pyplot as plt
    import seaborn as sns
    
    # Get average attention across all layers and heads
    # Shape: (num_layers, batch_size, num_heads, seq_len, seq_len)
    avg_attention = torch.stack(attentions).mean(dim=0)  # Average across layers
    avg_attention = avg_attention.mean(dim=1)  # Average across heads
    attention_matrix = avg_attention[0].detach().cpu().numpy()  # Get first (and only) batch
    
    # Clean tokens for display
    clean_tokens = []
    for token in tokens:
        if token.startswith('##'):
            clean_tokens.append(token[2:])
        elif token in ['[CLS]', '[SEP]', '[PAD]']:
            clean_tokens.append(token)
        else:
            clean_tokens.append(token)
    
    # Limit to reasonable size for visualization
    max_len = min(len(clean_tokens), 50)
    attention_matrix = attention_matrix[:max_len, :max_len]
    display_tokens = clean_tokens[:max_len]
    
    # Create visualization
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 8))
    
    # 1. Full attention heatmap
    sns.heatmap(attention_matrix, 
                xticklabels=display_tokens,
                yticklabels=display_tokens,
                cmap='Blues',
                ax=ax1,
                cbar_kws={'label': 'Attention Weight'})
    ax1.set_title(f'Attention Heatmap\nPrediction: {pred_label}', fontsize=14, weight='bold')
    ax1.set_xlabel('Attended Tokens')
    ax1.set_ylabel('Query Tokens')
    plt.setp(ax1.get_xticklabels(), rotation=45, ha='right')
    plt.setp(ax1.get_yticklabels(), rotation=0)
    
    # 2. CLS token attention (what the model focuses on for classification)
    cls_attention = attention_matrix[0, 1:]  # CLS token attention to other tokens (skip self-attention)
    tokens_for_cls = display_tokens[1:]  # Skip CLS token
    
    # Sort by attention weight
    token_attention_pairs = list(zip(tokens_for_cls, cls_attention))
    token_attention_pairs.sort(key=lambda x: x[1], reverse=True)
    
    # Take top 15 for readability
    top_tokens, top_weights = zip(*token_attention_pairs[:15])
    
    bars = ax2.barh(range(len(top_tokens)), top_weights, color='skyblue')
    ax2.set_yticks(range(len(top_tokens)))
    ax2.set_yticklabels(top_tokens)
    ax2.set_xlabel('Attention Weight')
    ax2.set_title(f'Top Attended Tokens for Classification\n(CLS token attention)', fontsize=14, weight='bold')
    ax2.invert_yaxis()
    
    # Add value labels on bars
    for i, (bar, weight) in enumerate(zip(bars, top_weights)):
        ax2.text(weight + 0.001, i, f'{weight:.3f}', 
                va='center', ha='left', fontsize=10)
    
    plt.tight_layout()
    plt.show()
    
    # Print statistics
    print(f"üìä Attention Statistics:")
    print(f"   ‚Ä¢ Number of layers: {len(attentions)}")
    print(f"   ‚Ä¢ Number of heads per layer: {attentions[0].shape[2]}")
    print(f"   ‚Ä¢ Sequence length: {len(tokens)}")
    print(f"   ‚Ä¢ Max attention weight: {attention_matrix.max():.4f}")
    print(f"   ‚Ä¢ Average attention weight: {attention_matrix.mean():.4f}")
    
    print(f"\nüéØ Top 5 tokens by CLS attention:")
    for i, (token, weight) in enumerate(token_attention_pairs[:5]):
        if token not in ['[SEP]', '[PAD]']:
            print(f"   {i+1}. '{token}': {weight:.4f}")
    
    print(f"üëÅÔ∏è Custom attention visualization complete for: {pred_label}")

print("‚úÖ Attention visualization ready")

‚úÖ Attention visualization ready


### 4.4 GradCAM Implementation

In [10]:
class ModelWrapper(torch.nn.Module):
    """Wrapper to fix SequenceClassifierOutput error with Captum"""
    def __init__(self, model):
        super().__init__()
        self.model = model
    
    def forward(self, input_ids, attention_mask=None):
        outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
        return outputs.logits

def explain_with_gradcam(text, target_class=None):
    """Generate GradCAM explanation for text"""
    print("‚è≥ Computing GradCAM attributions...")
    
    try:
        # Tokenize input
        inputs = tokenizer(text, return_tensors='pt', truncation=True, max_length=512)
        input_ids = inputs['input_ids']
        attention_mask = inputs['attention_mask']
        
        # Get prediction if not specified
        if target_class is None:
            with torch.no_grad():
                outputs = pt_model(**inputs)
                target_class = torch.argmax(outputs.logits, dim=1).item()
        
        # Ensure target_class is the right type for Captum
        target_class = int(target_class)  # Convert to Python int from numpy.int64
        
        # Use ModelWrapper for Captum compatibility
        wrapped_model = ModelWrapper(pt_model)
        wrapped_model.eval()
        
        # Try to access embedding layer with different paths
        embedding_layer = None
        try:
            embedding_layer = pt_model.bert.embeddings.word_embeddings
        except AttributeError:
            try:
                embedding_layer = pt_model.embeddings.word_embeddings
            except AttributeError:
                try:
                    embedding_layer = pt_model.get_input_embeddings()
                except AttributeError:
                    print("‚ùå Could not access embedding layer")
                    return
        
        if embedding_layer is None:
            print("‚ùå Embedding layer not found")
            return
        
        # Initialize LayerGradCam
        layer_gradcam = LayerGradCam(wrapped_model, embedding_layer)
        
        # Generate attributions
        attributions = layer_gradcam.attribute(
            input_ids,
            target=target_class,
            additional_forward_args=(attention_mask,)
        )
        
        # Process attributions
        attribution_scores = attributions.squeeze().detach().cpu().numpy()
        tokens = tokenizer.convert_ids_to_tokens(input_ids.squeeze())
        
        if len(attribution_scores.shape) > 1:
            attribution_scores = attribution_scores.sum(axis=-1)
        
        # Visualize
        _visualize_gradcam(tokens, attribution_scores, attention_mask, target_class)
        
        pred_label = label_encoder.inverse_transform([target_class])[0]
        print(f"üå°Ô∏è GradCAM explanation for class: {pred_label}")
        
    except Exception as e:
        print(f"‚ùå GradCAM error: {str(e)}")
        print("üí° Falling back to attention-based attribution...")
        
        # Fallback: Use attention weights as pseudo-GradCAM
        try:
            inputs = tokenizer(text, return_tensors='pt', truncation=True, max_length=512)
            with torch.no_grad():
                outputs = pt_model(**inputs, output_attentions=True)
                attentions = outputs.attentions
                
                if attentions is not None and len(attentions) > 0:
                    # Average attention across layers and heads
                    avg_attention = torch.stack(attentions).mean(dim=0).mean(dim=1)
                    cls_attention = avg_attention[0, 0, :].detach().cpu().numpy()
                    tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'].squeeze())
                    
                    # Visualize as pseudo-GradCAM
                    _visualize_gradcam(tokens, cls_attention, inputs['attention_mask'], 
                                     torch.argmax(outputs.logits, dim=-1).item())
                    print("üìä Used attention weights as attribution fallback")
                else:
                    print("‚ùå No attention weights available for fallback")
        except Exception as fallback_error:
            print(f"‚ùå Fallback also failed: {fallback_error}")
            print("üí° Try using SHAP or LIME for alternative explanations")

def _visualize_gradcam(tokens, attribution_scores, attention_mask, target_class):
    """Create GradCAM visualization"""
    fig, ax = plt.subplots(figsize=(14, 4))
    
    # Normalize attributions
    abs_attributions = np.abs(attribution_scores)
    if abs_attributions.max() > 0:
        normalized_attrs = abs_attributions / abs_attributions.max()
    else:
        normalized_attrs = abs_attributions
    
    # Plot tokens with color intensity
    colors = plt.cm.Reds(normalized_attrs)
    x_positions = []
    
    for i, (token, attr, color) in enumerate(zip(tokens, normalized_attrs, colors)):
        if token in ['[CLS]', '[SEP]', '[PAD]'] or attention_mask[0][i].item() == 0:
            continue
        
        clean_token = token.replace('##', '')
        if not clean_token.strip():
            continue
        
        x_pos = len(x_positions) * 1.2
        x_positions.append(x_pos)
        
        bbox_props = dict(boxstyle="round,pad=0.3", facecolor=color, alpha=0.8)
        ax.text(x_pos, 0.5, clean_token, fontsize=11, ha='center', va='center',
                bbox=bbox_props, weight='bold' if attr > 0.5 else 'normal')
    
    # Format plot
    if x_positions:
        ax.set_xlim(-0.5, max(x_positions) + 0.5)
    ax.set_ylim(0, 1)
    ax.axis('off')
    
    pred_label = label_encoder.inverse_transform([target_class])[0]
    ax.set_title(f'GradCAM Attribution for Class: {pred_label}\n(Darker Red = Higher Attribution)', 
                fontsize=14, pad=30, weight='bold')
    
    # Add colorbar
    sm = plt.cm.ScalarMappable(cmap=plt.cm.Reds, norm=plt.Normalize(vmin=0, vmax=1))
    sm.set_array([])
    cbar = plt.colorbar(sm, ax=ax, orientation='horizontal', shrink=0.6, pad=0.15)
    cbar.set_label('Attribution Intensity', fontsize=12)
    
    plt.tight_layout()
    plt.show()

print("‚úÖ GradCAM implementation ready")

‚úÖ GradCAM implementation ready


## 5. üéõÔ∏è Interactive Dashboard

In [11]:
class ExplainabilityDashboard:
    """Interactive dashboard for model explainability analysis"""
    
    def __init__(self):
        self.setup_data()
        self.create_interface()
    
    def setup_data(self):
        """Setup data for mistake analysis"""
        predictions_encoded = predict_class(test_texts)
        self.incorrect_indices = np.where(predictions_encoded != true_labels_encoded)[0]
        print(f"üìä Found {len(self.incorrect_indices)} mistakes out of {len(test_texts)} samples")
    
    def create_interface(self):
        """Create the dashboard interface"""
        # Input mode selector
        self.input_mode = widgets.ToggleButtons(
            options=[('Analyze Mistakes', 'mistakes'), ('Custom Text', 'custom')],
            value='mistakes',
            description='Analysis Mode:',
            style={'description_width': 'initial'}
        )
        
        # Mistake selector
        mistake_options = [(f"Mistake {i+1}: {test_texts[idx][:50]}...", i) 
                          for i, idx in enumerate(self.incorrect_indices[:20])]  # Limit for performance
        self.mistake_selector = widgets.Dropdown(
            options=mistake_options,
            description='Select Mistake:',
            style={'description_width': 'initial'}
        )
        
        # Custom text input
        self.text_input = widgets.Textarea(
            placeholder='Enter financial text to analyze...',
            description='Text:',
            layout=widgets.Layout(width='100%', height='80px'),
            style={'description_width': 'initial'}
        )
        
        # Control buttons
        self.analyze_button = widgets.Button(
            description='üöÄ Analyze',
            button_style='primary',
            layout=widgets.Layout(width='120px')
        )
        
        self.clear_button = widgets.Button(
            description='üßπ Clear',
            button_style='warning',
            layout=widgets.Layout(width='120px')
        )
        
        # Output tabs
        self.output_tabs = widgets.Tab()
        self.method_outputs = {
            'SHAP': widgets.Output(),
            'LIME': widgets.Output(),
            'Attention': widgets.Output(),
            'GradCAM': widgets.Output()
        }
        
        self.output_tabs.children = list(self.method_outputs.values())
        for i, method in enumerate(self.method_outputs.keys()):
            self.output_tabs.set_title(i, f'{method}')
        
        # Status output
        self.status_output = widgets.Output()
        
        # Event handlers
        self.input_mode.observe(self.on_mode_change, names='value')
        self.analyze_button.on_click(self.on_analyze)
        self.clear_button.on_click(self.on_clear)
    
    def on_mode_change(self, change):
        """Handle input mode change"""
        # Update the input container dynamically
        if hasattr(self, 'input_container'):
            if change['new'] == 'mistakes':
                self.input_container.children = [self.input_mode, self.mistake_selector]
            else:
                self.input_container.children = [self.input_mode, self.text_input]
    
    def update_interface(self):
        """Update interface based on mode"""
        # This method is called by on_mode_change
        pass
    
    def on_analyze(self, button):
        """Handle analyze button click"""
        try:
            # Get text and prediction info
            if self.input_mode.value == 'mistakes':
                mistake_idx = self.mistake_selector.value
                sample_idx = self.incorrect_indices[mistake_idx]
                text = test_texts[sample_idx]
                true_label = label_encoder.inverse_transform([true_labels_encoded[sample_idx]])[0]
                pred_class = int(predict_class(text)[0])  # Ensure Python int
                pred_label = label_encoder.inverse_transform([pred_class])[0]
            else:
                text = self.text_input.value.strip()
                if not text:
                    with self.status_output:
                        clear_output(wait=True)
                        print("‚ùå Please enter some text to analyze!")
                    return
                pred_class = int(predict_class(text)[0])  # Ensure Python int
                pred_label = label_encoder.inverse_transform([pred_class])[0]
                true_label = "Unknown"
            
            # Generate explanations
            self.generate_explanations(text, pred_label, true_label, pred_class)
            
        except Exception as e:
            with self.status_output:
                clear_output(wait=True)
                print(f"‚ùå Error during analysis: {str(e)}")
    
    def generate_explanations(self, text, pred_label, true_label, pred_class):
        """Generate all explanations for the text"""
        # Clear outputs
        for output in self.method_outputs.values():
            with output:
                clear_output()
        
        # Create header
        header_html = f"""
        <div style='background: #f8f9fa; padding: 15px; margin: 10px 0; border-radius: 8px; 
                    border-left: 4px solid #007bff; box-shadow: 0 2px 8px rgba(0,0,0,0.1);'>
            <h4 style='margin: 0 0 10px 0; color: #007bff;'>üìù Analysis Summary</h4>
            <p style='margin: 5px 0;'><strong>Text:</strong> <em>"{text}"</em></p>
            <p style='margin: 5px 0;'><strong>Model Prediction:</strong> 
               <span style='color: #28a745; font-weight: bold;'>{pred_label}</span></p>
            {f'<p style="margin: 5px 0;"><strong>True Label:</strong> <span style="color: #dc3545; font-weight: bold;">{true_label}</span></p>' if true_label != "Unknown" else ''}
        </div>
        """
        
        with self.status_output:
            clear_output(wait=True)
            print("üß† Generating explanations...")
        
        # SHAP
        with self.method_outputs['SHAP']:
            display(HTML(header_html))
            try:
                explain_with_shap(text, pred_class)
            except Exception as e:
                print(f"‚ùå SHAP failed: {str(e)}")
        
        # LIME
        with self.method_outputs['LIME']:
            display(HTML(header_html))
            try:
                explain_with_lime(text)
            except Exception as e:
                print(f"‚ùå LIME failed: {str(e)}")
                print("üí° Common LIME issues:")
                print("   - Text preprocessing differences")
                print("   - Prediction function format mismatch")
                print("   - Try using SHAP instead")
        
        # Attention
        with self.method_outputs['Attention']:
            display(HTML(header_html))
            try:
                # BertViz doesn't work well in widget contexts, so use custom visualization
                print("‚è≥ Generating attention visualization...")
                print("üí° Using custom heatmap (BertViz widgets don't render in dashboard)")
                
                # Get attention data
                inputs = tokenizer(text, return_tensors='pt', truncation=True, max_length=512)
                with torch.no_grad():
                    # Force eager attention for compatibility
                    original_impl = getattr(pt_model.config, '_attn_implementation', None)
                    pt_model.config._attn_implementation = 'eager'
                    
                    outputs = pt_model(**inputs, output_attentions=True)
                    attentions = outputs.attentions
                    tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])
                    
                    # Restore original implementation
                    if original_impl is not None:
                        pt_model.config._attn_implementation = original_impl
                
                if attentions is not None and len(attentions) > 0:
                    pred_label = label_encoder.inverse_transform([pred_class])[0]
                    _visualize_attention_heatmap(attentions, tokens, pred_label)
                else:
                    print("‚ùå No attention weights available")
                    
            except Exception as e:
                print(f"‚ùå Attention failed: {str(e)}")
                print("üí° Common Attention issues:")
                print("   - BertViz compatibility with model architecture")
                print("   - JavaScript widget display problems")
                print("   - Try refreshing the notebook kernel")
        
        # GradCAM
        with self.method_outputs['GradCAM']:
            display(HTML(header_html))
            try:
                explain_with_gradcam(text, pred_class)
            except Exception as e:
                print(f"‚ùå GradCAM failed: {str(e)}")
                print("üí° Common GradCAM issues:")
                print("   - Model architecture compatibility")
                print("   - Captum version mismatch")
                print("   - GPU/CPU tensor issues")
        
        with self.status_output:
            clear_output(wait=True)
            print("‚úÖ Analysis complete! Explore the tabs above.")
    
    def on_clear(self, button):
        """Clear all outputs"""
        for output in self.method_outputs.values():
            with output:
                clear_output()
        
        with self.status_output:
            clear_output(wait=True)
            print("üßπ All results cleared! Ready for new analysis.")
    
    def display(self):
        """Display the dashboard"""
        # Title
        title = widgets.HTML(
            value="""
            <div style='text-align: center; background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
                        color: white; padding: 20px; border-radius: 10px; margin-bottom: 20px;'>
                <h2 style='margin: 0; font-size: 24px;'>üß† Financial Sentiment Explainability Dashboard</h2>
                <p style='margin: 10px 0 0 0; opacity: 0.9;'>Comprehensive AI model explanation and analysis</p>
            </div>
            """
        )
        
        # Dynamic input section that updates based on mode
        self.input_container = widgets.VBox([
            self.input_mode,
            self.mistake_selector if self.input_mode.value == 'mistakes' else self.text_input
        ])
        
        # Controls
        controls = widgets.HBox([
            self.analyze_button,
            self.clear_button
        ], layout=widgets.Layout(justify_content='space-between', width='250px'))
        
        # Main dashboard
        dashboard = widgets.VBox([
            title,
            self.input_container,
            controls,
            self.status_output,
            self.output_tabs
        ])
        
        return dashboard

print("‚úÖ Dashboard class defined")

‚úÖ Dashboard class defined


## 6. üöÄ Launch Dashboard

In [12]:
# Initialize and display dashboard
print("üéØ Initializing Explainability Dashboard...")
dashboard = ExplainabilityDashboard()
display(dashboard.display())
print("‚úÖ Dashboard is ready! Use the interface above to analyze model predictions.")

üéØ Initializing Explainability Dashboard...
üìä Found 253 mistakes out of 1212 samples


VBox(children=(HTML(value="\n            <div style='text-align: center; background: linear-gradient(135deg, #‚Ä¶

‚úÖ Dashboard is ready! Use the interface above to analyze model predictions.


## 7. ? Quick Misclassification Analysis

Simple analysis to identify patterns for fine-tuning in the next notebook.

### üìã Enhanced Analysis Methodology

Our comprehensive fine-tuning and pruning analysis follows a systematic 5-step approach designed to provide actionable insights for model optimization:

**üîç Step 1: Basic Performance Analysis**
- Generate predictions and confidence scores for the entire test set
- Calculate accuracy, error rates, and confidence distributions
- Identify misclassified samples and low-confidence predictions
- Establish baseline metrics for optimization tracking

**üìä Step 2: Per-Class Performance Analysis** 
- Generate detailed confusion matrix and class-wise metrics
- Calculate precision, recall, and F1-scores for each sentiment class
- Identify most problematic classes requiring targeted fine-tuning
- Analyze error patterns between specific class pairs

**üéØ Step 3: Confidence Distribution Analysis**
- Analyze prediction confidence across different thresholds
- Calculate coverage and accuracy at various confidence levels
- Identify low-confidence samples for fine-tuning focus
- Assess entropy distribution for pruning strategy recommendations

**üìù Step 4: Enhanced Linguistic Pattern Analysis**
- Use advanced TF-IDF with trigrams for comprehensive vocabulary analysis
- Identify both problematic keywords (higher in errors) and protective keywords (higher in correct predictions)
- Analyze n-gram patterns that correlate with model failures
- Generate targeted keywords for data augmentation strategies

**üíæ Step 5: Fine-Tuning & Pruning Recommendations**
- Generate specific learning rate recommendations based on current performance
- Identify high-priority samples for hard negative mining
- Provide confidence-based pruning strategies with expected performance impact
- Suggest targeted data augmentation approaches for problematic patterns
- Compile actionable recommendations for immediate implementation

In [None]:
# Universal Model Analysis (PyTorch + ONNX Compatible)
import onnxruntime as ort
from collections import defaultdict
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics import classification_report, confusion_matrix, precision_recall_fscore_support
import os
import json
import time
import numpy as np

class UniversalModelAnalyzer:
    """Universal analyzer that works with both PyTorch and ONNX models"""
    
    def __init__(self, model_path, model_type='auto'):
        """
        Initialize analyzer for different model types
        
        Args:
            model_path: Path to model directory
            model_type: 'pytorch', 'onnx', or 'auto' (auto-detect)
        """
        self.model_path = Path(model_path)
        self.model_name = self.model_path.name
        
        # Auto-detect model type if not specified
        if model_type == 'auto':
            onnx_path = self.model_path / 'onnx' / 'model.onnx'
            if onnx_path.exists():
                self.model_type = 'onnx'
            else:
                self.model_type = 'pytorch'
        else:
            self.model_type = model_type
        
        # Load tokenizer (same for both types)
        self.tokenizer = BertTokenizerFast.from_pretrained(model_path)
        
        # Load label encoder
        with open(self.model_path / 'label_encoder.pkl', 'rb') as f:
            self.label_encoder = pickle.load(f)
        
        # Load appropriate model
        if self.model_type == 'onnx':
            onnx_path = self.model_path / 'onnx' / 'model.onnx'
            self.session = ort.InferenceSession(str(onnx_path))
            print(f"‚úÖ Loaded ONNX model: {self.model_name}")
        else:
            self.model = BertForSequenceClassification.from_pretrained(model_path)
            self.model.eval()
            print(f"‚úÖ Loaded PyTorch model: {self.model_name}")
    
    def predict_batch(self, texts, batch_size=32):
        """Predict classes and probabilities for batch of texts"""
        if isinstance(texts, str):
            texts = [texts]
        
        predictions = []
        probabilities = []
        
        for i in range(0, len(texts), batch_size):
            batch_texts = texts[i:i + batch_size]
            
            if self.model_type == 'onnx':
                batch_preds, batch_probs = self._predict_onnx_batch(batch_texts)
            else:
                batch_preds, batch_probs = self._predict_pytorch_batch(batch_texts)
            
            predictions.extend(batch_preds)
            probabilities.extend(batch_probs)
        
        return np.array(predictions), np.array(probabilities)
    
    def _predict_onnx_batch(self, texts):
        """ONNX batch prediction"""
        # Tokenize batch
        encodings = self.tokenizer(
            texts, 
            return_tensors='np',
            max_length=512, 
            truncation=True, 
            padding=True
        )
        
        # Run ONNX inference
        inputs = {
            'input_ids': encodings['input_ids'].astype(np.int64),
            'attention_mask': encodings['attention_mask'].astype(np.int64)
        }
        
        outputs = self.session.run(None, inputs)
        logits = outputs[0]
        
        # Convert to predictions and probabilities
        predictions = np.argmax(logits, axis=-1)
        probabilities = self._softmax(logits)
        
        return predictions.tolist(), probabilities.tolist()
    
    def _predict_pytorch_batch(self, texts):
        """PyTorch batch prediction"""
        encodings = self.tokenizer(
            texts,
            return_tensors='pt',
            max_length=512,
            truncation=True,
            padding=True
        )
        
        with torch.no_grad():
            outputs = self.model(**encodings)
            logits = outputs.logits
            predictions = torch.argmax(logits, dim=-1)
            probabilities = torch.softmax(logits, dim=-1)
        
        return predictions.cpu().numpy().tolist(), probabilities.cpu().numpy().tolist()
    
    def _softmax(self, x):
        """Numpy softmax implementation for ONNX"""
        exp_x = np.exp(x - np.max(x, axis=-1, keepdims=True))
        return exp_x / np.sum(exp_x, axis=-1, keepdims=True)

# Initialize the universal analyzer for current model
print("üîÑ Initializing Universal Model Analyzer...")
analyzer = UniversalModelAnalyzer(MODEL_DIR, 'auto')  # Auto-detect model type

def analyze_model_performance():
    """Comprehensive model performance analysis using universal analyzer"""
    
    print("üîç Generating comprehensive performance analysis...")
    
    # 1. Basic Performance Metrics
    print(f"\nüìä STEP 1: Basic Performance Analysis ({analyzer.model_type})")
    predictions, probabilities = analyzer.predict_batch(test_texts)
    max_probs = np.max(probabilities, axis=1)
    
    # Misclassification analysis
    misclassified_mask = predictions != true_labels_encoded
    misclassified_indices = np.where(misclassified_mask)[0]
    misclassified_texts = [test_texts[i] for i in misclassified_indices]
    
    accuracy = np.mean(predictions == true_labels_encoded)
    error_rate = 1 - accuracy
    
    print(f"   ‚Ä¢ Overall Accuracy: {accuracy:.4f}")
    print(f"   ‚Ä¢ Error Rate: {error_rate:.4f}")
    print(f"   ‚Ä¢ Total Misclassifications: {len(misclassified_texts)}")
    print(f"   ‚Ä¢ Average Confidence: {np.mean(max_probs):.4f}")
    print(f"   ‚Ä¢ Model Type: {analyzer.model_type.upper()}")
    
    return {
        'predictions': predictions,
        'probabilities': probabilities,
        'misclassified_mask': misclassified_mask,
        'misclassified_indices': misclassified_indices,
        'misclassified_texts': misclassified_texts,
        'accuracy': accuracy,
        'error_rate': error_rate,
        'avg_confidence': np.mean(max_probs),
        'model_type': analyzer.model_type
    }

def analyze_class_performance(performance_data):
    """Detailed per-class performance metrics"""
    
    print("\nüìà STEP 2: Per-Class Performance Analysis")
    
    # Get class-wise metrics
    precision, recall, f1, support = precision_recall_fscore_support(
        true_labels_encoded, performance_data['predictions'], average=None
    )
    
    # Confusion matrix
    cm = confusion_matrix(true_labels_encoded, performance_data['predictions'])
    
    class_metrics = {}
    print("   Per-Class Metrics:")
    for i, class_name in enumerate(analyzer.label_encoder.classes_):
        class_metrics[class_name] = {
            'precision': float(precision[i]),
            'recall': float(recall[i]),
            'f1_score': float(f1[i]),
            'support': int(support[i]),
            'errors': int(np.sum(cm[i]) - cm[i][i])  # Total errors for this class
        }
        print(f"     ‚Ä¢ {class_name}: P={precision[i]:.3f}, R={recall[i]:.3f}, F1={f1[i]:.3f}")
    
    # Identify most problematic classes (for targeted fine-tuning)
    problematic_classes = sorted(class_metrics.items(), key=lambda x: x[1]['f1_score'])[:2]
    print(f"\n   üéØ Most Problematic Classes (lowest F1): {[c[0] for c in problematic_classes]}")
    
    return {
        'class_metrics': class_metrics,
        'confusion_matrix': cm.tolist(),
        'problematic_classes': [c[0] for c in problematic_classes]
    }

def analyze_confidence_distribution(performance_data):
    """Analyze prediction confidence for pruning insights"""
    
    print("\nüéØ STEP 3: Confidence Distribution Analysis")
    
    probabilities = performance_data['probabilities']
    max_probs = np.max(probabilities, axis=1)
    entropy = -np.sum(probabilities * np.log(probabilities + 1e-10), axis=1)
    
    # Confidence thresholds for different scenarios
    confidence_thresholds = [0.5, 0.7, 0.8, 0.9, 0.95]
    confidence_analysis = {}
    
    print("   Confidence Distribution:")
    for threshold in confidence_thresholds:
        high_conf_mask = max_probs >= threshold
        high_conf_accuracy = np.mean(performance_data['predictions'][high_conf_mask] == true_labels_encoded[high_conf_mask]) if np.any(high_conf_mask) else 0
        coverage = np.mean(high_conf_mask)
        
        confidence_analysis[f"threshold_{threshold}"] = {
            'accuracy': float(high_conf_accuracy),
            'coverage': float(coverage),
            'sample_count': int(np.sum(high_conf_mask))
        }
        print(f"     ‚Ä¢ ‚â•{threshold}: {coverage:.1%} samples, {high_conf_accuracy:.3f} accuracy")
    
    # Low confidence samples (candidates for fine-tuning focus)
    low_conf_threshold = 0.6
    low_conf_mask = max_probs < low_conf_threshold
    low_conf_indices = np.where(low_conf_mask)[0]
    
    print(f"\n   üìâ Low Confidence Samples (<{low_conf_threshold}): {len(low_conf_indices)} ({len(low_conf_indices)/len(test_texts):.1%})")
    
    return {
        'confidence_analysis': confidence_analysis,
        'entropy_stats': {
            'mean': float(np.mean(entropy)),
            'std': float(np.std(entropy)),
            'high_entropy_samples': int(np.sum(entropy > np.percentile(entropy, 90)))
        },
        'low_confidence_indices': low_conf_indices.tolist(),
        'low_confidence_threshold': low_conf_threshold
    }

def analyze_linguistic_patterns(performance_data):
    """Enhanced keyword analysis for understanding model failures"""
    
    print("\nüìù STEP 4: Linguistic Pattern Analysis")
    
    misclassified_texts = performance_data['misclassified_texts']
    correctly_classified_texts = [test_texts[i] for i in range(len(test_texts)) if not performance_data['misclassified_mask'][i]]
    
    if len(misclassified_texts) == 0:
        print("   ‚úÖ No misclassified texts found!")
        return {
            'problematic_keywords': [],
            'protective_keywords': [],
            'vocabulary_size': 0,
            'tfidf_params': {}
        }
    
    # Enhanced TF-IDF analysis
    vectorizer = TfidfVectorizer(
        max_features=500, 
        stop_words='english', 
        ngram_range=(1, 3),  # Include trigrams for better context
        min_df=2,
        max_df=0.8
    )
    
    # Balanced sampling for fair comparison
    sample_size = min(len(misclassified_texts), len(correctly_classified_texts))
    balanced_correct = correctly_classified_texts[:sample_size]
    
    all_texts = misclassified_texts + balanced_correct
    vectorizer.fit(all_texts)
    
    # TF-IDF difference analysis
    misc_tfidf = vectorizer.transform(misclassified_texts).mean(axis=0).A1
    correct_tfidf = vectorizer.transform(balanced_correct).mean(axis=0).A1
    
    feature_names = vectorizer.get_feature_names_out()
    score_diff = misc_tfidf - correct_tfidf
    
    # Top problematic and protective features
    problematic_indices = score_diff.argsort()[-15:][::-1]  # Top 15 problematic
    protective_indices = score_diff.argsort()[:10]  # Top 10 protective
    
    problematic_keywords = [(feature_names[i], float(score_diff[i])) for i in problematic_indices if score_diff[i] > 0.001]
    protective_keywords = [(feature_names[i], float(abs(score_diff[i]))) for i in protective_indices if score_diff[i] < -0.001]
    
    print(f"   üö® Top Problematic Keywords (higher in errors):")
    for keyword, score in problematic_keywords[:8]:
        print(f"     ‚Ä¢ '{keyword}': +{score:.4f}")
    
    print(f"\n   ‚úÖ Top Protective Keywords (higher in correct):")
    for keyword, score in protective_keywords[:5]:
        print(f"     ‚Ä¢ '{keyword}': -{score:.4f}")
    
    return {
        'problematic_keywords': problematic_keywords,
        'protective_keywords': protective_keywords,
        'vocabulary_size': len(feature_names),
        'tfidf_params': {
            'max_features': 500,
            'ngram_range': [1, 3],
            'min_df': 2,
            'max_df': 0.8
        }
    }

def generate_fine_tuning_recommendations(performance_data, class_analysis, confidence_analysis, linguistic_analysis):
    """Generate specific recommendations for fine-tuning and pruning"""
    
    print("\nüéØ STEP 5: Fine-Tuning & Pruning Recommendations")
    
    recommendations = {
        'fine_tuning': {},
        'pruning': {},
        'data_augmentation': {},
        'architecture': {}
    }
    
    # Fine-tuning recommendations
    print("   üìö Fine-Tuning Recommendations:")
    
    # Target problematic classes
    problematic_classes = class_analysis['problematic_classes']
    recommendations['fine_tuning']['target_classes'] = problematic_classes
    print(f"     ‚Ä¢ Focus on classes: {problematic_classes}")
    
    # Learning rate suggestions based on performance
    if performance_data['accuracy'] > 0.8:
        lr_suggestion = "1e-5 to 5e-5 (conservative fine-tuning)"
    elif performance_data['accuracy'] > 0.7:
        lr_suggestion = "5e-5 to 1e-4 (moderate fine-tuning)"
    else:
        lr_suggestion = "1e-4 to 5e-4 (aggressive fine-tuning)"
    
    recommendations['fine_tuning']['learning_rate'] = lr_suggestion
    recommendations['fine_tuning']['model_type'] = performance_data['model_type']
    print(f"     ‚Ä¢ Learning rate: {lr_suggestion}")
    print(f"     ‚Ä¢ Model type: {performance_data['model_type'].upper()}")
    
    # Sample selection for fine-tuning
    low_conf_count = len(confidence_analysis['low_confidence_indices'])
    recommendations['fine_tuning']['focus_samples'] = {
        'low_confidence_count': low_conf_count,
        'misclassified_count': len(performance_data['misclassified_indices']),
        'strategy': 'Hard negative mining + low confidence samples'
    }
    print(f"     ‚Ä¢ Priority samples: {low_conf_count} low-confidence + {len(performance_data['misclassified_indices'])} errors")
    
    # Pruning recommendations
    print("\n   ‚úÇÔ∏è Pruning Recommendations:")
    
    # Confidence-based pruning strategy
    high_conf_90 = confidence_analysis['confidence_analysis']['threshold_0.9']
    if high_conf_90['coverage'] > 0.7 and high_conf_90['accuracy'] > 0.95:
        pruning_strategy = "Aggressive pruning (30-50%) - high confidence retained"
    elif high_conf_90['coverage'] > 0.5:
        pruning_strategy = "Moderate pruning (20-30%) - good confidence distribution"
    else:
        pruning_strategy = "Conservative pruning (10-20%) - low confidence samples"
    
    recommendations['pruning']['strategy'] = pruning_strategy
    recommendations['pruning']['confidence_threshold'] = 0.9
    recommendations['pruning']['expected_coverage'] = high_conf_90['coverage']
    recommendations['pruning']['model_type'] = performance_data['model_type']
    print(f"     ‚Ä¢ Strategy: {pruning_strategy}")
    print(f"     ‚Ä¢ Confidence threshold: 0.9 (covers {high_conf_90['coverage']:.1%} with {high_conf_90['accuracy']:.3f} accuracy)")
    
    # Data augmentation recommendations
    print("\n   üîÑ Data Augmentation Recommendations:")
    top_problematic = linguistic_analysis['problematic_keywords'][:5]
    recommendations['data_augmentation']['target_keywords'] = [kw[0] for kw in top_problematic]
    recommendations['data_augmentation']['methods'] = [
        'Synonym replacement for problematic terms',
        'Back-translation for class balance',
        'Paraphrasing for robustness'
    ]
    print(f"     ‚Ä¢ Target problematic keywords: {[kw[0] for kw in top_problematic[:3]]}")
    print(f"     ‚Ä¢ Focus on: {problematic_classes} classes")
    
    return recommendations

# Execute comprehensive analysis
print("üöÄ Starting Universal Fine-Tuning & Pruning Analysis")
print("="*60)

# Start timing for the entire analysis
analysis_start_time = time.time()

# Run analysis pipeline
performance_data = analyze_model_performance()
class_analysis = analyze_class_performance(performance_data)
confidence_analysis = analyze_confidence_distribution(performance_data)
linguistic_analysis = analyze_linguistic_patterns(performance_data)
recommendations = generate_fine_tuning_recommendations(
    performance_data, class_analysis, confidence_analysis, linguistic_analysis
)

def convert_to_json_serializable(obj):
    """Convert numpy types to JSON serializable Python types"""
    if isinstance(obj, np.integer):
        return int(obj)
    elif isinstance(obj, np.floating):
        return float(obj)
    elif isinstance(obj, np.ndarray):
        return obj.tolist()
    elif isinstance(obj, dict):
        return {key: convert_to_json_serializable(value) for key, value in obj.items()}
    elif isinstance(obj, list):
        return [convert_to_json_serializable(item) for item in obj]
    else:
        return obj

# Compile comprehensive results (convert all numpy types to JSON serializable)
comprehensive_results = {
    'metadata': {
        'analysis_date': time.strftime('%Y-%m-%d %H:%M:%S'),
        'dataset_size': len(test_texts),
        'model_path': str(MODEL_DIR),
        'model_name': analyzer.model_name,
        'model_type': performance_data['model_type'],
        'analysis_type': 'universal_fine_tuning_pruning'
    },
    'performance_metrics': {
        'overall_accuracy': float(performance_data['accuracy']),
        'error_rate': float(performance_data['error_rate']),
        'avg_confidence': float(performance_data['avg_confidence']),
        'total_misclassifications': len(performance_data['misclassified_texts'])
    },
    'class_analysis': convert_to_json_serializable(class_analysis),
    'confidence_analysis': convert_to_json_serializable(confidence_analysis),
    'linguistic_analysis': convert_to_json_serializable(linguistic_analysis),
    'recommendations': convert_to_json_serializable(recommendations),
    'sample_indices': {
        'misclassified': [int(x) for x in performance_data['misclassified_indices']],
        'low_confidence': [int(x) for x in confidence_analysis['low_confidence_indices']]
    }
}

# Save results
os.makedirs('analysis_results', exist_ok=True)

# Detailed JSON for programmatic use
with open('analysis_results/comprehensive_analysis.json', 'w') as f:
    json.dump(comprehensive_results, f, indent=2)

# Human-readable summary
summary_text = f"""
UNIVERSAL FINE-TUNING & PRUNING ANALYSIS SUMMARY
===============================================

üéØ PERFORMANCE OVERVIEW ({performance_data['model_type'].upper()}):
‚Ä¢ Overall Accuracy: {performance_data['accuracy']:.1%}
‚Ä¢ Error Rate: {performance_data['error_rate']:.1%}
‚Ä¢ Average Confidence: {performance_data['avg_confidence']:.3f}
‚Ä¢ Model: {analyzer.model_name}

üìä KEY METRICS FOR FINE-TUNING:
‚Ä¢ Most Problematic Classes: {class_analysis['problematic_classes']}
‚Ä¢ Low Confidence Samples: {len(confidence_analysis['low_confidence_indices'])} ({len(confidence_analysis['low_confidence_indices'])/len(test_texts):.1%})
‚Ä¢ High-Priority Samples: {len(performance_data['misclassified_indices'])} errors + {len(confidence_analysis['low_confidence_indices'])} low-conf

üîß RECOMMENDED FINE-TUNING STRATEGY:
‚Ä¢ Learning Rate: {recommendations['fine_tuning']['learning_rate']}
‚Ä¢ Focus Areas: {', '.join(class_analysis['problematic_classes'])}
‚Ä¢ Target Keywords: {', '.join([kw[0] for kw in linguistic_analysis['problematic_keywords'][:5]])}

‚úÇÔ∏è RECOMMENDED PRUNING STRATEGY:
‚Ä¢ {recommendations['pruning']['strategy']}
‚Ä¢ Confidence Threshold: {recommendations['pruning']['confidence_threshold']}
‚Ä¢ Expected Coverage: {recommendations['pruning']['expected_coverage']:.1%}

üìã NEXT STEPS:
1. Use misclassified samples for hard negative mining
2. Focus fine-tuning on {', '.join(class_analysis['problematic_classes'])} classes
3. Apply confidence-based pruning with 0.9 threshold
4. Monitor performance on high-entropy samples
"""

with open('analysis_results/analysis_summary.txt', 'w') as f:
    f.write(summary_text)

print(f"\nüíæ ANALYSIS COMPLETE!")
print(f"üìÅ Results saved to:")
print(f"   ‚Ä¢ analysis_results/comprehensive_analysis.json (detailed metrics)")
print(f"   ‚Ä¢ analysis_results/analysis_summary.txt (human-readable)")
print(f"\nüéØ Ready for Notebook 6: Fine-tuning with targeted improvements!")
print(f"‚è±Ô∏è  Analysis completed in {time.time() - analysis_start_time:.1f} seconds")
print(f"üîß Model Type Used: {performance_data['model_type'].upper()}")

üöÄ Starting Comprehensive Fine-Tuning & Pruning Analysis
üîç Generating comprehensive performance analysis...

üìä STEP 1: Basic Performance Analysis
   ‚Ä¢ Overall Accuracy: 0.7913
   ‚Ä¢ Error Rate: 0.2087
   ‚Ä¢ Total Misclassifications: 253
   ‚Ä¢ Average Confidence: 0.7314

üìà STEP 2: Per-Class Performance Analysis
   Per-Class Metrics:
     ‚Ä¢ negative: P=0.653, R=0.848, F1=0.738
     ‚Ä¢ neutral: P=0.862, R=0.831, F1=0.846
     ‚Ä¢ positive: P=0.724, R=0.683, F1=0.703

   üéØ Most Problematic Classes (lowest F1): ['positive', 'negative']

üéØ STEP 3: Confidence Distribution Analysis
   Confidence Distribution:
     ‚Ä¢ ‚â•0.5: 95.6% samples, 0.803 accuracy
     ‚Ä¢ ‚â•0.7: 64.5% samples, 0.886 accuracy
     ‚Ä¢ ‚â•0.8: 39.4% samples, 0.929 accuracy
     ‚Ä¢ ‚â•0.9: 0.0% samples, 0.000 accuracy
     ‚Ä¢ ‚â•0.95: 0.0% samples, 0.000 accuracy

   üìâ Low Confidence Samples (<0.6): 195 (16.1%)

üìù STEP 4: Linguistic Pattern Analysis
   üö® Top Problematic Keywords (high

## 8. üìã Summary

### ‚úÖ Completed:
- **Interactive Dashboard**: SHAP and LIME explanations for any text
- **Mistake Analysis**: Analyze specific model errors  
- **Misclassification Patterns**: Key insights for fine-tuning

### üìä Key Findings:
- Error rate: ~20% on test data
- Main confusion patterns identified
- Problematic keywords extracted

### üîú Next Steps:
Results saved to `analysis_results/` for **Notebook 6: Fine-tuning with Pruning Methods**