# Model Training and Evaluation Framework

This notebook provides a comprehensive framework for training and evaluating all pneumonia detection models developed in this project. It includes:

1. **Unified Training Pipeline**: Standardized training procedures for all model types
2. **Evaluation Metrics**: Comprehensive performance assessment
3. **Model Comparison**: Side-by-side analysis of different approaches
4. **Hyperparameter Optimization**: Systematic parameter tuning
5. **Cross-Validation**: Robust performance estimation
6. **Model Persistence**: Saving and loading trained models

This framework enables reproducible training and fair comparison of all pneumonia detection approaches.

## 1. Setup and Imports

In [None]:
import os
import sys
import time
import json
from datetime import datetime
from pathlib import Path

# Core ML libraries
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

# PyTorch
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset, random_split
import torchvision
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
import timm

# Scikit-learn
from sklearn.model_selection import train_test_split, StratifiedKFold
from sklearn.metrics import (
    classification_report, confusion_matrix, ConfusionMatrixDisplay,
    roc_curve, auc, precision_recall_curve, average_precision_score
)
from sklearn.svm import SVC
from sklearn.preprocessing import StandardScaler
import joblib

# Image processing
from PIL import Image
import cv2

# Utilities
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Configuration
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {DEVICE}")

# Create results directory
RESULTS_DIR = Path("../results")
MODELS_DIR = Path("../models")
RESULTS_DIR.mkdir(exist_ok=True)
MODELS_DIR.mkdir(exist_ok=True)

## 2. Data Loading and Preparation

In [None]:
class PneumoniaDataset(Dataset):
    """
    Custom dataset class for pneumonia detection with flexible transformations
    """
    def __init__(self, data_dir, transform=None, split='train'):
        self.data_dir = Path(data_dir)
        self.transform = transform
        self.split = split
        
        # Load image paths and labels
        self.samples = []
        self.labels = []
        
        for class_idx, class_name in enumerate(['normal', 'pneumonia']):
            class_dir = self.data_dir / split / class_name
            if class_dir.exists():
                for img_path in class_dir.glob('*.jpg'):
                    self.samples.append(str(img_path))
                    self.labels.append(class_idx)
        
        print(f"Loaded {len(self.samples)} {split} images")
        print(f"Class distribution: Normal={self.labels.count(0)}, Pneumonia={self.labels.count(1)}")
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        img_path = self.samples[idx]
        label = self.labels[idx]
        
        # Load image
        image = Image.open(img_path).convert('RGB')
        
        if self.transform:
            image = self.transform(image)
        
        return image, label

def get_data_transforms(image_size=224, augment=True):
    """
    Get data transformation pipelines for training and validation
    """
    if augment:
        train_transform = transforms.Compose([
            transforms.Resize((image_size, image_size)),
            transforms.RandomRotation(15),
            transforms.RandomHorizontalFlip(0.5),
            transforms.ColorJitter(brightness=0.2, contrast=0.2),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])
    else:
        train_transform = transforms.Compose([
            transforms.Resize((image_size, image_size)),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])
    
    val_transform = transforms.Compose([
        transforms.Resize((image_size, image_size)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
    
    return train_transform, val_transform

def create_data_loaders(data_dir, batch_size=32, image_size=224, val_split=0.2):
    """
    Create train and validation data loaders
    """
    train_transform, val_transform = get_data_transforms(image_size)
    
    # Load full training dataset
    full_dataset = PneumoniaDataset(data_dir, train_transform, 'train')
    
    # Split into train and validation
    val_size = int(len(full_dataset) * val_split)
    train_size = len(full_dataset) - val_size
    
    train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])
    
    # Update validation dataset transform
    val_dataset.dataset.transform = val_transform
    
    # Create data loaders
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2)
    
    # Test loader
    test_dataset = PneumoniaDataset(data_dir, val_transform, 'test')
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2)
    
    return train_loader, val_loader, test_loader

## 3. Model Architectures

