# File Location: notebooks/06_advanced_mechanics/14_custom_loops_kfold.ipynb

# Custom Loops and K-Fold Cross Validation

This notebook explores advanced PyTorch Lightning loop customization through K-Fold cross validation implementation. We'll learn how to create custom training loops, implement cross-validation strategies, and wrap existing training loops for enhanced functionality.

## Learning Objectives
- Understand PyTorch Lightning's loop architecture
- Implement custom training loops and FitLoop wrappers
- Build a comprehensive K-Fold cross validation system
- Handle data splitting and validation across multiple folds
- Aggregate results and perform statistical analysis

```python
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl
from torch.utils.data import DataLoader, Dataset, Subset
from sklearn.model_selection import KFold, StratifiedKFold
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
import torchvision
from torchvision import transforms
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from typing import List, Dict, Any, Optional
from collections import defaultdict
import os
from pytorch_lightning.loops import FitLoop
from pytorch_lightning.loops.base import Loop
from pytorch_lightning.trainer.states import TrainerFn

# Set random seeds
torch.manual_seed(42)
np.random.seed(42)

print(f"PyTorch version: {torch.__version__}")
print(f"Lightning version: {pl.__version__}")
```

## 1. Understanding Lightning Loop Architecture

```python
class LoopArchitectureExplainer:
    """
    PyTorch Lightning Loop Architecture:
    
    1. Base Loop: Core loop functionality
    2. FitLoop: Main training loop
    3. EvaluationLoop: Validation/test loops
    4. OptimizerLoop: Optimizer step loop
    5. Custom Loops: User-defined loops
    
    Loop Hierarchy:
    - TrainerLoop (top-level)
      ├── FitLoop (training)
      │   ├── TrainingEpochLoop
      │   └── ValidationEpochLoop
      └── EvaluationLoop (test/predict)
    """
    
    @staticmethod
    def explain_loop_methods():
        methods = {
            "setup()": "Initialize loop state and resources",
            "reset()": "Reset loop state between runs",
            "on_run_start()": "Called at the beginning of loop execution",
            "advance()": "Main loop iteration logic",
            "on_run_end()": "Called at the end of loop execution",
            "run()": "Main entry point that orchestrates the loop"
        }
        
        print("Key Loop Methods:")
        for method, description in methods.items():
            print(f"  {method}: {description}")

LoopArchitectureExplainer.explain_loop_methods()
```

## 2. Custom K-Fold Cross Validation Loop

