In [2]:
!pip install wfdb

Collecting wfdb
  Downloading wfdb-4.3.0-py3-none-any.whl.metadata (3.8 kB)
Downloading wfdb-4.3.0-py3-none-any.whl (163 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m163.8/163.8 kB[0m [31m3.8 MB/s[0m eta [36m0:00:00[0m00:01[0m
[?25hInstalling collected packages: wfdb
Successfully installed wfdb-4.3.0


In [2]:
#!/usr/bin/env python3
"""
Enhanced CNN-Transformer hybrid for apnea detection targeting 90%+ accuracy.
Key improvements:
- Deeper architecture with better feature extraction
- Advanced augmentation techniques
- Improved loss function with focal loss
- Better normalization and regularization
"""

import argparse
import os
import time
from pathlib import Path
from collections import Counter

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

try:
    import wfdb
except Exception:
    wfdb = None

from sklearn.metrics import roc_auc_score, precision_score, recall_score, f1_score

# ----------------------------- Utilities ---------------------------------

def set_seed(seed=42):
    import random
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

# ----------------------------- Focal Loss ---------------------------------

class FocalLoss(nn.Module):
    """Focal loss for handling class imbalance better than CE."""
    def __init__(self, alpha=None, gamma=2.0, reduction='mean'):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction
    
    def forward(self, inputs, targets):
        ce_loss = F.cross_entropy(inputs, targets, reduction='none', weight=self.alpha)
        pt = torch.exp(-ce_loss)
        focal_loss = ((1 - pt) ** self.gamma) * ce_loss
        
        if self.reduction == 'mean':
            return focal_loss.mean()
        elif self.reduction == 'sum':
            return focal_loss.sum()
        return focal_loss

# ----------------------------- Enhanced Model ------------------------------

class SEBlock(nn.Module):
    """Squeeze-and-Excitation block for channel attention."""
    def __init__(self, channels, reduction=8):
        super().__init__()
        self.fc1 = nn.Linear(channels, channels // reduction)
        self.fc2 = nn.Linear(channels // reduction, channels)
    
    def forward(self, x):
        # x: (B, C, L)
        b, c, _ = x.size()
        y = F.adaptive_avg_pool1d(x, 1).view(b, c)
        y = F.relu(self.fc1(y))
        y = torch.sigmoid(self.fc2(y)).view(b, c, 1)
        return x * y

class EnhancedResidualBlock(nn.Module):
    """Enhanced residual block with SE attention."""
    def __init__(self, channels, kernel_size=3, dropout=0.1):
        super().__init__()
        self.conv1 = nn.Conv1d(channels, channels, kernel_size=kernel_size, padding=kernel_size//2)
        self.conv2 = nn.Conv1d(channels, channels, kernel_size=kernel_size, padding=kernel_size//2)
        self.norm1 = nn.BatchNorm1d(channels)
        self.norm2 = nn.BatchNorm1d(channels)
        self.se = SEBlock(channels)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        residual = x
        x = self.norm1(x)
        x = F.gelu(self.conv1(x))
        x = self.dropout(x)
        x = self.norm2(x)
        x = self.conv2(x)
        x = self.se(x)
        return F.gelu(residual + x)

class TemporalAttention(nn.Module):
    """Temporal attention mechanism."""
    def __init__(self, d_model, num_heads=8, dropout=0.1):
        super().__init__()
        self.attention = nn.MultiheadAttention(d_model, num_heads, dropout=dropout, batch_first=True)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.ff = nn.Sequential(
            nn.Linear(d_model, d_model * 4),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(d_model * 4, d_model),
            nn.Dropout(dropout)
        )
    
    def forward(self, x):
        # x: (B, L, C)
        attn_out, _ = self.attention(x, x, x)
        x = self.norm1(x + attn_out)
        ff_out = self.ff(x)
        x = self.norm2(x + ff_out)
        return x

class EnhancedApneaModel(nn.Module):
    """Enhanced multi-scale CNN-Transformer with advanced features."""
    def __init__(self, d_model=256, n_cnn_layers=8, n_attn_layers=2, dropout=0.3):
        super().__init__()
        
        # Multi-scale input projection
        channels_per_scale = [d_model//4, d_model//4, d_model//2]  # Sums to d_model
        self.input_proj = nn.ModuleList([
            nn.Conv1d(1, channels_per_scale[i], kernel_size=k, padding=k//2) 
            for i, k in enumerate([3, 5, 7])
        ])
        self.input_combine = nn.Conv1d(d_model, d_model, kernel_size=1)
        self.input_norm = nn.BatchNorm1d(d_model)
        
        # Deep CNN feature extraction with varying kernel sizes
        self.cnn_blocks = nn.ModuleList()
        for i in range(n_cnn_layers):
            kernel_size = 3 if i % 2 == 0 else 5
            self.cnn_blocks.append(EnhancedResidualBlock(d_model, kernel_size, dropout))
        
        # Downsample for attention
        self.downsample = nn.Conv1d(d_model, d_model, kernel_size=3, stride=2, padding=1)
        
        # Temporal attention layers
        self.attn_layers = nn.ModuleList([
            TemporalAttention(d_model, num_heads=8, dropout=dropout)
            for _ in range(n_attn_layers)
        ])
        
        # Multi-scale feature aggregation
        self.global_pool = nn.AdaptiveAvgPool1d(1)
        self.global_max_pool = nn.AdaptiveMaxPool1d(1)
        
        # Enhanced classification head
        self.classifier = nn.Sequential(
            nn.Linear(d_model * 4, d_model * 2),
            nn.BatchNorm1d(d_model * 2),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(d_model * 2, d_model),
            nn.BatchNorm1d(d_model),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(d_model, 2)
        )
        
    def forward(self, x):
        # x: (B, L, 1)
        x = x.transpose(1, 2)  # (B, 1, L)
        
        # Multi-scale input
        multi_scale = [proj(x) for proj in self.input_proj]
        x = torch.cat(multi_scale, dim=1)  # (B, d_model, L)
        x = self.input_combine(x)
        x = self.input_norm(x)
        
        # Deep CNN feature extraction
        for block in self.cnn_blocks:
            x = block(x)
        
        # Global pooling features
        x_avg = self.global_pool(x).squeeze(-1)  # (B, d_model)
        x_max = self.global_max_pool(x).squeeze(-1)  # (B, d_model)
        
        # Downsample and apply attention
        x_down = self.downsample(x)  # (B, d_model, L/2)
        x_seq = x_down.transpose(1, 2)  # (B, L/2, d_model)
        
        for attn_layer in self.attn_layers:
            x_seq = attn_layer(x_seq)
        
        x_attn_avg = x_seq.mean(dim=1)  # (B, d_model)
        x_attn_max = x_seq.max(dim=1)[0]  # (B, d_model)
        
        # Concatenate all features
        x_combined = torch.cat([x_avg, x_max, x_attn_avg, x_attn_max], dim=-1)
        
        # Classification
        logits = self.classifier(x_combined)
        return logits

# --------------------------- Enhanced Dataset ---------------------------

class EnhancedApneaDataset(Dataset):
    """Enhanced dataset with better preprocessing and augmentation."""

    def __init__(self, data_dir: str, record_names: list = None, cache_dir: str = None,
                 segment_length: int = 6000, stride: int = 3000, split='train', augment=True):
        super().__init__()
        self.segment_length = int(segment_length)
        self.stride = int(stride)
        self.split = split
        self.augment = augment and (split == 'train')
        
        cache_dir = Path(cache_dir) if cache_dir else Path(data_dir)
        cache_file = cache_dir / f'apnea_cache_enhanced_{split}_{segment_length}.pt'

        if cache_file.exists():
            print(f"Loading cached {split} dataset from {cache_file}")
            data = torch.load(cache_file)
            self.segments = data['segments']
            self.labels = data['labels']
        else:
            assert wfdb is not None, "wfdb not available"
            assert record_names is not None, "record_names required"
            
            self.segments = []
            self.labels = []
            self.data_dir = Path(data_dir)
            
            print(f"Processing {len(record_names)} records for {split}...")
            for i, rec in enumerate(record_names):
                print(f"  [{i+1}/{len(record_names)}] {rec}...", end='\r')
                self._load_record(rec)
            
            if len(self.segments) == 0:
                raise RuntimeError("No segments loaded")
            
            self.segments = torch.tensor(np.stack(self.segments, axis=0), dtype=torch.float32)
            self.labels = torch.tensor(self.labels, dtype=torch.long)
            
            print(f"\nSaving {split} cache to {cache_file}")
            torch.save({'segments': self.segments, 'labels': self.labels}, cache_file)

        if self.segments.ndim == 2:
            self.segments = self.segments.unsqueeze(-1)

        print(f"{split.capitalize()}: {len(self.segments)} segments. "
              f"Class dist: {Counter(self.labels.tolist())}")

    def _load_record(self, record_name: str):
        try:
            record = wfdb.rdrecord(str(self.data_dir / record_name))
            signal = record.p_signal[:, 0].astype(np.float32)
    
            # Handle NaNs
            if np.isnan(signal).any():
                nans = np.isnan(signal)
                not_nans = ~nans
                if not_nans.sum() > 0:
                    signal[nans] = np.interp(np.flatnonzero(nans), np.flatnonzero(not_nans), signal[not_nans])
                else:
                    signal = np.zeros_like(signal)
    
            annotation = wfdb.rdann(str(self.data_dir / record_name), 'apn')
            
            n_minutes = len(signal) // 6000
            minute_labels = np.zeros(n_minutes, dtype=int)
            
            for i, symbol in enumerate(annotation.symbol):
                if symbol == 'A':
                    sample = annotation.sample[i]
                    minute = sample // 6000
                    if minute < n_minutes:
                        minute_labels[minute] = 1
    
            n_samples = len(signal)
            for start in range(0, n_samples - self.segment_length + 1, self.stride):
                end = start + self.segment_length
                seg = signal[start:end].astype(np.float32)
    
                # Robust normalization with better scaling
                seg_mean = np.nanmean(seg)
                seg_std = np.nanstd(seg)
                if np.isnan(seg_std) or seg_std < 1e-8:
                    seg = seg - seg_mean
                else:
                    seg = (seg - seg_mean) / (seg_std + 1e-8)
                
                # Clip extreme values
                seg = np.clip(seg, -10, 10)
    
                minute = start // 6000
                if minute < len(minute_labels):
                    label = minute_labels[minute]
                    self.segments.append(seg)
                    self.labels.append(int(label))
                    
        except Exception as e:
            print(f"\nError loading {record_name}: {e}")
    
    def _augment(self, seg):
        """Apply augmentation to segment."""
        if np.random.random() < 0.3:
            # Add Gaussian noise
            noise = np.random.normal(0, 0.05, seg.shape)
            seg = seg + noise
        
        if np.random.random() < 0.3:
            # Scale
            scale = np.random.uniform(0.9, 1.1)
            seg = seg * scale
        
        if np.random.random() < 0.2:
            # Time shift
            shift = np.random.randint(-50, 50)
            seg = np.roll(seg, shift, axis=0)
        
        return seg
            
    def __len__(self):
        return self.segments.shape[0]

    def __getitem__(self, idx):
        seg = self.segments[idx]
        label = self.labels[idx]
        
        if self.augment:
            seg = seg.numpy()
            seg = self._augment(seg)
            seg = torch.from_numpy(seg)
        
        # Ensure shape is (L, 1) not (L,) or (L, 1, 1)
        if seg.ndim == 1:
            seg = seg.unsqueeze(-1)
        elif seg.ndim == 3:
            seg = seg.squeeze()
            if seg.ndim == 1:
                seg = seg.unsqueeze(-1)
        
        return seg, label

# -------------------------- Training / Validation ------------------------

def compute_class_weights(labels_tensor):
    counts = Counter(labels_tensor.tolist())
    total = sum(counts.values())
    num_classes = len(counts)
    weights = [total / (num_classes * counts.get(i, 1)) for i in range(num_classes)]
    return torch.tensor(weights, dtype=torch.float32)

def train_epoch(model, dataloader, criterion, optimizer, device, epoch, scaler=None):
    model.train()
    total_loss = 0.0
    correct = 0
    total = 0
    
    num_batches = len(dataloader)
    print_freq = max(1, num_batches // 20)
    
    start_time = time.time()

    for batch_idx, (data, target) in enumerate(dataloader, 1):
        data = data.to(device, non_blocking=True)
        target = target.to(device, non_blocking=True)

        optimizer.zero_grad()

        with torch.amp.autocast('cuda', enabled=(device.type == 'cuda')):
            output = model(data)
            loss = criterion(output, target)
        
        if torch.isnan(loss):
            print(f"\nWARNING: NaN loss at batch {batch_idx}, skipping...")
            continue

        if scaler is not None:
            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            scaler.step(optimizer)
            scaler.update()
        else:
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()

        total_loss += loss.item()
        pred = output.argmax(dim=1)
        correct += pred.eq(target).sum().item()
        total += target.size(0)
        
        if batch_idx % print_freq == 0 or batch_idx == num_batches:
            curr_acc = 100.0 * correct / total
            curr_loss = total_loss / batch_idx
            elapsed = time.time() - start_time
            speed = batch_idx / elapsed
            eta = (num_batches - batch_idx) / speed if speed > 0 else 0
            
            print(f"  Epoch {epoch} [{batch_idx:4d}/{num_batches}] "
                  f"Loss: {curr_loss:.4f} Acc: {curr_acc:.2f}% "
                  f"Speed: {speed:.1f} b/s ETA: {eta:.0f}s", end='\r')

    print()
    avg_loss = total_loss / len(dataloader)
    accuracy = 100.0 * correct / total if total > 0 else 0.0
    return avg_loss, accuracy

def validate(model, dataloader, criterion, device):
    model.eval()
    total_loss = 0.0
    correct = 0
    total = 0
    all_preds = []
    all_targets = []
    all_probs = []

    with torch.no_grad():
        for data, target in dataloader:
            data = data.to(device, non_blocking=True)
            target = target.to(device, non_blocking=True)
            output = model(data)
            loss = criterion(output, target)

            total_loss += loss.item()
            probs = F.softmax(output, dim=1)[:, 1]
            pred = output.argmax(dim=1)

            correct += pred.eq(target).sum().item()
            total += target.size(0)

            all_preds.extend(pred.cpu().numpy().tolist())
            all_targets.extend(target.cpu().numpy().tolist())
            all_probs.extend(probs.cpu().numpy().tolist())

    avg_loss = total_loss / len(dataloader)
    accuracy = 100.0 * correct / total if total > 0 else 0.0
    
    precision = precision_score(all_targets, all_preds, zero_division=0)
    recall = recall_score(all_targets, all_preds, zero_division=0)
    f1 = f1_score(all_targets, all_preds, zero_division=0)
    
    return avg_loss, accuracy, np.array(all_preds), np.array(all_targets), np.array(all_probs), precision, recall, f1

# ------------------------------ Main ------------------------------------

def main(args):
    set_seed(args.seed)

    DATA_DIR = Path(args.data_dir)
    if not DATA_DIR.exists():
        raise FileNotFoundError(f"Data directory not found: {DATA_DIR}")

    # Find valid records
    record_files = list(DATA_DIR.glob('*.hea'))
    all_records = [f.stem for f in record_files]
    valid_records = [rec for rec in all_records 
                    if (DATA_DIR / (rec + '.apn')).exists() and not rec.endswith('er')]
    
    if len(valid_records) == 0:
        raise RuntimeError(f"No valid records found")

    print(f"Found {len(valid_records)} valid records")

    # Split records
    import random
    valid_records_shuffled = valid_records.copy()
    random.Random(args.seed).shuffle(valid_records_shuffled)
    split_idx = int(len(valid_records_shuffled) * args.train_split)
    train_records = valid_records_shuffled[:split_idx]
    val_records = valid_records_shuffled[split_idx:]
    print(f"Train: {len(train_records)} records, Val: {len(val_records)} records\n")

    # Create datasets with augmentation
    cache_dir = args.cache_dir if args.cache_dir else str(DATA_DIR)
    train_dataset = EnhancedApneaDataset(
        str(DATA_DIR), record_names=train_records, cache_dir=cache_dir,
        segment_length=args.segment_length, stride=args.stride, split='train', augment=True
    )
    val_dataset = EnhancedApneaDataset(
        str(DATA_DIR), record_names=val_records, cache_dir=cache_dir,
        segment_length=args.segment_length, stride=args.stride, split='val', augment=False
    )

    # DataLoaders
    num_workers = 2 if str(DATA_DIR).startswith('/kaggle') else 4
    
    train_loader = DataLoader(
        train_dataset, batch_size=args.batch_size, shuffle=True,
        num_workers=num_workers, pin_memory=True
    )
    val_loader = DataLoader(
        val_dataset, batch_size=args.batch_size, shuffle=False,
        num_workers=num_workers, pin_memory=True
    )

    # Setup device and model
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"\nUsing device: {device}")
    if device.type == 'cuda':
        print(f"GPU: {torch.cuda.get_device_name(0)}")

    model = EnhancedApneaModel(
        d_model=args.d_model, 
        n_cnn_layers=args.n_cnn_layers,
        n_attn_layers=args.n_attn_layers,
        dropout=args.dropout
    ).to(device)
    
    total_params = sum(p.numel() for p in model.parameters())
    print(f"Model parameters: {total_params:,}\n")

    # Focal loss for better class imbalance handling
    class_weights = compute_class_weights(train_dataset.labels).to(device)
    print(f"Class weights: {class_weights}")
    criterion = FocalLoss(alpha=class_weights, gamma=2.0)
    
    optimizer = torch.optim.AdamW(
        model.parameters(), 
        lr=args.lr, 
        weight_decay=args.weight_decay,
        betas=(0.9, 0.999)
    )
    
    # Cosine annealing with warmup
    scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
        optimizer, 
        T_0=10,
        T_mult=2,
        eta_min=1e-6
    )

    scaler = torch.amp.GradScaler('cuda') if device.type == 'cuda' else None

    best_val_acc = 0.0
    best_val_f1 = 0.0
    no_improve = 0

    # Training loop
    print("\nStarting training...")
    print("="*90)
    
    for epoch in range(1, args.epochs + 1):
        epoch_start = time.time()
        
        train_loss, train_acc = train_epoch(
            model, train_loader, criterion, optimizer, device, epoch, scaler=scaler
        )
        
        val_loss, val_acc, val_preds, val_targets, val_probs, precision, recall, f1 = validate(
            model, val_loader, criterion, device
        )

        try:
            auc = roc_auc_score(val_targets, val_probs)
        except Exception:
            auc = 0.0

        scheduler.step()
        epoch_time = time.time() - epoch_start
        
        print(f"Epoch {epoch:2d}/{args.epochs} - Time: {epoch_time:.1f}s - LR: {optimizer.param_groups[0]['lr']:.2e}")
        print(f"  Train - Loss: {train_loss:.4f}, Acc: {train_acc:.2f}%")
        print(f"  Val   - Loss: {val_loss:.4f}, Acc: {val_acc:.2f}%, AUC: {auc:.4f}")
        print(f"  Val   - Precision: {precision:.4f}, Recall: {recall:.4f}, F1: {f1:.4f}")

        improved = False
        if val_acc > best_val_acc or (val_acc == best_val_acc and f1 > best_val_f1):
            best_val_acc = val_acc
            best_val_f1 = f1
            no_improve = 0
            improved = True
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'val_acc': val_acc,
                'val_auc': auc,
                'val_f1': f1
            }, args.best_model_path)
            print(f"  ✓ New best! (Acc: {val_acc:.2f}%, F1: {f1:.4f})")
        else:
            no_improve += 1
            print(f"  No improvement ({no_improve}/{args.patience})")

        print("-"*90)

        if no_improve >= args.patience:
            print(f"\nEarly stopping after {epoch} epochs")
            break

    print(f"\n{'='*90}")
    print(f"Training finished!")
    print(f"Best validation - Accuracy: {best_val_acc:.2f}%, F1: {best_val_f1:.4f}")
    print(f"{'='*90}")

if __name__ == '__main__':
    kaggle_data = '/kaggle/input/vincent2/apnea-ecg-database-1.0.0'
    colab_data = '/content/apnea-ecg/1.0.0'
    
    if Path(kaggle_data).exists():
        default_data_dir = kaggle_data
        default_cache_dir = '/kaggle/working'
        default_model_path = '/kaggle/working/best_model_enhanced.pth'
    elif Path(colab_data).exists():
        default_data_dir = colab_data
        default_cache_dir = '/content'
        default_model_path = '/content/best_model_enhanced.pth'
    else:
        default_data_dir = None
        default_cache_dir = None
        default_model_path = 'best_model_enhanced.pth'
    
    parser = argparse.ArgumentParser(description='Enhanced apnea detection (90%+ target)')
    parser.add_argument('--data-dir', type=str, default=default_data_dir)
    parser.add_argument('--cache-dir', type=str, default=default_cache_dir)
    parser.add_argument('--segment-length', type=int, default=6000)  # 60s segments
    parser.add_argument('--stride', type=int, default=3000)  # 50% overlap
    parser.add_argument('--batch-size', type=int, default=32)  # Larger model needs smaller batch
    parser.add_argument('--epochs', type=int, default=100)
    parser.add_argument('--lr', type=float, default=1e-4)
    parser.add_argument('--weight-decay', type=float, default=1e-4)
    parser.add_argument('--d-model', type=int, default=256)
    parser.add_argument('--n-cnn-layers', type=int, default=8)
    parser.add_argument('--n-attn-layers', type=int, default=2)
    parser.add_argument('--dropout', type=float, default=0.3)
    parser.add_argument('--train-split', type=float, default=0.8)
    parser.add_argument('--num-workers', type=int, default=4)
    parser.add_argument('--patience', type=int, default=15)
    parser.add_argument('--best-model-path', type=str, default=default_model_path)
    parser.add_argument('--seed', type=int, default=42)

    args, _ = parser.parse_known_args()
    
    if args.data_dir is None:
        raise SystemExit(f"\nERROR: Dataset not found\n")
    
    print("="*90)
    print("ENHANCED MODEL CONFIGURATION (Target: 90%+ Accuracy)")
    print("="*90)
    print(f"  Data:          {args.data_dir}")
    print(f"  Cache:         {args.cache_dir}")
    print(f"  Model save:    {args.best_model_path}")
    print(f"  Segment:       {args.segment_length} samples (60s) with augmentation")
    print(f"  Batch size:    {args.batch_size}")
    print(f"  Epochs:        {args.epochs}")
    print(f"  Learning rate: {args.lr}")
    print(f"  Model:         d_model={args.d_model}, cnn_layers={args.n_cnn_layers}, "
          f"attn_layers={args.n_attn_layers}, dropout={args.dropout}")
    print(f"  Loss:          Focal Loss (gamma=2.0)")
    print("="*90 + "\n")
    
    main(args)

ENHANCED MODEL CONFIGURATION (Target: 90%+ Accuracy)
  Data:          /kaggle/input/vincent2/apnea-ecg-database-1.0.0
  Cache:         /kaggle/working
  Model save:    /kaggle/working/best_model_enhanced.pth
  Segment:       6000 samples (60s) with augmentation
  Batch size:    32
  Epochs:        100
  Learning rate: 0.0001
  Model:         d_model=256, cnn_layers=8, attn_layers=2, dropout=0.3
  Loss:          Focal Loss (gamma=2.0)

Found 43 valid records
Train: 34 records, Val: 9 records

Processing 34 records for train...
  [34/34] a20....
Saving train cache to /kaggle/working/apnea_cache_enhanced_train_6000.pt
Train: 33411 segments. Class dist: Counter({0: 21112, 1: 12299})
Processing 9 records for val...
  [9/9] a01....
Saving val cache to /kaggle/working/apnea_cache_enhanced_val_6000.pt
Val: 8622 segments. Class dist: Counter({0: 4687, 1: 3935})

Using device: cuda
GPU: Tesla P100-PCIE-16GB
Model parameters: 6,842,498

Class weights: tensor([0.7913, 1.3583], device='cuda:0')

St

OutOfMemoryError: CUDA out of memory. Tried to allocate 8.58 GiB. GPU 0 has a total capacity of 15.89 GiB of which 4.98 GiB is free. Process 2338 has 10.90 GiB memory in use. Of the allocated memory 10.49 GiB is allocated by PyTorch, and 132.44 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [3]:
#!/usr/bin/env python3
"""
Memory-efficient high-performance model for 90%+ apnea detection accuracy.
Optimized for Tesla P100 16GB GPU.
"""

import argparse
import os
import time
from pathlib import Path
from collections import Counter

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

try:
    import wfdb
except Exception:
    wfdb = None

from sklearn.metrics import roc_auc_score, precision_score, recall_score, f1_score

# ----------------------------- Utilities ---------------------------------

def set_seed(seed=42):
    import random
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

# ----------------------------- Efficient Model ---------------------------

class EfficientResBlock(nn.Module):
    """Efficient residual block with depthwise separable convolutions."""
    def __init__(self, channels, kernel_size=7):
        super().__init__()
        # Depthwise
        self.depthwise = nn.Conv1d(channels, channels, kernel_size, padding=kernel_size//2, groups=channels)
        # Pointwise
        self.pointwise = nn.Conv1d(channels, channels, 1)
        self.norm = nn.BatchNorm1d(channels)
        self.dropout = nn.Dropout(0.1)
        
    def forward(self, x):
        residual = x
        x = self.depthwise(x)
        x = self.pointwise(x)
        x = self.norm(x)
        x = self.dropout(x)
        return F.gelu(residual + x)


class EfficientApneaNet(nn.Module):
    """Memory-efficient architecture for high accuracy."""
    def __init__(self, d_model=128, n_blocks=6, dropout=0.2):
        super().__init__()
        
        # Input stem with multi-scale
        self.stem = nn.Sequential(
            nn.Conv1d(1, d_model//2, kernel_size=7, padding=3, stride=2),
            nn.BatchNorm1d(d_model//2),
            nn.GELU(),
            nn.Conv1d(d_model//2, d_model, kernel_size=5, padding=2, stride=2),
            nn.BatchNorm1d(d_model),
            nn.GELU(),
        )
        
        # Efficient residual blocks with varying receptive fields
        self.blocks = nn.ModuleList([
            EfficientResBlock(d_model, kernel_size=7 if i % 2 == 0 else 11)
            for i in range(n_blocks)
        ])
        
        # Lightweight attention
        self.channel_attn = nn.Sequential(
            nn.AdaptiveAvgPool1d(1),
            nn.Conv1d(d_model, d_model//4, 1),
            nn.GELU(),
            nn.Conv1d(d_model//4, d_model, 1),
            nn.Sigmoid()
        )
        
        # Temporal attention (memory efficient)
        self.temp_attn = nn.MultiheadAttention(d_model, num_heads=4, dropout=dropout, batch_first=True)
        self.temp_norm = nn.LayerNorm(d_model)
        
        # Classification head
        self.classifier = nn.Sequential(
            nn.Linear(d_model * 3, d_model),
            nn.BatchNorm1d(d_model),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(d_model, 2)
        )
        
    def forward(self, x):
        # x: (B, L, 1) -> (B, 1, L)
        x = x.transpose(1, 2)
        
        # Stem reduces sequence length by 4x
        x = self.stem(x)  # (B, d_model, L/4)
        
        # Residual blocks
        for block in self.blocks:
            x = block(x)
        
        # Channel attention
        attn_weights = self.channel_attn(x)
        x = x * attn_weights
        
        # Global features
        x_avg = F.adaptive_avg_pool1d(x, 1).squeeze(-1)  # (B, d_model)
        x_max = F.adaptive_max_pool1d(x, 1).squeeze(-1)  # (B, d_model)
        
        # Temporal attention on further downsampled sequence
        x_seq = F.adaptive_avg_pool1d(x, 50).transpose(1, 2)  # (B, 50, d_model)
        x_attn, _ = self.temp_attn(x_seq, x_seq, x_seq)
        x_attn = self.temp_norm(x_attn + x_seq)
        x_attn = x_attn.mean(dim=1)  # (B, d_model)
        
        # Combine features
        x_combined = torch.cat([x_avg, x_max, x_attn], dim=1)
        
        # Classify
        logits = self.classifier(x_combined)
        return logits


# --------------------------- Dataset ---------------------------

class ApneaDataset(Dataset):
    """Optimized dataset with data augmentation."""

    def __init__(self, data_dir: str, record_names: list = None, cache_dir: str = None,
                 segment_length: int = 6000, stride: int = 3000, split='train', augment=True):
        super().__init__()
        self.segment_length = int(segment_length)
        self.stride = int(stride)
        self.split = split
        self.augment = augment and (split == 'train')
        
        cache_dir = Path(cache_dir) if cache_dir else Path(data_dir)
        cache_file = cache_dir / f'apnea_{split}_{segment_length}_{stride}.pt'

        if cache_file.exists():
            print(f"Loading cached {split} from {cache_file}")
            data = torch.load(cache_file)
            self.segments = data['segments']
            self.labels = data['labels']
        else:
            assert wfdb is not None, "wfdb required"
            assert record_names is not None, "record_names required"
            
            self.segments = []
            self.labels = []
            self.data_dir = Path(data_dir)
            
            print(f"Processing {len(record_names)} records for {split}...")
            for i, rec in enumerate(record_names):
                print(f"  [{i+1}/{len(record_names)}] {rec}...", end='\r')
                self._load_record(rec)
            
            if len(self.segments) == 0:
                raise RuntimeError("No segments loaded")
            
            self.segments = torch.tensor(np.stack(self.segments, axis=0), dtype=torch.float32)
            self.labels = torch.tensor(self.labels, dtype=torch.long)
            
            print(f"\nSaving cache to {cache_file}")
            torch.save({'segments': self.segments, 'labels': self.labels}, cache_file)

        if self.segments.ndim == 2:
            self.segments = self.segments.unsqueeze(-1)

        print(f"{split.capitalize()}: {len(self.segments)} segments, "
              f"Class: {Counter(self.labels.tolist())}")

    def _load_record(self, record_name: str):
        try:
            record = wfdb.rdrecord(str(self.data_dir / record_name))
            signal = record.p_signal[:, 0].astype(np.float32)
    
            if np.isnan(signal).any():
                nans = np.isnan(signal)
                not_nans = ~nans
                if not_nans.sum() > 0:
                    signal[nans] = np.interp(np.flatnonzero(nans), np.flatnonzero(not_nans), signal[not_nans])
                else:
                    signal = np.zeros_like(signal)
    
            annotation = wfdb.rdann(str(self.data_dir / record_name), 'apn')
            
            n_minutes = len(signal) // 6000
            minute_labels = np.zeros(n_minutes, dtype=int)
            
            for i, symbol in enumerate(annotation.symbol):
                if symbol == 'A':
                    sample = annotation.sample[i]
                    minute = sample // 6000
                    if minute < n_minutes:
                        minute_labels[minute] = 1
    
            n_samples = len(signal)
            for start in range(0, n_samples - self.segment_length + 1, self.stride):
                end = start + self.segment_length
                seg = signal[start:end].astype(np.float32)
    
                seg_mean = np.nanmean(seg)
                seg_std = np.nanstd(seg)
                if np.isnan(seg_std) or seg_std < 1e-8:
                    seg = seg - seg_mean
                else:
                    seg = (seg - seg_mean) / (seg_std + 1e-8)
                
                seg = np.clip(seg, -10, 10)
    
                minute = start // 6000
                if minute < len(minute_labels):
                    label = minute_labels[minute]
                    self.segments.append(seg)
                    self.labels.append(int(label))
                    
        except Exception as e:
            print(f"\nError loading {record_name}: {e}")
    
    def _augment(self, seg):
        """Light augmentation."""
        seg = seg.numpy() if torch.is_tensor(seg) else seg
        
        if np.random.random() < 0.3:
            # Gaussian noise
            seg = seg + np.random.normal(0, 0.03, seg.shape).astype(np.float32)
        
        if np.random.random() < 0.2:
            # Scale
            seg = seg * np.random.uniform(0.95, 1.05)
        
        return torch.from_numpy(seg)
            
    def __len__(self):
        return self.segments.shape[0]

    def __getitem__(self, idx):
        seg = self.segments[idx]
        label = self.labels[idx]
        
        if self.augment:
            seg = self._augment(seg)
        
        if seg.ndim == 1:
            seg = seg.unsqueeze(-1)
        elif seg.ndim == 3:
            seg = seg.squeeze(1)
        
        return seg, label

# -------------------------- Training ------------------------

def compute_class_weights(labels_tensor):
    counts = Counter(labels_tensor.tolist())
    total = sum(counts.values())
    num_classes = len(counts)
    weights = [total / (num_classes * counts.get(i, 1)) for i in range(num_classes)]
    return torch.tensor(weights, dtype=torch.float32)

def train_epoch(model, dataloader, criterion, optimizer, scheduler, device, epoch, scaler=None):
    model.train()
    total_loss = 0.0
    correct = 0
    total = 0
    
    num_batches = len(dataloader)
    print_freq = max(1, num_batches // 15)
    start_time = time.time()

    for batch_idx, (data, target) in enumerate(dataloader, 1):
        data = data.to(device, non_blocking=True)
        target = target.to(device, non_blocking=True)

        optimizer.zero_grad(set_to_none=True)

        with torch.amp.autocast('cuda', enabled=(device.type == 'cuda')):
            output = model(data)
            loss = criterion(output, target)
        
        if torch.isnan(loss):
            print(f"\nWARNING: NaN loss, skipping batch {batch_idx}")
            continue

        if scaler is not None:
            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            scaler.step(optimizer)
            scaler.update()
        else:
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
        
        scheduler.step()

        total_loss += loss.item()
        pred = output.argmax(dim=1)
        correct += pred.eq(target).sum().item()
        total += target.size(0)
        
        if batch_idx % print_freq == 0 or batch_idx == num_batches:
            curr_acc = 100.0 * correct / total
            curr_loss = total_loss / batch_idx
            speed = batch_idx / (time.time() - start_time)
            eta = (num_batches - batch_idx) / speed if speed > 0 else 0
            
            print(f"  Ep {epoch} [{batch_idx:4d}/{num_batches}] "
                  f"Loss: {curr_loss:.4f} Acc: {curr_acc:.2f}% "
                  f"({speed:.1f} b/s, ETA: {eta:.0f}s)", end='\r')

    print()
    avg_loss = total_loss / len(dataloader)
    accuracy = 100.0 * correct / total if total > 0 else 0.0
    return avg_loss, accuracy

def validate(model, dataloader, criterion, device):
    model.eval()
    total_loss = 0.0
    correct = 0
    total = 0
    all_preds = []
    all_targets = []
    all_probs = []

    with torch.no_grad():
        for data, target in dataloader:
            data = data.to(device, non_blocking=True)
            target = target.to(device, non_blocking=True)
            output = model(data)
            loss = criterion(output, target)

            total_loss += loss.item()
            probs = F.softmax(output, dim=1)[:, 1]
            pred = output.argmax(dim=1)

            correct += pred.eq(target).sum().item()
            total += target.size(0)

            all_preds.extend(pred.cpu().tolist())
            all_targets.extend(target.cpu().tolist())
            all_probs.extend(probs.cpu().tolist())

    avg_loss = total_loss / len(dataloader)
    accuracy = 100.0 * correct / total if total > 0 else 0.0
    
    precision = precision_score(all_targets, all_preds, zero_division=0)
    recall = recall_score(all_targets, all_preds, zero_division=0)
    f1 = f1_score(all_targets, all_preds, zero_division=0)
    
    return avg_loss, accuracy, np.array(all_preds), np.array(all_targets), np.array(all_probs), precision, recall, f1

# ------------------------------ Main ------------------------------------

def main(args):
    set_seed(args.seed)

    DATA_DIR = Path(args.data_dir)
    if not DATA_DIR.exists():
        raise FileNotFoundError(f"Data directory not found: {DATA_DIR}")

    record_files = list(DATA_DIR.glob('*.hea'))
    all_records = [f.stem for f in record_files]
    valid_records = [rec for rec in all_records 
                    if (DATA_DIR / (rec + '.apn')).exists() and not rec.endswith('er')]
    
    if len(valid_records) == 0:
        raise RuntimeError("No valid records found")

    print(f"Found {len(valid_records)} valid records")

    import random
    valid_records_shuffled = valid_records.copy()
    random.Random(args.seed).shuffle(valid_records_shuffled)
    split_idx = int(len(valid_records_shuffled) * args.train_split)
    train_records = valid_records_shuffled[:split_idx]
    val_records = valid_records_shuffled[split_idx:]
    print(f"Train: {len(train_records)}, Val: {len(val_records)}\n")

    cache_dir = args.cache_dir if args.cache_dir else str(DATA_DIR)
    train_dataset = ApneaDataset(
        str(DATA_DIR), train_records, cache_dir,
        args.segment_length, args.stride, 'train', augment=True
    )
    val_dataset = ApneaDataset(
        str(DATA_DIR), val_records, cache_dir,
        args.segment_length, args.stride, 'val', augment=False
    )

    num_workers = 2 if str(DATA_DIR).startswith('/kaggle') else 4
    
    train_loader = DataLoader(
        train_dataset, batch_size=args.batch_size, shuffle=True,
        num_workers=num_workers, pin_memory=True
    )
    val_loader = DataLoader(
        val_dataset, batch_size=args.batch_size, shuffle=False,
        num_workers=num_workers, pin_memory=True
    )

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Device: {device}")
    if device.type == 'cuda':
        print(f"GPU: {torch.cuda.get_device_name(0)}")
        torch.cuda.empty_cache()

    model = EfficientApneaNet(
        d_model=args.d_model, n_blocks=args.n_blocks, dropout=args.dropout
    ).to(device)
    
    print(f"Parameters: {sum(p.numel() for p in model.parameters()):,}\n")

    class_weights = compute_class_weights(train_dataset.labels).to(device)
    print(f"Class weights: {class_weights}")
    criterion = nn.CrossEntropyLoss(weight=class_weights, label_smoothing=0.1)
    
    optimizer = torch.optim.AdamW(
        model.parameters(), lr=args.lr, weight_decay=args.weight_decay
    )
    
    scheduler = torch.optim.lr_scheduler.OneCycleLR(
        optimizer, max_lr=args.lr, epochs=args.epochs,
        steps_per_epoch=len(train_loader), pct_start=0.1
    )

    scaler = torch.amp.GradScaler('cuda') if device.type == 'cuda' else None

    best_val_acc = 0.0
    best_val_f1 = 0.0
    no_improve = 0

    print("\nStarting training...")
    print("="*80)
    
    for epoch in range(1, args.epochs + 1):
        epoch_start = time.time()
        
        train_loss, train_acc = train_epoch(
            model, train_loader, criterion, optimizer, scheduler, device, epoch, scaler
        )
        
        val_loss, val_acc, _, val_targets, val_probs, precision, recall, f1 = validate(
            model, val_loader, criterion, device
        )

        try:
            auc = roc_auc_score(val_targets, val_probs)
        except:
            auc = 0.0

        epoch_time = time.time() - epoch_start
        
        print(f"Epoch {epoch:2d}/{args.epochs} ({epoch_time:.1f}s)")
        print(f"  Train: Loss={train_loss:.4f}, Acc={train_acc:.2f}%")
        print(f"  Val:   Loss={val_loss:.4f}, Acc={val_acc:.2f}%, AUC={auc:.4f}")
        print(f"         Prec={precision:.3f}, Rec={recall:.3f}, F1={f1:.3f}")

        if val_acc > best_val_acc or (val_acc >= best_val_acc and f1 > best_val_f1):
            best_val_acc = val_acc
            best_val_f1 = f1
            no_improve = 0
            torch.save({
                'epoch': epoch, 'model_state_dict': model.state_dict(),
                'val_acc': val_acc, 'val_auc': auc, 'val_f1': f1
            }, args.best_model_path)
            print(f"  ✓ Best! (Acc={val_acc:.2f}%, F1={f1:.3f})")
        else:
            no_improve += 1
            print(f"  No improvement ({no_improve}/{args.patience})")

        print("-"*80)

        if no_improve >= args.patience:
            print(f"\nEarly stop at epoch {epoch}")
            break

    print(f"\n{'='*80}")
    print(f"BEST - Accuracy: {best_val_acc:.2f}%, F1: {best_val_f1:.3f}")
    print(f"{'='*80}")

if __name__ == '__main__':
    kaggle_data = '/kaggle/input/vincent2/apnea-ecg-database-1.0.0'
    colab_data = '/content/apnea-ecg/1.0.0'
    
    if Path(kaggle_data).exists():
        default_data_dir = kaggle_data
        default_cache_dir = '/kaggle/working'
        default_model_path = '/kaggle/working/best_model.pth'
    elif Path(colab_data).exists():
        default_data_dir = colab_data
        default_cache_dir = '/content'
        default_model_path = '/content/best_model.pth'
    else:
        default_data_dir = None
        default_cache_dir = None
        default_model_path = 'best_model.pth'
    
    parser = argparse.ArgumentParser()
    parser.add_argument('--data-dir', type=str, default=default_data_dir)
    parser.add_argument('--cache-dir', type=str, default=default_cache_dir)
    parser.add_argument('--segment-length', type=int, default=6000)  # 60s
    parser.add_argument('--stride', type=int, default=3000)  # 50% overlap
    parser.add_argument('--batch-size', type=int, default=48)  # Optimized for P100
    parser.add_argument('--epochs', type=int, default=80)
    parser.add_argument('--lr', type=float, default=3e-4)
    parser.add_argument('--weight-decay', type=float, default=1e-4)
    parser.add_argument('--d-model', type=int, default=128)  # Efficient size
    parser.add_argument('--n-blocks', type=int, default=8)  # Deep but efficient
    parser.add_argument('--dropout', type=float, default=0.2)
    parser.add_argument('--train-split', type=float, default=0.8)
    parser.add_argument('--patience', type=int, default=15)
    parser.add_argument('--best-model-path', type=str, default=default_model_path)
    parser.add_argument('--seed', type=int, default=42)

    args, _ = parser.parse_known_args()
    
    if args.data_dir is None:
        raise SystemExit("ERROR: Dataset not found")
    
    print("="*80)
    print("MEMORY-EFFICIENT MODEL (Target: 90%+ Accuracy)")
    print("="*80)
    print(f"  Data:       {args.data_dir}")
    print(f"  Segment:    {args.segment_length} samples (60s), stride={args.stride}")
    print(f"  Batch:      {args.batch_size}")
    print(f"  Epochs:     {args.epochs}")
    print(f"  Model:      d_model={args.d_model}, blocks={args.n_blocks}")
    print(f"  Optimizer:  AdamW (lr={args.lr}, wd={args.weight_decay})")
    print("="*80 + "\n")
    
    main(args)

MEMORY-EFFICIENT MODEL (Target: 90%+ Accuracy)
  Data:       /kaggle/input/vincent2/apnea-ecg-database-1.0.0
  Segment:    6000 samples (60s), stride=3000
  Batch:      48
  Epochs:     80
  Model:      d_model=128, blocks=8
  Optimizer:  AdamW (lr=0.0003, wd=0.0001)

Found 43 valid records
Train: 34, Val: 9

Processing 34 records for train...
  [34/34] a20....
Saving cache to /kaggle/working/apnea_train_6000_3000.pt
Train: 33411 segments, Class: Counter({0: 21112, 1: 12299})
Processing 9 records for val...
  [9/9] a01....
Saving cache to /kaggle/working/apnea_val_6000_3000.pt
Val: 8622 segments, Class: Counter({0: 4687, 1: 3935})
Device: cuda
GPU: Tesla P100-PCIE-16GB
Parameters: 310,818

Class weights: tensor([0.7913, 1.3583], device='cuda:0')

Starting training...
  Ep 1 [ 697/697] Loss: 0.5680 Acc: 74.52% (15.0 b/s, ETA: 0s))
Epoch  1/80 (50.3s)
  Train: Loss=0.5680, Acc=74.52%
  Val:   Loss=0.6762, Acc=64.59%, AUC=0.7138
         Prec=0.615, Rec=0.600, F1=0.607
  ✓ Best! (Acc=64.5

In [4]:
#!/usr/bin/env python3
"""
Production-grade memory-efficient model for 90%+ apnea detection accuracy.
Optimized for Tesla P100 16GB GPU with comprehensive validation and error handling.

Features:
- Record-level cross-validation to prevent data leakage
- Comprehensive input validation and error handling
- Enhanced data augmentation
- Test set evaluation
- Detailed logging and diagnostics
- Configurable sampling rate and parameters
"""

import argparse
import logging
import os
import sys
import time
import traceback
from pathlib import Path
from collections import Counter
from typing import List, Tuple, Dict, Optional, Any

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

try:
    import wfdb
    WFDB_AVAILABLE = True
except ImportError:
    WFDB_AVAILABLE = False
    wfdb = None

from sklearn.metrics import (
    roc_auc_score, precision_score, recall_score, 
    f1_score, confusion_matrix, classification_report
)

# ----------------------------- Configuration ---------------------------------

class Config:
    """Centralized configuration with validation."""
    
    def __init__(self, **kwargs):
        # Data parameters
        self.sampling_rate = kwargs.get('sampling_rate', 100)  # Hz
        self.segment_duration = kwargs.get('segment_duration', 60)  # seconds
        self.segment_length = self.sampling_rate * self.segment_duration
        self.stride_ratio = kwargs.get('stride_ratio', 0.5)  # 50% overlap
        self.stride = int(self.segment_length * self.stride_ratio)
        
        # Model parameters
        self.d_model = kwargs.get('d_model', 128)
        self.n_blocks = kwargs.get('n_blocks', 8)
        self.dropout = kwargs.get('dropout', 0.2)
        
        # Training parameters
        self.batch_size = kwargs.get('batch_size', 48)
        self.epochs = kwargs.get('epochs', 80)
        self.lr = kwargs.get('lr', 3e-4)
        self.weight_decay = kwargs.get('weight_decay', 1e-4)
        self.patience = kwargs.get('patience', 15)
        
        # Splits (train/val/test)
        self.train_split = kwargs.get('train_split', 0.7)
        self.val_split = kwargs.get('val_split', 0.15)
        self.test_split = kwargs.get('test_split', 0.15)
        
        # Paths
        self.data_dir = kwargs.get('data_dir')
        self.cache_dir = kwargs.get('cache_dir')
        self.output_dir = kwargs.get('output_dir', '.')
        
        # Other
        self.seed = kwargs.get('seed', 42)
        self.num_workers = kwargs.get('num_workers', 4)
        
        self.validate()
    
    def validate(self):
        """Validate configuration parameters."""
        assert self.sampling_rate > 0, "Sampling rate must be positive"
        assert self.segment_duration > 0, "Segment duration must be positive"
        assert 0 < self.stride_ratio <= 1, "Stride ratio must be in (0, 1]"
        assert self.d_model > 0 and self.d_model % 2 == 0, "d_model must be positive and even"
        assert self.n_blocks > 0, "n_blocks must be positive"
        assert 0 <= self.dropout < 1, "Dropout must be in [0, 1)"
        assert self.batch_size > 0, "Batch size must be positive"
        assert self.epochs > 0, "Epochs must be positive"
        assert self.lr > 0, "Learning rate must be positive"
        assert self.patience > 0, "Patience must be positive"
        
        # Split validation
        split_sum = self.train_split + self.val_split + self.test_split
        assert abs(split_sum - 1.0) < 1e-6, f"Splits must sum to 1.0, got {split_sum}"
        
        # Path validation
        if self.data_dir:
            data_path = Path(self.data_dir)
            if not data_path.exists():
                raise FileNotFoundError(f"Data directory not found: {data_path}")

# ----------------------------- Logging Setup ---------------------------------

def setup_logging(output_dir: str, verbose: bool = True) -> logging.Logger:
    """Setup comprehensive logging."""
    log_dir = Path(output_dir)
    log_dir.mkdir(parents=True, exist_ok=True)
    
    log_file = log_dir / f'training_{time.strftime("%Y%m%d_%H%M%S")}.log'
    
    # Create logger
    logger = logging.getLogger('ApneaDetection')
    logger.setLevel(logging.DEBUG if verbose else logging.INFO)
    
    # File handler (detailed)
    fh = logging.FileHandler(log_file)
    fh.setLevel(logging.DEBUG)
    fh.setFormatter(logging.Formatter(
        '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
    ))
    
    # Console handler (concise)
    ch = logging.StreamHandler()
    ch.setLevel(logging.INFO)
    ch.setFormatter(logging.Formatter('%(message)s'))
    
    logger.addHandler(fh)
    logger.addHandler(ch)
    
    return logger

# ----------------------------- Utilities ---------------------------------

def set_seed(seed: int = 42):
    """Set random seeds for reproducibility."""
    import random
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

def validate_signal_shape(signal: np.ndarray, expected_length: int, record_name: str):
    """Validate signal dimensions and length."""
    if signal.ndim != 1:
        raise ValueError(f"Record {record_name}: Expected 1D signal, got shape {signal.shape}")
    
    if len(signal) < expected_length:
        raise ValueError(
            f"Record {record_name}: Signal too short ({len(signal)} < {expected_length})"
        )

# ----------------------------- Enhanced Data Augmentation ---------------------------------

class SignalAugmenter:
    """Comprehensive signal augmentation for time-series ECG data."""
    
    def __init__(self, sampling_rate: int = 100):
        self.sampling_rate = sampling_rate
    
    def apply(self, signal: np.ndarray, augment_prob: float = 0.5) -> np.ndarray:
        """Apply random augmentations."""
        if np.random.random() > augment_prob:
            return signal
        
        signal = signal.copy()
        
        # Gaussian noise (40% chance)
        if np.random.random() < 0.4:
            noise_level = np.random.uniform(0.02, 0.05)
            signal += np.random.normal(0, noise_level, signal.shape).astype(np.float32)
        
        # Amplitude scaling (30% chance)
        if np.random.random() < 0.3:
            scale = np.random.uniform(0.9, 1.1)
            signal *= scale
        
        # Time warping (20% chance)
        if np.random.random() < 0.2:
            signal = self._time_warp(signal)
        
        # Baseline wander (15% chance)
        if np.random.random() < 0.15:
            signal = self._add_baseline_wander(signal)
        
        # Random sign flip (10% chance) - physiologically valid
        if np.random.random() < 0.1:
            signal *= -1
        
        return signal
    
    def _time_warp(self, signal: np.ndarray, sigma: float = 0.2) -> np.ndarray:
        """Apply smooth time warping."""
        length = len(signal)
        # Create smooth warping curve
        warp = np.cumsum(np.random.normal(1.0, sigma, length))
        warp = warp / warp[-1] * (length - 1)  # Normalize to signal length
        warp = np.clip(warp, 0, length - 1)
        
        # Interpolate
        indices = np.arange(length)
        warped = np.interp(indices, warp, signal)
        return warped.astype(np.float32)
    
    def _add_baseline_wander(self, signal: np.ndarray) -> np.ndarray:
        """Add low-frequency baseline wander."""
        length = len(signal)
        # Create low-frequency sine wave
        freq = np.random.uniform(0.1, 0.3)  # Hz
        phase = np.random.uniform(0, 2 * np.pi)
        amplitude = np.random.uniform(0.05, 0.15)
        
        t = np.arange(length) / self.sampling_rate
        baseline = amplitude * np.sin(2 * np.pi * freq * t + phase)
        
        return signal + baseline.astype(np.float32)

# ----------------------------- Efficient Model ---------------------------

class EfficientResBlock(nn.Module):
    """Efficient residual block with depthwise separable convolutions."""
    
    def __init__(self, channels: int, kernel_size: int = 7, dropout: float = 0.1):
        super().__init__()
        if channels <= 0:
            raise ValueError(f"channels must be positive, got {channels}")
        if kernel_size < 1 or kernel_size % 2 == 0:
            raise ValueError(f"kernel_size must be positive odd number, got {kernel_size}")
        
        self.depthwise = nn.Conv1d(
            channels, channels, kernel_size, 
            padding=kernel_size//2, groups=channels
        )
        self.pointwise = nn.Conv1d(channels, channels, 1)
        self.norm = nn.BatchNorm1d(channels)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        residual = x
        x = self.depthwise(x)
        x = self.pointwise(x)
        x = self.norm(x)
        x = self.dropout(x)
        return F.gelu(residual + x)


class EfficientApneaNet(nn.Module):
    """Production-grade memory-efficient architecture for sleep apnea detection."""
    
    def __init__(self, d_model: int = 128, n_blocks: int = 8, 
                 dropout: float = 0.2, input_length: int = 6000):
        super().__init__()
        
        # Validate inputs
        if d_model <= 0 or d_model % 2 != 0:
            raise ValueError(f"d_model must be positive even integer, got {d_model}")
        if n_blocks <= 0:
            raise ValueError(f"n_blocks must be positive, got {n_blocks}")
        if not 0 <= dropout < 1:
            raise ValueError(f"dropout must be in [0, 1), got {dropout}")
        
        self.d_model = d_model
        self.n_blocks = n_blocks
        self.input_length = input_length
        
        # Input stem with multi-scale feature extraction
        self.stem = nn.Sequential(
            nn.Conv1d(1, d_model//2, kernel_size=7, padding=3, stride=2),
            nn.BatchNorm1d(d_model//2),
            nn.GELU(),
            nn.Conv1d(d_model//2, d_model, kernel_size=5, padding=2, stride=2),
            nn.BatchNorm1d(d_model),
            nn.GELU(),
        )
        
        # Efficient residual blocks with varying receptive fields
        self.blocks = nn.ModuleList([
            EfficientResBlock(d_model, kernel_size=7 if i % 2 == 0 else 11, dropout=dropout)
            for i in range(n_blocks)
        ])
        
        # Lightweight channel attention (squeeze-and-excitation)
        self.channel_attn = nn.Sequential(
            nn.AdaptiveAvgPool1d(1),
            nn.Conv1d(d_model, d_model//4, 1),
            nn.GELU(),
            nn.Conv1d(d_model//4, d_model, 1),
            nn.Sigmoid()
        )
        
        # Temporal attention (memory efficient)
        self.temp_attn = nn.MultiheadAttention(
            d_model, num_heads=4, dropout=dropout, batch_first=True
        )
        self.temp_norm = nn.LayerNorm(d_model)
        
        # Classification head with batch normalization
        self.classifier = nn.Sequential(
            nn.Linear(d_model * 3, d_model),
            nn.BatchNorm1d(d_model),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(d_model, 2)
        )
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Forward pass.
        
        Args:
            x: Input tensor of shape (B, L, 1) or (B, 1, L)
        
        Returns:
            logits: Output logits of shape (B, 2)
        """
        # Validate input
        if x.ndim != 3:
            raise ValueError(f"Expected 3D input (B, L, C) or (B, C, L), got shape {x.shape}")
        
        # Handle both (B, L, 1) and (B, 1, L) formats
        if x.shape[-1] == 1:
            x = x.transpose(1, 2)  # (B, L, 1) -> (B, 1, L)
        
        if x.shape[1] != 1:
            raise ValueError(f"Expected 1 input channel, got {x.shape[1]} channels")
        
        # Stem reduces sequence length by 4x
        x = self.stem(x)  # (B, d_model, L/4)
        
        # Residual blocks
        for block in self.blocks:
            x = block(x)
        
        # Channel attention
        attn_weights = self.channel_attn(x)
        x = x * attn_weights
        
        # Global features via pooling
        x_avg = F.adaptive_avg_pool1d(x, 1).squeeze(-1)  # (B, d_model)
        x_max = F.adaptive_max_pool1d(x, 1).squeeze(-1)  # (B, d_model)
        
        # Temporal attention on downsampled sequence
        x_seq = F.adaptive_avg_pool1d(x, 50).transpose(1, 2)  # (B, 50, d_model)
        x_attn, _ = self.temp_attn(x_seq, x_seq, x_seq)
        x_attn = self.temp_norm(x_attn + x_seq)
        x_attn = x_attn.mean(dim=1)  # (B, d_model)
        
        # Combine features
        x_combined = torch.cat([x_avg, x_max, x_attn], dim=1)
        
        # Classify
        logits = self.classifier(x_combined)
        return logits

# --------------------------- Dataset with Record-Level Split ---------------------------

class ApneaDataset(Dataset):
    """
    Optimized dataset with proper record-level splitting to prevent data leakage.
    """

    def __init__(self, 
                 data_dir: str,
                 record_names: List[str],
                 config: Config,
                 cache_dir: Optional[str] = None,
                 split: str = 'train',
                 logger: Optional[logging.Logger] = None):
        super().__init__()
        
        if not WFDB_AVAILABLE:
            raise ImportError("wfdb package is required. Install with: pip install wfdb")
        
        self.config = config
        self.split = split
        self.logger = logger or logging.getLogger('ApneaDataset')
        self.augmenter = SignalAugmenter(config.sampling_rate) if split == 'train' else None
        
        cache_dir = Path(cache_dir) if cache_dir else Path(data_dir)
        cache_file = cache_dir / f'apnea_{split}_{config.segment_length}_{config.stride}_v2.pt'

        if cache_file.exists():
            self.logger.info(f"Loading cached {split} from {cache_file}")
            try:
                data = torch.load(cache_file, weights_only=True)
                self.segments = data['segments']
                self.labels = data['labels']
                self.record_ids = data.get('record_ids', [])
                self.logger.info(f"Loaded {len(self.segments)} segments from cache")
            except Exception as e:
                self.logger.error(f"Failed to load cache: {e}")
                raise
        else:
            self.segments = []
            self.labels = []
            self.record_ids = []
            self.data_dir = Path(data_dir)
            
            if not record_names:
                raise ValueError("record_names cannot be empty")
            
            self.logger.info(f"Processing {len(record_names)} records for {split}...")
            
            failed_records = []
            for i, rec in enumerate(record_names):
                try:
                    self.logger.info(f"  [{i+1}/{len(record_names)}] Processing {rec}...")
                    self._load_record(rec)
                except Exception as e:
                    self.logger.error(f"Failed to load {rec}: {e}")
                    self.logger.debug(traceback.format_exc())
                    failed_records.append(rec)
            
            if failed_records:
                self.logger.warning(f"Failed to load {len(failed_records)} records: {failed_records}")
            
            if len(self.segments) == 0:
                raise RuntimeError(f"No segments loaded for {split} split from {len(record_names)} records")
            
            self.segments = torch.tensor(np.stack(self.segments, axis=0), dtype=torch.float32)
            self.labels = torch.tensor(self.labels, dtype=torch.long)
            
            self.logger.info(f"Saving cache to {cache_file}")
            try:
                torch.save({
                    'segments': self.segments,
                    'labels': self.labels,
                    'record_ids': self.record_ids
                }, cache_file)
            except Exception as e:
                self.logger.warning(f"Failed to save cache: {e}")

        # Ensure proper shape
        if self.segments.ndim == 2:
            self.segments = self.segments.unsqueeze(-1)

        # Log class distribution
        class_dist = Counter(self.labels.tolist())
        self.logger.info(
            f"{split.capitalize()}: {len(self.segments)} segments, "
            f"Class distribution: {dict(class_dist)}"
        )

    def _load_record(self, record_name: str):
        """Load and process a single record with comprehensive error handling."""
        try:
            # Read signal
            record = wfdb.rdrecord(str(self.data_dir / record_name))
            
            if record.p_signal is None or record.p_signal.shape[0] == 0:
                raise ValueError(f"Empty signal in record {record_name}")
            
            if record.p_signal.shape[1] < 1:
                raise ValueError(f"No channels in record {record_name}")
            
            signal = record.p_signal[:, 0].astype(np.float32)
            
            # Validate signal
            validate_signal_shape(signal, self.config.segment_length, record_name)
            
            # Handle NaN values
            if np.isnan(signal).any():
                nans = np.isnan(signal)
                not_nans = ~nans
                if not_nans.sum() > 0:
                    signal[nans] = np.interp(
                        np.flatnonzero(nans),
                        np.flatnonzero(not_nans),
                        signal[not_nans]
                    )
                else:
                    raise ValueError(f"Record {record_name} contains only NaN values")
            
            # Read annotations
            annotation = wfdb.rdann(str(self.data_dir / record_name), 'apn')
            
            # Create minute-level labels
            n_minutes = len(signal) // self.config.segment_length
            minute_labels = np.zeros(n_minutes, dtype=int)
            
            apnea_count = 0
            for i, symbol in enumerate(annotation.symbol):
                if symbol == 'A':
                    apnea_count += 1
                    sample = annotation.sample[i]
                    minute = sample // self.config.segment_length
                    if 0 <= minute < n_minutes:
                        minute_labels[minute] = 1
            
            self.logger.debug(f"  {record_name}: {len(signal)} samples, {apnea_count} apnea events")
            
            # Extract segments
            n_samples = len(signal)
            segments_added = 0
            
            for start in range(0, n_samples - self.config.segment_length + 1, self.config.stride):
                end = start + self.config.segment_length
                seg = signal[start:end].astype(np.float32)
                
                # Normalize segment
                seg_mean = np.mean(seg)
                seg_std = np.std(seg)
                
                if np.isnan(seg_std) or seg_std < 1e-8:
                    seg = seg - seg_mean
                else:
                    seg = (seg - seg_mean) / (seg_std + 1e-8)
                
                # Clip outliers
                seg = np.clip(seg, -10, 10)
                
                # Assign label based on minute
                minute = start // self.config.segment_length
                if minute < len(minute_labels):
                    label = minute_labels[minute]
                    self.segments.append(seg)
                    self.labels.append(int(label))
                    self.record_ids.append(record_name)
                    segments_added += 1
            
            self.logger.debug(f"  {record_name}: Added {segments_added} segments")
                    
        except FileNotFoundError as e:
            raise FileNotFoundError(f"Record files not found for {record_name}: {e}")
        except Exception as e:
            raise RuntimeError(f"Error processing record {record_name}: {e}")
    
    def __len__(self) -> int:
        return self.segments.shape[0]

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
        seg = self.segments[idx]
        label = self.labels[idx]
        
        # Apply augmentation for training
        if self.augmenter is not None:
            seg = seg.numpy() if torch.is_tensor(seg) else seg
            seg = self.augmenter.apply(seg, augment_prob=0.5)
            seg = torch.from_numpy(seg)
        
        # Ensure shape is (L, 1)
        if seg.ndim == 1:
            seg = seg.unsqueeze(-1)
        elif seg.ndim == 3:
            seg = seg.squeeze(0)
        
        return seg, label

# -------------------------- Training Utilities ------------------------

def compute_class_weights(labels_tensor: torch.Tensor, logger: logging.Logger) -> torch.Tensor:
    """Compute balanced class weights."""
    counts = Counter(labels_tensor.tolist())
    total = sum(counts.values())
    num_classes = len(counts)
    
    if num_classes != 2:
        logger.warning(f"Expected 2 classes, found {num_classes}")
    
    weights = []
    for i in range(num_classes):
        if counts[i] == 0:
            logger.warning(f"Class {i} has 0 samples!")
            weights.append(1.0)
        else:
            weights.append(total / (num_classes * counts[i]))
    
    return torch.tensor(weights, dtype=torch.float32)

def train_epoch(model: nn.Module, 
                dataloader: DataLoader,
                criterion: nn.Module,
                optimizer: torch.optim.Optimizer,
                scheduler: torch.optim.lr_scheduler._LRScheduler,
                device: torch.device,
                epoch: int,
                logger: logging.Logger,
                scaler: Optional[torch.amp.GradScaler] = None) -> Tuple[float, float]:
    """Train for one epoch with comprehensive logging."""
    model.train()
    total_loss = 0.0
    correct = 0
    total = 0
    skipped_batches = 0
    
    num_batches = len(dataloader)
    print_freq = max(1, num_batches // 15)
    start_time = time.time()

    for batch_idx, (data, target) in enumerate(dataloader, 1):
        try:
            data = data.to(device, non_blocking=True)
            target = target.to(device, non_blocking=True)

            optimizer.zero_grad(set_to_none=True)

            with torch.amp.autocast('cuda', enabled=(device.type == 'cuda')):
                output = model(data)
                loss = criterion(output, target)
            
            # Check for NaN
            if torch.isnan(loss) or torch.isinf(loss):
                logger.warning(f"Invalid loss at batch {batch_idx}, skipping")
                skipped_batches += 1
                continue

            if scaler is not None:
                scaler.scale(loss).backward()
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                scaler.step(optimizer)
                scaler.update()
            else:
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                optimizer.step()
            
            scheduler.step()

            total_loss += loss.item()
            pred = output.argmax(dim=1)
            correct += pred.eq(target).sum().item()
            total += target.size(0)
            
            if batch_idx % print_freq == 0 or batch_idx == num_batches:
                curr_acc = 100.0 * correct / total if total > 0 else 0.0
                curr_loss = total_loss / (batch_idx - skipped_batches) if batch_idx > skipped_batches else 0.0
                speed = batch_idx / (time.time() - start_time)
                eta = (num_batches - batch_idx) / speed if speed > 0 else 0
                current_lr = optimizer.param_groups[0]['lr']
                
                logger.info(
                    f"  Ep {epoch} [{batch_idx:4d}/{num_batches}] "
                    f"Loss: {curr_loss:.4f} Acc: {curr_acc:.2f}% "
                    f"LR: {current_lr:.2e} ({speed:.1f} b/s, ETA: {eta:.0f}s)"
                )
        
        except Exception as e:
            logger.error(f"Error in batch {batch_idx}: {e}")
            logger.debug(traceback.format_exc())
            skipped_batches += 1
            continue

    if skipped_batches > 0:
        logger.warning(f"Skipped {skipped_batches} batches due to errors")

    avg_loss = total_loss / (num_batches - skipped_batches) if num_batches > skipped_batches else float('inf')
    accuracy = 100.0 * correct / total if total > 0 else 0.0
    
    return avg_loss, accuracy

def validate(model: nn.Module,
             dataloader: DataLoader,
             criterion: nn.Module,
             device: torch.device,
             logger: logging.Logger) -> Tuple[float, float, np.ndarray, np.ndarray, np.ndarray, float, float, float]:
    """Validate model with comprehensive metrics."""
    model.eval()
    total_loss = 0.0
    correct = 0
    total = 0
    all_preds = []
    all_targets = []
    all_probs = []

    with torch.no_grad():
        for data, target in dataloader:
            try:
                data = data.to(device, non_blocking=True)
                target = target.to(device, non_blocking=True)
                output = model(data)
                loss = criterion(output, target)

                if not (torch.isnan(loss) or torch.isinf(loss)):
                    total_loss += loss.item()

                probs = F.softmax(output, dim=1)[:, 1]
                pred = output.argmax(dim=1)

                correct += pred.eq(target).sum().item()
                total += target.size(0)

                all_preds.extend(pred.cpu().tolist())
                all_targets.extend(target.cpu().tolist())
                all_probs.extend(probs.cpu().tolist())
            
            except Exception as e:
                logger.error(f"Error in validation batch: {e}")
                continue

    avg_loss = total_loss / len(dataloader) if len(dataloader) > 0 else float('inf')
    accuracy = 100.0 * correct / total if total > 0 else 0.0
    
    # Compute metrics with error handling
    try:
        precision = precision_score(all_targets, all_preds, zero_division=0)
        recall = recall_score(all_targets, all_preds, zero_division=0)
        f1 = f1_score(all_targets, all_preds, zero_division=0)
    except Exception as e:
        logger.error(f"Error computing metrics: {e}")
        precision = recall = f1 = 0.0
    
    return (avg_loss, accuracy, np.array(all_preds), np.array(all_targets),
            np.array(all_probs), precision, recall, f1)

def evaluate_test_set(model: nn.Module,
                     dataloader: DataLoader,
                     device: torch.device,
                     logger: logging.Logger,
                     output_dir: str):
    """Comprehensive evaluation on test set."""
    logger.info("\n" + "="*80)
    logger.info("FINAL TEST SET EVALUATION")
    logger.info("="*80)
    
    model.eval()
    all_preds = []
    all_targets = []
    all_probs = []
    
    with torch.no_grad():
        for data, target in dataloader:
            try:
                data = data.to(device, non_blocking=True)
                target = target.to(device, non_blocking=True)
                output = model(data)
                
                probs = F.softmax(output, dim=1)[:, 1]
                pred = output.argmax(dim=1)
                
                all_preds.extend(pred.cpu().tolist())
                all_targets.extend(target.cpu().tolist())
                all_probs.extend(probs.cpu().tolist())
            except Exception as e:
                logger.error(f"Error in test batch: {e}")
                continue
    
    # Convert to numpy
    all_preds = np.array(all_preds)
    all_targets = np.array(all_targets)
    all_probs = np.array(all_probs)
    
    # Compute comprehensive metrics
    accuracy = 100.0 * np.sum(all_preds == all_targets) / len(all_targets)
    precision = precision_score(all_targets, all_preds, zero_division=0)
    recall = recall_score(all_targets, all_preds, zero_division=0)
    f1 = f1_score(all_targets, all_preds, zero_division=0)
    
    try:
        auc = roc_auc_score(all_targets, all_probs)
    except Exception as e:
        logger.warning(f"Could not compute AUC: {e}")
        auc = 0.0
    
    # Confusion matrix
    cm = confusion_matrix(all_targets, all_preds)
    
    # Log results
    logger.info(f"\nTest Set Performance:")
    logger.info(f"  Accuracy:  {accuracy:.2f}%")
    logger.info(f"  Precision: {precision:.4f}")
    logger.info(f"  Recall:    {recall:.4f}")
    logger.info(f"  F1 Score:  {f1:.4f}")
    logger.info(f"  AUC:       {auc:.4f}")
    logger.info(f"\nConfusion Matrix:")
    logger.info(f"  {cm}")
    
    # Classification report
    logger.info(f"\nDetailed Classification Report:")
    report = classification_report(all_targets, all_preds, 
                                   target_names=['Normal', 'Apnea'],
                                   digits=4)
    logger.info(f"\n{report}")
    
    # Save results
    results_file = Path(output_dir) / 'test_results.txt'
    with open(results_file, 'w') as f:
        f.write("="*80 + "\n")
        f.write("FINAL TEST SET EVALUATION\n")
        f.write("="*80 + "\n\n")
        f.write(f"Accuracy:  {accuracy:.2f}%\n")
        f.write(f"Precision: {precision:.4f}\n")
        f.write(f"Recall:    {recall:.4f}\n")
        f.write(f"F1 Score:  {f1:.4f}\n")
        f.write(f"AUC:       {auc:.4f}\n\n")
        f.write("Confusion Matrix:\n")
        f.write(f"{cm}\n\n")
        f.write("Classification Report:\n")
        f.write(report)
    
    logger.info(f"\nResults saved to {results_file}")
    
    return {
        'accuracy': accuracy,
        'precision': precision,
        'recall': recall,
        'f1': f1,
        'auc': auc,
        'confusion_matrix': cm
    }

# ------------------------------ Main ------------------------------------

def split_records(records: List[str], 
                 train_split: float, 
                 val_split: float,
                 test_split: float,
                 seed: int,
                 logger: logging.Logger) -> Tuple[List[str], List[str], List[str]]:
    """
    Split records into train/val/test ensuring no data leakage.
    Each record appears in exactly one split.
    """
    import random
    
    # Validate splits
    split_sum = train_split + val_split + test_split
    if abs(split_sum - 1.0) > 1e-6:
        raise ValueError(f"Splits must sum to 1.0, got {split_sum}")
    
    # Shuffle records
    records_shuffled = records.copy()
    random.Random(seed).shuffle(records_shuffled)
    
    # Calculate split indices
    n_records = len(records_shuffled)
    train_end = int(n_records * train_split)
    val_end = train_end + int(n_records * val_split)
    
    train_records = records_shuffled[:train_end]
    val_records = records_shuffled[train_end:val_end]
    test_records = records_shuffled[val_end:]
    
    logger.info(f"\nRecord-level split (seed={seed}):")
    logger.info(f"  Train: {len(train_records)} records")
    logger.info(f"  Val:   {len(val_records)} records")
    logger.info(f"  Test:  {len(test_records)} records")
    logger.info(f"  Total: {len(records)} records")
    
    # Verify no overlap
    train_set = set(train_records)
    val_set = set(val_records)
    test_set = set(test_records)
    
    if train_set & val_set:
        raise RuntimeError("Train and validation sets overlap!")
    if train_set & test_set:
        raise RuntimeError("Train and test sets overlap!")
    if val_set & test_set:
        raise RuntimeError("Validation and test sets overlap!")
    
    return train_records, val_records, test_records

def main(config: Config, logger: logging.Logger):
    """Main training pipeline with comprehensive error handling."""
    
    set_seed(config.seed)
    
    # Validate data directory
    DATA_DIR = Path(config.data_dir)
    if not DATA_DIR.exists():
        raise FileNotFoundError(f"Data directory not found: {DATA_DIR}")

    # Find all valid records
    record_files = list(DATA_DIR.glob('*.hea'))
    all_records = [f.stem for f in record_files]
    
    # Filter valid records (must have .apn annotation, exclude error records)
    valid_records = []
    for rec in all_records:
        if rec.endswith('er'):
            logger.debug(f"Skipping error record: {rec}")
            continue
        
        apn_file = DATA_DIR / (rec + '.apn')
        if not apn_file.exists():
            logger.debug(f"Skipping {rec}: no .apn annotation")
            continue
        
        valid_records.append(rec)
    
    if len(valid_records) == 0:
        raise RuntimeError("No valid records found in data directory")

    logger.info(f"\nFound {len(all_records)} total records")
    logger.info(f"Valid records with annotations: {len(valid_records)}")

    # Record-level split to prevent data leakage
    train_records, val_records, test_records = split_records(
        valid_records,
        config.train_split,
        config.val_split,
        config.test_split,
        config.seed,
        logger
    )
    
    if len(test_records) == 0:
        logger.warning("No test records! Consider adjusting split ratios.")

    # Create datasets
    logger.info("\nCreating datasets...")
    try:
        train_dataset = ApneaDataset(
            str(DATA_DIR), train_records, config,
            config.cache_dir, 'train', logger
        )
        val_dataset = ApneaDataset(
            str(DATA_DIR), val_records, config,
            config.cache_dir, 'val', logger
        )
        
        if test_records:
            test_dataset = ApneaDataset(
                str(DATA_DIR), test_records, config,
                config.cache_dir, 'test', logger
            )
        else:
            test_dataset = None
    except Exception as e:
        logger.error(f"Failed to create datasets: {e}")
        logger.debug(traceback.format_exc())
        raise

    # Create data loaders
    train_loader = DataLoader(
        train_dataset, 
        batch_size=config.batch_size,
        shuffle=True,
        num_workers=config.num_workers,
        pin_memory=True,
        persistent_workers=True if config.num_workers > 0 else False
    )
    val_loader = DataLoader(
        val_dataset,
        batch_size=config.batch_size,
        shuffle=False,
        num_workers=config.num_workers,
        pin_memory=True,
        persistent_workers=True if config.num_workers > 0 else False
    )
    
    if test_dataset:
        test_loader = DataLoader(
            test_dataset,
            batch_size=config.batch_size,
            shuffle=False,
            num_workers=config.num_workers,
            pin_memory=True
        )

    # Setup device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    logger.info(f"\nDevice: {device}")
    if device.type == 'cuda':
        logger.info(f"GPU: {torch.cuda.get_device_name(0)}")
        logger.info(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
        torch.cuda.empty_cache()

    # Create model
    try:
        model = EfficientApneaNet(
            d_model=config.d_model,
            n_blocks=config.n_blocks,
            dropout=config.dropout,
            input_length=config.segment_length
        ).to(device)
        
        total_params = sum(p.numel() for p in model.parameters())
        trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
        logger.info(f"\nModel: EfficientApneaNet")
        logger.info(f"  Total parameters:     {total_params:,}")
        logger.info(f"  Trainable parameters: {trainable_params:,}")
        logger.info(f"  Model size: ~{total_params * 4 / 1e6:.2f} MB (fp32)")
    except Exception as e:
        logger.error(f"Failed to create model: {e}")
        logger.debug(traceback.format_exc())
        raise

    # Setup training
    class_weights = compute_class_weights(train_dataset.labels, logger)
    class_weights = class_weights.to(device)
    logger.info(f"\nClass weights: {class_weights.tolist()}")
    
    criterion = nn.CrossEntropyLoss(weight=class_weights, label_smoothing=0.1)
    
    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=config.lr,
        weight_decay=config.weight_decay,
        betas=(0.9, 0.999)
    )
    
    scheduler = torch.optim.lr_scheduler.OneCycleLR(
        optimizer,
        max_lr=config.lr,
        epochs=config.epochs,
        steps_per_epoch=len(train_loader),
        pct_start=0.1,
        anneal_strategy='cos'
    )

    scaler = torch.amp.GradScaler('cuda') if device.type == 'cuda' else None

    # Training loop
    best_val_acc = 0.0
    best_val_f1 = 0.0
    no_improve = 0
    best_model_path = Path(config.output_dir) / 'best_model.pth'

    logger.info("\n" + "="*80)
    logger.info("STARTING TRAINING")
    logger.info("="*80)
    
    training_history = {
        'train_loss': [], 'train_acc': [],
        'val_loss': [], 'val_acc': [],
        'val_precision': [], 'val_recall': [], 'val_f1': [], 'val_auc': []
    }
    
    try:
        for epoch in range(1, config.epochs + 1):
            epoch_start = time.time()
            
            # Train
            train_loss, train_acc = train_epoch(
                model, train_loader, criterion, optimizer,
                scheduler, device, epoch, logger, scaler
            )
            
            # Validate
            val_loss, val_acc, _, val_targets, val_probs, precision, recall, f1 = validate(
                model, val_loader, criterion, device, logger
            )

            # Compute AUC
            try:
                auc = roc_auc_score(val_targets, val_probs)
            except Exception as e:
                logger.warning(f"Could not compute AUC: {e}")
                auc = 0.0

            epoch_time = time.time() - epoch_start
            
            # Log epoch summary
            logger.info(f"\nEpoch {epoch:2d}/{config.epochs} completed in {epoch_time:.1f}s")
            logger.info(f"  Train: Loss={train_loss:.4f}, Acc={train_acc:.2f}%")
            logger.info(f"  Val:   Loss={val_loss:.4f}, Acc={val_acc:.2f}%, AUC={auc:.4f}")
            logger.info(f"         Precision={precision:.4f}, Recall={recall:.4f}, F1={f1:.4f}")

            # Store history
            training_history['train_loss'].append(train_loss)
            training_history['train_acc'].append(train_acc)
            training_history['val_loss'].append(val_loss)
            training_history['val_acc'].append(val_acc)
            training_history['val_precision'].append(precision)
            training_history['val_recall'].append(recall)
            training_history['val_f1'].append(f1)
            training_history['val_auc'].append(auc)

            # Save best model
            is_best = (val_acc > best_val_acc) or (val_acc >= best_val_acc and f1 > best_val_f1)
            
            if is_best:
                best_val_acc = val_acc
                best_val_f1 = f1
                no_improve = 0
                
                try:
                    torch.save({
                        'epoch': epoch,
                        'model_state_dict': model.state_dict(),
                        'optimizer_state_dict': optimizer.state_dict(),
                        'scheduler_state_dict': scheduler.state_dict(),
                        'val_acc': val_acc,
                        'val_auc': auc,
                        'val_f1': f1,
                        'config': config.__dict__,
                        'training_history': training_history
                    }, best_model_path)
                    logger.info(f"  ✓ Best model saved! (Acc={val_acc:.2f}%, F1={f1:.4f})")
                except Exception as e:
                    logger.error(f"Failed to save model: {e}")
            else:
                no_improve += 1
                logger.info(f"  No improvement ({no_improve}/{config.patience})")

            logger.info("-"*80)

            # Early stopping
            if no_improve >= config.patience:
                logger.info(f"\nEarly stopping triggered at epoch {epoch}")
                break

    except KeyboardInterrupt:
        logger.info("\n\nTraining interrupted by user")
    except Exception as e:
        logger.error(f"\n\nTraining failed with error: {e}")
        logger.debug(traceback.format_exc())
        raise

    # Training summary
    logger.info(f"\n{'='*80}")
    logger.info("TRAINING COMPLETED")
    logger.info(f"{'='*80}")
    logger.info(f"Best Validation - Accuracy: {best_val_acc:.2f}%, F1: {best_val_f1:.4f}")
    logger.info(f"Model saved to: {best_model_path}")

    # Test set evaluation
    if test_dataset and best_model_path.exists():
        logger.info("\nLoading best model for test set evaluation...")
        try:
            checkpoint = torch.load(best_model_path, weights_only=False)
            model.load_state_dict(checkpoint['model_state_dict'])
            
            test_results = evaluate_test_set(
                model, test_loader, device, logger, config.output_dir
            )
            
            logger.info(f"\n{'='*80}")
            logger.info("FINAL TEST RESULTS")
            logger.info(f"{'='*80}")
            logger.info(f"  Accuracy:  {test_results['accuracy']:.2f}%")
            logger.info(f"  Precision: {test_results['precision']:.4f}")
            logger.info(f"  Recall:    {test_results['recall']:.4f}")
            logger.info(f"  F1 Score:  {test_results['f1']:.4f}")
            logger.info(f"  AUC:       {test_results['auc']:.4f}")
            logger.info(f"{'='*80}")
            
        except Exception as e:
            logger.error(f"Failed to evaluate test set: {e}")
            logger.debug(traceback.format_exc())
    
    elif not test_dataset:
        logger.info("\nNo test set available for final evaluation")

    logger.info("\nTraining pipeline completed successfully!")

if __name__ == '__main__':
    # Auto-detect environment
    kaggle_data = '/kaggle/input/vincent2/apnea-ecg-database-1.0.0'
    colab_data = '/content/apnea-ecg/1.0.0'
    
    if Path(kaggle_data).exists():
        default_data_dir = kaggle_data
        default_cache_dir = '/kaggle/working'
        default_output_dir = '/kaggle/working'
    elif Path(colab_data).exists():
        default_data_dir = colab_data
        default_cache_dir = '/content'
        default_output_dir = '/content'
    else:
        default_data_dir = None
        default_cache_dir = None
        default_output_dir = '.'
    
    # Argument parser
    parser = argparse.ArgumentParser(
        description='Production-grade sleep apnea detection from ECG signals',
        formatter_class=argparse.ArgumentDefaultsHelpFormatter
    )
    
    # Data arguments
    parser.add_argument('--data-dir', type=str, default=default_data_dir,
                       help='Path to apnea-ecg database')
    parser.add_argument('--cache-dir', type=str, default=default_cache_dir,
                       help='Directory for caching processed data')
    parser.add_argument('--output-dir', type=str, default=default_output_dir,
                       help='Directory for outputs (models, logs)')
    
    # Signal processing
    parser.add_argument('--sampling-rate', type=int, default=100,
                       help='ECG sampling rate in Hz')
    parser.add_argument('--segment-duration', type=int, default=60,
                       help='Segment duration in seconds')
    parser.add_argument('--stride-ratio', type=float, default=0.5,
                       help='Stride ratio for overlapping segments')
    
    # Model architecture
    parser.add_argument('--d-model', type=int, default=128,
                       help='Model dimension')
    parser.add_argument('--n-blocks', type=int, default=8,
                       help='Number of residual blocks')
    parser.add_argument('--dropout', type=float, default=0.2,
                       help='Dropout rate')
    
    # Training
    parser.add_argument('--batch-size', type=int, default=48,
                       help='Batch size')
    parser.add_argument('--epochs', type=int, default=80,
                       help='Number of epochs')
    parser.add_argument('--lr', type=float, default=3e-4,
                       help='Learning rate')
    parser.add_argument('--weight-decay', type=float, default=1e-4,
                       help='Weight decay')
    parser.add_argument('--patience', type=int, default=15,
                       help='Early stopping patience')
    
    # Data splits
    parser.add_argument('--train-split', type=float, default=0.7,
                       help='Training set ratio')
    parser.add_argument('--val-split', type=float, default=0.15,
                       help='Validation set ratio')
    parser.add_argument('--test-split', type=float, default=0.15,
                       help='Test set ratio')
    
    # Other
    parser.add_argument('--seed', type=int, default=42,
                       help='Random seed')
    parser.add_argument('--num-workers', type=int, default=4,
                       help='Number of dataloader workers')
    parser.add_argument('--verbose', action='store_true',
                       help='Enable verbose logging')

    args, _ = parser.parse_known_args()
    
    # Validate data directory
    if args.data_dir is None:
        print("ERROR: Dataset not found. Please specify --data-dir")
        print("\nLooking for apnea-ecg-database in:")
        print(f"  - {kaggle_data}")
        print(f"  - {colab_data}")
        sys.exit(1)
    
    # Create output directory
    output_dir = Path(args.output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)
    
    # Setup logging
    logger = setup_logging(str(output_dir), args.verbose)
    
    # Create config
    try:
        config = Config(**vars(args))
    except Exception as e:
        logger.error(f"Invalid configuration: {e}")
        sys.exit(1)
    
    # Log configuration
    logger.info("="*80)
    logger.info("SLEEP APNEA DETECTION - PRODUCTION GRADE MODEL")
    logger.info("="*80)
    logger.info(f"\nConfiguration:")
    logger.info(f"  Data directory:    {config.data_dir}")
    logger.info(f"  Cache directory:   {config.cache_dir}")
    logger.info(f"  Output directory:  {config.output_dir}")
    logger.info(f"\nSignal Processing:")
    logger.info(f"  Sampling rate:     {config.sampling_rate} Hz")
    logger.info(f"  Segment duration:  {config.segment_duration}s ({config.segment_length} samples)")
    logger.info(f"  Stride ratio:      {config.stride_ratio} ({config.stride} samples)")
    logger.info(f"\nModel Architecture:")
    logger.info(f"  d_model:           {config.d_model}")
    logger.info(f"  n_blocks:          {config.n_blocks}")
    logger.info(f"  dropout:           {config.dropout}")
    logger.info(f"\nTraining:")
    logger.info(f"  Batch size:        {config.batch_size}")
    logger.info(f"  Epochs:            {config.epochs}")
    logger.info(f"  Learning rate:     {config.lr}")
    logger.info(f"  Weight decay:      {config.weight_decay}")
    logger.info(f"  Patience:          {config.patience}")
    logger.info(f"\nData Splits:")
    logger.info(f"  Train:             {config.train_split:.1%}")
    logger.info(f"  Validation:        {config.val_split:.1%}")
    logger.info(f"  Test:              {config.test_split:.1%}")
    logger.info(f"\nOther:")
    logger.info(f"  Random seed:       {config.seed}")
    logger.info(f"  Num workers:       {config.num_workers}")
    logger.info("="*80 + "\n")
    
    # Run training
    try:
        main(config, logger)
    except Exception as e:
        logger.error(f"\nFATAL ERROR: {e}")
        logger.debug(traceback.format_exc())
        sys.exit(1)

SLEEP APNEA DETECTION - PRODUCTION GRADE MODEL

Configuration:
  Data directory:    /kaggle/input/vincent2/apnea-ecg-database-1.0.0
  Cache directory:   /kaggle/working
  Output directory:  /kaggle/working

Signal Processing:
  Sampling rate:     100 Hz
  Segment duration:  60s (6000 samples)
  Stride ratio:      0.5 (3000 samples)

Model Architecture:
  d_model:           128
  n_blocks:          8
  dropout:           0.2

Training:
  Batch size:        48
  Epochs:            80
  Learning rate:     0.0003
  Weight decay:      0.0001
  Patience:          15

Data Splits:
  Train:             70.0%
  Validation:        15.0%
  Test:              15.0%

Other:
  Random seed:       42
  Num workers:       4


Found 86 total records
Valid records with annotations: 43

Record-level split (seed=42):
  Train: 30 records
  Val:   6 records
  Test:  7 records
  Total: 43 records

Creating datasets...
Processing 30 records for train...
  [1/30] Processing b02...
  [2/30] Processing a10...
  [

Traceback (most recent call last):
  File "/tmp/ipykernel_48/3253938270.py", line 1279, in <cell line: 0>
    main(config, logger)
  File "/tmp/ipykernel_48/3253938270.py", line 1033, in main
    train_loss, train_acc = train_epoch(
                            ^^^^^^^^^^^^
  File "/tmp/ipykernel_48/3253938270.py", line 597, in train_epoch
    for batch_idx, (data, target) in enumerate(dataloader, 1):
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 708, in __next__
    data = self._next_data()
           ^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1480, in _next_data
    return self._process_data(data)
           ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1505, in _process_data
    data.reraise()
  File "/usr/local/lib/python3.11/dist-packages/torch/_utils.py", line 733, in reraise
    raise exception
ValueError: Caught Val

TypeError: object of type 'NoneType' has no len()

In [None]:
#!/usr/bin/env python3
"""
Enhanced high-performance model for 90%+ apnea detection accuracy.
Implements R-R interval extraction and advanced preprocessing from research paper.
Optimized for Tesla P100 16GB GPU.
"""

import argparse
import os
import time
from pathlib import Path
from collections import Counter

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from scipy import signal as scipy_signal
from scipy.interpolate import CubicSpline

try:
    import wfdb
except Exception:
    wfdb = None

from sklearn.metrics import roc_auc_score, precision_score, recall_score, f1_score, confusion_matrix

# ----------------------------- Utilities ---------------------------------

def set_seed(seed=42):
    import random
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

# ----------------------------- R-Peak Detection & Preprocessing ---------------------------------

def detect_r_peaks_hamilton(ecg_signal, fs=100):
    """
    Hamilton R-peak detection algorithm.
    Based on: Hamilton, P. S. (2002). Open source ECG analysis.
    """
    # Bandpass filter (5-15 Hz)
    nyquist = fs / 2
    low = 5 / nyquist
    high = 15 / nyquist
    b, a = scipy_signal.butter(2, [low, high], btype='band')
    filtered = scipy_signal.filtfilt(b, a, ecg_signal)
    
    # Derivative
    diff_signal = np.diff(filtered)
    
    # Squaring
    squared = diff_signal ** 2
    
    # Moving window integration (150ms window)
    window_size = int(0.15 * fs)
    integrated = np.convolve(squared, np.ones(window_size) / window_size, mode='same')
    
    # Find peaks
    threshold = 0.35 * np.max(integrated)
    peaks = []
    refractory = int(0.2 * fs)  # 200ms refractory period
    
    for i in range(1, len(integrated) - 1):
        if integrated[i] > threshold and integrated[i] > integrated[i-1] and integrated[i] > integrated[i+1]:
            if len(peaks) == 0 or i - peaks[-1] > refractory:
                # Find actual R-peak in original signal around this location
                search_start = max(0, i - int(0.05 * fs))
                search_end = min(len(ecg_signal), i + int(0.05 * fs))
                local_max_idx = search_start + np.argmax(np.abs(ecg_signal[search_start:search_end]))
                peaks.append(local_max_idx)
    
    return np.array(peaks)

def median_filter_rr(rr_intervals, threshold=0.3):
    """
    Median filter for removing physiologically uninterpretable R-R intervals.
    Based on: Chen et al. median filtering approach.
    """
    if len(rr_intervals) < 3:
        return rr_intervals
    
    filtered = rr_intervals.copy()
    median_rr = np.median(rr_intervals)
    
    for i in range(len(rr_intervals)):
        if np.abs(rr_intervals[i] - median_rr) > threshold * median_rr:
            # Replace with median of neighbors
            if i == 0:
                filtered[i] = rr_intervals[i+1]
            elif i == len(rr_intervals) - 1:
                filtered[i] = rr_intervals[i-1]
            else:
                filtered[i] = np.median([rr_intervals[i-1], rr_intervals[i+1]])
    
    return filtered

def extract_rr_features(ecg_segment, fs=100):
    """
    Extract R-R intervals and R-peak amplitudes from ECG segment.
    Apply cubic interpolation at 3 Hz as per paper.
    """
    try:
        # Detect R-peaks
        r_peaks = detect_r_peaks_hamilton(ecg_segment, fs)
        
        if len(r_peaks) < 2:
            # Return zeros if not enough peaks
            target_length = int(60 * 3)  # 60 seconds at 3 Hz = 180 samples
            return np.zeros(target_length), np.zeros(target_length)
        
        # Calculate R-R intervals (in seconds)
        rr_intervals = np.diff(r_peaks) / fs
        
        # Apply median filter
        rr_intervals = median_filter_rr(rr_intervals)
        
        # Extract R-peak amplitudes
        r_amplitudes = ecg_segment[r_peaks[1:]]  # Align with RR intervals
        
        # Time points for interpolation
        rr_times = r_peaks[1:] / fs  # Time of each RR interval
        
        # Cubic interpolation at 3 Hz
        target_fs = 3
        target_length = int(60 * target_fs)  # 60 seconds at 3 Hz
        target_times = np.linspace(rr_times[0], rr_times[-1], target_length)
        
        # Interpolate RR intervals
        if len(rr_times) >= 4:  # Need at least 4 points for cubic spline
            cs_rr = CubicSpline(rr_times, rr_intervals)
            rr_interpolated = cs_rr(target_times)
            
            # Interpolate R-peak amplitudes
            cs_amp = CubicSpline(rr_times, r_amplitudes)
            amp_interpolated = cs_amp(target_times)
        else:
            # Fall back to linear interpolation
            rr_interpolated = np.interp(target_times, rr_times, rr_intervals)
            amp_interpolated = np.interp(target_times, rr_times, r_amplitudes)
        
        # Normalize
        rr_interpolated = (rr_interpolated - np.mean(rr_interpolated)) / (np.std(rr_interpolated) + 1e-8)
        amp_interpolated = (amp_interpolated - np.mean(amp_interpolated)) / (np.std(amp_interpolated) + 1e-8)
        
        return rr_interpolated, amp_interpolated
        
    except Exception as e:
        target_length = int(60 * 3)
        return np.zeros(target_length), np.zeros(target_length)

# ----------------------------- Enhanced Model ---------------------------

class MultiHeadSelfAttention(nn.Module):
    """Enhanced self-attention mechanism."""
    def __init__(self, d_model, num_heads=8, dropout=0.1):
        super().__init__()
        self.attention = nn.MultiheadAttention(d_model, num_heads, dropout=dropout, batch_first=True)
        self.norm = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        # x: (B, L, d_model)
        attn_out, _ = self.attention(x, x, x)
        return self.norm(x + self.dropout(attn_out))

class EnhancedResBlock(nn.Module):
    """Enhanced residual block with depthwise separable convolutions."""
    def __init__(self, channels, kernel_size=7, dilation=1):
        super().__init__()
        padding = (kernel_size - 1) * dilation // 2
        self.depthwise = nn.Conv1d(channels, channels, kernel_size, padding=padding, 
                                   groups=channels, dilation=dilation)
        self.pointwise = nn.Conv1d(channels, channels, 1)
        self.norm = nn.BatchNorm1d(channels)
        self.dropout = nn.Dropout(0.1)
        
    def forward(self, x):
        residual = x
        x = self.depthwise(x)
        x = self.pointwise(x)
        x = self.norm(x)
        x = self.dropout(x)
        return F.gelu(residual + x)

class EnhancedApneaNet(nn.Module):
    """Enhanced architecture for 90%+ accuracy with multi-modal inputs."""
    def __init__(self, d_model=192, n_blocks=10, dropout=0.2):
        super().__init__()
        
        # Separate processing for raw ECG
        self.ecg_stem = nn.Sequential(
            nn.Conv1d(1, d_model//2, kernel_size=7, padding=3, stride=2),
            nn.BatchNorm1d(d_model//2),
            nn.GELU(),
            nn.Conv1d(d_model//2, d_model, kernel_size=5, padding=2, stride=2),
            nn.BatchNorm1d(d_model),
            nn.GELU(),
        )
        
        # Separate processing for RR intervals (180 samples at 3Hz)
        self.rr_processor = nn.Sequential(
            nn.Conv1d(1, d_model//2, kernel_size=5, padding=2),
            nn.BatchNorm1d(d_model//2),
            nn.GELU(),
            nn.Conv1d(d_model//2, d_model, kernel_size=5, padding=2),
            nn.BatchNorm1d(d_model),
            nn.GELU(),
        )
        
        # Separate processing for R-peak amplitudes
        self.amp_processor = nn.Sequential(
            nn.Conv1d(1, d_model//2, kernel_size=5, padding=2),
            nn.BatchNorm1d(d_model//2),
            nn.GELU(),
            nn.Conv1d(d_model//2, d_model, kernel_size=5, padding=2),
            nn.BatchNorm1d(d_model),
            nn.GELU(),
        )
        
        # Fusion layer
        self.fusion = nn.Conv1d(d_model * 3, d_model, 1)
        
        # Enhanced residual blocks with varying receptive fields and dilations
        self.blocks = nn.ModuleList([
            EnhancedResBlock(d_model, kernel_size=7 if i % 3 == 0 else (11 if i % 3 == 1 else 5),
                           dilation=1 if i < n_blocks//2 else 2)
            for i in range(n_blocks)
        ])
        
        # Multi-head self-attention layers
        self.attention_layers = nn.ModuleList([
            MultiHeadSelfAttention(d_model, num_heads=8, dropout=dropout)
            for _ in range(2)
        ])
        
        # Channel attention
        self.channel_attn = nn.Sequential(
            nn.AdaptiveAvgPool1d(1),
            nn.Conv1d(d_model, d_model//4, 1),
            nn.GELU(),
            nn.Conv1d(d_model//4, d_model, 1),
            nn.Sigmoid()
        )
        
        # Enhanced classification head
        self.classifier = nn.Sequential(
            nn.Linear(d_model * 4, d_model * 2),
            nn.BatchNorm1d(d_model * 2),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(d_model * 2, d_model),
            nn.BatchNorm1d(d_model),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(d_model, 2)
        )
        
    def forward(self, ecg, rr, amp):
        # Process ECG: (B, 6000, 1) -> (B, 1, 6000)
        ecg = ecg.transpose(1, 2)
        ecg_feat = self.ecg_stem(ecg)  # (B, d_model, L/4)
        
        # Process RR intervals: (B, 180, 1) -> (B, 1, 180)
        rr = rr.transpose(1, 2)
        rr_feat = self.rr_processor(rr)  # (B, d_model, 180)
        
        # Process R-peak amplitudes: (B, 180, 1) -> (B, 1, 180)
        amp = amp.transpose(1, 2)
        amp_feat = self.amp_processor(amp)  # (B, d_model, 180)
        
        # Align temporal dimensions
        target_length = min(ecg_feat.size(2), rr_feat.size(2), amp_feat.size(2))
        ecg_feat = F.adaptive_avg_pool1d(ecg_feat, target_length)
        rr_feat = F.adaptive_avg_pool1d(rr_feat, target_length)
        amp_feat = F.adaptive_avg_pool1d(amp_feat, target_length)
        
        # Fuse features
        x = torch.cat([ecg_feat, rr_feat, amp_feat], dim=1)  # (B, 3*d_model, L)
        x = self.fusion(x)  # (B, d_model, L)
        
        # Residual blocks
        for block in self.blocks:
            x = block(x)
        
        # Channel attention
        attn_weights = self.channel_attn(x)
        x = x * attn_weights
        
        # Self-attention
        x_seq = x.transpose(1, 2)  # (B, L, d_model)
        for attn_layer in self.attention_layers:
            x_seq = attn_layer(x_seq)
        x = x_seq.transpose(1, 2)  # (B, d_model, L)
        
        # Global pooling
        x_avg = F.adaptive_avg_pool1d(x, 1).squeeze(-1)
        x_max = F.adaptive_max_pool1d(x, 1).squeeze(-1)
        x_std = torch.std(x, dim=2)
        x_last = x[:, :, -1]
        
        # Combine features
        x_combined = torch.cat([x_avg, x_max, x_std, x_last], dim=1)
        
        # Classify
        logits = self.classifier(x_combined)
        return logits

# --------------------------- Enhanced Dataset ---------------------------

class EnhancedApneaDataset(Dataset):
    """Enhanced dataset with R-R interval and R-peak amplitude extraction."""

    def __init__(self, data_dir: str, record_names: list = None, cache_dir: str = None,
                 segment_length: int = 6000, stride: int = 3000, split='train', augment=True):
        super().__init__()
        self.segment_length = int(segment_length)
        self.stride = int(stride)
        self.split = split
        self.augment = augment and (split == 'train')
        
        cache_dir = Path(cache_dir) if cache_dir else Path(data_dir)
        cache_file = cache_dir / f'enhanced_apnea_{split}_{segment_length}_{stride}_v2.pt'

        if cache_file.exists():
            print(f"Loading cached {split} from {cache_file}")
            data = torch.load(cache_file)
            self.ecg_segments = data['ecg_segments']
            self.rr_segments = data['rr_segments']
            self.amp_segments = data['amp_segments']
            self.labels = data['labels']
        else:
            assert wfdb is not None, "wfdb required"
            assert record_names is not None, "record_names required"
            
            self.ecg_segments = []
            self.rr_segments = []
            self.amp_segments = []
            self.labels = []
            self.data_dir = Path(data_dir)
            
            print(f"Processing {len(record_names)} records for {split}...")
            for i, rec in enumerate(record_names):
                print(f"  [{i+1}/{len(record_names)}] {rec}...", end='\r')
                self._load_record(rec)
            
            if len(self.ecg_segments) == 0:
                raise RuntimeError("No segments loaded")
            
            self.ecg_segments = torch.tensor(np.stack(self.ecg_segments, axis=0), dtype=torch.float32)
            self.rr_segments = torch.tensor(np.stack(self.rr_segments, axis=0), dtype=torch.float32)
            self.amp_segments = torch.tensor(np.stack(self.amp_segments, axis=0), dtype=torch.float32)
            self.labels = torch.tensor(self.labels, dtype=torch.long)
            
            print(f"\nSaving cache to {cache_file}")
            torch.save({
                'ecg_segments': self.ecg_segments,
                'rr_segments': self.rr_segments,
                'amp_segments': self.amp_segments,
                'labels': self.labels
            }, cache_file)

        if self.ecg_segments.ndim == 2:
            self.ecg_segments = self.ecg_segments.unsqueeze(-1)
        if self.rr_segments.ndim == 2:
            self.rr_segments = self.rr_segments.unsqueeze(-1)
        if self.amp_segments.ndim == 2:
            self.amp_segments = self.amp_segments.unsqueeze(-1)

        print(f"{split.capitalize()}: {len(self.ecg_segments)} segments, "
              f"Class: {Counter(self.labels.tolist())}")

    def _load_record(self, record_name: str):
        try:
            record = wfdb.rdrecord(str(self.data_dir / record_name))
            signal = record.p_signal[:, 0].astype(np.float32)
    
            if np.isnan(signal).any():
                nans = np.isnan(signal)
                not_nans = ~nans
                if not_nans.sum() > 0:
                    signal[nans] = np.interp(np.flatnonzero(nans), np.flatnonzero(not_nans), signal[not_nans])
                else:
                    signal = np.zeros_like(signal)
    
            annotation = wfdb.rdann(str(self.data_dir / record_name), 'apn')
            
            n_minutes = len(signal) // 6000
            minute_labels = np.zeros(n_minutes, dtype=int)
            
            for i, symbol in enumerate(annotation.symbol):
                if symbol == 'A':
                    sample = annotation.sample[i]
                    minute = sample // 6000
                    if minute < n_minutes:
                        minute_labels[minute] = 1
    
            n_samples = len(signal)
            for start in range(0, n_samples - self.segment_length + 1, self.stride):
                end = start + self.segment_length
                ecg_seg = signal[start:end].astype(np.float32)
    
                # Normalize ECG
                seg_mean = np.nanmean(ecg_seg)
                seg_std = np.nanstd(ecg_seg)
                if np.isnan(seg_std) or seg_std < 1e-8:
                    ecg_seg = ecg_seg - seg_mean
                else:
                    ecg_seg = (ecg_seg - seg_mean) / (seg_std + 1e-8)
                ecg_seg = np.clip(ecg_seg, -10, 10)
                
                # Extract R-R intervals and R-peak amplitudes
                rr_seg, amp_seg = extract_rr_features(signal[start:end], fs=100)
    
                minute = start // 6000
                if minute < len(minute_labels):
                    label = minute_labels[minute]
                    self.ecg_segments.append(ecg_seg)
                    self.rr_segments.append(rr_seg)
                    self.amp_segments.append(amp_seg)
                    self.labels.append(int(label))
                    
        except Exception as e:
            print(f"\nError loading {record_name}: {e}")
    
    def _augment(self, ecg, rr, amp):
        """Enhanced augmentation for all modalities."""
        ecg = ecg.numpy() if torch.is_tensor(ecg) else ecg
        rr = rr.numpy() if torch.is_tensor(rr) else rr
        amp = amp.numpy() if torch.is_tensor(amp) else amp
        
        if np.random.random() < 0.3:
            # Gaussian noise
            ecg = ecg + np.random.normal(0, 0.02, ecg.shape).astype(np.float32)
            rr = rr + np.random.normal(0, 0.02, rr.shape).astype(np.float32)
            amp = amp + np.random.normal(0, 0.02, amp.shape).astype(np.float32)
        
        if np.random.random() < 0.2:
            # Scale
            scale = np.random.uniform(0.95, 1.05)
            ecg = ecg * scale
            amp = amp * scale
        
        return torch.from_numpy(ecg), torch.from_numpy(rr), torch.from_numpy(amp)
            
    def __len__(self):
        return self.ecg_segments.shape[0]

    def __getitem__(self, idx):
        ecg = self.ecg_segments[idx]
        rr = self.rr_segments[idx]
        amp = self.amp_segments[idx]
        label = self.labels[idx]
        
        if self.augment:
            ecg, rr, amp = self._augment(ecg, rr, amp)
        
        # Ensure correct shapes
        if ecg.ndim == 1:
            ecg = ecg.unsqueeze(-1)
        if rr.ndim == 1:
            rr = rr.unsqueeze(-1)
        if amp.ndim == 1:
            amp = amp.unsqueeze(-1)
        
        return ecg, rr, amp, label

# -------------------------- Training with Metrics ------------------------

def compute_class_weights(labels_tensor):
    counts = Counter(labels_tensor.tolist())
    total = sum(counts.values())
    num_classes = len(counts)
    weights = [total / (num_classes * counts.get(i, 1)) for i in range(num_classes)]
    return torch.tensor(weights, dtype=torch.float32)

def compute_metrics(y_true, y_pred):
    """Compute sensitivity, specificity, and other metrics."""
    tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()
    
    sensitivity = tp / (tp + fn) if (tp + fn) > 0 else 0  # Also called recall or TPR
    specificity = tn / (tn + fp) if (tn + fp) > 0 else 0  # TNR
    precision = tp / (tp + fp) if (tp + fp) > 0 else 0
    f1 = 2 * precision * sensitivity / (precision + sensitivity) if (precision + sensitivity) > 0 else 0
    accuracy = (tp + tn) / (tp + tn + fp + fn)
    
    return {
        'sensitivity': sensitivity,
        'specificity': specificity,
        'precision': precision,
        'f1': f1,
        'accuracy': accuracy
    }

def train_epoch(model, dataloader, criterion, optimizer, scheduler, device, epoch, scaler=None):
    model.train()
    total_loss = 0.0
    correct = 0
    total = 0
    
    num_batches = len(dataloader)
    print_freq = max(1, num_batches // 15)
    start_time = time.time()

    for batch_idx, (ecg, rr, amp, target) in enumerate(dataloader, 1):
        ecg = ecg.to(device, non_blocking=True)
        rr = rr.to(device, non_blocking=True)
        amp = amp.to(device, non_blocking=True)
        target = target.to(device, non_blocking=True)

        optimizer.zero_grad(set_to_none=True)

        with torch.amp.autocast('cuda', enabled=(device.type == 'cuda')):
            output = model(ecg, rr, amp)
            loss = criterion(output, target)
        
        if torch.isnan(loss):
            print(f"\nWARNING: NaN loss, skipping batch {batch_idx}")
            continue

        if scaler is not None:
            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            scaler.step(optimizer)
            scaler.update()
        else:
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
        
        scheduler.step()

        total_loss += loss.item()
        pred = output.argmax(dim=1)
        correct += pred.eq(target).sum().item()
        total += target.size(0)
        
        if batch_idx % print_freq == 0 or batch_idx == num_batches:
            curr_acc = 100.0 * correct / total
            curr_loss = total_loss / batch_idx
            speed = batch_idx / (time.time() - start_time)
            eta = (num_batches - batch_idx) / speed if speed > 0 else 0
            
            print(f"  Ep {epoch} [{batch_idx:4d}/{num_batches}] "
                  f"Loss: {curr_loss:.4f} Acc: {curr_acc:.2f}% "
                  f"({speed:.1f} b/s, ETA: {eta:.0f}s)", end='\r')

    print()
    avg_loss = total_loss / len(dataloader)
    accuracy = 100.0 * correct / total if total > 0 else 0.0
    return avg_loss, accuracy

def validate(model, dataloader, criterion, device):
    model.eval()
    total_loss = 0.0
    all_preds = []
    all_targets = []
    all_probs = []

    with torch.no_grad():
        for ecg, rr, amp, target in dataloader:
            ecg = ecg.to(device, non_blocking=True)
            rr = rr.to(device, non_blocking=True)
            amp = amp.to(device, non_blocking=True)
            target = target.to(device, non_blocking=True)
            output = model(ecg, rr, amp)
            loss = criterion(output, target)

            total_loss += loss.item()
            probs = F.softmax(output, dim=1)[:, 1]
            pred = output.argmax(dim=1)

            all_preds.extend(pred.cpu().tolist())
            all_targets.extend(target.cpu().tolist())
            all_probs.extend(probs.cpu().tolist())

    avg_loss = total_loss / len(dataloader)
    
    # Compute all metrics including sensitivity and specificity
    metrics = compute_metrics(all_targets, all_preds)
    
    try:
        auc = roc_auc_score(all_targets, all_probs)
    except:
        auc = 0.0
    
    return avg_loss, metrics, np.array(all_probs), auc

# ------------------------------ Main ------------------------------------

def main(args):
    set_seed(args.seed)

    DATA_DIR = Path(args.data_dir)
    if not DATA_DIR.exists():
        raise FileNotFoundError(f"Data directory not found: {DATA_DIR}")

    record_files = list(DATA_DIR.glob('*.hea'))
    all_records = [f.stem for f in record_files]
    valid_records = [rec for rec in all_records 
                    if (DATA_DIR / (rec + '.apn')).exists() and not rec.endswith('er')]
    
    if len(valid_records) == 0:
        raise RuntimeError("No valid records found")

    print(f"Found {len(valid_records)} valid records")

    import rand

In [4]:
#!/usr/bin/env python3
"""
Enhanced high-performance model for 90%+ apnea detection accuracy.
Includes R-R interval extraction, R-peak amplitude processing, and advanced metrics.
Optimized for Tesla P100 16GB GPU.
"""

import argparse
import os
import time
from pathlib import Path
from collections import Counter

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

try:
    import wfdb
except Exception:
    wfdb = None

try:
    from scipy import signal as scipy_signal
    from scipy.interpolate import interp1d
except:
    scipy_signal = None
    interp1d = None

from sklearn.metrics import roc_auc_score, precision_score, recall_score, f1_score, confusion_matrix

# ----------------------------- Utilities ---------------------------------

def set_seed(seed=42):
    import random
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)


# ----------------------------- R-Peak Detection --------------------------

def detect_r_peaks(ecg_signal, fs=100):
    """
    Detect R-peaks using a simple but effective algorithm.
    Based on Hamilton method principles.
    """
    # Bandpass filter (5-15 Hz for QRS complex)
    if scipy_signal is None:
        # Fallback: simple moving average
        window = int(0.12 * fs)
        filtered = np.convolve(ecg_signal, np.ones(window)/window, mode='same')
    else:
        nyquist = fs / 2
        low = 5 / nyquist
        high = 15 / nyquist
        b, a = scipy_signal.butter(2, [low, high], btype='band')
        filtered = scipy_signal.filtfilt(b, a, ecg_signal)
    
    # Derivative (emphasize QRS slope)
    diff = np.diff(filtered)
    diff = np.abs(diff)
    
    # Moving average
    window = int(0.08 * fs)
    integrated = np.convolve(diff, np.ones(window)/window, mode='same')
    
    # Threshold-based detection
    threshold = np.mean(integrated) + 0.5 * np.std(integrated)
    
    # Find peaks
    peaks = []
    refractory = int(0.2 * fs)  # 200ms refractory period
    
    for i in range(1, len(integrated) - 1):
        if integrated[i] > threshold and integrated[i] > integrated[i-1] and integrated[i] > integrated[i+1]:
            if len(peaks) == 0 or i - peaks[-1] > refractory:
                # Find actual peak in original signal around this point
                search_window = int(0.05 * fs)
                start = max(0, i - search_window)
                end = min(len(ecg_signal), i + search_window)
                actual_peak = start + np.argmax(ecg_signal[start:end])
                peaks.append(actual_peak)
    
    return np.array(peaks)


def median_filter_rr(rr_intervals, threshold=0.3):
    """
    Remove physiologically uninterpretable R-R intervals using median filter.
    As described in Chen et al.
    """
    if len(rr_intervals) < 3:
        return rr_intervals
    
    filtered = []
    for i, rr in enumerate(rr_intervals):
        if i == 0 or i == len(rr_intervals) - 1:
            filtered.append(rr)
            continue
        
        # Get local median
        local = rr_intervals[max(0, i-2):min(len(rr_intervals), i+3)]
        median = np.median(local)
        
        # Check if current RR is within threshold of median
        if abs(rr - median) / median < threshold:
            filtered.append(rr)
        else:
            filtered.append(median)
    
    return np.array(filtered)


def extract_rr_features(ecg_signal, fs=100, target_length=180):
    """
    Extract R-R intervals and R-peak amplitudes from ECG signal.
    Apply cubic interpolation at 3 Hz as described in the paper.
    """
    # Detect R-peaks
    r_peaks = detect_r_peaks(ecg_signal, fs)
    
    if len(r_peaks) < 2:
        # Return zeros if no valid peaks
        return np.zeros(target_length), np.zeros(target_length)
    
    # Extract R-R intervals (in seconds)
    rr_intervals = np.diff(r_peaks) / fs
    
    # Apply median filter
    rr_intervals = median_filter_rr(rr_intervals)
    
    # Extract R-peak amplitudes
    r_amplitudes = ecg_signal[r_peaks[1:]]  # Align with RR intervals
    
    # Create time axis for interpolation
    time_rr = np.cumsum(rr_intervals)
    time_rr = np.concatenate([[0], time_rr])
    
    # Extend arrays to match
    rr_intervals_ext = np.concatenate([[rr_intervals[0]], rr_intervals])
    r_amplitudes_ext = np.concatenate([[r_amplitudes[0]], r_amplitudes])
    
    # Target time axis at 3 Hz
    total_time = len(ecg_signal) / fs
    target_time = np.linspace(0, total_time, target_length)
    
    # Cubic interpolation
    if interp1d is not None and len(time_rr) > 3:
        try:
            f_rr = interp1d(time_rr, rr_intervals_ext, kind='cubic', fill_value='extrapolate')
            f_amp = interp1d(time_rr, r_amplitudes_ext, kind='cubic', fill_value='extrapolate')
            
            rr_interpolated = f_rr(target_time)
            amp_interpolated = f_amp(target_time)
        except:
            # Fallback to linear
            f_rr = interp1d(time_rr, rr_intervals_ext, kind='linear', fill_value='extrapolate')
            f_amp = interp1d(time_rr, r_amplitudes_ext, kind='linear', fill_value='extrapolate')
            
            rr_interpolated = f_rr(target_time)
            amp_interpolated = f_amp(target_time)
    else:
        # Simple linear interpolation fallback
        rr_interpolated = np.interp(target_time, time_rr, rr_intervals_ext)
        amp_interpolated = np.interp(target_time, time_rr, r_amplitudes_ext)
    
    # Clip outliers
    rr_interpolated = np.clip(rr_interpolated, 0.3, 2.0)  # 30-200 bpm
    
    return rr_interpolated, amp_interpolated


# ----------------------------- Enhanced Model ---------------------------

class MultiScaleConvBlock(nn.Module):
    """Multi-scale convolution block for capturing different temporal patterns."""
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv1 = nn.Conv1d(in_channels, out_channels // 3, kernel_size=3, padding=1)
        self.conv2 = nn.Conv1d(in_channels, out_channels // 3, kernel_size=5, padding=2)
        self.conv3 = nn.Conv1d(in_channels, out_channels // 3, kernel_size=7, padding=3)
        self.bn = nn.BatchNorm1d(out_channels)
        
    def forward(self, x):
        x1 = self.conv1(x)
        x2 = self.conv2(x)
        x3 = self.conv3(x)
        x = torch.cat([x1, x2, x3], dim=1)
        x = self.bn(x)
        return F.gelu(x)


class EnhancedResBlock(nn.Module):
    """Enhanced residual block with squeeze-excitation attention."""
    def __init__(self, channels, kernel_size=7):
        super().__init__()
        self.conv1 = nn.Conv1d(channels, channels, kernel_size, padding=kernel_size//2, groups=channels)
        self.conv2 = nn.Conv1d(channels, channels, 1)
        self.bn = nn.BatchNorm1d(channels)
        
        # Squeeze-and-Excitation
        self.se = nn.Sequential(
            nn.AdaptiveAvgPool1d(1),
            nn.Conv1d(channels, channels // 4, 1),
            nn.GELU(),
            nn.Conv1d(channels // 4, channels, 1),
            nn.Sigmoid()
        )
        self.dropout = nn.Dropout(0.15)
        
    def forward(self, x):
        residual = x
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.bn(x)
        
        # SE attention
        se_weight = self.se(x)
        x = x * se_weight
        
        x = self.dropout(x)
        return F.gelu(residual + x)


class BiLSTMBlock(nn.Module):
    """Bidirectional LSTM for temporal pattern learning."""
    def __init__(self, input_size, hidden_size, num_layers=2, dropout=0.2):
        super().__init__()
        self.lstm = nn.LSTM(
            input_size, hidden_size, num_layers,
            batch_first=True, bidirectional=True, dropout=dropout if num_layers > 1 else 0
        )
        self.layer_norm = nn.LayerNorm(hidden_size * 2)
        
    def forward(self, x):
        # x: (B, C, L) -> (B, L, C)
        x = x.transpose(1, 2)
        x, _ = self.lstm(x)
        x = self.layer_norm(x)
        # (B, L, C) -> (B, C, L)
        return x.transpose(1, 2)


class EnhancedApneaNet(nn.Module):
    """
    Enhanced architecture combining:
    - Multi-channel input (raw ECG, R-R intervals, R-peak amplitudes)
    - Multi-scale convolutions
    - Residual blocks with SE attention
    - BiLSTM for temporal modeling
    - Multi-head attention
    """
    def __init__(self, d_model=160, n_blocks=8, dropout=0.2):
        super().__init__()
        
        # Three-channel processing: raw ECG, RR intervals, R-peak amplitudes
        self.ecg_stem = nn.Sequential(
            nn.Conv1d(1, d_model // 3, kernel_size=7, padding=3, stride=2),
            nn.BatchNorm1d(d_model // 3),
            nn.GELU(),
        )
        
        self.rr_stem = nn.Sequential(
            nn.Conv1d(1, d_model // 3, kernel_size=5, padding=2),
            nn.BatchNorm1d(d_model // 3),
            nn.GELU(),
        )
        
        self.amp_stem = nn.Sequential(
            nn.Conv1d(1, d_model // 3, kernel_size=5, padding=2),
            nn.BatchNorm1d(d_model // 3),
            nn.GELU(),
        )
        
        # Multi-scale fusion
        self.fusion = MultiScaleConvBlock(d_model, d_model)
        
        # Residual blocks
        self.blocks = nn.ModuleList([
            EnhancedResBlock(d_model, kernel_size=7 if i % 2 == 0 else 11)
            for i in range(n_blocks)
        ])
        
        # BiLSTM for temporal dependencies
        self.bilstm = BiLSTMBlock(d_model, d_model // 2, num_layers=2, dropout=dropout)
        
        # Multi-head attention
        self.attention = nn.MultiheadAttention(d_model, num_heads=8, dropout=dropout, batch_first=True)
        self.attn_norm = nn.LayerNorm(d_model)
        
        # Global pooling
        self.gap = nn.AdaptiveAvgPool1d(1)
        self.gmp = nn.AdaptiveMaxPool1d(1)
        
        # Classification head with more capacity
        self.classifier = nn.Sequential(
            nn.Linear(d_model * 3, d_model * 2),
            nn.BatchNorm1d(d_model * 2),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(d_model * 2, d_model),
            nn.BatchNorm1d(d_model),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(d_model, 2)
        )
        
    def forward(self, x):
        # x: (B, L, 3) - [raw_ecg, rr_intervals, r_amplitudes]
        
        # Split channels
        ecg = x[:, :, 0:1].transpose(1, 2)  # (B, 1, L)
        rr = x[:, :, 1:2].transpose(1, 2)   # (B, 1, L')
        amp = x[:, :, 2:3].transpose(1, 2)  # (B, 1, L')
        
        # Process each channel
        ecg_feat = self.ecg_stem(ecg)  # (B, d/3, L/2)
        
        # Upsample RR and amp to match ECG if needed
        if rr.shape[-1] != ecg_feat.shape[-1]:
            rr = F.interpolate(rr, size=ecg_feat.shape[-1], mode='linear', align_corners=False)
            amp = F.interpolate(amp, size=ecg_feat.shape[-1], mode='linear', align_corners=False)
        
        rr_feat = self.rr_stem(rr)    # (B, d/3, L')
        amp_feat = self.amp_stem(amp) # (B, d/3, L')
        
        # Concatenate and fuse
        x = torch.cat([ecg_feat, rr_feat, amp_feat], dim=1)  # (B, d, L')
        x = self.fusion(x)
        
        # Residual blocks
        for block in self.blocks:
            x = block(x)
        
        # BiLSTM
        x = self.bilstm(x)  # (B, d, L')
        
        # Attention
        x_seq = x.transpose(1, 2)  # (B, L', d)
        x_attn, _ = self.attention(x_seq, x_seq, x_seq)
        x_attn = self.attn_norm(x_attn + x_seq)
        x = x_attn.transpose(1, 2)  # (B, d, L')
        
        # Global pooling
        x_avg = self.gap(x).squeeze(-1)  # (B, d)
        x_max = self.gmp(x).squeeze(-1)  # (B, d)
        x_attn_pool = x_attn.mean(dim=1)  # (B, d)
        
        # Combine
        x_combined = torch.cat([x_avg, x_max, x_attn_pool], dim=1)  # (B, 3*d)
        
        # Classify
        logits = self.classifier(x_combined)
        return logits


# --------------------------- Enhanced Dataset ---------------------------

class EnhancedApneaDataset(Dataset):
    """
    Enhanced dataset with R-R interval and R-peak amplitude extraction.
    """

    def __init__(self, data_dir: str, record_names: list = None, cache_dir: str = None,
                 segment_length: int = 6000, stride: int = 3000, split='train', augment=True):
        super().__init__()
        self.segment_length = int(segment_length)
        self.stride = int(stride)
        self.split = split
        self.augment = augment and (split == 'train')
        self.fs = 100  # Sampling frequency
        self.rr_length = 180  # 60s * 3Hz
        
        cache_dir = Path(cache_dir) if cache_dir else Path(data_dir)
        cache_file = cache_dir / f'enhanced_apnea_{split}_{segment_length}_{stride}.pt'

        if cache_file.exists():
            print(f"Loading cached {split} from {cache_file}")
            data = torch.load(cache_file)
            self.segments = data['segments']
            self.labels = data['labels']
        else:
            assert wfdb is not None, "wfdb required"
            assert record_names is not None, "record_names required"
            
            self.segments = []
            self.labels = []
            self.data_dir = Path(data_dir)
            
            print(f"Processing {len(record_names)} records for {split} with R-R extraction...")
            for i, rec in enumerate(record_names):
                print(f"  [{i+1}/{len(record_names)}] {rec}...", end='\r')
                self._load_record(rec)
            
            if len(self.segments) == 0:
                raise RuntimeError("No segments loaded")
            
            self.segments = torch.tensor(np.stack(self.segments, axis=0), dtype=torch.float32)
            self.labels = torch.tensor(self.labels, dtype=torch.long)
            
            print(f"\nSaving cache to {cache_file}")
            torch.save({'segments': self.segments, 'labels': self.labels}, cache_file)

        print(f"{split.capitalize()}: {len(self.segments)} segments (3-channel), "
              f"Class: {Counter(self.labels.tolist())}")

    def _load_record(self, record_name: str):
        try:
            record = wfdb.rdrecord(str(self.data_dir / record_name))
            signal = record.p_signal[:, 0].astype(np.float32)
    
            # Handle NaNs
            if np.isnan(signal).any():
                nans = np.isnan(signal)
                not_nans = ~nans
                if not_nans.sum() > 0:
                    signal[nans] = np.interp(np.flatnonzero(nans), np.flatnonzero(not_nans), signal[not_nans])
                else:
                    signal = np.zeros_like(signal)
    
            annotation = wfdb.rdann(str(self.data_dir / record_name), 'apn')
            
            n_minutes = len(signal) // 6000
            minute_labels = np.zeros(n_minutes, dtype=int)
            
            for i, symbol in enumerate(annotation.symbol):
                if symbol == 'A':
                    sample = annotation.sample[i]
                    minute = sample // 6000
                    if minute < n_minutes:
                        minute_labels[minute] = 1
    
            n_samples = len(signal)
            for start in range(0, n_samples - self.segment_length + 1, self.stride):
                end = start + self.segment_length
                seg = signal[start:end].astype(np.float32)
    
                # Normalize raw ECG
                seg_mean = np.nanmean(seg)
                seg_std = np.nanstd(seg)
                if np.isnan(seg_std) or seg_std < 1e-8:
                    seg_norm = seg - seg_mean
                else:
                    seg_norm = (seg - seg_mean) / (seg_std + 1e-8)
                seg_norm = np.clip(seg_norm, -10, 10)
                
                # Extract R-R intervals and R-peak amplitudes
                rr_intervals, r_amplitudes = extract_rr_features(seg, self.fs, self.rr_length)
                
                # Normalize RR intervals
                rr_mean = np.mean(rr_intervals)
                rr_std = np.std(rr_intervals)
                if rr_std > 1e-8:
                    rr_intervals = (rr_intervals - rr_mean) / (rr_std + 1e-8)
                else:
                    rr_intervals = rr_intervals - rr_mean
                rr_intervals = np.clip(rr_intervals, -10, 10)
                
                # Normalize R-peak amplitudes
                amp_mean = np.mean(r_amplitudes)
                amp_std = np.std(r_amplitudes)
                if amp_std > 1e-8:
                    r_amplitudes = (r_amplitudes - amp_mean) / (amp_std + 1e-8)
                else:
                    r_amplitudes = r_amplitudes - amp_mean
                r_amplitudes = np.clip(r_amplitudes, -10, 10)
                
                # Interpolate RR and amplitudes to match segment length
                if len(rr_intervals) != self.segment_length:
                    rr_intervals = np.interp(
                        np.linspace(0, 1, self.segment_length),
                        np.linspace(0, 1, len(rr_intervals)),
                        rr_intervals
                    )
                    r_amplitudes = np.interp(
                        np.linspace(0, 1, self.segment_length),
                        np.linspace(0, 1, len(r_amplitudes)),
                        r_amplitudes
                    )
                
                # Stack: [raw_ecg, rr_intervals, r_amplitudes]
                multi_channel = np.stack([seg_norm, rr_intervals, r_amplitudes], axis=-1)
    
                minute = start // 6000
                if minute < len(minute_labels):
                    label = minute_labels[minute]
                    self.segments.append(multi_channel)
                    self.labels.append(int(label))
                    
        except Exception as e:
            print(f"\nError loading {record_name}: {e}")
    
    def _augment(self, seg):
        """Enhanced augmentation for multi-channel data."""
        seg = seg.numpy() if torch.is_tensor(seg) else seg
        
        # Random noise on raw ECG channel
        if np.random.random() < 0.3:
            seg[:, 0] = seg[:, 0] + np.random.normal(0, 0.03, seg.shape[0]).astype(np.float32)
        
        # Random scaling on all channels
        if np.random.random() < 0.2:
            scale = np.random.uniform(0.95, 1.05)
            seg = seg * scale
        
        # Time shift
        if np.random.random() < 0.3:
            shift = np.random.randint(-100, 100)
            seg = np.roll(seg, shift, axis=0)
        
        return torch.from_numpy(seg)
            
    def __len__(self):
        return self.segments.shape[0]

    def __getitem__(self, idx):
        seg = self.segments[idx]
        label = self.labels[idx]
        
        if self.augment:
            seg = self._augment(seg)
        
        return seg, label


# -------------------------- Training & Metrics ------------------------

def compute_class_weights(labels_tensor):
    """Compute balanced class weights with stronger emphasis on minority class."""
    counts = Counter(labels_tensor.tolist())
    total = sum(counts.values())
    num_classes = len(counts)
    # Use sqrt for more balanced weighting
    weights = [np.sqrt(total / (num_classes * counts.get(i, 1))) for i in range(num_classes)]
    return torch.tensor(weights, dtype=torch.float32)


def compute_metrics(y_true, y_pred, y_probs):
    """Compute comprehensive metrics including sensitivity and specificity."""
    cm = confusion_matrix(y_true, y_pred)
    
    # Sensitivity (Recall) = TP / (TP + FN)
    sensitivity = recall_score(y_true, y_pred, zero_division=0)
    
    # Specificity = TN / (TN + FP)
    if cm.shape == (2, 2):
        tn, fp, fn, tp = cm.ravel()
        specificity = tn / (tn + fp) if (tn + fp) > 0 else 0
    else:
        specificity = 0
    
    precision = precision_score(y_true, y_pred, zero_division=0)
    f1 = f1_score(y_true, y_pred, zero_division=0)
    
    try:
        auc = roc_auc_score(y_true, y_probs)
    except:
        auc = 0.0
    
    return {
        'sensitivity': sensitivity,
        'specificity': specificity,
        'precision': precision,
        'f1': f1,
        'auc': auc
    }


def train_epoch(model, dataloader, criterion, optimizer, scheduler, device, epoch, scaler=None):
    model.train()
    total_loss = 0.0
    correct = 0
    total = 0
    
    num_batches = len(dataloader)
    print_freq = max(1, num_batches // 15)
    start_time = time.time()

    for batch_idx, (data, target) in enumerate(dataloader, 1):
        data = data.to(device, non_blocking=True)
        target = target.to(device, non_blocking=True)

        optimizer.zero_grad(set_to_none=True)

        with torch.amp.autocast('cuda', enabled=(device.type == 'cuda')):
            output = model(data)
            loss = criterion(output, target)
        
        if torch.isnan(loss):
            print(f"\nWARNING: NaN loss, skipping batch {batch_idx}")
            continue

        if scaler is not None:
            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            scaler.step(optimizer)
            scaler.update()
        else:
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
        
        scheduler.step()

        total_loss += loss.item()
        pred = output.argmax(dim=1)
        correct += pred.eq(target).sum().item()
        total += target.size(0)
        
        if batch_idx % print_freq == 0 or batch_idx == num_batches:
            curr_acc = 100.0 * correct / total
            curr_loss = total_loss / batch_idx
            speed = batch_idx / (time.time() - start_time)
            eta = (num_batches - batch_idx) / speed if speed > 0 else 0
            
            print(f"  Ep {epoch} [{batch_idx:4d}/{num_batches}] "
                  f"Loss: {curr_loss:.4f} Acc: {curr_acc:.2f}% "
                  f"({speed:.1f} b/s, ETA: {eta:.0f}s)", end='\r')

    print()
    avg_loss = total_loss / len(dataloader)
    accuracy = 100.0 * correct / total if total > 0 else 0.0
    return avg_loss, accuracy


def validate(model, dataloader, criterion, device):
    model.eval()
    total_loss = 0.0
    correct = 0
    total = 0
    all_preds = []
    all_targets = []
    all_probs = []

    with torch.no_grad():
        for data, target in dataloader:
            data = data.to(device, non_blocking=True)
            target = target.to(device, non_blocking=True)
            output = model(data)
            loss = criterion(output, target)

            total_loss += loss.item()
            probs = F.softmax(output, dim=1)[:, 1]
            pred = output.argmax(dim=1)

            correct += pred.eq(target).sum().item()
            total += target.size(0)

            all_preds.extend(pred.cpu().tolist())
            all_targets.extend(target.cpu().tolist())
            all_probs.extend(probs.cpu().tolist())

    avg_loss = total_loss / len(dataloader)
    accuracy = 100.0 * correct / total if total > 0 else 0.0
    
    metrics = compute_metrics(all_targets, all_preds, all_probs)
    
    return avg_loss, accuracy, np.array(all_preds), np.array(all_targets), np.array(all_probs), metrics


# ------------------------------ Main ------------------------------------

def main(args):
    set_seed(args.seed)

    DATA_

In [8]:
#!/usr/bin/env python3
"""
High-performance apnea detection with R-R interval extraction (Target: 90%+ accuracy)
Based on PhysioNet Apnea-ECG Database methodology
"""

import argparse
import os
import time
from pathlib import Path
from collections import Counter

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from scipy import signal as scipy_signal
from scipy.interpolate import interp1d

try:
    import wfdb
except Exception:
    wfdb = None

from sklearn.metrics import roc_auc_score, precision_score, recall_score, f1_score, confusion_matrix

# ----------------------------- Utilities ---------------------------------

def set_seed(seed=42):
    import random
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

# ----------------------------- R-Peak Detection & R-R Interval Extraction ---------------------------------

def detect_r_peaks_hamilton(ecg_signal, fs=100):
    """
    Hamilton R-peak detection algorithm
    Returns indices of R-peaks
    """
    # Bandpass filter (5-15 Hz)
    b, a = scipy_signal.butter(2, [5, 15], btype='band', fs=fs)
    filtered = scipy_signal.filtfilt(b, a, ecg_signal)
    
    # Derivative
    diff_signal = np.diff(filtered)
    
    # Squaring
    squared = diff_signal ** 2
    
    # Moving average integration (150ms window)
    window_size = int(0.15 * fs)
    integrated = np.convolve(squared, np.ones(window_size)/window_size, mode='same')
    
    # Find peaks
    threshold = np.mean(integrated) + 0.5 * np.std(integrated)
    peaks = []
    refractory = int(0.2 * fs)  # 200ms refractory period
    
    for i in range(1, len(integrated) - 1):
        if integrated[i] > threshold:
            if integrated[i] > integrated[i-1] and integrated[i] > integrated[i+1]:
                if not peaks or (i - peaks[-1]) > refractory:
                    peaks.append(i)
    
    return np.array(peaks)

def median_filter_rr(rr_intervals, window=5):
    """
    Median filter for removing physiologically uninterpretable R-R intervals
    Based on Chen et al. methodology
    """
    if len(rr_intervals) < window:
        return rr_intervals
    
    filtered = rr_intervals.copy()
    median_rr = np.median(rr_intervals)
    
    for i in range(len(rr_intervals)):
        # Check if RR interval is physiologically valid (300ms - 2000ms)
        if rr_intervals[i] < 0.3 or rr_intervals[i] > 2.0:
            filtered[i] = median_rr
            continue
        
        # Median filter
        start = max(0, i - window//2)
        end = min(len(rr_intervals), i + window//2 + 1)
        window_vals = rr_intervals[start:end]
        local_median = np.median(window_vals)
        
        # Replace outliers (> 20% deviation from local median)
        if abs(rr_intervals[i] - local_median) > 0.2 * local_median:
            filtered[i] = local_median
    
    return filtered

def extract_rr_features(ecg_segment, fs=100):
    """
    NaN-safe R-peak detection + 3 Hz interpolation.
    Returns:
        rr_interp  : length 180  (seconds * 3 Hz)
        ramp_interp: length 180
    """
    # ---- 1. R-peak detection -------------------------------------------------
    r_peaks = detect_r_peaks_hamilton(ecg_segment, fs)
    if len(r_peaks) < 3:                       # not enough peaks → dummy
        return np.zeros(180, dtype=np.float32), np.zeros(180, dtype=np.float32)

    # ---- 2. RR intervals -----------------------------------------------------
    rr_sec = np.diff(r_peaks) / fs
    rr_sec = median_filter_rr(rr_sec)          # outlier removal
    rr_times = r_peaks[1:] / fs                # time stamp of each RR

    # ---- 3. R-peak amplitudes ------------------------------------------------
    ramp = ecg_segment[r_peaks[1:]]

    # ---- 4. Cubic interpolation to 3 Hz -------------------------------------
    targ_t = np.linspace(0, 60, 180)
    kind = 'cubic' if len(rr_sec) >= 4 else 'linear'

    f_rr   = interp1d(rr_times, rr_sec,   kind=kind, bounds_error=False, fill_value=(rr_sec[0], rr_sec[-1]))
    f_amp  = interp1d(rr_times, ramp,     kind=kind, bounds_error=False, fill_value=(ramp[0],   ramp[-1]))

    rr_out   = np.clip(f_rr(targ_t), 0.3, 2.0).astype(np.float32)
    ramp_out = f_amp(targ_t).astype(np.float32)

    # ---- 5. Normalise --------------------------------------------------------
    rr_out   = (rr_out   - rr_out.mean())   / (rr_out.std()   + 1e-8)
    ramp_out = (ramp_out - ramp_out.mean()) / (ramp_out.std() + 1e-8)
    return rr_out, ramp_out

# ----------------------------- Improved Model ---------------------------------

class FocalLoss(nn.Module):
    """Focal Loss for handling class imbalance"""
    def __init__(self, alpha=0.25, gamma=2.0, weight=None):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.weight = weight
        
    def forward(self, inputs, targets):
        ce_loss = F.cross_entropy(inputs, targets, weight=self.weight, reduction='none')
        pt = torch.exp(-ce_loss)
        focal_loss = self.alpha * (1 - pt) ** self.gamma * ce_loss
        return focal_loss.mean()

class MultiScaleBlock(nn.Module):
    """Multi-scale feature extraction block"""
    def __init__(self, in_channels, out_channels):
        super().__init__()
        # ensure total output channels == out_channels
        c1 = out_channels // 3
        c2 = out_channels // 3
        c3 = out_channels - c1 - c2  # remainder so c1 + c2 + c3 == out_channels

        self.conv1 = nn.Conv1d(in_channels, c1, kernel_size=3, padding=1)
        self.conv2 = nn.Conv1d(in_channels, c2, kernel_size=5, padding=2)
        self.conv3 = nn.Conv1d(in_channels, c3, kernel_size=7, padding=3)
        self.bn = nn.BatchNorm1d(out_channels)
        
    def forward(self, x):
        x1 = self.conv1(x)
        x2 = self.conv2(x)
        x3 = self.conv3(x)
        out = torch.cat([x1, x2, x3], dim=1)
        return F.gelu(self.bn(out))


class EnhancedResBlock(nn.Module):
    """Enhanced residual block with squeeze-excitation"""
    def __init__(self, channels, kernel_size=7):
        super().__init__()
        self.conv1 = nn.Conv1d(channels, channels, kernel_size, padding=kernel_size//2, groups=channels)
        self.conv2 = nn.Conv1d(channels, channels, 1)
        self.norm = nn.BatchNorm1d(channels)
        
        # Squeeze-Excitation
        self.se = nn.Sequential(
            nn.AdaptiveAvgPool1d(1),
            nn.Conv1d(channels, channels//8, 1),
            nn.GELU(),
            nn.Conv1d(channels//8, channels, 1),
            nn.Sigmoid()
        )
        self.dropout = nn.Dropout(0.15)
        
    def forward(self, x):
        residual = x
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.norm(x)
        
        # Apply SE attention
        se_weight = self.se(x)
        x = x * se_weight
        
        x = self.dropout(x)
        return F.gelu(residual + x)

class ImprovedApneaNet(nn.Module):
    def __init__(self, d_model=256, n_blocks=10, dropout=0.15):
        super().__init__()
        # … rest of your code unchanged …
        
                # ====== Replace the three stem definitions with this block ======

        # Ensure the three modality-channel outputs sum to d_model (avoid integer division loss)
        c1 = d_model // 3
        c2 = d_model // 3
        c3 = d_model - c1 - c2  # remaining channels so c1+c2+c3 == d_model

        # ECG pathway (6000 samples) -> c1 channels
        self.ecg_stem = nn.Sequential(
            nn.Conv1d(1, c1, kernel_size=15, padding=7, stride=4),
            nn.BatchNorm1d(c1),
            nn.GELU(),
            nn.Conv1d(c1, c1, kernel_size=7, padding=3, stride=2),
            nn.BatchNorm1d(c1),
            nn.GELU(),
        )

        # RR interval pathway (180 samples @ 3Hz) -> c2 channels
        self.rr_stem = nn.Sequential(
            nn.Conv1d(1, c2, kernel_size=7, padding=3, stride=2),
            nn.BatchNorm1d(c2),
            nn.GELU(),
            nn.Conv1d(c2, c2, kernel_size=5, padding=2),
            nn.BatchNorm1d(c2),
            nn.GELU(),
        )

        # R-amplitude pathway (180 samples @ 3Hz) -> c3 channels
        self.ramp_stem = nn.Sequential(
            nn.Conv1d(1, c3, kernel_size=7, padding=3, stride=2),
            nn.BatchNorm1d(c3),
            nn.GELU(),
            nn.Conv1d(c3, c3, kernel_size=5, padding=2),
            nn.BatchNorm1d(c3),
            nn.GELU(),
        )

        # =================================================================

        
        # Multi-scale fusion
        self.fusion = MultiScaleBlock(d_model, d_model)
        
        # Enhanced residual blocks
        self.blocks = nn.ModuleList([
            EnhancedResBlock(d_model, kernel_size=7 if i % 2 == 0 else 11)
            for i in range(n_blocks)
        ])
        
        # Temporal attention with larger context
        self.temp_attn = nn.MultiheadAttention(d_model, num_heads=8, dropout=dropout, batch_first=True)
        self.temp_norm = nn.LayerNorm(d_model)
        self.temp_ffn = nn.Sequential(
            nn.Linear(d_model, d_model * 2),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(d_model * 2, d_model)
        )
        
        # Feature aggregation
        self.global_pool = nn.Sequential(
            nn.AdaptiveAvgPool1d(1),
            nn.Flatten()
        )
        
        # Enhanced classifier with multiple paths
        self.classifier = nn.Sequential(
            nn.Linear(d_model * 4, d_model * 2),
            nn.BatchNorm1d(d_model * 2),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(d_model * 2, d_model),
            nn.BatchNorm1d(d_model),
            nn.GELU(),
            nn.Dropout(dropout * 0.5),
            nn.Linear(d_model, 2)
        )
        
    def forward(self, ecg, rr, ramp):
        # Process each modality
        # ecg: (B, 6000, 1) -> (B, 1, 6000)
        # rr, ramp: (B, 180, 1) -> (B, 1, 180)
        ecg = ecg.transpose(1, 2)
        rr = rr.transpose(1, 2)
        ramp = ramp.transpose(1, 2)
        
        ecg_feat = self.ecg_stem(ecg)  # (B, d//3, L1)
        rr_feat = self.rr_stem(rr)      # (B, d//3, L2)
        ramp_feat = self.ramp_stem(ramp)  # (B, d//3, L2)
        
        # Align sequence lengths and concatenate
        target_len = min(ecg_feat.size(2), rr_feat.size(2), ramp_feat.size(2))
        ecg_feat = F.adaptive_avg_pool1d(ecg_feat, target_len)
        rr_feat = F.adaptive_avg_pool1d(rr_feat, target_len)
        ramp_feat = F.adaptive_avg_pool1d(ramp_feat, target_len)
        
        x = torch.cat([ecg_feat, rr_feat, ramp_feat], dim=1)  # (B, d_model, L)
        
        # Multi-scale fusion
        x = self.fusion(x)
        
        # Residual blocks
        for block in self.blocks:
            x = block(x)
        
        # Multiple pooling strategies
        x_avg = F.adaptive_avg_pool1d(x, 1).squeeze(-1)  # (B, d_model)
        x_max = F.adaptive_max_pool1d(x, 1).squeeze(-1)  # (B, d_model)
        x_std = torch.std(x, dim=2)  # (B, d_model)
        
        # Temporal attention
        x_seq = x.transpose(1, 2)  # (B, L, d_model)
        x_attn, _ = self.temp_attn(x_seq, x_seq, x_seq)
        x_attn = self.temp_norm(x_attn + x_seq)
        x_attn = x_attn + self.temp_ffn(x_attn)
        x_attn = x_attn.mean(dim=1)  # (B, d_model)
        
        # Combine all features
        x_combined = torch.cat([x_avg, x_max, x_std, x_attn], dim=1)
        
        # Classify
        logits = self.classifier(x_combined)
        return logits

# --------------------------- Enhanced Dataset ---------------------------

class EnhancedApneaDataset(Dataset):
    """Dataset with R-R interval and R-amplitude extraction"""

    def __init__(self, data_dir: str, record_names: list = None, cache_dir: str = None,
                 segment_length: int = 6000, stride: int = 3000, split='train', augment=True):
        super().__init__()
        self.segment_length = int(segment_length)
        self.stride = int(stride)
        self.split = split
        self.augment = augment and (split == 'train')
        
        cache_dir = Path(cache_dir) if cache_dir else Path(data_dir)
        cache_file = cache_dir / f'apnea_enhanced_{split}_{segment_length}_{stride}.pt'

        if cache_file.exists():
            print(f"Loading cached {split} from {cache_file}")
            data = torch.load(cache_file)
            self.ecg_segments = data['ecg_segments']
            self.rr_segments = data['rr_segments']
            self.ramp_segments = data['ramp_segments']
            self.labels = data['labels']
        else:
            assert wfdb is not None, "wfdb required"
            assert record_names is not None, "record_names required"
            
            self.ecg_segments = []
            self.rr_segments = []
            self.ramp_segments = []
            self.labels = []
            self.data_dir = Path(data_dir)
            
            print(f"Processing {len(record_names)} records for {split} (with R-R extraction)...")
            for i, rec in enumerate(record_names):
                print(f"  [{i+1}/{len(record_names)}] {rec}...", end='\r')
                self._load_record(rec)
            
            if len(self.ecg_segments) == 0:
                raise RuntimeError("No segments loaded")
            
            self.ecg_segments = torch.tensor(np.stack(self.ecg_segments, axis=0), dtype=torch.float32)
            self.rr_segments = torch.tensor(np.stack(self.rr_segments, axis=0), dtype=torch.float32)
            self.ramp_segments = torch.tensor(np.stack(self.ramp_segments, axis=0), dtype=torch.float32)
            self.labels = torch.tensor(self.labels, dtype=torch.long)
            
            print(f"\nSaving cache to {cache_file}")
            torch.save({
                'ecg_segments': self.ecg_segments,
                'rr_segments': self.rr_segments,
                'ramp_segments': self.ramp_segments,
                'labels': self.labels
            }, cache_file)

        if self.ecg_segments.ndim == 2:
            self.ecg_segments = self.ecg_segments.unsqueeze(-1)
        if self.rr_segments.ndim == 2:
            self.rr_segments = self.rr_segments.unsqueeze(-1)
        if self.ramp_segments.ndim == 2:
            self.ramp_segments = self.ramp_segments.unsqueeze(-1)

        print(f"{split.capitalize()}: {len(self.ecg_segments)} segments, "
              f"Class: {Counter(self.labels.tolist())}")

        def _load_record(self, record_name: str):
            try:
                rec = wfdb.rdrecord(str(self.data_dir / record_name))
                sig = rec.p_signal[:, 0].astype(np.float32)
                if np.isnan(sig).any():                       # NaN interpolation
                    nans = np.isnan(sig)
                    sig[nans] = np.interp(np.flatnonzero(nans),np.flatnonzero(~nans), sig[~nans])

                ann = wfdb.rdann(str(self.data_dir / record_name), 'apn')
                n_min = len(sig) // 6000
                mins  = np.zeros(n_min, dtype=int)
                for samp, sym in zip(ann.sample, ann.symbol):
                    if sym == 'A':
                        m = samp // 6000
                        if m < n_min:
                            mins[m] = 1

                for start in range(0, len(sig) - self.segment_length + 1, self.stride):
                    seg = sig[start:start + self.segment_length]
                # ---- normalise ECG -------------------------------------------------
                    seg = (seg - seg.mean()) / (seg.std() + 1e-8)
                    seg = np.clip(seg, -10, 10)

                    rr, ramp = extract_rr_features(seg, fs=100)

                    minute = start // 6000
                    if minute < len(mins):
                        self.ecg_segments.append(seg)
                        self.rr_segments.append(rr)
                        self.ramp_segments.append(ramp)
                        self.labels.append(int(mins[minute]))
            except Exception as e:
                print(f'\nSkip {record_name}: {e}')
    
    def _augment(self, ecg, rr, ramp):
        ecg, rr, ramp = map(lambda x: x.numpy() if torch.is_tensor(x) else x, (ecg, rr, ramp))

        if np.random.rand() < 0.5:
            ecg += np.random.normal(0, 0.02, ecg.shape).astype(np.float32)
            rr  += np.random.normal(0, 0.01, rr.shape).astype(np.float32)
            ramp+= np.random.normal(0, 0.01, ramp.shape).astype(np.float32)

        if np.random.rand() < 0.3:
            scale = np.random.uniform(0.9, 1.1)
            ecg *= scale
            ramp*= scale

        if np.random.rand() < 0.2:
            shift = np.random.randint(-150, 150)
            ecg = np.roll(ecg, shift)

        return map(torch.from_numpy, (ecg, rr, ramp))
            
    def __len__(self):
        return self.ecg_segments.shape[0]

    def __getitem__(self, idx):
        ecg = self.ecg_segments[idx]
        rr = self.rr_segments[idx]
        ramp = self.ramp_segments[idx]
        label = self.labels[idx]
        
        if self.augment:
            ecg, rr, ramp = self._augment(ecg, rr, ramp)
        
        # Ensure correct shape
        if ecg.ndim == 1:
            ecg = ecg.unsqueeze(-1)
        if rr.ndim == 1:
            rr = rr.unsqueeze(-1)
        if ramp.ndim == 1:
            ramp = ramp.unsqueeze(-1)
        
        return ecg, rr, ramp, label

# -------------------------- Training ------------------------

def compute_class_weights(labels_tensor):
    counts = Counter(labels_tensor.tolist())
    total = sum(counts.values())
    num_classes = len(counts)
    weights = [total / (num_classes * counts.get(i, 1)) for i in range(num_classes)]
    return torch.tensor(weights, dtype=torch.float32)

def train_epoch(model, dataloader, criterion, optimizer, scheduler, device, epoch, scaler=None):
    model.train()
    total_loss = 0.0
    correct = 0
    total = 0
    
    num_batches = len(dataloader)
    print_freq = max(1, num_batches // 15)
    start_time = time.time()

    for batch_idx, (ecg, rr, ramp, target) in enumerate(dataloader, 1):
        ecg = ecg.to(device, non_blocking=True)
        rr = rr.to(device, non_blocking=True)
        ramp = ramp.to(device, non_blocking=True)
        target = target.to(device, non_blocking=True)

        optimizer.zero_grad(set_to_none=True)

        with torch.amp.autocast('cuda', enabled=(device.type == 'cuda')):
            output = model(ecg, rr, ramp)
            loss = criterion(output, target)
        
        if torch.isnan(loss):
            print(f"\nWARNING: NaN loss, skipping batch {batch_idx}")
            continue

        if scaler is not None:
            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            scaler.step(optimizer)
            scaler.update()
        else:
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
        
        scheduler.step()

        total_loss += loss.item()
        pred = output.argmax(dim=1)
        correct += pred.eq(target).sum().item()
        total += target.size(0)
        
        if batch_idx % print_freq == 0 or batch_idx == num_batches:
            curr_acc = 100.0 * correct / total
            curr_loss = total_loss / batch_idx
            speed = batch_idx / (time.time() - start_time)
            eta = (num_batches - batch_idx) / speed if speed > 0 else 0
            
            print(f"  Ep {epoch} [{batch_idx:4d}/{num_batches}] "
                  f"Loss: {curr_loss:.4f} Acc: {curr_acc:.2f}% "
                  f"({speed:.1f} b/s, ETA: {eta:.0f}s)", end='\r')

    print()
    avg_loss = total_loss / len(dataloader)
    accuracy = 100.0 * correct / total if total > 0 else 0.0
    return avg_loss, accuracy

def validate(model, dataloader, criterion, device):
    model.eval()
    total_loss = 0.0
    correct = 0
    total = 0
    all_preds = []
    all_targets = []
    all_probs = []

    with torch.no_grad():
        for ecg, rr, ramp, target in dataloader:
            ecg = ecg.to(device, non_blocking=True)
            rr = rr.to(device, non_blocking=True)
            ramp = ramp.to(device, non_blocking=True)
            target = target.to(device, non_blocking=True)
            
            output = model(ecg, rr, ramp)
            loss = criterion(output, target)

            total_loss += loss.item()
            probs = F.softmax(output, dim=1)[:, 1]
            pred = output.argmax(dim=1)

            correct += pred.eq(target).sum().item()
            total += target.size(0)

            all_preds.extend(pred.cpu().tolist())
            all_targets.extend(target.cpu().tolist())
            all_probs.extend(probs.cpu().tolist())

    avg_loss = total_loss / len(dataloader)
    accuracy = 100.0 * correct / total if total > 0 else 0.0
    
    precision = precision_score(all_targets, all_preds, zero_division=0)
    recall = recall_score(all_targets, all_preds, zero_division=0)
    f1 = f1_score(all_targets, all_preds, zero_division=0)
    
    # Calculate specificity and sensitivity
    tn, fp, fn, tp = confusion_matrix(all_targets, all_preds).ravel()
    sensitivity = tp / (tp + fn) if (tp + fn) > 0 else 0.0  # Same as recall
    specificity = tn / (tn + fp) if (tn + fp) > 0 else 0.0
    
    return avg_loss, accuracy, np.array(all_preds), np.array(all_targets), np.array(all_probs), precision, recall, f1, sensitivity, specificity

# ------------------------------ Main ------------------------------------

def main(args):
    set_seed(args.seed)

    DATA_DIR = Path(args.data_dir)
    if not DATA_DIR.exists():
        raise FileNotFoundError(f"Data directory not found: {DATA_DIR}")

    record_files = list(DATA_DIR.glob('*.hea'))
    all_records = [f.stem for f in record_files]
    valid_records = [rec for rec in all_records 
                    if (DATA_DIR / (rec + '.apn')).exists() and not rec.endswith('er')]
    
    if len(valid_records) == 0:
        raise RuntimeError("No valid records found")

    print(f"Found {len(valid_records)} valid records")

    import random
    valid_records_shuffled = valid_records.copy()
    random.Random(args.seed).shuffle(valid_records_shuffled)
    split_idx = int(len(valid_records_shuffled) * args.train_split)
    train_records = valid_records_shuffled[:split_idx]
    val_records = valid_records_shuffled[split_idx:]
    print(f"Train: {len(train_records)}, Val: {len(val_records)}\n")

        # inside main(...) — create datasets
    cache_dir = args.cache_dir if args.cache_dir else str(DATA_DIR)
    train_dataset = EnhancedApneaDataset(
        str(DATA_DIR), train_records, cache_dir,
        args.segment_length, args.stride, 'train', augment=True
    )
    val_dataset = EnhancedApneaDataset(
        str(DATA_DIR), val_records, cache_dir,
        args.segment_length, args.stride, 'val', augment=False
    )

    num_workers = 2 if str(DATA_DIR).startswith('/kaggle') else 4

    train_loader = DataLoader(
        train_dataset, batch_size=args.batch_size, shuffle=True,
        num_workers=num_workers, pin_memory=True
    )
    val_loader = DataLoader(
        val_dataset, batch_size=args.batch_size, shuffle=False,
        num_workers=num_workers, pin_memory=True
    )

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Device: {device}")
    if device.type == 'cuda':
        try:
            print(f"GPU: {torch.cuda.get_device_name(0)}")
        except Exception:
            pass
        torch.cuda.empty_cache()

    model = ImprovedApneaNet(
        d_model=args.d_model, n_blocks=args.n_blocks, dropout=args.dropout
    ).to(device)

    print(f"Parameters: {sum(p.numel() for p in model.parameters()):,}\n")

    class_weights = compute_class_weights(train_dataset.labels).to(device)
    print(f"Class weights: {class_weights}")

    # Use Focal Loss for better class imbalance handling
        # stable, NaN-free loss with class weights
    criterion = nn.CrossEntropyLoss(weight=class_weights, label_smoothing=0.05)

    optimizer = torch.optim.AdamW(
        model.parameters(), lr=args.lr, weight_decay=args.weight_decay
    )

    scheduler = torch.optim.lr_scheduler.OneCycleLR(
        optimizer, max_lr=args.lr, epochs=args.epochs,
        steps_per_epoch=len(train_loader), pct_start=0.25
    )

    scaler = torch.amp.GradScaler() if device.type == 'cuda' else None

    best_val_acc = 0.0
    best_val_f1 = 0.0
    no_improve = 0

    print("\nStarting training...")
    print("="*100)

    for epoch in range(1, args.epochs + 1):
        epoch_start = time.time()

        train_loss, train_acc = train_epoch(
            model, train_loader, criterion, optimizer, scheduler, device, epoch, scaler
        )

        val_loss, val_acc, _, val_targets, val_probs, precision, recall, f1, sensitivity, specificity = validate(
            model, val_loader, criterion, device
        )

        try:
            auc = roc_auc_score(val_targets, val_probs)
        except Exception:
            auc = 0.0

        epoch_time = time.time() - epoch_start

        print(f"Epoch {epoch:2d}/{args.epochs} ({epoch_time:.1f}s)")
        print(f"  Train: Loss={train_loss:.4f}, Acc={train_acc:.2f}%")
        print(f"  Val:   Loss={val_loss:.4f}, Acc={val_acc:.2f}%, AUC={auc:.4f}")
        print(f"         Prec={precision:.3f}, Rec={recall:.3f}, F1={f1:.3f}")
        print(f"         Sensitivity={sensitivity:.3f}, Specificity={specificity:.3f}")

        if val_acc > best_val_acc or (val_acc >= best_val_acc and f1 > best_val_f1):
            best_val_acc = val_acc
            best_val_f1 = f1
            no_improve = 0
            torch.save({
                'epoch': epoch, 'model_state_dict': model.state_dict(),
                'val_acc': val_acc, 'val_auc': auc, 'val_f1': f1,
                'sensitivity': sensitivity, 'specificity': specificity
            }, args.best_model_path)
            print(f"  ✓ Best! (Acc={val_acc:.2f}%, F1={f1:.3f}, Sens={sensitivity:.3f}, Spec={specificity:.3f})")
        else:
            no_improve += 1
            print(f"  No improvement ({no_improve}/{args.patience})")

        print("-"*100)

        if no_improve >= args.patience:
            print(f"\nEarly stop at epoch {epoch}")
            break

    print(f"\n{'='*100}")
    print(f"BEST - Accuracy: {best_val_acc:.2f}%, F1: {best_val_f1:.3f}")
    print(f"{'='*100}")


if __name__ == '__main__':
    kaggle_data = '/kaggle/input/vincent2/apnea-ecg-database-1.0.0'
    colab_data = '/content/apnea-ecg/1.0.0'
    if Path(kaggle_data).exists():
        default_data_dir = kaggle_data
        default_cache_dir = '/kaggle/working'
        default_model_path = '/kaggle/working/best_model.pth'
    elif Path(colab_data).exists():
        default_data_dir = colab_data
        default_cache_dir = '/content'
        default_model_path = '/content/best_model.pth'
    else:
        default_data_dir = None
        default_cache_dir = None
        default_model_path = 'best_model.pth'

    parser = argparse.ArgumentParser()
    parser.add_argument('--data-dir', type=str, default=default_data_dir)
    parser.add_argument('--cache-dir', type=str, default=default_cache_dir)
    parser.add_argument('--segment-length', type=int, default=6000)
    parser.add_argument('--stride', type=int, default=2400)  # 60% overlap for more data
    parser.add_argument('--batch-size', type=int, default=48)  # Adjusted for multi-modal
    parser.add_argument('--epochs', type=int, default=100)
    parser.add_argument('--lr', type=float, default=1e-3)
    parser.add_argument('--weight-decay', type=float, default=1e-4)
    parser.add_argument('--d-model', type=int, default=256)
    parser.add_argument('--n-blocks', type=int, default=10)
    parser.add_argument('--dropout', type=float, default=0.15)
    parser.add_argument('--train-split', type=float, default=0.8)
    parser.add_argument('--patience', type=int, default=20)
    parser.add_argument('--best-model-path', type=str, default=default_model_path)
    parser.add_argument('--seed', type=int, default=42)

    args, _ = parser.parse_known_args()

    if args.data_dir is None:
        raise SystemExit("ERROR: Dataset not found")

    print("="*100)
    print("ENHANCED MODEL WITH R-R INTERVALS (Target: 90%+ Accuracy)")
    print("="*100)
    print(f"  Data:       {args.data_dir}")
    print(f"  Segment:    {args.segment_length} samples (60s), stride={args.stride}")
    print(f"  Features:   ECG + R-R Intervals + R-peak Amplitudes")
    print(f"  Batch:      {args.batch_size}")
    print(f"  Epochs:     {args.epochs}")
    print(f"  Model:      d_model={args.d_model}, blocks={args.n_blocks}")
    print(f"  Optimizer:  AdamW (lr={args.lr}, wd={args.weight_decay})")
    print(f"  Loss:       Focal Loss (alpha=0.25, gamma=2.0)")
    print("="*100 + "\n")

    main(args)



ENHANCED MODEL WITH R-R INTERVALS (Target: 90%+ Accuracy)
  Data:       /kaggle/input/vincent2/apnea-ecg-database-1.0.0
  Segment:    6000 samples (60s), stride=2400
  Features:   ECG + R-R Intervals + R-peak Amplitudes
  Batch:      48
  Epochs:     100
  Model:      d_model=256, blocks=10
  Optimizer:  AdamW (lr=0.001, wd=0.0001)
  Loss:       Focal Loss (alpha=0.25, gamma=2.0)

Found 43 valid records
Train: 34, Val: 9

Loading cached train from /kaggle/working/apnea_enhanced_train_6000_2400.pt
Train: 41751 segments, Class: Counter({0: 26395, 1: 15356})
Loading cached val from /kaggle/working/apnea_enhanced_val_6000_2400.pt
Val: 10777 segments, Class: Counter({0: 5862, 1: 4915})
Device: cuda
GPU: Tesla P100-PCIE-16GB
Parameters: 2,496,778

Class weights: tensor([0.7909, 1.3594], device='cuda:0')

Starting training...



  Ep 1 [  58/870] Loss: 0.6449 Acc: 66.10% (21.4 b/s, ETA: 38s)


  Ep 1 [ 116/870] Loss: 0.6121 Acc: 68.31% (22.7 b/s, ETA: 33s)



  Ep 1 [ 174/870] Loss: 0.5864 Ac

KeyboardInterrupt: 

In [9]:
#!/usr/bin/env python3
"""
High-performance apnea detection with R-R interval extraction (Target: 90%+ accuracy)
Based on PhysioNet Apnea-ECG Database methodology
"""

import argparse
import os
import time
from pathlib import Path
from collections import Counter

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from scipy import signal as scipy_signal
from scipy.interpolate import interp1d

try:
    import wfdb
except Exception:
    wfdb = None

from sklearn.metrics import roc_auc_score, precision_score, recall_score, f1_score, confusion_matrix

# ----------------------------- Utilities ---------------------------------

def set_seed(seed=42):
    import random
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

# ----------------------------- R-Peak Detection & R-R Interval Extraction ---------------------------------

def detect_r_peaks_hamilton(ecg_signal, fs=100):
    """
    Hamilton R-peak detection algorithm
    Returns indices of R-peaks
    """
    # Bandpass filter (5-15 Hz)
    b, a = scipy_signal.butter(2, [5, 15], btype='band', fs=fs)
    filtered = scipy_signal.filtfilt(b, a, ecg_signal)
    
    # Derivative
    diff_signal = np.diff(filtered)
    
    # Squaring
    squared = diff_signal ** 2
    
    # Moving average integration (150ms window)
    window_size = int(0.15 * fs)
    integrated = np.convolve(squared, np.ones(window_size)/window_size, mode='same')
    
    # Find peaks
    threshold = np.mean(integrated) + 0.5 * np.std(integrated)
    peaks = []
    refractory = int(0.2 * fs)  # 200ms refractory period
    
    for i in range(1, len(integrated) - 1):
        if integrated[i] > threshold:
            if integrated[i] > integrated[i-1] and integrated[i] > integrated[i+1]:
                if not peaks or (i - peaks[-1]) > refractory:
                    peaks.append(i)
    
    return np.array(peaks)

def median_filter_rr(rr_intervals, window=5):
    """
    Median filter for removing physiologically uninterpretable R-R intervals
    Based on Chen et al. methodology
    """
    if len(rr_intervals) < window:
        return rr_intervals
    
    filtered = rr_intervals.copy()
    median_rr = np.median(rr_intervals)
    
    for i in range(len(rr_intervals)):
        # Check if RR interval is physiologically valid (300ms - 2000ms)
        if rr_intervals[i] < 0.3 or rr_intervals[i] > 2.0:
            filtered[i] = median_rr
            continue
        
        # Median filter
        start = max(0, i - window//2)
        end = min(len(rr_intervals), i + window//2 + 1)
        window_vals = rr_intervals[start:end]
        local_median = np.median(window_vals)
        
        # Replace outliers (> 20% deviation from local median)
        if abs(rr_intervals[i] - local_median) > 0.2 * local_median:
            filtered[i] = local_median
    
    return filtered

def extract_rr_features(ecg_segment, fs=100):
    """
    NaN-safe R-peak detection + 3 Hz interpolation.
    Returns:
        rr_interp  : length 180  (seconds * 3 Hz)
        ramp_interp: length 180
    """
    # ---- 1. R-peak detection -------------------------------------------------
    r_peaks = detect_r_peaks_hamilton(ecg_segment, fs)
    if len(r_peaks) < 3:                       # not enough peaks → dummy
        return np.zeros(180, dtype=np.float32), np.zeros(180, dtype=np.float32)

    # ---- 2. RR intervals -----------------------------------------------------
    rr_sec = np.diff(r_peaks) / fs
    rr_sec = median_filter_rr(rr_sec)          # outlier removal
    rr_times = r_peaks[1:] / fs                # time stamp of each RR

    # ---- 3. R-peak amplitudes ------------------------------------------------
    ramp = ecg_segment[r_peaks[1:]]

    # ---- 4. Cubic interpolation to 3 Hz -------------------------------------
    targ_t = np.linspace(0, 60, 180)
    kind = 'cubic' if len(rr_sec) >= 4 else 'linear'

    f_rr   = interp1d(rr_times, rr_sec,   kind=kind, bounds_error=False, fill_value=(rr_sec[0], rr_sec[-1]))
    f_amp  = interp1d(rr_times, ramp,     kind=kind, bounds_error=False, fill_value=(ramp[0],   ramp[-1]))

    rr_out   = np.clip(f_rr(targ_t), 0.3, 2.0).astype(np.float32)
    ramp_out = f_amp(targ_t).astype(np.float32)

    # ---- 5. Normalise with safety checks -------------------------------------
    rr_std = rr_out.std()
    ramp_std = ramp_out.std()
    
    if rr_std > 1e-6:
        rr_out = (rr_out - rr_out.mean()) / rr_std
    else:
        rr_out = rr_out - rr_out.mean()
    
    if ramp_std > 1e-6:
        ramp_out = (ramp_out - ramp_out.mean()) / ramp_std
    else:
        ramp_out = ramp_out - ramp_out.mean()
    
    # Clip to prevent extreme values
    rr_out = np.clip(rr_out, -10, 10)
    ramp_out = np.clip(ramp_out, -10, 10)
    
    return rr_out, ramp_out

# ----------------------------- Improved Model ---------------------------------

class MultiScaleBlock(nn.Module):
    """Multi-scale feature extraction block"""
    def __init__(self, in_channels, out_channels):
        super().__init__()
        # ensure total output channels == out_channels
        c1 = out_channels // 3
        c2 = out_channels // 3
        c3 = out_channels - c1 - c2  # remainder so c1 + c2 + c3 == out_channels

        self.conv1 = nn.Conv1d(in_channels, c1, kernel_size=3, padding=1)
        self.conv2 = nn.Conv1d(in_channels, c2, kernel_size=5, padding=2)
        self.conv3 = nn.Conv1d(in_channels, c3, kernel_size=7, padding=3)
        self.bn = nn.BatchNorm1d(out_channels)
        
    def forward(self, x):
        x1 = self.conv1(x)
        x2 = self.conv2(x)
        x3 = self.conv3(x)
        out = torch.cat([x1, x2, x3], dim=1)
        return F.gelu(self.bn(out))


class EnhancedResBlock(nn.Module):
    """Enhanced residual block with squeeze-excitation"""
    def __init__(self, channels, kernel_size=7):
        super().__init__()
        self.conv1 = nn.Conv1d(channels, channels, kernel_size, padding=kernel_size//2, groups=channels)
        self.conv2 = nn.Conv1d(channels, channels, 1)
        self.norm = nn.BatchNorm1d(channels)
        
        # Squeeze-Excitation with stability improvements
        se_channels = max(8, channels // 8)  # Ensure at least 8 channels
        self.se = nn.Sequential(
            nn.AdaptiveAvgPool1d(1),
            nn.Conv1d(channels, se_channels, 1),
            nn.GELU(),
            nn.Conv1d(se_channels, channels, 1),
            nn.Sigmoid()
        )
        self.dropout = nn.Dropout(0.15)
        
    def forward(self, x):
        residual = x
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.norm(x)
        
        # Apply SE attention with safety check
        se_weight = self.se(x)
        # Clamp to prevent extreme values
        se_weight = torch.clamp(se_weight, 0.0, 1.0)
        x = x * se_weight
        
        x = self.dropout(x)
        return F.gelu(residual + x)

class ImprovedApneaNet(nn.Module):
    def __init__(self, d_model=256, n_blocks=10, dropout=0.15):
        super().__init__()
        
        # Ensure the three modality-channel outputs sum to d_model
        c1 = d_model // 3
        c2 = d_model // 3
        c3 = d_model - c1 - c2

        # ECG pathway (6000 samples) -> c1 channels
        self.ecg_stem = nn.Sequential(
            nn.Conv1d(1, c1, kernel_size=15, padding=7, stride=4),
            nn.BatchNorm1d(c1),
            nn.GELU(),
            nn.Conv1d(c1, c1, kernel_size=7, padding=3, stride=2),
            nn.BatchNorm1d(c1),
            nn.GELU(),
        )

        # RR interval pathway (180 samples @ 3Hz) -> c2 channels
        self.rr_stem = nn.Sequential(
            nn.Conv1d(1, c2, kernel_size=7, padding=3, stride=2),
            nn.BatchNorm1d(c2),
            nn.GELU(),
            nn.Conv1d(c2, c2, kernel_size=5, padding=2),
            nn.BatchNorm1d(c2),
            nn.GELU(),
        )

        # R-amplitude pathway (180 samples @ 3Hz) -> c3 channels
        self.ramp_stem = nn.Sequential(
            nn.Conv1d(1, c3, kernel_size=7, padding=3, stride=2),
            nn.BatchNorm1d(c3),
            nn.GELU(),
            nn.Conv1d(c3, c3, kernel_size=5, padding=2),
            nn.BatchNorm1d(c3),
            nn.GELU(),
        )
        
        # Multi-scale fusion
        self.fusion = MultiScaleBlock(d_model, d_model)
        
        # Enhanced residual blocks
        self.blocks = nn.ModuleList([
            EnhancedResBlock(d_model, kernel_size=7 if i % 2 == 0 else 11)
            for i in range(n_blocks)
        ])
        
        # Temporal attention with larger context
        self.temp_attn = nn.MultiheadAttention(d_model, num_heads=8, dropout=dropout, batch_first=True)
        self.temp_norm = nn.LayerNorm(d_model)
        self.temp_ffn = nn.Sequential(
            nn.Linear(d_model, d_model * 2),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(d_model * 2, d_model)
        )
        
        # Feature aggregation
        self.global_pool = nn.Sequential(
            nn.AdaptiveAvgPool1d(1),
            nn.Flatten()
        )
        
        # Enhanced classifier with multiple paths
        self.classifier = nn.Sequential(
            nn.Linear(d_model * 4, d_model * 2),
            nn.BatchNorm1d(d_model * 2),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(d_model * 2, d_model),
            nn.BatchNorm1d(d_model),
            nn.GELU(),
            nn.Dropout(dropout * 0.5),
            nn.Linear(d_model, 2)
        )
        
        # Initialize weights
        self._initialize_weights()
        
    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv1d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm1d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
        
    def forward(self, ecg, rr, ramp):
        # Validate inputs
        if torch.isnan(ecg).any() or torch.isinf(ecg).any():
            ecg = torch.nan_to_num(ecg, nan=0.0, posinf=10.0, neginf=-10.0)
        if torch.isnan(rr).any() or torch.isinf(rr).any():
            rr = torch.nan_to_num(rr, nan=0.0, posinf=10.0, neginf=-10.0)
        if torch.isnan(ramp).any() or torch.isinf(ramp).any():
            ramp = torch.nan_to_num(ramp, nan=0.0, posinf=10.0, neginf=-10.0)
        
        # Process each modality
        ecg = ecg.transpose(1, 2)
        rr = rr.transpose(1, 2)
        ramp = ramp.transpose(1, 2)
        
        ecg_feat = self.ecg_stem(ecg)
        rr_feat = self.rr_stem(rr)
        ramp_feat = self.ramp_stem(ramp)
        
        # Align sequence lengths and concatenate
        target_len = min(ecg_feat.size(2), rr_feat.size(2), ramp_feat.size(2))
        ecg_feat = F.adaptive_avg_pool1d(ecg_feat, target_len)
        rr_feat = F.adaptive_avg_pool1d(rr_feat, target_len)
        ramp_feat = F.adaptive_avg_pool1d(ramp_feat, target_len)
        
        x = torch.cat([ecg_feat, rr_feat, ramp_feat], dim=1)
        
        # Multi-scale fusion
        x = self.fusion(x)
        
        # Residual blocks
        for block in self.blocks:
            x = block(x)
        
        # Multiple pooling strategies with safety
        x_avg = F.adaptive_avg_pool1d(x, 1).squeeze(-1)
        x_max = F.adaptive_max_pool1d(x, 1).squeeze(-1)
        x_std = torch.std(x, dim=2) + 1e-8  # Add epsilon for stability
        
        # Temporal attention
        x_seq = x.transpose(1, 2)
        x_attn, _ = self.temp_attn(x_seq, x_seq, x_seq)
        x_attn = self.temp_norm(x_attn + x_seq)
        x_attn = x_attn + self.temp_ffn(x_attn)
        x_attn = x_attn.mean(dim=1)
        
        # Combine all features
        x_combined = torch.cat([x_avg, x_max, x_std, x_attn], dim=1)
        
        # Final NaN check before classifier
        if torch.isnan(x_combined).any() or torch.isinf(x_combined).any():
            x_combined = torch.nan_to_num(x_combined, nan=0.0, posinf=10.0, neginf=-10.0)
        
        # Classify
        logits = self.classifier(x_combined)
        return logits

# --------------------------- Enhanced Dataset ---------------------------

class EnhancedApneaDataset(Dataset):
    """Dataset with R-R interval and R-amplitude extraction"""

    def __init__(self, data_dir: str, record_names: list = None, cache_dir: str = None,
                 segment_length: int = 6000, stride: int = 3000, split='train', augment=True):
        super().__init__()
        self.segment_length = int(segment_length)
        self.stride = int(stride)
        self.split = split
        self.augment = augment and (split == 'train')
        
        cache_dir = Path(cache_dir) if cache_dir else Path(data_dir)
        cache_file = cache_dir / f'apnea_enhanced_{split}_{segment_length}_{stride}.pt'

        if cache_file.exists():
            print(f"Loading cached {split} from {cache_file}")
            data = torch.load(cache_file)
            self.ecg_segments = data['ecg_segments']
            self.rr_segments = data['rr_segments']
            self.ramp_segments = data['ramp_segments']
            self.labels = data['labels']
        else:
            assert wfdb is not None, "wfdb required"
            assert record_names is not None, "record_names required"
            
            self.ecg_segments = []
            self.rr_segments = []
            self.ramp_segments = []
            self.labels = []
            self.data_dir = Path(data_dir)
            
            print(f"Processing {len(record_names)} records for {split} (with R-R extraction)...")
            for i, rec in enumerate(record_names):
                print(f"  [{i+1}/{len(record_names)}] {rec}...", end='\r')
                self._load_record(rec)
            
            if len(self.ecg_segments) == 0:
                raise RuntimeError("No segments loaded")
            
            self.ecg_segments = torch.tensor(np.stack(self.ecg_segments, axis=0), dtype=torch.float32)
            self.rr_segments = torch.tensor(np.stack(self.rr_segments, axis=0), dtype=torch.float32)
            self.ramp_segments = torch.tensor(np.stack(self.ramp_segments, axis=0), dtype=torch.float32)
            self.labels = torch.tensor(self.labels, dtype=torch.long)
            
            print(f"\nSaving cache to {cache_file}")
            torch.save({
                'ecg_segments': self.ecg_segments,
                'rr_segments': self.rr_segments,
                'ramp_segments': self.ramp_segments,
                'labels': self.labels
            }, cache_file)

        if self.ecg_segments.ndim == 2:
            self.ecg_segments = self.ecg_segments.unsqueeze(-1)
        if self.rr_segments.ndim == 2:
            self.rr_segments = self.rr_segments.unsqueeze(-1)
        if self.ramp_segments.ndim == 2:
            self.ramp_segments = self.ramp_segments.unsqueeze(-1)

        print(f"{split.capitalize()}: {len(self.ecg_segments)} segments, "
              f"Class: {Counter(self.labels.tolist())}")

    def _load_record(self, record_name: str):
        try:
            rec = wfdb.rdrecord(str(self.data_dir / record_name))
            sig = rec.p_signal[:, 0].astype(np.float32)
            if np.isnan(sig).any():
                nans = np.isnan(sig)
                sig[nans] = np.interp(np.flatnonzero(nans), np.flatnonzero(~nans), sig[~nans])

            ann = wfdb.rdann(str(self.data_dir / record_name), 'apn')
            n_min = len(sig) // 6000
            mins = np.zeros(n_min, dtype=int)
            for samp, sym in zip(ann.sample, ann.symbol):
                if sym == 'A':
                    m = samp // 6000
                    if m < n_min:
                        mins[m] = 1

            for start in range(0, len(sig) - self.segment_length + 1, self.stride):
                seg = sig[start:start + self.segment_length]
                
                # Normalise ECG with safety
                seg_std = seg.std()
                if seg_std > 1e-6:
                    seg = (seg - seg.mean()) / seg_std
                else:
                    seg = seg - seg.mean()
                seg = np.clip(seg, -10, 10)

                rr, ramp = extract_rr_features(seg, fs=100)

                minute = start // 6000
                if minute < len(mins):
                    self.ecg_segments.append(seg)
                    self.rr_segments.append(rr)
                    self.ramp_segments.append(ramp)
                    self.labels.append(int(mins[minute]))
        except Exception as e:
            print(f'\nSkip {record_name}: {e}')
    
    def _augment(self, ecg, rr, ramp):
        ecg, rr, ramp = map(lambda x: x.numpy() if torch.is_tensor(x) else x, (ecg, rr, ramp))

        if np.random.rand() < 0.5:
            ecg += np.random.normal(0, 0.02, ecg.shape).astype(np.float32)
            rr += np.random.normal(0, 0.01, rr.shape).astype(np.float32)
            ramp += np.random.normal(0, 0.01, ramp.shape).astype(np.float32)

        if np.random.rand() < 0.3:
            scale = np.random.uniform(0.9, 1.1)
            ecg *= scale
            ramp *= scale

        if np.random.rand() < 0.2:
            shift = np.random.randint(-150, 150)
            ecg = np.roll(ecg, shift)

        # Clip after augmentation
        ecg = np.clip(ecg, -10, 10)
        rr = np.clip(rr, -10, 10)
        ramp = np.clip(ramp, -10, 10)

        return tuple(map(torch.from_numpy, (ecg, rr, ramp)))
            
    def __len__(self):
        return self.ecg_segments.shape[0]

    def __getitem__(self, idx):
        ecg = self.ecg_segments[idx]
        rr = self.rr_segments[idx]
        ramp = self.ramp_segments[idx]
        label = self.labels[idx]
        
        if self.augment:
            ecg, rr, ramp = self._augment(ecg, rr, ramp)
        
        # Ensure correct shape
        if ecg.ndim == 1:
            ecg = ecg.unsqueeze(-1)
        if rr.ndim == 1:
            rr = rr.unsqueeze(-1)
        if ramp.ndim == 1:
            ramp = ramp.unsqueeze(-1)
        
        return ecg, rr, ramp, label

# -------------------------- Training ------------------------

def compute_class_weights(labels_tensor):
    counts = Counter(labels_tensor.tolist())
    total = sum(counts.values())
    num_classes = len(counts)
    weights = [total / (num_classes * counts.get(i, 1)) for i in range(num_classes)]
    return torch.tensor(weights, dtype=torch.float32)

def train_epoch(model, dataloader, criterion, optimizer, scheduler, device, epoch, scaler=None):
    model.train()
    total_loss = 0.0
    correct = 0
    total = 0
    
    num_batches = len(dataloader)
    print_freq = max(1, num_batches // 15)
    start_time = time.time()

    for batch_idx, (ecg, rr, ramp, target) in enumerate(dataloader, 1):
        ecg = ecg.to(device, non_blocking=True)
        rr = rr.to(device, non_blocking=True)
        ramp = ramp.to(device, non_blocking=True)
        target = target.to(device, non_blocking=True)

        optimizer.zero_grad(set_to_none=True)

        with torch.amp.autocast('cuda', enabled=(device.type == 'cuda')):
            output = model(ecg, rr, ramp)
            loss = criterion(output, target)
        
        if torch.isnan(loss) or torch.isinf(loss):
            print(f"\nWARNING: NaN/Inf loss, skipping batch {batch_idx}")
            continue

        if scaler is not None:
            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            scaler.step(optimizer)
            scaler.update()
        else:
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
        
        scheduler.step()

        total_loss += loss.item()
        pred = output.argmax(dim=1)
        correct += pred.eq(target).sum().item()
        total += target.size(0)
        
        if batch_idx % print_freq == 0 or batch_idx == num_batches:
            curr_acc = 100.0 * correct / total
            curr_loss = total_loss / batch_idx
            speed = batch_idx / (time.time() - start_time)
            eta = (num_batches - batch_idx) / speed if speed > 0 else 0
            
            print(f"  Ep {epoch} [{batch_idx:4d}/{num_batches}] "
                  f"Loss: {curr_loss:.4f} Acc: {curr_acc:.2f}% "
                  f"({speed:.1f} b/s, ETA: {eta:.0f}s)", end='\r')

    print()
    avg_loss = total_loss / len(dataloader)
    accuracy = 100.0 * correct / total if total > 0 else 0.0
    return avg_loss, accuracy

def validate(model, dataloader, criterion, device):
    model.eval()
    total_loss = 0.0
    correct = 0
    total = 0
    all_preds = []
    all_targets = []
    all_probs = []

    with torch.no_grad():
        for ecg, rr, ramp, target in dataloader:
            ecg = ecg.to(device, non_blocking=True)
            rr = rr.to(device, non_blocking=True)
            ramp = ramp.to(device, non_blocking=True)
            target = target.to(device, non_blocking=True)
            
            output = model(ecg, rr, ramp)
            loss = criterion(output, target)

            total_loss += loss.item()
            probs = F.softmax(output, dim=1)[:, 1]
            pred = output.argmax(dim=1)

            correct += pred.eq(target).sum().item()
            total += target.size(0)

            all_preds.extend(pred.cpu().tolist())
            all_targets.extend(target.cpu().tolist())
            all_probs.extend(probs.cpu().tolist())

    avg_loss = total_loss / len(dataloader)
    accuracy = 100.0 * correct / total if total > 0 else 0.0
    
    precision = precision_score(all_targets, all_preds, zero_division=0)
    recall = recall_score(all_targets, all_preds, zero_division=0)
    f1 = f1_score(all_targets, all_preds, zero_division=0)
    
    # Calculate specificity and sensitivity
    tn, fp, fn, tp = confusion_matrix(all_targets, all_preds).ravel()
    sensitivity = tp / (tp + fn) if (tp + fn) > 0 else 0.0
    specificity = tn / (tn + fp) if (tn + fp) > 0 else 0.0
    
    return avg_loss, accuracy, np.array(all_preds), np.array(all_targets), np.array(all_probs), precision, recall, f1, sensitivity, specificity

# ------------------------------ Main ------------------------------------

def main(args):
    set_seed(args.seed)

    DATA_DIR = Path(args.data_dir)
    if not DATA_DIR.exists():
        raise FileNotFoundError(f"Data directory not found: {DATA_DIR}")

    record_files = list(DATA_DIR.glob('*.hea'))
    all_records = [f.stem for f in record_files]
    valid_records = [rec for rec in all_records 
                    if (DATA_DIR / (rec + '.apn')).exists() and not rec.endswith('er')]
    
    if len(valid_records) == 0:
        raise RuntimeError("No valid records found")

    print(f"Found {len(valid_records)} valid records")

    import random
    valid_records_shuffled = valid_records.copy()
    random.Random(args.seed).shuffle(valid_records_shuffled)
    split_idx = int(len(valid_records_shuffled) * args.train_split)
    train_records = valid_records_shuffled[:split_idx]
    val_records = valid_records_shuffled[split_idx:]
    print(f"Train: {len(train_records)}, Val: {len(val_records)}\n")

    cache_dir = args.cache_dir if args.cache_dir else str(DATA_DIR)
    train_dataset = EnhancedApneaDataset(
        str(DATA_DIR), train_records, cache_dir,
        args.segment_length, args.stride, 'train', augment=True
    )
    val_dataset = EnhancedApneaDataset(
        str(DATA_DIR), val_records, cache_dir,
        args.segment_length, args.stride, 'val', augment=False
    )

    num_workers = 2 if str(DATA_DIR).startswith('/kaggle') else 4

    train_loader = DataLoader(
        train_dataset, batch_size=args.batch_size, shuffle=True,
        num_workers=num_workers, pin_memory=True
    )
    val_loader = DataLoader(
        val_dataset, batch_size=args.batch_size, shuffle=False,
        num_workers=num_workers, pin_memory=True
    )

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Device: {device}")
    if device.type == 'cuda':
        try:
            print(f"GPU: {torch.cuda.get_device_name(0)}")
        except Exception:
            pass
        torch.cuda.empty_cache()

    model = ImprovedApneaNet(
        d_model=args.d_model, n_blocks=args.n_blocks, dropout=args.dropout
    ).to(device)

    print(f"Parameters: {sum(p.numel() for p in model.parameters()):,}\n")

    class_weights = compute_class_weights(train_dataset.labels).to(device)
    print(f"Class weights: {class_weights}")

    # Use Focal Loss for better class imbalance handling
        # stable, NaN-free loss with class weights
    criterion = nn.CrossEntropyLoss(weight=class_weights, label_smoothing=0.05)

    optimizer = torch.optim.AdamW(
        model.parameters(), lr=args.lr, weight_decay=args.weight_decay
    )

    scheduler = torch.optim.lr_scheduler.OneCycleLR(
        optimizer, max_lr=args.lr, epochs=args.epochs,
        steps_per_epoch=len(train_loader), pct_start=0.25
    )

    scaler = torch.amp.GradScaler() if device.type == 'cuda' else None

    best_val_acc = 0.0
    best_val_f1 = 0.0
    no_improve = 0

    print("\nStarting training...")
    print("="*100)

    for epoch in range(1, args.epochs + 1):
        epoch_start = time.time()

        train_loss, train_acc = train_epoch(
            model, train_loader, criterion, optimizer, scheduler, device, epoch, scaler
        )

        val_loss, val_acc, _, val_targets, val_probs, precision, recall, f1, sensitivity, specificity = validate(
            model, val_loader, criterion, device
        )

        try:
            auc = roc_auc_score(val_targets, val_probs)
        except Exception:
            auc = 0.0

        epoch_time = time.time() - epoch_start

        print(f"Epoch {epoch:2d}/{args.epochs} ({epoch_time:.1f}s)")
        print(f"  Train: Loss={train_loss:.4f}, Acc={train_acc:.2f}%")
        print(f"  Val:   Loss={val_loss:.4f}, Acc={val_acc:.2f}%, AUC={auc:.4f}")
        print(f"         Prec={precision:.3f}, Rec={recall:.3f}, F1={f1:.3f}")
        print(f"         Sensitivity={sensitivity:.3f}, Specificity={specificity:.3f}")

        if val_acc > best_val_acc or (val_acc >= best_val_acc and f1 > best_val_f1):
            best_val_acc = val_acc
            best_val_f1 = f1
            no_improve = 0
            torch.save({
                'epoch': epoch, 'model_state_dict': model.state_dict(),
                'val_acc': val_acc, 'val_auc': auc, 'val_f1': f1,
                'sensitivity': sensitivity, 'specificity': specificity
            }, args.best_model_path)
            print(f"  ✓ Best! (Acc={val_acc:.2f}%, F1={f1:.3f}, Sens={sensitivity:.3f}, Spec={specificity:.3f})")
        else:
            no_improve += 1
            print(f"  No improvement ({no_improve}/{args.patience})")

        print("-"*100)

        if no_improve >= args.patience:
            print(f"\nEarly stop at epoch {epoch}")
            break

    print(f"\n{'='*100}")
    print(f"BEST - Accuracy: {best_val_acc:.2f}%, F1: {best_val_f1:.3f}")
    print(f"{'='*100}")


if __name__ == '__main__':
    kaggle_data = '/kaggle/input/vincent2/apnea-ecg-database-1.0.0'
    colab_data = '/content/apnea-ecg/1.0.0'
    if Path(kaggle_data).exists():
        default_data_dir = kaggle_data
        default_cache_dir = '/kaggle/working'
        default_model_path = '/kaggle/working/best_model.pth'
    elif Path(colab_data).exists():
        default_data_dir = colab_data
        default_cache_dir = '/content'
        default_model_path = '/content/best_model.pth'
    else:
        default_data_dir = None
        default_cache_dir = None
        default_model_path = 'best_model.pth'

    parser = argparse.ArgumentParser()
    parser.add_argument('--data-dir', type=str, default=default_data_dir)
    parser.add_argument('--cache-dir', type=str, default=default_cache_dir)
    parser.add_argument('--segment-length', type=int, default=6000)
    parser.add_argument('--stride', type=int, default=2400)  # 60% overlap for more data
    parser.add_argument('--batch-size', type=int, default=48)  # Adjusted for multi-modal
    parser.add_argument('--epochs', type=int, default=100)
    parser.add_argument('--lr', type=float, default=1e-3)
    parser.add_argument('--weight-decay', type=float, default=1e-4)
    parser.add_argument('--d-model', type=int, default=256)
    parser.add_argument('--n-blocks', type=int, default=10)
    parser.add_argument('--dropout', type=float, default=0.15)
    parser.add_argument('--train-split', type=float, default=0.8)
    parser.add_argument('--patience', type=int, default=20)
    parser.add_argument('--best-model-path', type=str, default=default_model_path)
    parser.add_argument('--seed', type=int, default=42)

    args, _ = parser.parse_known_args()

    if args.data_dir is None:
        raise SystemExit("ERROR: Dataset not found")

    print("="*100)
    print("ENHANCED MODEL WITH R-R INTERVALS (Target: 90%+ Accuracy)")
    print("="*100)
    print(f"  Data:       {args.data_dir}")
    print(f"  Segment:    {args.segment_length} samples (60s), stride={args.stride}")
    print(f"  Features:   ECG + R-R Intervals + R-peak Amplitudes")
    print(f"  Batch:      {args.batch_size}")
    print(f"  Epochs:     {args.epochs}")
    print(f"  Model:      d_model={args.d_model}, blocks={args.n_blocks}")
    print(f"  Optimizer:  AdamW (lr={args.lr}, wd={args.weight_decay})")
    print(f"  Loss:       Focal Loss (alpha=0.25, gamma=2.0)")
    print("="*100 + "\n")

    main(args)

ENHANCED MODEL WITH R-R INTERVALS (Target: 90%+ Accuracy)
  Data:       /kaggle/input/vincent2/apnea-ecg-database-1.0.0
  Segment:    6000 samples (60s), stride=2400
  Features:   ECG + R-R Intervals + R-peak Amplitudes
  Batch:      48
  Epochs:     100
  Model:      d_model=256, blocks=10
  Optimizer:  AdamW (lr=0.001, wd=0.0001)
  Loss:       Focal Loss (alpha=0.25, gamma=2.0)

Found 43 valid records
Train: 34, Val: 9

Loading cached train from /kaggle/working/apnea_enhanced_train_6000_2400.pt
Train: 41751 segments, Class: Counter({0: 26395, 1: 15356})
Loading cached val from /kaggle/working/apnea_enhanced_val_6000_2400.pt
Val: 10777 segments, Class: Counter({0: 5862, 1: 4915})
Device: cuda
GPU: Tesla P100-PCIE-16GB
Parameters: 2,496,778

Class weights: tensor([0.7909, 1.3594], device='cuda:0')

Starting training...




  Ep 1 [ 870/870] Loss: 0.5580 Acc: 74.69% (21.3 b/s, ETA: 0s))
Epoch  1/100 (43.8s)
  Train: Loss=0.5580, Acc=74.69%
  Val:   Loss=0.5033, Acc=78.05%, AUC=0.8956
         Prec=0.694, Rec=0.928, F1=0.794
         Sensitivity=0.928, Specificity=0.657
  ✓ Best! (Acc=78.05%, F1=0.794, Sens=0.928, Spec=0.657)
----------------------------------------------------------------------------------------------------
  Ep 2 [ 870/870] Loss: 0.3917 Acc: 85.49% (22.1 b/s, ETA: 0s))
Epoch  2/100 (42.3s)
  Train: Loss=0.3917, Acc=85.49%
  Val:   Loss=0.5056, Acc=80.02%, AUC=0.8718
         Prec=0.753, Rec=0.837, F1=0.793
         Sensitivity=0.837, Specificity=0.769
  ✓ Best! (Acc=80.02%, F1=0.793, Sens=0.837, Spec=0.769)
----------------------------------------------------------------------------------------------------
  Ep 3 [ 870/870] Loss: 0.3555 Acc: 87.77% (22.3 b/s, ETA: 0s))
Epoch  3/100 (42.2s)
  Train: Loss=0.3555, Acc=87.77%
  Val:   Loss=0.4170, Acc=83.59%, AUC=0.9167
         Prec=0.799, 

In [2]:
#!/usr/bin/env python3
"""
High-performance apnea detection with R-R interval extraction (Target: 90%+ accuracy)
Based on PhysioNet Apnea-ECG Database methodology
"""

import argparse
import os
import time
from pathlib import Path
from collections import Counter

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from scipy import signal as scipy_signal
from scipy.interpolate import interp1d

try:
    import wfdb
except Exception:
    wfdb = None

from sklearn.metrics import roc_auc_score, precision_score, recall_score, f1_score, confusion_matrix

# ----------------------------- Utilities ---------------------------------

def set_seed(seed=42):
    import random
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

# ----------------------------- R-Peak Detection & R-R Interval Extraction ---------------------------------

def detect_r_peaks_hamilton(ecg_signal, fs=100):
    """
    Hamilton R-peak detection algorithm
    Returns indices of R-peaks
    """
    # Bandpass filter (5-15 Hz)
    b, a = scipy_signal.butter(2, [5, 15], btype='band', fs=fs)
    filtered = scipy_signal.filtfilt(b, a, ecg_signal)
    
    # Derivative
    diff_signal = np.diff(filtered)
    
    # Squaring
    squared = diff_signal ** 2
    
    # Moving average integration (150ms window)
    window_size = int(0.15 * fs)
    integrated = np.convolve(squared, np.ones(window_size)/window_size, mode='same')
    
    # Find peaks
    threshold = np.mean(integrated) + 0.5 * np.std(integrated)
    peaks = []
    refractory = int(0.2 * fs)  # 200ms refractory period
    
    for i in range(1, len(integrated) - 1):
        if integrated[i] > threshold:
            if integrated[i] > integrated[i-1] and integrated[i] > integrated[i+1]:
                if not peaks or (i - peaks[-1]) > refractory:
                    peaks.append(i)
    
    return np.array(peaks)

def median_filter_rr(rr_intervals, window=5):
    """
    Median filter for removing physiologically uninterpretable R-R intervals
    Based on Chen et al. methodology
    """
    if len(rr_intervals) < window:
        return rr_intervals
    
    filtered = rr_intervals.copy()
    median_rr = np.median(rr_intervals)
    
    for i in range(len(rr_intervals)):
        # Check if RR interval is physiologically valid (300ms - 2000ms)
        if rr_intervals[i] < 0.3 or rr_intervals[i] > 2.0:
            filtered[i] = median_rr
            continue
        
        # Median filter
        start = max(0, i - window//2)
        end = min(len(rr_intervals), i + window//2 + 1)
        window_vals = rr_intervals[start:end]
        local_median = np.median(window_vals)
        
        # Replace outliers (> 20% deviation from local median)
        if abs(rr_intervals[i] - local_median) > 0.2 * local_median:
            filtered[i] = local_median
    
    return filtered

def extract_rr_features(ecg_segment, fs=100):
    """
    NaN-safe R-peak detection + 3 Hz interpolation.
    Returns:
        rr_interp  : length 180  (seconds * 3 Hz)
        ramp_interp: length 180
    """
    # ---- 1. R-peak detection -------------------------------------------------
    r_peaks = detect_r_peaks_hamilton(ecg_segment, fs)
    if len(r_peaks) < 3:                       # not enough peaks → dummy
        return np.zeros(180, dtype=np.float32), np.zeros(180, dtype=np.float32)

    # ---- 2. RR intervals -----------------------------------------------------
    rr_sec = np.diff(r_peaks) / fs
    rr_sec = median_filter_rr(rr_sec)          # outlier removal
    rr_times = r_peaks[1:] / fs                # time stamp of each RR

    # ---- 3. R-peak amplitudes ------------------------------------------------
    ramp = ecg_segment[r_peaks[1:]]

    # ---- 4. Cubic interpolation to 3 Hz -------------------------------------
    targ_t = np.linspace(0, 60, 180)
    kind = 'cubic' if len(rr_sec) >= 4 else 'linear'

    f_rr   = interp1d(rr_times, rr_sec,   kind=kind, bounds_error=False, fill_value=(rr_sec[0], rr_sec[-1]))
    f_amp  = interp1d(rr_times, ramp,     kind=kind, bounds_error=False, fill_value=(ramp[0],   ramp[-1]))

    rr_out   = np.clip(f_rr(targ_t), 0.3, 2.0).astype(np.float32)
    ramp_out = f_amp(targ_t).astype(np.float32)

    # ---- 5. Normalise with safety checks -------------------------------------
    rr_std = rr_out.std()
    ramp_std = ramp_out.std()
    
    if rr_std > 1e-6:
        rr_out = (rr_out - rr_out.mean()) / rr_std
    else:
        rr_out = rr_out - rr_out.mean()
    
    if ramp_std > 1e-6:
        ramp_out = (ramp_out - ramp_out.mean()) / ramp_std
    else:
        ramp_out = ramp_out - ramp_out.mean()
    
    # Clip to prevent extreme values
    rr_out = np.clip(rr_out, -10, 10)
    ramp_out = np.clip(ramp_out, -10, 10)
    
    return rr_out, ramp_out

# ----------------------------- Improved Model ---------------------------------

class MultiScaleBlock(nn.Module):
    """Multi-scale feature extraction block"""
    def __init__(self, in_channels, out_channels):
        super().__init__()
        # ensure total output channels == out_channels
        c1 = out_channels // 3
        c2 = out_channels // 3
        c3 = out_channels - c1 - c2  # remainder so c1 + c2 + c3 == out_channels

        self.conv1 = nn.Conv1d(in_channels, c1, kernel_size=3, padding=1)
        self.conv2 = nn.Conv1d(in_channels, c2, kernel_size=5, padding=2)
        self.conv3 = nn.Conv1d(in_channels, c3, kernel_size=7, padding=3)
        self.bn = nn.BatchNorm1d(out_channels)
        
    def forward(self, x):
        x1 = self.conv1(x)
        x2 = self.conv2(x)
        x3 = self.conv3(x)
        out = torch.cat([x1, x2, x3], dim=1)
        return F.gelu(self.bn(out))


class EnhancedResBlock(nn.Module):
    """Enhanced residual block with squeeze-excitation"""
    def __init__(self, channels, kernel_size=7):
        super().__init__()
        self.conv1 = nn.Conv1d(channels, channels, kernel_size, padding=kernel_size//2, groups=channels)
        self.conv2 = nn.Conv1d(channels, channels, 1)
        self.norm = nn.BatchNorm1d(channels)
        
        # Squeeze-Excitation with stability improvements
        se_channels = max(8, channels // 8)  # Ensure at least 8 channels
        self.se = nn.Sequential(
            nn.AdaptiveAvgPool1d(1),
            nn.Conv1d(channels, se_channels, 1),
            nn.GELU(),
            nn.Conv1d(se_channels, channels, 1),
            nn.Sigmoid()
        )
        self.dropout = nn.Dropout(0.15)
        
    def forward(self, x):
        residual = x
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.norm(x)
        
        # Apply SE attention with safety check
        se_weight = self.se(x)
        # Clamp to prevent extreme values
        se_weight = torch.clamp(se_weight, 0.0, 1.0)
        x = x * se_weight
        
        x = self.dropout(x)
        return F.gelu(residual + x)

class ImprovedApneaNet(nn.Module):
    def __init__(self, d_model=256, n_blocks=10, dropout=0.15):
        super().__init__()
        
        # Ensure the three modality-channel outputs sum to d_model
        c1 = d_model // 3
        c2 = d_model // 3
        c3 = d_model - c1 - c2

        # ECG pathway (6000 samples) -> c1 channels
        self.ecg_stem = nn.Sequential(
            nn.Conv1d(1, c1, kernel_size=15, padding=7, stride=4),
            nn.BatchNorm1d(c1),
            nn.GELU(),
            nn.Conv1d(c1, c1, kernel_size=7, padding=3, stride=2),
            nn.BatchNorm1d(c1),
            nn.GELU(),
        )

        # RR interval pathway (180 samples @ 3Hz) -> c2 channels
        self.rr_stem = nn.Sequential(
            nn.Conv1d(1, c2, kernel_size=7, padding=3, stride=2),
            nn.BatchNorm1d(c2),
            nn.GELU(),
            nn.Conv1d(c2, c2, kernel_size=5, padding=2),
            nn.BatchNorm1d(c2),
            nn.GELU(),
        )

        # R-amplitude pathway (180 samples @ 3Hz) -> c3 channels
        self.ramp_stem = nn.Sequential(
            nn.Conv1d(1, c3, kernel_size=7, padding=3, stride=2),
            nn.BatchNorm1d(c3),
            nn.GELU(),
            nn.Conv1d(c3, c3, kernel_size=5, padding=2),
            nn.BatchNorm1d(c3),
            nn.GELU(),
        )
        
        # Multi-scale fusion
        self.fusion = MultiScaleBlock(d_model, d_model)
        
        # Enhanced residual blocks
        self.blocks = nn.ModuleList([
            EnhancedResBlock(d_model, kernel_size=7 if i % 2 == 0 else 11)
            for i in range(n_blocks)
        ])
        
        # Temporal attention with larger context
        self.temp_attn = nn.MultiheadAttention(d_model, num_heads=8, dropout=dropout, batch_first=True)
        self.temp_norm = nn.LayerNorm(d_model)
        self.temp_ffn = nn.Sequential(
            nn.Linear(d_model, d_model * 2),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(d_model * 2, d_model)
        )
        
        # Feature aggregation
        self.global_pool = nn.Sequential(
            nn.AdaptiveAvgPool1d(1),
            nn.Flatten()
        )
        
        # Enhanced classifier with multiple paths
        self.classifier = nn.Sequential(
            nn.Linear(d_model * 4, d_model * 2),
            nn.BatchNorm1d(d_model * 2),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(d_model * 2, d_model),
            nn.BatchNorm1d(d_model),
            nn.GELU(),
            nn.Dropout(dropout * 0.5),
            nn.Linear(d_model, 2)
        )
        
        # Initialize weights
        self._initialize_weights()
        
    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv1d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm1d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
        
    def forward(self, ecg, rr, ramp):
        # Validate inputs
        if torch.isnan(ecg).any() or torch.isinf(ecg).any():
            ecg = torch.nan_to_num(ecg, nan=0.0, posinf=10.0, neginf=-10.0)
        if torch.isnan(rr).any() or torch.isinf(rr).any():
            rr = torch.nan_to_num(rr, nan=0.0, posinf=10.0, neginf=-10.0)
        if torch.isnan(ramp).any() or torch.isinf(ramp).any():
            ramp = torch.nan_to_num(ramp, nan=0.0, posinf=10.0, neginf=-10.0)
        
        # Process each modality
        ecg = ecg.transpose(1, 2)
        rr = rr.transpose(1, 2)
        ramp = ramp.transpose(1, 2)
        
        ecg_feat = self.ecg_stem(ecg)
        rr_feat = self.rr_stem(rr)
        ramp_feat = self.ramp_stem(ramp)
        
        # Align sequence lengths and concatenate
        target_len = min(ecg_feat.size(2), rr_feat.size(2), ramp_feat.size(2))
        ecg_feat = F.adaptive_avg_pool1d(ecg_feat, target_len)
        rr_feat = F.adaptive_avg_pool1d(rr_feat, target_len)
        ramp_feat = F.adaptive_avg_pool1d(ramp_feat, target_len)
        
        x = torch.cat([ecg_feat, rr_feat, ramp_feat], dim=1)
        
        # Multi-scale fusion
        x = self.fusion(x)
        
        # Residual blocks
        for block in self.blocks:
            x = block(x)
        
        # Multiple pooling strategies with safety
        x_avg = F.adaptive_avg_pool1d(x, 1).squeeze(-1)
        x_max = F.adaptive_max_pool1d(x, 1).squeeze(-1)
        x_std = torch.std(x, dim=2) + 1e-8  # Add epsilon for stability
        
        # Temporal attention
        x_seq = x.transpose(1, 2)
        x_attn, _ = self.temp_attn(x_seq, x_seq, x_seq)
        x_attn = self.temp_norm(x_attn + x_seq)
        x_attn = x_attn + self.temp_ffn(x_attn)
        x_attn = x_attn.mean(dim=1)
        
        # Combine all features
        x_combined = torch.cat([x_avg, x_max, x_std, x_attn], dim=1)
        
        # Final NaN check before classifier
        if torch.isnan(x_combined).any() or torch.isinf(x_combined).any():
            x_combined = torch.nan_to_num(x_combined, nan=0.0, posinf=10.0, neginf=-10.0)
        
        # Classify
        logits = self.classifier(x_combined)
        return logits

# --------------------------- Enhanced Dataset ---------------------------

class EnhancedApneaDataset(Dataset):
    """Dataset with R-R interval and R-amplitude extraction"""

    def __init__(self, data_dir: str, record_names: list = None, cache_dir: str = None,
                 segment_length: int = 6000, stride: int = 3000, split='train', augment=True):
        super().__init__()
        self.segment_length = int(segment_length)
        self.stride = int(stride)
        self.split = split
        self.augment = augment and (split == 'train')
        
        cache_dir = Path(cache_dir) if cache_dir else Path(data_dir)
        cache_file = cache_dir / f'apnea_enhanced_{split}_{segment_length}_{stride}.pt'

        if cache_file.exists():
            print(f"Loading cached {split} from {cache_file}")
            data = torch.load(cache_file)
            self.ecg_segments = data['ecg_segments']
            self.rr_segments = data['rr_segments']
            self.ramp_segments = data['ramp_segments']
            self.labels = data['labels']
        else:
            assert wfdb is not None, "wfdb required"
            assert record_names is not None, "record_names required"
            
            self.ecg_segments = []
            self.rr_segments = []
            self.ramp_segments = []
            self.labels = []
            self.data_dir = Path(data_dir)
            
            print(f"Processing {len(record_names)} records for {split} (with R-R extraction)...")
            for i, rec in enumerate(record_names):
                print(f"  [{i+1}/{len(record_names)}] {rec}...", end='\r')
                self._load_record(rec)
            
            if len(self.ecg_segments) == 0:
                raise RuntimeError("No segments loaded")
            
            self.ecg_segments = torch.tensor(np.stack(self.ecg_segments, axis=0), dtype=torch.float32)
            self.rr_segments = torch.tensor(np.stack(self.rr_segments, axis=0), dtype=torch.float32)
            self.ramp_segments = torch.tensor(np.stack(self.ramp_segments, axis=0), dtype=torch.float32)
            self.labels = torch.tensor(self.labels, dtype=torch.long)
            
            print(f"\nSaving cache to {cache_file}")
            torch.save({
                'ecg_segments': self.ecg_segments,
                'rr_segments': self.rr_segments,
                'ramp_segments': self.ramp_segments,
                'labels': self.labels
            }, cache_file)

        if self.ecg_segments.ndim == 2:
            self.ecg_segments = self.ecg_segments.unsqueeze(-1)
        if self.rr_segments.ndim == 2:
            self.rr_segments = self.rr_segments.unsqueeze(-1)
        if self.ramp_segments.ndim == 2:
            self.ramp_segments = self.ramp_segments.unsqueeze(-1)

        print(f"{split.capitalize()}: {len(self.ecg_segments)} segments, "
              f"Class: {Counter(self.labels.tolist())}")

    def _load_record(self, record_name: str):
        try:
            rec = wfdb.rdrecord(str(self.data_dir / record_name))
            sig = rec.p_signal[:, 0].astype(np.float32)
            if np.isnan(sig).any():
                nans = np.isnan(sig)
                sig[nans] = np.interp(np.flatnonzero(nans), np.flatnonzero(~nans), sig[~nans])

            ann = wfdb.rdann(str(self.data_dir / record_name), 'apn')
            n_min = len(sig) // 6000
            mins = np.zeros(n_min, dtype=int)
            for samp, sym in zip(ann.sample, ann.symbol):
                if sym == 'A':
                    m = samp // 6000
                    if m < n_min:
                        mins[m] = 1

            for start in range(0, len(sig) - self.segment_length + 1, self.stride):
                seg = sig[start:start + self.segment_length]
                
                # Normalise ECG with safety
                seg_std = seg.std()
                if seg_std > 1e-6:
                    seg = (seg - seg.mean()) / seg_std
                else:
                    seg = seg - seg.mean()
                seg = np.clip(seg, -10, 10)

                rr, ramp = extract_rr_features(seg, fs=100)

                minute = start // 6000
                if minute < len(mins):
                    self.ecg_segments.append(seg)
                    self.rr_segments.append(rr)
                    self.ramp_segments.append(ramp)
                    self.labels.append(int(mins[minute]))
        except Exception as e:
            print(f'\nSkip {record_name}: {e}')
    
    def _augment(self, ecg, rr, ramp):
        ecg, rr, ramp = map(lambda x: x.numpy() if torch.is_tensor(x) else x, (ecg, rr, ramp))

        if np.random.rand() < 0.5:
            ecg += np.random.normal(0, 0.02, ecg.shape).astype(np.float32)
            rr += np.random.normal(0, 0.01, rr.shape).astype(np.float32)
            ramp += np.random.normal(0, 0.01, ramp.shape).astype(np.float32)

        if np.random.rand() < 0.3:
            scale = np.random.uniform(0.9, 1.1)
            ecg *= scale
            ramp *= scale

        if np.random.rand() < 0.2:
            shift = np.random.randint(-150, 150)
            ecg = np.roll(ecg, shift)

        # Clip after augmentation
        ecg = np.clip(ecg, -10, 10)
        rr = np.clip(rr, -10, 10)
        ramp = np.clip(ramp, -10, 10)

        return tuple(map(torch.from_numpy, (ecg, rr, ramp)))
            
    def __len__(self):
        return self.ecg_segments.shape[0]

    def __getitem__(self, idx):
        ecg = self.ecg_segments[idx]
        rr = self.rr_segments[idx]
        ramp = self.ramp_segments[idx]
        label = self.labels[idx]
        
        if self.augment:
            ecg, rr, ramp = self._augment(ecg, rr, ramp)
        
        # Ensure correct shape
        if ecg.ndim == 1:
            ecg = ecg.unsqueeze(-1)
        if rr.ndim == 1:
            rr = rr.unsqueeze(-1)
        if ramp.ndim == 1:
            ramp = ramp.unsqueeze(-1)
        
        return ecg, rr, ramp, label

# -------------------------- Training ------------------------

def compute_class_weights(labels_tensor):
    counts = Counter(labels_tensor.tolist())
    total = sum(counts.values())
    num_classes = len(counts)
    weights = [total / (num_classes * counts.get(i, 1)) for i in range(num_classes)]
    return torch.tensor(weights, dtype=torch.float32)

def train_epoch(model, dataloader, criterion, optimizer, scheduler, device, epoch, scaler=None):
    model.train()
    total_loss = 0.0
    correct = 0
    total = 0
    
    num_batches = len(dataloader)
    print_freq = max(1, num_batches // 15)
    start_time = time.time()

    for batch_idx, (ecg, rr, ramp, target) in enumerate(dataloader, 1):
        ecg = ecg.to(device, non_blocking=True)
        rr = rr.to(device, non_blocking=True)
        ramp = ramp.to(device, non_blocking=True)
        target = target.to(device, non_blocking=True)

        optimizer.zero_grad(set_to_none=True)

        with torch.amp.autocast('cuda', enabled=(device.type == 'cuda')):
            output = model(ecg, rr, ramp)
            loss = criterion(output, target)
        
        if torch.isnan(loss) or torch.isinf(loss):
            print(f"\nWARNING: NaN/Inf loss, skipping batch {batch_idx}")
            continue

        if scaler is not None:
            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            scaler.step(optimizer)
            scaler.update()
        else:
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
        
        scheduler.step()

        total_loss += loss.item()
        pred = output.argmax(dim=1)
        correct += pred.eq(target).sum().item()
        total += target.size(0)
        
        if batch_idx % print_freq == 0 or batch_idx == num_batches:
            curr_acc = 100.0 * correct / total
            curr_loss = total_loss / batch_idx
            speed = batch_idx / (time.time() - start_time)
            eta = (num_batches - batch_idx) / speed if speed > 0 else 0
            
            print(f"  Ep {epoch} [{batch_idx:4d}/{num_batches}] "
                  f"Loss: {curr_loss:.4f} Acc: {curr_acc:.2f}% "
                  f"({speed:.1f} b/s, ETA: {eta:.0f}s)", end='\r')

    print()
    avg_loss = total_loss / len(dataloader)
    accuracy = 100.0 * correct / total if total > 0 else 0.0
    return avg_loss, accuracy

def validate(model, dataloader, criterion, device):
    model.eval()
    total_loss = 0.0
    correct = 0
    total = 0
    all_preds = []
    all_targets = []
    all_probs = []

    with torch.no_grad():
        for ecg, rr, ramp, target in dataloader:
            ecg = ecg.to(device, non_blocking=True)
            rr = rr.to(device, non_blocking=True)
            ramp = ramp.to(device, non_blocking=True)
            target = target.to(device, non_blocking=True)
            
            output = model(ecg, rr, ramp)
            loss = criterion(output, target)

            total_loss += loss.item()
            probs = F.softmax(output, dim=1)[:, 1]
            pred = output.argmax(dim=1)

            correct += pred.eq(target).sum().item()
            total += target.size(0)

            all_preds.extend(pred.cpu().tolist())
            all_targets.extend(target.cpu().tolist())
            all_probs.extend(probs.cpu().tolist())

    avg_loss = total_loss / len(dataloader)
    accuracy = 100.0 * correct / total if total > 0 else 0.0
    
    precision = precision_score(all_targets, all_preds, zero_division=0)
    recall = recall_score(all_targets, all_preds, zero_division=0)
    f1 = f1_score(all_targets, all_preds, zero_division=0)
    
    # Calculate specificity and sensitivity
    tn, fp, fn, tp = confusion_matrix(all_targets, all_preds).ravel()
    sensitivity = tp / (tp + fn) if (tp + fn) > 0 else 0.0
    specificity = tn / (tn + fp) if (tn + fp) > 0 else 0.0
    
    return avg_loss, accuracy, np.array(all_preds), np.array(all_targets), np.array(all_probs), precision, recall, f1, sensitivity, specificity

# ------------------------------ Main ------------------------------------

def main(args):
    set_seed(args.seed)

    DATA_DIR = Path(args.data_dir)
    if not DATA_DIR.exists():
        raise FileNotFoundError(f"Data directory not found: {DATA_DIR}")

    record_files = list(DATA_DIR.glob('*.hea'))
    all_records = [f.stem for f in record_files]
    valid_records = [rec for rec in all_records 
                    if (DATA_DIR / (rec + '.apn')).exists() and not rec.endswith('er')]
    
    if len(valid_records) == 0:
        raise RuntimeError("No valid records found")

    print(f"Found {len(valid_records)} valid records")

    import random
    valid_records_shuffled = valid_records.copy()
    random.Random(args.seed).shuffle(valid_records_shuffled)
    split_idx = int(len(valid_records_shuffled) * args.train_split)
    train_records = valid_records_shuffled[:split_idx]
    val_records = valid_records_shuffled[split_idx:]
    print(f"Train: {len(train_records)}, Val: {len(val_records)}\n")

    cache_dir = args.cache_dir if args.cache_dir else str(DATA_DIR)
    train_dataset = EnhancedApneaDataset(
        str(DATA_DIR), train_records, cache_dir,
        args.segment_length, args.stride, 'train', augment=True
    )
    val_dataset = EnhancedApneaDataset(
        str(DATA_DIR), val_records, cache_dir,
        args.segment_length, args.stride, 'val', augment=False
    )

    num_workers = 2 if str(DATA_DIR).startswith('/kaggle') else 4

    train_loader = DataLoader(
        train_dataset, batch_size=args.batch_size, shuffle=True,
        num_workers=num_workers, pin_memory=True
    )
    val_loader = DataLoader(
        val_dataset, batch_size=args.batch_size, shuffle=False,
        num_workers=num_workers, pin_memory=True
    )

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Device: {device}")
    if device.type == 'cuda':
        try:
            print(f"GPU: {torch.cuda.get_device_name(0)}")
        except Exception:
            pass
        torch.cuda.empty_cache()

    model = ImprovedApneaNet(
        d_model=args.d_model, n_blocks=args.n_blocks, dropout=args.dropout
    ).to(device)

    print(f"Parameters: {sum(p.numel() for p in model.parameters()):,}\n")

    class_weights = compute_class_weights(train_dataset.labels).to(device)
    print(f"Class weights: {class_weights}")

    # Use Focal Loss for better class imbalance handling
        # stable, NaN-free loss with class weights
    criterion = nn.CrossEntropyLoss(weight=class_weights, label_smoothing=0.05)

    optimizer = torch.optim.AdamW(
        model.parameters(), lr=args.lr, weight_decay=args.weight_decay
    )

    scheduler = torch.optim.lr_scheduler.OneCycleLR(
        optimizer, max_lr=args.lr, epochs=args.epochs,
        steps_per_epoch=len(train_loader), pct_start=0.2
    )

    scaler = torch.amp.GradScaler() if device.type == 'cuda' else None

    best_val_acc = 0.0
    best_val_f1 = 0.0
    no_improve = 0

    print("\nStarting training...")
    print("="*100)

    for epoch in range(1, args.epochs + 1):
        epoch_start = time.time()

        train_loss, train_acc = train_epoch(
            model, train_loader, criterion, optimizer, scheduler, device, epoch, scaler
        )

        val_loss, val_acc, _, val_targets, val_probs, precision, recall, f1, sensitivity, specificity = validate(
            model, val_loader, criterion, device
        )

        try:
            auc = roc_auc_score(val_targets, val_probs)
        except Exception:
            auc = 0.0

        epoch_time = time.time() - epoch_start

        print(f"Epoch {epoch:2d}/{args.epochs} ({epoch_time:.1f}s)")
        print(f"  Train: Loss={train_loss:.4f}, Acc={train_acc:.2f}%")
        print(f"  Val:   Loss={val_loss:.4f}, Acc={val_acc:.2f}%, AUC={auc:.4f}")
        print(f"         Prec={precision:.3f}, Rec={recall:.3f}, F1={f1:.3f}")
        print(f"         Sensitivity={sensitivity:.3f}, Specificity={specificity:.3f}")

        if val_acc > best_val_acc or (val_acc >= best_val_acc and f1 > best_val_f1):
            best_val_acc = val_acc
            best_val_f1 = f1
            no_improve = 0
            torch.save({
                'epoch': epoch, 'model_state_dict': model.state_dict(),
                'val_acc': val_acc, 'val_auc': auc, 'val_f1': f1,
                'sensitivity': sensitivity, 'specificity': specificity
            }, args.best_model_path)
            print(f"  ✓ Best! (Acc={val_acc:.2f}%, F1={f1:.3f}, Sens={sensitivity:.3f}, Spec={specificity:.3f})")
        else:
            no_improve += 1
            print(f"  No improvement ({no_improve}/{args.patience})")

        print("-"*100)

        if no_improve >= args.patience:
            print(f"\nEarly stop at epoch {epoch}")
            break

    print(f"\n{'='*100}")
    print(f"BEST - Accuracy: {best_val_acc:.2f}%, F1: {best_val_f1:.3f}")
    print(f"{'='*100}")


if __name__ == '__main__':
    kaggle_data = '/kaggle/input/vincent2/apnea-ecg-database-1.0.0'
    colab_data = '/content/apnea-ecg/1.0.0'
    if Path(kaggle_data).exists():
        default_data_dir = kaggle_data
        default_cache_dir = '/kaggle/working'
        default_model_path = '/kaggle/working/best_model.pth'
    elif Path(colab_data).exists():
        default_data_dir = colab_data
        default_cache_dir = '/content'
        default_model_path = '/content/best_model.pth'
    else:
        default_data_dir = None
        default_cache_dir = None
        default_model_path = 'best_model.pth'

    parser = argparse.ArgumentParser()
    parser.add_argument('--data-dir', type=str, default=default_data_dir)
    parser.add_argument('--cache-dir', type=str, default=default_cache_dir)
    parser.add_argument('--segment-length', type=int, default=6000)
    parser.add_argument('--stride', type=int, default=2400)  # 60% overlap for more data
    parser.add_argument('--batch-size', type=int, default=32)  # Adjusted for multi-modal
    parser.add_argument('--epochs', type=int, default=100)
    parser.add_argument('--lr', type=float, default=1e-4)
    parser.add_argument('--weight-decay', type=float, default=0)
    parser.add_argument('--d-model', type=int, default=256)
    parser.add_argument('--n-blocks', type=int, default=10)
    parser.add_argument('--dropout', type=float, default=0.15)
    parser.add_argument('--train-split', type=float, default=0.8)
    parser.add_argument('--patience', type=int, default=20)
    parser.add_argument('--best-model-path', type=str, default=default_model_path)
    parser.add_argument('--seed', type=int, default=42)

    args, _ = parser.parse_known_args()

    if args.data_dir is None:
        raise SystemExit("ERROR: Dataset not found")

    print("="*100)
    print("ENHANCED MODEL WITH R-R INTERVALS (Target: 90%+ Accuracy)")
    print("="*100)
    print(f"  Data:       {args.data_dir}")
    print(f"  Segment:    {args.segment_length} samples (60s), stride={args.stride}")
    print(f"  Features:   ECG + R-R Intervals + R-peak Amplitudes")
    print(f"  Batch:      {args.batch_size}")
    print(f"  Epochs:     {args.epochs}")
    print(f"  Model:      d_model={args.d_model}, blocks={args.n_blocks}")
    print(f"  Optimizer:  AdamW (lr={args.lr}, wd={args.weight_decay})")
    print(f"  Loss:       Focal Loss (alpha=0.25, gamma=2.0)")
    print("="*100 + "\n")

    main(args)

ENHANCED MODEL WITH R-R INTERVALS (Target: 90%+ Accuracy)
  Data:       /kaggle/input/vincent2/apnea-ecg-database-1.0.0
  Segment:    6000 samples (60s), stride=2400
  Features:   ECG + R-R Intervals + R-peak Amplitudes
  Batch:      32
  Epochs:     100
  Model:      d_model=256, blocks=10
  Optimizer:  AdamW (lr=0.0001, wd=0)
  Loss:       Focal Loss (alpha=0.25, gamma=2.0)

Found 43 valid records
Train: 34, Val: 9

Processing 34 records for train (with R-R extraction)...
  [34/34] a20....
Saving cache to /kaggle/working/apnea_enhanced_train_6000_2400.pt
Train: 41751 segments, Class: Counter({0: 26395, 1: 15356})
Processing 9 records for val (with R-R extraction)...
  [9/9] a01....
Saving cache to /kaggle/working/apnea_enhanced_val_6000_2400.pt
Val: 10777 segments, Class: Counter({0: 5862, 1: 4915})
Device: cuda
GPU: Tesla P100-PCIE-16GB
Parameters: 2,496,778

Class weights: tensor([0.7909, 1.3594], device='cuda:0')

Starting training...




  Ep 1 [1305/1305] Loss: 0.6828 Acc: 69.07% (24.3 b/s, ETA: 0s))
Epoch  1/100 (58.0s)
  Train: Loss=0.6828, Acc=69.07%
  Val:   Loss=0.5591, Acc=73.87%, AUC=0.8367
         Prec=0.666, Rec=0.855, F1=0.749
         Sensitivity=0.855, Specificity=0.641
  ✓ Best! (Acc=73.87%, F1=0.749, Sens=0.855, Spec=0.641)
----------------------------------------------------------------------------------------------------
  Ep 2 [1305/1305] Loss: 0.5557 Acc: 75.20% (25.4 b/s, ETA: 0s))
Epoch  2/100 (55.5s)
  Train: Loss=0.5557, Acc=75.20%
  Val:   Loss=0.4985, Acc=78.19%, AUC=0.8756
         Prec=0.716, Rec=0.865, F1=0.783
         Sensitivity=0.865, Specificity=0.712
  ✓ Best! (Acc=78.19%, F1=0.783, Sens=0.865, Spec=0.712)
----------------------------------------------------------------------------------------------------
  Ep 3 [1305/1305] Loss: 0.5022 Acc: 78.56% (25.3 b/s, ETA: 0s))
Epoch  3/100 (55.6s)
  Train: Loss=0.5022, Acc=78.56%
  Val:   Loss=0.5167, Acc=77.36%, AUC=0.8655
         Prec=0.71

In [3]:
#!/usr/bin/env python3
"""
High-performance apnea detection with R-R interval extraction (Target: 90%+ accuracy)
Based on PhysioNet Apnea-ECG Database methodology
"""

import argparse
import os
import time
from pathlib import Path
from collections import Counter

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from scipy import signal as scipy_signal
from scipy.interpolate import interp1d

try:
    import wfdb
except Exception:
    wfdb = None

from sklearn.metrics import roc_auc_score, precision_score, recall_score, f1_score, confusion_matrix

# ----------------------------- Utilities ---------------------------------

def set_seed(seed=42):
    import random
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

# ----------------------------- R-Peak Detection & R-R Interval Extraction ---------------------------------

def detect_r_peaks_hamilton(ecg_signal, fs=100):
    """
    Enhanced Hamilton R-peak detection with better noise handling
    """
    # Remove baseline wander (0.5 Hz high-pass)
    b_high, a_high = scipy_signal.butter(2, 0.5, btype='high', fs=fs)
    ecg_signal = scipy_signal.filtfilt(b_high, a_high, ecg_signal)
    
    # Bandpass filter (5-15 Hz) - tighter for QRS
    b, a = scipy_signal.butter(3, [5, 15], btype='band', fs=fs)
    filtered = scipy_signal.filtfilt(b, a, ecg_signal)
    
    # Derivative with normalization
    diff_signal = np.diff(filtered)
    diff_signal = diff_signal / (np.std(diff_signal) + 1e-8)
    
    # Squaring
    squared = diff_signal ** 2
    
    # Moving average integration (150ms window)
    window_size = int(0.15 * fs)
    integrated = np.convolve(squared, np.ones(window_size)/window_size, mode='same')
    
    # Adaptive thresholding with percentile
    threshold = np.percentile(integrated, 75) + 0.3 * np.std(integrated)
    
    # Find peaks with minimum distance constraint
    from scipy.signal import find_peaks
    peaks, properties = find_peaks(integrated, 
                                   height=threshold,
                                   distance=int(0.2 * fs))  # 200ms minimum
    
    # Validate peaks are actual R-peaks by checking original signal
    valid_peaks = []
    for peak in peaks:
        # Check window around peak in original signal
        win_start = max(0, peak - int(0.05 * fs))
        win_end = min(len(ecg_signal), peak + int(0.05 * fs))
        if win_end > win_start:
            local_max = np.argmax(np.abs(ecg_signal[win_start:win_end])) + win_start
            valid_peaks.append(local_max)
    
    return np.array(valid_peaks)

def median_filter_rr(rr_intervals, window=5):
    """
    Median filter for removing physiologically uninterpretable R-R intervals
    Based on Chen et al. methodology
    """
    if len(rr_intervals) < window:
        return rr_intervals
    
    filtered = rr_intervals.copy()
    median_rr = np.median(rr_intervals)
    
    for i in range(len(rr_intervals)):
        # Check if RR interval is physiologically valid (300ms - 2000ms)
        if rr_intervals[i] < 0.3 or rr_intervals[i] > 2.0:
            filtered[i] = median_rr
            continue
        
        # Median filter
        start = max(0, i - window//2)
        end = min(len(rr_intervals), i + window//2 + 1)
        window_vals = rr_intervals[start:end]
        local_median = np.median(window_vals)
        
        # Replace outliers (> 20% deviation from local median)
        if abs(rr_intervals[i] - local_median) > 0.2 * local_median:
            filtered[i] = local_median
    
    return filtered

def extract_rr_features(ecg_segment, fs=100):
    """
    Enhanced RR extraction with additional HRV features
    Returns 360 features: [rr_interp(180), ramp_interp(180)]
    """
    r_peaks = detect_r_peaks_hamilton(ecg_segment, fs)
    if len(r_peaks) < 5:
        return np.zeros(180, dtype=np.float32), np.zeros(180, dtype=np.float32)

    # RR intervals
    rr_sec = np.diff(r_peaks) / fs
    rr_sec = median_filter_rr(rr_sec)
    rr_times = r_peaks[1:] / fs
    
    # R-peak amplitudes
    ramp = ecg_segment[r_peaks[1:]]
    
    # Interpolate to 3 Hz (180 samples for 60s)
    targ_t = np.linspace(0, 60, 180)
    
    # Use better interpolation
    if len(rr_sec) >= 10:
        kind = 'cubic'
    elif len(rr_sec) >= 4:
        kind = 'quadratic'
    else:
        kind = 'linear'
    
    f_rr = interp1d(rr_times, rr_sec, kind=kind, bounds_error=False, 
                    fill_value='extrapolate')
    f_amp = interp1d(rr_times, ramp, kind=kind, bounds_error=False, 
                     fill_value='extrapolate')
    
    rr_out = f_rr(targ_t)
    ramp_out = f_amp(targ_t)
    
    # Clip to physiological range
    rr_out = np.clip(rr_out, 0.3, 2.0)
    
    # Robust normalization (using IQR instead of std)
    def robust_normalize(x):
        q1, q3 = np.percentile(x, [25, 75])
        iqr = q3 - q1
        if iqr > 1e-6:
            return np.clip((x - np.median(x)) / iqr, -5, 5)
        else:
            return x - np.median(x)
    
    rr_out = robust_normalize(rr_out).astype(np.float32)
    ramp_out = robust_normalize(ramp_out).astype(np.float32)
    
    return rr_out, ramp_out

# ----------------------------- Improved Model ---------------------------------

class MultiScaleBlock(nn.Module):
    """Multi-scale feature extraction block"""
    def __init__(self, in_channels, out_channels):
        super().__init__()
        # ensure total output channels == out_channels
        c1 = out_channels // 3
        c2 = out_channels // 3
        c3 = out_channels - c1 - c2  # remainder so c1 + c2 + c3 == out_channels

        self.conv1 = nn.Conv1d(in_channels, c1, kernel_size=3, padding=1)
        self.conv2 = nn.Conv1d(in_channels, c2, kernel_size=5, padding=2)
        self.conv3 = nn.Conv1d(in_channels, c3, kernel_size=7, padding=3)
        self.bn = nn.BatchNorm1d(out_channels)
        
    def forward(self, x):
        x1 = self.conv1(x)
        x2 = self.conv2(x)
        x3 = self.conv3(x)
        out = torch.cat([x1, x2, x3], dim=1)
        return F.gelu(self.bn(out))


class EnhancedResBlock(nn.Module):
    """Enhanced residual block with squeeze-excitation"""
    def __init__(self, channels, kernel_size=7):
        super().__init__()
        self.conv1 = nn.Conv1d(channels, channels, kernel_size, padding=kernel_size//2, groups=channels)
        self.conv2 = nn.Conv1d(channels, channels, 1)
        self.norm = nn.BatchNorm1d(channels)
        
        # Squeeze-Excitation with stability improvements
        se_channels = max(8, channels // 8)  # Ensure at least 8 channels
        self.se = nn.Sequential(
            nn.AdaptiveAvgPool1d(1),
            nn.Conv1d(channels, se_channels, 1),
            nn.GELU(),
            nn.Conv1d(se_channels, channels, 1),
            nn.Sigmoid()
        )
        self.dropout = nn.Dropout(0.15)
        
    def forward(self, x):
        residual = x
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.norm(x)
        
        # Apply SE attention with safety check
        se_weight = self.se(x)
        # Clamp to prevent extreme values
        se_weight = torch.clamp(se_weight, 0.0, 1.0)
        x = x * se_weight
        
        x = self.dropout(x)
        return F.gelu(residual + x)

class ImprovedApneaNet(nn.Module):
    def __init__(self, d_model=256, n_blocks=10, dropout=0.15):
        super().__init__()
        
        # Ensure the three modality-channel outputs sum to d_model
        c1 = d_model // 3
        c2 = d_model // 3
        c3 = d_model - c1 - c2

        # ECG pathway (6000 samples) -> c1 channels
        self.ecg_stem = nn.Sequential(
            nn.Conv1d(1, c1, kernel_size=15, padding=7, stride=4),
            nn.BatchNorm1d(c1),
            nn.GELU(),
            nn.Conv1d(c1, c1, kernel_size=7, padding=3, stride=2),
            nn.BatchNorm1d(c1),
            nn.GELU(),
        )

        # RR interval pathway (180 samples @ 3Hz) -> c2 channels
        self.rr_stem = nn.Sequential(
            nn.Conv1d(1, c2, kernel_size=7, padding=3, stride=2),
            nn.BatchNorm1d(c2),
            nn.GELU(),
            nn.Conv1d(c2, c2, kernel_size=5, padding=2),
            nn.BatchNorm1d(c2),
            nn.GELU(),
        )

        # R-amplitude pathway (180 samples @ 3Hz) -> c3 channels
        self.ramp_stem = nn.Sequential(
            nn.Conv1d(1, c3, kernel_size=7, padding=3, stride=2),
            nn.BatchNorm1d(c3),
            nn.GELU(),
            nn.Conv1d(c3, c3, kernel_size=5, padding=2),
            nn.BatchNorm1d(c3),
            nn.GELU(),
        )
        
        # Multi-scale fusion
        self.fusion = MultiScaleBlock(d_model, d_model)
        
        # Enhanced residual blocks
        self.blocks = nn.ModuleList([
            EnhancedResBlock(d_model, kernel_size=7 if i % 2 == 0 else 11)
            for i in range(n_blocks)
        ])
        
        # Temporal attention with larger context
        self.temp_attn = nn.MultiheadAttention(d_model, num_heads=8, dropout=dropout, batch_first=True)
        self.temp_norm = nn.LayerNorm(d_model)
        self.temp_ffn = nn.Sequential(
            nn.Linear(d_model, d_model * 2),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(d_model * 2, d_model)
        )
        
        # Feature aggregation
        self.global_pool = nn.Sequential(
            nn.AdaptiveAvgPool1d(1),
            nn.Flatten()
        )
        
        # Enhanced classifier with multiple paths
        # Attention-based pooling
        self.attention_pool = nn.Sequential(
            nn.Conv1d(d_model, d_model // 4, 1),
            nn.GELU(),
            nn.Conv1d(d_model // 4, 1, 1),
            nn.Softmax(dim=2)
        )
        
        # Enhanced classifier with deeper network
        self.classifier = nn.Sequential(
            nn.Linear(d_model * 4, d_model * 3),
            nn.BatchNorm1d(d_model * 3),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(d_model * 3, d_model * 2),
            nn.BatchNorm1d(d_model * 2),
            nn.GELU(),
            nn.Dropout(dropout * 0.7),
            nn.Linear(d_model * 2, d_model),
            nn.BatchNorm1d(d_model),
            nn.GELU(),
            nn.Dropout(dropout * 0.5),
            nn.Linear(d_model, 2)
        )
        
        # Initialize weights
        self._initialize_weights()
        
    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv1d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm1d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
        
    def forward(self, ecg, rr, ramp):
        # Validate inputs
        if torch.isnan(ecg).any() or torch.isinf(ecg).any():
            ecg = torch.nan_to_num(ecg, nan=0.0, posinf=10.0, neginf=-10.0)
        if torch.isnan(rr).any() or torch.isinf(rr).any():
            rr = torch.nan_to_num(rr, nan=0.0, posinf=10.0, neginf=-10.0)
        if torch.isnan(ramp).any() or torch.isinf(ramp).any():
            ramp = torch.nan_to_num(ramp, nan=0.0, posinf=10.0, neginf=-10.0)
        
        # Process each modality
        ecg = ecg.transpose(1, 2)
        rr = rr.transpose(1, 2)
        ramp = ramp.transpose(1, 2)
        
        ecg_feat = self.ecg_stem(ecg)
        rr_feat = self.rr_stem(rr)
        ramp_feat = self.ramp_stem(ramp)
        
        # Align sequence lengths and concatenate
        target_len = min(ecg_feat.size(2), rr_feat.size(2), ramp_feat.size(2))
        ecg_feat = F.adaptive_avg_pool1d(ecg_feat, target_len)
        rr_feat = F.adaptive_avg_pool1d(rr_feat, target_len)
        ramp_feat = F.adaptive_avg_pool1d(ramp_feat, target_len)
        
        x = torch.cat([ecg_feat, rr_feat, ramp_feat], dim=1)
        
        # Multi-scale fusion
        x = self.fusion(x)
        
        # Residual blocks
        for block in self.blocks:
            x = block(x)
        
        # Multiple pooling strategies with safety
        # Attention-weighted pooling
        attn_weights = self.attention_pool(x)
        x_attn_pool = (x * attn_weights).sum(dim=2)
        
        # Multiple pooling strategies
        x_avg = F.adaptive_avg_pool1d(x, 1).squeeze(-1)
        x_max = F.adaptive_max_pool1d(x, 1).squeeze(-1)
        x_std = torch.std(x, dim=2) + 1e-8
        
        # Temporal attention
        x_seq = x.transpose(1, 2)
        x_attn, _ = self.temp_attn(x_seq, x_seq, x_seq)
        x_attn = self.temp_norm(x_attn + x_seq)
        x_attn = x_attn + self.temp_ffn(x_attn)
        x_attn = x_attn.mean(dim=1)
        
        # Combine all features
        x_combined = torch.cat([x_avg, x_max, x_std, x_attn], dim=1)
        
        # Final NaN check before classifier
        if torch.isnan(x_combined).any() or torch.isinf(x_combined).any():
            x_combined = torch.nan_to_num(x_combined, nan=0.0, posinf=10.0, neginf=-10.0)
        
        # Classify
        logits = self.classifier(x_combined)
        return logits

# --------------------------- Enhanced Dataset ---------------------------

class EnhancedApneaDataset(Dataset):
    """Dataset with R-R interval and R-amplitude extraction"""

    def __init__(self, data_dir: str, record_names: list = None, cache_dir: str = None,
                 segment_length: int = 6000, stride: int = 3000, split='train', augment=True):
        super().__init__()
        self.segment_length = int(segment_length)
        self.stride = int(stride)
        self.split = split
        self.augment = augment and (split == 'train')
        
        cache_dir = Path(cache_dir) if cache_dir else Path(data_dir)
        cache_file = cache_dir / f'apnea_enhanced_{split}_{segment_length}_{stride}.pt'

        if cache_file.exists():
            print(f"Loading cached {split} from {cache_file}")
            data = torch.load(cache_file)
            self.ecg_segments = data['ecg_segments']
            self.rr_segments = data['rr_segments']
            self.ramp_segments = data['ramp_segments']
            self.labels = data['labels']
        else:
            assert wfdb is not None, "wfdb required"
            assert record_names is not None, "record_names required"
            
            self.ecg_segments = []
            self.rr_segments = []
            self.ramp_segments = []
            self.labels = []
            self.data_dir = Path(data_dir)
            
            print(f"Processing {len(record_names)} records for {split} (with R-R extraction)...")
            for i, rec in enumerate(record_names):
                print(f"  [{i+1}/{len(record_names)}] {rec}...", end='\r')
                self._load_record(rec)
            
            if len(self.ecg_segments) == 0:
                raise RuntimeError("No segments loaded")
            
            self.ecg_segments = torch.tensor(np.stack(self.ecg_segments, axis=0), dtype=torch.float32)
            self.rr_segments = torch.tensor(np.stack(self.rr_segments, axis=0), dtype=torch.float32)
            self.ramp_segments = torch.tensor(np.stack(self.ramp_segments, axis=0), dtype=torch.float32)
            self.labels = torch.tensor(self.labels, dtype=torch.long)
            
            print(f"\nSaving cache to {cache_file}")
            torch.save({
                'ecg_segments': self.ecg_segments,
                'rr_segments': self.rr_segments,
                'ramp_segments': self.ramp_segments,
                'labels': self.labels
            }, cache_file)

        if self.ecg_segments.ndim == 2:
            self.ecg_segments = self.ecg_segments.unsqueeze(-1)
        if self.rr_segments.ndim == 2:
            self.rr_segments = self.rr_segments.unsqueeze(-1)
        if self.ramp_segments.ndim == 2:
            self.ramp_segments = self.ramp_segments.unsqueeze(-1)

        print(f"{split.capitalize()}: {len(self.ecg_segments)} segments, "
              f"Class: {Counter(self.labels.tolist())}")

    def _load_record(self, record_name: str):
        try:
            rec = wfdb.rdrecord(str(self.data_dir / record_name))
            sig = rec.p_signal[:, 0].astype(np.float32)
            if np.isnan(sig).any():
                nans = np.isnan(sig)
                sig[nans] = np.interp(np.flatnonzero(nans), np.flatnonzero(~nans), sig[~nans])

            ann = wfdb.rdann(str(self.data_dir / record_name), 'apn')
            n_min = len(sig) // 6000
            mins = np.zeros(n_min, dtype=int)
            for samp, sym in zip(ann.sample, ann.symbol):
                if sym == 'A':
                    m = samp // 6000
                    if m < n_min:
                        mins[m] = 1

            for start in range(0, len(sig) - self.segment_length + 1, self.stride):
                seg = sig[start:start + self.segment_length]
                
                # Normalise ECG with safety
                seg_std = seg.std()
                if seg_std > 1e-6:
                    seg = (seg - seg.mean()) / seg_std
                else:
                    seg = seg - seg.mean()
                seg = np.clip(seg, -10, 10)

                rr, ramp = extract_rr_features(seg, fs=100)

                minute = start // 6000
                if minute < len(mins):
                    self.ecg_segments.append(seg)
                    self.rr_segments.append(rr)
                    self.ramp_segments.append(ramp)
                    self.labels.append(int(mins[minute]))
        except Exception as e:
            print(f'\nSkip {record_name}: {e}')
    
    def _augment(self, ecg, rr, ramp):
        ecg, rr, ramp = map(lambda x: x.numpy() if torch.is_tensor(x) else x, (ecg, rr, ramp))

    # Flatten to 1D for augmentation, will reshape back at the end
        ecg = ecg.squeeze()
        rr = rr.squeeze()
        ramp = ramp.squeeze()

    # Time shift (more aggressive)
        if np.random.rand() < 0.4:
            shift = np.random.randint(-300, 300)
            ecg = np.roll(ecg, shift)

    # Gaussian noise (adjusted per signal type)
        if np.random.rand() < 0.6:
            ecg += np.random.normal(0, 0.03, ecg.shape).astype(np.float32)
            rr += np.random.normal(0, 0.02, rr.shape).astype(np.float32)
            ramp += np.random.normal(0, 0.02, ramp.shape).astype(np.float32)

    # Amplitude scaling
        if np.random.rand() < 0.4:
            ecg_scale = np.random.uniform(0.85, 1.15)
            ramp_scale = np.random.uniform(0.9, 1.1)
            ecg *= ecg_scale
            ramp *= ramp_scale

    # Baseline wander simulation for ECG
        if np.random.rand() < 0.3:
            baseline = np.sin(np.linspace(0, 2*np.pi, len(ecg))) * np.random.uniform(0.05, 0.15)
            ecg += baseline

    # RR interval perturbation (simulating heart rate variability changes)
        if np.random.rand() < 0.3:
            rr_noise = scipy_signal.savgol_filter(np.random.randn(len(rr)), 15, 3) * 0.05
            rr += rr_noise

    # Clip after augmentation
        ecg = np.clip(ecg, -10, 10)
        rr = np.clip(rr, -10, 10)
        ramp = np.clip(ramp, -10, 10)

    # Convert back to torch tensors
        return tuple(map(torch.from_numpy, (ecg, rr, ramp)))
            
    def __len__(self):
        return self.ecg_segments.shape[0]

    def __getitem__(self, idx):
        ecg = self.ecg_segments[idx]
        rr = self.rr_segments[idx]
        ramp = self.ramp_segments[idx]
        label = self.labels[idx]
        
        if self.augment:
            ecg, rr, ramp = self._augment(ecg, rr, ramp)
        
        # Ensure correct shape
        if ecg.ndim == 1:
            ecg = ecg.unsqueeze(-1)
        if rr.ndim == 1:
            rr = rr.unsqueeze(-1)
        if ramp.ndim == 1:
            ramp = ramp.unsqueeze(-1)
        
        return ecg, rr, ramp, label

# -------------------------- Training ------------------------

def compute_class_weights(labels_tensor):
    counts = Counter(labels_tensor.tolist())
    total = sum(counts.values())
    num_classes = len(counts)
    weights = [total / (num_classes * counts.get(i, 1)) for i in range(num_classes)]
    return torch.tensor(weights, dtype=torch.float32)

class FocalLoss(nn.Module):
    def __init__(self, alpha=0.25, gamma=2.0, weight=None):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.weight = weight
        
    def forward(self, inputs, targets):
        ce_loss = F.cross_entropy(inputs, targets, reduction='none', weight=self.weight)
        pt = torch.exp(-ce_loss)
        focal_loss = self.alpha * (1 - pt) ** self.gamma * ce_loss
        return focal_loss.mean()

def train_epoch(model, dataloader, criterion, optimizer, scheduler, device, epoch, scaler=None):
    model.train()
    total_loss = 0.0
    correct = 0
    total = 0
    
    num_batches = len(dataloader)
    print_freq = max(1, num_batches // 15)
    start_time = time.time()

    for batch_idx, (ecg, rr, ramp, target) in enumerate(dataloader, 1):
        ecg = ecg.to(device, non_blocking=True)
        rr = rr.to(device, non_blocking=True)
        ramp = ramp.to(device, non_blocking=True)
        target = target.to(device, non_blocking=True)

        optimizer.zero_grad(set_to_none=True)

        with torch.amp.autocast('cuda', enabled=(device.type == 'cuda')):
            output = model(ecg, rr, ramp)
            loss = criterion(output, target)
        
        if torch.isnan(loss) or torch.isinf(loss):
            print(f"\nWARNING: NaN/Inf loss, skipping batch {batch_idx}")
            continue

        if scaler is not None:
            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            scaler.step(optimizer)
            scaler.update()
        else:
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
        
        scheduler.step()

        total_loss += loss.item()
        pred = output.argmax(dim=1)
        correct += pred.eq(target).sum().item()
        total += target.size(0)
        
        if batch_idx % print_freq == 0 or batch_idx == num_batches:
            curr_acc = 100.0 * correct / total
            curr_loss = total_loss / batch_idx
            speed = batch_idx / (time.time() - start_time)
            eta = (num_batches - batch_idx) / speed if speed > 0 else 0
            
            print(f"  Ep {epoch} [{batch_idx:4d}/{num_batches}] "
                  f"Loss: {curr_loss:.4f} Acc: {curr_acc:.2f}% "
                  f"({speed:.1f} b/s, ETA: {eta:.0f}s)", end='\r')

    print()
    avg_loss = total_loss / len(dataloader)
    accuracy = 100.0 * correct / total if total > 0 else 0.0
    return avg_loss, accuracy

def validate(model, dataloader, criterion, device):
    model.eval()
    total_loss = 0.0
    correct = 0
    total = 0
    all_preds = []
    all_targets = []
    all_probs = []

    with torch.no_grad():
        for ecg, rr, ramp, target in dataloader:
            ecg = ecg.to(device, non_blocking=True)
            rr = rr.to(device, non_blocking=True)
            ramp = ramp.to(device, non_blocking=True)
            target = target.to(device, non_blocking=True)
            
            output = model(ecg, rr, ramp)
            loss = criterion(output, target)

            total_loss += loss.item()
            probs = F.softmax(output, dim=1)[:, 1]
            pred = output.argmax(dim=1)

            correct += pred.eq(target).sum().item()
            total += target.size(0)

            all_preds.extend(pred.cpu().tolist())
            all_targets.extend(target.cpu().tolist())
            all_probs.extend(probs.cpu().tolist())

    avg_loss = total_loss / len(dataloader)
    accuracy = 100.0 * correct / total if total > 0 else 0.0
    
    precision = precision_score(all_targets, all_preds, zero_division=0)
    recall = recall_score(all_targets, all_preds, zero_division=0)
    f1 = f1_score(all_targets, all_preds, zero_division=0)
    
    # Calculate specificity and sensitivity
    tn, fp, fn, tp = confusion_matrix(all_targets, all_preds).ravel()
    sensitivity = tp / (tp + fn) if (tp + fn) > 0 else 0.0
    specificity = tn / (tn + fp) if (tn + fp) > 0 else 0.0
    
    return avg_loss, accuracy, np.array(all_preds), np.array(all_targets), np.array(all_probs), precision, recall, f1, sensitivity, specificity

# ------------------------------ Main ------------------------------------

def main(args):
    set_seed(args.seed)

    DATA_DIR = Path(args.data_dir)
    if not DATA_DIR.exists():
        raise FileNotFoundError(f"Data directory not found: {DATA_DIR}")

    record_files = list(DATA_DIR.glob('*.hea'))
    all_records = [f.stem for f in record_files]
    valid_records = [rec for rec in all_records 
                    if (DATA_DIR / (rec + '.apn')).exists() and not rec.endswith('er')]
    
    if len(valid_records) == 0:
        raise RuntimeError("No valid records found")

    print(f"Found {len(valid_records)} valid records")

    import random
    valid_records_shuffled = valid_records.copy()
    random.Random(args.seed).shuffle(valid_records_shuffled)
    split_idx = int(len(valid_records_shuffled) * args.train_split)
    train_records = valid_records_shuffled[:split_idx]
    val_records = valid_records_shuffled[split_idx:]
    print(f"Train: {len(train_records)}, Val: {len(val_records)}\n")

    cache_dir = args.cache_dir if args.cache_dir else str(DATA_DIR)
    train_dataset = EnhancedApneaDataset(
        str(DATA_DIR), train_records, cache_dir,
        args.segment_length, args.stride, 'train', augment=True
    )
    val_dataset = EnhancedApneaDataset(
        str(DATA_DIR), val_records, cache_dir,
        args.segment_length, args.stride, 'val', augment=False
    )

    num_workers = 2 if str(DATA_DIR).startswith('/kaggle') else 4

    train_loader = DataLoader(
        train_dataset, batch_size=args.batch_size, shuffle=True,
        num_workers=num_workers, pin_memory=True
    )
    val_loader = DataLoader(
        val_dataset, batch_size=args.batch_size, shuffle=False,
        num_workers=num_workers, pin_memory=True
    )

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Device: {device}")
    if device.type == 'cuda':
        try:
            print(f"GPU: {torch.cuda.get_device_name(0)}")
        except Exception:
            pass
        torch.cuda.empty_cache()

    model = ImprovedApneaNet(
        d_model=args.d_model, n_blocks=args.n_blocks, dropout=args.dropout
    ).to(device)

    print(f"Parameters: {sum(p.numel() for p in model.parameters()):,}\n")

    class_weights = compute_class_weights(train_dataset.labels).to(device)
    print(f"Class weights: {class_weights}")

    # Use Focal Loss for better class imbalance handling
        # stable, NaN-free loss with class weights
    criterion = FocalLoss(alpha=0.25, gamma=2.0, weight=class_weights)

    optimizer = torch.optim.AdamW(
        model.parameters(), lr=args.lr, weight_decay=args.weight_decay
    )

    scheduler = torch.optim.lr_scheduler.OneCycleLR(
        optimizer, max_lr=args.lr, epochs=args.epochs,
        steps_per_epoch=len(train_loader), pct_start=0.2
    )

    scaler = torch.amp.GradScaler() if device.type == 'cuda' else None

    best_val_acc = 0.0
    best_val_f1 = 0.0
    no_improve = 0

    print("\nStarting training...")
    print("="*100)

    # class FocalLoss(nn.Module):
    #     def __init__(self, alpha=0.25, gamma=2.0, weight=None):
    #         super().__init__()
    #         self.alpha = alpha
    #         self.gamma = gamma
    #         self.weight = weight
        
    #     def forward(self, inputs, targets):
    #         ce_loss = F.cross_entropy(inputs, targets, reduction='none', weight=self.weight)
    #         pt = torch.exp(-ce_loss)
    #         focal_loss = self.alpha * (1 - pt) ** self.gamma * ce_loss
    #         return focal_loss.mean()

    for epoch in range(1, args.epochs + 1):
        epoch_start = time.time()

        train_loss, train_acc = train_epoch(
            model, train_loader, criterion, optimizer, scheduler, device, epoch, scaler
        )

        val_loss, val_acc, _, val_targets, val_probs, precision, recall, f1, sensitivity, specificity = validate(
            model, val_loader, criterion, device
        )

        try:
            auc = roc_auc_score(val_targets, val_probs)
        except Exception:
            auc = 0.0

        epoch_time = time.time() - epoch_start

        print(f"Epoch {epoch:2d}/{args.epochs} ({epoch_time:.1f}s)")
        print(f"  Train: Loss={train_loss:.4f}, Acc={train_acc:.2f}%")
        print(f"  Val:   Loss={val_loss:.4f}, Acc={val_acc:.2f}%, AUC={auc:.4f}")
        print(f"         Prec={precision:.3f}, Rec={recall:.3f}, F1={f1:.3f}")
        print(f"         Sensitivity={sensitivity:.3f}, Specificity={specificity:.3f}")

        if val_acc > best_val_acc or (val_acc >= best_val_acc and f1 > best_val_f1):
            best_val_acc = val_acc
            best_val_f1 = f1
            no_improve = 0
            torch.save({
                'epoch': epoch, 'model_state_dict': model.state_dict(),
                'val_acc': val_acc, 'val_auc': auc, 'val_f1': f1,
                'sensitivity': sensitivity, 'specificity': specificity
            }, args.best_model_path)
            print(f"  ✓ Best! (Acc={val_acc:.2f}%, F1={f1:.3f}, Sens={sensitivity:.3f}, Spec={specificity:.3f})")
        else:
            no_improve += 1
            print(f"  No improvement ({no_improve}/{args.patience})")

        print("-"*100)

        if no_improve >= args.patience:
            print(f"\nEarly stop at epoch {epoch}")
            break

    print(f"\n{'='*100}")
    print(f"BEST - Accuracy: {best_val_acc:.2f}%, F1: {best_val_f1:.3f}")
    print(f"{'='*100}")


if __name__ == '__main__':
    kaggle_data = '/kaggle/input/vincent2/apnea-ecg-database-1.0.0'
    colab_data = '/content/apnea-ecg/1.0.0'
    if Path(kaggle_data).exists():
        default_data_dir = kaggle_data
        default_cache_dir = '/kaggle/working'
        default_model_path = '/kaggle/working/best_model.pth'
    elif Path(colab_data).exists():
        default_data_dir = colab_data
        default_cache_dir = '/content'
        default_model_path = '/content/best_model.pth'
    else:
        default_data_dir = None
        default_cache_dir = None
        default_model_path = 'best_model.pth'

    parser = argparse.ArgumentParser()
    parser.add_argument('--data-dir', type=str, default=default_data_dir)
    parser.add_argument('--cache-dir', type=str, default=default_cache_dir)
    parser.add_argument('--segment-length', type=int, default=6000)
    parser.add_argument('--stride', type=int, default=2000)  # 67% overlap - more training data
    parser.add_argument('--batch-size', type=int, default=24)  # Smaller batch for better gradients
    parser.add_argument('--epochs', type=int, default=100)
    parser.add_argument('--lr', type=float, default=8e-5)
    parser.add_argument('--weight-decay', type=float, default=0.01)
    parser.add_argument('--d-model', type=int, default=256)
    parser.add_argument('--n-blocks', type=int, default=10)
    parser.add_argument('--dropout', type=float, default=0.2)
    parser.add_argument('--train-split', type=float, default=0.8)
    parser.add_argument('--patience', type=int, default=20)
    parser.add_argument('--best-model-path', type=str, default=default_model_path)
    parser.add_argument('--seed', type=int, default=42)

    args, _ = parser.parse_known_args()

    if args.data_dir is None:
        raise SystemExit("ERROR: Dataset not found")

    print("="*100)
    print("ENHANCED MODEL WITH R-R INTERVALS (Target: 90%+ Accuracy)")
    print("="*100)
    print(f"  Data:       {args.data_dir}")
    print(f"  Segment:    {args.segment_length} samples (60s), stride={args.stride}")
    print(f"  Features:   ECG + R-R Intervals + R-peak Amplitudes")
    print(f"  Batch:      {args.batch_size}")
    print(f"  Epochs:     {args.epochs}")
    print(f"  Model:      d_model={args.d_model}, blocks={args.n_blocks}")
    print(f"  Optimizer:  AdamW (lr={args.lr}, wd={args.weight_decay})")
    print(f"  Loss:       Focal Loss (alpha=0.25, gamma=2.0)")
    print("="*100 + "\n")

    main(args)

ENHANCED MODEL WITH R-R INTERVALS (Target: 90%+ Accuracy)
  Data:       /kaggle/input/vincent2/apnea-ecg-database-1.0.0
  Segment:    6000 samples (60s), stride=2000
  Features:   ECG + R-R Intervals + R-peak Amplitudes
  Batch:      24
  Epochs:     100
  Model:      d_model=256, blocks=10
  Optimizer:  AdamW (lr=8e-05, wd=0.01)
  Loss:       Focal Loss (alpha=0.25, gamma=2.0)

Found 43 valid records
Train: 34, Val: 9

Processing 34 records for train (with R-R extraction)...
  [34/34] a20....
Saving cache to /kaggle/working/apnea_enhanced_train_6000_2000.pt
Train: 50104 segments, Class: Counter({0: 31658, 1: 18446})
Processing 9 records for val (with R-R extraction)...
  [9/9] a01....
Saving cache to /kaggle/working/apnea_enhanced_val_6000_2000.pt
Val: 12933 segments, Class: Counter({0: 7032, 1: 5901})
Device: cuda
GPU: Tesla P100-PCIE-16GB
Parameters: 3,170,955

Class weights: tensor([0.7913, 1.3581], device='cuda:0')

Starting training...
  Ep 1 [2088/2088] Loss: 0.1182 Acc: 45.10% 