# Testing lead V6 extraction method on few sample patients

In [4]:
import wfdb
import numpy as np
from scipy import signal as sig
from scipy.ndimage import median_filter
import matplotlib.pyplot as plt
from pathlib import Path

# Define the base path to the PTB-XL dataset
base_path = Path("/Users/rohanmotanavar/datasets/PTB_XL")

# List of specific patients to test
patients = [
    "records500/00000/00001_hr",
    "records500/00000/00300_hr",
    "records500/00000/00092_hr",
    "records500/00000/00427_hr",
    "records500/00000/00896_hr",
    "records500/00000/00900_hr",
    "records500/01000/01379_hr",
    "records500/01000/01739_hr"
]

# Function to preprocess the ECG signal (baseline drift removal and noise reduction)
def preprocess_ecg_signal(ecg_signal, fs, median_window_sec=0.2, lowpass_cutoff=40):

    try:
        # Validate inputs
        if not isinstance(ecg_signal, np.ndarray):
            raise TypeError("Input signal must be a numpy array.")
        if ecg_signal.size == 0:
            raise ValueError("Input signal is empty.")
        if fs <= 0:
            raise ValueError("Sampling frequency must be positive.")
        if lowpass_cutoff <= 0 or lowpass_cutoff >= fs / 2:
            raise ValueError("Low-pass cutoff frequency must be between 0 and Nyquist frequency.")

        # Step 1: Remove baseline drift using median filtering
        # Convert window size from seconds to samples
        median_window_samples = int(median_window_sec * fs)
        # Ensure the window size is odd for median filter
        if median_window_samples % 2 == 0:
            median_window_samples += 1
        # Estimate the baseline using a median filter
        baseline = median_filter(ecg_signal, size=median_window_samples)
        # Subtract the baseline to remove drift
        signal_detrended = ecg_signal - baseline

        # Step 2: Remove high-frequency noise using a low-pass filter
        nyquist_freq = fs / 2
        normalized_lowpass_cutoff = lowpass_cutoff / nyquist_freq
        # Design a low-pass Butterworth filter (order 2)
        b_low, a_low = sig.butter(2, normalized_lowpass_cutoff, btype='low', analog=False)
        # Apply the low-pass filter
        filtered_signal = sig.filtfilt(b_low, a_low, signal_detrended)
        
        return filtered_signal
    
    except Exception as e:
        print(f"Error in ECG preprocessing: {e}")
        return None

# Function to extract Lead V6 and preprocess the signal
def extract_lead_v6(record_path, fs=500, plot=False):
 
    try:
        # Load the ECG record
        record = wfdb.rdrecord(str(record_path))
        
        # Verify the sampling frequency
        if record.fs != fs:
            raise ValueError(f"Sampling frequency mismatch: expected {fs} Hz, got {record.fs} Hz")
        
        # Extract the signal data (12 leads)
        signals = record.p_signal
        
        # Verify the number of leads
        if signals.shape[1] != 12:
            raise ValueError(f"Expected 12 leads, got {signals.shape[1]} leads")
        
        # Extract Lead V6 (index 11: I, II, III, aVR, aVL, aVF, V1-V6)
        lead_v6 = signals[:, 11]
        
        # Preprocess the signal (baseline drift removal and noise reduction)
        lead_v6_filtered = preprocess_ecg_signal(lead_v6, fs, median_window_sec=0.2, lowpass_cutoff=40)
        
        if lead_v6_filtered is None:
            raise ValueError("Signal preprocessing failed.")
        
        # Optional: Plot the original and filtered signals for validation
        if plot:
            time = np.arange(len(lead_v6)) / fs
            plt.figure(figsize=(12, 6))
            plt.subplot(2, 1, 1)
            plt.plot(time, lead_v6, label="Original Lead V6", color='red')
            plt.title(f"Original Lead V6 - {record_path.name}")
            plt.xlabel("Time (s)")
            plt.ylabel("Amplitude (mV)")
            plt.grid(True)
            plt.legend()
            
            plt.subplot(2, 1, 2)
            plt.plot(time, lead_v6_filtered, label="Filtered Lead V6", color='black')
            plt.title("Filtered Lead V6 (Baseline Drift and Noise Removed)")
            plt.xlabel("Time (s)")
            plt.ylabel("Amplitude (mV)")
            plt.grid(True)
            plt.legend()
            
            plt.tight_layout()

            # Save the figure
            plot_filename = f"Visualization/{record_path.name}_lead_v6.png"
            plt.savefig(plot_filename)
            plt.close()

            plt.show()
        
        return lead_v6_filtered
    
    except Exception as e:
        print(f"Error processing {record_path}: {e}")
        return None

