In [None]:
"""
DVG-M4oE: COMPLETE PRODUCTION CODE WITH INNOVATIONS - FIXED VERSION
- Integrated all components: Dataset, Model, Training
- MedNeXt backbone for feature extraction
- Focal Loss for class imbalance handling
- On-demand Frangi vessel mask generation if missing
- Comprehensive checkpoint system with torch.save
- Enhanced early stopping and metrics
- Self-contained script ready for RSNA 2025
"""

import sys
import os
import logging
import re
from glob import glob
from multiprocessing import Pool
from pathlib import Path
from typing import List, Tuple, Dict, Optional

import numpy as np
import pandas as pd
import pydicom
import nibabel as nib
from scipy import ndimage
from scipy.ndimage import distance_transform_edt
from skimage.filters import frangi
from skimage.morphology import remove_small_objects, skeletonize
from tqdm.auto import tqdm
import pickle

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import GradScaler, autocast
from torch.optim.lr_scheduler import StepLR
import torch.optim as optim

from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import roc_auc_score, average_precision_score

import shutil
import traceback
from multiprocessing.dummy import Pool as ThreadPool
from datetime import datetime
import gc
import copy
import time
import warnings
warnings.filterwarnings('ignore')

# ================================================================================
# M4oE COMPONENTS
# ================================================================================
def softmax(x: torch.Tensor, dim) -> torch.Tensor:
    max_vals = torch.amax(x, dim=dim, keepdim=True)
    e_x = torch.exp(x - max_vals)
    sum_exp = e_x.sum(dim=dim, keepdim=True)
    return e_x / sum_exp

class MutualInformationLoss(nn.Module):
    def __init__(self, epsilon=1e-4):
        super().__init__()
        self.epsilon = epsilon
    
    def forward(self, phi: torch.Tensor) -> torch.Tensor:
        batch_size, m, n, p = phi.shape
        
        phi = phi.reshape(phi.shape[0], phi.shape[1] * phi.shape[2] * phi.shape[3])
        phi = torch.softmax(phi, dim=1)
        phi = phi.reshape(phi.shape[0], m, n, p)
        
        p_m = phi.sum(dim=(2, 3))
        p_t = phi.sum(dim=(1, 2))
        p_mt = phi.sum(dim=2)
        
        denominator = p_m.unsqueeze(2) * p_t.unsqueeze(1)
        numerator = p_mt
        
        log_term = torch.log(numerator / (denominator + 1e-10))
        mutual_info = torch.sum(p_mt * log_term, dim=(0, 1, 2))
        
        return -mutual_info

# ================================================================================
# FIXED MedNeXt ENCODER IMPLEMENTATION
# ================================================================================
class MedNeXtEncoder(nn.Module):
    """Fixed MedNeXt encoder for feature extraction"""
    
    def __init__(self, in_channels=1, feature_dim=512):
        super().__init__()
        
        # Simplified MedNeXt-style encoder
        self.stem = nn.Sequential(
            nn.Conv3d(in_channels, 32, kernel_size=7, stride=2, padding=3),
            nn.BatchNorm3d(32),
            nn.ReLU(inplace=True)
        )
        
        # Down blocks
        self.down1 = self._make_layer(32, 64, stride=2)
        self.down2 = self._make_layer(64, 128, stride=2) 
        self.down3 = self._make_layer(128, 256, stride=2)
        self.down4 = self._make_layer(256, 512, stride=2)
        
        # Global pooling and projection
        self.global_pool = nn.AdaptiveAvgPool3d(1)
        self.projection = nn.Linear(512, feature_dim)
        
    def _make_layer(self, in_channels, out_channels, stride=1):
        return nn.Sequential(
            nn.Conv3d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1),
            nn.BatchNorm3d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm3d(out_channels),
            nn.ReLU(inplace=True)
        )
    
    def forward(self, x):
        x = self.stem(x)
        x = self.down1(x)
        x = self.down2(x)
        x = self.down3(x)
        x = self.down4(x)
        
        x = self.global_pool(x)
        x = x.flatten(1)
        x = self.projection(x)
        
        return x

# ================================================================================
# BLUEPRINT STAGES
# ================================================================================
class CTAAdaptiveProjection(nn.Module):
    """STAGE 1: CTA Adaptive Projection Layer"""
    def __init__(self):
        super().__init__()
        self.projection = nn.Conv3d(3, 1, kernel_size=1, bias=False)
        clinical_weights = torch.tensor([0.3, 0.6, 0.1])  # [Brain, Blood, Bone]
        self.projection.weight.data = clinical_weights.view(1, 3, 1, 1, 1)

    def forward(self, cta_3channel):
        return self.projection(cta_3channel)

