In [None]:
"""
ECG Digitization - ROBUST Training Pipeline (NaN Loss Fixed)

FIXES FOR NaN LOSS:
1. Signal normalization - ECG signals can have extreme values
2. Gradient clipping - Prevent exploding gradients
3. Learning rate warmup - Start with small LR
4. Loss masking validation - Check for invalid values
5. Input validation - Check data quality
6. Stable initialization - Proper weight initialization
"""

import os
import gc
import warnings
from pathlib import Path

import numpy as np
import pandas as pd
import cv2
from tqdm.auto import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import autocast, GradScaler
import torchvision.models as models

import albumentations as A
from albumentations.pytorch import ToTensorV2
from sklearn.model_selection import train_test_split

warnings.filterwarnings('ignore')

# ===================== CONFIGURATION =====================
class RobustConfig:
    # Paths
    DATA_DIR = Path('/kaggle/input/physionet-ecg-image-digitization')
    OUTPUT_DIR = Path('/kaggle/working')
    MODEL_DIR = OUTPUT_DIR / 'models'
    
    # Image settings
    IMG_SIZE = (384, 768)
    BATCH_SIZE = 4  # Slightly larger for stability
    GRAD_ACCUM_STEPS = 4
    
    # Model
    MODEL_NAME = 'efficientnet_b0'
    
    # Training - CRITICAL FIXES
    EPOCHS = 20
    LR = 1e-4  # Lower LR to prevent instability
    WEIGHT_DECAY = 1e-5
    WARMUP_EPOCHS = 2  # Gradual LR increase
    MAX_GRAD_NORM = 1.0  # Gradient clipping
    
    # Numerical stability
    SIGNAL_NORMALIZATION = True  # Normalize ECG signals
    USE_MIXED_PRECISION = True
    LOSS_SCALE = 1.0  # Can adjust if needed
    
    # Data
    NUM_WORKERS = 2
    PIN_MEMORY = True
    
    # Debug
    DEBUG_MODE = True  # Enable extensive logging
    CHECK_NAN_FREQUENCY = 10  # Check every N batches
    
    # ECG settings
    LEADS = ['I', 'II', 'III', 'aVR', 'aVL', 'aVF', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6']
    MAX_SIGNAL_LENGTH = 5000
    
    # Device
    DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
    
    def __init__(self):
        self.MODEL_DIR.mkdir(parents=True, exist_ok=True)

config = RobustConfig()

print(f"Robust Configuration:")
print(f"  Signal Normalization: {config.SIGNAL_NORMALIZATION}")
print(f"  Learning Rate: {config.LR}")
print(f"  Warmup Epochs: {config.WARMUP_EPOCHS}")
print(f"  Gradient Clipping: {config.MAX_GRAD_NORM}")
print(f"  Debug Mode: {config.DEBUG_MODE}")

# ===================== SIGNAL NORMALIZATION =====================
class SignalNormalizer:
    """Normalize ECG signals to prevent numerical issues"""
    
    @staticmethod
    def normalize(signals):
        """Normalize signals to [-1, 1] range"""
        # Typical ECG range is -5mV to +5mV
        # Clip extreme outliers first
        signals = np.clip(signals, -10, 10)
        
        # Normalize per lead
        normalized = np.zeros_like(signals)
        for i in range(signals.shape[0]):
            lead = signals[i]
            # Robust normalization using percentiles
            p5, p95 = np.percentile(lead, [5, 95])
            if p95 - p5 > 0.01:  # Avoid division by zero
                normalized[i] = (lead - np.mean(lead)) / (p95 - p5 + 1e-8)
            else:
                normalized[i] = lead
        
        # Final clip
        normalized = np.clip(normalized, -5, 5)
        
        return normalized.astype(np.float32)
    
    @staticmethod
    def denormalize(signals, stats):
        """Denormalize back to original scale (for inference)"""
        # For now, keep normalized (model trained on normalized)
        return signals

# ===================== ROBUST DATASET =====================
class RobustECGDataset(Dataset):
    """Dataset with extensive validation and normalization"""
    
    def __init__(self, df, image_dir, transform=None, is_training=True):
        self.df = df.reset_index(drop=True)
        self.image_dir = image_dir
        self.transform = transform
        self.is_training = is_training
        self.normalizer = SignalNormalizer()
        
        # Pre-validate paths
        if config.DEBUG_MODE:
            self.validate_paths()
    
    def validate_paths(self):
        """Check if files exist"""
        print("Validating dataset paths...")
        missing = 0
        for idx in range(min(10, len(self.df))):
            row = self.df.iloc[idx]
            ecg_id = row['id']
            
            # Check image
            img_path = self.image_dir / str(ecg_id) / f"{ecg_id}-0001.png"
            if not img_path.exists():
                missing += 1
            
            # Check signal
            signal_path = self.image_dir / str(ecg_id) / f"{ecg_id}.csv"
            if not signal_path.exists():
                missing += 1
        
        if missing > 0:
            print(f"⚠ Warning: {missing} files missing in first 10 samples")
        else:
            print("✓ All paths valid")
    
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        try:
            row = self.df.iloc[idx]
            ecg_id = row['id']
            
            # Load image
            if self.is_training:
                segment = np.random.choice(['0001', '0003', '0004'])
            else:
                segment = '0001'
            
            img_path = self.image_dir / str(ecg_id) / f"{ecg_id}-{segment}.png"
            img = cv2.imread(str(img_path))
            
            if img is None:
                # Fallback to zeros
                img = np.ones((*config.IMG_SIZE, 3), dtype=np.uint8) * 128
            else:
                img = cv2.resize(img, (config.IMG_SIZE[1], config.IMG_SIZE[0]))
                img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            
            # Load signal
            signal_path = self.image_dir / str(ecg_id) / f"{ecg_id}.csv"
            signal_df = pd.read_csv(signal_path, usecols=config.LEADS)
            
            # Process signals
            signals = []
            for lead in config.LEADS:
                lead_signal = signal_df[lead].values.astype(np.float32)
                
                # Handle length
                if len(lead_signal) > config.MAX_SIGNAL_LENGTH:
                    lead_signal = lead_signal[:config.MAX_SIGNAL_LENGTH]
                elif len(lead_signal) < config.MAX_SIGNAL_LENGTH:
                    lead_signal = np.pad(lead_signal, (0, config.MAX_SIGNAL_LENGTH - len(lead_signal)))
                
                signals.append(lead_signal)
            
            signals = np.stack(signals, axis=0)
            
            # CRITICAL: Normalize signals
            if config.SIGNAL_NORMALIZATION:
                signals = self.normalizer.normalize(signals)
            
            # Check for NaN/Inf in signals
            if not np.isfinite(signals).all():
                print(f"Warning: Non-finite values in signals for {ecg_id}")
                signals = np.nan_to_num(signals, nan=0.0, posinf=5.0, neginf=-5.0)
            
            # Transform image
            if self.transform:
                augmented = self.transform(image=img)
                img = augmented['image']
            
            # Create mask
            actual_len = min(row['sig_len'], config.MAX_SIGNAL_LENGTH)
            mask = np.zeros((12, config.MAX_SIGNAL_LENGTH), dtype=np.float32)
            mask[:, :actual_len] = 1.0
            
            return {
                'image': img,
                'signals': torch.from_numpy(signals),
                'mask': torch.from_numpy(mask)
            }
        
        except Exception as e:
            print(f"Error loading sample {idx}: {e}")
            # Return dummy data
            return {
                'image': torch.zeros(3, config.IMG_SIZE[0], config.IMG_SIZE[1]),
                'signals': torch.zeros(12, config.MAX_SIGNAL_LENGTH),
                'mask': torch.ones(12, config.MAX_SIGNAL_LENGTH)
            }

# ===================== TRANSFORMS =====================
def get_robust_transforms(is_training):
    if is_training:
        return A.Compose([
            A.GaussNoise(var_limit=(5, 15), p=0.3),
            A.RandomBrightnessContrast(brightness_limit=0.1, contrast_limit=0.1, p=0.3),
            A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            ToTensorV2(),
        ])
    else:
        return A.Compose([
            A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            ToTensorV2(),
        ])

# ===================== ROBUST MODEL =====================
class RobustECGModel(nn.Module):
    """Model with proper initialization and stability"""
    
    def __init__(self):
        super().__init__()
        
        # Backbone
        self.backbone = models.efficientnet_b0(pretrained=True)
        self.feature_dim = 1280
        self.backbone.classifier = nn.Identity()
        
        # Decoder with proper initialization
        self.decoder = nn.Sequential(
            nn.Linear(self.feature_dim, 512),
            nn.LayerNorm(512),  # Add normalization
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(512, 256),
            nn.LayerNorm(256),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(256, 12 * config.MAX_SIGNAL_LENGTH)
        )
        
        # Initialize decoder weights properly
        self._initialize_weights()
    
    def _initialize_weights(self):
        """Proper weight initialization"""
        for m in self.decoder.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight, gain=0.1)  # Small gain for stability
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
    
    def forward(self, x):
        # Extract features
        features = self.backbone(x)
        
        # Check for NaN
        if not torch.isfinite(features).all():
            print("Warning: NaN in backbone features!")
            features = torch.nan_to_num(features, nan=0.0, posinf=1.0, neginf=-1.0)
        
        # Decode
        out = self.decoder(features)
        out = out.view(-1, 12, config.MAX_SIGNAL_LENGTH)
        
        # Tanh activation to bound outputs
        out = torch.tanh(out) * 5.0  # Output range: [-5, 5]
        
        return out