# Function to test the methodology on the specified patients
def test_on_specific_patients(base_path, patients_list):

    num_patients = len(patients_list)
    
    for i, patient in enumerate(patients_list):
        # Construct the full path to the record
        record_path = base_path / patient
        print(f"Processing record {i+1}/{num_patients}: {record_path.name}")
        
        # Check if the record exists
        if not record_path.with_suffix(".dat").exists():
            print(f"Record {record_path} does not exist.")
            continue
        
        # Extract Lead V6 and preprocess the signal
        lead_v6 = extract_lead_v6(record_path, fs=500, plot=True)
        
        if lead_v6 is not None:
            print(f"Successfully extracted Lead V6 for {record_path.name}, length: {len(lead_v6)} samples")
        else:
            print(f"Failed to extract Lead V6 for {record_path.name}")

# Run the test on the specified patients
if __name__ == "__main__":
    print("Testing Lead V6 extraction and preprocessing on specified patients...")
    test_on_specific_patients(base_path, patients)

Testing Lead V6 extraction and preprocessing on specified patients...
Processing record 1/8: 00001_hr
Successfully extracted Lead V6 for 00001_hr, length: 5000 samples
Processing record 2/8: 00300_hr
Successfully extracted Lead V6 for 00300_hr, length: 5000 samples
Processing record 3/8: 00092_hr
Successfully extracted Lead V6 for 00092_hr, length: 5000 samples
Processing record 4/8: 00427_hr
Successfully extracted Lead V6 for 00427_hr, length: 5000 samples
Processing record 5/8: 00896_hr
Successfully extracted Lead V6 for 00896_hr, length: 5000 samples
Processing record 6/8: 00900_hr
Successfully extracted Lead V6 for 00900_hr, length: 5000 samples
Processing record 7/8: 01379_hr
Successfully extracted Lead V6 for 01379_hr, length: 5000 samples
Processing record 8/8: 01739_hr
Successfully extracted Lead V6 for 01739_hr, length: 5000 samples


# Iterating over entire PTB-XL dataset

In [2]:
import wfdb
import numpy as np
from scipy import signal as sig
from scipy.ndimage import median_filter
import matplotlib.pyplot as plt
from pathlib import Path
import h5py
from tqdm import tqdm
import pandas as pd

# Define the base path to the PTB-XL dataset and CSV file
base_path = Path("/Users/rohanmotanavar/datasets/PTB_XL")
csv_path = base_path / "ptbxl_database.csv"

# Function to preprocess the ECG signal (baseline drift removal and noise reduction)
def preprocess_ecg_signal(ecg_signal, fs, median_window_sec=0.2, lowpass_cutoff=40):
    """
    Preprocess an ECG signal by removing baseline drift using median filtering and high-frequency noise.
    
    Parameters:
        ecg_signal (np.ndarray): The ECG signal (1D array).
        fs (float): Sampling frequency in Hz.
        median_window_sec (float): Window size for median filter in seconds.
        lowpass_cutoff (float): Cutoff frequency for the low-pass filter in Hz.
    
    Returns:
        np.ndarray: The preprocessed ECG signal.
    """
    try:
        # Validate inputs
        if not isinstance(ecg_signal, np.ndarray):
            raise TypeError("Input signal must be a numpy array.")
        if ecg_signal.size == 0:
            raise ValueError("Input signal is empty.")
        if fs <= 0:
            raise ValueError("Sampling frequency must be positive.")
        if lowpass_cutoff <= 0 or lowpass_cutoff >= fs / 2:
            raise ValueError("Low-pass cutoff frequency must be between 0 and Nyquist frequency.")

        # Step 1: Remove baseline drift using median filtering
        # Convert window size from seconds to samples
        median_window_samples = int(median_window_sec * fs)
        # Ensure the window size is odd for median filter
        if median_window_samples % 2 == 0:
            median_window_samples += 1
        # Estimate the baseline using a median filter
        baseline = median_filter(ecg_signal, size=median_window_samples)
        # Subtract the baseline to remove drift
        signal_detrended = ecg_signal - baseline

        # Step 2: Remove high-frequency noise using a low-pass filter
        nyquist_freq = fs / 2
        normalized_lowpass_cutoff = lowpass_cutoff / nyquist_freq
        # Design a low-pass Butterworth filter (order 2)
        b_low, a_low = sig.butter(2, normalized_lowpass_cutoff, btype='low', analog=False)
        # Apply the low-pass filter
        filtered_signal = sig.filtfilt(b_low, a_low, signal_detrended)
        
        return filtered_signal
    
    except Exception as e:
        print(f"Error in ECG preprocessing: {e}")
        return None

# Function to extract Lead V6 and preprocess the signal
def extract_lead_v6(record_path, fs=500):
    """
    Extract Lead V6 from a PTB-XL record and preprocess the signal.
    
    Parameters:
        record_path (Path): Path to the record (without extension).
        fs (float): Expected sampling frequency in Hz.
    
    Returns:
        np.ndarray: The preprocessed Lead V6 signal, or None if processing fails.
    """
    try:
        # Load the ECG record
        record = wfdb.rdrecord(str(record_path))
        
        # Verify the sampling frequency
        if record.fs != fs:
            raise ValueError(f"Sampling frequency mismatch: expected {fs} Hz, got {record.fs} Hz")
        
        # Extract the signal data (12 leads)
        signals = record.p_signal
        
        # Verify the number of leads
        if signals.shape[1] != 12:
            raise ValueError(f"Expected 12 leads, got {signals.shape[1]} leads")
        
        # Extract Lead V6 (index 11: I, II, III, aVR, aVL, aVF, V1-V6)
        lead_v6 = signals[:, 11]
        
        # Preprocess the signal (baseline drift removal and noise reduction)
        lead_v6_filtered = preprocess_ecg_signal(lead_v6, fs, median_window_sec=0.2, lowpass_cutoff=40)
        
        if lead_v6_filtered is None:
            raise ValueError("Signal preprocessing failed.")
        
        return lead_v6_filtered
    
    except Exception as e:
        print(f"Error processing {record_path}: {e}")
        return None

