# K-Fold Cross-Validation Training
This notebook implements robust stratified K-Fold cross-validation for sleep disorder classification.

## Why K-Fold Cross-Validation?
- ‚úÖ **More Reliable Metrics**: Uses entire dataset for both training and validation
- ‚úÖ **Reduced Variance**: Averages performance across multiple folds
- ‚úÖ **Better Generalization**: Ensures model performs well on different data splits
- ‚úÖ **Industry Standard**: Professional approach for model evaluation
- ‚úÖ **Class Balance**: Stratified splits maintain class distribution

## Process:
1. Split data into K folds (typically 5 or 10)
2. Train model K times, each time using different fold as validation
3. Average metrics across all folds
4. Report mean and standard deviation

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score
from sklearn.metrics import confusion_matrix, classification_report
import tensorflow as tf
from tensorflow import keras
from time import time
import json

print(f"TensorFlow version: {tf.__version__}")
print(f"NumPy version: {np.__version__}")
%matplotlib inline

## 1. K-Fold Cross-Validation Function

In [None]:
def kfold_cross_validation(model_builder, X, y, n_splits=5, epochs=150, batch_size=64, 
                          random_state=42, verbose=1):
    """
    Perform K-Fold cross-validation for model training
    
    Parameters:
    -----------
    model_builder : function
        Function that returns a compiled model
    X : array, shape (n_samples, n_timesteps, n_features)
        Input data
    y : array, shape (n_samples,)
        Target labels
    n_splits : int
        Number of folds
    epochs : int
        Maximum epochs per fold
    batch_size : int
        Batch size for training
    random_state : int
        Random seed for reproducibility
    verbose : int
        Verbosity level (0, 1, or 2)
    
    Returns:
    --------
    results : dict
        Dictionary containing all fold results and statistics
    models : list
        List of trained models from each fold
    histories : list
        List of training histories from each fold
    """
    
    # Initialize stratified K-fold
    skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=random_state)
    
    # Storage for results
    fold_results = []
    models = []
    histories = []
    all_y_true = []
    all_y_pred = []
    all_y_pred_proba = []
    
    # Callbacks
    early_stopping = keras.callbacks.EarlyStopping(
        monitor='val_loss',
        patience=30,
        restore_best_weights=True,
        verbose=0
    )
    
    reduce_lr = keras.callbacks.ReduceLROnPlateau(
        monitor='val_loss',
        factor=0.5,
        patience=10,
        min_lr=1e-6,
        verbose=0
    )
    
    print(f"\n{'='*70}")
    print(f"Starting {n_splits}-Fold Cross-Validation")
    print(f"{'='*70}")
    print(f"Total samples: {len(X)}")
    print(f"Samples per fold (approx): {len(X) // n_splits}")
    print(f"Class distribution: {np.bincount(y.astype(int))}")
    print(f"{'='*70}\n")
    
    # Iterate through folds
    fold_start_time = time()
    
    for fold, (train_idx, val_idx) in enumerate(skf.split(X, y), 1):
        print(f"\n{'='*70}")
        print(f"üìÅ FOLD {fold}/{n_splits}")
        print(f"{'='*70}")
        
        # Split data
        X_train_fold, X_val_fold = X[train_idx], X[val_idx]
        y_train_fold, y_val_fold = y[train_idx], y[val_idx]
        
        print(f"Training samples: {len(X_train_fold)}")
        print(f"Validation samples: {len(X_val_fold)}")
        print(f"Train class distribution: {np.bincount(y_train_fold.astype(int))}")
        print(f"Val class distribution: {np.bincount(y_val_fold.astype(int))}")
        
        # Build fresh model for this fold
        model = model_builder()
        
        # Train model
        print(f"\nüöÄ Training fold {fold}...")
        fold_train_start = time()
        
        history = model.fit(
            X_train_fold, y_train_fold,
            validation_data=(X_val_fold, y_val_fold),
            epochs=epochs,
            batch_size=batch_size,
            callbacks=[early_stopping, reduce_lr],
            verbose=verbose,
            shuffle=True
        )
        
        fold_train_time = time() - fold_train_start
        
        # Evaluate on validation set
        y_pred_proba = model.predict(X_val_fold, verbose=0)
        y_pred = (y_pred_proba > 0.5).astype(int).flatten()
        
        # Calculate metrics
        fold_metrics = {
            'fold': fold,
            'accuracy': accuracy_score(y_val_fold, y_pred),
            'precision': precision_score(y_val_fold, y_pred, zero_division=0),
            'recall': recall_score(y_val_fold, y_pred, zero_division=0),
            'f1_score': f1_score(y_val_fold, y_pred, zero_division=0),
            'roc_auc': roc_auc_score(y_val_fold, y_pred_proba),
            'training_time': fold_train_time,
            'epochs_trained': len(history.history['loss'])
        }
        
        # Calculate specificity
        cm = confusion_matrix(y_val_fold, y_pred)
        if cm.shape == (2, 2):
            tn, fp, fn, tp = cm.ravel()
            fold_metrics['specificity'] = tn / (tn + fp) if (tn + fp) > 0 else 0
        else:
            fold_metrics['specificity'] = 0
        
        fold_results.append(fold_metrics)
        models.append(model)
        histories.append(history)
        all_y_true.extend(y_val_fold)
        all_y_pred.extend(y_pred)
        all_y_pred_proba.extend(y_pred_proba.flatten())
        
        # Print fold results
        print(f"\nüìä Fold {fold} Results:")
        print(f"  Accuracy:    {fold_metrics['accuracy']:.4f}")
        print(f"  Precision:   {fold_metrics['precision']:.4f}")
        print(f"  Recall:      {fold_metrics['recall']:.4f}")
        print(f"  Specificity: {fold_metrics['specificity']:.4f}")
        print(f"  F1-Score:    {fold_metrics['f1_score']:.4f}")
        print(f"  ROC-AUC:     {fold_metrics['roc_auc']:.4f}")
        print(f"  Training time: {fold_train_time:.2f}s ({fold_metrics['epochs_trained']} epochs)")
    
    total_cv_time = time() - fold_start_time
    
    # Calculate aggregate statistics
    df_results = pd.DataFrame(fold_results)
    
    summary_stats = {
        'mean_accuracy': df_results['accuracy'].mean(),
        'std_accuracy': df_results['accuracy'].std(),
        'mean_precision': df_results['precision'].mean(),
        'std_precision': df_results['precision'].std(),
        'mean_recall': df_results['recall'].mean(),
        'std_recall': df_results['recall'].std(),
        'mean_specificity': df_results['specificity'].mean(),
        'std_specificity': df_results['specificity'].std(),
        'mean_f1': df_results['f1_score'].mean(),
        'std_f1': df_results['f1_score'].std(),
        'mean_roc_auc': df_results['roc_auc'].mean(),
        'std_roc_auc': df_results['roc_auc'].std(),
        'total_training_time': df_results['training_time'].sum(),
        'mean_training_time': df_results['training_time'].mean(),
        'total_cv_time': total_cv_time
    }
    
    # Print final summary
    print(f"\n\n{'='*70}")
    print(f"üìà CROSS-VALIDATION SUMMARY ({n_splits} folds)")
    print(f"{'='*70}")
    print(f"\nüéØ Performance Metrics (Mean ¬± Std):")
    print(f"  Accuracy:    {summary_stats['mean_accuracy']:.4f} ¬± {summary_stats['std_accuracy']:.4f}")
    print(f"  Precision:   {summary_stats['mean_precision']:.4f} ¬± {summary_stats['std_precision']:.4f}")
    print(f"  Recall:      {summary_stats['mean_recall']:.4f} ¬± {summary_stats['std_recall']:.4f}")
    print(f"  Specificity: {summary_stats['mean_specificity']:.4f} ¬± {summary_stats['std_specificity']:.4f}")
    print(f"  F1-Score:    {summary_stats['mean_f1']:.4f} ¬± {summary_stats['std_f1']:.4f}")
    print(f"  ROC-AUC:     {summary_stats['mean_roc_auc']:.4f} ¬± {summary_stats['std_roc_auc']:.4f}")
    print(f"\n‚è±Ô∏è  Training Time:")
    print(f"  Per fold (avg): {summary_stats['mean_training_time']:.2f}s")
    print(f"  Total CV time:  {total_cv_time:.2f}s ({total_cv_time/60:.2f} min)")
    print(f"{'='*70}\n")
    
    # Compile results
    results = {
        'fold_results': fold_results,
        'summary_stats': summary_stats,
        'all_predictions': {
            'y_true': np.array(all_y_true),
            'y_pred': np.array(all_y_pred),
            'y_pred_proba': np.array(all_y_pred_proba)
        }
    }
    
    return results, models, histories