# ===================== ROBUST LOSS =====================
class RobustLoss(nn.Module):
    """Loss function with NaN protection"""
    
    def __init__(self):
        super().__init__()
    
    def forward(self, pred, target, mask):
        # Check inputs
        if not torch.isfinite(pred).all():
            print("⚠ NaN/Inf in predictions!")
            pred = torch.nan_to_num(pred, nan=0.0, posinf=5.0, neginf=-5.0)
        
        if not torch.isfinite(target).all():
            print("⚠ NaN/Inf in targets!")
            target = torch.nan_to_num(target, nan=0.0, posinf=5.0, neginf=-5.0)
        
        # Apply mask
        pred = pred * mask
        target = target * mask
        
        # Compute loss
        loss = F.mse_loss(pred, target, reduction='none')
        loss = (loss * mask).sum() / (mask.sum() + 1e-8)
        
        # Check loss
        if not torch.isfinite(loss):
            print("⚠ NaN/Inf in loss!")
            loss = torch.tensor(1.0, device=loss.device)
        
        return loss

# ===================== ROBUST TRAINER =====================
class RobustTrainer:
    """Trainer with extensive debugging and stability features"""
    
    def __init__(self, model, train_loader, val_loader):
        self.model = model.to(config.DEVICE)
        self.train_loader = train_loader
        self.val_loader = val_loader
        
        # Optimizer with lower LR
        self.optimizer = torch.optim.AdamW(
            model.parameters(),
            lr=config.LR,
            weight_decay=config.WEIGHT_DECAY,
            eps=1e-8  # Numerical stability
        )
        
        # Learning rate scheduler with warmup
        self.scheduler = self.get_scheduler()
        
        # Mixed precision
        self.scaler = GradScaler(enabled=config.USE_MIXED_PRECISION)
        
        # Loss
        self.criterion = RobustLoss()
        
        # Tracking
        self.best_loss = float('inf')
        self.global_step = 0
    
    def get_scheduler(self):
        """Scheduler with warmup"""
        from torch.optim.lr_scheduler import LambdaLR
        
        def lr_lambda(epoch):
            if epoch < config.WARMUP_EPOCHS:
                # Linear warmup
                return (epoch + 1) / config.WARMUP_EPOCHS
            else:
                # Cosine decay
                progress = (epoch - config.WARMUP_EPOCHS) / (config.EPOCHS - config.WARMUP_EPOCHS)
                return 0.5 * (1 + np.cos(np.pi * progress))
        
        return LambdaLR(self.optimizer, lr_lambda)
    
    def check_model_health(self):
        """Check for NaN in model parameters"""
        for name, param in self.model.named_parameters():
            if not torch.isfinite(param).all():
                print(f"⚠ NaN/Inf in parameter: {name}")
                return False
        return True
    
    def train_epoch(self, epoch):
        self.model.train()
        total_loss = 0
        num_batches = 0
        
        self.optimizer.zero_grad()
        
        pbar = tqdm(self.train_loader, desc=f'Epoch {epoch+1}/{config.EPOCHS}')
        
        for step, batch in enumerate(pbar):
            try:
                images = batch['image'].to(config.DEVICE, non_blocking=True)
                signals = batch['signals'].to(config.DEVICE, non_blocking=True)
                mask = batch['mask'].to(config.DEVICE, non_blocking=True)
                
                # Check inputs
                if config.DEBUG_MODE and step % config.CHECK_NAN_FREQUENCY == 0:
                    if not torch.isfinite(images).all():
                        print(f"⚠ NaN in images at step {step}")
                    if not torch.isfinite(signals).all():
                        print(f"⚠ NaN in signals at step {step}")
                
                # Forward pass
                with autocast(enabled=config.USE_MIXED_PRECISION):
                    pred = self.model(images)
                    loss = self.criterion(pred, signals, mask)
                    loss = loss / config.GRAD_ACCUM_STEPS
                
                # Check loss
                if not torch.isfinite(loss):
                    print(f"⚠ NaN loss at step {step}, skipping batch")
                    self.optimizer.zero_grad()
                    continue
                
                # Backward
                self.scaler.scale(loss).backward()
                
                # Update weights
                if (step + 1) % config.GRAD_ACCUM_STEPS == 0:
                    # Unscale gradients
                    self.scaler.unscale_(self.optimizer)
                    
                    # Clip gradients
                    grad_norm = torch.nn.utils.clip_grad_norm_(
                        self.model.parameters(), 
                        config.MAX_GRAD_NORM
                    )
                    
                    if config.DEBUG_MODE and grad_norm > 10:
                        print(f"⚠ Large gradient norm: {grad_norm:.2f}")
                    
                    # Optimizer step
                    self.scaler.step(self.optimizer)
                    self.scaler.update()
                    self.optimizer.zero_grad()
                
                total_loss += loss.item() * config.GRAD_ACCUM_STEPS
                num_batches += 1
                
                # Update progress bar
                avg_loss = total_loss / num_batches
                pbar.set_postfix({
                    'loss': f'{avg_loss:.6f}',
                    'lr': f'{self.optimizer.param_groups[0]["lr"]:.2e}'
                })
                
                self.global_step += 1
                
            except RuntimeError as e:
                print(f"Error at step {step}: {e}")
                self.optimizer.zero_grad()
                continue
        
        if num_batches == 0:
            return float('nan')
        
        return total_loss / num_batches
    
    def validate(self):
        self.model.eval()
        total_loss = 0
        num_batches = 0
        
        with torch.no_grad():
            for batch in tqdm(self.val_loader, desc='Validation'):
                try:
                    images = batch['image'].to(config.DEVICE, non_blocking=True)
                    signals = batch['signals'].to(config.DEVICE, non_blocking=True)
                    mask = batch['mask'].to(config.DEVICE, non_blocking=True)
                    
                    with autocast(enabled=config.USE_MIXED_PRECISION):
                        pred = self.model(images)
                        loss = self.criterion(pred, signals, mask)
                    
                    if torch.isfinite(loss):
                        total_loss += loss.item()
                        num_batches += 1
                
                except RuntimeError as e:
                    print(f"Error in validation: {e}")
                    continue
        
        if num_batches == 0:
            return float('nan')
        
        return total_loss / num_batches
    
    def fit(self):
        print("\nStarting training...")
        print(f"Total parameters: {sum(p.numel() for p in self.model.parameters()):,}")
        
        for epoch in range(config.EPOCHS):
            # Check model health
            if not self.check_model_health():
                print("⚠ Model has NaN parameters, stopping training")
                break
            
            # Train
            train_loss = self.train_epoch(epoch)
            
            # Validate
            val_loss = self.validate()
            
            # Print
            print(f'Epoch {epoch+1}: Train={train_loss:.6f}, Val={val_loss:.6f}')
            
            # Save best
            if np.isfinite(val_loss) and val_loss < self.best_loss:
                self.best_loss = val_loss
                self.save_model('best_model.pth')
                print(f'✓ Best model saved (val={val_loss:.6f})')
            
            # Step scheduler
            self.scheduler.step()
            
            # Memory cleanup
            if epoch % 2 == 0:
                gc.collect()
                torch.cuda.empty_cache()
        
        return self.best_loss
    
    def save_model(self, filename):
        torch.save(
            self.model.state_dict(),
            config.MODEL_DIR / filename
        )