# Function to process the entire dataset using the CSV file and store in HDF5 and CSV with a progress bar
def process_entire_dataset(csv_path, base_path, output_hdf5_path="preprocessed_lead_v6.h5", output_csv_path="preprocessed_lead_v6.csv"):
    """
    Process the entire PTB-XL dataset to extract and preprocess Lead V6 for all records,
    using the ptbxl_database.csv file to get record paths, and store the results in an HDF5 file
    and a CSV file for viewing.
    
    Parameters:
        csv_path (Path): Path to the ptbxl_database.csv file.
        base_path (Path): Base path to the PTB-XL dataset.
        output_hdf5_path (str): Path to the output HDF5 file.
        output_csv_path (str): Path to the output CSV file for viewing.
    """
    # Load the CSV file
    print("Loading ptbxl_database.csv...")
    df = pd.read_csv(csv_path)
    
    # Extract relevant columns
    records = df[['ecg_id', 'patient_id', 'filename_hr']].copy()
    total_records = len(records)
    print(f"Total records to process: {total_records}")

    # Initialize a list to store data for the CSV
    csv_data = []
    signal_length = 5000  # Expected length of each preprocessed Lead V6 signal

    # Create the HDF5 file
    with h5py.File(output_hdf5_path, 'w') as h5file:
        processed_records = 0

        # Use tqdm to display a progress bar
        with tqdm(total=total_records, desc="Processing PTB-XL dataset", unit="record") as pbar:
            for _, row in records.iterrows():
                ecg_id = row['ecg_id']
                patient_id = row['patient_id']
                filename_hr = row['filename_hr']
                
                # Construct the full path to the record
                record_path = base_path / filename_hr
                record_name = record_path.name  # e.g., "00017_hr"
                
                # Extract and preprocess Lead V6
                lead_v6 = extract_lead_v6(record_path, fs=500)
                
                if lead_v6 is not None:
                    # Store the preprocessed signal in the HDF5 file
                    # Use the ecg_id as the dataset key for uniqueness
                    dset = h5file.create_dataset(str(ecg_id), data=lead_v6, compression='gzip', compression_opts=4)
                    # Store metadata as attributes
                    dset.attrs['patient_id'] = patient_id
                    dset.attrs['filename_hr'] = filename_hr
                    dset.attrs['record_path'] = str(record_path)
                    processed_records += 1

                    # Prepare data for CSV
                    # Create a dictionary with ecg_id, patient_id, filename_hr, and the signal values
                    record_data = {
                        'ecg_id': ecg_id,
                        'patient_id': patient_id,
                        'filename_hr': filename_hr
                    }
                    # Add the signal values as separate columns (sample_0, sample_1, ..., sample_4999)
                    for i in range(signal_length):
                        record_data[f'sample_{i}'] = lead_v6[i]
                    csv_data.append(record_data)
                
                # Update the progress bar
                pbar.update(1)

        print(f"\nProcessing complete. Successfully processed: {processed_records}/{total_records} records")

    # Step: Save the preprocessed signals to a CSV file for viewing
    print(f"\nSaving preprocessed signals to {output_csv_path} for viewing...")
    # Create a DataFrame from the collected data
    csv_df = pd.DataFrame(csv_data)
    # Define the column order: metadata first, then the signal samples
    columns = ['ecg_id', 'patient_id', 'filename_hr'] + [f'sample_{i}' for i in range(signal_length)]
    csv_df = csv_df[columns]
    # Save to CSV
    csv_df.to_csv(output_csv_path, index=False)
    print(f"CSV file saved successfully: {output_csv_path}")

# Run the processing for the entire dataset
if __name__ == "__main__":
    print("Starting processing of the entire PTB-XL dataset using ptbxl_database.csv...")
    process_entire_dataset(csv_path, base_path, output_hdf5_path="preprocessed_lead_v6.h5", output_csv_path="preprocessed_lead_v6.csv")

Starting processing of the entire PTB-XL dataset using ptbxl_database.csv...
Loading ptbxl_database.csv...
Total records to process: 21799


Processing PTB-XL dataset: 100%|██████████| 21799/21799 [01:57<00:00, 184.77record/s]



Processing complete. Successfully processed: 21799/21799 records

Saving preprocessed signals to preprocessed_lead_v6.csv for viewing...
CSV file saved successfully: preprocessed_lead_v6.csv
