In [1]:
# Block 1: Dependencies and Setup

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import CosineAnnealingLR, ReduceLROnPlateau
import torchvision.transforms as transforms

import numpy as np
import pandas as pd
import nibabel as nib
import os
import random
from pathlib import Path
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import accuracy_score, roc_auc_score, confusion_matrix, classification_report
import warnings
warnings.filterwarnings('ignore')

# Progress bar
from tqdm import tqdm
import csv
import time

# Medical imaging libraries
try:
    import torchio as tio
    print("TorchIO available for medical augmentations")
except ImportError:
    print("TorchIO not found. Install with: pip install torchio")
    tio = None

try:
    import monai
    from monai.networks.nets import ResNet
    from monai.transforms import (
        Compose, LoadImage, ScaleIntensity, RandRotate, RandFlip,
        RandGaussianNoise, RandBiasField, Rand3DElastic
    )
    print("MONAI available for medical networks")
except ImportError:
    print("MONAI not found. Install with: pip install monai")
    monai = None

# Set device and random seeds for reproducibility
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(42)

# Configuration
CONFIG = {
    'data_path': r"C:\Users\acer\Desktop\Project_TMJOA\Data\training_dataset_3D",  # Update this to your data path
    'batch_size': 4,  # Small batch size for 3D volumes
    'learning_rate': 1e-4,
    'num_epochs': 100,
    'early_stopping_patience': 20,
    'num_folds': 5,
    'input_size': (254, 254, 254),  # Adjust based on your data
    'num_classes': 2,
    'weight_decay': 5e-5,
    'dropout_rate': 0.3,
    'ensemble_size': 5
}

print("Configuration loaded successfully")
print(f"Input size: {CONFIG['input_size']}")
print(f"Batch size: {CONFIG['batch_size']}")
print(f"Learning rate: {CONFIG['learning_rate']}")

TorchIO available for medical augmentations
MONAI available for medical networks
Using device: cpu
Configuration loaded successfully
Input size: (254, 254, 254)
Batch size: 4
Learning rate: 0.0001


In [2]:
from utils.augmentation import TMJAugmentations, MixUp3D, CutMix3D

tmj_augment_train = TMJAugmentations(training=True)
tmj_augment_val = TMJAugmentations(training=False)
mixup = MixUp3D(alpha=0.2)
cutmix = CutMix3D(alpha=1.0)

In [3]:
from utils.unit_test import validate_data_structure
from utils.dataloader import create_dataloaders

# Validate data structure first
if validate_data_structure(CONFIG['data_path']):
    # Create all dataloaders
    train_loader, val_loader, test_loader, train_dataset, val_dataset, test_dataset = create_dataloaders(
        CONFIG['data_path'], 
        CONFIG, 
        use_balanced_sampling=True
    )
    
    # Test loading a sample from each split
    if len(train_dataset) > 0:
        sample_volume, sample_label, sample_patient = train_dataset[0]
        print(f"\nTrain sample shape: {sample_volume.shape}")
        print(f"Train sample label: {sample_label}")
        print(f"Train sample patient ID: {sample_patient}")
    
    print("✅ Dataset loading successful!")
    
else:
    print("❌ Please fix data structure before proceeding")

Validating data structure at C:\Users\acer\Desktop\Project_TMJOA\Data\training_dataset_3D...
✅ train/0: 114 valid files
✅ train/1: 140 valid files
✅ val/0: 33 valid files
✅ val/1: 40 valid files
✅ test/0: 16 valid files
✅ test/1: 21 valid files
✅ Data structure validation passed!
Loaded train split with 254 samples:
  Class 0: 114 samples
  Class 1: 140 samples
Loaded val split with 73 samples:
  Class 0: 33 samples
  Class 1: 40 samples
Loaded test split with 37 samples:
  Class 0: 16 samples
  Class 1: 21 samples
Using balanced sampling for training data

DataLoaders created:
  Train: 254 samples, 63 batches
  Val:   73 samples, 19 batches
  Test:  37 samples, 10 batches

Train sample shape: torch.Size([1, 254, 254, 254])
Train sample label: 0
Train sample patient ID: 48-26453 L
✅ Dataset loading successful!


In [4]:
from utils.model import create_medical_resnet3d

print("Creating 3D ResNet model...")
model = create_medical_resnet3d(
    arch='resnet18',  # Start with ResNet-18 for smaller datasets
    num_classes=CONFIG['num_classes'],
    pretrained_path=None  # Add path to MedicalNet weights if available
)

print(f"Model created: {model.__class__.__name__}")
print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"Trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")

# Test model with sample input
try:
    test_input = torch.randn(1, 1, *CONFIG['input_size'])
    with torch.no_grad():
        test_output = model(test_input)
    print(f"Model test successful - Output shape: {test_output.shape}")
except Exception as e:
    print(f"Model test failed: {e}")

# Move model to device
model = model.to(device)

Creating 3D ResNet model...
Model created: ResNet3D
Total parameters: 33,161,026
Trainable parameters: 33,161,026
Model test successful - Output shape: torch.Size([1, 2])


In [5]:
# Block 6: Training and Validation Functions with Enhanced Logging

import csv
from tqdm import tqdm
import os

class CSVLogger:
    """CSV logger for training metrics"""
    def __init__(self, log_path):
        self.log_path = log_path
        self.fieldnames = ['epoch', 'train_loss', 'train_acc', 'val_loss', 'val_acc', 'val_auc', 'lr', 'best_val_acc']
        
        # Create directory if it doesn't exist
        os.makedirs(os.path.dirname(log_path), exist_ok=True)
        
        # Initialize CSV file with headers
        with open(self.log_path, 'w', newline='') as csvfile:
            writer = csv.DictWriter(csvfile, fieldnames=self.fieldnames)
            writer.writeheader()
    
    def log_epoch(self, epoch, train_loss, train_acc, val_loss, val_acc, val_auc, lr, best_val_acc):
        """Log metrics for one epoch"""
        with open(self.log_path, 'a', newline='') as csvfile:
            writer = csv.DictWriter(csvfile, fieldnames=self.fieldnames)
            writer.writerow({
                'epoch': epoch,
                'train_loss': f'{train_loss:.6f}',
                'train_acc': f'{train_acc:.6f}',
                'val_loss': f'{val_loss:.6f}',
                'val_acc': f'{val_acc:.6f}',
                'val_auc': f'{val_auc:.6f}',
                'lr': f'{lr:.8f}',
                'best_val_acc': f'{best_val_acc:.6f}'
            })