In [None]:
class XceptionModel(nn.Module):
    """Xception model for pneumonia classification"""
    def __init__(self, pretrained=True, freeze_layers=100):
        super(XceptionModel, self).__init__()
        self.xception = timm.create_model('xception', pretrained=pretrained)
        
        # Freeze early layers
        for i, (name, param) in enumerate(self.xception.named_parameters()):
            if i < freeze_layers:
                param.requires_grad = False
        
        # Replace classifier
        self.xception.global_pool = nn.Identity()
        self.xception.fc = nn.Identity()
        
        self.pool = nn.AdaptiveAvgPool2d((1, 1))
        self.classifier = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(2048, 128),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(128, 1)
        )
    
    def forward(self, x):
        x = self.xception(x)
        x = self.pool(x).view(x.size(0), -1)
        x = self.classifier(x)
        return x

class XceptionLSTM(nn.Module):
    """Xception-LSTM hybrid model"""
    def __init__(self, pretrained=True, freeze_layers=100):
        super(XceptionLSTM, self).__init__()
        self.xception = timm.create_model("xception", pretrained=pretrained, features_only=True)
        
        # Freeze early layers
        for i, (name, param) in enumerate(self.xception.named_parameters()):
            if i < freeze_layers:
                param.requires_grad = False
        
        self.pool = nn.AdaptiveAvgPool2d((1, 1))
        self.reshape = nn.Flatten(2)
        self.transpose = lambda x: x.permute(0, 2, 1)
        self.lstm = nn.LSTM(input_size=2048, hidden_size=256, batch_first=True)
        self.fc1 = nn.Linear(256, 64)
        self.dropout = nn.Dropout(0.46)
        self.fc2 = nn.Linear(64, 1)

    def forward(self, x):
        x = self.xception(x)[-1]
        x = self.pool(x)
        x = self.reshape(x)
        x = self.transpose(x)
        x, _ = self.lstm(x)
        x = x[:, -1, :]
        x = self.fc1(x)
        x = self.dropout(x)
        x = self.fc2(x)
        return x

class FusionModel(nn.Module):
    """Fusion model combining Xception and VGG16"""
    def __init__(self, pretrained=True):
        super(FusionModel, self).__init__()
        
        # Xception branch
        self.xception = timm.create_model('xception', pretrained=pretrained)
        self.xception.global_pool = nn.Identity()
        self.xception.fc = nn.Identity()
        
        # VGG16 branch
        self.vgg = torchvision.models.vgg16(pretrained=pretrained)
        self.vgg.classifier = nn.Identity()
        
        # Pooling layers
        self.xception_pool = nn.AdaptiveAvgPool2d((1, 1))
        self.vgg_pool = nn.AdaptiveAvgPool2d((1, 1))
        
        # Fusion classifier
        self.classifier = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(2048 + 512, 128),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(128, 1)
        )
    
    def forward(self, x):
        # Xception path
        x1 = self.xception(x)
        x1 = self.xception_pool(x1).view(x.size(0), -1)
        
        # VGG16 path
        x2 = self.vgg.features(x)
        x2 = self.vgg_pool(x2).view(x.size(0), -1)
        
        # Concatenate and classify
        fused = torch.cat((x1, x2), dim=1)
        out = self.classifier(fused)
        return out

def get_model(model_name, **kwargs):
    """Factory function to create models"""
    models = {
        'xception': XceptionModel,
        'xception_lstm': XceptionLSTM,
        'fusion': FusionModel
    }
    
    if model_name not in models:
        raise ValueError(f"Unknown model: {model_name}")
    
    return models[model_name](**kwargs)

## 4. Training Framework

