# PhysioNet - ECG Image Digitization Challenge

## Problem Statement
Convert ECG images (scanned printouts, photos) into time-series data suitable for machine learning analysis. The challenge involves handling:
- Physical artifacts (rotations, misalignments, blurring)
- Multi-lead ECG interrelationships (12 leads)
- Variable sampling frequencies and signal lengths
- Different vendor formats and lead placements


## 1. Import Required Libraries

In [None]:
# Core libraries
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import warnings
warnings.filterwarnings('ignore')

# Image processing
import cv2
from PIL import Image
from scipy import signal
from scipy.ndimage import rotate
from skimage import transform, filters

# Deep Learning
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import torchvision.models as models

# Additional utilities
from tqdm import tqdm
import json
import os
import gc

# Set random seeds for reproducibility
np.random.seed(42)
torch.manual_seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(42)

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Display settings
plt.style.use('seaborn-v0_8-darkgrid')
sns.set_palette("husl")

## 2. Data Loading and Exploration

In [None]:
# Define data paths
DATA_DIR = Path('/kaggle/input/physionet-ecg-image-digitization/')

# Explore directory structure
print("Dataset Directory Structure:")
print("=" * 60)
if DATA_DIR.exists():
    for item in sorted(DATA_DIR.rglob('*')):
        if item.is_file():
            print(f"  {item.relative_to(DATA_DIR)} - {item.stat().st_size / 1024:.2f} KB")
        elif item.is_dir():
            print(f"üìÅ {item.relative_to(DATA_DIR)}/")
else:
    print("‚ö†Ô∏è  Data directory not found. Creating dummy structure for development...")
    # For local development, create a dummy structure
    DATA_DIR = Path('./data')
    DATA_DIR.mkdir(exist_ok=True)

## 3. ECG Image Preprocessing Pipeline

Handle various image artifacts:
- Rotation correction
- Alignment and cropping
- Noise reduction
- Grid detection and removal
- Lead segmentation