class ModelCheckpoint:
    """Enhanced model checkpoint saving"""
    def __init__(self, save_dir, model_name='model'):
        self.save_dir = Path(save_dir)
        self.save_dir.mkdir(exist_ok=True)
        self.model_name = model_name
        self.best_val_acc = -1.0
        
    def save_checkpoint(self, model, optimizer, scheduler, epoch, metrics, is_best=False, is_last=False):
        """Save model checkpoint"""
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict() if scheduler else None,
            'metrics': metrics,
            'best_val_acc': self.best_val_acc
        }
        
        if is_best:
            best_path = self.save_dir / f'{self.model_name}_best.pt'
            torch.save(checkpoint, best_path)
            print(f"💾 Saved best model: {best_path}")
            
        if is_last:
            last_path = self.save_dir / f'{self.model_name}_last.pt'
            torch.save(checkpoint, last_path)
            print(f"💾 Saved last model: {last_path}")

class EarlyStopping:
    """Early stopping to prevent overfitting"""
    def __init__(self, patience=20, min_delta=0.001, restore_best_weights=True):
        self.patience = patience
        self.min_delta = min_delta
        self.restore_best_weights = restore_best_weights
        self.best_loss = None
        self.counter = 0
        self.best_weights = None
        
    def __call__(self, val_loss, model):
        if self.best_loss is None:
            self.best_loss = val_loss
            self.save_checkpoint(model)
        elif val_loss < self.best_loss - self.min_delta:
            self.best_loss = val_loss
            self.counter = 0
            self.save_checkpoint(model)
        else:
            self.counter += 1
            
        if self.counter >= self.patience:
            if self.restore_best_weights:
                model.load_state_dict(self.best_weights)
            return True
        return False
    
    def save_checkpoint(self, model):
        self.best_weights = model.state_dict().copy()

def mixup_criterion(criterion, pred, y_a, y_b, lam):
    """Loss function for MixUp"""
    return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)

def train_epoch(model, dataloader, criterion, optimizer, device, epoch, use_mixup=True, use_cutmix=True):
    """Train for one epoch with progress bar and memory optimization"""
    model.train()
    running_loss = 0.0
    running_corrects = 0
    total_samples = 0
    
    # Progress bar
    pbar = tqdm(dataloader, desc=f'Epoch {epoch:3d} [Train]', 
                leave=False, ncols=100, ascii=True)
    
    for batch_idx, (inputs, targets, _) in enumerate(pbar):
        inputs, targets = inputs.to(device), targets.to(device)
        
        # Clear cache every few batches to prevent memory buildup
        if batch_idx % 5 == 0:
            torch.cuda.empty_cache()
        
        # Random choice between normal training, MixUp, and CutMix
        r = np.random.rand(1)
        if use_mixup and r < 0.3:
            # MixUp augmentation
            inputs, targets_a, targets_b, lam = mixup(inputs, targets)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = mixup_criterion(criterion, outputs, targets_a, targets_b, lam)
            
        elif use_cutmix and r < 0.6:
            # CutMix augmentation
            inputs, targets_a, targets_b, lam = cutmix(inputs, targets)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = mixup_criterion(criterion, outputs, targets_a, targets_b, lam)
            
        else:
            # Normal training
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
        
        loss.backward()
        
        # Gradient clipping for stability
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        
        optimizer.step()
        
        # Statistics
        running_loss += loss.item() * inputs.size(0)
        if r >= 0.6:  # Only count predictions for normal training
            _, preds = torch.max(outputs, 1)
            running_corrects += torch.sum(preds == targets.data)
            total_samples += targets.size(0)
        else:
            total_samples += targets.size(0)
        
        # Update progress bar
        current_loss = running_loss / ((batch_idx + 1) * dataloader.batch_size)
        current_acc = (running_corrects.double() / total_samples).item() if total_samples > 0 else 0.0
        pbar.set_postfix({
            'Loss': f'{current_loss:.4f}',
            'Acc': f'{current_acc:.4f}',
            'GPU': f'{torch.cuda.memory_allocated()/1024**3:.1f}GB' if torch.cuda.is_available() else 'CPU'
        })
        
        # Delete intermediate variables to free memory
        del outputs, loss
        if 'targets_a' in locals():
            del targets_a, targets_b
        
        # Clear cache periodically
        if batch_idx % 10 == 0:
            torch.cuda.empty_cache()
    
    pbar.close()
    
    epoch_loss = running_loss / len(dataloader.dataset)
    epoch_acc = running_corrects.double() / total_samples if total_samples > 0 else 0.0
    
    # Final cache clear
    torch.cuda.empty_cache()
    
    return epoch_loss, epoch_acc.item()