In [None]:
class ModelTrainer:
    """Unified training framework for all model types"""
    
    def __init__(self, model, criterion, optimizer, device=None, scheduler=None):
        self.model = model
        self.criterion = criterion
        self.optimizer = optimizer
        self.scheduler = scheduler
        self.device = device or DEVICE
        
        self.model.to(self.device)
        
        # Training history
        self.history = {
            'train_loss': [],
            'train_acc': [],
            'val_loss': [],
            'val_acc': [],
            'lr': []
        }
    
    def train_epoch(self, train_loader):
        """Train for one epoch"""
        self.model.train()
        running_loss = 0.0
        correct = 0
        total = 0
        
        progress_bar = tqdm(train_loader, desc='Training')
        for inputs, labels in progress_bar:
            inputs, labels = inputs.to(self.device), labels.float().to(self.device)
            
            self.optimizer.zero_grad()
            outputs = self.model(inputs).squeeze()
            loss = self.criterion(outputs, labels)
            loss.backward()
            self.optimizer.step()
            
            running_loss += loss.item()
            predicted = (torch.sigmoid(outputs) > 0.5).float()
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            
            progress_bar.set_postfix({
                'Loss': f'{running_loss/(progress_bar.n+1):.4f}',
                'Acc': f'{100.*correct/total:.2f}%'
            })
        
        epoch_loss = running_loss / len(train_loader)
        epoch_acc = correct / total
        
        return epoch_loss, epoch_acc
    
    def validate_epoch(self, val_loader):
        """Validate for one epoch"""
        self.model.eval()
        running_loss = 0.0
        correct = 0
        total = 0
        
        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs, labels = inputs.to(self.device), labels.float().to(self.device)
                
                outputs = self.model(inputs).squeeze()
                loss = self.criterion(outputs, labels)
                
                running_loss += loss.item()
                predicted = (torch.sigmoid(outputs) > 0.5).float()
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
        
        epoch_loss = running_loss / len(val_loader)
        epoch_acc = correct / total
        
        return epoch_loss, epoch_acc
    
    def train(self, train_loader, val_loader, epochs, early_stopping_patience=None):
        """Complete training loop"""
        best_val_acc = 0.0
        best_model_state = None
        patience_counter = 0
        
        print(f"Starting training for {epochs} epochs...")
        start_time = time.time()
        
        for epoch in range(epochs):
            print(f"\nEpoch {epoch+1}/{epochs}")
            print("-" * 50)
            
            # Training phase
            train_loss, train_acc = self.train_epoch(train_loader)
            
            # Validation phase
            val_loss, val_acc = self.validate_epoch(val_loader)
            
            # Update learning rate
            if self.scheduler:
                self.scheduler.step(val_loss)
            
            # Record history
            self.history['train_loss'].append(train_loss)
            self.history['train_acc'].append(train_acc)
            self.history['val_loss'].append(val_loss)
            self.history['val_acc'].append(val_acc)
            self.history['lr'].append(self.optimizer.param_groups[0]['lr'])
            
            print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}")
            print(f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}")
            print(f"Learning Rate: {self.optimizer.param_groups[0]['lr']:.6f}")
            
            # Save best model
            if val_acc > best_val_acc:
                best_val_acc = val_acc
                best_model_state = self.model.state_dict().copy()
                patience_counter = 0
                print(f"New best validation accuracy: {best_val_acc:.4f}")
            else:
                patience_counter += 1
            
            # Early stopping
            if early_stopping_patience and patience_counter >= early_stopping_patience:
                print(f"Early stopping triggered after {patience_counter} epochs without improvement")
                break
        
        # Load best model
        if best_model_state:
            self.model.load_state_dict(best_model_state)
        
        training_time = time.time() - start_time
        print(f"\nTraining completed in {training_time:.2f} seconds")
        print(f"Best validation accuracy: {best_val_acc:.4f}")
        
        return self.history
    
    def plot_training_history(self, save_path=None):
        """Plot training history"""
        fig, axes = plt.subplots(2, 2, figsize=(15, 10))
        
        # Loss
        axes[0, 0].plot(self.history['train_loss'], label='Train Loss')
        axes[0, 0].plot(self.history['val_loss'], label='Val Loss')
        axes[0, 0].set_title('Model Loss')
        axes[0, 0].set_xlabel('Epoch')
        axes[0, 0].set_ylabel('Loss')
        axes[0, 0].legend()
        axes[0, 0].grid(True)
        
        # Accuracy
        axes[0, 1].plot(self.history['train_acc'], label='Train Acc')
        axes[0, 1].plot(self.history['val_acc'], label='Val Acc')
        axes[0, 1].set_title('Model Accuracy')
        axes[0, 1].set_xlabel('Epoch')
        axes[0, 1].set_ylabel('Accuracy')
        axes[0, 1].legend()
        axes[0, 1].grid(True)
        
        # Learning Rate
        axes[1, 0].plot(self.history['lr'])
        axes[1, 0].set_title('Learning Rate')
        axes[1, 0].set_xlabel('Epoch')
        axes[1, 0].set_ylabel('Learning Rate')
        axes[1, 0].grid(True)
        
        # Remove empty subplot
        axes[1, 1].remove()
        
        plt.tight_layout()
        
        if save_path:
            plt.savefig(save_path, dpi=300, bbox_inches='tight')
        
        plt.show()