In [None]:
class ECGImagePreprocessor:
    """
    Preprocesses ECG images to handle various artifacts and prepare for digitization.
    """
    
    def __init__(self, target_size=(512, 512)):
        self.target_size = target_size
    
    def detect_rotation(self, image):
        """Detect and correct rotation in ECG image using Hough line detection."""
        gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) if len(image.shape) == 3 else image
        edges = cv2.Canny(gray, 50, 150, apertureSize=3)
        
        # Detect lines using Hough transform
        lines = cv2.HoughLines(edges, 1, np.pi/180, 100)
        
        if lines is not None:
            angles = []
            for rho, theta in lines[:, 0]:
                angle = np.degrees(theta) - 90
                if -45 < angle < 45:  # Only consider reasonable rotations
                    angles.append(angle)
            
            if angles:
                median_angle = np.median(angles)
                return median_angle
        return 0
    
    def rotate_image(self, image, angle):
        """Rotate image by given angle."""
        if abs(angle) < 0.5:  # Skip negligible rotations
            return image
        
        (h, w) = image.shape[:2]
        center = (w // 2, h // 2)
        M = cv2.getRotationMatrix2D(center, angle, 1.0)
        
        # Calculate new bounding box
        cos = np.abs(M[0, 0])
        sin = np.abs(M[0, 1])
        new_w = int((h * sin) + (w * cos))
        new_h = int((h * cos) + (w * sin))
        
        M[0, 2] += (new_w / 2) - center[0]
        M[1, 2] += (new_h / 2) - center[1]
        
        rotated = cv2.warpAffine(image, M, (new_w, new_h), 
                                 flags=cv2.INTER_LINEAR,
                                 borderMode=cv2.BORDER_CONSTANT,
                                 borderValue=(255, 255, 255))
        return rotated
    
    def remove_grid(self, image):
        """Remove grid lines from ECG image using morphological operations."""
        gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) if len(image.shape) == 3 else image
        
        # Detect grid using morphological operations
        # Horizontal grid lines
        horizontal_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (40, 1))
        detect_horizontal = cv2.morphologyEx(gray, cv2.MORPH_OPEN, horizontal_kernel, iterations=2)
        
        # Vertical grid lines
        vertical_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (1, 40))
        detect_vertical = cv2.morphologyEx(gray, cv2.MORPH_OPEN, vertical_kernel, iterations=2)
        
        # Combine grid detection
        grid_mask = cv2.add(detect_horizontal, detect_vertical)
        
        # Remove grid by inpainting
        result = cv2.inpaint(gray, grid_mask, 3, cv2.INPAINT_TELEA)
        
        return result
    
    def denoise_image(self, image):
        """Apply denoising while preserving ECG signal edges."""
        # Use bilateral filter to reduce noise while preserving edges
        denoised = cv2.bilateralFilter(image, 9, 75, 75)
        return denoised
    
    def normalize_image(self, image):
        """Normalize image intensity and contrast."""
        # Apply CLAHE (Contrast Limited Adaptive Histogram Equalization)
        clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
        normalized = clahe.apply(image)
        return normalized
    
    def crop_to_content(self, image, margin=10):
        """Crop image to actual ECG content, removing excess borders."""
        gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) if len(image.shape) == 3 else image
        
        # Threshold to find content
        _, binary = cv2.threshold(gray, 250, 255, cv2.THRESH_BINARY_INV)
        
        # Find contours
        contours, _ = cv2.findContours(binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
        
        if contours:
            # Get bounding box of all contours
            x_min, y_min = image.shape[1], image.shape[0]
            x_max, y_max = 0, 0
            
            for contour in contours:
                x, y, w, h = cv2.boundingRect(contour)
                x_min = min(x_min, x)
                y_min = min(y_min, y)
                x_max = max(x_max, x + w)
                y_max = max(y_max, y + h)
            
            # Add margin
            x_min = max(0, x_min - margin)
            y_min = max(0, y_min - margin)
            x_max = min(image.shape[1], x_max + margin)
            y_max = min(image.shape[0], y_max + margin)
            
            # Crop
            cropped = image[y_min:y_max, x_min:x_max]
            return cropped
        
        return image
    
    def preprocess(self, image_path):
        """Complete preprocessing pipeline."""
        # Load image
        image = cv2.imread(str(image_path))
        if image is None:
            raise ValueError(f"Could not load image: {image_path}")
        
        # Detect and correct rotation
        angle = self.detect_rotation(image)
        image = self.rotate_image(image, angle)
        
        # Crop to content
        image = self.crop_to_content(image)
        
        # Convert to grayscale
        gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
        
        # Remove grid
        no_grid = self.remove_grid(image)
        
        # Denoise
        denoised = self.denoise_image(no_grid)
        
        # Normalize
        normalized = self.normalize_image(denoised)
        
        # Resize to target size
        resized = cv2.resize(normalized, self.target_size, interpolation=cv2.INTER_AREA)
        
        return resized

# Initialize preprocessor
preprocessor = ECGImagePreprocessor(target_size=(1024, 1024))
print("‚úì ECG Image Preprocessor initialized")

## 4. Custom Dataset Class

In [None]:
class ECGDataset(Dataset):
    """
    Custom Dataset for ECG image to time-series conversion.
    Handles loading images and corresponding ground-truth signals.
    """
    
    def __init__(self, image_paths, signal_paths=None, preprocessor=None, 
                 transform=None, is_test=False):
        """
        Args:
            image_paths: List of paths to ECG images
            signal_paths: List of paths to ground-truth signal files (for training)
            preprocessor: ECGImagePreprocessor instance
            transform: PyTorch transforms for data augmentation
            is_test: Whether this is test data (no ground truth available)
        """
        self.image_paths = image_paths
        self.signal_paths = signal_paths
        self.preprocessor = preprocessor
        self.transform = transform
        self.is_test = is_test
        
        # Standard 12-lead ECG lead names
        self.lead_names = ['I', 'II', 'III', 'aVR', 'aVL', 'aVF', 
                          'V1', 'V2', 'V3', 'V4', 'V5', 'V6']
    
    def __len__(self):
        return len(self.image_paths)
    
    def load_signal(self, signal_path):
        """Load ground-truth ECG signal from file."""
        # Assuming signals are stored as numpy arrays or CSV files
        if signal_path.endswith('.npy'):
            signal = np.load(signal_path)
        elif signal_path.endswith('.csv'):
            signal = pd.read_csv(signal_path).values
        else:
            # Try to read as text file
            signal = np.loadtxt(signal_path, delimiter=',')
        
        return signal.astype(np.float32)
    
    def __getitem__(self, idx):
        # Load and preprocess image
        image_path = self.image_paths[idx]
        
        if self.preprocessor:
            image = self.preprocessor.preprocess(image_path)
        else:
            image = cv2.imread(str(image_path), cv2.IMREAD_GRAYSCALE)
            image = cv2.resize(image, (1024, 1024))
        
        # Normalize to [0, 1]
        image = image.astype(np.float32) / 255.0
        
        # Convert to tensor (C, H, W)
        image = torch.from_numpy(image).unsqueeze(0)
        
        # Apply transforms if any
        if self.transform:
            image = self.transform(image)
        
        if self.is_test:
            # Return only image and ID for test set
            record_id = Path(image_path).stem
            return image, record_id
        else:
            # Load ground-truth signal for training
            signal_path = self.signal_paths[idx]
            signal = self.load_signal(signal_path)
            
            # Ensure signal is 2D: (num_leads, num_samples)
            if signal.ndim == 1:
                signal = signal.reshape(1, -1)
            
            signal = torch.from_numpy(signal)
            
            return image, signal

# Data augmentation transforms for training
train_transforms = transforms.Compose([
    transforms.RandomApply([
        transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 2.0))
    ], p=0.3),
    transforms.RandomApply([
        transforms.Lambda(lambda x: x + torch.randn_like(x) * 0.01)  # Add noise
    ], p=0.3),
])

print("‚úì ECG Dataset class defined")

## 5. Deep Learning Model Architecture

**ECG Digitization Network**: A hybrid CNN-Transformer architecture that:
1. **CNN Encoder**: Extracts spatial features from ECG images
2. **Transformer Decoder**: Generates time-series sequences with attention mechanisms
3. **Multi-Lead Decoder**: Produces all 12 ECG leads simultaneously

In [None]:
class PositionalEncoding(nn.Module):
    """Positional encoding for transformer."""
    
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)
    
    def forward(self, x):
        return x + self.pe[:, :x.size(1)]