```python
class KFoldLoop(Loop):
    """Custom loop for K-Fold cross validation"""
    
    def __init__(self, num_folds: int = 5, shuffle: bool = True, stratified: bool = True):
        super().__init__()
        self.num_folds = num_folds
        self.shuffle = shuffle
        self.stratified = stratified
        
        # Results storage
        self.fold_results = []
        self.current_fold = 0
        self.kfold_splitter = None
        
        # Data storage
        self.full_dataset = None
        self.targets = None
        
    @property
    def done(self) -> bool:
        """Check if all folds are completed"""
        return self.current_fold >= self.num_folds
    
    def reset(self) -> None:
        """Reset the loop state"""
        self.current_fold = 0
        self.fold_results = []
    
    def on_run_start(self, *args, **kwargs) -> None:
        """Initialize K-Fold splitter and data"""
        print(f"Starting {self.num_folds}-fold cross validation")
        
        # Get full dataset from trainer's datamodule
        datamodule = self.trainer.datamodule
        if hasattr(datamodule, 'full_dataset'):
            self.full_dataset = datamodule.full_dataset
            self.targets = datamodule.targets
        else:
            raise ValueError("DataModule must have 'full_dataset' and 'targets' attributes")
        
        # Initialize appropriate splitter
        if self.stratified and self.targets is not None:
            self.kfold_splitter = StratifiedKFold(
                n_splits=self.num_folds, 
                shuffle=self.shuffle, 
                random_state=42
            )
            splits = self.kfold_splitter.split(
                range(len(self.full_dataset)), 
                self.targets
            )
        else:
            self.kfold_splitter = KFold(
                n_splits=self.num_folds, 
                shuffle=self.shuffle, 
                random_state=42
            )
            splits = self.kfold_splitter.split(range(len(self.full_dataset)))
        
        # Store all splits
        self.splits = list(splits)
    
    def advance(self, *args, **kwargs) -> None:
        """Execute one fold of cross validation"""
        print(f"\n--- Fold {self.current_fold + 1}/{self.num_folds} ---")
        
        # Get train/val indices for current fold
        train_indices, val_indices = self.splits[self.current_fold]
        
        # Create datasets for current fold
        train_dataset = Subset(self.full_dataset, train_indices)
        val_dataset = Subset(self.full_dataset, val_indices)
        
        # Update datamodule with fold-specific datasets
        self.trainer.datamodule.setup_fold(train_dataset, val_dataset)
        
        # Reset model for this fold (important for fair comparison)
        self.trainer.lightning_module = self.trainer.lightning_module.__class__(
            **self.trainer.lightning_module.hparams
        )
        
        # Run training for this fold
        self.trainer.fit_loop.reset()
        self.trainer.fit_loop.run()
        
        # Evaluate on validation set
        val_results = self.trainer.validate(
            model=self.trainer.lightning_module,
            dataloaders=self.trainer.datamodule.val_dataloader(),
            verbose=False
        )[0]
        
        # Store fold results
        fold_result = {
            'fold': self.current_fold + 1,
            'train_indices': train_indices.tolist(),
            'val_indices': val_indices.tolist(),
            'val_results': val_results,
            'model_state': self.trainer.lightning_module.state_dict().copy()
        }
        
        self.fold_results.append(fold_result)
        
        print(f"Fold {self.current_fold + 1} completed:")
        for metric, value in val_results.items():
            print(f"  {metric}: {value:.4f}")
        
        self.current_fold += 1
    
    def on_run_end(self) -> None:
        """Aggregate and report final results"""
        print(f"\n=== K-Fold Cross Validation Results ===")
        
        # Aggregate metrics
        metric_aggregates = self._aggregate_metrics()
        
        # Print summary
        for metric, stats in metric_aggregates.items():
            print(f"{metric}:")
            print(f"  Mean: {stats['mean']:.4f} ± {stats['std']:.4f}")
            print(f"  Min: {stats['min']:.4f}, Max: {stats['max']:.4f}")
        
        # Store aggregated results
        self.trainer.lightning_module.kfold_results = {
            'fold_results': self.fold_results,
            'aggregated_metrics': metric_aggregates
        }
    
    def _aggregate_metrics(self) -> Dict[str, Dict[str, float]]:
        """Aggregate metrics across all folds"""
        metrics = defaultdict(list)
        
        # Collect metrics from all folds
        for fold_result in self.fold_results:
            for metric, value in fold_result['val_results'].items():
                if isinstance(value, (int, float)):
                    metrics[metric].append(value)
        
        # Calculate statistics
        aggregated = {}
        for metric, values in metrics.items():
            aggregated[metric] = {
                'mean': np.mean(values),
                'std': np.std(values),
                'min': np.min(values),
                'max': np.max(values),
                'values': values
            }
        
        return aggregated

print("K-Fold Loop implementation complete!")
```

## 3. Custom FitLoop Wrapper

```python
class KFoldFitLoop(FitLoop):
    """Custom FitLoop wrapper that integrates K-Fold validation"""
    
    def __init__(self, num_folds: int = 5, stratified: bool = True):
        super().__init__()
        self.kfold_loop = KFoldLoop(num_folds=num_folds, stratified=stratified)
    
    def run(self, *args, **kwargs) -> None:
        """Override run to use K-Fold instead of standard training"""
        # Setup the K-Fold loop
        self.kfold_loop.trainer = self.trainer
        
        # Run K-Fold cross validation
        self.kfold_loop.run()
        
        # The rest of the training state is handled by the K-Fold loop
        return None

print("Custom FitLoop wrapper created!")
```

## 4. Lightning Module for K-Fold

