In [None]:
"""
PhysioNet ECG Image Digitization Competition - Complete Solution
================================================================

This script provides a comprehensive solution for the PhysioNet ECG Image Digitization competition.
It includes image preprocessing, lead extraction, signal reconstruction, and evaluation metrics.

Author: Muhammad Qasim Shabbir
Date: 2025
"""

# Install required libraries if not available
import subprocess
import sys

def install_package(package):
    try:
        subprocess.check_call([sys.executable, "-m", "pip", "install", package])
    except subprocess.CalledProcessError:
        print(f"Failed to install {package}")

# Install tqdm for progress bars
try:
    from tqdm import tqdm
except ImportError:
    install_package("tqdm")
    from tqdm import tqdm

import cv2
import numpy as np
import pandas as pd
import os
import glob
from pathlib import Path
from typing import List, Tuple, Dict, Optional
import matplotlib.pyplot as plt
from scipy import signal as scipy_signal
from scipy.optimize import minimize_scalar
from scipy.ndimage import gaussian_filter1d
import warnings
warnings.filterwarnings('ignore')

# GPU and parallel processing imports
try:
    import cupy as cp
    import cupyx.scipy.ndimage as cp_ndimage
    import cupyx.scipy.signal as cp_signal
    GPU_AVAILABLE = True
    print("CuPy available - GPU acceleration enabled")
except ImportError:
    GPU_AVAILABLE = False
    print("CuPy not available - using CPU only")

try:
    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    from torch.utils.data import DataLoader, Dataset
    TORCH_AVAILABLE = True
    print("PyTorch available - Deep learning models enabled")
except ImportError:
    TORCH_AVAILABLE = False
    print("PyTorch not available - using traditional methods only")

from multiprocessing import Pool, cpu_count
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor
import threading