class CNNEncoder(nn.Module):
    """
    CNN Encoder to extract features from ECG images.
    Uses a pretrained ResNet backbone.
    """
    
    def __init__(self, pretrained=True):
        super().__init__()
        
        # Use ResNet34 as backbone
        resnet = models.resnet34(pretrained=pretrained)
        
        # Modify first conv layer to accept 1-channel grayscale images
        self.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
        if pretrained:
            # Copy pretrained weights (average across RGB channels)
            self.conv1.weight.data = resnet.conv1.weight.data.mean(dim=1, keepdim=True)
        
        self.bn1 = resnet.bn1
        self.relu = resnet.relu
        self.maxpool = resnet.maxpool
        
        # ResNet layers
        self.layer1 = resnet.layer1
        self.layer2 = resnet.layer2
        self.layer3 = resnet.layer3
        self.layer4 = resnet.layer4
        
        # Adaptive pooling to get fixed-size output
        self.adaptive_pool = nn.AdaptiveAvgPool2d((8, 8))
        
        # Feature dimension
        self.feature_dim = 512  # ResNet34 output channels
    
    def forward(self, x):
        # Input: (batch, 1, H, W)
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        
        # Output: (batch, 512, 8, 8)
        x = self.adaptive_pool(x)
        
        # Flatten spatial dimensions: (batch, 512, 64)
        batch_size = x.size(0)
        x = x.view(batch_size, self.feature_dim, -1)
        x = x.permute(0, 2, 1)  # (batch, 64, 512)
        
        return x


class TransformerDecoder(nn.Module):
    """
    Transformer decoder to generate time-series from image features.
    """
    
    def __init__(self, d_model=512, nhead=8, num_layers=6, 
                 dim_feedforward=2048, dropout=0.1):
        super().__init__()
        
        self.d_model = d_model
        
        # Positional encoding
        self.pos_encoder = PositionalEncoding(d_model)
        
        # Transformer decoder layers
        decoder_layer = nn.TransformerDecoderLayer(
            d_model=d_model,
            nhead=nhead,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
            batch_first=True
        )
        self.transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers)
    
    def forward(self, memory, tgt):
        """
        Args:
            memory: Encoded image features (batch, seq_len, d_model)
            tgt: Target sequence (batch, tgt_len, d_model)
        """
        tgt = self.pos_encoder(tgt)
        output = self.transformer_decoder(tgt, memory)
        return output


class ECGDigitizationModel(nn.Module):
    """
    Complete model for ECG image digitization.
    """
    
    def __init__(self, num_leads=12, max_signal_length=5000, 
                 d_model=512, nhead=8, num_decoder_layers=6):
        super().__init__()
        
        self.num_leads = num_leads
        self.max_signal_length = max_signal_length
        self.d_model = d_model
        
        # CNN Encoder
        self.encoder = CNNEncoder(pretrained=True)
        
        # Transformer Decoder
        self.decoder = TransformerDecoder(
            d_model=d_model,
            nhead=nhead,
            num_layers=num_decoder_layers,
            dim_feedforward=2048,
            dropout=0.1
        )
        
        # Learnable query embeddings for each time step
        self.query_embed = nn.Embedding(max_signal_length, d_model)
        
        # Output projection for each lead
        self.output_projection = nn.ModuleList([
            nn.Sequential(
                nn.Linear(d_model, 256),
                nn.ReLU(),
                nn.Dropout(0.1),
                nn.Linear(256, 1)
            ) for _ in range(num_leads)
        ])
        
        # Signal length prediction (adaptive to actual signal length)
        self.length_predictor = nn.Sequential(
            nn.Linear(d_model * 64, 512),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(512, 1),
            nn.Sigmoid()
        )
    
    def forward(self, images, target_length=None):
        """
        Args:
            images: ECG images (batch, 1, H, W)
            target_length: Target signal length (optional)
        
        Returns:
            signals: Predicted ECG signals (batch, num_leads, signal_length)
            pred_length: Predicted signal length
        """
        batch_size = images.size(0)
        
        # Encode image
        memory = self.encoder(images)  # (batch, 64, 512)
        
        # Predict signal length
        memory_flat = memory.reshape(batch_size, -1)
        length_ratio = self.length_predictor(memory_flat)  # (batch, 1)
        pred_length = (length_ratio * self.max_signal_length).squeeze(-1)
        
        # Use target length if provided, otherwise use predicted
        if target_length is None:
            signal_length = int(pred_length.mean().item())
        else:
            signal_length = target_length
        
        signal_length = min(signal_length, self.max_signal_length)
        
        # Generate query embeddings
        queries = self.query_embed.weight[:signal_length].unsqueeze(0).repeat(batch_size, 1, 1)
        
        # Decode
        decoded = self.decoder(memory, queries)  # (batch, signal_length, d_model)
        
        # Project to signals for each lead
        signals = []
        for lead_idx in range(self.num_leads):
            lead_signal = self.output_projection[lead_idx](decoded)  # (batch, signal_length, 1)
            signals.append(lead_signal.squeeze(-1))
        
        # Stack all leads: (batch, num_leads, signal_length)
        signals = torch.stack(signals, dim=1)
        
        return signals, pred_length


# Initialize model
model = ECGDigitizationModel(
    num_leads=12,
    max_signal_length=5000,
    d_model=512,
    nhead=8,
    num_decoder_layers=6
).to(device)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"‚úì ECG Digitization Model initialized")
print(f"  Total parameters: {total_params:,}")
print(f"  Trainable parameters: {trainable_params:,}")

## 6. Signal Alignment and SNR Evaluation Metric