```python
class KFoldClassifier(pl.LightningModule):
    """Lightning module optimized for K-Fold cross validation"""
    
    def __init__(
        self, 
        num_classes: int = 10, 
        learning_rate: float = 0.001,
        architecture: str = 'simple',
        dropout: float = 0.5
    ):
        super().__init__()
        self.save_hyperparameters()
        
        # Build model based on architecture
        if architecture == 'simple':
            self.model = self._build_simple_model()
        elif architecture == 'resnet':
            self.model = self._build_resnet_model()
        else:
            raise ValueError(f"Unknown architecture: {architecture}")
        
        # Loss function
        self.criterion = nn.CrossEntropyLoss()
        
        # Metrics for each split
        self.train_acc = pl.metrics.Accuracy(task="multiclass", num_classes=num_classes)
        self.val_acc = pl.metrics.Accuracy(task="multiclass", num_classes=num_classes)
        
        # Storage for K-fold results
        self.kfold_results = None
        
        # Track fold-specific metrics
        self.fold_metrics = {
            'train_loss': [],
            'val_loss': [],
            'train_acc': [],
            'val_acc': []
        }
    
    def _build_simple_model(self):
        """Build a simple MLP model"""
        return nn.Sequential(
            nn.Flatten(),
            nn.Linear(784, 256),
            nn.ReLU(),
            nn.Dropout(self.hparams.dropout),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Dropout(self.hparams.dropout),
            nn.Linear(128, self.hparams.num_classes)
        )
    
    def _build_resnet_model(self):
        """Build a ResNet-like model for MNIST"""
        class ResNetBlock(nn.Module):
            def __init__(self, in_channels, out_channels, stride=1):
                super().__init__()
                self.conv1 = nn.Conv2d(in_channels, out_channels, 3, stride=stride, padding=1)
                self.bn1 = nn.BatchNorm2d(out_channels)
                self.conv2 = nn.Conv2d(out_channels, out_channels, 3, padding=1)
                self.bn2 = nn.BatchNorm2d(out_channels)
                
                # Shortcut connection
                self.shortcut = nn.Sequential()
                if stride != 1 or in_channels != out_channels:
                    self.shortcut = nn.Sequential(
                        nn.Conv2d(in_channels, out_channels, 1, stride=stride),
                        nn.BatchNorm2d(out_channels)
                    )
            
            def forward(self, x):
                residual = x
                out = F.relu(self.bn1(self.conv1(x)))
                out = self.bn2(self.conv2(out))
                out += self.shortcut(residual)
                out = F.relu(out)
                return out
        
        return nn.Sequential(
            nn.Conv2d(1, 32, 3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            ResNetBlock(32, 32),
            ResNetBlock(32, 64, stride=2),
            ResNetBlock(64, 64),
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Flatten(),
            nn.Linear(64, self.hparams.num_classes)
        )
    
    def forward(self, x):
        return self.model(x)
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.criterion(logits, y)
        
        # Calculate accuracy
        self.train_acc(logits, y)
        
        # Log metrics
        self.log('train_loss', loss, on_step=False, on_epoch=True, prog_bar=True)
        self.log('train_acc', self.train_acc, on_step=False, on_epoch=True, prog_bar=True)
        
        return loss
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.criterion(logits, y)
        
        # Calculate accuracy
        self.val_acc(logits, y)
        
        # Log metrics
        self.log('val_loss', loss, on_step=False, on_epoch=True, prog_bar=True)
        self.log('val_acc', self.val_acc, on_step=False, on_epoch=True, prog_bar=True)
        
        return {'val_loss': loss, 'val_acc': self.val_acc.compute()}
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=5)
        return {
            'optimizer': optimizer,
            'lr_scheduler': {
                'scheduler': scheduler,
                'monitor': 'val_loss'
            }
        }
    
    def get_fold_summary(self):
        """Get summary of K-fold results"""
        if self.kfold_results is None:
            return "No K-fold results available"
        
        summary = []
        summary.append("K-Fold Cross Validation Summary:")
        summary.append("=" * 40)
        
        for metric, stats in self.kfold_results['aggregated_metrics'].items():
            summary.append(f"{metric}: {stats['mean']:.4f} ± {stats['std']:.4f}")
        
        return "\n".join(summary)

# Initialize model
model = KFoldClassifier(num_classes=10, learning_rate=0.001, architecture='simple')
```

## 5. Data Module for K-Fold Cross Validation