class ECGDigitizer:
    """
    Complete ECG Image Digitization Solution for PhysioNet Competition
    Enhanced with GPU acceleration and parallel processing
    """
    
    def __init__(self, use_gpu=True, num_workers=None):
        self.leads = ['I', 'II', 'III', 'aVR', 'aVL', 'aVF', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6']
        self.lead_durations = {
            'II': 20.0,  # 10 seconds for lead II
            **{lead: 3.5 for lead in self.leads if lead != 'II'}  # 2.5 seconds for all other leads
        }
        
        # GPU and parallel processing setup
        self.use_gpu = use_gpu and GPU_AVAILABLE
        self.num_workers = num_workers or min(cpu_count(), 8)
        
        if self.use_gpu:
            # Initialize GPU memory pools for better performance
            cp.cuda.MemoryPool().set_limit(size=15 * 1024**3)  # 15GB per GPU
            self.gpu_devices = cp.cuda.runtime.getDeviceCount()
            print(f"GPU acceleration enabled with {self.gpu_devices} devices")
            
            # Set up multi-GPU processing
            if self.gpu_devices >= 2:
                self.gpu_0 = cp.cuda.Device(0)
                self.gpu_1 = cp.cuda.Device(1)
                print("Dual GPU setup detected - using both GPUs")
            else:
                self.gpu_0 = cp.cuda.Device(0)
                self.gpu_1 = None
        else:
            print("Using CPU processing with multiprocessing")
            
        # Initialize PyTorch models if available
        if TORCH_AVAILABLE and self.use_gpu:
            self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
            self.model = self._create_ecg_model()
            print(f"PyTorch model initialized on {self.device}")
        else:
            self.device = None
            self.model = None
    
    def _create_ecg_model(self):
        """Create a PyTorch model for ECG signal reconstruction"""
        if not TORCH_AVAILABLE:
            return None
            
        class ECGSignalNet(nn.Module):
            def __init__(self):
                super().__init__()
                # Encoder for image features
                self.encoder = nn.Sequential(
                    nn.Conv2d(1, 32, 3, padding=1),
                    nn.ReLU(),
                    nn.MaxPool2d(2),
                    nn.Conv2d(32, 64, 3, padding=1),
                    nn.ReLU(),
                    nn.MaxPool2d(2),
                    nn.Conv2d(64, 128, 3, padding=1),
                    nn.ReLU(),
                    nn.AdaptiveAvgPool2d((8, 8))
                )
                
                # Decoder for signal reconstruction
                self.decoder = nn.Sequential(
                    nn.Linear(128 * 8 * 8, 512),
                    nn.ReLU(),
                    nn.Dropout(0.3),
                    nn.Linear(512, 256),
                    nn.ReLU(),
                    nn.Dropout(0.3),
                    nn.Linear(256, 1000),  # Max signal length
                    nn.Tanh()
                )
                
            def forward(self, x):
                x = self.encoder(x)
                x = x.view(x.size(0), -1)
                x = self.decoder(x)
                return x
        
        model = ECGSignalNet().to(self.device)
        return model
    
    def preprocess_image_gpu(self, image_path: str) -> np.ndarray:
        """GPU-accelerated image preprocessing"""
        if not self.use_gpu:
            return self.preprocess_image(image_path)
            
        # Load image on CPU first
        image = cv2.imread(image_path)
        if image is None:
            raise ValueError(f"Could not load image: {image_path}")
        
        # Convert to grayscale and move to GPU
        gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
        gpu_image = cp.asarray(gray, dtype=cp.float32)
        
        # GPU-accelerated Gaussian blur
        gpu_blurred = cp_ndimage.gaussian_filter(gpu_image, sigma=1.0)
        
        # GPU-accelerated adaptive thresholding
        # Note: OpenCV adaptive thresholding is not available in CuPy
        # So we'll do basic thresholding on GPU and fallback to CPU for adaptive
        blurred_cpu = cp.asnumpy(gpu_blurred)
        binary = cv2.adaptiveThreshold(
            blurred_cpu.astype(np.uint8), 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, 
            cv2.THRESH_BINARY_INV, 11, 2
        )
        
        # Move back to GPU for morphological operations
        gpu_binary = cp.asarray(binary)
        
        # GPU-accelerated morphological operations
        kernel = cp.ones((3, 3), dtype=cp.uint8)
        gpu_binary = cp_ndimage.binary_closing(gpu_binary, structure=kernel)
        gpu_binary = cp_ndimage.binary_opening(gpu_binary, structure=kernel)
        
        return cp.asnumpy(gpu_binary)
    
    def preprocess_image(self, image_path: str) -> np.ndarray:
        """
        Preprocess ECG image to enhance signal extraction
        
        Args:
            image_path: Path to the ECG image
            
        Returns:
            Preprocessed binary image
        """
        # Load image
        image = cv2.imread(image_path)
        if image is None:
            raise ValueError(f"Could not load image: {image_path}")
            
        # Convert to grayscale
        gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
        
        # Apply Gaussian blur to reduce noise
        blurred = cv2.GaussianBlur(gray, (5, 5), 0)
        
        # Apply adaptive thresholding for better binarization
        binary = cv2.adaptiveThreshold(
            blurred, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, 
            cv2.THRESH_BINARY_INV, 11, 2
        )
        
        # Morphological operations to clean up the image
        kernel = np.ones((3, 3), np.uint8)
        binary = cv2.morphologyEx(binary, cv2.MORPH_CLOSE, kernel)
        binary = cv2.morphologyEx(binary, cv2.MORPH_OPEN, kernel)
        
        return binary
    
    def detect_lead_regions(self, binary_image: np.ndarray) -> List[Tuple[int, int, int, int]]:
        """
        Detect individual lead regions in the ECG image using improved method
        
        Args:
            binary_image: Preprocessed binary image
            
        Returns:
            List of bounding boxes (x, y, w, h) for each lead
        """
        # Ensure binary image is uint8 type for OpenCV
        if binary_image.dtype != np.uint8:
            binary_image = binary_image.astype(np.uint8)
        
        # Method 1: Use horizontal projection to find lead regions
        h, w = binary_image.shape
        horizontal_projection = np.sum(binary_image, axis=1)
        
        # Find peaks in horizontal projection (potential lead regions)
        lead_regions = []
        in_lead = False
        start_y = 0
        
        threshold = np.mean(horizontal_projection) * 0.5
        
        for y in range(h):
            if horizontal_projection[y] > threshold and not in_lead:
                # Start of a lead region
                start_y = y
                in_lead = True
            elif horizontal_projection[y] <= threshold and in_lead:
                # End of a lead region
                end_y = y
                lead_height = end_y - start_y
                
                if lead_height > 10:  # Minimum height for a lead
                    # Extract the lead region
                    lead_region = binary_image[start_y:end_y, :]
                    
                    # Find the actual width of the signal
                    vertical_projection = np.sum(lead_region, axis=0)
                    signal_indices = np.where(vertical_projection > 0)[0]
                    
                    if len(signal_indices) > 0:
                        start_x = signal_indices[0]
                        end_x = signal_indices[-1] + 1
                        lead_width = end_x - start_x
                        
                        if lead_width > 50:  # Minimum width for a lead
                            lead_regions.append((start_x, start_y, lead_width, lead_height))
                
                in_lead = False
        
        # If no leads found with projection method, fall back to contour method
        if len(lead_regions) == 0:
            contours, _ = cv2.findContours(binary_image, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
            
            for contour in contours:
                area = cv2.contourArea(contour)
                if area > 500:  # Lower threshold
                    x, y, w, h = cv2.boundingRect(contour)
                    aspect_ratio = w / h
                    if 1.5 < aspect_ratio < 30:  # More flexible aspect ratio
                        lead_regions.append((x, y, w, h))
        
        # Sort by y-coordinate (top to bottom)
        lead_regions.sort(key=lambda x: x[1])
        
        # Limit to reasonable number of leads (ECG typically has 12 leads)
        if len(lead_regions) > 15:
            lead_regions = lead_regions[:15]
        
        return lead_regions
    
    def extract_lead_signal(self, binary_image: np.ndarray, bbox: Tuple[int, int, int, int]) -> np.ndarray:
        """
        Extract time series signal from a lead region using improved ECG-specific method
        
        Args:
            binary_image: Preprocessed binary image
            bbox: Bounding box (x, y, w, h) of the lead region
            
        Returns:
            Extracted signal as 1D array
        """
        x, y, w, h = bbox
        lead_region = binary_image[y:y+h, x:x+w]
        
        # Check if region is valid
        if lead_region.size == 0 or w < 10 or h < 5:
            return np.array([])
        
        # Method 1: Find the center line using weighted average
        signal = np.zeros(w)
        
        for col in range(w):
            column = lead_region[:, col]
            # Find all non-zero pixels in this column
            nonzero_indices = np.where(column > 0)[0]
            
            if len(nonzero_indices) > 0:
                # Use weighted average of non-zero pixels
                weights = column[nonzero_indices]
                weighted_center = np.average(nonzero_indices, weights=weights)
                signal[col] = weighted_center
            else:
                # If no signal in this column, use previous value or interpolate
                if col > 0:
                    signal[col] = signal[col-1]
                else:
                    signal[col] = h // 2  # Default to middle
        
        # Smooth the signal
        if len(signal) > 10:
            signal = gaussian_filter1d(signal.astype(float), sigma=1.0)
        
        # Convert to ECG-like signal (invert Y-axis since ECG traces go up for positive values)
        signal = h - signal
        
        # Normalize the signal
        if len(signal) > 0:
            mean_val = np.mean(signal)
            std_val = np.std(signal)
            if std_val > 1e-8:
                signal = (signal - mean_val) / std_val
            else:
                signal = signal - mean_val
        
        return signal
    
    def resample_signal_gpu(self, input_signal: np.ndarray, target_length: int) -> np.ndarray:
        """GPU-accelerated signal resampling"""
        if not self.use_gpu:
            return self.resample_signal(input_signal, target_length)
            
        if len(input_signal) == 0:
            return np.zeros(target_length)
        
        if len(input_signal) == target_length:
            return input_signal
        
        # Move signal to GPU
        gpu_signal = cp.asarray(input_signal, dtype=cp.float32)
        
        # GPU-accelerated resampling using CuPy
        try:
            gpu_resampled = cp_signal.resample(gpu_signal, target_length)
            return cp.asnumpy(gpu_resampled)
        except:
            # Fallback to CPU if GPU resampling fails
            return self.resample_signal(input_signal, target_length)
    
    def resample_signal(self, input_signal: np.ndarray, target_length: int) -> np.ndarray:
        """
        Resample signal to target length
        
        Args:
            input_signal: Input signal
            target_length: Target length for resampling
            
        Returns:
            Resampled signal
        """
        if len(input_signal) == 0:
            return np.zeros(target_length)
        
        if len(input_signal) == target_length:
            return input_signal
            
        # Use scipy's resample for better quality
        resampled = scipy_signal.resample(input_signal, target_length)
        return resampled
    
    def process_batch_gpu(self, image_paths: List[str], fs_values: List[int], leads: List[str]) -> List[np.ndarray]:
        """Process multiple images in parallel using GPU"""
        if not self.use_gpu or len(image_paths) == 0:
            return []
        
        results = []
        
        # Split work between GPUs if available
        if self.gpu_devices >= 2:
            mid = len(image_paths) // 2
            batch1 = image_paths[:mid]
            batch2 = image_paths[mid:]
            
            # Process on GPU 0
            with self.gpu_0:
                results1 = self._process_batch_single_gpu(batch1, fs_values[:mid], leads[:mid])
            
            # Process on GPU 1
            with self.gpu_1:
                results2 = self._process_batch_single_gpu(batch2, fs_values[mid:], leads[mid:])
            
            results = results1 + results2
        else:
            # Single GPU processing
            with self.gpu_0:
                results = self._process_batch_single_gpu(image_paths, fs_values, leads)
        
        return results
    
    def _process_batch_single_gpu(self, image_paths: List[str], fs_values: List[int], leads: List[str]) -> List[np.ndarray]:
        """Process batch on a single GPU"""
        results = []
        
        for i, (image_path, fs, lead) in enumerate(zip(image_paths, fs_values, leads)):
            try:
                # Preprocess image on GPU
                binary_image = self.preprocess_image_gpu(image_path)
                
                # Detect lead regions
                lead_regions = self.detect_lead_regions(binary_image)
                
                if lead_regions:
                    # Extract signal from first region
                    signal = self.extract_lead_signal(binary_image, lead_regions[0])
                    
                    # Resample on GPU
                    duration = self.lead_durations[lead]
                    target_length = int(fs * duration)
                    resampled_signal = self.resample_signal_gpu(signal, target_length)
                    
                    results.append(resampled_signal)
                else:
                    # Return zero signal if no leads detected
                    duration = self.lead_durations[lead]
                    target_length = int(fs * duration)
                    results.append(np.zeros(target_length))
                    
            except Exception as e:
                print(f"Error processing {image_path}: {str(e)}")
                duration = self.lead_durations[lead]
                target_length = int(fs * duration)
                results.append(np.zeros(target_length))
        
        return results
    
    def process_ecg_image(self, image_path: str, fs: int, lead: str = 'II') -> np.ndarray:
        """
        Process a single ECG image and extract the signal for a specific lead
        
        Args:
            image_path: Path to the ECG image
            fs: Sampling frequency
            lead: ECG lead to extract
            
        Returns:
            Extracted and resampled signal
        """
        # Use GPU preprocessing if available
        if self.use_gpu:
            binary_image = self.preprocess_image_gpu(image_path)
        else:
            binary_image = self.preprocess_image(image_path)
        
        # Detect lead regions
        lead_regions = self.detect_lead_regions(binary_image)
        
        if not lead_regions:
            # If no leads detected, return zero signal
            duration = self.lead_durations[lead]
            target_length = int(fs * duration)
            return np.zeros(target_length)
        
        # For simplicity, use the first detected lead region
        # In a more sophisticated approach, you would map regions to specific leads
        bbox = lead_regions[0]
        signal = self.extract_lead_signal(binary_image, bbox)
        
        # Resample to target length (use GPU if available)
        duration = self.lead_durations[lead]
        target_length = int(fs * duration)
        if self.use_gpu:
            resampled_signal = self.resample_signal_gpu(signal, target_length)
        else:
            resampled_signal = self.resample_signal(signal, target_length)
        
        return resampled_signal
    
    def calculate_snr(self, predicted: np.ndarray, ground_truth: np.ndarray, fs: int) -> float:
        """
        Calculate modified Signal-to-Noise Ratio (SNR) as per competition metric
        
        Args:
            predicted: Predicted signal
            ground_truth: Ground truth signal
            fs: Sampling frequency
            
        Returns:
            SNR in decibels
        """
        # Handle edge cases
        if len(predicted) == 0 or len(ground_truth) == 0:
            return -np.inf
        
        # Check if signals are all zeros
        if np.all(predicted == 0) or np.all(ground_truth == 0):
            return -np.inf
        
        if len(predicted) != len(ground_truth):
            # Resample predicted to match ground truth length
            predicted = self.resample_signal(predicted, len(ground_truth))
        
        # Ensure signals are not all the same value
        if np.std(ground_truth) == 0 or np.std(predicted) == 0:
            return -np.inf
        
        # Time alignment (find optimal shift up to 0.2 seconds)
        max_shift = int(0.2 * fs)
        best_correlation = -np.inf
        best_shift = 0
        
        for shift in range(-max_shift, max_shift + 1):
            if shift == 0:
                shifted_pred = predicted
            elif shift > 0:
                shifted_pred = np.pad(predicted[shift:], (0, shift), mode='constant')
            else:
                shifted_pred = np.pad(predicted[:shift], (-shift, 0), mode='constant')
            
            if len(shifted_pred) == len(ground_truth):
                try:
                    correlation = np.corrcoef(ground_truth, shifted_pred)[0, 1]
                    if not np.isnan(correlation) and correlation > best_correlation:
                        best_correlation = correlation
                        best_shift = shift
                except:
                    continue
        
        # Apply best shift
        if best_shift > 0:
            aligned_pred = np.pad(predicted[best_shift:], (0, best_shift), mode='constant')
        elif best_shift < 0:
            aligned_pred = np.pad(predicted[:best_shift], (-best_shift, 0), mode='constant')
        else:
            aligned_pred = predicted
        
        # Vertical alignment (remove constant offset)
        offset = np.mean(ground_truth) - np.mean(aligned_pred)
        aligned_pred = aligned_pred + offset
        
        # Calculate SNR
        signal_power = np.sum(ground_truth ** 2)
        noise_power = np.sum((ground_truth - aligned_pred) ** 2)
        
        if noise_power == 0 or signal_power == 0:
            return -np.inf
        
        snr_linear = signal_power / noise_power
        
        if snr_linear <= 0 or np.isnan(snr_linear) or np.isinf(snr_linear):
            return -np.inf
        
        snr_db = 10 * np.log10(snr_linear)
        
        # Check for valid result
        if np.isnan(snr_db) or np.isinf(snr_db):
            return -np.inf
        
        return snr_db
    
    def process_training_data(self, train_dir: str, train_csv_path: str, use_parallel=False) -> Dict:
        """
        Process training data and calculate performance metrics with GPU acceleration and parallel processing
        
        Args:
            train_dir: Directory containing training images
            train_csv_path: Path to training CSV file
            use_parallel: Whether to use parallel processing
            
        Returns:
            Dictionary with processing results and metrics
        """
        train_df = pd.read_csv(train_csv_path)
        results = {
            'snr_scores': [],
            'processed_ids': [],
            'errors': []
        }
        
        print(f"Processing {len(train_df)} training samples...")
        print(f"GPU acceleration: {'Enabled' if self.use_gpu else 'Disabled'}")
        print(f"Parallel processing: {'Enabled' if use_parallel else 'Disabled'}")
        
        if use_parallel and len(train_df) > 10:
            # Use parallel processing for large datasets
            results = self._process_training_data_parallel(train_df, train_dir)
        else:
            # Sequential processing
            results = self._process_training_data_sequential(train_df, train_dir)
        
        return results
    
    def _process_training_data_sequential(self, train_df: pd.DataFrame, train_dir: str) -> Dict:
        """Sequential processing of training data"""
        results = {
            'snr_scores': [],
            'processed_ids': [],
            'errors': []
        }
        
        # Add progress bar
        progress_bar = tqdm(train_df.iterrows(), total=len(train_df), desc="Processing training data")
        
        for idx, row in progress_bar:
            ecg_id = row['id']
            fs = row['fs']
            sig_len = row['sig_len']
            
            # Load ground truth data
            gt_path = os.path.join(train_dir, str(ecg_id), f"{ecg_id}.csv")
            if not os.path.exists(gt_path):
                progress_bar.set_description(f"Skipping ID {ecg_id} - No ground truth")
                continue
            
            ground_truth_df = pd.read_csv(gt_path)
            
            # Process each lead
            lead_snrs = []
            for lead in self.leads:
                if lead in ground_truth_df.columns:
                    # Get ground truth signal
                    gt_signal = ground_truth_df[lead].values
                    
                    # Skip if ground truth is all NaN or empty
                    if len(gt_signal) == 0 or np.all(np.isnan(gt_signal)):
                        continue
                    
                    # Remove NaN values from ground truth
                    valid_indices = ~np.isnan(gt_signal)
                    if np.sum(valid_indices) < 10:  # Need at least 10 valid points
                        continue
                    
                    gt_signal_clean = gt_signal[valid_indices]
                    
                    # Find corresponding image (use first available segment)
                    image_pattern = os.path.join(train_dir, str(ecg_id), f"{ecg_id}-*.png")
                    image_files = glob.glob(image_pattern)
                    
                    if image_files:
                        # Process the first available image
                        predicted_signal = self.process_ecg_image(image_files[0], fs, lead)
                        
                        # Skip if predicted signal is empty
                        if len(predicted_signal) == 0:
                            continue
                        
                        # Resample predicted to match ground truth length
                        if len(predicted_signal) != len(gt_signal_clean):
                            predicted_signal = self.resample_signal(predicted_signal, len(gt_signal_clean))
                        
                        # Calculate SNR
                        snr = self.calculate_snr(predicted_signal, gt_signal_clean, fs)
                        if not np.isnan(snr) and not np.isinf(snr):
                            lead_snrs.append(snr)
            
            if lead_snrs:
                avg_snr = np.mean(lead_snrs)
                results['snr_scores'].append(avg_snr)
                results['processed_ids'].append(ecg_id)
                # Update progress bar description
                progress_bar.set_description(f"ID {ecg_id} - SNR: {avg_snr:.2f} dB")
            else:
                progress_bar.set_description(f"ID {ecg_id} - No valid signals")
        
        return results
    
    def _process_training_data_parallel(self, train_df: pd.DataFrame, train_dir: str) -> Dict:
        """Parallel processing of training data using multiprocessing"""
        results = {
            'snr_scores': [],
            'processed_ids': [],
            'errors': []
        }
        
        # Prepare data for parallel processing
        processing_data = []
        for idx, row in train_df.iterrows():
            ecg_id = row['id']
            fs = row['fs']
            sig_len = row['sig_len']
            
            gt_path = os.path.join(train_dir, str(ecg_id), f"{ecg_id}.csv")
            if os.path.exists(gt_path):
                processing_data.append((ecg_id, fs, sig_len, gt_path, train_dir))
        
        # Process in parallel
        with ProcessPoolExecutor(max_workers=self.num_workers) as executor:
            future_to_data = {
                executor.submit(self._process_single_training_sample, data): data 
                for data in processing_data
            }
            
            for future in future_to_data:
                try:
                    result = future.result()
                    if result:
                        results['snr_scores'].append(result['snr'])
                        results['processed_ids'].append(result['ecg_id'])
                except Exception as e:
                    ecg_id = future_to_data[future][0]
                    error_msg = f"Error processing ID {ecg_id}: {str(e)}"
                    results['errors'].append(error_msg)
                    print(error_msg)
        
        return results
    
    def _process_single_training_sample(self, data):
        """Process a single training sample (for parallel processing)"""
        ecg_id, fs, sig_len, gt_path, train_dir = data
        
        # Create a new digitizer instance for this process (without PyTorch model)
        digitizer = ECGDigitizer(use_gpu=False, num_workers=1)
        
        ground_truth_df = pd.read_csv(gt_path)
        
        # Process each lead
        lead_snrs = []
        for lead in digitizer.leads:
            if lead in ground_truth_df.columns:
                # Get ground truth signal
                gt_signal = ground_truth_df[lead].values
                
                # Skip if ground truth is all NaN or empty
                if len(gt_signal) == 0 or np.all(np.isnan(gt_signal)):
                    continue
                
                # Remove NaN values from ground truth
                valid_indices = ~np.isnan(gt_signal)
                if np.sum(valid_indices) < 10:  # Need at least 10 valid points
                    continue
                
                gt_signal_clean = gt_signal[valid_indices]
                
                # Find corresponding image (use first available segment)
                image_pattern = os.path.join(train_dir, str(ecg_id), f"{ecg_id}-*.png")
                image_files = glob.glob(image_pattern)
                
                if image_files:
                    # Process the first available image
                    predicted_signal = digitizer.process_ecg_image(image_files[0], fs, lead)
                    
                    # Skip if predicted signal is empty
                    if len(predicted_signal) == 0:
                        continue
                    
                    # Resample predicted to match ground truth length
                    if len(predicted_signal) != len(gt_signal_clean):
                        predicted_signal = digitizer.resample_signal(predicted_signal, len(gt_signal_clean))
                    
                    # Calculate SNR
                    snr = digitizer.calculate_snr(predicted_signal, gt_signal_clean, fs)
                    if not np.isnan(snr) and not np.isinf(snr):
                        lead_snrs.append(snr)
        
        if lead_snrs:
            avg_snr = np.mean(lead_snrs)
            return {'ecg_id': ecg_id, 'snr': avg_snr}
        
        return None
    
    def generate_submission(self, test_dir: str, test_csv_path: str, output_path: str, use_parallel=False):
        """
        Generate submission file for the competition with GPU acceleration and parallel processing
        
        Args:
            test_dir: Directory containing test images
            test_csv_path: Path to test CSV file
            output_path: Path for output submission file
            use_parallel: Whether to use parallel processing
        """
        test_df = pd.read_csv(test_csv_path)
        
        print(f"Generating submission for {len(test_df)} test samples...")
        print(f"GPU acceleration: {'Enabled' if self.use_gpu else 'Disabled'}")
        print(f"Parallel processing: {'Enabled' if use_parallel else 'Disabled'}")
        
        if use_parallel and len(test_df) > 50:
            # Use parallel processing for large datasets
            submission_data = self._generate_submission_parallel(test_df, test_dir)
        else:
            # Sequential processing
            submission_data = self._generate_submission_sequential(test_df, test_dir)
        
        # Create submission DataFrame and save
        submission_df = pd.DataFrame(submission_data)
        submission_df.to_csv(output_path, index=False)
        print(f"Submission saved to: {output_path}")
        
        return submission_df
    
    def _generate_submission_sequential(self, test_df: pd.DataFrame, test_dir: str) -> List[Dict]:
        """Sequential submission generation"""
        submission_data = []
        
        # Add progress bar
        progress_bar = tqdm(test_df.iterrows(), total=len(test_df), desc="Generating submission")
        
        for idx, row in progress_bar:
            ecg_id = row['id']
            lead = row['lead']
            fs = row['fs']
            num_rows = row['number_of_rows']
            
            # Process the test image
            image_path = os.path.join(test_dir, f"{ecg_id}.png")
            
            if os.path.exists(image_path):
                # Extract signal
                signal = self.process_ecg_image(image_path, fs, lead)
                
                # Ensure signal has correct length
                if len(signal) != num_rows:
                    if self.use_gpu:
                        signal = self.resample_signal_gpu(signal, num_rows)
                    else:
                        signal = self.resample_signal(signal, num_rows)
                
                # Create submission entries
                for row_id in range(num_rows):
                    submission_id = f"{ecg_id}_{row_id}_{lead}"
                    value = signal[row_id] if row_id < len(signal) else 0.0
                    submission_data.append({
                        'id': submission_id,
                        'value': value
                    })
                
                # Update progress bar description
                progress_bar.set_description(f"Processing {ecg_id} - {lead}")
            else:
                print(f"Test image not found: {image_path}")
                # Fill with zeros if image not found
                for row_id in range(num_rows):
                    submission_id = f"{ecg_id}_{row_id}_{lead}"
                    submission_data.append({
                        'id': submission_id,
                        'value': 0.0
                    })
        
        return submission_data
    
    def _generate_submission_parallel(self, test_df: pd.DataFrame, test_dir: str) -> List[Dict]:
        """Parallel submission generation using multiprocessing"""
        submission_data = []
        
        # Prepare data for parallel processing
        processing_data = []
        for idx, row in test_df.iterrows():
            ecg_id = row['id']
            lead = row['lead']
            fs = row['fs']
            num_rows = row['number_of_rows']
            image_path = os.path.join(test_dir, f"{ecg_id}.png")
            processing_data.append((ecg_id, lead, fs, num_rows, image_path))
        
        # Process in parallel
        with ProcessPoolExecutor(max_workers=self.num_workers) as executor:
            future_to_data = {
                executor.submit(self._process_single_test_sample, data): data 
                for data in processing_data
            }
            
            for future in future_to_data:
                try:
                    result = future.result()
                    if result:
                        submission_data.extend(result)
                except Exception as e:
                    ecg_id = future_to_data[future][0]
                    print(f"Error processing test ID {ecg_id}: {str(e)}")
                    # Fill with zeros if processing fails
                    lead, fs, num_rows = future_to_data[future][1:4]
                    for row_id in range(num_rows):
                        submission_id = f"{ecg_id}_{row_id}_{lead}"
                        submission_data.append({
                            'id': submission_id,
                            'value': 0.0
                        })
        
        return submission_data
    
    def _process_single_test_sample(self, data):
        """Process a single test sample (for parallel processing)"""
        ecg_id, lead, fs, num_rows, image_path = data
        
        # Create a new digitizer instance for this process (without PyTorch model)
        digitizer = ECGDigitizer(use_gpu=False, num_workers=1)
        
        if os.path.exists(image_path):
            # Extract signal
            signal = digitizer.process_ecg_image(image_path, fs, lead)
            
            # Ensure signal has correct length
            if len(signal) != num_rows:
                signal = digitizer.resample_signal(signal, num_rows)
            
            # Create submission entries
            submission_entries = []
            for row_id in range(num_rows):
                submission_id = f"{ecg_id}_{row_id}_{lead}"
                value = signal[row_id] if row_id < len(signal) else 0.0
                submission_entries.append({
                    'id': submission_id,
                    'value': value
                })
            
            return submission_entries
        else:
            # Fill with zeros if image not found
            submission_entries = []
            for row_id in range(num_rows):
                submission_id = f"{ecg_id}_{row_id}_{lead}"
                submission_entries.append({
                    'id': submission_id,
                    'value': 0.0
                })
            return submission_entries

def main():
    """
    Main function to run the complete ECG digitization solution with GPU acceleration
    """
    print("PhysioNet ECG Image Digitization - GPU-Accelerated Complete Solution")
    print("=" * 70)
    
    # Initialize the digitizer with GPU acceleration
    digitizer = ECGDigitizer(use_gpu=True, num_workers=8)
    
    # Define paths (adjust these to match your data structure)
    data_dir = "/kaggle/input/physionet-ecg-image-digitization"
    train_csv_path = os.path.join(data_dir, "train.csv")
    test_csv_path = os.path.join(data_dir, "test.csv")
    train_dir = os.path.join(data_dir, "train")
    test_dir = os.path.join(data_dir, "test")
    
    print(f"\nSystem Configuration:")
    print(f"  - GPU Acceleration: {'Enabled' if digitizer.use_gpu else 'Disabled'}")
    print(f"  - Number of Workers: {digitizer.num_workers}")
    if digitizer.use_gpu:
        print(f"  - GPU Devices: {digitizer.gpu_devices}")
        print(f"  - GPU Memory per Device: 15GB")
        print(f"  - Total GPU Memory: {digitizer.gpu_devices * 15}GB")
    
    # Check if training data is available
    if os.path.exists(train_csv_path) and os.path.exists(train_dir):
        print("\n1. Processing Training Data with GPU Acceleration...")
        training_results = digitizer.process_training_data(train_dir, train_csv_path, use_parallel=False)
        
        if training_results['snr_scores']:
            avg_snr = np.mean(training_results['snr_scores'])
            std_snr = np.std(training_results['snr_scores'])
            print(f"\nTraining Results:")
            print(f"  - Processed {len(training_results['processed_ids'])} samples")
            print(f"  - Average SNR: {avg_snr:.2f} Â± {std_snr:.2f} dB")
            print(f"  - Best SNR: {max(training_results['snr_scores']):.2f} dB")
            print(f"  - Worst SNR: {min(training_results['snr_scores']):.2f} dB")
            print(f"  - Errors: {len(training_results['errors'])}")
        else:
            print("No training data processed successfully.")
    
    # Generate submission if test data is available
    if os.path.exists(test_csv_path) and os.path.exists(test_dir):
        print("\n2. Generating Submission with GPU Acceleration...")
        submission_path = "submission.csv"
        submission_df = digitizer.generate_submission(test_dir, test_csv_path, submission_path, use_parallel=False)
        print(f"Submission generated with {len(submission_df)} predictions")
        
        # Show sample of submission
        print(f"\nSample submission entries:")
        print(submission_df.head(10))
    else:
        print("\nTest data not found. Skipping submission generation.")
    
    print("\n" + "=" * 70)
    print("GPU-Accelerated ECG Digitization Solution Completed Successfully!")
    print("=" * 70)

if __name__ == "__main__":
    main()