print("‚úÖ K-Fold cross-validation function defined")

## 2. Results Visualization Functions

In [None]:
def plot_kfold_results(results, save_path=None):
    """
    Visualize K-Fold cross-validation results
    
    Parameters:
    -----------
    results : dict
        Results dictionary from kfold_cross_validation
    save_path : str or None
        Path to save figure
    """
    df = pd.DataFrame(results['fold_results'])
    metrics = ['accuracy', 'precision', 'recall', 'specificity', 'f1_score', 'roc_auc']
    
    fig, axes = plt.subplots(2, 3, figsize=(18, 10))
    axes = axes.flatten()
    
    for idx, metric in enumerate(metrics):
        ax = axes[idx]
        
        # Bar plot for each fold
        bars = ax.bar(df['fold'], df[metric], alpha=0.7, color='steelblue', edgecolor='black')
        
        # Add mean line
        mean_val = df[metric].mean()
        ax.axhline(mean_val, color='red', linestyle='--', linewidth=2, label=f'Mean: {mean_val:.4f}')
        
        # Styling
        ax.set_xlabel('Fold', fontsize=11, fontweight='bold')
        ax.set_ylabel(metric.replace('_', ' ').title(), fontsize=11, fontweight='bold')
        ax.set_title(f'{metric.replace("_", " ").title()} per Fold', fontsize=12, fontweight='bold')
        ax.set_ylim([0, 1.05])
        ax.legend()
        ax.grid(True, alpha=0.3, axis='y')
        
        # Add value labels on bars
        for bar in bars:
            height = bar.get_height()
            ax.text(bar.get_x() + bar.get_width()/2., height,
                   f'{height:.3f}',
                   ha='center', va='bottom', fontsize=9)
    
    plt.suptitle('K-Fold Cross-Validation Results', fontsize=16, fontweight='bold', y=1.00)
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        print(f"‚úÖ Figure saved to {save_path}")
    
    plt.show()

