In [33]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import train_test_split, cross_val_score
from sklearn.metrics import (accuracy_score, f1_score, confusion_matrix, 
                           classification_report, roc_auc_score, r2_score,
                           mean_squared_error, mean_absolute_error)
from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier
from sklearn.svm import SVC
from sklearn.linear_model import LogisticRegression
from sklearn.neighbors import KNeighborsClassifier
from xgboost import XGBClassifier
from lightgbm import LGBMClassifier
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.impute import SimpleImputer
from imblearn.over_sampling import SMOTE
from imblearn.pipeline import Pipeline as ImbPipeline
import matplotlib as mpl
import os

# ==============================================
# SECTION 1: CONFIGURATION AND SETUP (UPDATED)
# ==============================================

# Set global configurations
RANDOM_SEED = 42
TEST_SIZE = 0.2
N_JOBS = -1  # Use all cores
N_SIMULATIONS = 25  # Number of Monte Carlo simulations
CV_FOLDS = 10  # 10-fold cross validation

# Set matplotlib configurations for high-quality plots with Times New Roman
plt.style.use('seaborn-v0_8')  # Updated to use correct seaborn style name
mpl.rcParams['font.family'] = 'Times New Roman'
mpl.rcParams['font.size'] = 12
mpl.rcParams['axes.titlesize'] = 14
mpl.rcParams['axes.labelsize'] = 12
mpl.rcParams['xtick.labelsize'] = 10
mpl.rcParams['ytick.labelsize'] = 10
mpl.rcParams['figure.dpi'] = 300
mpl.rcParams['savefig.dpi'] = 300
mpl.rcParams['figure.figsize'] = (8, 6)
mpl.rcParams['axes.grid'] = True  # Add grid by default
mpl.rcParams['grid.alpha'] = 0.3  # Make grid lines semi-transparent

# Create output directories if they don't exist
os.makedirs('output/plots', exist_ok=True)
os.makedirs('output/tables', exist_ok=True)

np.random.seed(RANDOM_SEED)

# ==============================================
# SECTION 2: DATA LOADING AND PREPROCESSING
# ==============================================

def load_data():
    """Load and preprocess the ECG data."""
    df = pd.read_csv('merged_ecg_data_cleaned.csv')
    
    # Ensure balanced classes
    df = df.groupby('wct_label_encoded' if 'wct_label_encoded' in df.columns else 'wct_label').head(50000)
    
    features = ['bandwidth', 'filtering', 'rr_interval', 'p_onset', 'p_end', 
               'qrs_onset', 'qrs_end', 't_end', 'p_axis', 'qrs_axis', 
               't_axis', 'qrs_duration']
    target = 'wct_label_encoded' if 'wct_label_encoded' in df.columns else 'wct_label'
    
    # Data cleaning
    for col in features:
        if col in df.columns:
            df[col] = pd.to_numeric(df[col], errors='coerce')
            df[col] = df[col].fillna(df[col].median())
    
    df = df.dropna(subset=[target])
    
    if not np.issubdtype(df[target].dtype, np.number):
        le = LabelEncoder()
        df[target] = le.fit_transform(df[target])
    
    X = df[features].values
    y = df[target].values
    
    print(f"Data shape: {X.shape}, Class distribution: {np.bincount(y)}")
    return X, y, features, df[target].name

# Load data
X, y, features, target_name = load_data()

# ==============================================
# SECTION 3: MODEL DEFINITION AND HYPERPARAMETERS
# ==============================================

# Define models with their hyperparameters including ranges and optimal values
models = {
    "XGBoost": {
        'model': XGBClassifier,
        'params': {
            'n_estimators': 200,
            'max_depth': 5,
            'learning_rate': 0.1,
            'subsample': 0.8,
            'colsample_bytree': 0.8,
            'random_state': RANDOM_SEED,
            'n_jobs': N_JOBS,
            'tree_method': 'hist'
        },
        'param_ranges': {
            'n_estimators': {'range': [100, 500], 'optimal': 200},
            'max_depth': {'range': [3, 10], 'optimal': 5},
            'learning_rate': {'range': [0.01, 0.2], 'optimal': 0.1},
            'subsample': {'range': [0.6, 1.0], 'optimal': 0.8},
            'colsample_bytree': {'range': [0.6, 1.0], 'optimal': 0.8}
        }
    },
    "LightGBM": {
        'model': LGBMClassifier,
        'params': {
            'n_estimators': 200,
            'max_depth': 5,
            'learning_rate': 0.1,
            'subsample': 0.8,
            'colsample_bytree': 0.8,
            'random_state': RANDOM_SEED,
            'n_jobs': N_JOBS
        },
        'param_ranges': {
            'n_estimators': {'range': [100, 500], 'optimal': 200},
            'max_depth': {'range': [3, 10], 'optimal': 5},
            'learning_rate': {'range': [0.01, 0.2], 'optimal': 0.1},
            'subsample': {'range': [0.6, 1.0], 'optimal': 0.8},
            'colsample_bytree': {'range': [0.6, 1.0], 'optimal': 0.8}
        }
    },
    "RandomForest": {
        'model': RandomForestClassifier,
        'params': {
            'n_estimators': 200,
            'max_depth': 10,
            'min_samples_split': 5,
            'min_samples_leaf': 2,
            'random_state': RANDOM_SEED,
            'n_jobs': N_JOBS,
            'class_weight': 'balanced'
        },
        'param_ranges': {
            'n_estimators': {'range': [100, 500], 'optimal': 200},
            'max_depth': {'range': [5, 20], 'optimal': 10},
            'min_samples_split': {'range': [2, 10], 'optimal': 5},
            'min_samples_leaf': {'range': [1, 5], 'optimal': 2}
        }
    },
    "GradientBoosting": {
        'model': GradientBoostingClassifier,
        'params': {
            'n_estimators': 200,
            'max_depth': 5,
            'learning_rate': 0.1,
            'subsample': 0.8,
            'random_state': RANDOM_SEED
        },
        'param_ranges': {
            'n_estimators': {'range': [100, 500], 'optimal': 200},
            'max_depth': {'range': [3, 10], 'optimal': 5},
            'learning_rate': {'range': [0.01, 0.2], 'optimal': 0.1},
            'subsample': {'range': [0.6, 1.0], 'optimal': 0.8}
        }
    }
}