In [None]:
class ECGAlignmentAndSNR:
    """
    Implements the competition's modified SNR metric with signal alignment.
    """
    
    def __init__(self, sampling_rate=500, max_time_shift=0.2):
        """
        Args:
            sampling_rate: ECG sampling rate in Hz (default 500 Hz)
            max_time_shift: Maximum time shift in seconds (default 0.2s)
        """
        self.sampling_rate = sampling_rate
        self.max_shift_samples = int(max_time_shift * sampling_rate)
    
    def find_time_alignment(self, pred_signal, true_signal):
        """
        Find optimal time shift using cross-correlation.
        
        Args:
            pred_signal: Predicted signal (1D array)
            true_signal: Ground truth signal (1D array)
        
        Returns:
            optimal_shift: Optimal shift in samples
        """
        # Ensure same length by padding/truncating
        min_len = min(len(pred_signal), len(true_signal))
        pred_signal = pred_signal[:min_len]
        true_signal = true_signal[:min_len]
        
        # Compute cross-correlation
        correlation = signal.correlate(true_signal, pred_signal, mode='same')
        
        # Find peak within allowed shift range
        center = len(correlation) // 2
        start_idx = max(0, center - self.max_shift_samples)
        end_idx = min(len(correlation), center + self.max_shift_samples + 1)
        
        local_correlation = correlation[start_idx:end_idx]
        peak_idx = np.argmax(local_correlation)
        optimal_shift = peak_idx + start_idx - center
        
        return optimal_shift
    
    def apply_time_shift(self, signal, shift):
        """
        Apply time shift to signal.
        
        Args:
            signal: Input signal
            shift: Shift amount in samples (positive = shift right)
        
        Returns:
            shifted_signal: Time-shifted signal
        """
        if shift == 0:
            return signal
        
        shifted = np.zeros_like(signal)
        
        if shift > 0:
            # Shift right
            shifted[shift:] = signal[:-shift]
        else:
            # Shift left
            shifted[:shift] = signal[-shift:]
        
        return shifted
    
    def remove_vertical_offset(self, pred_signal, true_signal):
        """
        Remove constant vertical offset between signals.
        
        Args:
            pred_signal: Predicted signal
            true_signal: Ground truth signal
        
        Returns:
            corrected_signal: Prediction with offset removed
        """
        offset = np.mean(pred_signal - true_signal)
        return pred_signal - offset
    
    def calculate_snr(self, pred_signal, true_signal):
        """
        Calculate SNR in decibels.
        
        Args:
            pred_signal: Predicted signal
            true_signal: Ground truth signal
        
        Returns:
            snr_db: Signal-to-noise ratio in dB
        """
        # Signal power (true signal)
        signal_power = np.sum(true_signal ** 2)
        
        # Noise power (reconstruction error)
        noise = pred_signal - true_signal
        noise_power = np.sum(noise ** 2)
        
        # Avoid division by zero
        if noise_power == 0:
            return np.inf
        if signal_power == 0:
            return -np.inf
        
        # SNR in dB
        snr = signal_power / noise_power
        snr_db = 10 * np.log10(snr)
        
        return snr_db
    
    def align_and_score(self, pred_signals, true_signals):
        """
        Align predicted signals with ground truth and calculate SNR.
        
        Args:
            pred_signals: Predicted signals (num_leads, signal_length)
            true_signals: Ground truth signals (num_leads, signal_length)
        
        Returns:
            snr_db: Overall SNR in dB
            aligned_signals: Time and vertically aligned predictions
        """
        num_leads = pred_signals.shape[0]
        aligned_signals = []
        
        total_signal_power = 0
        total_noise_power = 0
        
        for lead_idx in range(num_leads):
            pred = pred_signals[lead_idx]
            true = true_signals[lead_idx]
            
            # Ensure same length
            min_len = min(len(pred), len(true))
            pred = pred[:min_len]
            true = true[:min_len]
            
            # Find and apply time alignment
            shift = self.find_time_alignment(pred, true)
            pred_aligned = self.apply_time_shift(pred, shift)
            
            # Remove vertical offset
            pred_aligned = self.remove_vertical_offset(pred_aligned, true)
            
            aligned_signals.append(pred_aligned)
            
            # Accumulate powers across leads
            total_signal_power += np.sum(true ** 2)
            noise = pred_aligned - true
            total_noise_power += np.sum(noise ** 2)
        
        # Calculate overall SNR
        if total_noise_power == 0:
            snr_db = np.inf
        elif total_signal_power == 0:
            snr_db = -np.inf
        else:
            snr = total_signal_power / total_noise_power
            snr_db = 10 * np.log10(snr)
        
        return snr_db, np.array(aligned_signals)
    
    def batch_score(self, pred_batch, true_batch):
        """
        Calculate average SNR for a batch of ECG records.
        
        Args:
            pred_batch: Batch of predicted signals (batch, num_leads, signal_length)
            true_batch: Batch of ground truth signals (batch, num_leads, signal_length)
        
        Returns:
            avg_snr: Average SNR across all records in dB
        """
        snr_scores = []
        
        for i in range(len(pred_batch)):
            snr_db, _ = self.align_and_score(pred_batch[i], true_batch[i])
            if not np.isinf(snr_db):
                snr_scores.append(snr_db)
        
        if len(snr_scores) == 0:
            return 0.0
        
        return np.mean(snr_scores)


# Initialize SNR calculator
snr_calculator = ECGAlignmentAndSNR(sampling_rate=500, max_time_shift=0.2)
print("‚úì SNR Calculator initialized")

## 7. Loss Functions