print("‚úÖ Visualization functions defined")

## 3. Comparison Boxplot

In [None]:
def plot_metrics_boxplot(results, save_path=None):
    """
    Create boxplot showing distribution of metrics across folds
    """
    df = pd.DataFrame(results['fold_results'])
    metrics = ['accuracy', 'precision', 'recall', 'specificity', 'f1_score', 'roc_auc']
    
    # Prepare data for boxplot
    data_to_plot = [df[metric].values for metric in metrics]
    labels = [m.replace('_', ' ').title() for m in metrics]
    
    fig, ax = plt.subplots(figsize=(14, 7))
    
    bp = ax.boxplot(data_to_plot, labels=labels, patch_artist=True,
                    notch=True, showmeans=True,
                    boxprops=dict(facecolor='lightblue', alpha=0.7),
                    medianprops=dict(color='red', linewidth=2),
                    meanprops=dict(marker='D', markerfacecolor='green', markersize=8))
    
    ax.set_ylabel('Score', fontsize=12, fontweight='bold')
    ax.set_title('Distribution of Metrics Across K-Folds', fontsize=14, fontweight='bold')
    ax.set_ylim([0, 1.05])
    ax.grid(True, alpha=0.3, axis='y')
    
    # Add legend
    ax.legend([bp['medians'][0], bp['means'][0]], 
             ['Median', 'Mean'], 
             loc='lower left', fontsize=10)
    
    plt.xticks(rotation=15, ha='right')
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        print(f"‚úÖ Boxplot saved to {save_path}")
    
    plt.show()

print("‚úÖ Boxplot function defined")

## 4. Training History Comparison

