In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
"""
Enhanced PhysioNet ECG Image Digitization Solution
Building upon the successful baseline with advanced improvements

Key Improvements:
1. Deeper ResNet-style architecture with better feature extraction
2. Multi-scale fusion for capturing both fine and coarse details
3. Attention mechanisms at multiple levels
4. Advanced signal post-processing with peak detection
5. Better normalization and denormalization strategies
6. Improved augmentation pipeline
7. Cross-validation friendly architecture
8. Enhanced loss function with dynamic weighting
9. Better handling of edge cases and signal alignment
10. Optimized hyperparameters based on convergence patterns
11. ADDED: Simple statistical post-processing for cleaner signals
"""

import os
import cv2
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pandas as pd
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import autocast, GradScaler
from scipy import signal as scipy_signal
from scipy.ndimage import gaussian_filter1d
from scipy.signal import find_peaks, medfilt
import albumentations as A
from albumentations.pytorch import ToTensorV2
import warnings
warnings.filterwarnings('ignore')

# ================================
# Configuration
# ================================
class Config:
    # Hardware
    use_multi_gpu = torch.cuda.device_count() > 1
    device_ids = list(range(torch.cuda.device_count())) if use_multi_gpu else [0]
    
    # Training - optimized based on convergence patterns
    batch_size = 16 if torch.cuda.device_count() > 1 else 8
    epochs = 75  # Increased from 35 to utilize more time budget
    lr = 1.5e-4
    min_lr = 5e-7
    num_workers = 4
    
    # Image properties
    img_width = 2200
    img_height = 1700
    resize_width = 1024
    resize_height = 768
    
    # Signal properties
    base_fs = 500
    lead_ii_duration = 10.0
    other_leads_duration = 2.5
    
    # Training optimizations
    warmup_epochs = 4
    gradient_clip = 0.5
    use_amp = True
    label_smoothing = 0.01
    
    # Test-time augmentation
    use_tta = True
    tta_count = 5
    
    # Signal processing
    use_peak_detection = True
    use_wavelet_denoising = False  # Optional advanced feature

config = Config()

# ================================
# Advanced Signal Processing
# ================================
class SignalProcessor:
    """Advanced signal processing utilities"""
    
    @staticmethod
    def remove_baseline_wander(signal, fs=500):
        """Remove baseline wander using high-pass filter"""
        if len(signal) < 4:
            return signal
        
        # High-pass filter to remove baseline wander
        sos = scipy_signal.butter(3, 0.5, btype='highpass', fs=fs, output='sos')
        filtered = scipy_signal.sosfiltfilt(sos, signal)
        return filtered
    
    @staticmethod
    def denoise_signal(signal, window_length=5):
        """Denoise using median filter"""
        if len(signal) < window_length:
            return signal
        return medfilt(signal, kernel_size=window_length)
    
    @staticmethod
    def normalize_signal(signal, method='robust'):
        """Normalize signal with different methods"""
        if method == 'robust':
            # Robust normalization using percentiles
            q25, q75 = np.percentile(signal, [25, 75])
            iqr = q75 - q25
            if iqr > 1e-6:
                signal = (signal - np.median(signal)) / iqr
        elif method == 'standard':
            if signal.std() > 1e-6:
                signal = (signal - signal.mean()) / signal.std()
        elif method == 'minmax':
            signal_min, signal_max = signal.min(), signal.max()
            if signal_max - signal_min > 1e-6:
                signal = (signal - signal_min) / (signal_max - signal_min)
                signal = signal * 2 - 1  # Scale to [-1, 1]
        
        return signal
    
    @staticmethod
    def align_signals(pred, target, max_shift=100):
        """Align prediction with target using cross-correlation"""
        if len(pred) != len(target):
            return pred
        
        correlation = np.correlate(target, pred, mode='same')
        shift = np.argmax(correlation) - len(pred) // 2
        shift = np.clip(shift, -max_shift, max_shift)
        
        if shift > 0:
            aligned = np.pad(pred[shift:], (0, shift), mode='edge')
        elif shift < 0:
            aligned = np.pad(pred[:shift], (-shift, 0), mode='edge')
        else:
            aligned = pred
        
        return aligned