In [None]:
class ECGLoss(nn.Module):
    """
    Combined loss function for ECG digitization.
    Combines MSE, frequency domain loss, and morphological similarity.
    """
    
    def __init__(self, alpha=1.0, beta=0.5, gamma=0.3):
        """
        Args:
            alpha: Weight for time-domain MSE loss
            beta: Weight for frequency-domain loss
            gamma: Weight for morphological loss
        """
        super().__init__()
        self.alpha = alpha
        self.beta = beta
        self.gamma = gamma
        self.mse = nn.MSELoss()
    
    def frequency_loss(self, pred, target):
        """
        Loss in frequency domain using FFT.
        """
        # Compute FFT
        pred_fft = torch.fft.rfft(pred, dim=-1)
        target_fft = torch.fft.rfft(target, dim=-1)
        
        # Compare magnitudes
        pred_mag = torch.abs(pred_fft)
        target_mag = torch.abs(target_fft)
        
        freq_loss = F.mse_loss(pred_mag, target_mag)
        return freq_loss
    
    def morphological_loss(self, pred, target):
        """
        Loss based on ECG morphological features (peaks, valleys).
        """
        # Compute first derivative (slope)
        pred_diff = pred[:, :, 1:] - pred[:, :, :-1]
        target_diff = target[:, :, 1:] - target[:, :, :-1]
        
        # Compare derivatives
        morph_loss = F.mse_loss(pred_diff, target_diff)
        
        return morph_loss
    
    def forward(self, pred, target):
        """
        Args:
            pred: Predicted signals (batch, num_leads, signal_length)
            target: Ground truth signals (batch, num_leads, signal_length)
        """
        # Ensure same length
        min_len = min(pred.size(-1), target.size(-1))
        pred = pred[..., :min_len]
        target = target[..., :min_len]
        
        # Time-domain MSE loss
        time_loss = self.mse(pred, target)
        
        # Frequency-domain loss
        freq_loss = self.frequency_loss(pred, target)
        
        # Morphological loss
        morph_loss = self.morphological_loss(pred, target)
        
        # Combined loss
        total_loss = (self.alpha * time_loss + 
                     self.beta * freq_loss + 
                     self.gamma * morph_loss)
        
        return total_loss, {
            'time_loss': time_loss.item(),
            'freq_loss': freq_loss.item(),
            'morph_loss': morph_loss.item()
        }


# Initialize loss function
criterion = ECGLoss(alpha=1.0, beta=0.5, gamma=0.3)
print("‚úì Loss function initialized")

## 8. Training and Validation Pipeline