```python
class KFoldDataModule(pl.LightningDataModule):
    """Data module specifically designed for K-Fold cross validation"""
    
    def __init__(self, batch_size: int = 64, num_workers: int = 4):
        super().__init__()
        self.batch_size = batch_size
        self.num_workers = num_workers
        
        # Transform
        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))
        ])
        
        # Will be set during K-fold
        self.fold_train_dataset = None
        self.fold_val_dataset = None
        
        # Full dataset storage
        self.full_dataset = None
        self.targets = None
    
    def prepare_data(self):
        # Download MNIST
        torchvision.datasets.MNIST('./data', train=True, download=True)
        torchvision.datasets.MNIST('./data', train=False, download=True)
    
    def setup(self, stage: Optional[str] = None):
        if stage == 'fit' or stage is None:
            # Combine train and test sets for K-fold
            train_dataset = torchvision.datasets.MNIST('./data', train=True, transform=self.transform)
            test_dataset = torchvision.datasets.MNIST('./data', train=False, transform=self.transform)
            
            # Combine datasets
            full_data = torch.cat([train_dataset.data, test_dataset.data])
            full_targets = torch.cat([train_dataset.targets, test_dataset.targets])
            
            # Create combined dataset
            self.full_dataset = torchvision.datasets.MNIST('./data', train=True, transform=self.transform)
            self.full_dataset.data = full_data
            self.full_dataset.targets = full_targets
            
            # Store targets for stratified splitting
            self.targets = full_targets.numpy()
    
    def setup_fold(self, train_dataset, val_dataset):
        """Setup datasets for a specific fold"""
        self.fold_train_dataset = train_dataset
        self.fold_val_dataset = val_dataset
    
    def train_dataloader(self):
        return DataLoader(
            self.fold_train_dataset, 
            batch_size=self.batch_size, 
            shuffle=True, 
            num_workers=self.num_workers,
            pin_memory=True
        )
    
    def val_dataloader(self):
        return DataLoader(
            self.fold_val_dataset, 
            batch_size=self.batch_size, 
            shuffle=False, 
            num_workers=self.num_workers,
            pin_memory=True
        )

# Initialize data module
data_module = KFoldDataModule(batch_size=64, num_workers=4)
```

## 6. Custom Trainer with K-Fold Integration

```python
class KFoldTrainer(pl.Trainer):
    """Custom trainer with integrated K-Fold cross validation"""
    
    def __init__(self, num_folds: int = 5, stratified: bool = True, *args, **kwargs):
        super().__init__(*args, **kwargs)
        
        # Replace the default fit loop with K-Fold loop
        self.fit_loop = KFoldFitLoop(num_folds=num_folds, stratified=stratified)
        self.fit_loop.trainer = self
    
    def fit(self, model, datamodule=None, *args, **kwargs):
        """Override fit to use K-Fold validation"""
        # Standard setup
        self.lightning_module = model
        self.datamodule = datamodule
        
        if datamodule:
            datamodule.setup('fit')
        
        # Run K-Fold cross validation
        self.fit_loop.run()
        
        return self

# Alternative: Standard trainer with manual K-Fold
def run_kfold_validation(model_class, data_module, num_folds=5, max_epochs=10):
    """Manual K-Fold validation function"""
    
    # Initialize results storage
    fold_results = []
    
    # Setup full dataset
    data_module.setup('fit')
    full_dataset = data_module.full_dataset
    targets = data_module.targets
    
    # Initialize K-Fold splitter
    kfold = StratifiedKFold(n_splits=num_folds, shuffle=True, random_state=42)
    
    for fold, (train_indices, val_indices) in enumerate(kfold.split(range(len(full_dataset)), targets)):
        print(f"\n--- Fold {fold + 1}/{num_folds} ---")
        
        # Create fold datasets
        train_dataset = Subset(full_dataset, train_indices)
        val_dataset = Subset(full_dataset, val_indices)
        
        # Setup data module for this fold
        data_module.setup_fold(train_dataset, val_dataset)
        
        # Initialize fresh model for this fold
        model = model_class()
        
        # Create trainer for this fold
        trainer = pl.Trainer(
            max_epochs=max_epochs,
            accelerator='auto',
            devices=1,
            logger=False,
            enable_checkpointing=False,
            enable_progress_bar=False
        )
        
        # Train model
        trainer.fit(model, data_module)
        
        # Validate
        val_results = trainer.validate(model, data_module, verbose=False)[0]
        
        # Store results
        fold_results.append({
            'fold': fold + 1,
            'val_results': val_results,
            'model': model
        })
        
        print(f"Fold {fold + 1} - Val Acc: {val_results['val_acc']:.4f}, Val Loss: {val_results['val_loss']:.4f}")
    
    return fold_results

print("K-Fold trainer and manual validation function ready!")
```

## 7. Running K-Fold Cross Validation

