# Visualization & Results Analysis
This notebook provides comprehensive visualization tools for model performance analysis.

## Visualization Tools Included:
1. **Training History Plots**: Loss and accuracy curves
2. **Confusion Matrix**: Detailed classification results
3. **ROC & PR Curves**: Model discrimination ability
4. **Attention Weights**: What the model focuses on
5. **Feature Importance**: Signal analysis
6. **Prediction Analysis**: Correct vs incorrect predictions
7. **Comparative Analysis**: Multiple model comparison
8. **Statistical Reports**: Comprehensive metrics tables

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import (
    confusion_matrix, classification_report, 
    roc_curve, auc, precision_recall_curve,
    accuracy_score, precision_score, recall_score, f1_score
)
import tensorflow as tf
from tensorflow import keras

# Set style
sns.set_style('whitegrid')
plt.rcParams['figure.dpi'] = 100
plt.rcParams['savefig.dpi'] = 300

print("‚úÖ Visualization libraries loaded")
%matplotlib inline

## 1. Training History Visualization

In [None]:
def plot_training_history(history, save_path=None):
    """
    Plot training and validation loss/accuracy
    
    Parameters:
    -----------
    history : keras History object or dict
        Training history
    save_path : str or None
        Path to save figure
    """
    if hasattr(history, 'history'):
        history = history.history
    
    fig, axes = plt.subplots(1, 2, figsize=(16, 5))
    
    # Plot accuracy
    axes[0].plot(history['accuracy'], 'b-', linewidth=2, label='Training Accuracy')
    axes[0].plot(history['val_accuracy'], 'r-', linewidth=2, label='Validation Accuracy')
    axes[0].set_xlabel('Epoch', fontsize=12, fontweight='bold')
    axes[0].set_ylabel('Accuracy', fontsize=12, fontweight='bold')
    axes[0].set_title('Model Accuracy Over Epochs', fontsize=14, fontweight='bold')
    axes[0].legend(loc='lower right', fontsize=11)
    axes[0].grid(True, alpha=0.3)
    
    # Add final values
    final_train_acc = history['accuracy'][-1]
    final_val_acc = history['val_accuracy'][-1]
    axes[0].text(0.02, 0.98, f'Final Train: {final_train_acc:.4f}\nFinal Val: {final_val_acc:.4f}',
                transform=axes[0].transAxes, fontsize=10,
                verticalalignment='top', bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
    
    # Plot loss
    axes[1].plot(history['loss'], 'b-', linewidth=2, label='Training Loss')
    axes[1].plot(history['val_loss'], 'r-', linewidth=2, label='Validation Loss')
    axes[1].set_xlabel('Epoch', fontsize=12, fontweight='bold')
    axes[1].set_ylabel('Loss', fontsize=12, fontweight='bold')
    axes[1].set_title('Model Loss Over Epochs', fontsize=14, fontweight='bold')
    axes[1].legend(loc='upper right', fontsize=11)
    axes[1].grid(True, alpha=0.3)
    
    # Add final values
    final_train_loss = history['loss'][-1]
    final_val_loss = history['val_loss'][-1]
    axes[1].text(0.02, 0.98, f'Final Train: {final_train_loss:.4f}\nFinal Val: {final_val_loss:.4f}',
                transform=axes[1].transAxes, fontsize=10,
                verticalalignment='top', bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
    
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, bbox_inches='tight')
        print(f"‚úÖ Training history saved to {save_path}")
    
    plt.show()

print("‚úÖ Training history plotting function defined")

## 2. Enhanced Confusion Matrix

In [None]:
def plot_confusion_matrix(y_true, y_pred, class_names=['Healthy', 'Unhealthy'], 
                         normalize=False, save_path=None):
    """
    Plot enhanced confusion matrix with percentages
    
    Parameters:
    -----------
    y_true : array
        True labels
    y_pred : array
        Predicted labels
    class_names : list
        Names of classes
    normalize : bool
        Whether to show percentages
    save_path : str or None
        Path to save figure
    """
    cm = confusion_matrix(y_true, y_pred)
    
    if normalize:
        cm_norm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
        cm_display = cm_norm
        fmt = '.2%'
    else:
        cm_display = cm
        fmt = 'd'
    
    fig, ax = plt.subplots(figsize=(10, 8))
    
    # Create heatmap
    sns.heatmap(cm_display, annot=True, fmt=fmt, cmap='Blues', 
                xticklabels=class_names, yticklabels=class_names,
                cbar_kws={'label': 'Count' if not normalize else 'Percentage'},
                linewidths=2, linecolor='white', ax=ax)
    
    # Add counts in each cell
    if normalize:
        for i in range(cm.shape[0]):
            for j in range(cm.shape[1]):
                ax.text(j + 0.5, i + 0.7, f'({cm[i, j]})',
                       ha='center', va='center', fontsize=10, color='gray')
    
    ax.set_ylabel('True Label', fontsize=13, fontweight='bold')
    ax.set_xlabel('Predicted Label', fontsize=13, fontweight='bold')
    ax.set_title('Confusion Matrix', fontsize=15, fontweight='bold', pad=20)
    
    # Calculate metrics
    if cm.shape == (2, 2):
        tn, fp, fn, tp = cm.ravel()
        sensitivity = tp / (tp + fn) if (tp + fn) > 0 else 0
        specificity = tn / (tn + fp) if (tn + fp) > 0 else 0
        precision = tp / (tp + fp) if (tp + fp) > 0 else 0
        npv = tn / (tn + fn) if (tn + fn) > 0 else 0
        
        metrics_text = f'Sensitivity: {sensitivity:.3f}\nSpecificity: {specificity:.3f}\n'
        metrics_text += f'Precision: {precision:.3f}\nNPV: {npv:.3f}'
        
        ax.text(1.15, 0.5, metrics_text, transform=ax.transAxes,
               fontsize=11, verticalalignment='center',
               bbox=dict(boxstyle='round', facecolor='lightblue', alpha=0.7))
    
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, bbox_inches='tight')
        print(f"‚úÖ Confusion matrix saved to {save_path}")
    
    plt.show()