In [None]:
class ECGTrainer:
    """
    Training and validation pipeline for ECG digitization model.
    """
    
    def __init__(self, model, criterion, snr_calculator, device, 
                 learning_rate=1e-4, weight_decay=1e-5):
        self.model = model
        self.criterion = criterion
        self.snr_calculator = snr_calculator
        self.device = device
        
        # Optimizer with different learning rates for encoder and decoder
        encoder_params = list(model.encoder.parameters())
        other_params = [p for name, p in model.named_parameters() 
                       if not name.startswith('encoder')]
        
        self.optimizer = torch.optim.AdamW([
            {'params': encoder_params, 'lr': learning_rate * 0.1},  # Lower LR for pretrained encoder
            {'params': other_params, 'lr': learning_rate}
        ], weight_decay=weight_decay)
        
        # Learning rate scheduler
        self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            self.optimizer, mode='max', factor=0.5, patience=3, verbose=True
        )
        
        # History
        self.train_losses = []
        self.val_losses = []
        self.val_snrs = []
        self.best_snr = -np.inf
    
    def train_epoch(self, train_loader):
        """Train for one epoch."""
        self.model.train()
        epoch_loss = 0
        loss_components = {'time_loss': 0, 'freq_loss': 0, 'morph_loss': 0}
        
        pbar = tqdm(train_loader, desc='Training')
        for batch_idx, (images, signals) in enumerate(pbar):
            images = images.to(self.device)
            signals = signals.to(self.device)
            
            # Forward pass
            pred_signals, pred_length = self.model(images, target_length=signals.size(-1))
            
            # Compute loss
            loss, components = self.criterion(pred_signals, signals)
            
            # Backward pass
            self.optimizer.zero_grad()
            loss.backward()
            
            # Gradient clipping
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
            
            self.optimizer.step()
            
            # Track metrics
            epoch_loss += loss.item()
            for key in components:
                loss_components[key] += components[key]
            
            # Update progress bar
            pbar.set_postfix({
                'loss': f'{loss.item():.4f}',
                'time': f'{components["time_loss"]:.4f}',
                'freq': f'{components["freq_loss"]:.4f}'
            })
            
            # Clear cache periodically
            if batch_idx % 10 == 0:
                torch.cuda.empty_cache()
        
        # Average metrics
        epoch_loss /= len(train_loader)
        for key in loss_components:
            loss_components[key] /= len(train_loader)
        
        return epoch_loss, loss_components
    
    def validate(self, val_loader):
        """Validate model."""
        self.model.eval()
        epoch_loss = 0
        all_snr_scores = []
        
        with torch.no_grad():
            pbar = tqdm(val_loader, desc='Validation')
            for images, signals in pbar:
                images = images.to(self.device)
                signals = signals.to(self.device)
                
                # Forward pass
                pred_signals, pred_length = self.model(images, target_length=signals.size(-1))
                
                # Compute loss
                loss, _ = self.criterion(pred_signals, signals)
                epoch_loss += loss.item()
                
                # Compute SNR for each sample
                pred_np = pred_signals.cpu().numpy()
                true_np = signals.cpu().numpy()
                
                for i in range(len(pred_np)):
                    snr_db, _ = self.snr_calculator.align_and_score(pred_np[i], true_np[i])
                    if not np.isinf(snr_db):
                        all_snr_scores.append(snr_db)
                
                # Update progress bar
                if all_snr_scores:
                    pbar.set_postfix({
                        'loss': f'{loss.item():.4f}',
                        'snr': f'{np.mean(all_snr_scores):.2f} dB'
                    })
        
        epoch_loss /= len(val_loader)
        avg_snr = np.mean(all_snr_scores) if all_snr_scores else 0.0
        
        return epoch_loss, avg_snr
    
    def train(self, train_loader, val_loader, num_epochs, save_dir='./checkpoints'):
        """Complete training loop."""
        os.makedirs(save_dir, exist_ok=True)
        
        print(f"\n{'='*60}")
        print(f"Starting Training for {num_epochs} epochs")
        print(f"{'='*60}\n")
        
        for epoch in range(num_epochs):
            print(f"\nEpoch {epoch+1}/{num_epochs}")
            print("-" * 60)
            
            # Train
            train_loss, train_components = self.train_epoch(train_loader)
            self.train_losses.append(train_loss)
            
            # Validate
            val_loss, val_snr = self.validate(val_loader)
            self.val_losses.append(val_loss)
            self.val_snrs.append(val_snr)
            
            # Update learning rate
            self.scheduler.step(val_snr)
            
            # Print epoch summary
            print(f"\n{'‚îÄ'*60}")
            print(f"Epoch {epoch+1} Summary:")
            print(f"  Train Loss: {train_loss:.4f}")
            print(f"    ‚îú‚îÄ Time Loss: {train_components['time_loss']:.4f}")
            print(f"    ‚îú‚îÄ Freq Loss: {train_components['freq_loss']:.4f}")
            print(f"    ‚îî‚îÄ Morph Loss: {train_components['morph_loss']:.4f}")
            print(f"  Val Loss: {val_loss:.4f}")
            print(f"  Val SNR: {val_snr:.2f} dB")
            print(f"  Best SNR: {self.best_snr:.2f} dB")
            print(f"{'‚îÄ'*60}")
            
            # Save best model
            if val_snr > self.best_snr:
                self.best_snr = val_snr
                checkpoint_path = os.path.join(save_dir, 'best_model.pth')
                torch.save({
                    'epoch': epoch,
                    'model_state_dict': self.model.state_dict(),
                    'optimizer_state_dict': self.optimizer.state_dict(),
                    'snr': val_snr,
                    'loss': val_loss
                }, checkpoint_path)
                print(f"‚úì Saved best model (SNR: {val_snr:.2f} dB)")
            
            # Save checkpoint every 5 epochs
            if (epoch + 1) % 5 == 0:
                checkpoint_path = os.path.join(save_dir, f'checkpoint_epoch_{epoch+1}.pth')
                torch.save({
                    'epoch': epoch,
                    'model_state_dict': self.model.state_dict(),
                    'optimizer_state_dict': self.optimizer.state_dict(),
                    'snr': val_snr,
                    'loss': val_loss
                }, checkpoint_path)
        
        print(f"\n{'='*60}")
        print(f"Training Complete!")
        print(f"Best Validation SNR: {self.best_snr:.2f} dB")
        print(f"{'='*60}\n")
    
    def plot_training_history(self):
        """Plot training history."""
        fig, axes = plt.subplots(1, 2, figsize=(15, 5))
        
        # Plot losses
        axes[0].plot(self.train_losses, label='Train Loss', marker='o')
        axes[0].plot(self.val_losses, label='Val Loss', marker='s')
        axes[0].set_xlabel('Epoch')
        axes[0].set_ylabel('Loss')
        axes[0].set_title('Training and Validation Loss')
        axes[0].legend()
        axes[0].grid(True, alpha=0.3)
        
        # Plot SNR
        axes[1].plot(self.val_snrs, label='Val SNR', marker='o', color='green')
        axes[1].axhline(y=self.best_snr, color='r', linestyle='--', 
                       label=f'Best SNR: {self.best_snr:.2f} dB')
        axes[1].set_xlabel('Epoch')
        axes[1].set_ylabel('SNR (dB)')
        axes[1].set_title('Validation SNR')
        axes[1].legend()
        axes[1].grid(True, alpha=0.3)
        
        plt.tight_layout()
        plt.show()

print("‚úì Trainer class defined")

## 9. Data Preparation and Training Execution

**Note**: This section assumes data is available at `/kaggle/input/physionet-ecg-image-digitization/`. 
Update paths according to actual dataset structure.

In [None]:
# Configuration
CONFIG = {
    'data_dir': Path('/kaggle/input/physionet-ecg-image-digitization/'),
    'batch_size': 8,
    'num_epochs': 50,
    'learning_rate': 1e-4,
    'num_workers': 4,
    'val_split': 0.15,
    'seed': 42
}

# Set seeds
np.random.seed(CONFIG['seed'])
torch.manual_seed(CONFIG['seed'])