def validate_epoch(model, dataloader, criterion, device, epoch):
    """Validate for one epoch with progress bar"""
    model.eval()
    running_loss = 0.0
    running_corrects = 0
    all_preds = []
    all_labels = []
    all_probs = []
    
    # Progress bar
    pbar = tqdm(dataloader, desc=f'Epoch {epoch:3d} [Val]  ', 
                leave=False, ncols=100, ascii=True)
    
    with torch.no_grad():
        for batch_idx, (inputs, targets, _) in enumerate(pbar):
            inputs, targets = inputs.to(device), targets.to(device)
            
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            
            # Statistics
            running_loss += loss.item() * inputs.size(0)
            _, preds = torch.max(outputs, 1)
            running_corrects += torch.sum(preds == targets.data)
            
            # Store predictions for metrics
            probs = F.softmax(outputs, dim=1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(targets.cpu().numpy())
            all_probs.extend(probs.cpu().numpy())
            
            # Update progress bar
            current_loss = running_loss / ((batch_idx + 1) * dataloader.batch_size)
            current_acc = (running_corrects.double() / ((batch_idx + 1) * dataloader.batch_size)).item()
            pbar.set_postfix({
                'Loss': f'{current_loss:.4f}',
                'Acc': f'{current_acc:.4f}'
            })
    
    pbar.close()
    
    epoch_loss = running_loss / len(dataloader.dataset)
    epoch_acc = running_corrects.double() / len(dataloader.dataset)
    
    # Calculate AUC if binary classification
    try:
        if len(np.unique(all_labels)) == 2:
            auc = roc_auc_score(all_labels, [p[1] for p in all_probs])
        else:
            auc = 0.0
    except:
        auc = 0.0
    
    return epoch_loss, epoch_acc.item(), auc, all_preds, all_labels, all_probs

def print_epoch_results(epoch, num_epochs, train_loss, train_acc, val_loss, val_acc, val_auc, 
                       lr, best_val_acc, time_elapsed):
    """Print detailed epoch results"""
    print(f"\n{'='*80}")
    print(f"EPOCH {epoch+1:3d}/{num_epochs} RESULTS")
    print(f"{'='*80}")
    print(f"⏱️  Time: {time_elapsed:.1f}s")
    print(f"📚 Learning Rate: {lr:.8f}")
    print(f"")
    print(f"🏋️  TRAINING   → Loss: {train_loss:.6f} | Acc: {train_acc:.4f} ({train_acc*100:.2f}%)")
    print(f"✅ VALIDATION → Loss: {val_loss:.6f} | Acc: {val_acc:.4f} ({val_acc*100:.2f}%) | AUC: {val_auc:.4f}")
    print(f"🏆 BEST VAL   → Acc: {best_val_acc:.4f} ({best_val_acc*100:.2f}%)")
    
    if torch.cuda.is_available():
        print(f"🔧 GPU Memory: {torch.cuda.memory_allocated()/1024**3:.2f}GB / {torch.cuda.max_memory_allocated()/1024**3:.2f}GB peak")
    
    print(f"{'='*80}")

def train_model_with_validation(model, train_loader, val_loader, criterion, optimizer, 
                               scheduler, num_epochs, device, fold=None, save_dir=None):
    """Complete training loop with enhanced logging and checkpointing"""
    
    import time
    
    # Setup logging and checkpointing
    if save_dir is None:
        save_dir = Path("tmj_results") / "checkpoints"
    save_dir = Path(save_dir)
    save_dir.mkdir(parents=True, exist_ok=True)
    
    # Model name for saving
    model_name = f"tmj_model_fold{fold}" if fold else "tmj_model"
    
    # Initialize loggers
    csv_logger = CSVLogger(save_dir / f"{model_name}_training_log.csv")
    checkpoint_manager = ModelCheckpoint(save_dir, model_name)
    early_stopping = EarlyStopping(patience=CONFIG['early_stopping_patience'])
    
    # Initialize tracking
    train_losses, train_accs = [], []
    val_losses, val_accs, val_aucs = [], [], []
    
    # Best model tracking
    best_val_acc = 0.0
    
    print(f"\n🚀 Starting Training {'for ' + str(fold) if fold else ''}")
    print(f"📁 Logs and models will be saved to: {save_dir}")
    print(f"📊 CSV log: {model_name}_training_log.csv")
    print("="*80)
    
    for epoch in range(num_epochs):
        epoch_start_time = time.time()
        
        # Training phase
        train_loss, train_acc = train_epoch(
            model, train_loader, criterion, optimizer, device, epoch
        )
        
        # Validation phase
        val_loss, val_acc, val_auc, val_preds, val_labels, val_probs = validate_epoch(
            model, val_loader, criterion, device, epoch
        )
        
        # Learning rate scheduling
        if isinstance(scheduler, ReduceLROnPlateau):
            scheduler.step(val_loss)
        else:
            scheduler.step()
        
        # Get current learning rate
        current_lr = optimizer.param_groups[0]['lr']
        
        # Update best accuracy
        is_best = val_acc > best_val_acc
        if is_best:
            best_val_acc = val_acc
        
        # Store metrics
        train_losses.append(train_loss)
        train_accs.append(train_acc)
        val_losses.append(val_loss)
        val_accs.append(val_acc)
        val_aucs.append(val_auc)
        
        # Calculate epoch time
        epoch_time = time.time() - epoch_start_time
        
        # Print detailed results
        print_epoch_results(epoch, num_epochs, train_loss, train_acc, val_loss, val_acc, 
                          val_auc, current_lr, best_val_acc, epoch_time)
        
        # Log to CSV
        csv_logger.log_epoch(epoch, train_loss, train_acc, val_loss, val_acc, val_auc, 
                           current_lr, best_val_acc)
        
        # Save checkpoints
        metrics = {
            'train_loss': train_loss,
            'train_acc': train_acc,
            'val_loss': val_loss,
            'val_acc': val_acc,
            'val_auc': val_auc
        }
        
        checkpoint_manager.save_checkpoint(
            model, optimizer, scheduler, epoch, metrics, 
            is_best=is_best, is_last=(epoch == num_epochs - 1)
        )
        
        # Early stopping check
        if early_stopping(val_loss, model):
            print(f'\n🛑 Early stopping triggered at epoch {epoch+1}')
            print(f"   Best validation accuracy: {best_val_acc:.4f}")
            break
    
    # Final validation metrics
    final_val_loss, final_val_acc, final_val_auc, final_preds, final_labels, final_probs = validate_epoch(
        model, val_loader, criterion, device, num_epochs
    )
    
    # Confusion matrix
    cm = confusion_matrix(final_labels, final_preds)
    
    print(f"\n🏁 TRAINING COMPLETED!")
    print(f"📊 Final Results:")
    print(f"   Best Validation Accuracy: {best_val_acc:.4f}")
    print(f"   Final Validation Accuracy: {final_val_acc:.4f}")
    print(f"   Final Validation AUC: {final_val_auc:.4f}")
    print(f"📁 All logs and models saved to: {save_dir}")
    
    history = {
        'train_losses': train_losses,
        'train_accs': train_accs,
        'val_losses': val_losses,
        'val_accs': val_accs,
        'val_aucs': val_aucs,
        'final_val_acc': final_val_acc,
        'final_val_auc': final_val_auc,
        'best_val_acc': best_val_acc,
        'confusion_matrix': cm,
        'val_predictions': final_preds,
        'val_labels': final_labels,
        'val_probabilities': final_probs,
        'save_dir': save_dir,
        'model_name': model_name
    }
    
    return model, history

def create_optimizer_and_scheduler(model, train_loader):
    """Create optimizer and learning rate scheduler"""
    
    # AdamW optimizer with weight decay
    optimizer = optim.AdamW(
        model.parameters(),
        lr=CONFIG['learning_rate'],
        weight_decay=CONFIG['weight_decay'],
        betas=(0.9, 0.999)
    )
    
    # Cosine annealing with warm restarts
    scheduler = CosineAnnealingLR(
        optimizer,
        T_max=CONFIG['num_epochs'] // 4,  # Restart every 1/4 of total epochs
        eta_min=CONFIG['learning_rate'] * 0.01
    )
    
    return optimizer, scheduler

print("Enhanced training functions with comprehensive logging:")
print("- 📊 CSV logging for all metrics")
print("- 💾 Automatic best and last model saving")
print("- 📈 Real-time progress bars for batches")
print("- 🖨️  Detailed epoch results printing")
print("- 🔧 GPU memory monitoring")
print("- ⏱️  Timing information")
print("- 🛑 Enhanced early stopping")
print("- 📁 Organized checkpoint management")

Enhanced training functions with comprehensive logging:
- 📊 CSV logging for all metrics
- 💾 Automatic best and last model saving
- 📈 Real-time progress bars for batches
- 🖨️  Detailed epoch results printing
- 🔧 GPU memory monitoring
- ⏱️  Timing information
- 🛑 Enhanced early stopping
- 📁 Organized checkpoint management


In [9]:
# Block 7: Training and Evaluation Functions for Pre-Split Data

from utils.model import FocalLoss

def train_single_model(train_loader, val_loader, config):
    """Train a single model using pre-split data"""
    
    print("Training single model on pre-split data...")
    print("-" * 50)
    
    # Create model
    model = create_medical_resnet3d(
        arch='resnet18',
        num_classes=config['num_classes']
    ).to(device)
    
    # Create optimizer and scheduler
    optimizer, scheduler = create_optimizer_and_scheduler(model, train_loader)
    
    # Create loss function (Focal Loss for imbalance)
    criterion = FocalLoss(alpha=0.25, gamma=2.0)
    
    # Train the model
    trained_model, history = train_model_with_validation(
        model, train_loader, val_loader, criterion, optimizer,
        scheduler, config['num_epochs'], device
    )
    
    print(f"Single model training completed!")
    print(f"  Final Val Accuracy: {history['final_val_acc']:.4f}")
    print(f"  Final Val AUC: {history['final_val_auc']:.4f}")
    
    return trained_model, history

def train_ensemble_models(train_loader, val_loader, config, num_models=5):
    """Train ensemble of models with different initializations using pre-split data"""
    
    print(f"Training ensemble of {num_models} models on pre-split data...")
    print("=" * 60)
    
    ensemble_models = []
    ensemble_results = []
    
    for model_idx in range(num_models):
        print(f"\nTraining Ensemble Model {model_idx + 1}/{num_models}")
        print("-" * 40)
        
        # Set different random seed for each model
        set_seed(42 + model_idx * 10)
        
        # Create model with different initialization
        model = create_medical_resnet3d(
            arch='resnet18',
            num_classes=config['num_classes']
        ).to(device)
        
        # Create optimizer and scheduler
        optimizer, scheduler = create_optimizer_and_scheduler(model, train_loader)
        
        # Create loss function
        criterion = FocalLoss(alpha=0.25, gamma=2.0)
        
        # Train the model
        trained_model, history = train_model_with_validation(
            model, train_loader, val_loader, criterion, optimizer,
            scheduler, config['num_epochs'], device, fold=f"Model_{model_idx+1}"
        )
        
        # Store model and results
        ensemble_models.append(trained_model.state_dict().copy())
        ensemble_results.append(history)
        
        print(f"Model {model_idx + 1} - Val Accuracy: {history['final_val_acc']:.4f}")
    
    # Calculate ensemble statistics
    val_accs = [result['final_val_acc'] for result in ensemble_results]
    val_aucs = [result['final_val_auc'] for result in ensemble_results]
    
    print(f"\nEnsemble Training Results:")
    print(f"Individual accuracies: {[f'{acc:.4f}' for acc in val_accs]}")
    print(f"Individual AUCs: {[f'{auc:.4f}' for auc in val_aucs]}")
    print(f"Mean accuracy: {np.mean(val_accs):.4f} ± {np.std(val_accs):.4f}")
    print(f"Mean AUC: {np.mean(val_aucs):.4f} ± {np.std(val_aucs):.4f}")
    
    return ensemble_models, ensemble_results

def evaluate_ensemble(ensemble_models, test_loader, device):
    """Evaluate ensemble of models on test set"""
    
    print("Evaluating ensemble on test set...")
    
    # Create a model architecture for loading weights
    base_model = create_medical_resnet3d(
        arch='resnet18',
        num_classes=CONFIG['num_classes']
    ).to(device)
    
    all_ensemble_probs = []
    all_labels = []
    
    # Get predictions from each model
    for model_idx, model_weights in enumerate(ensemble_models):
        print(f"Getting predictions from model {model_idx + 1}/{len(ensemble_models)}...")
        
        base_model.load_state_dict(model_weights)
        base_model.eval()
        
        model_probs = []
        labels = []
        
        with torch.no_grad():
            for inputs, targets, _ in test_loader:
                inputs, targets = inputs.to(device), targets.to(device)
                
                outputs = base_model(inputs)
                probs = F.softmax(outputs, dim=1)
                
                model_probs.extend(probs.cpu().numpy())
                if model_idx == 0:  # Only store labels once
                    labels.extend(targets.cpu().numpy())
        
        all_ensemble_probs.append(model_probs)
        if model_idx == 0:
            all_labels = labels
    
    # Average ensemble predictions
    ensemble_probs = np.mean(all_ensemble_probs, axis=0)
    ensemble_preds = np.argmax(ensemble_probs, axis=1)
    
    # Calculate metrics
    ensemble_acc = accuracy_score(all_labels, ensemble_preds)
    try:
        ensemble_auc = roc_auc_score(all_labels, ensemble_probs[:, 1]) if ensemble_probs.shape[1] > 1 else 0.0
    except:
        ensemble_auc = 0.0
    
    # Individual model accuracies
    individual_accs = []
    individual_aucs = []
    for model_probs in all_ensemble_probs:
        model_preds = np.argmax(model_probs, axis=1)
        model_acc = accuracy_score(all_labels, model_preds)
        individual_accs.append(model_acc)
        
        try:
            model_auc = roc_auc_score(all_labels, np.array(model_probs)[:, 1]) if len(model_probs[0]) > 1 else 0.0
            individual_aucs.append(model_auc)
        except:
            individual_aucs.append(0.0)
    
    print(f"\n🎯 ENSEMBLE EVALUATION RESULTS:")
    print("=" * 50)
    print(f"Individual model accuracies: {[f'{acc:.4f}' for acc in individual_accs]}")
    print(f"Individual model AUCs: {[f'{auc:.4f}' for auc in individual_aucs]}")
    print(f"Ensemble accuracy: {ensemble_acc:.4f}")
    print(f"Ensemble AUC: {ensemble_auc:.4f}")
    print(f"Best individual accuracy: {max(individual_accs):.4f}")
    print(f"Improvement over best individual: {ensemble_acc - max(individual_accs):+.4f}")
    
    # Confusion matrix
    cm = confusion_matrix(all_labels, ensemble_preds)
    print(f"\nConfusion Matrix:")
    print(cm)
    
    return {
        'ensemble_accuracy': ensemble_acc,
        'ensemble_auc': ensemble_auc,
        'individual_accuracies': individual_accs,
        'individual_aucs': individual_aucs,
        'ensemble_predictions': ensemble_preds,
        'ensemble_probabilities': ensemble_probs,
        'labels': all_labels,
        'confusion_matrix': cm
    }

def evaluate_single_model(model, test_loader, device):
    """Evaluate a single model on test set"""
    
    print("Evaluating single model on test set...")
    
    model.eval()
    all_preds = []
    all_probs = []
    all_labels = []
    
    with torch.no_grad():
        for inputs, targets, _ in test_loader:
            inputs, targets = inputs.to(device), targets.to(device)
            
            outputs = model(inputs)
            probs = F.softmax(outputs, dim=1)
            preds = torch.argmax(outputs, dim=1)
            
            all_preds.extend(preds.cpu().numpy())
            all_probs.extend(probs.cpu().numpy())
            all_labels.extend(targets.cpu().numpy())
    
    # Calculate metrics
    accuracy = accuracy_score(all_labels, all_preds)
    try:
        auc = roc_auc_score(all_labels, [p[1] for p in all_probs]) if len(all_probs[0]) > 1 else 0.0
    except:
        auc = 0.0
    
    # Confusion matrix
    cm = confusion_matrix(all_labels, all_preds)
    
    print(f"\n🎯 SINGLE MODEL TEST RESULTS:")
    print("=" * 40)
    print(f"Test Accuracy: {accuracy:.4f}")
    print(f"Test AUC: {auc:.4f}")
    print(f"Confusion Matrix:")
    print(cm)
    
    return {
        'accuracy': accuracy,
        'auc': auc,
        'predictions': all_preds,
        'probabilities': all_probs,
        'labels': all_labels,
        'confusion_matrix': cm
    }

def plot_training_history(histories, title_prefix="", save_path=None):
    """Plot training curves from training history"""
    
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
    
    for i, history in enumerate(histories):
        # Training and validation loss
        axes[0, 0].plot(history['train_losses'], label=f'Model {i+1} Train', alpha=0.7)
        axes[0, 1].plot(history['val_losses'], label=f'Model {i+1} Val', alpha=0.7)
        
        # Training and validation accuracy
        axes[1, 0].plot(history['train_accs'], label=f'Model {i+1} Train', alpha=0.7)
        axes[1, 1].plot(history['val_accs'], label=f'Model {i+1} Val', alpha=0.7)
    
    axes[0, 0].set_title(f'{title_prefix}Training Loss')
    axes[0, 0].set_xlabel('Epoch')
    axes[0, 0].set_ylabel('Loss')
    axes[0, 0].legend()
    axes[0, 0].grid(True, alpha=0.3)
    
    axes[0, 1].set_title(f'{title_prefix}Validation Loss')
    axes[0, 1].set_xlabel('Epoch')
    axes[0, 1].set_ylabel('Loss')
    axes[0, 1].legend()
    axes[0, 1].grid(True, alpha=0.3)
    
    axes[1, 0].set_title(f'{title_prefix}Training Accuracy')
    axes[1, 0].set_xlabel('Epoch')
    axes[1, 0].set_ylabel('Accuracy')
    axes[1, 0].legend()
    axes[1, 0].grid(True, alpha=0.3)
    
    axes[1, 1].set_title(f'{title_prefix}Validation Accuracy')
    axes[1, 1].set_xlabel('Epoch')
    axes[1, 1].set_ylabel('Accuracy')
    axes[1, 1].legend()
    axes[1, 1].grid(True, alpha=0.3)
    
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        print(f"Training curves saved to {save_path}")
    
    plt.show()

def plot_confusion_matrix(cm, class_names=['Healthy', 'Osteoarthritis'], title='Confusion Matrix', save_path=None):
    """Plot confusion matrix with proper formatting"""
    
    plt.figure(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
                xticklabels=class_names, yticklabels=class_names)
    plt.title(title)
    plt.xlabel('Predicted')
    plt.ylabel('Actual')
    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        print(f"Confusion matrix saved to {save_path}")
    
    plt.show()

def comprehensive_model_evaluation(model_or_models, test_loader, device, is_ensemble=False):
    """Comprehensive evaluation with visualizations"""
    
    if is_ensemble:
        results = evaluate_ensemble(model_or_models, test_loader, device)
        
        # Plot confusion matrix for ensemble
        plot_confusion_matrix(
            results['confusion_matrix'], 
            title='Ensemble Model - Test Set Confusion Matrix'
        )
        
    else:
        results = evaluate_single_model(model_or_models, test_loader, device)
        
        # Plot confusion matrix for single model
        plot_confusion_matrix(
            results['confusion_matrix'], 
            title='Single Model - Test Set Confusion Matrix'
        )
    
    return results

print("Training and evaluation functions updated for pre-split data:")
print("- train_single_model(): Train one model on pre-split data")
print("- train_ensemble_models(): Train multiple models with different seeds")
print("- evaluate_ensemble(): Comprehensive ensemble evaluation")
print("- evaluate_single_model(): Single model evaluation") 
print("- comprehensive_model_evaluation(): Evaluation with visualizations")
print("- Enhanced plotting functions with save options")

Training and evaluation functions updated for pre-split data:
- train_single_model(): Train one model on pre-split data
- train_ensemble_models(): Train multiple models with different seeds
- evaluate_ensemble(): Comprehensive ensemble evaluation
- evaluate_single_model(): Single model evaluation
- comprehensive_model_evaluation(): Evaluation with visualizations
- Enhanced plotting functions with save options


In [7]:
# Block 8: Main Training Execution for Pre-Split Data

def main_training_pipeline():
    """Main pipeline for TMJ classification model training with pre-split data"""
    
    print("Starting TMJ 3D CNN Training Pipeline with Pre-Split Data")
    print("=" * 70)
    
    # Step 1: Validate data structure and load datasets
    print("\n1. Validating data structure and loading datasets...")
    try:
        if not validate_data_structure(CONFIG['data_path']):
            print("❌ Data structure validation failed. Please fix and try again.")
            return None
        
        # Create all dataloaders
        train_loader, val_loader, test_loader, train_dataset, val_dataset, test_dataset = create_dataloaders(
            CONFIG['data_path'], 
            CONFIG, 
            use_balanced_sampling=True
        )
        
        print("✅ All datasets loaded successfully!")
        
        # Display dataset statistics
        print(f"\nDataset Statistics:")
        print(f"  Training:   {len(train_dataset)} samples")
        print(f"  Validation: {len(val_dataset)} samples") 
        print(f"  Test:       {len(test_dataset)} samples")
        print(f"  Total:      {len(train_dataset) + len(val_dataset) + len(test_dataset)} samples")
        
    except Exception as e:
        print(f"❌ Error loading datasets: {e}")
        return None
    
    # Step 2: Choose training strategy
    print("\n2. Training Strategy Selection:")
    print("   a) Single model training (fastest, good for initial experiments)")
    print("   b) Ensemble training (best performance, multiple models)")
    print("   c) Both single and ensemble (comprehensive comparison)")
    
    strategy = input("Choose strategy (a/b/c) [default: a]: ").lower().strip()
    if strategy == '':
        strategy = 'a'
    
    results = {}
    
    if strategy == 'a':
        # Single model training
        print("\n3. Training Single Model...")
        print("=" * 50)
        
        model, history = train_single_model(train_loader, val_loader, CONFIG)
        
        # Evaluate on test set
        print("\n4. Evaluating on Test Set...")
        test_results = comprehensive_model_evaluation(
            model, test_loader, device, is_ensemble=False
        )
        
        # Plot training curves
        plot_training_history([history], "Single Model ")
        
        # Store results
        results = {
            'strategy': 'single_model',
            'model': model.state_dict(),
            'history': history,
            'test_results': test_results
        }
        
        print(f"\n🎯 FINAL RESULTS - Single Model:")
        print(f"  Validation Accuracy: {history['final_val_acc']:.4f}")
        print(f"  Test Accuracy: {test_results['accuracy']:.4f}")
        print(f"  Test AUC: {test_results['auc']:.4f}")
        
    elif strategy == 'b':
        # Ensemble training
        print("\n3. Training Ensemble Models...")
        print("=" * 50)
        
        ensemble_models, ensemble_histories = train_ensemble_models(
            train_loader, val_loader, CONFIG, num_models=CONFIG['ensemble_size']
        )
        
        # Evaluate ensemble on test set
        print("\n4. Evaluating Ensemble on Test Set...")
        test_results = comprehensive_model_evaluation(
            ensemble_models, test_loader, device, is_ensemble=True
        )
        
        # Plot training curves
        plot_training_history(ensemble_histories, "Ensemble ")
        
        # Store results
        results = {
            'strategy': 'ensemble',
            'models': ensemble_models,
            'histories': ensemble_histories,
            'test_results': test_results
        }
        
        print(f"\n🎯 FINAL RESULTS - Ensemble:")
        val_accs = [h['final_val_acc'] for h in ensemble_histories]
        print(f"  Validation Accuracy: {np.mean(val_accs):.4f} ± {np.std(val_accs):.4f}")
        print(f"  Test Accuracy: {test_results['ensemble_accuracy']:.4f}")
        print(f"  Test AUC: {test_results['ensemble_auc']:.4f}")
        
    else:
        # Both single and ensemble
        print("\n3. Training Both Single Model and Ensemble...")
        print("=" * 60)
        
        # Train single model
        print("\n3a. Training Single Model...")
        single_model, single_history = train_single_model(train_loader, val_loader, CONFIG)
        
        # Train ensemble
        print("\n3b. Training Ensemble Models...")
        ensemble_models, ensemble_histories = train_ensemble_models(
            train_loader, val_loader, CONFIG, num_models=CONFIG['ensemble_size']
        )
        
        # Evaluate both on test set
        print("\n4. Evaluating Both Approaches on Test Set...")
        
        print("\n4a. Single Model Test Results:")
        single_test_results = comprehensive_model_evaluation(
            single_model, test_loader, device, is_ensemble=False
        )
        
        print("\n4b. Ensemble Test Results:")
        ensemble_test_results = comprehensive_model_evaluation(
            ensemble_models, test_loader, device, is_ensemble=True
        )
        
        # Plot training curves
        plot_training_history([single_history], "Single Model ")
        plot_training_history(ensemble_histories, "Ensemble ")
        
        # Store results
        results = {
            'strategy': 'both',
            'single_model': single_model.state_dict(),
            'single_history': single_history,
            'single_test_results': single_test_results,
            'ensemble_models': ensemble_models,
            'ensemble_histories': ensemble_histories,
            'ensemble_test_results': ensemble_test_results
        }
        
        print(f"\n🎯 FINAL COMPARISON:")
        print("=" * 40)
        print(f"Single Model:")
        print(f"  Test Accuracy: {single_test_results['accuracy']:.4f}")
        print(f"  Test AUC: {single_test_results['auc']:.4f}")
        print(f"Ensemble Model:")
        print(f"  Test Accuracy: {ensemble_test_results['ensemble_accuracy']:.4f}")
        print(f"  Test AUC: {ensemble_test_results['ensemble_auc']:.4f}")
        print(f"Ensemble Improvement: {ensemble_test_results['ensemble_accuracy'] - single_test_results['accuracy']:+.4f}")
    
    # Step 5: Save results
    print(f"\n5. Saving Results...")
    save_results(results, strategy)
    
    return results

def save_results(results, strategy_name):
    """Save training results and models for pre-split data"""
    
    import pickle
    from datetime import datetime
    
    # Create results directory
    results_dir = Path("tmj_results")
    results_dir.mkdir(exist_ok=True)
    
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    
    # Save training results
    results_file = results_dir / f"{strategy_name}_results_{timestamp}.pkl"
    with open(results_file, 'wb') as f:
        pickle.dump({
            'results': results,
            'config': CONFIG,
            'strategy': strategy_name,
            'timestamp': timestamp
        }, f)
    
    # Save model weights
    if strategy_name == 'single_model':
        models_file = results_dir / f"single_model_{timestamp}.pt"
        torch.save(results['model'], models_file)
    elif strategy_name == 'ensemble':
        models_file = results_dir / f"ensemble_models_{timestamp}.pt"
        torch.save(results['models'], models_file)
    else:  # both
        single_file = results_dir / f"single_model_{timestamp}.pt"
        ensemble_file = results_dir / f"ensemble_models_{timestamp}.pt"
        torch.save(results['single_model'], single_file)
        torch.save(results['ensemble_models'], ensemble_file)
        models_file = f"{single_file.name} & {ensemble_file.name}"
    
    # Save detailed report
    report_file = results_dir / f"{strategy_name}_report_{timestamp}.txt"
    with open(report_file, 'w') as f:
        f.write(f"TMJ 3D CNN Training Report - {strategy_name.upper()}\n")
        f.write("=" * 70 + "\n")
        f.write(f"Timestamp: {timestamp}\n")
        f.write(f"Strategy: {strategy_name}\n\n")
        
        f.write("Configuration:\n")
        f.write("-" * 20 + "\n")
        for key, value in CONFIG.items():
            f.write(f"  {key}: {value}\n")
        f.write("\n")
        
        if strategy_name == "single_model":
            f.write("Single Model Results:\n")
            f.write("-" * 25 + "\n")
            f.write(f"  Validation Accuracy: {results['history']['final_val_acc']:.4f}\n")
            f.write(f"  Validation AUC: {results['history']['final_val_auc']:.4f}\n")
            f.write(f"  Test Accuracy: {results['test_results']['accuracy']:.4f}\n")
            f.write(f"  Test AUC: {results['test_results']['auc']:.4f}\n")
            
        elif strategy_name == "ensemble":
            val_accs = [h['final_val_acc'] for h in results['histories']]
            val_aucs = [h['final_val_auc'] for h in results['histories']]
            f.write("Ensemble Results:\n")
            f.write("-" * 20 + "\n")
            f.write(f"  Validation Accuracy: {np.mean(val_accs):.4f} ± {np.std(val_accs):.4f}\n")
            f.write(f"  Validation AUC: {np.mean(val_aucs):.4f} ± {np.std(val_aucs):.4f}\n")
            f.write(f"  Test Accuracy: {results['test_results']['ensemble_accuracy']:.4f}\n")
            f.write(f"  Test AUC: {results['test_results']['ensemble_auc']:.4f}\n")
            f.write(f"  Individual Test Accuracies: {[f'{acc:.4f}' for acc in results['test_results']['individual_accuracies']]}\n")
            
        else:  # both
            f.write("Comparison Results:\n")
            f.write("-" * 22 + "\n")
            f.write("Single Model:\n")
            f.write(f"  Test Accuracy: {results['single_test_results']['accuracy']:.4f}\n")
            f.write(f"  Test AUC: {results['single_test_results']['auc']:.4f}\n")
            f.write("Ensemble Model:\n")
            f.write(f"  Test Accuracy: {results['ensemble_test_results']['ensemble_accuracy']:.4f}\n")
            f.write(f"  Test AUC: {results['ensemble_test_results']['ensemble_auc']:.4f}\n")
            improvement = results['ensemble_test_results']['ensemble_accuracy'] - results['single_test_results']['accuracy']
            f.write(f"  Ensemble Improvement: {improvement:+.4f}\n")
    
    print(f"✅ Results saved to: {results_dir}")
    print(f"  - Training results: {results_file.name}")
    print(f"  - Model weights: {models_file}")
    print(f"  - Summary report: {report_file.name}")

def load_and_inference(model_path, test_loader, model_type='single'):
    """Load trained model and run inference on test data"""
    
    print(f"Loading {model_type} model for inference...")
    
    if model_type == 'ensemble':
        # Load ensemble models
        ensemble_models = torch.load(model_path, map_location=device)
        print(f"Loaded ensemble of {len(ensemble_models)} models")
        
        # Evaluate ensemble
        results = evaluate_ensemble(ensemble_models, test_loader, device)
        return results
        
    else:
        # Load single model
        model_weights = torch.load(model_path, map_location=device)
        print("Loaded single model")
        
        # Create model architecture and load weights
        model = create_medical_resnet3d(
            arch='resnet18',
            num_classes=CONFIG['num_classes']
        ).to(device)
        model.load_state_dict(model_weights)
        
        # Evaluate model
        results = evaluate_single_model(model, test_loader, device)
        return results

def quick_train():
    """Quick training function for testing/debugging"""
    
    print("🚀 Quick Training Mode (reduced epochs for testing)")
    
    # Temporarily reduce epochs for quick testing
    original_epochs = CONFIG['num_epochs']
    CONFIG['num_epochs'] = 5
    
    try:
        # Load data
        train_loader, val_loader, test_loader, _, _, _ = create_dataloaders(
            CONFIG['data_path'], CONFIG, use_balanced_sampling=True
        )
        
        # Train single model quickly
        model, history = train_single_model(train_loader, val_loader, CONFIG)
        
        # Quick evaluation
        test_results = evaluate_single_model(model, test_loader, device)
        
        print(f"✅ Quick training completed!")
        print(f"  Test Accuracy: {test_results['accuracy']:.4f}")
        
        return model, history, test_results
        
    finally:
        # Restore original epochs
        CONFIG['num_epochs'] = original_epochs

def validate_training_setup():
    """Validate that everything is set up correctly for training"""
    
    print("🔍 Validating Training Setup...")
    print("-" * 40)
    
    checks_passed = 0
    total_checks = 6
    
    # Check 1: Data structure
    if validate_data_structure(CONFIG['data_path']):
        print("✅ Data structure validation passed")
        checks_passed += 1
    else:
        print("❌ Data structure validation failed")
    
    # Check 2: GPU availability
    if torch.cuda.is_available():
        print(f"✅ GPU available: {torch.cuda.get_device_name()}")
        checks_passed += 1
    else:
        print("⚠️  GPU not available, will use CPU (slower)")
        checks_passed += 1
    
    # Check 3: Try loading a small batch
    try:
        train_loader, _, _, _, _, _ = create_dataloaders(CONFIG['data_path'], CONFIG)
        sample_batch = next(iter(train_loader))
        print(f"✅ Data loading works, batch shape: {sample_batch[0].shape}")
        checks_passed += 1
    except Exception as e:
        print(f"❌ Data loading failed: {e}")
    
    # Check 4: Model creation
    try:
        test_model = create_medical_resnet3d(arch='resnet18', num_classes=CONFIG['num_classes'])
        print(f"✅ Model creation successful")
        checks_passed += 1
    except Exception as e:
        print(f"❌ Model creation failed: {e}")
    
    # Check 5: Forward pass test
    try:
        test_input = torch.randn(1, 1, *CONFIG['input_size'])
        with torch.no_grad():
            test_output = test_model(test_input)
        print(f"✅ Model forward pass successful, output shape: {test_output.shape}")
        checks_passed += 1
    except Exception as e:
        print(f"❌ Model forward pass failed: {e}")
    
    # Check 6: Directory permissions
    try:
        results_dir = Path("tmj_results")
        results_dir.mkdir(exist_ok=True)
        test_file = results_dir / "test.txt"
        test_file.write_text("test")
        test_file.unlink()
        print("✅ Results directory writable")
        checks_passed += 1
    except Exception as e:
        print(f"❌ Cannot write to results directory: {e}")
    
    print(f"\n📊 Setup Validation: {checks_passed}/{total_checks} checks passed")
    
    if checks_passed == total_checks:
        print("🎉 All checks passed! Ready to train.")
        return True
    else:
        print("⚠️  Some checks failed. Please fix issues before training.")
        return False

if __name__ == "__main__":
    # Check if data path exists
    if not os.path.exists(CONFIG['data_path']):
        print(f"❌ Error: Data path {CONFIG['data_path']} does not exist!")
        print("Please update CONFIG['data_path'] to point to your TMJ dataset")
        print("Expected structure:")
        print("  tmj_data/")
        print("    ├── train/")
        print("    │   ├── 0/  (class 0 files)")
        print("    │   └── 1/  (class 1 files)")
        print("    ├── val/")
        print("    │   ├── 0/  (class 0 files)")
        print("    │   └── 1/  (class 1 files)")
        print("    └── test/")
        print("        ├── 0/  (class 0 files)")
        print("        └── 1/  (class 1 files)")
    else:
        # Validate setup
        if validate_training_setup():
            print("\n" + "="*50)
            print("🚀 READY TO START TRAINING!")
            print("="*50)
            print("Run: main_training_pipeline()")
            print("Or for quick test: quick_train()")
            print("="*50)
        else:
            print("\n❌ Please fix the issues above before proceeding.")

print("\n" + "=" * 70)
print("TMJ 3D CNN TRAINING SCRIPT - PRE-SPLIT DATA VERSION")
print("=" * 70)
print("\nAvailable functions:")
print("• main_training_pipeline() - Complete training pipeline")
print("• quick_train() - Fast training for testing (5 epochs)")
print("• validate_training_setup() - Check if everything is ready")
print("• load_and_inference(model_path, test_loader) - Load and test models")
print("\nUpdate CONFIG['data_path'] then run main_training_pipeline()!")
print("=" * 70)

🔍 Validating Training Setup...
----------------------------------------
Validating data structure at C:\Users\acer\Desktop\Project_TMJOA\Data\training_dataset_3D...
✅ train/0: 114 valid files
✅ train/1: 140 valid files
✅ val/0: 33 valid files
✅ val/1: 40 valid files
✅ test/0: 16 valid files
✅ test/1: 21 valid files
✅ Data structure validation passed!
✅ Data structure validation passed
⚠️  GPU not available, will use CPU (slower)
Loaded train split with 254 samples:
  Class 0: 114 samples
  Class 1: 140 samples
Loaded val split with 73 samples:
  Class 0: 33 samples
  Class 1: 40 samples
Loaded test split with 37 samples:
  Class 0: 16 samples
  Class 1: 21 samples
Using balanced sampling for training data

DataLoaders created:
  Train: 254 samples, 63 batches
  Val:   73 samples, 19 batches
  Test:  37 samples, 10 batches
✅ Data loading works, batch shape: torch.Size([4, 1, 254, 254, 254])
✅ Model creation successful
✅ Model forward pass successful, output shape: torch.Size([1, 2])
✅ R

In [10]:
main_training_pipeline()

Starting TMJ 3D CNN Training Pipeline with Pre-Split Data

1. Validating data structure and loading datasets...
Validating data structure at C:\Users\acer\Desktop\Project_TMJOA\Data\training_dataset_3D...
✅ train/0: 114 valid files
✅ train/1: 140 valid files
✅ val/0: 33 valid files
✅ val/1: 40 valid files
✅ test/0: 16 valid files
✅ test/1: 21 valid files
✅ Data structure validation passed!
Loaded train split with 254 samples:
  Class 0: 114 samples
  Class 1: 140 samples
Loaded val split with 73 samples:
  Class 0: 33 samples
  Class 1: 40 samples
Loaded test split with 37 samples:
  Class 0: 16 samples
  Class 1: 21 samples
Using balanced sampling for training data

DataLoaders created:
  Train: 254 samples, 63 batches
  Val:   73 samples, 19 batches
  Test:  37 samples, 10 batches
✅ All datasets loaded successfully!

Dataset Statistics:
  Training:   254 samples
  Validation: 73 samples
  Test:       37 samples
  Total:      364 samples

2. Training Strategy Selection:
   a) Single m

                                                                                                    

KeyboardInterrupt: 