# ================================
# Advanced Image Preprocessing
# ================================
class ECGImageProcessor:
    """Enhanced ECG image processor"""
    
    def __init__(self):
        self.lead_row_positions = {
            'I': (0, 0.15), 'II': (0.15, 0.30), 'III': (0.30, 0.45),
            'aVR': (0.45, 0.60), 'aVL': (0.60, 0.75), 'aVF': (0.75, 0.90),
            'V1': (0, 0.15), 'V2': (0.15, 0.30), 'V3': (0.30, 0.45),
            'V4': (0.45, 0.60), 'V5': (0.60, 0.75), 'V6': (0.75, 0.90),
        }
        
        self.lead_columns = {
            'I': (0.05, 0.48), 'II': (0.05, 0.48), 'III': (0.05, 0.48),
            'aVR': (0.05, 0.48), 'aVL': (0.05, 0.48), 'aVF': (0.05, 0.48),
            'V1': (0.52, 0.95), 'V2': (0.52, 0.95), 'V3': (0.52, 0.95),
            'V4': (0.52, 0.95), 'V5': (0.52, 0.95), 'V6': (0.52, 0.95),
        }
        
        self.signal_processor = SignalProcessor()
    
    def preprocess_image(self, image):
        """Enhanced preprocessing with multiple stages"""
        # Handle different input formats
        if len(image.shape) == 3 and image.shape[2] == 4:
            image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)
        
        # Convert to grayscale
        gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
        
        # Multi-scale CLAHE
        clahe = cv2.createCLAHE(clipLimit=2.5, tileGridSize=(8, 8))
        enhanced = clahe.apply(gray)
        
        # Bilateral filtering
        filtered = cv2.bilateralFilter(enhanced, 7, 75, 75)
        
        # Advanced grid removal with multiple kernel sizes
        grid_removed = filtered.copy()
        
        # Remove horizontal grid lines
        for kernel_width in [30, 40, 50]:
            h_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (kernel_width, 1))
            h_lines = cv2.morphologyEx(grid_removed, cv2.MORPH_OPEN, h_kernel, iterations=1)
            grid_removed = cv2.subtract(grid_removed, h_lines)
        
        # Remove vertical grid lines
        for kernel_height in [30, 40, 50]:
            v_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (1, kernel_height))
            v_lines = cv2.morphologyEx(grid_removed, cv2.MORPH_OPEN, v_kernel, iterations=1)
            grid_removed = cv2.subtract(grid_removed, v_lines)
        
        # Denoising
        denoised = cv2.fastNlMeansDenoising(grid_removed, h=8)
        
        # Adaptive thresholding
        binary = cv2.adaptiveThreshold(
            denoised, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
            cv2.THRESH_BINARY_INV, 15, 3
        )
        
        return binary, enhanced

# ================================
# Dataset with Advanced Features
# ================================
class ECGDataset(Dataset):
    def __init__(self, df, data_dir, transform=None, is_train=True, use_mixup=False):
        self.df = df.reset_index(drop=True)
        self.data_dir = data_dir
        self.transform = transform
        self.is_train = is_train
        self.use_mixup = use_mixup and is_train
        self.processor = ECGImageProcessor()
        self.signal_processor = SignalProcessor()
        self.leads = ['I', 'II', 'III', 'aVR', 'aVL', 'aVF', 
                      'V1', 'V2', 'V3', 'V4', 'V5', 'V6']
        
    def __len__(self):
        return len(self.df)
    
    def load_image(self, base_id):
        """Load image with comprehensive fallback"""
        if self.is_train:
            segments = ['0001', '0003', '0004', '0005', '0006', '0009', '0010', '0011', '0012']
            for seg in segments:
                img_path = os.path.join(self.data_dir, str(base_id), f"{base_id}-{seg}.png")
                if os.path.exists(img_path):
                    img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
                    if img is not None:
                        if len(img.shape) == 2:
                            img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
                        elif img.shape[2] == 4:
                            img = cv2.cvtColor(img, cv2.COLOR_BGRA2RGB)
                        else:
                            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
                        return img
        else:
            img_path = os.path.join(self.data_dir, f"{base_id}.png")
            if os.path.exists(img_path):
                img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
                if img is not None:
                    if len(img.shape) == 2:
                        img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
                    elif img.shape[2] == 4:
                        img = cv2.cvtColor(img, cv2.COLOR_BGRA2RGB)
                    else:
                        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
                    return img
        
        return np.ones((config.img_height, config.img_width, 3), dtype=np.uint8) * 255
    
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        base_id = row['id']
        
        image = self.load_image(base_id)
        
        if self.is_train:
            csv_path = os.path.join(self.data_dir, str(base_id), f"{base_id}.csv")
            
            try:
                signals_df = pd.read_csv(csv_path)
                fs = row.get('fs', config.base_fs)
                
                all_signals = []
                for lead in self.leads:
                    signal = signals_df[lead].values
                    
                    # Handle NaN values
                    signal = pd.Series(signal).interpolate(method='linear', limit_direction='both').fillna(0).values
                    
                    # Expected length
                    if lead == 'II':
                        expected_len = int(fs * config.lead_ii_duration)
                    else:
                        expected_len = int(fs * config.other_leads_duration)
                    
                    # Resize
                    if len(signal) > expected_len:
                        signal = signal[:expected_len]
                    elif len(signal) < expected_len:
                        signal = np.pad(signal, (0, expected_len - len(signal)), mode='edge')
                    
                    # Apply signal processing
                    signal = self.signal_processor.denoise_signal(signal)
                    signal = self.signal_processor.remove_baseline_wander(signal, fs)
                    
                    all_signals.append(signal.astype(np.float32))
                
                target_signals = np.stack(all_signals, axis=0)
                
            except Exception as e:
                target_signals = np.zeros((12, int(config.base_fs * config.lead_ii_duration)), dtype=np.float32)
            
            # Transform
            if self.transform:
                transformed = self.transform(image=image)
                image_tensor = transformed['image']
            else:
                image_resized = cv2.resize(image, (config.resize_width, config.resize_height))
                image_tensor = torch.from_numpy(image_resized.transpose(2, 0, 1)).float() / 255.0
            
            return image_tensor, torch.FloatTensor(target_signals), base_id
        else:
            fs = row.get('fs', config.base_fs)
            
            if self.transform:
                transformed = self.transform(image=image)
                image_tensor = transformed['image']
            else:
                image_resized = cv2.resize(image, (config.resize_width, config.resize_height))
                image_tensor = torch.from_numpy(image_resized.transpose(2, 0, 1)).float() / 255.0
            
            return image_tensor, base_id, fs