```python
# Method 1: Using Custom Trainer
print("Method 1: Custom K-Fold Trainer")
kfold_trainer = KFoldTrainer(
    num_folds=5,
    stratified=True,
    max_epochs=10,
    accelerator='auto',
    devices=1
)

# Train with K-Fold
kfold_trainer.fit(model, data_module)

# Get results
kfold_results = model.kfold_results
print(model.get_fold_summary())

# Method 2: Manual K-Fold validation
print("\nMethod 2: Manual K-Fold Validation")
manual_results = run_kfold_validation(
    lambda: KFoldClassifier(num_classes=10, learning_rate=0.001),
    data_module,
    num_folds=5,
    max_epochs=10
)

# Aggregate manual results
manual_metrics = defaultdict(list)
for result in manual_results:
    for metric, value in result['val_results'].items():
        if isinstance(value, (int, float)):
            manual_metrics[metric].append(value)

print("\nManual K-Fold Results:")
for metric, values in manual_metrics.items():
    mean_val = np.mean(values)
    std_val = np.std(values)
    print(f"{metric}: {mean_val:.4f} ± {std_val:.4f}")
```

## 8. Advanced K-Fold Analysis

```python
class KFoldAnalyzer:
    """Advanced analysis tools for K-Fold results"""
    
    def __init__(self, fold_results):
        self.fold_results = fold_results
    
    def plot_fold_performance(self):
        """Plot performance across folds"""
        if isinstance(self.fold_results, dict) and 'fold_results' in self.fold_results:
            results = self.fold_results['fold_results']
        else:
            results = self.fold_results
        
        metrics = ['val_acc', 'val_loss']
        fig, axes = plt.subplots(1, 2, figsize=(12, 5))
        
        for idx, metric in enumerate(metrics):
            values = []
            fold_nums = []
            
            for result in results:
                if 'val_results' in result and metric in result['val_results']:
                    values.append(result['val_results'][metric])
                    fold_nums.append(result.get('fold', len(fold_nums) + 1))
                elif metric in result:
                    values.append(result[metric])
                    fold_nums.append(result.get('fold', len(fold_nums) + 1))
            
            if values:
                axes[idx].bar(fold_nums, values, alpha=0.7)
                axes[idx].axhline(y=np.mean(values), color='r', linestyle='--', label=f'Mean: {np.mean(values):.4f}')
                axes[idx].set_xlabel('Fold')
                axes[idx].set_ylabel(metric.replace('_', ' ').title())
                axes[idx].set_title(f'{metric.replace("_", " ").title()} Across Folds')
                axes[idx].legend()
                axes[idx].grid(True, alpha=0.3)
        
        plt.tight_layout()
        plt.show()
    
    def statistical_analysis(self):
        """Perform statistical analysis of fold results"""
        if isinstance(self.fold_results, dict) and 'fold_results' in self.fold_results:
            results = self.fold_results['fold_results']
        else:
            results = self.fold_results
        
        # Extract metrics
        metrics_data = defaultdict(list)
        for result in results:
            val_results = result.get('val_results', result)
            for metric, value in val_results.items():
                if isinstance(value, (int, float)):
                    metrics_data[metric].append(value)
        
        # Statistical tests
        print("Statistical Analysis:")
        print("=" * 50)
        
        for metric, values in metrics_data.items():
            values = np.array(values)
            
            print(f"\n{metric.upper()}:")
            print(f"  Mean: {np.mean(values):.4f}")
            print(f"  Std:  {np.std(values):.4f}")
            print(f"  Min:  {np.min(values):.4f}")
            print(f"  Max:  {np.max(values):.4f}")
            print(f"  CV:   {np.std(values)/np.mean(values)*100:.2f}%")
            
            # Confidence interval (95%)
            ci_lower = np.percentile(values, 2.5)
            ci_upper = np.percentile(values, 97.5)
            print(f"  95% CI: [{ci_lower:.4f}, {ci_upper:.4f}]")
    
    def compare_architectures(self, other_results, labels=None):
        """Compare results from different architectures"""
        if labels is None:
            labels = ['Architecture 1', 'Architecture 2']
        
        results_list = [self.fold_results, other_results]
        
        fig, ax = plt.subplots(figsize=(10, 6))
        
        metric = 'val_acc'  # Focus on accuracy
        
        for idx, (results, label) in enumerate(zip(results_list, labels)):
            if isinstance(results, dict) and 'fold_results' in results:
                fold_results = results['fold_results']
            else:
                fold_results = results
            
            values = []
            for result in fold_results:
                val_results = result.get('val_results', result)
                if metric in val_results:
                    values.append(val_results[metric])
            
            if values:
                positions = np.arange(len(values)) + idx * 0.4
                ax.bar(positions, values, width=0.35, label=label, alpha=0.7)
        
        ax.set_xlabel('Fold')
        ax.set_ylabel('Validation Accuracy')
        ax.set_title('Architecture Comparison Across Folds')
        ax.legend()
        ax.grid(True, alpha=0.3)
        
        plt.tight_layout()
        plt.show()

# Analyze results
analyzer = KFoldAnalyzer(manual_results)
analyzer.plot_fold_performance()
analyzer.statistical_analysis()
```