def prepare_data_loaders(config):
    """
    Prepare training and validation data loaders.
    This is a template - adjust according to actual data structure.
    """
    # Example: Assuming images are in 'train/images/' and signals in 'train/signals/'
    train_image_dir = config['data_dir'] / 'train' / 'images'
    train_signal_dir = config['data_dir'] / 'train' / 'signals'
    
    # Get all image files
    image_paths = sorted(list(train_image_dir.glob('*.png')) + 
                        list(train_image_dir.glob('*.jpg')))
    
    # Corresponding signal files (adjust extension as needed)
    signal_paths = [train_signal_dir / f"{img.stem}.npy" for img in image_paths]
    
    # Filter to only include samples where both image and signal exist
    valid_samples = [(img, sig) for img, sig in zip(image_paths, signal_paths) 
                     if img.exists() and sig.exists()]
    
    if len(valid_samples) == 0:
        print("‚ö†Ô∏è  No training data found. Using dummy data for demonstration...")
        # Create dummy data for testing the pipeline
        return None, None
    
    image_paths, signal_paths = zip(*valid_samples)
    
    # Train-validation split
    n_samples = len(image_paths)
    n_val = int(n_samples * config['val_split'])
    
    indices = np.random.permutation(n_samples)
    val_indices = indices[:n_val]
    train_indices = indices[n_val:]
    
    train_image_paths = [image_paths[i] for i in train_indices]
    train_signal_paths = [signal_paths[i] for i in train_indices]
    val_image_paths = [image_paths[i] for i in val_indices]
    val_signal_paths = [signal_paths[i] for i in val_indices]
    
    # Create datasets
    train_dataset = ECGDataset(
        train_image_paths, 
        train_signal_paths,
        preprocessor=preprocessor,
        transform=train_transforms,
        is_test=False
    )
    
    val_dataset = ECGDataset(
        val_image_paths,
        val_signal_paths,
        preprocessor=preprocessor,
        transform=None,
        is_test=False
    )
    
    # Create data loaders
    train_loader = DataLoader(
        train_dataset,
        batch_size=config['batch_size'],
        shuffle=True,
        num_workers=config['num_workers'],
        pin_memory=True
    )
    
    val_loader = DataLoader(
        val_dataset,
        batch_size=config['batch_size'],
        shuffle=False,
        num_workers=config['num_workers'],
        pin_memory=True
    )
    
    print(f"‚úì Data prepared:")
    print(f"  Training samples: {len(train_dataset)}")
    print(f"  Validation samples: {len(val_dataset)}")
    print(f"  Batch size: {config['batch_size']}")
    
    return train_loader, val_loader

# Prepare data loaders
train_loader, val_loader = prepare_data_loaders(CONFIG)

# Check if data is available
if train_loader is not None:
    print("\n‚úì Data loaders ready for training")
else:
    print("\n‚ö†Ô∏è  Data not available. Training will be skipped.")

In [None]:
# Initialize trainer
trainer = ECGTrainer(
    model=model,
    criterion=criterion,
    snr_calculator=snr_calculator,
    device=device,
    learning_rate=CONFIG['learning_rate']
)

# Start training (only if data is available)
if train_loader is not None and val_loader is not None:
    trainer.train(
        train_loader=train_loader,
        val_loader=val_loader,
        num_epochs=CONFIG['num_epochs'],
        save_dir='./checkpoints'
    )
    
    # Plot training history
    trainer.plot_training_history()
else:
    print("‚ö†Ô∏è  Skipping training - no data available")
    print("üí° Once data is available, uncomment and run this cell to train the model")

## 10. Inference and Submission Generation

In [None]:
def load_best_model(checkpoint_path, model, device):
    """Load the best trained model."""
    if os.path.exists(checkpoint_path):
        checkpoint = torch.load(checkpoint_path, map_location=device)
        model.load_state_dict(checkpoint['model_state_dict'])
        print(f"‚úì Loaded model from {checkpoint_path}")
        print(f"  Epoch: {checkpoint['epoch']}")
        print(f"  Validation SNR: {checkpoint['snr']:.2f} dB")
        return model
    else:
        print(f"‚ö†Ô∏è  Checkpoint not found: {checkpoint_path}")
        print("  Using untrained model for inference")
        return model


def predict_on_test_set(model, test_loader, device):
    """Generate predictions for test set."""
    model.eval()
    predictions = []
    
    with torch.no_grad():
        for images, record_ids in tqdm(test_loader, desc='Generating predictions'):
            images = images.to(device)
            
            # Predict
            pred_signals, pred_lengths = model(images)
            
            # Convert to numpy
            pred_signals = pred_signals.cpu().numpy()
            pred_lengths = pred_lengths.cpu().numpy()
            
            # Store predictions with IDs
            for i, record_id in enumerate(record_ids):
                predictions.append({
                    'id': record_id,
                    'signal': pred_signals[i],
                    'length': pred_lengths[i]
                })
    
    return predictions


def create_submission(predictions, output_path='submission.csv'):
    """
    Create submission file in the required format.
    
    Format: Each row contains (id, lead_name, time_step, value)
    """
    submission_data = []
    
    lead_names = ['I', 'II', 'III', 'aVR', 'aVL', 'aVF', 
                  'V1', 'V2', 'V3', 'V4', 'V5', 'V6']
    
    for pred in tqdm(predictions, desc='Creating submission'):
        record_id = pred['id']
        signal = pred['signal']  # (num_leads, signal_length)
        
        for lead_idx, lead_name in enumerate(lead_names):
            lead_signal = signal[lead_idx]
            
            for time_step, value in enumerate(lead_signal):
                submission_data.append({
                    'id': f"{record_id}_{lead_name}_{time_step}",
                    'value': value
                })
    
    # Create DataFrame
    submission_df = pd.DataFrame(submission_data)
    
    # Save to CSV
    submission_df.to_csv(output_path, index=False)
    print(f"‚úì Submission saved to {output_path}")
    print(f"  Total predictions: {len(submission_df):,}")
    
    return submission_df