## 5. Evaluation Framework

In [None]:
class ModelEvaluator:
    """Comprehensive model evaluation framework"""
    
    def __init__(self, model, device=None):
        self.model = model
        self.device = device or DEVICE
        self.model.to(self.device)
    
    def evaluate(self, test_loader, return_predictions=False):
        """Comprehensive evaluation on test set"""
        self.model.eval()
        
        all_predictions = []
        all_probabilities = []
        all_labels = []
        
        with torch.no_grad():
            for inputs, labels in tqdm(test_loader, desc='Evaluating'):
                inputs = inputs.to(self.device)
                
                outputs = self.model(inputs).squeeze()
                probabilities = torch.sigmoid(outputs)
                predictions = (probabilities > 0.5).float()
                
                all_predictions.extend(predictions.cpu().numpy())
                all_probabilities.extend(probabilities.cpu().numpy())
                all_labels.extend(labels.numpy())
        
        all_predictions = np.array(all_predictions)
        all_probabilities = np.array(all_probabilities)
        all_labels = np.array(all_labels)
        
        # Calculate metrics
        results = self._calculate_metrics(all_labels, all_predictions, all_probabilities)
        
        if return_predictions:
            results['predictions'] = all_predictions
            results['probabilities'] = all_probabilities
            results['labels'] = all_labels
        
        return results
    
    def _calculate_metrics(self, y_true, y_pred, y_prob):
        """Calculate comprehensive evaluation metrics"""
        # Basic metrics
        cm = confusion_matrix(y_true, y_pred)
        tn, fp, fn, tp = cm.ravel()
        
        accuracy = (tp + tn) / (tp + tn + fp + fn)
        precision = tp / (tp + fp) if (tp + fp) > 0 else 0
        recall = tp / (tp + fn) if (tp + fn) > 0 else 0
        specificity = tn / (tn + fp) if (tn + fp) > 0 else 0
        f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
        
        # ROC and PR curves
        fpr, tpr, _ = roc_curve(y_true, y_prob)
        roc_auc = auc(fpr, tpr)
        
        precision_curve, recall_curve, _ = precision_recall_curve(y_true, y_prob)
        pr_auc = average_precision_score(y_true, y_prob)
        
        return {
            'accuracy': accuracy,
            'precision': precision,
            'recall': recall,
            'specificity': specificity,
            'f1_score': f1,
            'roc_auc': roc_auc,
            'pr_auc': pr_auc,
            'confusion_matrix': cm,
            'tp': tp, 'fp': fp, 'tn': tn, 'fn': fn,
            'fpr': fpr, 'tpr': tpr,
            'precision_curve': precision_curve,
            'recall_curve': recall_curve
        }
    
    def plot_evaluation(self, results, model_name="Model", save_path=None):
        """Plot comprehensive evaluation results"""
        fig, axes = plt.subplots(2, 2, figsize=(15, 12))
        
        # Confusion Matrix
        disp = ConfusionMatrixDisplay(confusion_matrix=results['confusion_matrix'],
                                      display_labels=['Normal', 'Pneumonia'])
        disp.plot(ax=axes[0, 0], cmap='Blues', values_format='d')
        axes[0, 0].set_title(f'{model_name} - Confusion Matrix')
        
        # ROC Curve
        axes[0, 1].plot(results['fpr'], results['tpr'], 
                       label=f'ROC Curve (AUC = {results["roc_auc"]:.3f})')
        axes[0, 1].plot([0, 1], [0, 1], 'k--', label='Random')
        axes[0, 1].set_xlabel('False Positive Rate')
        axes[0, 1].set_ylabel('True Positive Rate')
        axes[0, 1].set_title(f'{model_name} - ROC Curve')
        axes[0, 1].legend()
        axes[0, 1].grid(True)
        
        # Precision-Recall Curve
        axes[1, 0].plot(results['recall_curve'], results['precision_curve'],
                       label=f'PR Curve (AUC = {results["pr_auc"]:.3f})')
        axes[1, 0].set_xlabel('Recall')
        axes[1, 0].set_ylabel('Precision')
        axes[1, 0].set_title(f'{model_name} - Precision-Recall Curve')
        axes[1, 0].legend()
        axes[1, 0].grid(True)
        
        # Metrics Bar Plot
        metrics = ['Accuracy', 'Precision', 'Recall', 'Specificity', 'F1-Score']
        values = [results['accuracy'], results['precision'], results['recall'], 
                 results['specificity'], results['f1_score']]
        
        bars = axes[1, 1].bar(metrics, values, color=['skyblue', 'lightgreen', 'lightcoral', 'lightyellow', 'lightpink'])
        axes[1, 1].set_ylabel('Score')
        axes[1, 1].set_title(f'{model_name} - Performance Metrics')
        axes[1, 1].set_ylim(0, 1)
        
        # Add value labels on bars
        for bar, value in zip(bars, values):
            axes[1, 1].text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01,
                           f'{value:.3f}', ha='center', va='bottom')
        
        plt.tight_layout()
        
        if save_path:
            plt.savefig(save_path, dpi=300, bbox_inches='tight')
        
        plt.show()
    
    def print_detailed_results(self, results, model_name="Model"):
        """Print detailed evaluation results"""
        print(f"\n{'='*60}")
        print(f"{model_name} - Detailed Evaluation Results")
        print(f"{'='*60}")
        
        print(f"\nClassification Metrics:")
        print(f"  Accuracy:    {results['accuracy']:.4f}")
        print(f"  Precision:   {results['precision']:.4f}")
        print(f"  Recall:      {results['recall']:.4f}")
        print(f"  Specificity: {results['specificity']:.4f}")
        print(f"  F1-Score:    {results['f1_score']:.4f}")
        
        print(f"\nAUC Scores:")
        print(f"  ROC AUC:     {results['roc_auc']:.4f}")
        print(f"  PR AUC:      {results['pr_auc']:.4f}")
        
        print(f"\nConfusion Matrix:")
        print(f"  True Positives:  {results['tp']}")
        print(f"  False Positives: {results['fp']}")
        print(f"  True Negatives:  {results['tn']}")
        print(f"  False Negatives: {results['fn']}")
        
        print(f"\nConfusion Matrix:")
        print(results['confusion_matrix'])