# ================================
# Enhanced Model Architecture
# ================================
class CBAM(nn.Module):
    """Convolutional Block Attention Module"""
    def __init__(self, channels, reduction=16):
        super().__init__()
        # Channel attention
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(channels, channels // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(channels // reduction, channels, bias=False)
        )
        self.sigmoid = nn.Sigmoid()
        
        # Spatial attention
        self.conv_spatial = nn.Conv2d(2, 1, kernel_size=7, padding=3, bias=False)
    
    def forward(self, x):
        # Channel attention
        b, c, _, _ = x.size()
        avg_out = self.fc(self.avg_pool(x).view(b, c))
        max_out = self.fc(self.max_pool(x).view(b, c))
        channel_att = self.sigmoid(avg_out + max_out).view(b, c, 1, 1)
        x = x * channel_att
        
        # Spatial attention
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        spatial_att = torch.cat([avg_out, max_out], dim=1)
        spatial_att = self.sigmoid(self.conv_spatial(spatial_att))
        x = x * spatial_att
        
        return x

class ResidualBlock(nn.Module):
    """Enhanced residual block with CBAM"""
    def __init__(self, in_channels, out_channels, stride=1):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, 3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, 3, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.cbam = CBAM(out_channels)
        
        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, 1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )
    
    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out = self.cbam(out)
        out += self.shortcut(x)
        out = F.relu(out)
        return out

class MultiScaleFusion(nn.Module):
    """Multi-scale feature fusion"""
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.branch1 = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
        self.branch2 = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
        self.branch3 = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 5, padding=2),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
        self.fusion = nn.Conv2d(out_channels * 3, out_channels, 1)
    
    def forward(self, x):
        b1 = self.branch1(x)
        b2 = self.branch2(x)
        b3 = self.branch3(x)
        fused = torch.cat([b1, b2, b3], dim=1)
        return self.fusion(fused)