In [None]:
def plot_training_histories(histories, save_path=None):
    """
    Plot training histories from all folds
    """
    fig, axes = plt.subplots(1, 2, figsize=(16, 5))
    
    # Plot accuracy
    for i, history in enumerate(histories, 1):
        axes[0].plot(history.history['accuracy'], alpha=0.5, label=f'Fold {i} Train')
        axes[0].plot(history.history['val_accuracy'], alpha=0.5, linestyle='--', label=f'Fold {i} Val')
    
    axes[0].set_xlabel('Epoch', fontsize=11, fontweight='bold')
    axes[0].set_ylabel('Accuracy', fontsize=11, fontweight='bold')
    axes[0].set_title('Training & Validation Accuracy Across Folds', fontsize=12, fontweight='bold')
    axes[0].legend(bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=8)
    axes[0].grid(True, alpha=0.3)
    
    # Plot loss
    for i, history in enumerate(histories, 1):
        axes[1].plot(history.history['loss'], alpha=0.5, label=f'Fold {i} Train')
        axes[1].plot(history.history['val_loss'], alpha=0.5, linestyle='--', label=f'Fold {i} Val')
    
    axes[1].set_xlabel('Epoch', fontsize=11, fontweight='bold')
    axes[1].set_ylabel('Loss', fontsize=11, fontweight='bold')
    axes[1].set_title('Training & Validation Loss Across Folds', fontsize=12, fontweight='bold')
    axes[1].legend(bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=8)
    axes[1].grid(True, alpha=0.3)
    
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        print(f"‚úÖ Training history plot saved to {save_path}")
    
    plt.show()

print("‚úÖ Training history plotting function defined")

## 5. Save Results Function

In [None]:
def save_kfold_results(results, save_dir='./kfold_results'):
    """
    Save K-Fold results to files
    
    Parameters:
    -----------
    results : dict
        Results from kfold_cross_validation
    save_dir : str
        Directory to save results
    """
    import os
    os.makedirs(save_dir, exist_ok=True)
    
    # Save fold results as CSV
    df_folds = pd.DataFrame(results['fold_results'])
    df_folds.to_csv(f"{save_dir}/fold_results.csv", index=False)
    print(f"‚úÖ Fold results saved to {save_dir}/fold_results.csv")
    
    # Save summary statistics as JSON
    with open(f"{save_dir}/summary_stats.json", 'w') as f:
        json.dump(results['summary_stats'], f, indent=4)
    print(f"‚úÖ Summary statistics saved to {save_dir}/summary_stats.json")
    
    # Save predictions
    np.save(f"{save_dir}/y_true.npy", results['all_predictions']['y_true'])
    np.save(f"{save_dir}/y_pred.npy", results['all_predictions']['y_pred'])
    np.save(f"{save_dir}/y_pred_proba.npy", results['all_predictions']['y_pred_proba'])
    print(f"‚úÖ Predictions saved to {save_dir}/")
    
    print(f"\n‚úÖ All results saved to {save_dir}/")

print("‚úÖ Save results function defined")

## 6. Example Usage Template

In [None]:
# Example usage (uncomment to use):
# 
# # Define model builder function
# def create_model():
#     from tensorflow import keras
#     model = keras.Sequential([
#         keras.layers.Input(shape=(1024, 1)),
#         keras.layers.Conv1D(32, 7, activation='relu', padding='same'),
#         keras.layers.MaxPooling1D(2),
#         keras.layers.LSTM(64),
#         keras.layers.Dense(32, activation='relu'),
#         keras.layers.Dense(1, activation='sigmoid')
#     ])
#     model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
#     return model
# 
# # Load your data
# # X = ...
# # y = ...
# 
# # Run K-Fold CV
# results, models, histories = kfold_cross_validation(
#     model_builder=create_model,
#     X=X,
#     y=y,
#     n_splits=5,
#     epochs=100,
#     batch_size=64,
#     verbose=1
# )
# 
# # Visualize results
# plot_kfold_results(results, save_path='kfold_metrics.png')
# plot_metrics_boxplot(results, save_path='kfold_boxplot.png')
# plot_training_histories(histories, save_path='kfold_training.png')
# 
# # Save results
# save_kfold_results(results, save_dir='./my_kfold_results')

print("\n" + "="*70)
print("‚úÖ K-Fold Cross-Validation utilities loaded successfully!")
print("="*70)
print("\nAvailable functions:")
print("  - kfold_cross_validation(model_builder, X, y, n_splits, ...)")
print("  - plot_kfold_results(results, save_path)")
print("  - plot_metrics_boxplot(results, save_path)")
print("  - plot_training_histories(histories, save_path)")
print("  - save_kfold_results(results, save_dir)")
print("\n" + "="*70)