## 9. Model Selection with K-Fold

```python
class ModelSelector:
    """Model selection using K-Fold cross validation"""
    
    def __init__(self, data_module):
        self.data_module = data_module
        self.results = {}
    
    def evaluate_architecture(self, name, model_class, model_params, num_folds=5, max_epochs=10):
        """Evaluate a specific architecture"""
        print(f"\nEvaluating {name}...")
        
        fold_results = run_kfold_validation(
            lambda: model_class(**model_params),
            self.data_module,
            num_folds=num_folds,
            max_epochs=max_epochs
        )
        
        # Aggregate results
        metrics = defaultdict(list)
        for result in fold_results:
            for metric, value in result['val_results'].items():
                if isinstance(value, (int, float)):
                    metrics[metric].append(value)
        
        # Calculate statistics
        aggregated = {}
        for metric, values in metrics.items():
            aggregated[metric] = {
                'mean': np.mean(values),
                'std': np.std(values),
                'values': values
            }
        
        self.results[name] = {
            'fold_results': fold_results,
            'aggregated': aggregated
        }
        
        return aggregated
    
    def compare_models(self):
        """Compare all evaluated models"""
        if not self.results:
            print("No models evaluated yet!")
            return
        
        print("\nModel Comparison:")
        print("=" * 60)
        
        # Create comparison table
        comparison_data = []
        for name, results in self.results.items():
            agg = results['aggregated']
            comparison_data.append({
                'Model': name,
                'Val Acc Mean': f"{agg['val_acc']['mean']:.4f}",
                'Val Acc Std': f"{agg['val_acc']['std']:.4f}",
                'Val Loss Mean': f"{agg['val_loss']['mean']:.4f}",
                'Val Loss Std': f"{agg['val_loss']['std']:.4f}"
            })
        
        df = pd.DataFrame(comparison_data)
        print(df.to_string(index=False))
        
        # Find best model
        best_acc = 0
        best_model = None
        for name, results in self.results.items():
            acc = results['aggregated']['val_acc']['mean']
            if acc > best_acc:
                best_acc = acc
                best_model = name
        
        print(f"\nBest Model: {best_model} (Val Acc: {best_acc:.4f})")
    
    def plot_comparison(self):
        """Plot model comparison"""
        if len(self.results) < 2:
            print("Need at least 2 models for comparison")
            return
        
        fig, axes = plt.subplots(1, 2, figsize=(15, 6))
        
        models = list(self.results.keys())
        metrics = ['val_acc', 'val_loss']
        colors = plt.cm.Set3(np.linspace(0, 1, len(models)))
        
        for metric_idx, metric in enumerate(metrics):
            for model_idx, model_name in enumerate(models):
                values = self.results[model_name]['aggregated'][metric]['values']
                x_pos = np.arange(len(values)) + model_idx * 0.8 / len(models)
                
                axes[metric_idx].bar(
                    x_pos, values, 
                    width=0.8/len(models), 
                    label=model_name,
                    color=colors[model_idx],
                    alpha=0.7
                )
            
            axes[metric_idx].set_xlabel('Fold')
            axes[metric_idx].set_ylabel(metric.replace('_', ' ').title())
            axes[metric_idx].set_title(f'{metric.replace("_", " ").title()} Comparison')
            axes[metric_idx].legend()
            axes[metric_idx].grid(True, alpha=0.3)
        
        plt.tight_layout()
        plt.show()

# Model selection example
selector = ModelSelector(data_module)

# Evaluate different architectures
selector.evaluate_architecture(
    'Simple MLP',
    KFoldClassifier,
    {'num_classes': 10, 'learning_rate': 0.001, 'architecture': 'simple', 'dropout': 0.3},
    num_folds=3, max_epochs=5
)

selector.evaluate_architecture(
    'Simple MLP (High Dropout)',
    KFoldClassifier,
    {'num_classes': 10, 'learning_rate': 0.001, 'architecture': 'simple', 'dropout': 0.7},
    num_folds=3, max_epochs=5
)

selector.evaluate_architecture(
    'ResNet',
    KFoldClassifier,
    {'num_classes': 10, 'learning_rate': 0.001, 'architecture': 'resnet', 'dropout': 0.5},
    num_folds=3, max_epochs=5
)

# Compare models
selector.compare_models()
selector.plot_comparison()
```