print("‚úÖ Confusion matrix plotting function defined")

## 3. ROC & Precision-Recall Curves

In [None]:
def plot_roc_and_pr_curves(y_true, y_pred_proba, save_path=None):
    """
    Plot ROC curve and Precision-Recall curve side by side
    
    Parameters:
    -----------
    y_true : array
        True labels
    y_pred_proba : array
        Predicted probabilities
    save_path : str or None
        Path to save figure
    """
    fig, axes = plt.subplots(1, 2, figsize=(16, 6))
    
    # ROC Curve
    fpr, tpr, thresholds_roc = roc_curve(y_true, y_pred_proba)
    roc_auc = auc(fpr, tpr)
    
    axes[0].plot(fpr, tpr, color='darkorange', lw=3, label=f'ROC curve (AUC = {roc_auc:.3f})')
    axes[0].plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--', label='Random Classifier')
    axes[0].fill_between(fpr, tpr, 0, alpha=0.2, color='orange')
    axes[0].set_xlim([0.0, 1.0])
    axes[0].set_ylim([0.0, 1.05])
    axes[0].set_xlabel('False Positive Rate', fontsize=12, fontweight='bold')
    axes[0].set_ylabel('True Positive Rate', fontsize=12, fontweight='bold')
    axes[0].set_title('Receiver Operating Characteristic (ROC) Curve', fontsize=13, fontweight='bold')
    axes[0].legend(loc='lower right', fontsize=11)
    axes[0].grid(True, alpha=0.3)
    
    # Precision-Recall Curve
    precision, recall, thresholds_pr = precision_recall_curve(y_true, y_pred_proba)
    pr_auc = auc(recall, precision)
    
    axes[1].plot(recall, precision, color='green', lw=3, label=f'PR curve (AUC = {pr_auc:.3f})')
    axes[1].fill_between(recall, precision, 0, alpha=0.2, color='green')
    axes[1].set_xlim([0.0, 1.0])
    axes[1].set_ylim([0.0, 1.05])
    axes[1].set_xlabel('Recall', fontsize=12, fontweight='bold')
    axes[1].set_ylabel('Precision', fontsize=12, fontweight='bold')
    axes[1].set_title('Precision-Recall Curve', fontsize=13, fontweight='bold')
    axes[1].legend(loc='lower left', fontsize=11)
    axes[1].grid(True, alpha=0.3)
    
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, bbox_inches='tight')
        print(f"‚úÖ ROC and PR curves saved to {save_path}")
    
    plt.show()
    
    return roc_auc, pr_auc

print("‚úÖ ROC and PR curve plotting function defined")