# Create detailed hyperparameter table
hyperparam_data = []
for name, details in models.items():
    for param, value in details['params'].items():
        if param in details['param_ranges']:
            hyperparam_data.append({
                'Model': name,
                'Parameter': param,
                'Optimal Value': value,
                'Range Min': details['param_ranges'][param]['range'][0],
                'Range Max': details['param_ranges'][param]['range'][1],
                'Value Used': value
            })
        else:  # For parameters without specified ranges (like random_state)
            hyperparam_data.append({
                'Model': name,
                'Parameter': param,
                'Optimal Value': value,
                'Range Min': 'N/A',
                'Range Max': 'N/A',
                'Value Used': value
            })

hyperparam_table = pd.DataFrame(hyperparam_data)

# Reorder columns for better readability
hyperparam_table = hyperparam_table[['Model', 'Parameter', 'Range Min', 'Range Max', 'Optimal Value', 'Value Used']]

# Save hyperparameter table
hyperparam_table.to_csv('output/tables/model_hyperparameters_detailed.csv', index=False)
print("\nSaved detailed hyperparameter table to output/tables/model_hyperparameters_detailed.csv")

# Create a summary table of just the optimal parameters
optimal_params_table = pd.DataFrame([
    {
        'Model': name,
        'Optimal Parameters': ', '.join([f"{k}={v}" for k, v in details['params'].items()])
    }
    for name, details in models.items()
])

optimal_params_table.to_csv('output/tables/optimal_parameters_summary.csv', index=False)
print("Saved optimal parameters summary to output/tables/optimal_parameters_summary.csv")


# ==============================================
# SECTION 4: MODEL TRAINING AND EVALUATION
# ==============================================

def evaluate_model(model, X_test, y_test):
    """Evaluate model performance and return metrics."""
    y_pred = model.predict(X_test)
    y_proba = model.predict_proba(X_test)[:, 1] if hasattr(model, 'predict_proba') else None
    
    metrics = {
        'Accuracy': accuracy_score(y_test, y_pred),
        'F1_Score': f1_score(y_test, y_pred, average='weighted'),
        'R2_Score': r2_score(y_test, y_pred),
        'RMSE': np.sqrt(mean_squared_error(y_test, y_pred)),
        'MAE': mean_absolute_error(y_test, y_pred),
        'ROC_AUC': roc_auc_score(y_test, y_proba) if y_proba is not None else np.nan
    }
    
    return metrics

# Initialize results storage
results = []
mc_results = {name: {'train_r2': [], 'test_r2': [], 'rmse': []} for name in models.keys()}

# Monte Carlo simulations
for sim in range(N_SIMULATIONS):
    print(f"\n🚀 Simulation {sim+1}/{N_SIMULATIONS}")
    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=TEST_SIZE, stratify=y, random_state=RANDOM_SEED+sim)
    
    for name, details in models.items():
        print(f"  - Training {name}...")
        
        # Create pipeline
        pipeline = ImbPipeline([
            ('imputer', SimpleImputer(strategy='median')),
            ('scaler', StandardScaler()),
            ('smote', SMOTE(random_state=RANDOM_SEED+sim)),
            ('model', details['model'](**details['params']))
        ])
        
        try:
            # Fit model
            pipeline.fit(X_train, y_train)
            
            # Evaluate
            metrics = evaluate_model(pipeline, X_test, y_test)
            
            # Store results
            results.append({
                'Simulation': sim+1,
                'Model': name,
                **metrics
            })
            
            # Cross validation for R2 score
            cv_scores = cross_val_score(pipeline, X_train, y_train, 
                                     cv=CV_FOLDS, scoring='r2', n_jobs=N_JOBS)
            
            # Store MC results
            mc_results[name]['train_r2'].append(np.mean(cv_scores))
            mc_results[name]['test_r2'].append(metrics['R2_Score'])
            mc_results[name]['rmse'].append(metrics['RMSE'])
            
        except Exception as e:
            print(f"Error with {name} in simulation {sim+1}: {str(e)}")
            continue

# Save performance metrics to CSV
results_df = pd.DataFrame(results)
results_df.to_csv('output/tables/performance_metrics.csv', index=False)
print("\nSaved performance metrics to output/tables/performance_metrics.csv")


# ==============================================
# SECTION 5: VISUALIZATIONS
# ==============================================

# ==============================================
# SECTION 5: VISUALIZATIONS (UPDATED)
# ==============================================

# Update the MC results storage to track all metrics
mc_results = {name: {
    'train_r2': [],
    'test_r2': [],
    'accuracy': [],
    'f1_score': [],
    'roc_auc': [],
    'rmse': [],
    'mae': []
} for name in models.keys()}

# Modify the simulation loop to store all metrics
for sim in range(N_SIMULATIONS):
    # ... [existing simulation code] ...
    
    for name, details in models.items():
        # ... [existing training code] ...
        
        try:
            # ... [existing fitting code] ...
            
            # Store all metrics
            mc_results[name]['train_r2'].append(np.mean(cv_scores))
            mc_results[name]['test_r2'].append(metrics['R2_Score'])
            mc_results[name]['accuracy'].append(metrics['Accuracy'])
            mc_results[name]['f1_score'].append(metrics['F1_Score'])
            mc_results[name]['roc_auc'].append(metrics['ROC_AUC'])
            mc_results[name]['rmse'].append(metrics['RMSE'])
            mc_results[name]['mae'].append(metrics['MAE'])
            
        except Exception as e:
            print(f"Error with {name} in simulation {sim+1}: {str(e)}")
            continue

# 1. Updated Training vs Testing Performance Across Simulations (All metrics)
metrics_to_plot = ['accuracy', 'f1_score', 'roc_auc', 'rmse', 'mae']
metric_names = ['Accuracy', 'F1 Score', 'ROC AUC', 'RMSE', 'MAE']

plt.figure(figsize=(15, 10))
for i, metric in enumerate(metrics_to_plot):
    plt.subplot(2, 3, i+1)
    for name in models.keys():
        plt.plot(range(1, N_SIMULATIONS+1), mc_results[name][metric], 
                 label=name, alpha=0.8)
    plt.xlabel('Simulation Number')
    plt.ylabel(metric_names[i])
    plt.title(f'{metric_names[i]} Across Simulations')
    plt.grid(True, alpha=0.3)
    if i == 0:  # Only show legend on first plot
        plt.legend()

plt.tight_layout()
plt.savefig('output/plots/all_metrics_across_simulations.png')
plt.close()

# 2. Metric distribution comparison (Boxplot version)
plt.figure(figsize=(15, 10))
for i, metric in enumerate(metrics_to_plot):
    plt.subplot(2, 3, i+1)
    metric_data = []
    for name in models.keys():
        metric_data.append(mc_results[name][metric])
    
    plt.boxplot(metric_data, labels=models.keys())
    plt.title(f'{metric_names[i]} Distribution')
    plt.ylabel(metric_names[i])
    plt.xticks(rotation=45)
    plt.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('output/plots/metric_distribution_comparison.png')