## 10. Advanced Loop Customization

```python
class CustomValidationLoop(Loop):
    """Custom validation loop with detailed metrics"""
    
    def __init__(self):
        super().__init__()
        self.outputs = []
        self.current_batch = 0
        
    @property
    def done(self) -> bool:
        return self.current_batch >= len(self.trainer.val_dataloaders[0])
    
    def reset(self) -> None:
        self.outputs = []
        self.current_batch = 0
    
    def advance(self, *args, **kwargs) -> None:
        # Get next batch
        batch = next(iter(self.trainer.val_dataloaders[0]))
        
        # Run validation step
        with torch.no_grad():
            output = self.trainer.lightning_module.validation_step(batch, self.current_batch)
            self.outputs.append(output)
        
        self.current_batch += 1
    
    def on_run_end(self) -> None:
        # Aggregate outputs
        avg_loss = torch.stack([x['val_loss'] for x in self.outputs]).mean()
        avg_acc = torch.stack([x['val_acc'] for x in self.outputs]).mean()
        
        # Log aggregated metrics
        self.trainer.lightning_module.log('custom_val_loss', avg_loss)
        self.trainer.lightning_module.log('custom_val_acc', avg_acc)
        
        print(f"Custom validation completed: Loss={avg_loss:.4f}, Acc={avg_acc:.4f}")

# Usage example
class CustomLoopModel(KFoldClassifier):
    """Model with custom validation loop"""
    
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.custom_val_loop = CustomValidationLoop()
    
    def validation_epoch_end(self, outputs):
        # Custom validation logic
        if hasattr(self, 'custom_val_loop') and self.trainer:
            self.custom_val_loop.trainer = self.trainer
            # Could run custom validation here
            pass
        
        return super().validation_epoch_end(outputs)

print("Advanced loop customization examples complete!")
```

# Summary

This notebook demonstrated advanced PyTorch Lightning loop customization through comprehensive K-Fold cross validation implementation. Key concepts covered:

## Core Loop Architecture
- **Lightning Loop Hierarchy**: Understanding base loops, fit loops, and evaluation loops
- **Custom Loop Creation**: Building loops from scratch with proper state management
- **FitLoop Wrapping**: Extending existing loops with additional functionality
- **Loop Integration**: Seamlessly integrating custom loops with Lightning trainers

## K-Fold Cross Validation Implementation
- **Data Splitting**: Proper stratified and non-stratified data splitting
- **Fold Management**: Handling multiple training runs with fresh model initialization
- **Result Aggregation**: Statistical analysis and confidence interval calculation
- **Model Selection**: Systematic comparison of different architectures

## Advanced Features Implemented
- **Custom Trainers**: Extended trainer classes with K-Fold integration
- **Statistical Analysis**: Comprehensive performance analysis with visualization
- **Architecture Comparison**: Side-by-side model performance evaluation
- **Validation Loops**: Custom validation logic with detailed metrics

## Best Practices Established
- Proper dataset handling for cross-validation scenarios
- Model state management across folds
- Statistical significance testing and confidence intervals
- Memory-efficient data loading and model initialization

## Key Benefits
- **Robust Evaluation**: More reliable model performance estimation
- **Model Selection**: Data-driven architecture choice
- **Variance Analysis**: Understanding model stability across data splits
- **Research Flexibility**: Framework for custom training experiments

## Next Steps
- Implement nested cross-validation for hyperparameter optimization
- Add support for time series cross-validation
- Integrate with hyperparameter optimization libraries
- Extend to multi-dataset evaluation scenarios

The custom loop architecture provides the foundation for sophisticated training strategies while maintaining Lightning's ease of use and best practices.