In [4]:
import wfdb                          # To read the ECG files
from wfdb import processing          # For QRS detection
import numpy as np                   # Numerical operations
import joblib                        # To load the saved model
import pywt                          # For wavelet feature extraction
import os                            # For file operations
import cv2                           # For image processing
from pdf2image import convert_from_path  # For PDF to image conversion
import warnings
import pickle
import sklearn

# Let's modify the digitize_ecg_from_pdf function to return segment information
def digitize_ecg_from_pdf(pdf_path, output_file='calibrated_ecg.dat', debug=False, save_segments=True):
    """
    Process an ECG PDF file and convert it to a .dat signal file.
    
    Args:
        pdf_path (str): Path to the ECG PDF file
        output_file (str): Path to save the output .dat file (default: 'calibrated_ecg.dat')
        debug (bool): Whether to print debug information
        save_segments (bool): Whether to save individual segments
    
    Returns:
        tuple: (path to the created .dat file, list of paths to segment files)
    """
    if debug:
        print(f"Starting ECG digitization from PDF: {pdf_path}")
    
    # Convert PDF to image
    images = convert_from_path(pdf_path)
    temp_image_path = 'temp_ecg_image.jpg'
    images[0].save(temp_image_path, 'JPEG')
    
    if debug:
        print(f"Converted PDF to image: {temp_image_path}")
    
    # Load the image
    img = cv2.imread(temp_image_path, cv2.IMREAD_GRAYSCALE)
    height, width = img.shape
    
    if debug:
        print(f"Image dimensions: {width}x{height}")
    
    # Fixed calibration parameters
    calibration = {
        'seconds_per_pixel': 2.0 / 197.0,  # 197 pixels = 2 seconds
        'mv_per_pixel': 1.0 / 78.8,        # 78.8 pixels = 1 mV
    }
    
    if debug:
        print(f"Calibration parameters: {calibration}")
    
    # Calculate layer boundaries using percentages
    layer1_start = int(height * 35.35 / 100)
    layer1_end = int(height * 51.76 / 100)
    layer2_start = int(height * 51.82 / 100)
    layer2_end = int(height * 69.41 / 100)
    layer3_start = int(height * 69.47 / 100)
    layer3_end = int(height * 87.06 / 100)
    
    if debug:
        print(f"Layer 1 boundaries: {layer1_start}-{layer1_end}")
        print(f"Layer 2 boundaries: {layer2_start}-{layer2_end}")
        print(f"Layer 3 boundaries: {layer3_start}-{layer3_end}")
    
    # Crop each layer
    layers = [
        img[layer1_start:layer1_end, :],  # Layer 1
        img[layer2_start:layer2_end, :],  # Layer 2
        img[layer3_start:layer3_end, :]   # Layer 3
    ]
    
    # Process each layer to extract waveform contours
    signals = []
    time_points = []
    layer_duration = 10.0  # Each layer is 10 seconds long
    
    for i, layer in enumerate(layers):
        if debug:
            print(f"Processing layer {i+1}...")
        
        # Binary thresholding
        _, binary = cv2.threshold(layer, 200, 255, cv2.THRESH_BINARY_INV)
        
        # Detect contours
        contours, _ = cv2.findContours(binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
        waveform_contour = max(contours, key=cv2.contourArea)  # Largest contour is the ECG
        
        if debug:
            print(f"  - Found {len(contours)} contours")
            print(f"  - Selected contour with {len(waveform_contour)} points")
        
        # Sort contour points and extract coordinates
        sorted_contour = sorted(waveform_contour, key=lambda p: p[0][0])
        x_coords = np.array([point[0][0] for point in sorted_contour])
        y_coords = np.array([point[0][1] for point in sorted_contour])
        
        # Calculate isoelectric line (one-third from the bottom)
        isoelectric_line_y = layer.shape[0] * 0.6
        
        # Convert to time using fixed layer duration
        x_min, x_max = np.min(x_coords), np.max(x_coords)
        time = (x_coords - x_min) / (x_max - x_min) * layer_duration
        
        # Calculate signal in millivolts and apply baseline correction
        signal_mv = (isoelectric_line_y - y_coords) * calibration['mv_per_pixel']
        signal_mv = signal_mv - np.mean(signal_mv)
        
        if debug:
            print(f"  - Layer {i+1} signal range: {np.min(signal_mv):.2f} mV to {np.max(signal_mv):.2f} mV")
        
        # Store the time points and calibrated signal
        time_points.append(time)
        signals.append(signal_mv)
    
    # Save individual segments if requested
    segment_files = []
    sampling_frequency = 500  # Standard ECG frequency
    samples_per_segment = int(layer_duration * sampling_frequency)  # 5000 samples per 10-second segment
    
    if save_segments:
        base_name = os.path.splitext(output_file)[0]
        
        for i, signal in enumerate(signals):
            # Interpolate to get evenly sampled signal
            segment_time = np.linspace(0, layer_duration, samples_per_segment)
            interpolated_signal = np.interp(segment_time, time_points[i], signals[i])
            
            # Normalize and scale
            interpolated_signal = interpolated_signal - np.mean(interpolated_signal)
            signal_peak = np.max(np.abs(interpolated_signal))
            
            if signal_peak > 0 and (signal_peak < 0.5 or signal_peak > 4.0):
                scaling_factor = 2.0 / signal_peak  # Target peak amplitude of 2.0 mV
                interpolated_signal = interpolated_signal * scaling_factor
            
            # Convert to 16-bit integers
            adc_gain = 1000.0
            int_signal = (interpolated_signal * adc_gain).astype(np.int16)
            
            # Save segment
            segment_file = f"{base_name}_segment{i+1}.dat"
            int_signal.reshape(-1, 1).tofile(segment_file)
            segment_files.append(segment_file)
            
            if debug:
                print(f"Saved segment {i+1} to {segment_file}")
    
    # Combine signals with proper time alignment for the full record
    total_duration = layer_duration * len(layers)
    num_samples = int(total_duration * sampling_frequency)
    combined_time = np.linspace(0, total_duration, num_samples)
    combined_signal = np.zeros(num_samples)
    
    if debug:
        print(f"Combining signals with {sampling_frequency} Hz sampling rate, total duration: {total_duration}s")
    
    # Place each lead at the correct time position
    for i, (time, signal) in enumerate(zip(time_points, signals)):
        start_time = i * layer_duration
        mask = (combined_time >= start_time) & (combined_time < start_time + layer_duration)
        relevant_times = combined_time[mask]
        interpolated_signal = np.interp(relevant_times, start_time + time, signal)
        combined_signal[mask] = interpolated_signal
        
        if debug:
            print(f"  - Added layer {i+1} signal from {start_time}s to {start_time + layer_duration}s")
    
    # Baseline correction and amplitude scaling
    combined_signal = combined_signal - np.mean(combined_signal)
    signal_peak = np.max(np.abs(combined_signal))
    target_amplitude = 2.0  # Target peak amplitude in mV
    
    if debug:
        print(f"Signal peak before scaling: {signal_peak:.2f} mV")
    
    if signal_peak > 0 and (signal_peak < 0.5 or signal_peak > 4.0):
        scaling_factor = target_amplitude / signal_peak
        combined_signal = combined_signal * scaling_factor
        if debug:
            print(f"Applied scaling factor: {scaling_factor:.2f}")
            print(f"Signal peak after scaling: {np.max(np.abs(combined_signal)):.2f} mV")
    
    # Convert to 16-bit integers and save as .dat file
    adc_gain = 1000.0  # Standard gain: 1000 units per mV
    int_signal = (combined_signal * adc_gain).astype(np.int16)
    int_signal.tofile(output_file)
    
    if debug:
        print(f"Saved signal to {output_file} with {len(int_signal)} samples")
        print(f"Integer signal range: {np.min(int_signal)} to {np.max(int_signal)}")
    
    # Clean up temporary files
    if os.path.exists(temp_image_path):
        os.remove(temp_image_path)
        if debug:
            print(f"Removed temporary image: {temp_image_path}")
    
    return output_file, segment_files

# Add a function to split a DAT file into segments
def split_dat_into_segments(file_path, segment_duration=10.0, debug=False):
    """
    Split a DAT file into equal segments.
    
    Args:
        file_path (str): Path to the DAT file (without extension)
        segment_duration (float): Duration of each segment in seconds
        debug (bool): Whether to print debug information
        
    Returns:
        list: Paths to the segment files
    """
    try:
        # Load the signal
        signal_all_leads, fs = load_dat_signal(file_path, debug=debug)
        
        if debug:
            print(f"Loaded signal with shape {signal_all_leads.shape}")
        
        # Choose a lead
        if signal_all_leads.shape[1] == 1:
            lead_index = 0
        else:
            lead_priority = [1, 0]  # Try Lead II (index 1), then I (index 0)
            lead_index = next((i for i in lead_priority if i < signal_all_leads.shape[1]), 0)
            
        signal = signal_all_leads[:, lead_index]
        
        # Calculate samples per segment
        samples_per_segment = int(segment_duration * fs)
        total_samples = len(signal)
        num_segments = total_samples // samples_per_segment
        
        if debug:
            print(f"Splitting signal into {num_segments} segments of {segment_duration} seconds each")
            
        segment_files = []
        
        # Split and save each segment
        base_name = os.path.splitext(file_path)[0]
        
        for i in range(num_segments):
            start_idx = i * samples_per_segment
            end_idx = (i + 1) * samples_per_segment
            segment = signal[start_idx:end_idx]
            
            # Save segment
            segment_file = f"{base_name}_segment{i+1}.dat"
            segment.reshape(-1, 1).tofile(segment_file)
            segment_files.append(segment_file)
            
            if debug:
                print(f"Saved segment {i+1} to {segment_file}")
                
        return segment_files
        
    except Exception as e:
        if debug:
            print(f"Error splitting DAT file: {str(e)}")
        return []

# Add function to load DAT signals
def load_dat_signal(file_path, n_leads=12, n_samples=5000, dtype=np.int16, debug=False):
    """
    Load a DAT file containing ECG signal data.
    
    Args:
        file_path (str): Path to the DAT file (without extension)
        n_leads (int): Number of leads in the signal
        n_samples (int): Number of samples per lead
        dtype: Data type of the signal
        debug (bool): Whether to print debug information
        
    Returns:
        tuple: (numpy array of signal data, sampling frequency)
    """
    try:
        # Handle both cases: with and without .dat extension
        if file_path.endswith('.dat'):
            dat_path = file_path
        else:
            dat_path = file_path + '.dat'
            
        if debug:
            print(f"Loading signal from: {dat_path}")
            
        raw = np.fromfile(dat_path, dtype=dtype)
        
        if debug:
            print(f"Raw data size: {raw.size}")
            
        # Try to infer number of leads if read size doesn't match expected
        if raw.size != n_leads * n_samples:
            if debug:
                print(f"Unexpected size: {raw.size}, expected {n_leads * n_samples}")
                print("Attempting to infer number of leads...")
                
            # Check if single lead
            if raw.size == n_samples:
                if debug:
                    print("Detected single lead signal")
                signal = raw.reshape(n_samples, 1)
                return signal, 500
                
            # Try common lead counts
            possible_leads = [1, 2, 3, 6, 12]
            for possible_lead_count in possible_leads:
                if raw.size % possible_lead_count == 0:
                    actual_samples = raw.size // possible_lead_count
                    if debug:
                        print(f"Inferred {possible_lead_count} leads with {actual_samples} samples each")
                    signal = raw.reshape(actual_samples, possible_lead_count)
                    return signal, 500
            
            # If we can't determine it reliably, reshape as single lead
            if debug:
                print("Could not infer lead count, reshaping as single lead")
            signal = raw.reshape(-1, 1)
            return signal, 500
            
        # Normal case when size matches expectation
        signal = raw.reshape(n_samples, n_leads)
        return signal, 500  # Signal + sampling frequency
    except Exception as e:
        if debug:
            print(f"Error loading DAT file: {str(e)}")
        # Return empty signal with single channel
        return np.zeros((n_samples, 1)), 500

# Add the feature extraction function
def extract_features_from_signal(signal, debug=False):
    """
    Extract features from an ECG signal.
    
    Args:
        signal (numpy.ndarray): ECG signal
        debug (bool): Whether to print debug information
        
    Returns:
        list: Features extracted from the signal
    """
    if debug:
        print("Extracting features from signal...")
        
    features = []
    features.append(np.mean(signal))
    features.append(np.std(signal))
    features.append(np.median(signal))
    features.append(np.min(signal))
    features.append(np.max(signal))
    features.append(np.percentile(signal, 25))
    features.append(np.percentile(signal, 75))
    features.append(np.mean(np.diff(signal)))

    if debug:
        print("Computing wavelet decomposition...")
        
    coeffs = pywt.wavedec(signal, 'db4', level=5)
    for i, coeff in enumerate(coeffs):
        features.append(np.mean(coeff))
        features.append(np.std(coeff))
        features.append(np.min(coeff))
        features.append(np.max(coeff))
        
        if debug and i == 0:
            print(f"Wavelet features for level {i}: mean={np.mean(coeff):.4f}, std={np.std(coeff):.4f}")

    if debug:
        print(f"Extracted {len(features)} features")
        
    return features

# Add the classify_new_ecg function
def classify_new_ecg(file_path, model, debug=False):
    """
    Classify a new ECG file.
    
    Args:
        file_path (str): Path to the ECG file (without extension)
        model: The trained model for classification
        debug (bool): Whether to print debug information
        
    Returns:
        str: Classification result ("Normal", "Abnormal", or error message)
    """
    try:
        if debug:
            print(f"Classifying ECG from: {file_path}")
            
        signal_all_leads, fs = load_dat_signal(file_path, debug=debug)
        
        if debug:
            print(f"Loaded signal with shape {signal_all_leads.shape}, sampling rate {fs} Hz")
            
        # Choose lead for analysis - priority order
        if signal_all_leads.shape[1] == 1:
            lead_index = 0
            if debug:
                print("Using single lead")
        else:
            lead_priority = [1, 0]  # Try Lead II (index 1), then I (index 0)
            lead_index = next((i for i in lead_priority if i < signal_all_leads.shape[1]), 0)
            if debug:
                print(f"Using lead index {lead_index}")

        # Extract the signal
        signal = signal_all_leads[:, lead_index]
        
        # Normalize signal
        signal = (signal - np.mean(signal)) / np.std(signal)
        
        if debug:
            print("Signal normalized")
            print(f"Detecting QRS complexes...")

        # Detect QRS complexes
        try:
            xqrs = processing.XQRS(sig=signal, fs=fs)
            xqrs.detect()
            r_peaks = xqrs.qrs_inds
            if debug:
                print(f"Detected {len(r_peaks)} QRS complexes with XQRS method")
        except Exception as e:
            if debug:
                print(f"XQRS detection failed: {str(e)}")
                print("Falling back to GQRS detector")
            r_peaks = processing.gqrs_detect(sig=signal, fs=fs)
            if debug:
                print(f"Detected {len(r_peaks)} QRS complexes with GQRS method")

        # Check if we found enough QRS complexes
        if len(r_peaks) < 5:
            if debug:
                print(f"Insufficient beats detected: {len(r_peaks)}")
            return "Insufficient beats"

        # Calculate RR intervals and QRS durations
        rr_intervals = np.diff(r_peaks) / fs
        qrs_durations = np.array([r_peaks[i] - r_peaks[i - 1] for i in range(1, len(r_peaks))])
        
        if debug:
            print(f"Mean RR interval: {np.mean(rr_intervals):.4f} s")
            print(f"Mean QRS duration: {np.mean(qrs_durations) / fs:.4f} s")

        # Extract features
        features = extract_features_from_signal(signal, debug=debug)
        
        # Add rhythm features
        features.extend([
            len(r_peaks),
            np.mean(rr_intervals) if len(rr_intervals) > 0 else 0,
            np.std(rr_intervals) if len(rr_intervals) > 0 else 0,
            np.median(rr_intervals) if len(rr_intervals) > 0 else 0,
            np.mean(qrs_durations) if len(qrs_durations) > 0 else 0,
            np.std(qrs_durations) if len(qrs_durations) > 0 else 0
        ])
        
        if debug:
            print(f"Final feature vector length: {len(features)}")
            
        # Make prediction
        prediction = model.predict([features])[0]
        result = "Abnormal" if prediction == 1 else "Normal"
        
        if debug:
            print(f"Classification result: {result} (prediction value: {prediction})")
            
        return result

    except Exception as e:
        error_msg = f"Error: {str(e)}"
        if debug:
            print(error_msg)
        return error_msg

# Modify the classify_ecg wrapper function to use the voting approach
def classify_ecg(file_path, model, is_pdf=False, debug=False):
    """
    Wrapper function that handles both PDF and DAT ECG files with segment voting.
    
    Args:
        file_path (str): Path to the ECG file (.pdf or without extension for .dat)
        model: The trained model for classification
        is_pdf (bool): Whether the input file is a PDF (True) or DAT (False)
        debug (bool): Enable debug output
        
    Returns:
        str: Classification result ("Normal", "Abnormal", or error message)
    """
    try:
        # Check if model is valid
        if model is None:
            return "Error: Model not loaded. Please check model compatibility."
            
        if is_pdf:
            if debug:
                print(f"Processing PDF file: {file_path}")
            
            # Extract file name without extension for output
            base_name = os.path.splitext(os.path.basename(file_path))[0]
            output_dat = f"{base_name}_digitized.dat"
            
            # Digitize the PDF to a DAT file and get segment files
            dat_path, segment_files = digitize_ecg_from_pdf(
                pdf_path=file_path, 
                output_file=output_dat, 
                debug=debug
            )
            
            if debug:
                print(f"Digitized ECG saved to: {dat_path}")
                print(f"Created {len(segment_files)} segment files")
        else:
            if debug:
                print(f"Processing DAT file: {file_path}")
            
            # For DAT files, we need to split into segments
            segment_files = split_dat_into_segments(file_path, debug=debug)
            
            if not segment_files:
                # If splitting failed, try classifying the whole file
                return classify_new_ecg(file_path, model, debug=debug)
        
        # Process each segment and collect votes
        segment_results = []
        
        for i, segment_file in enumerate(segment_files):
            if debug:
                print(f"\n--- Processing Segment {i+1} ---")
                
            # Get file path without extension
            segment_path = os.path.splitext(segment_file)[0]
            
            # Classify this segment
            result = classify_new_ecg(segment_path, model, debug=debug)
            
            if debug:
                print(f"Segment {i+1} classification: {result}")
                
            segment_results.append(result)
            
            # Remove temporary segment files
            try:
                os.remove(segment_file)
                if debug:
                    print(f"Removed temporary segment file: {segment_file}")
            except:
                pass
        
        # Count results and use majority voting
        if segment_results:
            normal_count = segment_results.count("Normal")
            abnormal_count = segment_results.count("Abnormal")
            error_count = len(segment_results) - normal_count - abnormal_count
            
            if debug:
                print(f"\n--- Voting Results ---")
                print(f"Normal votes: {normal_count}")
                print(f"Abnormal votes: {abnormal_count}")
                print(f"Errors/Inconclusive: {error_count}")
            
            # Decision rules:
            # 1. If any segment is abnormal, classify as abnormal
            # 2. Only classify as normal if majority of segments are normal
            if abnormal_count > normal_count:
                final_result = "Abnormal"
            elif normal_count > abnormal_count:
                final_result = "Normal"
            else:
                final_result = "Inconclusive"
                
            if debug:
                print(f"Final decision: {final_result}")
                
            return final_result
        else:
            return "Error: No valid segments to classify"
        
    except Exception as e:
        error_msg = f"Classification error: {str(e)}"
        if debug:
            print(error_msg)
        return error_msg

# Load the saved model
try:
    model_path = 'voting_classifier.pkl'
    if os.path.exists(model_path):
        voting_loaded = joblib.load(model_path)
        print(f"Successfully loaded classification model from {model_path}")
    else:
        print(f"Model file not found: {model_path}")
        print("Attempting to locate model file...")
        
        # Try to find the model in the current or parent directories
        for root, dirs, files in os.walk('.'):
            for file in files:
                if file.endswith('.pkl') and 'voting' in file.lower():
                    model_path = os.path.join(root, file)
                    print(f"Found potential model file: {model_path}")
                    voting_loaded = joblib.load(model_path)
                    print(f"Successfully loaded model from {model_path}")
                    break
            if 'voting_loaded' in locals():
                break
                
        if 'voting_loaded' not in locals():
            # If we still can't find it, create a dummy model for demonstration
            print("No model found. Creating a dummy model for demonstration.")
            from sklearn.ensemble import VotingClassifier
            from sklearn.tree import DecisionTreeClassifier
            
            # Create a simple dummy model
            dummy_model = DecisionTreeClassifier(max_depth=2)
            voting_loaded = VotingClassifier(estimators=[('dt', dummy_model)], voting='soft')
            # This won't actually work for prediction but allows the code to run
            print("WARNING: Using a dummy model that won't provide valid predictions")
except Exception as e:
    print(f"Error loading model: {str(e)}")
    print("Creating a placeholder model object")
    # Create a placeholder that will show errors when used
    class PlaceholderModel:
        def predict(self, X):
            raise RuntimeError("No valid model loaded")
    voting_loaded = PlaceholderModel()

print("\n----- Testing PDF file classification with segment voting -----")
try:
    # Check if the test PDF file exists
    test_pdf_path = "sample.pdf"
    if os.path.exists(test_pdf_path):
        print(f"Found test PDF file: {test_pdf_path}")
        result_pdf = classify_ecg(test_pdf_path, voting_loaded, is_pdf=True, debug=True)
        print(f"PDF file classification result (segment voting): {result_pdf}")
    else:
        print(f"Test PDF file not found: {test_pdf_path}")
        print("Skipping PDF file classification test.")
except Exception as e:
    print(f"Error testing PDF file: {str(e)}")

Successfully loaded classification model from voting_classifier.pkl

----- Testing PDF file classification with segment voting -----
Found test PDF file: sample.pdf
Processing PDF file: sample.pdf
Starting ECG digitization from PDF: sample.pdf
Converted PDF to image: temp_ecg_image.jpg
Image dimensions: 2200x1700
Calibration parameters: {'seconds_per_pixel': 0.01015228426395939, 'mv_per_pixel': 0.012690355329949238}
Layer 1 boundaries: 600-879
Layer 2 boundaries: 880-1179
Layer 3 boundaries: 1180-1480
Processing layer 1...
  - Found 3912 contours
  - Selected contour with 3830 points
  - Layer 1 signal range: -0.81 mV to 1.49 mV
Processing layer 2...
  - Found 4864 contours
  - Selected contour with 3570 points
  - Layer 2 signal range: -0.56 mV to 1.18 mV
Processing layer 3...
  - Found 4067 contours
  - Selected contour with 3917 points
  - Layer 3 signal range: -0.49 mV to 1.20 mV
Saved segment 1 to sample_digitized_segment1.dat
Saved segment 2 to sample_digitized_segment2.dat
Saved