plt.close()

# 3. Metric correlation heatmap
corr_metrics = ['accuracy', 'f1_score', 'roc_auc', 'rmse', 'mae']
corr_names = ['Accuracy', 'F1', 'AUC', 'RMSE', 'MAE']

# Calculate correlations across all simulations
all_metrics = []
for name in models.keys():
    for sim in range(N_SIMULATIONS):
        metrics = {m: mc_results[name][m][sim] for m in corr_metrics}
        metrics['Model'] = name
        all_metrics.append(metrics)

corr_df = pd.DataFrame(all_metrics)
corr_matrix = corr_df[corr_metrics].corr()

plt.figure(figsize=(8, 6))
sns.heatmap(corr_matrix, annot=True, cmap='coolwarm', 
            xticklabels=corr_names, yticklabels=corr_names,
            vmin=-1, vmax=1, fmt='.2f')
plt.title('Metric Correlation Heatmap')
plt.tight_layout()
plt.savefig('output/plots/metric_correlations.png')
plt.close()

# ... [keep all other existing visualization code] ...

# 3. Model fitting diagram (Underfitting/Overfitting/Balanced)
plt.figure(figsize=(8, 6))
x = np.linspace(0, 1, 100)
y_true = np.sin(2 * np.pi * x)

# Underfit
y_under = 0.5 * x + 0.2
# Overfit
y_over = np.sin(2 * np.pi * x) + 0.3 * np.cos(20 * np.pi * x)
# Good fit
y_good = np.sin(2 * np.pi * x) + np.random.normal(0, 0.1, len(x))

plt.scatter(x, y_true + np.random.normal(0, 0.05, len(x)), 
            color='blue', alpha=0.3, label='Data points')
plt.plot(x, y_under, 'r-', linewidth=2, label='Underfitting')
plt.plot(x, y_over, 'g-', linewidth=2, label='Overfitting')
plt.plot(x, y_good, 'k--', linewidth=2, label='Good fit')
plt.xlabel('Input')
plt.ylabel('Output')
plt.title('Model Fitting Illustration')
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.savefig('output/plots/model_fitting_diagram.png')
plt.close()

# 4. WCT Predictions Statistics
# Get predictions from last simulation
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=TEST_SIZE, stratify=y, random_state=RANDOM_SEED+N_SIMULATIONS)

best_model = ImbPipeline([
    ('imputer', SimpleImputer(strategy='median')),
    ('scaler', StandardScaler()),
    ('smote', SMOTE(random_state=RANDOM_SEED)),
    ('model', XGBClassifier(**models['XGBoost']['params']))
])
best_model.fit(X_train, y_train)
y_pred = best_model.predict(X_test)

# Confusion matrix
plt.figure(figsize=(8, 6))
cm = confusion_matrix(y_test, y_pred)
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
            xticklabels=['Negative', 'Positive'],
            yticklabels=['Negative', 'Positive'])
plt.xlabel('Predicted')
plt.ylabel('Actual')
plt.title('WCT Prediction Confusion Matrix')
plt.tight_layout()
plt.savefig('output/plots/wct_confusion_matrix.png')
plt.close()

# Prediction distribution
plt.figure(figsize=(8, 6))
pd.Series(y_pred).value_counts().plot(kind='bar')
plt.xlabel('WCT Prediction')
plt.ylabel('Count')
plt.title('WCT Prediction Distribution')
plt.xticks([0, 1], ['Negative', 'Positive'], rotation=0)
plt.tight_layout()
plt.savefig('output/plots/wct_prediction_distribution.png')
plt.close()

# ==============================================
# SECTION 6: FINAL OUTPUTS
# ==============================================

print("\n🔥 All tasks completed successfully!")
print("Generated outputs:")
print("- Model hyperparameters table")
print("- Performance metrics CSV")
print("- Training vs testing performance plot")
print("- RMSE across simulations plot")
print("- Model fitting diagram")
print("- WCT prediction statistics plots")


ParserError: Error tokenizing data. C error: Calling read(nbytes) on source failed. Try engine='python'.

In [12]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix

# Set style for academic publication
plt.style.use('seaborn-v0_8')
plt.rcParams['font.family'] = 'Times New Roman'
plt.rcParams['figure.dpi'] = 300
plt.rcParams['axes.titlesize'] = 12
plt.rcParams['axes.labelsize'] = 10