## 6. Model Comparison Framework

In [None]:
def compare_models(models_dict, test_loader, save_path=None):
    """
    Compare multiple models on the same test set
    
    Args:
        models_dict: Dictionary of {model_name: model} pairs
        test_loader: Test data loader
        save_path: Path to save comparison results
    """
    results = {}
    
    print("Evaluating models...")
    for model_name, model in models_dict.items():
        print(f"\nEvaluating {model_name}...")
        evaluator = ModelEvaluator(model)
        model_results = evaluator.evaluate(test_loader)
        results[model_name] = model_results
    
    # Create comparison DataFrame
    metrics = ['accuracy', 'precision', 'recall', 'specificity', 'f1_score', 'roc_auc', 'pr_auc']
    comparison_df = pd.DataFrame({
        model_name: [model_results[metric] for metric in metrics]
        for model_name, model_results in results.items()
    }, index=metrics)
    
    # Plot comparison
    fig, axes = plt.subplots(2, 2, figsize=(16, 12))
    
    # Metrics comparison
    comparison_df.T.plot(kind='bar', ax=axes[0, 0], rot=45)
    axes[0, 0].set_title('Model Performance Comparison')
    axes[0, 0].set_ylabel('Score')
    axes[0, 0].legend(bbox_to_anchor=(1.05, 1), loc='upper left')
    axes[0, 0].grid(True, alpha=0.3)
    
    # ROC curves comparison
    for model_name, model_results in results.items():
        axes[0, 1].plot(model_results['fpr'], model_results['tpr'], 
                       label=f'{model_name} (AUC = {model_results["roc_auc"]:.3f})')
    axes[0, 1].plot([0, 1], [0, 1], 'k--', label='Random')
    axes[0, 1].set_xlabel('False Positive Rate')
    axes[0, 1].set_ylabel('True Positive Rate')
    axes[0, 1].set_title('ROC Curves Comparison')
    axes[0, 1].legend()
    axes[0, 1].grid(True)
    
    # PR curves comparison
    for model_name, model_results in results.items():
        axes[1, 0].plot(model_results['recall_curve'], model_results['precision_curve'],
                       label=f'{model_name} (AUC = {model_results["pr_auc"]:.3f})')
    axes[1, 0].set_xlabel('Recall')
    axes[1, 0].set_ylabel('Precision')
    axes[1, 0].set_title('Precision-Recall Curves Comparison')
    axes[1, 0].legend()
    axes[1, 0].grid(True)
    
    # F1-Score vs Accuracy scatter
    for model_name, model_results in results.items():
        axes[1, 1].scatter(model_results['accuracy'], model_results['f1_score'], 
                          label=model_name, s=100)
        axes[1, 1].annotate(model_name, 
                           (model_results['accuracy'], model_results['f1_score']),
                           xytext=(5, 5), textcoords='offset points')
    axes[1, 1].set_xlabel('Accuracy')
    axes[1, 1].set_ylabel('F1-Score')
    axes[1, 1].set_title('Accuracy vs F1-Score')
    axes[1, 1].grid(True)
    
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
    
    plt.show()
    
    # Print comparison table
    print("\n" + "="*80)
    print("MODEL PERFORMANCE COMPARISON")
    print("="*80)
    print(comparison_df.round(4))
    
    # Save results
    if save_path:
        results_file = save_path.replace('.png', '_results.json')
        with open(results_file, 'w') as f:
            # Convert numpy arrays to lists for JSON serialization
            json_results = {}
            for model_name, model_results in results.items():
                json_results[model_name] = {
                    k: v.tolist() if isinstance(v, np.ndarray) else v
                    for k, v in model_results.items()
                    if k not in ['fpr', 'tpr', 'precision_curve', 'recall_curve', 'confusion_matrix']
                }
            json.dump(json_results, f, indent=2)
        
        comparison_df.to_csv(save_path.replace('.png', '_comparison.csv'))
    
    return results, comparison_df