class VesselGuidedAttention(nn.Module):
    """STAGE 3: Vessel-Guided Enhancement"""
    def __init__(self, feature_dim=512):
        super().__init__()
        
        self.spatial_attention = nn.Sequential(
            nn.Conv3d(1, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv3d(16, 1, kernel_size=1),
            nn.Sigmoid()
        )
        
        self.vessel_to_feature = nn.Sequential(
            nn.Linear(1, feature_dim),
            nn.Sigmoid()
        )
        
    def forward(self, features, vessel_mask=None):
        if vessel_mask is not None and vessel_mask.sum() > 0:
            vessel_attention_spatial = self.spatial_attention(vessel_mask)
            vessel_context_scalar = F.adaptive_avg_pool3d(vessel_attention_spatial, 1).flatten(1)
            vessel_attention_features = self.vessel_to_feature(vessel_context_scalar)
            enhanced_features = features * vessel_attention_features
            return enhanced_features
        else:
            return features

class ModalitySpecificMoE(nn.Module):
    """Custom MSoE with vessel guidance"""
    def __init__(self, num_experts=6, feature_dim=512):
        super().__init__()
        
        self.experts = nn.ModuleList([
            nn.Sequential(
                nn.Linear(feature_dim, feature_dim * 2),
                nn.ReLU(inplace=True),
                nn.Dropout(0.1),
                nn.Linear(feature_dim * 2, feature_dim),
                nn.LayerNorm(feature_dim)
            ) for _ in range(num_experts)
        ])
        
        self.vessel_gating = nn.Sequential(
            nn.Linear(feature_dim + 1, num_experts),
            nn.Softmax(dim=-1)
        )
        
    def forward(self, modality_features, vessel_context=None):
        B = modality_features.size(0)
        
        if vessel_context is not None:
            gating_input = torch.cat([modality_features, vessel_context], dim=1)
        else:
            zero_vessel = torch.zeros(B, 1, device=modality_features.device)
            gating_input = torch.cat([modality_features, zero_vessel], dim=1)
        
        expert_weights = self.vessel_gating(gating_input)
        
        expert_outputs = torch.stack([
            expert(modality_features) for expert in self.experts
        ], dim=1)
        
        output = torch.sum(expert_weights.unsqueeze(-1) * expert_outputs, dim=1)
        
        return output, expert_weights

class CrossModalTaskMoE(nn.Module):
    """Cross-Modal Task MoE with routing logits extraction for CMI"""
    def __init__(self, num_experts=8, num_tasks=14, feature_dim=512):
        super().__init__()
        
        self.task_embeddings = nn.Parameter(torch.randn(num_tasks, 64) * 0.02)
        
        self.experts = nn.ModuleList([
            nn.Sequential(
                nn.Linear(feature_dim * 4, feature_dim * 2),
                nn.ReLU(inplace=True),
                nn.Dropout(0.1),
                nn.Linear(feature_dim * 2, feature_dim),
                nn.LayerNorm(feature_dim)
            ) for _ in range(num_experts)
        ])
        
        self.routing_matrix = nn.Parameter(
            torch.zeros(feature_dim * 4, num_experts, num_tasks)
        )
        nn.init.normal_(self.routing_matrix, mean=0, std=1/(feature_dim * 4)**0.5)
        
        self.task_gating = nn.Sequential(
            nn.Linear(feature_dim * 4 + 64, feature_dim),
            nn.ReLU(),
            nn.Linear(feature_dim, num_experts),
            nn.Softmax(dim=-1)
        )
        
    def forward(self, all_modality_features, task_id):
        B = all_modality_features.size(0)
        
        if isinstance(task_id, int):
            task_embedding = self.task_embeddings[task_id].unsqueeze(0).expand(B, -1)
        else:
            task_embedding = self.task_embeddings[task_id]
        
        features_seq = all_modality_features.unsqueeze(1)
        routing_logits = torch.einsum("bmd,dnp->bmnp", features_seq, self.routing_matrix)
        
        gating_input = torch.cat([all_modality_features, task_embedding], dim=1)
        expert_weights = self.task_gating(gating_input)
        
        expert_outputs = torch.stack([
            expert(all_modality_features) for expert in self.experts
        ], dim=1)
        
        fused_output = torch.sum(expert_weights.unsqueeze(-1) * expert_outputs, dim=1)
        
        return fused_output, expert_weights, routing_logits

class SequentialExpertCommunication(nn.Module):
    """Sequential Expert Communication (Chain-of-Experts)"""
    def __init__(self, num_experts=6, num_iterations=3, feature_dim=512):
        super().__init__()
        
        self.communication_experts = nn.ModuleList([
            nn.Sequential(
                nn.Linear(feature_dim, feature_dim * 2),
                nn.ReLU(inplace=True),
                nn.Dropout(0.1),
                nn.Linear(feature_dim * 2, feature_dim),
                nn.LayerNorm(feature_dim)
            ) for _ in range(num_experts)
        ])
        
        self.routers = nn.ModuleList([
            nn.Sequential(
                nn.Linear(feature_dim + 1, feature_dim // 2),
                nn.ReLU(),
                nn.Linear(feature_dim // 2, num_experts),
                nn.Softmax(dim=-1)
            ) for _ in range(num_iterations)
        ])
        
    def forward(self, x, vessel_context=None):
        B, feature_dim = x.shape
        x_refined = x
        
        for iteration in range(len(self.routers)):
            if vessel_context is not None:
                routing_input = torch.cat([x_refined, vessel_context], dim=1)
            else:
                zero_context = torch.zeros(B, 1, device=x.device)
                routing_input = torch.cat([x_refined, zero_context], dim=1)
            
            expert_weights = self.routers[iteration](routing_input)
            
            expert_outputs = torch.stack([
                expert(x_refined) for expert in self.communication_experts
            ], dim=1)
            
            expert_output = torch.sum(expert_weights.unsqueeze(-1) * expert_outputs, dim=1)
            x_refined = expert_output + x_refined  # Residual connection
            
        return x_refined

# ================================================================================
# CMI LOSS COMPUTATION
# ================================================================================
def compute_cmi_loss_robust(mtoe_routing_logits, modality_indicators, epsilon=1e-8):
    """M4oE-based CMI loss using routing logits"""
    if mtoe_routing_logits is None:
        return torch.tensor(0.0)
    
    B, seq_len, N, P = mtoe_routing_logits.shape
    M = modality_indicators.shape[1]
    
    mi_loss_fn = MutualInformationLoss(epsilon=epsilon)
    
    routing_logits_expanded = mtoe_routing_logits.unsqueeze(1).repeat(1, M, 1, 1, 1)
    routing_logits_modal = routing_logits_expanded.squeeze(2)
    
    total_cmi = 0.0
    valid_tasks = 0
    
    for task_k in range(P):
        task_routing = routing_logits_modal[:, :, :, task_k]
        
        modality_mask = modality_indicators.float()
        if modality_mask.sum() == 0:
            continue
            
        task_routing_masked = task_routing * modality_mask.unsqueeze(-1)
        task_phi = task_routing_masked.unsqueeze(-1)
        
        try:
            cmi_task = mi_loss_fn(task_phi)
            if not torch.isnan(cmi_task) and not torch.isinf(cmi_task):
                total_cmi += cmi_task
                valid_tasks += 1
        except:
            continue
    
    if valid_tasks == 0:
        return torch.tensor(0.0, device=mtoe_routing_logits.device)
    
    return total_cmi / valid_tasks

# ================================================================================
# FEATURE EXTRACTION
# ================================================================================
class FeatureExtraction(nn.Module):
    """Feature extraction with MedNeXt encoders"""
    def __init__(self, target_dim: int = 512):
        super().__init__()
        
        self.encoders = nn.ModuleDict({
            'cta': MedNeXtEncoder(in_channels=1, feature_dim=target_dim),
            'mra': MedNeXtEncoder(in_channels=1, feature_dim=target_dim), 
            't1': MedNeXtEncoder(in_channels=1, feature_dim=target_dim),
            't2': MedNeXtEncoder(in_channels=1, feature_dim=target_dim)
        })
        
    def forward(self, modalities):
        features = {}
        for modality, data in modalities.items():
            features[modality] = self.encoders[modality](data)
        return features

# ================================================================================
# MAIN MODEL
# ================================================================================
class DVG_M4oE_Complete_Blueprint(nn.Module):
    """Complete blueprint implementation with all 7 stages"""
    def __init__(self, num_classes=14, feature_dim=512):
        super().__init__()
        
        self.num_classes = num_classes
        
        # Stage 1: CTA adaptive projection
        self.cta_projection = CTAAdaptiveProjection()
        
        # Stage 2: Feature extraction
        self.feature_extraction = FeatureExtraction(target_dim=feature_dim)
        
        # Stage 3: Vessel-guided attention
        self.vessel_attention = VesselGuidedAttention(feature_dim)
        
        # Stage 4: MSoE processing
        self.msoe_modules = nn.ModuleDict({
            modality: ModalitySpecificMoE(num_experts=6, feature_dim=feature_dim)
            for modality in ['cta', 'mra', 't1', 't2']
        })
        
        # Stage 6: Sequential expert communication
        self.expert_communication = SequentialExpertCommunication(
            num_experts=6, num_iterations=3, feature_dim=feature_dim
        )
        
        # Stage 5: MToE processing  
        self.mtoe_module = CrossModalTaskMoE(
            num_experts=8, num_tasks=num_classes, feature_dim=feature_dim
        )
        
        # Stage 7: Task-specific classification heads
        self.task_heads = nn.ModuleList([
            nn.Sequential(
                nn.Linear(feature_dim, feature_dim // 2),
                nn.ReLU(),
                nn.Dropout(0.1),
                nn.Linear(feature_dim // 2, 1)
            ) for _ in range(num_classes)
        ])
        
        # Missing modality handling
        self.missing_modality_embedding = nn.Parameter(torch.randn(feature_dim) * 0.02)
    
    def forward(self, batch):
        B = batch['cta'].size(0)
        
        # Stage 1: CTA adaptive projection
        cta_processed = self.cta_projection(batch['cta'])
        
        # Prepare modalities
        modalities = {
            'cta': cta_processed,
            'mra': batch['mra'],
            't1': batch['t1'], 
            't2': batch['t2']
        }
        
        # Stage 2: Feature extraction
        features = self.feature_extraction(modalities)
        
        # Handle missing modalities
        for modality in features:
            if features[modality].sum() == 0:
                features[modality] = self.missing_modality_embedding.unsqueeze(0).expand(B, -1)
        
        # Stage 3: Vessel-guided enhancement
        vessel_mask = batch.get('vessel_mask', None)
        enhanced_features = {}
        
        for mod, feat in features.items():
            enhanced_features[mod] = self.vessel_attention(feat, vessel_mask)

        # Vessel context
        vessel_context = None
        if vessel_mask is not None and vessel_mask.sum() > 0:
            vessel_context = F.adaptive_avg_pool3d(vessel_mask, 1).flatten(1)
        else:
            vessel_context = torch.zeros(B, 1, device=batch['cta'].device)
        
        # Stage 4: MSoE processing
        msoe_outputs = {}
        msoe_weights = {}
        
        for modality, feat in enhanced_features.items():
            msoe_out, exp_weights = self.msoe_modules[modality](feat, vessel_context)
            msoe_outputs[modality] = msoe_out
            msoe_weights[f'msoe_{modality}'] = exp_weights
        
        # Stage 6: Sequential expert communication
        communicated_outputs = {}
        for modality, msoe_out in msoe_outputs.items():
            communicated_outputs[modality] = self.expert_communication(
                msoe_out, vessel_context
            )
        
        # Stage 5: MToE cross-modal fusion + classification
        all_features = torch.cat(list(communicated_outputs.values()), dim=1)
        
        predictions = []
        mtoe_expert_weights = []
        mtoe_routing_logits = []
        
        for task_id in range(self.num_classes):
            task_features, task_exp_weights, task_routing_logits = self.mtoe_module(all_features, task_id)
            prediction = self.task_heads[task_id](task_features)
            
            predictions.append(prediction)
            mtoe_expert_weights.append(task_exp_weights)
            mtoe_routing_logits.append(task_routing_logits)
        
        predictions = torch.cat(predictions, dim=-1)
        stacked_routing_logits = torch.cat(mtoe_routing_logits, dim=-1)
        
        return {
            'predictions': predictions,
            'msoe_weights': msoe_weights,
            'mtoe_weights': torch.stack(mtoe_expert_weights, dim=1),
            'mtoe_routing_logits': stacked_routing_logits,
            'features': communicated_outputs
        }

# ================================================================================
# FOCAL LOSS
# ================================================================================
class FocalLoss(nn.Module):
    """Focal loss for class imbalance"""
    
    def __init__(self, gamma=2.0, alpha=None, reduction='mean'):
        super().__init__()
        self.gamma = gamma
        self.alpha = alpha
        self.reduction = reduction
    
    def forward(self, input, target):
        ce_loss = F.binary_cross_entropy_with_logits(input, target, reduction='none')
        pt = torch.exp(-ce_loss)
        f_loss = (1 - pt) ** self.gamma * ce_loss
        
        if self.alpha is not None:
            f_loss = self.alpha * f_loss
        
        if self.reduction == 'mean':
            return f_loss.mean()
        elif self.reduction == 'sum':
            return f_loss.sum()
        return f_loss

# ================================================================================
# LOSS COMPUTATION
# ================================================================================
def compute_total_loss_corrected(predictions, targets, model_outputs, modality_indicators, beta=0.1, focal_gamma=2.0):
    """Corrected loss computation with Focal Loss"""
    focal_loss = FocalLoss(gamma=focal_gamma)
    classification_loss = focal_loss(predictions, targets)
    
    if 'mtoe_routing_logits' in model_outputs:
        cmi_loss = compute_cmi_loss_robust(
            model_outputs['mtoe_routing_logits'],
            modality_indicators
        )
    else:
        cmi_loss = torch.tensor(0.0)
    
    total_loss = classification_loss - beta * cmi_loss
    
    return {
        'total_loss': total_loss,
        'classification_loss': classification_loss,
        'cmi_loss': cmi_loss
    }

# ================================================================================
# METRICS CLASS
# ================================================================================
class ProductionMetricsFixed:
    """Production metrics for RSNA tasks"""
    def __init__(self, task_names, short_names):
        self.task_names = task_names
        self.short_names = short_names
        self.reset()
    
    def reset(self):
        self.predictions = []
        self.targets = []
        self.losses = []
        self.classification_losses = []
        self.cmi_losses = []
        self.batch_times = []
        self.learning_rates = []
    
    def update(self, predictions, targets, total_loss, class_loss, cmi_loss, batch_time, lr=None):
        self.predictions.append(predictions.detach().cpu())
        self.targets.append(targets.detach().cpu())
        self.losses.append(total_loss.detach().cpu().item())
        self.classification_losses.append(class_loss.detach().cpu().item())
        self.cmi_losses.append(cmi_loss.detach().cpu().item() if isinstance(cmi_loss, torch.Tensor) else cmi_loss)
        self.batch_times.append(batch_time)
        if lr is not None:
            self.learning_rates.append(lr)
    
    def compute_rsna_metrics(self):
        if not self.predictions:
            return {}
        
        all_preds = torch.cat(self.predictions, dim=0).numpy()
        all_targets = torch.cat(self.targets, dim=0).numpy()
        
        # Compute AUC for each task
        task_aucs = {}
        valid_aucs = []
        
        for i, task_name in enumerate(self.short_names):
            try:
                if len(np.unique(all_targets[:, i])) > 1:
                    auc = roc_auc_score(all_targets[:, i], all_preds[:, i])
                    task_aucs[task_name] = auc
                    valid_aucs.append(auc)
                else:
                    task_aucs[task_name] = 0.5
            except:
                task_aucs[task_name] = 0.5
        
        mean_auc = np.mean(valid_aucs) if valid_aucs else 0.5
        aneurysm_auc = task_aucs.get('Aneurysm', 0.5)
        
        return {
            'loss': np.mean(self.losses),
            'classification_loss': np.mean(self.classification_losses),
            'cmi_loss': np.mean(self.cmi_losses),
            'mean_auc': mean_auc,
            'aneurysm_auc': aneurysm_auc,
            'task_aucs': task_aucs,
            'samples_per_sec': len(self.batch_times) / sum(self.batch_times) if self.batch_times else 0,
            'learning_rate': self.learning_rates[-1] if self.learning_rates else 0
        }

# ================================================================================
# CHECKPOINT MANAGER
# ================================================================================
class CheckpointManager:
    """Comprehensive checkpoint manager"""
    def __init__(self, config):
        self.config = config
        self.checkpoint_dir = Path(config['checkpoint_dir'])
        self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
        self.model_name = config['model_name']
    
    def save_checkpoint(self, model, optimizer, scheduler, scaler, epoch, train_results, val_results, history, config, checkpoint_type="latest", extra_info=""):
        if not self.config['save_checkpoints']:
            return None
        
        # Prepare checkpoint data
        checkpoint_data = {
            'epoch': epoch,
            'model_state_dict': model.state_dict() if self.config['save_model_state'] else None,
            'optimizer_state_dict': optimizer.state_dict() if self.config['save_optimizer_state'] else None,
            'scheduler_state_dict': scheduler.state_dict() if self.config['save_scheduler_state'] else None,
            'train_results': train_results,
            'val_results': val_results,
            'config': config if self.config['save_config'] else None,
            'timestamp': datetime.now().isoformat(),
            'checkpoint_type': checkpoint_type
        }
        
        if scaler and self.config['save_scaler_state']:
            checkpoint_data['scaler_state_dict'] = scaler.state_dict()
        
        if self.config['save_training_history']:
            checkpoint_data['training_history'] = history
        
        # Generate filename
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        if extra_info:
            filename = f"{self.model_name}_{checkpoint_type}_{extra_info}_{timestamp}.pt"
        else:
            filename = f"{self.model_name}_{checkpoint_type}_{timestamp}.pt"
        
        filepath = self.checkpoint_dir / filename
        
        # Save checkpoint
        torch.save(checkpoint_data, filepath)
        print(f"💾 Saved {checkpoint_type} checkpoint: {filename}")
        
        return filepath
    
    def save_latest_checkpoint(self, model, optimizer, scheduler, scaler, epoch, train_results, val_results, history, config):
        return self.save_checkpoint(model, optimizer, scheduler, scaler, epoch, train_results, val_results, history, config, "latest")
    
    def save_best_checkpoint(self, model, optimizer, scheduler, scaler, epoch, train_results, val_results, history, config):
        return self.save_checkpoint(model, optimizer, scheduler, scaler, epoch, train_results, val_results, history, config, "best")
    
    def save_periodic_checkpoint(self, model, optimizer, scheduler, scaler, epoch, train_results, val_results, history, config):
        if (epoch + 1) % self.config['periodic_epochs'] == 0:
            return self.save_checkpoint(model, optimizer, scheduler, scaler, epoch, train_results, val_results, history, config, "periodic", f"epoch_{epoch+1}")
        return None
    
    def save_final_checkpoint(self, model, optimizer, scheduler, scaler, epoch, train_results, val_results, history, config, stop_reason=""):
        extra = f"final_{stop_reason}" if stop_reason else "final"
        return self.save_checkpoint(model, optimizer, scheduler, scaler, epoch, train_results, val_results, history, config, "final", extra)
    
    def get_checkpoint_summary(self):
        checkpoints = list(self.checkpoint_dir.glob(f"{self.model_name}_*.pt"))
        return {
            'checkpoint_dir': str(self.checkpoint_dir),
            'total_checkpoints': len(checkpoints),
            'best_checkpoint': str(next((f for f in checkpoints if 'best' in f.name), 'None')),
            'latest_checkpoint': str(next((f for f in checkpoints if 'latest' in f.name), 'None'))
        }

# ================================================================================
# ENHANCED EARLY STOPPING
# ================================================================================
class EnhancedEarlyStopping:
    """Enhanced early stopping with multiple criteria"""
    def __init__(self, config):
        self.primary_metric = config['primary_metric']
        self.primary_mode = config['primary_mode']
        self.primary_patience = config['primary_patience']
        self.primary_min_delta = config['primary_min_delta']
        
        self.best_primary_score = float('-inf') if self.primary_mode == 'max' else float('inf')
        self.primary_patience_counter = 0
        self.best_epoch = -1
        self.best_model_state = None
        
        self.epochs_no_improvement = 0
        self.max_epochs_no_improvement = config.get('max_epochs_no_improvement', 8)
        
    def step(self, epoch, val_results, model, optimizer):
        current_score = val_results.get(self.primary_metric, 0.0)
        
        # Check if improved
        if self.primary_mode == 'max':
            improved = current_score > (self.best_primary_score + self.primary_min_delta)
        else:
            improved = current_score < (self.best_primary_score - self.primary_min_delta)
        
        if improved:
            self.best_primary_score = current_score
            self.primary_patience_counter = 0
            self.epochs_no_improvement = 0
            self.best_epoch = epoch
            self.best_model_state = copy.deepcopy(model.state_dict())
        else:
            self.primary_patience_counter += 1
            self.epochs_no_improvement += 1
        
        # Check stopping conditions
        should_stop = False
        stop_reason = ""
        
        if self.primary_patience_counter >= self.primary_patience:
            should_stop = True
            stop_reason = f"Primary metric ({self.primary_metric}) patience exceeded"
        elif self.epochs_no_improvement >= self.max_epochs_no_improvement:
            should_stop = True
            stop_reason = f"No improvement for {self.max_epochs_no_improvement} epochs"
        
        return should_stop, stop_reason, improved
    
    def restore_best_model(self, model, optimizer):
        if self.best_model_state is not None:
            model.load_state_dict(self.best_model_state)
            return True
        return False
    
    def get_status_info(self):
        return {
            'primary_patience': f"{self.primary_patience_counter}/{self.primary_patience}",
            'primary_best': f"{self.best_primary_score:.4f}",
            'epochs_no_improvement': self.epochs_no_improvement,
            'best_epoch': self.best_epoch
        }

# ================================================================================
# DATASET
# ================================================================================
class DVG_MultiModalRSNADataset(Dataset):
    """DVG-M4oE optimized dataset with Frangi vessel generation if missing"""
    
    def __init__(self, 
                 csv_path: str, 
                 data_root: str,
                 clean_files_path: Optional[str] = None,
                 split: str = 'train', 
                 train_ratio: float = 0.8,
                 target_size: Tuple[int, int, int] = (128, 128, 128),
                 verbose: bool = True,
                 debug_first_sample: bool = True):
        
        self.data_root = Path(data_root)
        self.split = split
        self.target_size = target_size
        self.verbose = verbose
        self.debug_first_sample = debug_first_sample
        self.first_sample_loaded = False
        
        # Kaggle dataset structure
        self.csv_root = self.data_root / "rsna-csv"
        self.cta1_root = self.data_root / "rsna-cta1" / "CTA1" 
        self.cta2_root = self.data_root / "rsna-cta2" / "CTA"
        self.remains_root = self.data_root / "rsna-remains"
        self.clean_files_root = self.data_root / "rsna-clean-files"
        
        # Load clean files
        self.valid_files = None
        clean_file_candidates = [
            Path(clean_files_path) if clean_files_path else None,
            self.clean_files_root / "clean_files_kaggle.txt",
            self.data_root / "clean_files_kaggle.txt",
        ]
        
        for clean_candidate in clean_file_candidates:
            if clean_candidate and clean_candidate.exists():
                with open(clean_candidate, 'r') as f:
                    self.valid_files = set(line.strip() for line in f if line.strip())
                if self.verbose:
                    print(f"✅ Loaded {len(self.valid_files)} clean files")
                break
        
        # Load CSV
        csv_candidates = [
            Path(csv_path),
            self.csv_root / "train.csv",
        ]
        
        df = None
        for csv_candidate in csv_candidates:
            if csv_candidate.exists():
                df = pd.read_csv(csv_candidate)
                if self.verbose:
                    print(f"✅ Loaded CSV: {csv_candidate}")
                break
        
        if df is None:
            raise FileNotFoundError(f"Could not find train.csv")
        
        # Filter by CTA modality
        if 'Modality' in df.columns:
            cta_df = df[df['Modality'] == 'CTA'].reset_index(drop=True)
        else:
            cta_df = df.copy()
        
        # Apply clean file filtering
        if self.valid_files:
            original_len = len(cta_df)
            cta_df = cta_df[cta_df['SeriesInstanceUID'].apply(
                lambda uid: self._is_series_in_clean_files(uid)
            )].reset_index(drop=True)
        
        # Train/val split
        n_train = int(len(cta_df) * train_ratio)
        if split == 'train':
            self.df = cta_df.iloc[:n_train].reset_index(drop=True)
        else:
            self.df = cta_df.iloc[n_train:].reset_index(drop=True)
        
        self.full_df = df
        
        # Labels
        self.label_cols = [
            'Left Infraclinoid Internal Carotid Artery',
            'Right Infraclinoid Internal Carotid Artery',
            'Left Supraclinoid Internal Carotid Artery', 
            'Right Supraclinoid Internal Carotid Artery',
            'Left Middle Cerebral Artery',
            'Right Middle Cerebral Artery',
            'Anterior Communicating Artery',
            'Left Anterior Cerebral Artery',
            'Right Anterior Cerebral Artery',
            'Left Posterior Communicating Artery',
            'Right Posterior Communicating Artery',
            'Basilar Tip',
            'Other Posterior Circulation',
            'Aneurysm Present'
        ]
    
    def _is_series_in_clean_files(self, series_uid: str) -> bool:
        if self.valid_files is None:
            return True
        possible_formats = [
            f"CTA/{series_uid}.nii.gz",
            f"CTA/{series_uid}.nii",
            f"MRA/{series_uid}.nii.gz",
            f"MRA/{series_uid}.nii",
            series_uid
        ]
        return any(fmt in self.valid_files for fmt in possible_formats)
    
    def __len__(self):
        return len(self.df)
    
    def find_file_with_extensions(self, base_path: Path, series_uid: str, 
                                   possible_extensions: List[str] = ['.nii.gz', '.nii']) -> Optional[Path]:
        for ext in possible_extensions:
            file_path = base_path / f"{series_uid}{ext}"
            if file_path.exists():
                return file_path
        return None
    
    def find_cta_file(self, series_uid: str) -> Optional[Path]:
        cta_bases = [self.cta1_root, self.cta2_root]
        for base in cta_bases:
            found = self.find_file_with_extensions(base, series_uid)
            if found:
                return found
        return None
    
    def load_and_resize_volume_128(self, nifti_path: Optional[Path], 
                                   expected_channels: int = 1) -> torch.Tensor:
        if nifti_path is None or not nifti_path.exists():
            return torch.zeros(expected_channels, *self.target_size, dtype=torch.float32)

        try:
            nifti_img = nib.load(str(nifti_path))
            volume = nifti_img.get_fdata().astype(np.float32)

            # Handle shapes
            if volume.ndim == 5:
                if volume.shape[-1] == 3:
                    volume = volume[..., 0, :].transpose(3, 0, 1, 2)
                else:
                    volume = volume[..., 0, 0][np.newaxis, ...]
                    
            elif volume.ndim == 4:
                if volume.shape[-1] == 3:
                    volume = volume.transpose(3, 0, 1, 2)
                elif volume.shape[-1] == 1:
                    volume = volume[..., 0][np.newaxis, ...]
                    
            elif volume.ndim == 3:
                volume = volume[np.newaxis, ...]

            # Ensure correct channels
            current_channels = volume.shape[0]
            if current_channels != expected_channels:
                if expected_channels == 1 and current_channels >= 1:
                    volume = volume[0:1, ...]
                elif expected_channels == 3:
                    if current_channels == 1:
                        volume = np.repeat(volume, 3, axis=0)
                    elif current_channels >= 3:
                        volume = volume[:3, ...]

            # Resize to 128³
            volume_tensor = torch.from_numpy(volume).unsqueeze(0).float()
            volume_resized = torch.nn.functional.interpolate(
                volume_tensor,
                size=self.target_size,
                mode='trilinear',
                align_corners=False
            )

            result = volume_resized.squeeze(0)
            return result

        except Exception as e:
            return torch.zeros(expected_channels, *self.target_size, dtype=torch.float32)
    
    def find_related_series(self, cta_series_uid: str, modality_folder: Path) -> Optional[Path]:
        # Strategy 1: Direct SeriesUID match
        direct_match = self.find_file_with_extensions(modality_folder, cta_series_uid)
        if direct_match:
            return direct_match
        
        # Strategy 2: Find via patient matching
        if hasattr(self, 'full_df'):
            cta_row = self.full_df[self.full_df['SeriesInstanceUID'] == cta_series_uid]
            if not cta_row.empty:
                cta_row = cta_row.iloc[0]
                patient_age = cta_row.get('PatientAge', None)
                patient_sex = cta_row.get('PatientSex', None)
                
                modality_name_map = {
                    'MRA': 'MRA',
                    'MRI_T1post': 'MRI T1post',
                    'MRI_T2': 'MRI T2'
                }
                target_modality = None
                for folder_name, csv_modality in modality_name_map.items():
                    if folder_name in str(modality_folder):
                        target_modality = csv_modality
                        break
                
                if target_modality and patient_age is not None:
                    candidate_rows = self.full_df[
                        (self.full_df['Modality'] == target_modality) &
                        (self.full_df['PatientAge'] == patient_age) &
                        (self.full_df['PatientSex'] == patient_sex)
                    ]
                    
                    for _, candidate_row in candidate_rows.iterrows():
                        candidate_uid = candidate_row['SeriesInstanceUID']
                        candidate_path = self.find_file_with_extensions(modality_folder, candidate_uid)
                        if candidate_path and candidate_path.exists():
                            return candidate_path
        
        return None
    
    def generate_vessel_mask_frangi(self, volume: torch.Tensor) -> torch.Tensor:
        """Generate vessel mask using Frangi if missing"""
        vol_np = volume.squeeze().cpu().numpy()
        vol_enh = frangi(vol_np, sigmas=(1, 5), scale_step=1)
        mask = vol_enh > np.percentile(vol_enh, 95)
        mask = remove_small_objects(mask, min_size=100)
        return torch.from_numpy(mask.astype(np.float32)).unsqueeze(0)
    
    def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
        row = self.df.iloc[idx]
        series_uid = row['SeriesInstanceUID']
        
        # Load CTA
        cta_path = self.find_cta_file(series_uid)
        cta = self.load_and_resize_volume_128(cta_path, expected_channels=3)

        # Load other modalities
        mra_path = self.find_related_series(series_uid, self.remains_root / 'MRA')
        mra = self.load_and_resize_volume_128(mra_path, expected_channels=1)
        
        t1_path = self.find_related_series(series_uid, self.remains_root / 'MRI_T1post')
        t1 = self.load_and_resize_volume_128(t1_path, expected_channels=1)
        
        t2_path = self.find_related_series(series_uid, self.remains_root / 'MRI_T2')
        t2 = self.load_and_resize_volume_128(t2_path, expected_channels=1)
        
        # Vessel mask - try to find or generate with Frangi
        vessel_mask_path = self.find_file_with_extensions(
            self.remains_root / 'vessel_masks', 
            f"{series_uid}_vessel_mask"
        )
        vessel_mask = self.load_and_resize_volume_128(vessel_mask_path, expected_channels=1)
        
        if vessel_mask.sum() == 0:
            vessel_mask = self.generate_vessel_mask_frangi(cta.mean(0, keepdim=True))

        # Modality indicators
        modality_indicators = torch.tensor([
            1.0 if cta.sum() > 0 else 0.0,
            1.0 if mra.sum() > 0 else 0.0,
            1.0 if t1.sum() > 0 else 0.0,
            1.0 if t2.sum() > 0 else 0.0
        ], dtype=torch.float32)

        # Labels
        labels = []
        for col in self.label_cols:
            if col in row and pd.notna(row[col]):
                labels.append(float(row[col]))
            else:
                labels.append(0.0)
        labels = torch.tensor(labels, dtype=torch.float32)

        # Metadata
        metadata = {
            'age': float(row.get('PatientAge', 0.0)),
            'sex': 1.0 if row.get('PatientSex', 'Unknown') == 'Male' else 0.0,
            'series_uid': series_uid,
            'modality_count': int(modality_indicators.sum().item()),
            'has_vessel_mask': 1.0 if vessel_mask.sum() > 0 else 0.0
        }

        return {
            'cta': cta,
            'mra': mra,
            't1': t1,
            't2': t2,
            'vessel_mask': vessel_mask,
            'labels': labels,
            'modality_indicators': modality_indicators,
            'metadata': metadata
        }

def dvg_collate_fn(batch: List[Dict]) -> Dict[str, torch.Tensor]:
    """Custom collate function"""
    keys = batch[0].keys()
    collated = {}
    
    for key in keys:
        if key == 'metadata':
            collated[key] = [sample[key] for sample in batch]
        else:
            collated[key] = torch.stack([sample[key] for sample in batch], dim=0)
    
    return collated

# ================================================================================
# HELPER FUNCTIONS
# ================================================================================
def create_dvg_m4oe_complete_blueprint(device: str = "cuda:0", num_gpus: int = 4, **kwargs):
    """Create DVG-M4oE model"""
    model = DVG_M4oE_Complete_Blueprint(**kwargs)
    model = model.to(device)
    
    if num_gpus > 1:
        model = nn.DataParallel(model, device_ids=list(range(num_gpus)))
    
    return model

def get_dvg_dataloaders(
    csv_path: str = "/kaggle/input/rsna-csv/train.csv",
    data_root: str = "/kaggle/input",
    clean_files_path: Optional[str] = "/kaggle/input/rsna-clean-files/clean_files_kaggle.txt",
    batch_size: int = 4,
    num_workers: int = 4,
    target_size: Tuple[int, int, int] = (128, 128, 128),
    train_ratio: float = 0.8,
    verbose: bool = True,
    debug_first_sample: bool = True
) -> Tuple[DataLoader, DataLoader]:
    
    train_dataset = DVG_MultiModalRSNADataset(
        csv_path=csv_path,
        data_root=data_root,
        clean_files_path=clean_files_path,
        split='train',
        target_size=target_size,
        verbose=verbose,
        debug_first_sample=debug_first_sample
    )
    
    val_dataset = DVG_MultiModalRSNADataset(
        csv_path=csv_path,
        data_root=data_root,
        clean_files_path=clean_files_path,
        split='val',
        target_size=target_size,
        verbose=verbose,
        debug_first_sample=False
    )

    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=True,
        drop_last=True,
        collate_fn=dvg_collate_fn,
        persistent_workers=True if num_workers > 0 else False
    )

    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True,
        collate_fn=dvg_collate_fn,
        persistent_workers=True if num_workers > 0 else False
    )

    return train_loader, val_loader

# ================================================================================
# CONFIGURATION CLASS
# ================================================================================
class ProductionConfigFixed:
    """Production configuration with innovations"""
    
    def __init__(self):
        # Data paths
        self.csv_path = "/kaggle/input/rsna-csv/train.csv"
        self.data_root = "/kaggle/input"
        self.clean_files_path = "/kaggle/input/rsna-clean-files/clean_files_kaggle.txt"
        
        # Training parameters
        self.num_epochs = 20
        self.batch_size = 4
        self.learning_rate = 1e-5
        self.weight_decay = 1e-4
        
        # Data loading
        self.num_workers = os.cpu_count()
        self.pin_memory = True
        
        # Model parameters
        self.num_classes = 14
        self.feature_dim = 512
        
        # Training optimization
        self.grad_clip_norm = 1.0
        self.cmi_loss_weight = 0.1
        self.mixed_precision = True
        self.memory_cleanup_freq = 50
        
        # Focal loss parameters
        self.focal_gamma = 2.0
        self.focal_alpha = None
        
        # Checkpoint configuration
        self.checkpoint_config = {
            'save_checkpoints': True,
            'checkpoint_dir': '/kaggle/working/checkpoints',
            'model_name': 'DVG_M4oE_RSNA2025',
            'save_best_model': True,
            'save_latest_model': True,
            'save_final_model': True,
            'save_periodic': True,
            'periodic_epochs': 5,
            'save_model_state': True,
            'save_optimizer_state': True,
            'save_scheduler_state': True,
            'save_scaler_state': True,
            'save_training_history': True,
            'save_config': True,
        }
        
        # Early stopping configuration
        self.early_stopping_config = {
            'primary_metric': 'mean_auc',
            'primary_mode': 'max',
            'primary_patience': 5,
            'primary_min_delta': 1e-4,
            'max_epochs_no_improvement': 8,
        }
        
        # RSNA task names
        self.rsna_task_names = [
            'Left Infraclinoid Internal Carotid Artery',
            'Right Infraclinoid Internal Carotid Artery', 
            'Left Supraclinoid Internal Carotid Artery',
            'Right Supraclinoid Internal Carotid Artery',
            'Left Middle Cerebral Artery',
            'Right Middle Cerebral Artery',
            'Anterior Communicating Artery',
            'Left Anterior Cerebral Artery', 
            'Right Anterior Cerebral Artery',
            'Left Posterior Communicating Artery',
            'Right Posterior Communicating Artery',
            'Basilar Tip',
            'Other Posterior Circulation',
            'Aneurysm Present'
        ]
        
        self.task_short_names = [
            'L_Infra_ICA', 'R_Infra_ICA', 'L_Supra_ICA', 'R_Supra_ICA',
            'L_MCA', 'R_MCA', 'AComA', 'L_ACA', 'R_ACA', 
            'L_PComA', 'R_PComA', 'Basilar', 'Other_Post', 'Aneurysm'
        ]

# ================================================================================
# TRAINER CLASS
# ================================================================================
class ProductionTrainerFixed:
    """Production trainer with innovations integrated"""
    
    def __init__(self, config: ProductionConfigFixed):
        self.config = config
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.num_gpus = torch.cuda.device_count()
        
        # GPU optimizations
        if torch.cuda.is_available():
            torch.backends.cudnn.benchmark = True
            torch.backends.cuda.matmul.allow_tf32 = True
            torch.backends.cudnn.allow_tf32 = True
        
        # Create model
        self.model = create_dvg_m4oe_complete_blueprint(
            device=self.device,
            num_gpus=self.num_gpus,
            num_classes=config.num_classes,
            feature_dim=config.feature_dim
        )
        
        if isinstance(self.model, nn.DataParallel):
            self.base_model = self.model.module
        else:
            self.base_model = self.model
        
        # Create optimizer
        self.optimizer = optim.AdamW(
            self.model.parameters(),
            lr=config.learning_rate,
            weight_decay=config.weight_decay
        )
        
        # Scheduler
        self.scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(
            self.optimizer, 
            T_0=5,
            T_mult=2,
            eta_min=1e-6
        )
        
        # Mixed precision
        self.scaler = GradScaler() if config.mixed_precision else None
        
        # Metrics
        self.train_metrics = ProductionMetricsFixed(config.rsna_task_names, config.task_short_names)
        self.val_metrics = ProductionMetricsFixed(config.rsna_task_names, config.task_short_names)
        
        # Checkpoint manager
        self.checkpoint_manager = CheckpointManager(config.checkpoint_config)
        
        # Early stopping
        self.early_stopping = EnhancedEarlyStopping(config.early_stopping_config)
        
        # Training state
        self.training_history = []
        self.start_time = None
    
    def _create_modality_indicators(self, batch):
        """Create modality indicators"""
        B = batch['cta'].size(0)
        device = batch['cta'].device
        
        modality_indicators = torch.ones(B, 4, device=device, dtype=torch.float32)
        
        for i, modality in enumerate(['mra', 't1', 't2']):
            if modality in batch:
                modality_sum = batch[modality].sum(dim=(1,2,3,4))
                modality_indicators[:, i+1] = (modality_sum > 0.01).float()
        
        return modality_indicators
    
    def train_epoch(self, train_loader, epoch):
        """Training epoch"""
        self.model.train()
        self.train_metrics.reset()
        
        pbar = tqdm(
            enumerate(train_loader), 
            total=len(train_loader),
            desc=f"🚀 Epoch {epoch+1}",
            ncols=140
        )
        
        for batch_idx, batch in pbar:
            batch_start = time.time()
            
            # Move to device
            for key in batch:
                if isinstance(batch[key], torch.Tensor):
                    batch[key] = batch[key].to(self.device)
            
            modality_indicators = self._create_modality_indicators(batch)
            self.optimizer.zero_grad()
            
            # Forward pass
            if self.config.mixed_precision and self.scaler:
                with autocast():
                    outputs = self.model(batch)
                    loss_dict = compute_total_loss_corrected(
                        outputs['predictions'], batch['labels'], outputs, 
                        modality_indicators, beta=self.config.cmi_loss_weight,
                        focal_gamma=self.config.focal_gamma
                    )
                
                self.scaler.scale(loss_dict['total_loss']).backward()
                
                if self.config.grad_clip_norm > 0:
                    self.scaler.unscale_(self.optimizer)
                    torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.grad_clip_norm)
                
                self.scaler.step(self.optimizer)
                self.scaler.update()
            else:
                outputs = self.model(batch)
                loss_dict = compute_total_loss_corrected(
                    outputs['predictions'], batch['labels'], outputs, 
                    modality_indicators, beta=self.config.cmi_loss_weight,
                    focal_gamma=self.config.focal_gamma
                )
                
                loss_dict['total_loss'].backward()
                
                if self.config.grad_clip_norm > 0:
                    torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.grad_clip_norm)
                
                self.optimizer.step()
            
            batch_time = time.time() - batch_start
            current_lr = self.optimizer.param_groups[0]['lr']
            
            # Update metrics
            self.train_metrics.update(
                outputs['predictions'], batch['labels'],
                loss_dict['total_loss'], loss_dict['classification_loss'], loss_dict['cmi_loss'],
                batch_time, current_lr
            )
            
            # Update progress bar
            if batch_idx % 10 == 0:
                pbar.set_postfix({
                    'Loss': f"{loss_dict['total_loss'].item():.4f}",
                    'Class': f"{loss_dict['classification_loss'].item():.4f}",
                    'CMI': f"{loss_dict['cmi_loss'].item():.3f}",
                    'LR': f"{current_lr:.2e}",
                    'Speed': f"{self.config.batch_size/batch_time:.1f}sps"
                })
            
            # Memory cleanup
            if batch_idx % self.config.memory_cleanup_freq == 0:
                torch.cuda.empty_cache()
        
        pbar.close()
        return self.train_metrics.compute_rsna_metrics()
    
    def validate_epoch(self, val_loader, epoch):
        """Validation epoch"""
        self.model.eval()
        self.val_metrics.reset()
        
        pbar = tqdm(
            val_loader,
            desc=f"📊 Validation {epoch+1}",
            ncols=120
        )
        
        with torch.no_grad():
            for batch in pbar:
                batch_start = time.time()
                
                for key in batch:
                    if isinstance(batch[key], torch.Tensor):
                        batch[key] = batch[key].to(self.device)
                
                modality_indicators = self._create_modality_indicators(batch)
                
                if self.config.mixed_precision and self.scaler:
                    with autocast():
                        outputs = self.model(batch)
                        loss_dict = compute_total_loss_corrected(
                            outputs['predictions'], batch['labels'], outputs, 
                            modality_indicators, beta=self.config.cmi_loss_weight,
                            focal_gamma=self.config.focal_gamma
                        )
                else:
                    outputs = self.model(batch)
                    loss_dict = compute_total_loss_corrected(
                        outputs['predictions'], batch['labels'], outputs, 
                        modality_indicators, beta=self.config.cmi_loss_weight,
                        focal_gamma=self.config.focal_gamma
                    )
                
                batch_time = time.time() - batch_start
                
                self.val_metrics.update(
                    outputs['predictions'], batch['labels'],
                    loss_dict['total_loss'], loss_dict['classification_loss'], loss_dict['cmi_loss'],
                    batch_time
                )
        
        pbar.close()
        return self.val_metrics.compute_rsna_metrics()
    
    def print_results(self, epoch, train_results, val_results):
        """Print comprehensive results"""
        print(f"\n🚀 EPOCH {epoch+1} RESULTS:")
        print("="*120)
        
        # Main metrics
        print(f"📊 Main Metrics:")
        print(f"   Train - Loss: {train_results['loss']:.4f}, AUC: {train_results['mean_auc']:.4f}")
        print(f"   Val   - Loss: {val_results['loss']:.4f}, AUC: {val_results['mean_auc']:.4f}")
        
        # Loss breakdown
        print(f"📉 Loss Breakdown:")
        print(f"   Train - Class: {train_results['classification_loss']:.4f}, CMI: {train_results['cmi_loss']:.4f}")
        print(f"   Val   - Class: {val_results['classification_loss']:.4f}, CMI: {val_results['cmi_loss']:.4f}")
        
        print("="*120)
    
    def train(self, train_loader, val_loader):
        """Complete production training"""
        print(f"🚀 STARTING PRODUCTION DVG TRAINING:")
        print("="*120)
        
        self.start_time = time.time()
        
        for epoch in range(self.config.num_epochs):
            print(f"\n🔄 EPOCH {epoch+1}/{self.config.num_epochs}")
            
            # Train and validate
            train_results = self.train_epoch(train_loader, epoch)
            val_results = self.validate_epoch(val_loader, epoch)
            
            # Scheduler step
            self.scheduler.step()
            
            # Print results
            self.print_results(epoch, train_results, val_results)
            
            # Early stopping check
            should_stop, stop_reason, improved = self.early_stopping.step(
                epoch, val_results, self.model, self.optimizer
            )
            
            # Save history
            self.training_history.append({
                'epoch': epoch,
                'train': train_results,
                'val': val_results,
                'timestamp': datetime.now().isoformat()
            })
            
            # Save checkpoints
            if self.config.checkpoint_config['save_checkpoints']:
                self.checkpoint_manager.save_latest_checkpoint(
                    self.model, self.optimizer, self.scheduler, self.scaler,
                    epoch, train_results, val_results, self.training_history, self.config
                )
                
                if improved:
                    self.checkpoint_manager.save_best_checkpoint(
                        self.model, self.optimizer, self.scheduler, self.scaler,
                        epoch, train_results, val_results, self.training_history, self.config
                    )
                
                self.checkpoint_manager.save_periodic_checkpoint(
                    self.model, self.optimizer, self.scheduler, self.scaler,
                    epoch, train_results, val_results, self.training_history, self.config
                )
            
            # Check for early stopping
            if should_stop:
                print(f"\n🛑 EARLY STOPPING TRIGGERED: {stop_reason}")
                break
        
        # Save final checkpoint
        if self.config.checkpoint_config['save_checkpoints']:
            self.checkpoint_manager.save_final_checkpoint(
                self.model, self.optimizer, self.scheduler, self.scaler,
                epoch, train_results, val_results, self.training_history, self.config, stop_reason
            )
        
        total_time = time.time() - self.start_time
        print(f"\n🎉 TRAINING COMPLETED!")
        print(f"   ⏱️ Total Time: {total_time/3600:.1f} hours")
        print(f"   🏆 Best Score: {self.early_stopping.best_primary_score:.4f}")
        
        return self.training_history

# ================================================================================
# MAIN FUNCTION
# ================================================================================
def train_production_dvg_corrected():
    """Production training with innovations applied"""
    
    print("="*140)
    print("🚀 DVG-M4oE PRODUCTION TRAINING - CORRECTED VERSION")
    print("="*140)
    
    config = ProductionConfigFixed()
    
    print(f"📥 Creating production dataloaders...")
    train_loader, val_loader = get_dvg_dataloaders(
        csv_path=config.csv_path,
        data_root=config.data_root,
        clean_files_path=config.clean_files_path,
        batch_size=config.batch_size,
        num_workers=config.num_workers,
        target_size=(32, 358, 358),
        verbose=True
    )
    
    trainer = ProductionTrainerFixed(config)
    history = trainer.train(train_loader, val_loader)
    
    return trainer, history

# ================================================================================
# EXECUTION
# ================================================================================
if __name__ == "__main__":
    try:
        trainer, history = train_production_dvg_corrected()
        print(f"\n🏆 TRAINING COMPLETED SUCCESSFULLY!")
        
    except Exception as e:
        print(f"❌ Training failed: {e}")
        import traceback
        traceback.print_exc()