# Load best model
best_model_path = './checkpoints/best_model.pth'
model = load_best_model(best_model_path, model, device)

In [None]:
# Prepare test data loader
def prepare_test_loader(config):
    """Prepare test data loader."""
    test_image_dir = config['data_dir'] / 'test' / 'images'
    
    if not test_image_dir.exists():
        print("‚ö†Ô∏è  Test data directory not found")
        return None
    
    # Get all test images
    test_image_paths = sorted(list(test_image_dir.glob('*.png')) + 
                             list(test_image_dir.glob('*.jpg')))
    
    if len(test_image_paths) == 0:
        print("‚ö†Ô∏è  No test images found")
        return None
    
    # Create test dataset
    test_dataset = ECGDataset(
        test_image_paths,
        signal_paths=None,
        preprocessor=preprocessor,
        transform=None,
        is_test=True
    )
    
    # Create test loader
    test_loader = DataLoader(
        test_dataset,
        batch_size=config['batch_size'],
        shuffle=False,
        num_workers=config['num_workers'],
        pin_memory=True
    )
    
    print(f"‚úì Test data prepared: {len(test_dataset)} samples")
    return test_loader

# Prepare test loader
test_loader = prepare_test_loader(CONFIG)

# Generate predictions
if test_loader is not None:
    predictions = predict_on_test_set(model, test_loader, device)
    
    # Create submission file
    submission_df = create_submission(predictions, output_path='submission.csv')
    
    # Display first few rows
    print("\nSubmission preview:")
    print(submission_df.head(20))
else:
    print("‚ö†Ô∏è  Test data not available. Skipping prediction generation.")
    print("üí° Once test data is available, run this cell to generate predictions")

## 11. Visualization and Analysis

In [None]:
def visualize_ecg_prediction(image, true_signal, pred_signal, snr_db=None, 
                            lead_names=None, save_path=None):
    """
    Visualize ECG image alongside ground truth and predicted signals.
    
    Args:
        image: Input ECG image
        true_signal: Ground truth signals (num_leads, signal_length)
        pred_signal: Predicted signals (num_leads, signal_length)
        snr_db: SNR score in dB
        lead_names: Names of ECG leads
        save_path: Path to save figure
    """
    if lead_names is None:
        lead_names = ['I', 'II', 'III', 'aVR', 'aVL', 'aVF', 
                     'V1', 'V2', 'V3', 'V4', 'V5', 'V6']
    
    num_leads = len(lead_names)
    
    # Create figure
    fig = plt.figure(figsize=(20, 15))
    gs = fig.add_gridspec(num_leads + 1, 2, hspace=0.4, wspace=0.3)
    
    # Plot image
    ax_img = fig.add_subplot(gs[0, :])
    if image.ndim == 3:
        ax_img.imshow(image, cmap='gray')
    else:
        ax_img.imshow(image, cmap='gray')
    ax_img.set_title('ECG Image', fontsize=14, fontweight='bold')
    ax_img.axis('off')
    
    # Plot each lead
    for i, lead_name in enumerate(lead_names):
        ax = fig.add_subplot(gs[i + 1, :])
        
        # Plot ground truth
        ax.plot(true_signal[i], label='Ground Truth', color='blue', 
               linewidth=1.5, alpha=0.7)
        
        # Plot prediction
        ax.plot(pred_signal[i], label='Prediction', color='red', 
               linewidth=1.5, alpha=0.7, linestyle='--')
        
        ax.set_ylabel(lead_name, fontsize=11, fontweight='bold')
        ax.grid(True, alpha=0.3)
        
        if i == 0:
            ax.legend(loc='upper right')
            if snr_db is not None:
                ax.set_title(f'12-Lead ECG Signals (SNR: {snr_db:.2f} dB)', 
                           fontsize=12, fontweight='bold')
        
        if i == num_leads - 1:
            ax.set_xlabel('Sample', fontsize=11)
    
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=150, bbox_inches='tight')
        print(f"‚úì Figure saved to {save_path}")
    
    plt.show()


def visualize_sample_predictions(model, dataset, num_samples=3, device='cpu'):
    """
    Visualize predictions for sample ECGs.
    """
    model.eval()
    
    indices = np.random.choice(len(dataset), size=min(num_samples, len(dataset)), 
                              replace=False)
    
    with torch.no_grad():
        for idx in indices:
            image, true_signal = dataset[idx]
            
            # Predict
            image_batch = image.unsqueeze(0).to(device)
            pred_signal, _ = model(image_batch, target_length=true_signal.size(-1))
            
            # Convert to numpy
            image_np = image.squeeze().cpu().numpy()
            true_signal_np = true_signal.cpu().numpy()
            pred_signal_np = pred_signal.squeeze().cpu().numpy()
            
            # Calculate SNR
            snr_db, aligned_pred = snr_calculator.align_and_score(
                pred_signal_np, true_signal_np
            )
            
            # Visualize
            visualize_ecg_prediction(
                image_np, 
                true_signal_np, 
                aligned_pred,
                snr_db=snr_db
            )
            
            print(f"\nSample {idx} - SNR: {snr_db:.2f} dB")
            print("-" * 60)


# Visualize predictions on validation samples
if val_loader is not None and len(val_loader.dataset) > 0:
    print("Visualizing sample predictions on validation set...\n")
    visualize_sample_predictions(model, val_loader.dataset, num_samples=3, device=device)
else:
    print("‚ö†Ô∏è  No validation data available for visualization")
    print("üí° This visualization will work once you have trained the model with actual data")