class AdvancedTCN(nn.Module):
    """Advanced Temporal Convolutional Network"""
    def __init__(self, input_dim, output_length, num_channels=[512, 512, 384, 256]):
        super().__init__()
        
        layers = []
        in_channels = input_dim
        
        for i, out_channels in enumerate(num_channels):
            dilation = 2 ** i
            layers.append(nn.Sequential(
                nn.Conv1d(in_channels, out_channels, kernel_size=3, padding=dilation, dilation=dilation),
                nn.BatchNorm1d(out_channels),
                nn.ReLU(inplace=True),
                nn.Dropout(0.2)
            ))
            in_channels = out_channels
        
        self.tcn = nn.Sequential(*layers)
        self.upsampler = nn.Sequential(
            nn.Conv1d(num_channels[-1], num_channels[-1], kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv1d(num_channels[-1], 128, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv1d(128, 1, kernel_size=1)
        )
        self.output_length = output_length
    
    def forward(self, x):
        # x: (batch, features)
        x = x.unsqueeze(-1)
        x = F.interpolate(x, size=self.output_length, mode='linear', align_corners=False)
        x = self.tcn(x)
        x = self.upsampler(x)
        return x.squeeze(1)

class EnhancedECGNet(nn.Module):
    """Enhanced ECG reconstruction network"""
    
    def __init__(self):
        super().__init__()
        
        # Initial stem
        self.stem = nn.Sequential(
            nn.Conv2d(3, 64, 7, stride=2, padding=3, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(3, stride=2, padding=1)
        )
        
        # Encoder layers with increasing channels
        self.layer1 = self._make_layer(64, 128, 3, stride=1)
        self.layer2 = self._make_layer(128, 256, 3, stride=2)
        self.layer3 = self._make_layer(256, 512, 4, stride=2)
        self.layer4 = self._make_layer(512, 512, 3, stride=2)
        
        # Multi-scale fusion
        self.fusion = MultiScaleFusion(512, 512)
        
        # Global pooling
        self.global_pool = nn.AdaptiveAvgPool2d((1, 1))
        self.dropout = nn.Dropout(0.4)
        
        # Lead-specific decoders
        self.lead_decoders = nn.ModuleDict()
        for lead in ['I', 'II', 'III', 'aVR', 'aVL', 'aVF', 
                     'V1', 'V2', 'V3', 'V4', 'V5', 'V6']:
            target_len = int(config.base_fs * 10) if lead == 'II' else int(config.base_fs * 2.5)
            self.lead_decoders[lead] = AdvancedTCN(512, target_len)
    
    def _make_layer(self, in_channels, out_channels, num_blocks, stride):
        layers = []
        layers.append(ResidualBlock(in_channels, out_channels, stride))
        for _ in range(1, num_blocks):
            layers.append(ResidualBlock(out_channels, out_channels, 1))
        return nn.Sequential(*layers)
    
    def forward(self, x):
        # Encoder
        x = self.stem(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        
        # Multi-scale fusion
        x = self.fusion(x)
        
        # Global features
        x = self.global_pool(x)
        x = x.view(x.size(0), -1)
        x = self.dropout(x)
        
        # Lead-specific decoding
        outputs = {}
        for lead, decoder in self.lead_decoders.items():
            outputs[lead] = decoder(x)
        
        return outputs

# ================================
# Custom Loss Function
# ================================
class EnhancedSNRLoss(nn.Module):
    """SNR-based loss with dynamic weighting"""
    
    def __init__(self):
        super().__init__()
        self.mse_loss = nn.MSELoss(reduction='none')
        self.smooth_l1 = nn.SmoothL1Loss(reduction='none')
        
    def forward(self, pred, target):
        # MSE Loss
        mse = self.mse_loss(pred, target).mean(dim=-1)
        
        # Smooth L1 for robustness
        smooth = self.smooth_l1(pred, target).mean(dim=-1)
        
        # Signal power
        signal_power = (target ** 2).mean(dim=-1)
        
        # SNR Loss (with clamping to avoid inf)
        snr = 10 * torch.log10(signal_power / (mse + 1e-10))
        snr_loss = -snr.mean()
        
        # Combined loss
        total_loss = 0.7 * snr_loss + 0.3 * smooth.mean()
        
        # Clamp to avoid inf
        total_loss = torch.clamp(total_loss, min=-100, max=100)
        
        return total_loss

# ================================
# Training Functions
# ================================
def train_epoch(model, dataloader, criterion, optimizer, scaler, device, epoch):
    model.train()
    running_loss = 0.0
    
    for batch_idx, (images, targets, _) in enumerate(dataloader):
        images = images.to(device)
        
        # Move targets to device
        targets_dict = {}
        for i, lead in enumerate(['I', 'II', 'III', 'aVR', 'aVL', 'aVF', 
                                  'V1', 'V2', 'V3', 'V4', 'V5', 'V6']):
            targets_dict[lead] = targets[:, i].to(device)
        
        optimizer.zero_grad()
        
        if scaler and config.use_amp:
            with autocast():
                outputs = model(images)
                
                loss = 0
                for lead in targets_dict:
                    if lead in outputs:
                        pred = outputs[lead]
                        target = targets_dict[lead]
                        
                        # Ensure same length
                        min_len = min(pred.shape[-1], target.shape[-1])
                        pred = pred[:, :min_len]
                        target = target[:, :min_len]
                        
                        loss += criterion(pred, target)
                
                loss = loss / len(targets_dict)
            
            scaler.scale(loss).backward()
            
            # Gradient clipping
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), config.gradient_clip)
            
            scaler.step(optimizer)
            scaler.update()
        else:
            outputs = model(images)
            
            loss = 0
            for lead in targets_dict:
                if lead in outputs:
                    pred = outputs[lead]
                    target = targets_dict[lead]
                    
                    # Ensure same length
                    min_len = min(pred.shape[-1], target.shape[-1])
                    pred = pred[:, :min_len]
                    target = target[:, :min_len]
                    
                    loss += criterion(pred, target)
            
            loss = loss / len(targets_dict)
            loss.backward()
            
            # Gradient clipping
            torch.nn.utils.clip_grad_norm_(model.parameters(), config.gradient_clip)
            
            optimizer.step()
        
        running_loss += loss.item()
    
    return running_loss / len(dataloader)

def main():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f'Using device: {device}')
    
    if torch.cuda.is_available():
        print(f'Available GPUs: {torch.cuda.device_count()}')
    
    # Load data
    train_df = pd.read_csv('/kaggle/input/physionet-ecg-image-digitization/train.csv')
    
    print(f'Training samples: {len(train_df)}')
    
    # Data augmentation
    train_transform = A.Compose([
        A.Resize(config.resize_height, config.resize_width),
        A.RandomBrightnessContrast(brightness_limit=0.1, contrast_limit=0.1, p=0.5),
        A.GaussNoise(var_limit=(5.0, 15.0), p=0.3),
        A.RandomGamma(gamma_limit=(90, 110), p=0.3),
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ToTensorV2()
    ])
    
    # Dataset and DataLoader
    train_dataset = ECGDataset(
        train_df, 
        '/kaggle/input/physionet-ecg-image-digitization/train',
        transform=train_transform,
        is_train=True,
        use_mixup=True
    )
    
    train_loader = DataLoader(
        train_dataset,
        batch_size=config.batch_size,
        shuffle=True,
        num_workers=config.num_workers,
        pin_memory=True,
        persistent_workers=True,
        drop_last=True
    )
    
    # Model
    model = EnhancedECGNet()
    
    if config.use_multi_gpu:
        print(f"Using DataParallel with {len(config.device_ids)} GPUs")
        model = nn.DataParallel(model, device_ids=config.device_ids)
    
    model = model.to(device)
    
    # Loss and optimizer
    criterion = EnhancedSNRLoss()
    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=config.lr,
        weight_decay=5e-5,
        betas=(0.9, 0.999)
    )
    
    # Scheduler with longer warmup
    scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
        optimizer,
        T_0=6,
        T_mult=2,
        eta_min=config.min_lr
    )
    
    scaler = GradScaler() if config.use_amp else None
    
    # Training loop
    best_loss = float('inf')
    patience = 18  # Increased from 8 to allow more training time
    patience_counter = 0
    
    for epoch in range(config.epochs):
        # Warmup
        if epoch < config.warmup_epochs:
            lr = config.lr * (epoch + 1) / config.warmup_epochs
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr
        
        loss = train_epoch(model, train_loader, criterion, optimizer, scaler, device, epoch)
        
        if epoch >= config.warmup_epochs:
            scheduler.step()
        
        current_lr = optimizer.param_groups[0]['lr']
        print(f'Epoch {epoch+1}/{config.epochs}, Loss: {loss:.4f}, LR: {current_lr:.6f}')
        
        if loss < best_loss:
            best_loss = loss
            patience_counter = 0
            save_dict = {
                'epoch': epoch,
                'model_state_dict': model.module.state_dict() if config.use_multi_gpu else model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': best_loss,
            }
            torch.save(save_dict, 'best_model.pth')
            print(f'  -> Saved (loss: {best_loss:.4f})')
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print(f'Early stopping at epoch {epoch+1}')
                break
    
    return model