## 7. Training Configuration and Execution

In [None]:
def train_model_pipeline(model_name, data_dir, config=None):
    """
    Complete training pipeline for a single model
    
    Args:
        model_name: Name of the model to train
        data_dir: Path to the data directory
        config: Training configuration dictionary
    """
    # Default configuration
    default_config = {
        'batch_size': 32,
        'image_size': 224,
        'epochs': 25,
        'learning_rate': 1e-4,
        'weight_decay': 1e-5,
        'early_stopping_patience': 5,
        'val_split': 0.2
    }
    
    if config:
        default_config.update(config)
    config = default_config
    
    print(f"Training {model_name} with configuration:")
    for key, value in config.items():
        print(f"  {key}: {value}")
    
    # Create data loaders
    train_loader, val_loader, test_loader = create_data_loaders(
        data_dir, 
        batch_size=config['batch_size'],
        image_size=config['image_size'],
        val_split=config['val_split']
    )
    
    # Create model
    model = get_model(model_name)
    
    # Define loss and optimizer
    criterion = nn.BCEWithLogitsLoss()
    optimizer = optim.Adam(model.parameters(), 
                          lr=config['learning_rate'], 
                          weight_decay=config['weight_decay'])
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=3, factor=0.5)
    
    # Create trainer
    trainer = ModelTrainer(model, criterion, optimizer, scheduler=scheduler)
    
    # Train model
    history = trainer.train(
        train_loader, val_loader, 
        epochs=config['epochs'],
        early_stopping_patience=config['early_stopping_patience']
    )
    
    # Evaluate model
    evaluator = ModelEvaluator(model)
    results = evaluator.evaluate(test_loader)
    
    # Print results
    evaluator.print_detailed_results(results, model_name)
    
    # Plot results
    trainer.plot_training_history(save_path=RESULTS_DIR / f"{model_name}_training_history.png")
    evaluator.plot_evaluation(results, model_name, save_path=RESULTS_DIR / f"{model_name}_evaluation.png")
    
    # Save model
    model_path = MODELS_DIR / f"{model_name}_weights.pth"
    torch.save(model.state_dict(), model_path)
    print(f"Model saved to {model_path}")
    
    return model, history, results, test_loader