# ===================== MAIN =====================
def main():
    print("="*60)
    print("ROBUST TRAINING PIPELINE (NaN Fix)")
    print("="*60)
    
    # Load metadata
    print("\nLoading metadata...")
    train_df = pd.read_csv(config.DATA_DIR / 'train.csv')
    print(f"Total samples: {len(train_df)}")
    
    # Quick sample statistics
    print(f"\nDataset statistics:")
    print(f"  Sampling frequencies: {train_df['fs'].unique()}")
    print(f"  Signal lengths: {train_df['sig_len'].describe()}")
    
    # Split
    train_ids, val_ids = train_test_split(
        train_df['id'].unique(),
        test_size=0.15,
        random_state=42
    )
    
    train_subset = train_df[train_df['id'].isin(train_ids)].reset_index(drop=True)
    val_subset = train_df[train_df['id'].isin(val_ids)].reset_index(drop=True)
    
    # Use subset for initial testing
    if config.DEBUG_MODE:
        train_subset = train_subset.sample(min(200, len(train_subset))).reset_index(drop=True)
        val_subset = val_subset.sample(min(50, len(val_subset))).reset_index(drop=True)
    
    print(f"Train: {len(train_subset)}, Val: {len(val_subset)}")
    
    # Create datasets
    print("\nCreating datasets...")
    train_dataset = RobustECGDataset(
        train_subset,
        config.DATA_DIR / 'train',
        transform=get_robust_transforms(True),
        is_training=True
    )
    
    val_dataset = RobustECGDataset(
        val_subset,
        config.DATA_DIR / 'train',
        transform=get_robust_transforms(False),
        is_training=False
    )
    
    # Create dataloaders
    train_loader = DataLoader(
        train_dataset,
        batch_size=config.BATCH_SIZE,
        shuffle=True,
        num_workers=config.NUM_WORKERS,
        pin_memory=config.PIN_MEMORY,
        drop_last=True  # Drop incomplete batches
    )
    
    val_loader = DataLoader(
        val_dataset,
        batch_size=config.BATCH_SIZE,
        shuffle=False,
        num_workers=config.NUM_WORKERS,
        pin_memory=config.PIN_MEMORY
    )
    
    print("✓ Data loaders ready")
    
    # Create model
    print("\nCreating model...")
    model = RobustECGModel()
    
    # Train
    trainer = RobustTrainer(model, train_loader, val_loader)
    best_loss = trainer.fit()
    
    print(f"\n✓ Training complete! Best val loss: {best_loss:.6f}")
    
    # Cleanup
    del model, trainer
    gc.collect()
    torch.cuda.empty_cache()

if __name__ == '__main__':
    main()