In [2]:
import wfdb                          # To read the ECG files
from wfdb import processing          # For QRS detection
import numpy as np                   # Numerical operations
import joblib                        # To load the saved model
import pywt                          # For wavelet feature extraction
import os                            # For file operations
import cv2                           # For image processing
from pdf2image import convert_from_path  # For PDF to image conversion
import warnings
import pickle
import sklearn

def digitize_ecg_from_pdf(pdf_path, output_file='calibrated_ecg.dat', save_segments=True):
    """
    Process an ECG PDF file and convert it to a .dat signal file.
    
    Args:
        pdf_path (str): Path to the ECG PDF file
        output_file (str): Path to save the output .dat file (default: 'calibrated_ecg.dat')
        save_segments (bool): Whether to save individual segments
    
    Returns:
        tuple: (path to the created .dat file, list of paths to segment files)
    """
    # Convert PDF to image
    images = convert_from_path(pdf_path)
    temp_image_path = 'temp_ecg_image.jpg'
    images[0].save(temp_image_path, 'JPEG')
    
    # Load the image
    img = cv2.imread(temp_image_path, cv2.IMREAD_GRAYSCALE)
    height, width = img.shape
    
    # Fixed calibration parameters
    calibration = {
        'seconds_per_pixel': 2.0 / 197.0,  # 197 pixels = 2 seconds
        'mv_per_pixel': 1.0 / 78.8,        # 78.8 pixels = 1 mV
    }
    
    # Calculate layer boundaries using percentages
    layer1_start = int(height * 35.35 / 100)
    layer1_end = int(height * 51.76 / 100)
    layer2_start = int(height * 51.82 / 100)
    layer2_end = int(height * 69.41 / 100)
    layer3_start = int(height * 69.47 / 100)
    layer3_end = int(height * 87.06 / 100)
    
    # Crop each layer
    layers = [
        img[layer1_start:layer1_end, :],  # Layer 1
        img[layer2_start:layer2_end, :],  # Layer 2
        img[layer3_start:layer3_end, :]   # Layer 3
    ]
    
    # Process each layer to extract waveform contours
    signals = []
    time_points = []
    layer_duration = 10.0  # Each layer is 10 seconds long
    
    for i, layer in enumerate(layers):
        # Binary thresholding
        _, binary = cv2.threshold(layer, 200, 255, cv2.THRESH_BINARY_INV)
        
        # Detect contours
        contours, _ = cv2.findContours(binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
        waveform_contour = max(contours, key=cv2.contourArea)  # Largest contour is the ECG
        
        # Sort contour points and extract coordinates
        sorted_contour = sorted(waveform_contour, key=lambda p: p[0][0])
        x_coords = np.array([point[0][0] for point in sorted_contour])
        y_coords = np.array([point[0][1] for point in sorted_contour])
        
        # Calculate isoelectric line (one-third from the bottom)
        isoelectric_line_y = layer.shape[0] * 0.6
        
        # Convert to time using fixed layer duration
        x_min, x_max = np.min(x_coords), np.max(x_coords)
        time = (x_coords - x_min) / (x_max - x_min) * layer_duration
        
        # Calculate signal in millivolts and apply baseline correction
        signal_mv = (isoelectric_line_y - y_coords) * calibration['mv_per_pixel']
        signal_mv = signal_mv - np.mean(signal_mv)
        
        # Store the time points and calibrated signal
        time_points.append(time)
        signals.append(signal_mv)
    
    # Save individual segments if requested
    segment_files = []
    sampling_frequency = 500  # Standard ECG frequency
    samples_per_segment = int(layer_duration * sampling_frequency)  # 5000 samples per 10-second segment
    
    if save_segments:
        base_name = os.path.splitext(output_file)[0]
        
        for i, signal in enumerate(signals):
            # Interpolate to get evenly sampled signal
            segment_time = np.linspace(0, layer_duration, samples_per_segment)
            interpolated_signal = np.interp(segment_time, time_points[i], signals[i])
            
            # Normalize and scale
            interpolated_signal = interpolated_signal - np.mean(interpolated_signal)
            signal_peak = np.max(np.abs(interpolated_signal))
            
            if signal_peak > 0 and (signal_peak < 0.5 or signal_peak > 4.0):
                scaling_factor = 2.0 / signal_peak  # Target peak amplitude of 2.0 mV
                interpolated_signal = interpolated_signal * scaling_factor
            
            # Convert to 16-bit integers
            adc_gain = 1000.0
            int_signal = (interpolated_signal * adc_gain).astype(np.int16)
            
            # Save segment
            segment_file = f"{base_name}_segment{i+1}.dat"
            int_signal.reshape(-1, 1).tofile(segment_file)
            segment_files.append(segment_file)
    
    # Combine signals with proper time alignment for the full record
    total_duration = layer_duration * len(layers)
    num_samples = int(total_duration * sampling_frequency)
    combined_time = np.linspace(0, total_duration, num_samples)
    combined_signal = np.zeros(num_samples)
    
    # Place each lead at the correct time position
    for i, (time, signal) in enumerate(zip(time_points, signals)):
        start_time = i * layer_duration
        mask = (combined_time >= start_time) & (combined_time < start_time + layer_duration)
        relevant_times = combined_time[mask]
        interpolated_signal = np.interp(relevant_times, start_time + time, signal)
        combined_signal[mask] = interpolated_signal
    
    # Baseline correction and amplitude scaling
    combined_signal = combined_signal - np.mean(combined_signal)
    signal_peak = np.max(np.abs(combined_signal))
    target_amplitude = 2.0  # Target peak amplitude in mV
    
    if signal_peak > 0 and (signal_peak < 0.5 or signal_peak > 4.0):
        scaling_factor = target_amplitude / signal_peak
        combined_signal = combined_signal * scaling_factor
    
    # Convert to 16-bit integers and save as .dat file
    adc_gain = 1000.0  # Standard gain: 1000 units per mV
    int_signal = (combined_signal * adc_gain).astype(np.int16)
    int_signal.tofile(output_file)
    
    # Clean up temporary files
    if os.path.exists(temp_image_path):
        os.remove(temp_image_path)
    
    return output_file, segment_files

def split_dat_into_segments(file_path, segment_duration=10.0):
    """
    Split a DAT file into equal segments.
    
    Args:
        file_path (str): Path to the DAT file (without extension)
        segment_duration (float): Duration of each segment in seconds
        
    Returns:
        list: Paths to the segment files
    """
    try:
        # Load the signal
        signal_all_leads, fs = load_dat_signal(file_path)
        
        # Choose a lead
        if signal_all_leads.shape[1] == 1:
            lead_index = 0
        else:
            lead_priority = [1, 0]  # Try Lead II (index 1), then I (index 0)
            lead_index = next((i for i in lead_priority if i < signal_all_leads.shape[1]), 0)
            
        signal = signal_all_leads[:, lead_index]
        
        # Calculate samples per segment
        samples_per_segment = int(segment_duration * fs)
        total_samples = len(signal)
        num_segments = total_samples // samples_per_segment
        
        segment_files = []
        
        # Split and save each segment
        base_name = os.path.splitext(file_path)[0]
        
        for i in range(num_segments):
            start_idx = i * samples_per_segment
            end_idx = (i + 1) * samples_per_segment
            segment = signal[start_idx:end_idx]
            
            # Save segment
            segment_file = f"{base_name}_segment{i+1}.dat"
            segment.reshape(-1, 1).tofile(segment_file)
            segment_files.append(segment_file)
                
        return segment_files
        
    except Exception as e:
        return []

def load_dat_signal(file_path, n_leads=12, n_samples=5000, dtype=np.int16):
    """
    Load a DAT file containing ECG signal data.
    
    Args:
        file_path (str): Path to the DAT file (without extension)
        n_leads (int): Number of leads in the signal
        n_samples (int): Number of samples per lead
        dtype: Data type of the signal
        
    Returns:
        tuple: (numpy array of signal data, sampling frequency)
    """
    try:
        # Handle both cases: with and without .dat extension
        if file_path.endswith('.dat'):
            dat_path = file_path
        else:
            dat_path = file_path + '.dat'
            
        raw = np.fromfile(dat_path, dtype=dtype)
        
        # Try to infer number of leads if read size doesn't match expected
        if raw.size != n_leads * n_samples:
            # Check if single lead
            if raw.size == n_samples:
                signal = raw.reshape(n_samples, 1)
                return signal, 500
                
            # Try common lead counts
            possible_leads = [1, 2, 3, 6, 12]
            for possible_lead_count in possible_leads:
                if raw.size % possible_lead_count == 0:
                    actual_samples = raw.size // possible_lead_count
                    signal = raw.reshape(actual_samples, possible_lead_count)
                    return signal, 500
            
            # If we can't determine it reliably, reshape as single lead
            signal = raw.reshape(-1, 1)
            return signal, 500
            
        # Normal case when size matches expectation
        signal = raw.reshape(n_samples, n_leads)
        return signal, 500  # Signal + sampling frequency
    except Exception as e:
        # Return empty signal with single channel
        return np.zeros((n_samples, 1)), 500

def extract_features_from_signal(signal):
    """
    Extract features from an ECG signal.
    
    Args:
        signal (numpy.ndarray): ECG signal
        
    Returns:
        list: Features extracted from the signal
    """
    features = []
    features.append(np.mean(signal))
    features.append(np.std(signal))
    features.append(np.median(signal))
    features.append(np.min(signal))
    features.append(np.max(signal))
    features.append(np.percentile(signal, 25))
    features.append(np.percentile(signal, 75))
    features.append(np.mean(np.diff(signal)))
        
    coeffs = pywt.wavedec(signal, 'db4', level=5)
    for i, coeff in enumerate(coeffs):
        features.append(np.mean(coeff))
        features.append(np.std(coeff))
        features.append(np.min(coeff))
        features.append(np.max(coeff))
        
    return features

def classify_new_ecg(file_path, model):
    """
    Classify a new ECG file.
    
    Args:
        file_path (str): Path to the ECG file (without extension)
        model: The trained model for classification
        
    Returns:
        str: Classification result ("Normal", "Abnormal", or error message)
    """
    try:
        signal_all_leads, fs = load_dat_signal(file_path)
        
        # Choose lead for analysis - priority order
        if signal_all_leads.shape[1] == 1:
            lead_index = 0
        else:
            lead_priority = [1, 0]  # Try Lead II (index 1), then I (index 0)
            lead_index = next((i for i in lead_priority if i < signal_all_leads.shape[1]), 0)

        # Extract the signal
        signal = signal_all_leads[:, lead_index]
        
        # Normalize signal
        signal = (signal - np.mean(signal)) / np.std(signal)

        # Detect QRS complexes
        try:
            xqrs = processing.XQRS(sig=signal, fs=fs)
            xqrs.detect()
            r_peaks = xqrs.qrs_inds
        except Exception:
            r_peaks = processing.gqrs_detect(sig=signal, fs=fs)

        # Check if we found enough QRS complexes
        if len(r_peaks) < 5:
            return "Insufficient beats"

        # Calculate RR intervals and QRS durations
        rr_intervals = np.diff(r_peaks) / fs
        qrs_durations = np.array([r_peaks[i] - r_peaks[i - 1] for i in range(1, len(r_peaks))])

        # Extract features
        features = extract_features_from_signal(signal)
        
        # Add rhythm features
        features.extend([
            len(r_peaks),
            np.mean(rr_intervals) if len(rr_intervals) > 0 else 0,
            np.std(rr_intervals) if len(rr_intervals) > 0 else 0,
            np.median(rr_intervals) if len(rr_intervals) > 0 else 0,
            np.mean(qrs_durations) if len(qrs_durations) > 0 else 0,
            np.std(qrs_durations) if len(qrs_durations) > 0 else 0
        ])
        
        # Make prediction
        prediction = model.predict([features])[0]
        result = "Abnormal" if prediction == 1 else "Normal"
        
        return result

    except Exception as e:
        error_msg = f"Error: {str(e)}"
        return error_msg

def classify_ecg(file_path, model, is_pdf=False):
    """
    Wrapper function that handles both PDF and DAT ECG files with segment voting.
    
    Args:
        file_path (str): Path to the ECG file (.pdf or without extension for .dat)
        model: The trained model for classification
        is_pdf (bool): Whether the input file is a PDF (True) or DAT (False)
        
    Returns:
        str: Classification result ("Normal", "Abnormal", or error message)
    """
    try:
        # Check if model is valid
        if model is None:
            return "Error: Model not loaded. Please check model compatibility."
            
        if is_pdf:
            # Extract file name without extension for output
            base_name = os.path.splitext(os.path.basename(file_path))[0]
            output_dat = f"{base_name}_digitized.dat"
            
            # Digitize the PDF to a DAT file and get segment files
            dat_path, segment_files = digitize_ecg_from_pdf(
                pdf_path=file_path, 
                output_file=output_dat
            )
        else:
            # For DAT files, we need to split into segments
            segment_files = split_dat_into_segments(file_path)
            
            if not segment_files:
                # If splitting failed, try classifying the whole file
                return classify_new_ecg(file_path, model)
        
        # Process each segment and collect votes
        segment_results = []
        
        for i, segment_file in enumerate(segment_files):
            # Get file path without extension
            segment_path = os.path.splitext(segment_file)[0]
            
            # Classify this segment
            result = classify_new_ecg(segment_path, model)
            segment_results.append(result)
            
            # Remove temporary segment files
            try:
                os.remove(segment_file)
            except:
                pass
        
        # Count results and use majority voting
        if segment_results:
            normal_count = segment_results.count("Normal")
            abnormal_count = segment_results.count("Abnormal")
            
            # Decision rules:
            # 1. If any segment is abnormal, classify as abnormal
            # 2. Only classify as normal if majority of segments are normal
            if abnormal_count > normal_count:
                final_result = "Abnormal"
            elif normal_count > abnormal_count:
                final_result = "Normal"
            else:
                final_result = "Inconclusive"
                
            return final_result
        else:
            return "Error: No valid segments to classify"
        
    except Exception as e:
        error_msg = f"Classification error: {str(e)}"
        return error_msg

# Load the saved model
try:
    model_path = 'voting_classifier.pkl'
    if os.path.exists(model_path):
        voting_loaded = joblib.load(model_path)
    else:
        # Try to find the model in the current or parent directories
        for root, dirs, files in os.walk('.'):
            for file in files:
                if file.endswith('.pkl') and 'voting' in file.lower():
                    model_path = os.path.join(root, file)
                    voting_loaded = joblib.load(model_path)
                    break
            if 'voting_loaded' in locals():
                break
                
        if 'voting_loaded' not in locals():
            voting_loaded = None
except Exception as e:
    voting_loaded = None

# Simple test for the classify_ecg function
test_pdf_path = "sample.pdf"
if os.path.exists(test_pdf_path) and voting_loaded is not None:
    result_pdf = classify_ecg(test_pdf_path, voting_loaded, is_pdf=True)
    print(f"Classification result: {result_pdf}")

Learning initial signal parameters...
Failed to find 8 beats during learning.
Initializing using default parameters
Running QRS detection...
QRS detection complete.
Learning initial signal parameters...
Failed to find 8 beats during learning.
Initializing using default parameters
Running QRS detection...
QRS detection complete.
Learning initial signal parameters...
Failed to find 8 beats during learning.
Initializing using default parameters
Running QRS detection...
QRS detection complete.
Classification result: Normal