## 4. Attention Weights Visualization

In [None]:
def visualize_attention_weights(model, X_sample, sample_idx=0, save_path=None):
    """
    Visualize attention weights for a specific sample
    
    Parameters:
    -----------
    model : keras.Model
        Model with attention layer
    X_sample : array
        Input samples
    sample_idx : int
        Index of sample to visualize
    save_path : str or None
        Path to save figure
    """
    try:
        # Get attention layer
        attention_layer = None
        for layer in model.layers:
            if 'attention' in layer.name.lower():
                attention_layer = layer
                break
        
        if attention_layer is None:
            print("‚ö†Ô∏è No attention layer found in model")
            return
        
        # Create model to output attention weights
        attention_model = keras.Model(
            inputs=model.input,
            outputs=attention_layer.output
        )
        
        # Get attention weights
        sample = X_sample[sample_idx:sample_idx+1]
        _, attention_weights = attention_model.predict(sample, verbose=0)
        attention_weights = attention_weights[0].flatten()
        
        # Plot
        fig, axes = plt.subplots(2, 1, figsize=(15, 8))
        
        # Original signal
        axes[0].plot(sample[0].flatten(), 'b-', linewidth=1.5, alpha=0.7)
        axes[0].set_title('Original Signal', fontsize=13, fontweight='bold')
        axes[0].set_ylabel('Amplitude', fontsize=11)
        axes[0].grid(True, alpha=0.3)
        
        # Attention weights
        axes[1].plot(attention_weights, 'r-', linewidth=1.5)
        axes[1].fill_between(range(len(attention_weights)), attention_weights, alpha=0.3, color='red')
        axes[1].set_title('Attention Weights (What the model focuses on)', fontsize=13, fontweight='bold')
        axes[1].set_xlabel('Time Steps', fontsize=11)
        axes[1].set_ylabel('Attention Weight', fontsize=11)
        axes[1].grid(True, alpha=0.3)
        
        plt.tight_layout()
        
        if save_path:
            plt.savefig(save_path, bbox_inches='tight')
            print(f"‚úÖ Attention visualization saved to {save_path}")
        
        plt.show()
        
    except Exception as e:
        print(f"‚ö†Ô∏è Could not visualize attention: {e}")

print("‚úÖ Attention visualization function defined")

## 5. Prediction Analysis