## 8. Example Training and Evaluation

Uncomment the cells below to run training and evaluation:

In [None]:
# Configuration for training experiments
TRAINING_CONFIGS = {
    'xception': {
        'batch_size': 32,
        'epochs': 20,
        'learning_rate': 1e-4,
        'early_stopping_patience': 5
    },
    'xception_lstm': {
        'batch_size': 24,  # Smaller batch size for LSTM
        'epochs': 25,
        'learning_rate': 1e-4,
        'early_stopping_patience': 7
    },
    'fusion': {
        'batch_size': 24,  # Smaller batch size for fusion model
        'epochs': 30,
        'learning_rate': 5e-5,  # Lower learning rate for fusion
        'early_stopping_patience': 8
    }
}

print("Training configurations loaded.")
print("Uncomment the training cells below to start training.")

In [None]:
# Example: Train a single model
# DATA_DIR = "../data"  # Adjust path to your data directory

# model, history, results, test_loader = train_model_pipeline(
#     model_name='xception',
#     data_dir=DATA_DIR,
#     config=TRAINING_CONFIGS['xception']
# )

print("Single model training pipeline is ready.")
print("Uncomment the above code and set DATA_DIR to train a model.")

In [None]:
# Example: Train and compare all models
# DATA_DIR = "../data"  # Adjust path to your data directory

# trained_models = {}
# all_results = {}

# for model_name in ['xception', 'xception_lstm', 'fusion']:
#     print(f"\n{'='*60}")
#     print(f"Training {model_name.upper()}")
#     print(f"{'='*60}")
#     
#     model, history, results, test_loader = train_model_pipeline(
#         model_name=model_name,
#         data_dir=DATA_DIR,
#         config=TRAINING_CONFIGS[model_name]
#     )
#     
#     trained_models[model_name] = model
#     all_results[model_name] = results

# # Compare all models
# comparison_results, comparison_df = compare_models(
#     trained_models, 
#     test_loader, 
#     save_path=RESULTS_DIR / "model_comparison.png"
# )

print("Multi-model training and comparison pipeline is ready.")
print("Uncomment the above code and set DATA_DIR to train and compare all models.")

## 9. Cross-Validation Framework