def plot_all_confusion_matrices(y_true, y_preds, model_names):
    """
    Generate a 2x2 grid of confusion matrices for four models
    
    Parameters:
    y_true (array): True labels (0 or 1)
    y_preds (list of arrays): Predictions from each model
    model_names (list): Names of the four models
    """
    # Create figure and axes
    fig, axs = plt.subplots(2, 2, figsize=(12, 10))
    fig.suptitle('Confusion Matrices Comparison Across Models', y=1.02, fontsize=14)
    
    # Define colormaps for each model
    cmaps = ['Blues', 'Oranges', 'Greens', 'Reds']
    
    # Plot each confusion matrix
    for i, (pred, name, cmap) in enumerate(zip(y_preds, model_names, cmaps)):
        ax = axs[i//2, i%2]
        cm = confusion_matrix(y_true, pred)
        sns.heatmap(cm, annot=True, fmt='d', cmap=cmap, ax=ax,
                   cbar=False, annot_kws={'size': 10})
        
        # Calculate metrics
        tn, fp, fn, tp = cm.ravel()
        sensitivity = tp / (tp + fn)
        specificity = tn / (tn + fp)
        
        # Set titles and labels
        ax.set_title(f'{name}\nSens: {sensitivity:.3f}, Spec: {specificity:.3f}', 
                    pad=12, fontsize=11)
        ax.set_xlabel('Predicted Label', labelpad=8)
        ax.set_ylabel('True Label', labelpad=8)
        ax.set_xticklabels(['Negative', 'Positive'])
        ax.set_yticklabels(['Negative', 'Positive'], rotation=0)
        
    
    

    
    plt.tight_layout()
    plt.savefig('output/plots/all_confusion_matrices.png', bbox_inches='tight')
    plt.close()
    print("Saved combined confusion matrices to output/plots/all_confusion_matrices.png")

# Example usage with mock data (replace with your actual predictions)
y_true = np.random.randint(0, 2, 1000)  # True labels
y_preds = [
    np.random.binomial(1, 0.98, 1000),  # XGBoost predictions
    np.random.binomial(1, 0.97, 1000),  # LightGBM predictions
    np.random.binomial(1, 0.96, 1000),  # Random Forest predictions
    np.random.binomial(1, 0.95, 1000)   # Gradient Boosting predictions
]
model_names = ['XGBoost', 'LightGBM', 'Random Forest', 'Gradient Boosting']

plot_all_confusion_matrices(y_true, y_preds, model_names)

Saved combined confusion matrices to output/plots/all_confusion_matrices.png


In [28]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
from sklearn.metrics import confusion_matrix

# Set style for academic publication
plt.style.use('seaborn-v0_8')
plt.rcParams['font.family'] = 'Times New Roman'
plt.rcParams['figure.dpi'] = 300
plt.rcParams['axes.titlesize'] = 12
plt.rcParams['axes.labelsize'] = 10
plt.rcParams['xtick.labelsize'] = 9
plt.rcParams['ytick.labelsize'] = 9

def generate_model_analysis_plots(y_true, y_preds, feature_importances, model_names, features):
    """
    Generate comprehensive model evaluation plots including:
    - Confusion matrices for all models
    - Feature importance comparison
    
    Parameters:
    y_true (array): True labels (0 or 1)
    y_preds (list of arrays): Predictions from each model
    feature_importances (list of arrays): Feature importance scores
    model_names (list): Names of the models
    features (list): Feature names
    """
    # Create figure with 2 rows (confusion matrices and feature importance)
    fig = plt.figure(figsize=(14, 14))
    gs = fig.add_gridspec(2, 1, height_ratios=[1, 1.2])
    
    # --------------------------------------------------
    # Confusion Matrices (Top Row)
    # --------------------------------------------------
    axs_conf = fig.add_subplot(gs[0])
    axs_conf.axis('off')  # Turn off main axis for subplots
    
    # Create 2x2 grid for confusion matrices
    inner_gs_conf = gs[0].subgridspec(2, 2, wspace=0.15, hspace=0.25)
    axs_conf = inner_gs_conf.subplots()
    
    # Define colormaps for each model
    cmaps = ['Blues', 'Oranges', 'Greens', 'Reds']
    
    # Plot each confusion matrix
    for i, (pred, name, cmap) in enumerate(zip(y_preds, model_names, cmaps)):
        ax = axs_conf[i//2, i%2]
        cm = confusion_matrix(y_true, pred)
        sns.heatmap(cm, annot=True, fmt='d', cmap=cmap, ax=ax,
                   cbar=False, annot_kws={'size': 10})
        
        # Calculate metrics
        tn, fp, fn, tp = cm.ravel()
        sensitivity = tp / (tp + fn)
        specificity = tn / (tn + fp)
        accuracy = (tp + tn) / (tp + tn + fp + fn)
        
        # Set titles and labels
        ax.set_title(f'{name}\nSens: {sensitivity:.3f}, Spec: {specificity:.3f}\nAcc: {accuracy:.3f}', 
                    pad=12, fontsize=11)
        ax.set_xlabel('Predicted Label', labelpad=8)
        ax.set_ylabel('True Label', labelpad=8)
        ax.set_xticklabels(['Negative', 'Positive'])
        ax.set_yticklabels(['Negative', 'Positive'], rotation=0)
        
        # Add border
        for _, spine in ax.spines.items():
            spine.set_visible(True)
            spine.set_linewidth(1.5)
    
    
    # --------------------------------------------------
    # Feature Importance (Bottom Row)
    # --------------------------------------------------
    axs_feat = fig.add_subplot(gs[1])
    axs_feat.axis('off')  # Turn off main axis for subplots
    
    # Create dataframe for feature importances
    feat_df = pd.DataFrame(feature_importances, index=model_names, columns=features).T
    feat_df = feat_df.sort_values(by=model_names[0], ascending=False)
    
    # Create 2x2 grid for feature importance
    inner_gs_feat = gs[1].subgridspec(2, 2, wspace=0.15, hspace=0.3)
    axs_feat = inner_gs_feat.subplots()
    
    # Plot feature importance for each model
    for i, (name, cmap) in enumerate(zip(model_names, cmaps)):
        ax = axs_feat[i//2, i%2]
        sns.barplot(x=feat_df[name].values, y=feat_df.index, ax=ax, 
                   color=sns.color_palette(cmap)[2], saturation=0.8)
        
        # Formatting
        ax.set_title(f'{name} Feature Importance', pad=12, fontsize=11)
        ax.set_xlabel('Importance Score', labelpad=8)
        ax.set_ylabel('')
        ax.grid(True, alpha=0.3, axis='x')
        
        # Add value labels for top 3 features
        for j, (feature, val) in enumerate(zip(feat_df.index, feat_df[name])):
            if j < 3:  # Only label top 3 features
                ax.text(val + 0.01, j, f'{val:.3f}', 
                        va='center', fontsize=9)
    

    
    plt.tight_layout()
    plt.savefig('output/plots/full_model_analysis.png', bbox_inches='tight')
    plt.close()
    print("Saved comprehensive model analysis to output/plots/full_model_analysis.png")


In [29]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from scipy.stats import gaussian_kde

# Set style
plt.style.use('seaborn-v0_8')
plt.rcParams['font.family'] = 'Times New Roman'
plt.rcParams['figure.dpi'] = 300

# Data from Table 5.2
data = {
    'Model': ['XGBoost', 'LightGBM', 'Random Forest', 'Gradient Boosting'],
    'Avg Accuracy': [0.9998, 0.9997, 0.9995, 0.9993],
    'F1 Score': [0.9998, 0.9996, 0.9994, 0.9992],
    'ROC AUC': [0.9999, 0.9998, 0.9996, 0.9995],
    'Max RMSE': [0.0071, 0.0122, 0.0122, 0.0158]
}

# Create density plot
def plot_metric_density():
    fig, ax = plt.subplots(figsize=(8, 6))
    
    # Create density distributions for each metric
    metrics = ['Avg Accuracy', 'F1 Score', 'ROC AUC', 'Max RMSE']
    colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728']
    
    for metric, color in zip(metrics, colors):
        values = data[metric]
        
        # Create density estimate
        density = gaussian_kde(values)
        xs = np.linspace(min(values)-0.0005, max(values)+0.0005, 200)
        
        # Plot density curve
        ax.plot(xs, density(xs), color=color, label=metric)
        
        # Add rug plot
        ax.scatter(values, [-0.0005]*len(values), color=color, marker='|', s=100)
    
    # Formatting
    ax.set_xlim(0.999, 1.0005)
    ax.set_ylim(-0.005, 25)
    ax.set_xlabel('Metric Value', fontsize=10)
    ax.set_ylabel('Density', fontsize=10)
    ax.set_title('Model Metric Density Distributions', fontsize=12)
    ax.grid(True, alpha=0.3)
    ax.legend(fontsize=9)
    
    # Add model labels at bottom
    for i, model in enumerate(data['Model']):
        ax.text(0.9991 + i*0.0003, -0.003, model, 
                rotation=45, fontsize=8, ha='left')
    
    plt.tight_layout()
    plt.savefig('output/plots/metric_density.png')
    plt.close()
    print("Saved density plot to output/plots/metric_density.png")

# Generate plot
plot_metric_density()

Saved density plot to output/plots/metric_density.png


In [31]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import os
from scipy import stats

# Load and prepare the data
df = pd.read_csv('merged_ecg_data_cleaned.csv')

# Convert string ranges to numeric (take first value)
def convert_freq_range(x):
    if isinstance(x, str) and '-' in x and 'Hz' in x:
        return float(x.split('-')[0])
    return x

df['bandwidth'] = df['bandwidth'].apply(convert_freq_range)
df['filtering'] = df['filtering'].apply(convert_freq_range)

# Ensure numeric types
numeric_features = ['rr_interval', 'p_onset', 'p_end', 'qrs_onset', 
                   'qrs_end', 't_end', 'p_axis', 'qrs_axis', 
                   't_axis', 'qrs_duration']
for feat in numeric_features:
    df[feat] = pd.to_numeric(df[feat], errors='coerce')

# Handle target variable
target = 'wct_label_encoded' if 'wct_label_encoded' in df.columns else 'wct_label'
if not np.issubdtype(df[target].dtype, np.number):
    df[target] = df[target].astype('category').cat.codes

# Create output directory
os.makedirs('statistical_analysis', exist_ok=True)

# 1. Temporal Feature Trends (Line Graphs)
temporal_features = ['rr_interval', 'p_onset', 'qrs_onset', 't_end']
plt.figure(figsize=(14, 8))

for i, feat in enumerate(temporal_features, 1):
    plt.subplot(2, 2, i)
    sns.lineplot(data=df, x=df.index, y=feat, hue=target, 
                estimator='mean', errorbar=('ci', 95))
    plt.title(f'Trend of {feat}')
    plt.xlabel('Time Sequence')
    plt.ylabel(feat)
    plt.legend(title='WCT Status')

plt.tight_layout()
plt.savefig('statistical_analysis/temporal_trends.png')
plt.close()

# 2. Statistical Distribution Over Time
stats_to_plot = ['mean', 'std', 'skew']
plt.figure(figsize=(15, 5))

for i, stat in enumerate(stats_to_plot, 1):
    plt.subplot(1, 3, i)
    df_rolling = df[numeric_features].rolling(100).agg(stat)
    for feat in ['rr_interval', 'qrs_duration']:  # Key features
        sns.lineplot(data=df_rolling, x=df_rolling.index, y=feat)
    plt.title(f'Rolling {stat} of Key Features')
    plt.xlabel('Time Sequence')
    plt.ylabel(stat)
    plt.legend(['RR Interval', 'QRS Duration'])

plt.tight_layout()
plt.savefig('statistical_analysis/rolling_stats.png')
plt.close()

# 3. Statistical Significance Over Time
significant_features = []
window_size = 500  # Analyze in windows of 500 samples

for feat in numeric_features:
    p_values = []
    for i in range(0, len(df), window_size):
        window = df.iloc[i:i+window_size]
        if len(window[target].unique()) == 2:  # Both classes present
            group0 = window[window[target] == 0][feat].dropna()
            group1 = window[window[target] == 1][feat].dropna()
            if len(group0) > 10 and len(group1) > 10:
                _, p = stats.mannwhitneyu(group0, group1)
                p_values.append(p)
    
    if p_values and np.mean(p_values) < 0.05:
        significant_features.append(feat)

# 4. Plot Most Significant Features
plt.figure(figsize=(14, 6))
for i, feat in enumerate(significant_features[:3], 1):
    plt.subplot(1, 3, i)
    sns.lineplot(data=df, x=df.index, y=feat, hue=target,
                estimator='median', errorbar=None)
    plt.title(f'{feat} by WCT Status')
    plt.xlabel('Time Sequence')
    plt.ylabel(feat)

plt.tight_layout()
plt.savefig('statistical_analysis/significant_features_trends.png')
plt.close()

# 5. Statistical Metric Trends
metrics = ['mean', 'std', 'median']
plt.figure(figsize=(15, 10))

for i, metric in enumerate(metrics, 1):
    plt.subplot(2, 2, i)
    df_grouped = df.groupby(target)[numeric_features].agg(metric).T
    df_grouped.plot(marker='o', ax=plt.gca())
    plt.title(f'{metric.capitalize()} Values by WCT Status')
    plt.ylabel('Value')
    plt.xticks(rotation=45)
    plt.legend(title='WCT Status')

plt.tight_layout()
plt.savefig('statistical_analysis/metric_trends.png')
plt.close()

print("Statistical line graphs generated and saved to 'statistical_analysis' directory")

  plt.tight_layout()
  plt.savefig('statistical_analysis/rolling_stats.png')


Statistical line graphs generated and saved to 'statistical_analysis' directory


In [38]:
# ==============================================
# COMPLETE SALIENCY VISUALIZATION FOR ECG WCT DETECTION
# ==============================================

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from sklearn.inspection import permutation_importance
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import Pipeline
from xgboost import XGBClassifier
import shap
import os

# Set up directories
os.makedirs('output/plots', exist_ok=True)
os.makedirs('output/tables', exist_ok=True)

# ==============================================
# 1. DATA LOADING AND PREPROCESSING
# ==============================================

def load_and_preprocess_data():
    """Load and preprocess the ECG dataset"""
    print("Loading merged_ecg_data_cleaned.csv...")
    try:
        df = pd.read_csv('merged_ecg_data_cleaned.csv')
    except FileNotFoundError:
        raise FileNotFoundError("Could not find 'merged_ecg_data_cleaned.csv'. Please ensure the file exists in the current directory.")
    
    # Select features and target
    features = ['bandwidth', 'filtering', 'rr_interval', 'p_onset', 'p_end', 
               'qrs_onset', 'qrs_end', 't_end', 'p_axis', 'qrs_axis', 
               't_axis', 'qrs_duration']
    target = 'wct_label_encoded' if 'wct_label_encoded' in df.columns else 'wct_label'
    
    # Handle missing values
    for col in features:
        if col in df.columns:
            df[col] = pd.to_numeric(df[col], errors='coerce')
            df[col] = df[col].fillna(df[col].median())
    
    # Encode target if needed
    if not np.issubdtype(df[target].dtype, np.number):
        df[target] = df[target].astype('category').cat.codes
    
    X = df[features].values
    y = df[target].values
    
    print(f"Data shape: {X.shape}, Features: {features}")
    return X, y, features

X, y, features = load_and_preprocess_data()

# ==============================================
# 2. MODEL TRAINING
# ==============================================

print("\nTraining XGBoost model for interpretation...")
model = Pipeline([
    ('scaler', StandardScaler()),
    ('model', XGBClassifier(
        n_estimators=200,
        max_depth=5,
        learning_rate=0.1,
        random_state=42,
        n_jobs=-1
    ))
])
model.fit(X, y)

# ==============================================
# 3. ENHANCED SALIENCY VISUALIZATIONS
# ==============================================

def plot_ecg_saliency(model, X, y, features, n_samples=1000):
    """
    Generate beautiful saliency visualizations for ECG features
    """
    # Permutation Importance
    print("Calculating permutation importance...")
    result = permutation_importance(
        model, X[:n_samples], y[:n_samples],
        n_repeats=10,
        random_state=42,
        n_jobs=-1
    )
    
    # Sort features by importance
    sorted_idx = result.importances_mean.argsort()[::-1]
    sorted_features = np.array(features)[sorted_idx]
    sorted_importance = result.importances[sorted_idx]
    
    # Create ECG-appropriate color gradient
    ecg_colors = ["#FF6B6B", "#FFA3A3", "#FFD3B6", "#DCE2C8", "#A5D8D6", "#74B3CE"]
    cmap = mcolors.LinearSegmentedColormap.from_list("ecg_cmap", ecg_colors)
    
    # Create figure with ECG styling
    fig, ax = plt.subplots(figsize=(14, 8))
    plt.style.use('default')  # Using default style
    
    # Create horizontal violin plot
    parts = ax.violinplot(
        sorted_importance.T,
        vert=False,
        showmeans=True,
        showextrema=True,
        widths=0.7
    )
    
    # Color each violin by importance
    for pc, color in zip(parts['bodies'], np.linspace(0, 1, len(sorted_features))):
        pc.set_facecolor(cmap(color))
        pc.set_edgecolor('#2d3436')
        pc.set_alpha(0.9)
    
    # ECG-specific styling
    ax.set_yticks(range(1, len(sorted_features)+1))
    ax.set_yticklabels(sorted_features, fontsize=10, fontfamily='serif')
    ax.set_title("ECG Feature Importance for WCT Detection\n(Permutation Importance Analysis)", 
                fontsize=16, pad=20, fontweight='bold', fontfamily='serif')
    ax.set_xlabel("Mean Accuracy Decrease When Permuted", 
                 fontsize=12, fontfamily='serif')
    ax.set_ylabel("ECG Features", fontsize=12, fontfamily='serif')
    
    # Add clinical context annotation
    ax.annotate(
        "Higher values indicate features more critical\nfor Wide Complex Tachycardia detection",
        xy=(0.7, 0.9), xycoords='axes fraction',
        fontsize=10, fontfamily='serif',
        bbox=dict(boxstyle="round", alpha=0.1, facecolor="white")
    )
    
    # Final styling
    ax.grid(True, linestyle='--', alpha=0.2)
    ax.set_facecolor('#f8f9fa')
    
    # Add professional colorbar - FIXED with proper mappable
    sm = plt.cm.ScalarMappable(cmap=cmap, norm=plt.Normalize(vmin=0, vmax=len(sorted_features)))
    sm.set_array([])
    cbar = fig.colorbar(sm, ax=ax, pad=0.02)
    cbar.set_label('Feature Importance Rank', 
                  rotation=270, 
                  labelpad=20,
                  fontfamily='serif')
    
    plt.tight_layout()
    plt.savefig('output/plots/ecg_saliency_map.png', 
               dpi=300, 
               bbox_inches='tight',
               transparent=True)
    plt.close()
    
    return result, sorted_features

# Generate the main saliency visualization
perm_result, important_features = plot_ecg_saliency(model, X, y, features)
print("✅ Saved ECG saliency map to output/plots/ecg_saliency_map.png")

# ==============================================
# 4. SHAP VALUE ANALYSIS (DETAILED FEATURE IMPACT)
# ==============================================

def plot_shap_analysis(model, X, features, n_samples=500):
    """Generate SHAP value visualizations"""
    print("\nGenerating SHAP explanations...")
    
    try:
        # Prepare SHAP explainer
        explainer = shap.TreeExplainer(model.named_steps['model'])
        
        # Transform data through pipeline
        X_transformed = model.named_steps['scaler'].transform(X[:n_samples])
        
        # Compute SHAP values
        shap_values = explainer.shap_values(X_transformed)
        
        # Create clinical color scheme (red=positive, blue=negative)
        shap_cmap = mcolors.LinearSegmentedColormap.from_list(
            "shap_clinical", ["#1e90ff", "#f1f2f6", "#ff4757"])
        
        # Beeswarm plot
        plt.figure()
        shap.summary_plot(
            shap_values, 
            X_transformed, 
            feature_names=features,
            plot_type="dot",
            color=shap_cmap,
            show=False,
            plot_size=(12, 8)
        )
        
        # Clinical styling
        plt.title("SHAP Feature Impact on WCT Predictions\n(Red=Higher Risk, Blue=Lower Risk)", 
                 fontsize=14, pad=20, fontweight='bold', fontfamily='serif')
        plt.gcf().set_facecolor('#f8f9fa')
        plt.tight_layout()
        plt.savefig('output/plots/ecg_shap_beeswarm.png', 
                   dpi=300, 
                   bbox_inches='tight',
                   transparent=True)
        plt.close()
        
        # Feature importance plot
        plt.figure()
        shap.summary_plot(
            shap_values, 
            X_transformed, 
            feature_names=features,
            plot_type="bar",
            show=False,
            plot_size=(10, 6)
        )
        plt.title("SHAP Feature Importance Ranking", 
                 fontsize=14, pad=20, fontweight='bold', fontfamily='serif')
        plt.gcf().set_facecolor('#f8f9fa')
        plt.tight_layout()
        plt.savefig('output/plots/ecg_shap_importance.png', 
                   dpi=300, 
                   bbox_inches='tight',
                   transparent=True)
        plt.close()
    except Exception as e:
        print(f"⚠️ Could not generate SHAP plots: {str(e)}")
        print("You may need to install SHAP: pip install shap")

# Generate SHAP plots
plot_shap_analysis(model, X, features)
print("✅ Saved SHAP visualizations to output/plots/")

# ==============================================
# 5. ECG FEATURE GROUP ANALYSIS
# ==============================================

def plot_ecg_feature_groups(importance_result, features):
    """Visualize importance by clinical ECG feature groups"""
    # Define clinically meaningful groups
    ecg_groups = {
        'Timing Intervals': ['rr_interval', 'p_onset', 'p_end', 
                            'qrs_onset', 'qrs_end', 't_end', 'qrs_duration'],
        'Electrical Axes': ['p_axis', 'qrs_axis', 't_axis'],
        'Signal Quality': ['bandwidth', 'filtering']
    }
    
    # Calculate group importances
    group_data = []
    for group_name, group_features in ecg_groups.items():
        group_mask = [f in group_features for f in features]
        group_mean = np.mean(importance_result.importances_mean[group_mask])
        group_std = np.mean(importance_result.importances_std[group_mask])
        group_data.append({
            'Group': group_name,
            'Importance': group_mean,
            'Std': group_std
        })
    
    group_df = pd.DataFrame(group_data).sort_values('Importance', ascending=False)
    
    # Create plot
    fig, ax = plt.subplots(figsize=(10, 6))
    colors = plt.cm.PuBu(np.linspace(0.3, 0.9, len(group_df)))
    
    ax.barh(
        group_df['Group'],
        group_df['Importance'],
        xerr=group_df['Std'],
        color=colors,
        alpha=0.8,
        capsize=5
    )
    
    # Clinical ECG styling
    ax.set_title("WCT Detection Importance by ECG Feature Category", 
                fontsize=16, pad=20, fontweight='bold', fontfamily='serif')
    ax.set_xlabel("Mean Permutation Importance", fontsize=12, fontfamily='serif')
    ax.set_ylabel("ECG Feature Categories", fontsize=12, fontfamily='serif')
    ax.grid(True, linestyle='--', alpha=0.2)
    ax.set_facecolor('#f8f9fa')
    ax.invert_yaxis()
    
    plt.tight_layout()
    plt.savefig('output/plots/ecg_group_importance.png', 
               dpi=300, 
               bbox_inches='tight',
               transparent=True)
    plt.close()

# Generate grouped importance plot
plot_ecg_feature_groups(perm_result, features)
print("✅ Saved ECG group importance to output/plots/ecg_group_importance.png")

# ==============================================
# FINAL OUTPUT
# ==============================================

print("\n🎉 ECG Saliency Analysis Complete!")
print("Generated the following clinical visualizations:")
print("1. output/plots/ecg_saliency_map.png - Main permutation importance")
print("2. output/plots/ecg_shap_beeswarm.png - Detailed SHAP values")
print("3. output/plots/ecg_shap_importance.png - SHAP importance ranking")
print("4. output/plots/ecg_group_importance.png - Clinical feature groups")

Loading merged_ecg_data_cleaned.csv...


  return np.nanmean(a, axis, out=out, keepdims=keepdims)
  return np.nanmean(a, axis, out=out, keepdims=keepdims)


Data shape: (800035, 12), Features: ['bandwidth', 'filtering', 'rr_interval', 'p_onset', 'p_end', 'qrs_onset', 'qrs_end', 't_end', 'p_axis', 'qrs_axis', 't_axis', 'qrs_duration']

Training XGBoost model for interpretation...


  updated_mean = (last_sum + new_sum) / updated_sample_count
  T = new_sum / new_sample_count
  new_unnormalized_variance -= correction**2 / new_sample_count


Calculating permutation importance...
✅ Saved ECG saliency map to output/plots/ecg_saliency_map.png

Generating SHAP explanations...


  shap.summary_plot(
  return _nanquantile_unchecked(
  shap.summary_plot(


✅ Saved SHAP visualizations to output/plots/
✅ Saved ECG group importance to output/plots/ecg_group_importance.png

🎉 ECG Saliency Analysis Complete!
Generated the following clinical visualizations:
1. output/plots/ecg_saliency_map.png - Main permutation importance
2. output/plots/ecg_shap_beeswarm.png - Detailed SHAP values
3. output/plots/ecg_shap_importance.png - SHAP importance ranking
4. output/plots/ecg_group_importance.png - Clinical feature groups


In [2]:
# ==============================================
# ECG SALIENCY VISUALIZATION - FINAL WORKING VERSION
# ==============================================

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from sklearn.inspection import permutation_importance
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import Pipeline
from xgboost import XGBClassifier
import shap
import os

# Set up directories
os.makedirs('output/plots', exist_ok=True)
os.makedirs('output/tables', exist_ok=True)

# ==============================================
# 1. DATA LOADING AND PREPROCESSING
# ==============================================

def load_and_preprocess_data():
    """Load and preprocess the ECG dataset"""
    print("Loading merged_ecg_data_cleaned.csv...")
    try:
        df = pd.read_csv('merged_ecg_data_cleaned.csv')
    except FileNotFoundError:
        raise FileNotFoundError("Could not find 'merged_ecg_data_cleaned.csv'")
    
    # Select features and target
    features = ['bandwidth', 'filtering', 'rr_interval', 'p_onset', 'p_end', 
               'qrs_onset', 'qrs_end', 't_end', 'p_axis', 'qrs_axis', 
               't_axis', 'qrs_duration']
    target = 'wct_label_encoded' if 'wct_label_encoded' in df.columns else 'wct_label'
    
    # Handle missing values
    for col in features:
        if col in df.columns:
            df[col] = pd.to_numeric(df[col], errors='coerce')
            df[col] = df[col].fillna(df[col].median())
    
    # Encode target if needed
    if not np.issubdtype(df[target].dtype, np.number):
        df[target] = df[target].astype('category').cat.codes
    
    X = df[features].values
    y = df[target].values
    
    print(f"Data shape: {X.shape}, Features: {features}")
    return X, y, features

X, y, features = load_and_preprocess_data()

# ==============================================
# 2. MODEL TRAINING
# ==============================================

print("\nTraining XGBoost model for interpretation...")
model = Pipeline([
    ('scaler', StandardScaler()),
    ('model', XGBClassifier(
        n_estimators=200,
        max_depth=5,
        learning_rate=0.1,
        random_state=42,
        n_jobs=-1
    ))
])
model.fit(X, y)

# ==============================================
# 3. ENHANCED SALIENCY VISUALIZATIONS
# ==============================================

def plot_ecg_saliency(model, X, y, features, n_samples=1000):
    """Generate beautiful saliency visualizations"""
    # Permutation Importance
    print("Calculating permutation importance...")
    result = permutation_importance(
        model, X[:n_samples], y[:n_samples],
        n_repeats=10,
        random_state=42,
        n_jobs=-1
    )
    
    # Sort features by importance
    sorted_idx = result.importances_mean.argsort()[::-1]
    sorted_features = np.array(features)[sorted_idx]
    sorted_importance = result.importances[sorted_idx]
    
    # Create figure with modern style
    plt.figure(figsize=(12, 8))
    
    # Create boxplot with enhanced styling
    boxprops = dict(linestyle='-', linewidth=1.5, color='darkblue')
    whiskerprops = dict(linestyle='--', linewidth=1, color='black')
    medianprops = dict(linestyle='-', linewidth=2, color='firebrick')
    
    bp = plt.boxplot(
        sorted_importance.T,
        vert=False,
        patch_artist=True,
        boxprops=boxprops,
        whiskerprops=whiskerprops,
        medianprops=medianprops
    )
    
    # Color boxes with clinical color scheme
    clinical_colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', 
                      '#9467bd', '#8c564b', '#e377c2', '#7f7f7f']
    for i, box in enumerate(bp['boxes']):
        box.set_facecolor(clinical_colors[i % len(clinical_colors)])
        box.set_alpha(0.7)
    
    # Add feature names
    plt.yticks(range(1, len(sorted_features)+1), sorted_features, fontsize=10)
    
    # Add titles and labels
    plt.title('ECG Feature Importance for WCT Detection', 
             fontsize=14, pad=20, fontweight='bold')
    plt.xlabel('Mean Accuracy Decrease', fontsize=12)
    plt.ylabel('ECG Features', fontsize=12)
    
    # Add grid
    plt.grid(True, linestyle=':', alpha=0.3)
    
    plt.tight_layout()
    plt.savefig('output/plots/ecg_saliency_map.png', dpi=300)
    plt.close()
    
    return result, sorted_features

# Generate the main saliency visualization
perm_result, important_features = plot_ecg_saliency(model, X, y, features)
print("✅ Saved ECG saliency map to output/plots/ecg_saliency_map.png")

# ==============================================
# 4. SHAP VALUE ANALYSIS
# ==============================================

def plot_shap_analysis(model, X, features, n_samples=500):
    """Generate SHAP value visualizations"""
    print("\nGenerating SHAP explanations...")
    
    try:
        # Prepare SHAP explainer
        explainer = shap.TreeExplainer(model.named_steps['model'])
        X_transformed = model.named_steps['scaler'].transform(X[:n_samples])
        shap_values = explainer.shap_values(X_transformed)
        
        # Beeswarm plot with clinical colors
        plt.figure()
        shap.summary_plot(
            shap_values, 
            X_transformed, 
            feature_names=features,
            plot_type="dot",
            show=False,
            plot_size=(12, 8),
            cmap=plt.get_cmap('coolwarm')
        )
        plt.title("SHAP Feature Impact on WCT Predictions", 
                 fontsize=14, pad=20, fontweight='bold')
        plt.tight_layout()
        plt.savefig('output/plots/ecg_shap_beeswarm.png', dpi=300)
        plt.close()
        
    except Exception as e:
        print(f"⚠️ Could not generate SHAP plots: {str(e)}")

# Generate SHAP plots
plot_shap_analysis(model, X, features)
print("✅ Saved SHAP visualizations to output/plots/")

# ==============================================
# 5. ECG FEATURE GROUP ANALYSIS
# ==============================================

def plot_ecg_feature_groups(importance_result, features):
    """Visualize importance by clinical ECG feature groups"""
    ecg_groups = {
        'Timing Intervals': ['rr_interval', 'p_onset', 'p_end', 'qrs_onset', 'qrs_end', 't_end', 'qrs_duration'],
        'Electrical Axes': ['p_axis', 'qrs_axis', 't_axis'],
        'Signal Quality': ['bandwidth', 'filtering']
    }
    
    # Calculate group importances
    group_data = []
    for group_name, group_features in ecg_groups.items():
        group_mask = [f in group_features for f in features]
        group_mean = np.mean(importance_result.importances_mean[group_mask])
        group_data.append({'Group': group_name, 'Importance': group_mean})
    
    group_df = pd.DataFrame(group_data).sort_values('Importance', ascending=False)
    
    # Create plot with clinical styling
    plt.figure(figsize=(10, 5))
    bars = plt.barh(
        group_df['Group'],
        group_df['Importance'],
        color=['#1f77b4', '#ff7f0e', '#2ca02c']  # Clinical color scheme
    )
    
    # Add value labels
    for bar in bars:
        width = bar.get_width()
        plt.text(width*1.02, bar.get_y() + bar.get_height()/2,
                f'{width:.3f}',
                va='center')
    
    plt.title('WCT Detection Importance by ECG Feature Category', 
             fontsize=14, pad=20, fontweight='bold')
    plt.xlabel('Mean Permutation Importance')
    plt.grid(True, linestyle=':', alpha=0.3)
    plt.tight_layout()
    plt.savefig('output/plots/ecg_group_importance.png', dpi=300)
    plt.close()

# Generate grouped importance plot
plot_ecg_feature_groups(perm_result, features)
print("✅ Saved ECG group importance to output/plots/ecg_group_importance.png")

# ==============================================
# FINAL OUTPUT
# ==============================================

print("\n🎉 ECG Saliency Analysis Complete!")
print("Generated the following visualizations:")
print("1. output/plots/ecg_saliency_map.png - Feature importance")
print("2. output/plots/ecg_shap_beeswarm.png - SHAP values")
print("3. output/plots/ecg_group_importance.png - Feature groups")

Loading merged_ecg_data_cleaned.csv...


  return np.nanmean(a, axis, out=out, keepdims=keepdims)
  return np.nanmean(a, axis, out=out, keepdims=keepdims)


Data shape: (800035, 12), Features: ['bandwidth', 'filtering', 'rr_interval', 'p_onset', 'p_end', 'qrs_onset', 'qrs_end', 't_end', 'p_axis', 'qrs_axis', 't_axis', 'qrs_duration']

Training XGBoost model for interpretation...


  updated_mean = (last_sum + new_sum) / updated_sample_count
  T = new_sum / new_sample_count
  new_unnormalized_variance -= correction**2 / new_sample_count


Calculating permutation importance...
✅ Saved ECG saliency map to output/plots/ecg_saliency_map.png

Generating SHAP explanations...


  return _nanquantile_unchecked(


✅ Saved SHAP visualizations to output/plots/
✅ Saved ECG group importance to output/plots/ecg_group_importance.png

🎉 ECG Saliency Analysis Complete!
Generated the following visualizations:
1. output/plots/ecg_saliency_map.png - Feature importance
2. output/plots/ecg_shap_beeswarm.png - SHAP values
3. output/plots/ecg_group_importance.png - Feature groups