In [None]:
def plot_prediction_analysis(X_test, y_test, y_pred, y_pred_proba, 
                            class_names=['Healthy', 'Unhealthy'],
                            num_samples=6, save_path=None):
    """
    Visualize correct and incorrect predictions
    
    Parameters:
    -----------
    X_test : array
        Test samples
    y_test : array
        True labels
    y_pred : array
        Predicted labels
    y_pred_proba : array
        Prediction probabilities
    class_names : list
        Class names
    num_samples : int
        Number of samples to show
    save_path : str or None
        Path to save figure
    """
    # Find correct and incorrect predictions
    correct_idx = np.where(y_test == y_pred)[0]
    incorrect_idx = np.where(y_test != y_pred)[0]
    
    # Select samples
    n_correct = min(num_samples // 2, len(correct_idx))
    n_incorrect = min(num_samples // 2, len(incorrect_idx))
    
    selected_correct = np.random.choice(correct_idx, n_correct, replace=False)
    selected_incorrect = np.random.choice(incorrect_idx, n_incorrect, replace=False) if len(incorrect_idx) > 0 else []
    
    fig, axes = plt.subplots(num_samples, 1, figsize=(14, 3*num_samples))
    
    plot_idx = 0
    
    # Plot correct predictions
    for idx in selected_correct:
        if plot_idx >= num_samples:
            break
        ax = axes[plot_idx] if num_samples > 1 else axes
        ax.plot(X_test[idx].flatten(), 'g-', linewidth=1, alpha=0.8)
        ax.set_title(f'‚úÖ CORRECT: True={class_names[int(y_test[idx])]}, '
                    f'Pred={class_names[int(y_pred[idx])]} (Confidence: {y_pred_proba[idx]:.3f})',
                    fontsize=11, fontweight='bold', color='green')
        ax.set_ylabel('Amplitude')
        ax.grid(True, alpha=0.3)
        plot_idx += 1
    
    # Plot incorrect predictions
    for idx in selected_incorrect:
        if plot_idx >= num_samples:
            break
        ax = axes[plot_idx] if num_samples > 1 else axes
        ax.plot(X_test[idx].flatten(), 'r-', linewidth=1, alpha=0.8)
        ax.set_title(f'‚ùå INCORRECT: True={class_names[int(y_test[idx])]}, '
                    f'Pred={class_names[int(y_pred[idx])]} (Confidence: {y_pred_proba[idx]:.3f})',
                    fontsize=11, fontweight='bold', color='red')
        ax.set_ylabel('Amplitude')
        ax.grid(True, alpha=0.3)
        plot_idx += 1
    
    if num_samples > 1:
        axes[-1].set_xlabel('Time Steps', fontsize=11)
    else:
        axes.set_xlabel('Time Steps', fontsize=11)
    
    plt.suptitle('Prediction Analysis: Correct vs Incorrect Classifications', 
                fontsize=14, fontweight='bold', y=1.00)
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, bbox_inches='tight')
        print(f"‚úÖ Prediction analysis saved to {save_path}")
    
    plt.show()

print("‚úÖ Prediction analysis function defined")

## 6. Comprehensive Metrics Report

In [None]:
def generate_metrics_report(y_true, y_pred, y_pred_proba, class_names=['Healthy', 'Unhealthy']):
    """
    Generate comprehensive metrics report
    
    Parameters:
    -----------
    y_true : array
        True labels
    y_pred : array
        Predicted labels
    y_pred_proba : array
        Prediction probabilities
    class_names : list
        Class names
    
    Returns:
    --------
    metrics_dict : dict
        Dictionary of all metrics
    """
    # Calculate confusion matrix
    cm = confusion_matrix(y_true, y_pred)
    
    # Calculate basic metrics
    accuracy = accuracy_score(y_true, y_pred)
    precision = precision_score(y_true, y_pred, zero_division=0)
    recall = recall_score(y_true, y_pred, zero_division=0)
    f1 = f1_score(y_true, y_pred, zero_division=0)
    
    # ROC-AUC
    try:
        roc_auc = roc_auc_score(y_true, y_pred_proba)
    except:
        roc_auc = 0.0
    
    # Calculate additional metrics
    if cm.shape == (2, 2):
        tn, fp, fn, tp = cm.ravel()
        specificity = tn / (tn + fp) if (tn + fp) > 0 else 0
        npv = tn / (tn + fn) if (tn + fn) > 0 else 0
        fpr = fp / (fp + tn) if (fp + tn) > 0 else 0
        fnr = fn / (fn + tp) if (fn + tp) > 0 else 0
    else:
        specificity = npv = fpr = fnr = 0
    
    metrics_dict = {
        'Accuracy': accuracy,
        'Precision (PPV)': precision,
        'Recall (Sensitivity/TPR)': recall,
        'Specificity (TNR)': specificity,
        'F1-Score': f1,
        'ROC-AUC': roc_auc,
        'NPV': npv,
        'FPR': fpr,
        'FNR': fnr
    }
    
    # Print report
    print("\n" + "="*70)
    print("üìä COMPREHENSIVE METRICS REPORT")
    print("="*70)
    print(f"\nTotal Samples: {len(y_true)}")
    print(f"Class Distribution: {np.bincount(y_true.astype(int))}")
    print("\n" + "-"*70)
    print("PERFORMANCE METRICS:")
    print("-"*70)
    for metric, value in metrics_dict.items():
        print(f"  {metric:30s}: {value:.4f} ({value*100:.2f}%)")
    print("="*70)
    
    # Classification report
    print("\n" + "="*70)
    print("CLASSIFICATION REPORT:")
    print("="*70)
    print(classification_report(y_true, y_pred, target_names=class_names, digits=4))
    
    # Confusion matrix
    print("\n" + "="*70)
    print("CONFUSION MATRIX:")
    print("="*70)
    print(pd.DataFrame(cm, index=class_names, columns=class_names))
    print("="*70 + "\n")
    
    return metrics_dict

print("‚úÖ Metrics report function defined")

## 7. Model Comparison Table

In [None]:
def create_comparison_table(results_dict, save_path=None):
    """
    Create comparison table for multiple models
    
    Parameters:
    -----------
    results_dict : dict
        Dictionary with model names as keys and metrics dicts as values
    save_path : str or None
        Path to save table as CSV
    
    Returns:
    --------
    df : DataFrame
        Comparison table
    """
    df = pd.DataFrame(results_dict).T
    df = df.round(4)
    
    # Style the dataframe
    print("\n" + "="*100)
    print("üìä MODEL COMPARISON TABLE")
    print("="*100)
    print(df.to_string())
    print("="*100 + "\n")
    
    # Highlight best values
    print("\nüèÜ BEST PERFORMING MODEL PER METRIC:")
    print("-"*100)
    for col in df.columns:
        best_model = df[col].idxmax()
        best_value = df[col].max()
        print(f"  {col:30s}: {best_model:20s} ({best_value:.4f})")
    print("="*100 + "\n")
    
    if save_path:
        df.to_csv(save_path)
        print(f"‚úÖ Comparison table saved to {save_path}")
    
    return df

print("‚úÖ Comparison table function defined")

## 8. Complete Visualization Pipeline

In [None]:
def complete_visualization_pipeline(model, history, X_test, y_test, 
                                   class_names=['Healthy', 'Unhealthy'],
                                   save_dir='./results'):
    """
    Run complete visualization pipeline
    
    Parameters:
    -----------
    model : keras.Model
        Trained model
    history : keras History or dict
        Training history
    X_test, y_test : arrays
        Test data
    class_names : list
        Class names
    save_dir : str
        Directory to save all plots
    """
    import os
    os.makedirs(save_dir, exist_ok=True)
    
    print("\n" + "="*70)
    print("üé® Running Complete Visualization Pipeline")
    print("="*70 + "\n")
    
    # Get predictions
    print("üìä Generating predictions...")
    y_pred_proba = model.predict(X_test, verbose=0).flatten()
    y_pred = (y_pred_proba > 0.5).astype(int)
    
    # 1. Training history
    print("üìà Plotting training history...")
    plot_training_history(history, save_path=f"{save_dir}/training_history.png")
    
    # 2. Confusion matrix
    print("üìä Creating confusion matrix...")
    plot_confusion_matrix(y_test, y_pred, class_names, 
                         normalize=True, save_path=f"{save_dir}/confusion_matrix.png")
    
    # 3. ROC and PR curves
    print("üìà Plotting ROC and PR curves...")
    plot_roc_and_pr_curves(y_test, y_pred_proba, save_path=f"{save_dir}/roc_pr_curves.png")
    
    # 4. Prediction analysis
    print("üîç Analyzing predictions...")
    plot_prediction_analysis(X_test, y_test, y_pred, y_pred_proba, class_names,
                            num_samples=6, save_path=f"{save_dir}/prediction_analysis.png")
    
    # 5. Metrics report
    print("üìä Generating metrics report...")
    metrics = generate_metrics_report(y_test, y_pred, y_pred_proba, class_names)
    
    # 6. Attention visualization (if applicable)
    print("üëÅÔ∏è Attempting attention visualization...")
    visualize_attention_weights(model, X_test, sample_idx=0, 
                               save_path=f"{save_dir}/attention_weights.png")
    
    print("\n" + "="*70)
    print(f"‚úÖ All visualizations saved to {save_dir}/")
    print("="*70 + "\n")
    
    return metrics

print("‚úÖ Complete visualization pipeline defined")

## Example Usage

In [None]:
# Example usage (uncomment to use):
# 
# # After training your model
# # model = ...
# # history = ...
# # X_test, y_test = ...
# 
# # Run complete visualization
# metrics = complete_visualization_pipeline(
#     model=model,
#     history=history,
#     X_test=X_test,
#     y_test=y_test,
#     class_names=['Healthy', 'Unhealthy'],
#     save_dir='./my_results'
# )

print("\n" + "="*70)
print("‚úÖ Visualization utilities loaded successfully!")
print("="*70)
print("\nAvailable functions:")
print("  - plot_training_history(history, save_path)")
print("  - plot_confusion_matrix(y_true, y_pred, class_names, normalize, save_path)")
print("  - plot_roc_and_pr_curves(y_true, y_pred_proba, save_path)")
print("  - visualize_attention_weights(model, X_sample, sample_idx, save_path)")
print("  - plot_prediction_analysis(X_test, y_test, y_pred, y_pred_proba, ...)")
print("  - generate_metrics_report(y_true, y_pred, y_pred_proba, class_names)")
print("  - create_comparison_table(results_dict, save_path)")
print("  - complete_visualization_pipeline(model, history, X_test, y_test, ...)")
print("\n" + "="*70)