In [None]:
def cross_validate_model(model_name, data_dir, k_folds=5, config=None):
    """
    Perform k-fold cross-validation for robust performance estimation
    
    Args:
        model_name: Name of the model to validate
        data_dir: Path to the data directory
        k_folds: Number of folds for cross-validation
        config: Training configuration
    """
    print(f"Performing {k_folds}-fold cross-validation for {model_name}")
    
    # Load dataset
    train_transform, _ = get_data_transforms(224)
    dataset = PneumoniaDataset(data_dir, train_transform, 'train')
    
    # Prepare for stratified k-fold
    skf = StratifiedKFold(n_splits=k_folds, shuffle=True, random_state=42)
    
    fold_results = []
    
    for fold, (train_idx, val_idx) in enumerate(skf.split(dataset.samples, dataset.labels)):
        print(f"\nFold {fold + 1}/{k_folds}")
        print("-" * 30)
        
        # Create fold datasets
        train_subset = torch.utils.data.Subset(dataset, train_idx)
        val_subset = torch.utils.data.Subset(dataset, val_idx)
        
        # Create data loaders
        train_loader = DataLoader(train_subset, batch_size=32, shuffle=True)
        val_loader = DataLoader(val_subset, batch_size=32, shuffle=False)
        
        # Create and train model
        model = get_model(model_name)
        criterion = nn.BCEWithLogitsLoss()
        optimizer = optim.Adam(model.parameters(), lr=1e-4)
        
        trainer = ModelTrainer(model, criterion, optimizer)
        history = trainer.train(train_loader, val_loader, epochs=10)  # Fewer epochs for CV
        
        # Evaluate fold
        evaluator = ModelEvaluator(model)
        val_results = evaluator.evaluate(val_loader)
        
        fold_results.append({
            'fold': fold + 1,
            'accuracy': val_results['accuracy'],
            'precision': val_results['precision'],
            'recall': val_results['recall'],
            'f1_score': val_results['f1_score'],
            'roc_auc': val_results['roc_auc']
        })
        
        print(f"Fold {fold + 1} Results:")
        print(f"  Accuracy: {val_results['accuracy']:.4f}")
        print(f"  F1-Score: {val_results['f1_score']:.4f}")
        print(f"  ROC AUC:  {val_results['roc_auc']:.4f}")
    
    # Aggregate results
    cv_df = pd.DataFrame(fold_results)
    
    print(f"\n{'='*50}")
    print(f"{k_folds}-Fold Cross-Validation Results for {model_name}")
    print(f"{'='*50}")
    
    for metric in ['accuracy', 'precision', 'recall', 'f1_score', 'roc_auc']:
        mean_val = cv_df[metric].mean()
        std_val = cv_df[metric].std()
        print(f"{metric.replace('_', ' ').title()::<12} {mean_val:.4f} Â± {std_val:.4f}")
    
    return cv_df

# Example usage:
# cv_results = cross_validate_model('xception', DATA_DIR, k_folds=5)

print("Cross-validation framework is ready.")
print("Uncomment the example usage above to run cross-validation.")

## 10. Summary and Best Practices

This comprehensive training and evaluation framework provides:

### Key Features:

1. **Unified Training Pipeline**: Consistent training procedures across all model types
2. **Comprehensive Evaluation**: Multiple metrics and visualizations
3. **Model Comparison**: Side-by-side performance analysis
4. **Cross-Validation**: Robust performance estimation
5. **Reproducibility**: Fixed random seeds and structured configuration

### Best Practices Implemented:

1. **Data Augmentation**: Improves model generalization
2. **Early Stopping**: Prevents overfitting
3. **Learning Rate Scheduling**: Adaptive learning rate adjustment
4. **Comprehensive Metrics**: Beyond accuracy for medical applications
5. **Model Persistence**: Save and load trained models

### Usage Guidelines:

1. **Start Simple**: Begin with basic CNN models before complex architectures
2. **Validate Properly**: Use stratified splits for class balance
3. **Monitor Training**: Watch for overfitting and convergence
4. **Compare Fairly**: Use same test set for all model comparisons
5. **Document Results**: Save configurations and results for reproducibility

### Medical AI Considerations:

1. **Sensitivity vs Specificity**: Balance based on clinical requirements
2. **Confidence Estimation**: Important for clinical decision support
3. **Interpretability**: Consider model explainability for medical use
4. **Robustness**: Test on diverse datasets and conditions

This framework enables systematic development and evaluation of pneumonia detection models suitable for clinical research and applications.