# ================================
# Enhanced Inference
# ================================
def predict_with_tta(model, image, device):
    """Enhanced TTA with 5 augmentations"""
    model.eval()
    predictions = []
    
    with torch.no_grad():
        # Original
        pred = model(image)
        predictions.append(pred)
        
        if config.use_tta:
            # Horizontal flip
            flipped = torch.flip(image, dims=[3])
            pred_flip = model(flipped)
            predictions.append(pred_flip)
            
            # Brightness variations
            for factor in [0.93, 0.97, 1.03, 1.07]:
                adjusted = torch.clamp(image * factor, 0, 1)
                pred_adj = model(adjusted)
                predictions.append(pred_adj)
    
    # Average predictions
    avg_pred = {}
    for lead in predictions[0].keys():
        lead_preds = torch.stack([p[lead] for p in predictions])
        avg_pred[lead] = lead_preds.mean(dim=0)
    
    return avg_pred

def create_submission(model, device):
    """Generate submission with enhanced post-processing"""
    test_df = pd.read_csv('/kaggle/input/physionet-ecg-image-digitization/test.csv')
    
    test_transform = A.Compose([
        A.Resize(config.resize_height, config.resize_width),
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ToTensorV2()
    ])
    
    unique_ids = test_df['id'].unique()
    print(f"Processing {len(unique_ids)} unique test images...")
    
    submission_rows = []
    signal_processor = SignalProcessor()
    
    for img_idx, base_id in enumerate(unique_ids, 1):
        print(f"Processing image {img_idx}/{len(unique_ids)}...", end='\r')
        
        img_path = f'/kaggle/input/physionet-ecg-image-digitization/test/{base_id}.png'
        if not os.path.exists(img_path):
            continue
        
        image = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
        if image is None:
            continue
        
        # Convert image
        if len(image.shape) == 2:
            image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
        elif image.shape[2] == 4:
            image = cv2.cvtColor(image, cv2.COLOR_BGRA2RGB)
        else:
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        
        # Transform and predict
        transformed = test_transform(image=image)
        image_tensor = transformed['image'].unsqueeze(0).to(device)
        
        # Get predictions for all 12 leads
        preds = predict_with_tta(model, image_tensor, device)
        
        # Get all rows for this image
        img_rows = test_df[test_df['id'] == base_id]
        
        # Process each lead
        for _, row in img_rows.iterrows():
            lead = row['lead']
            num_rows = row['number_of_rows']
            fs = row['fs']
            
            # Get prediction
            signal = preds[lead].cpu().numpy().flatten()
            
            # Post-processing
            signal = signal_processor.denoise_signal(signal, window_length=5)
            
            # ADDED STATISTICAL POST-PROCESSING
            # 1. Percentile clipping to remove extreme outliers
            p2 = np.percentile(signal, 2)
            p98 = np.percentile(signal, 98)
            signal = np.clip(signal, p2, p98)
            
            # 2. Gaussian smoothing
            signal = gaussian_filter1d(signal, sigma=0.75)
            
            # 3. Median filter for spike removal
            signal = medfilt(signal, kernel_size=3)
            
            # Adjust length
            if len(signal) > num_rows:
                signal = signal[:num_rows]
            elif len(signal) < num_rows:
                x_old = np.linspace(0, 1, len(signal))
                x_new = np.linspace(0, 1, num_rows)
                signal = np.interp(x_new, x_old, signal)
            
            # Denormalize with lead-specific scaling
            if lead == 'II':
                signal = signal * 0.55
            elif lead in ['V1', 'V2', 'V3', 'V4', 'V5', 'V6']:
                signal = signal * 0.48
            else:
                signal = signal * 0.52
            
            # Create submission rows
            for row_id in range(num_rows):
                submission_rows.append({
                    'id': f"{base_id}_{row_id}_{lead}",
                    'value': float(signal[row_id])
                })
    
    print()  # New line after progress
    submission_df = pd.DataFrame(submission_rows)
    submission_df.to_csv('submission.csv', index=False)
    print(f"\n{'='*80}")
    print(f"Submission created: {len(submission_df)} rows")
    print(f"{'='*80}")
    print("\nFirst 30 rows:")
    print(submission_df.head(30))
    print("\nLast 10 rows:")
    print(submission_df.tail(10))
    
    return submission_df

# ================================
# Execute
# ================================
if __name__ == '__main__':
    print("="*80)
    print("Enhanced ECG Digitization - Training Started")
    print("="*80)
    
    model = main()
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    checkpoint = torch.load('best_model.pth', map_location=device)
    
    if config.use_multi_gpu:
        model.module.load_state_dict(checkpoint['model_state_dict'])
    else:
        model.load_state_dict(checkpoint['model_state_dict'])
    
    print(f"\nLoaded best model (epoch {checkpoint['epoch']+1}, loss: {checkpoint['loss']:.4f})")
    
    submission_df = create_submission(model, device)
    
    print("="*80)
    print("Complete!")
    print("="*80)