# Create Modified Reference Table from CSV

In [1]:
import pandas as pd

refTable = pd.read_csv('resources/subject-info.csv')
refTable = refTable[['Patient ID', 'LVEF (%)', 'Cause of death', 'Exit of the study']]
refTable.rename(columns={'Patient ID': 'Patient_ID', "LVEF (%)": "LVEF"}, inplace = True)

refTable['LVEF'] = refTable['LVEF'].astype(float)

# Fill NaN values in 'Exit_of_the_study' with 0 (avoiding chained assignment)
refTable['Exit of the study'] = refTable['Exit of the study'].fillna(0)

# Now, you can filter based on 'Exit_of_the_study' being 0
refTable = refTable[refTable['Exit of the study'] == 0]

# Exclude patients where 'Cause_of_Death' is not equal to 0
refTable = refTable[refTable['Cause of death'] == 0]

refTable = refTable.dropna()

refTable.index = range(0, len(refTable), 1)
refTable.head()

Unnamed: 0,Patient_ID,LVEF,Cause of death,Exit of the study
0,P0001,35.0,0,0.0
1,P0002,35.0,0,0.0
2,P0003,39.0,0,0.0
3,P0004,38.0,0,0.0
4,P0005,34.0,0,0.0


In [2]:
refTable.to_csv('resources/reference-table.csv', index = False)

# Signal Retrieval and Processing

## Signal processing (the whole dataset is downloaded to local storage with wget) [Employed]

### Imports

In [3]:
import os
import logging
import datetime
import re
from tqdm import tqdm
import pandas as pd
import numpy as np
import tensorflow as tf
import wfdb
from scipy.signal import butter, lfilter, filtfilt
from IPython import get_ipython
import matplotlib.pyplot as plt

### Logging

In [4]:
def setup_logger(log_dir):
    """Sets up a logger that writes to a file in the specified directory."""
    os.makedirs(log_dir, exist_ok=True)

    timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
    log_filename = f"{timestamp}.log"
    log_filepath = os.path.join(log_dir, log_filename)

    logger = logging.getLogger(__name__)
    logger.setLevel(logging.INFO)

    formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
    file_handler = logging.FileHandler(log_filepath)
    file_handler.setFormatter(formatter)
    logger.addHandler(file_handler)

    return logger

### Filter for ECG Signals

In [5]:
def butter_bandpass(lowcut, highcut, fs, order=5):
    nyq = 0.5 * fs
    low = lowcut / nyq
    high = highcut / nyq
    b, a = butter(order, [low, high], btype='band')
    return b, a

def butter_bandpass_filter(r_signal, lowcut, highcut, fs, order=4):
    b, a = butter_bandpass(lowcut, highcut, fs, order=order)
    f_signal = filtfilt(b, a, r_signal)
    return f_signal

### Plot for Inspection

In [7]:
def plot_and_save_batch(signals, patient_ids, output_dir, filename, sampling_rate, plots_per_image=9):
    """Plots a batch of ECG signals with grid and patient IDs, and saves them as a single image with a fixed number of plots."""
    num_signals = len(signals)
    if num_signals == 0:
        return

    rows = 3
    cols = 3
    time = np.arange(signals[0].shape[0]) / sampling_rate if signals else np.array([])

    fig, axes = plt.subplots(rows, cols, figsize=(15, 10))
    axes = axes.flatten()

    for i in range(plots_per_image):
        if i < num_signals:
            axes[i].plot(time, signals[i])
            axes[i].grid(True)
            axes[i].set_title(f"Patient ID: {patient_ids[i]}", fontsize=8)
            axes[i].tick_params(axis='both', which='major', labelsize=6)
        else:
            fig.delaxes(axes[i])  # Remove empty subplots

    plt.tight_layout()
    os.makedirs(output_dir, exist_ok=True)
    plt.savefig(os.path.join(output_dir, filename))
    plt.close()

### Processing into TFRecord

In [10]:
def write_tfrecords(patient_ids, record_dir, refTable, output_dir, inspection_dir, sampling_rate, window_seconds=5, start_offset_seconds=5, lead_index=0, inspection_frequency=10, plots_per_image=9, logger=None):
    """
    Writes TFRecords and periodically saves batches of ECG windows for inspection with grid and patient IDs, with a fixed number of plots per image.

    Args:
        patient_ids (list): List of patient IDs to process.
        record_dir (str): Directory containing the WFDB record files.
        refTable (pd.DataFrame): DataFrame containing patient information.
        output_dir (str): Directory to save the TFRecord files.
        inspection_dir (str): Directory to save inspection plot images.
        sampling_rate (int): Expected sampling rate of the ECG signals.
        window_seconds (int): The size of the window in seconds to extract.
        start_offset_seconds (int): Starting offset in seconds.
        lead_index (int): The index of the lead to process.
        inspection_frequency (int): Save an inspection plot every N processed patients.
        plots_per_image (int): The fixed number of plots to include in each inspection image.
        logger (logging.Logger, optional): Logger object.
    """
    if logger is None:
        def log_info(message):
            print(message)
        def log_error(message):
            print(f"Error: {message}")
    else:
        log_info = logger.info
        log_error = logger.error

    os.makedirs(output_dir, exist_ok=True)
    os.makedirs(inspection_dir, exist_ok=True)
    all_windows_for_inspection = []
    inspection_patient_ids = []

    def _float_feature(value):
        return tf.train.Feature(float_list=tf.train.FloatList(value=value))

    target_window_size = int(window_seconds * sampling_rate)
    start_index = int(start_offset_seconds * sampling_rate)

    for i, pid in enumerate(tqdm(patient_ids, desc=f"Writing TFRecords ({window_seconds} sec, Lead {lead_index})")):
        record_path = f"{record_dir}/{pid}_H"
        try:
            record = wfdb.rdrecord(record_path, channels=[lead_index])
            signals = record.p_signal.astype(np.float32)
            fields = record.__dict__
            fs = fields['fs']
            if fs != sampling_rate:
                log_info(f"Sampling rate mismatch for patient {pid}. Skipping.")
                continue
            if signals.shape[1] < 1:
                log_info(f"Less than 1 lead for patient {pid}. Skipping.")
                continue
            if len(signals) < start_index + target_window_size:
                log_info(f"Signal too short for patient {pid}. Skipping.")
                continue

            window = signals[start_index:start_index + target_window_size, 0]
            
            patient_row = refTable[refTable['Patient_ID'] == pid]
            if patient_row.empty:
                log_info(f"Patient {pid} not in refTable. Skipping.")
                continue
            lvef = patient_row['LVEF'].values[0]

            tfrecord_path = os.path.join(output_dir, f"{pid}_lead_{lead_index}_window_{start_offset_seconds}s_{window_seconds}s.tfrecord")
            with tf.io.TFRecordWriter(tfrecord_path) as writer:
                feature = {
                    'signal': _float_feature(window.astype(np.float32).flatten()), # Use 'window' directly
                    'lvef': _float_feature([float(lvef)]),
                }
                example = tf.train.Example(features=tf.train.Features(feature=feature))
                writer.write(example.SerializeToString())

            log_info(f"Processed and wrote TFRecord for patient {pid}.")

            all_windows_for_inspection.append(window) # Append 'window' directly
            inspection_patient_ids.append(pid)

            if len(all_windows_for_inspection) >= plots_per_image:
                plot_filename = f"inspection_batch_{i // inspection_frequency + 1}_part_{(len(all_windows_for_inspection) - 1) // plots_per_image + 1}_lead_{lead_index}.png"
                batch_signals = all_windows_for_inspection[:plots_per_image]
                batch_pids = inspection_patient_ids[:plots_per_image]
                plot_and_save_batch(batch_signals, batch_pids, inspection_dir, plot_filename, sampling_rate, plots_per_image)
                all_windows_for_inspection = all_windows_for_inspection[plots_per_image:]
                inspection_patient_ids = inspection_patient_ids[plots_per_image:]

        except Exception as e:
            log_error(f"Error processing patient {pid}: {e}")

    # Save any remaining windows
    if all_windows_for_inspection:
        plot_filename = f"inspection_batch_final_lead_{lead_index}.png"
        plot_and_save_batch(all_windows_for_inspection, inspection_patient_ids, inspection_dir, plot_filename, sampling_rate, plots_per_image)

In [8]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.signal import butter, filtfilt, iirnotch
import pywt # For wavelet transform
import os # Added for file operations in write_tfrecords
import tensorflow as tf # Added for TFRecord operations
import wfdb # Added for WFDB record reading
from tqdm import tqdm # Added for progress bar

sampling_rate = 1000 # Hz
duration = 5 # seconds
num_samples = sampling_rate * duration
time = np.linspace(0, duration, num_samples, endpoint=False)

# Simulate a basic ECG-like signal (sum of sinusoids for QRS, P, T waves)
ecg_freq = 1.2 # Approximate heart rate in Hz (72 bpm)
ecg_signal = 0.6 * np.sin(2 * np.pi * ecg_freq * time) \
             + 0.2 * np.sin(2 * np.pi * 3 * ecg_freq * time) \
             + 0.1 * np.sin(2 * np.pi * 5 * ecg_freq * time)

# Add baseline wander (low-frequency noise)
baseline_wander_freq = 0.2 # Hz
baseline_wander = 0.15 * np.sin(2 * np.pi * baseline_wander_freq * time)

# Add powerline interference (e.g., 50 Hz or 60 Hz)
powerline_freq = 50 # Hz (adjust to 60 Hz if in regions like North America)
powerline_noise = 0.1 * np.sin(2 * np.pi * powerline_freq * time + np.random.rand() * 2 * np.pi)

# Add high-frequency random noise (e.g., muscle artifacts, instrumentation noise)
random_noise = 0.05 * np.random.randn(num_samples)

# Combine to create the noisy ECG signal
noisy_ecg = ecg_signal + baseline_wander + powerline_noise + random_noise

# --- 2. Denoising Algorithm Steps (Standalone Example) ---
# This section remains as a standalone example of the denoising process.

# Step 2.1: Baseline Wander Removal (High-Pass Filter)
cutoff_highpass = 0.5 # Hz
nyquist = 0.5 * sampling_rate
normal_cutoff_highpass = cutoff_highpass / nyquist
b_hp, a_hp = butter(2, normal_cutoff_highpass, btype='high', analog=False)
ecg_no_baseline = filtfilt(b_hp, a_hp, noisy_ecg)

# Step 2.2: Powerline Interference Removal (Notch Filter)
notch_freq = powerline_freq # Frequency to remove (Hz)
quality_factor = 30 # Q-factor, determines the bandwidth of the notch filter
b_notch, a_notch = iirnotch(notch_freq, quality_factor, sampling_rate)
ecg_no_powerline = filtfilt(b_notch, a_notch, ecg_no_baseline)

# Step 2.3: High-Frequency Noise Removal (Wavelet Denoising)
wavelet = 'db4'
decomposition_level = 4
coeffs = pywt.wavedec(ecg_no_powerline, wavelet, level=decomposition_level)
threshold = 0.5 * np.std(coeffs[-1]) * np.sqrt(2 * np.log(len(ecg_no_powerline)))
denoised_coeffs = [coeffs[0]]
for i in range(1, len(coeffs)):
    denoised_coeffs.append(pywt.threshold(coeffs[i], threshold, mode='soft'))
final_denoised_ecg = pywt.waverec(denoised_coeffs, wavelet)

if len(final_denoised_ecg) > len(noisy_ecg):
    final_denoised_ecg = final_denoised_ecg[:len(noisy_ecg)]
elif len(final_denoised_ecg) < len(noisy_ecg):
    padding = np.zeros(len(noisy_ecg) - len(final_denoised_ecg))
    final_denoised_ecg = np.concatenate((final_denoised_ecg, padding))


# --- Integrated write_tfrecords function with Denoising ---
def write_tfrecords(patient_ids, record_dir, refTable, output_dir, inspection_dir, sampling_rate, window_seconds=5, start_offset_seconds=5, lead_index=0, inspection_frequency=10, plots_per_image=9, logger=None,
                    denoising_params=None):
    """
    Writes TFRecords and periodically saves batches of ECG windows for inspection with grid and patient IDs,
    with a fixed number of plots per image. Integrates ECG denoising.

    Args:
        patient_ids (list): List of patient IDs to process.
        record_dir (str): Directory containing the WFDB record files.
        refTable (pd.DataFrame): DataFrame containing patient information.
        output_dir (str): Directory to save the TFRecord files.
        inspection_dir (str): Directory to save inspection plot images.
        sampling_rate (int): Expected sampling rate of the ECG signals.
        window_seconds (int): The size of the window in seconds to extract.
        start_offset_seconds (int): Starting offset in seconds.
        lead_index (int): The index of the lead to process.
        inspection_frequency (int): Save an inspection plot every N processed patients.
        plots_per_image (int): The fixed number of plots to include in each inspection image.
        logger (logging.Logger, optional): Logger object.
        denoising_params (dict, optional): Dictionary of denoising parameters.
                                          Expected keys: 'powerline_freq', 'cutoff_highpass',
                                          'notch_quality_factor', 'wavelet', 'decomposition_level'.
                                          If None, default values will be used.
    """
    if logger is None:
        def log_info(message):
            print(message)
        def log_error(message):
            print(f"Error: {message}")
    else:
        log_info = logger.info
        log_error = logger.error

    os.makedirs(output_dir, exist_ok=True)
    os.makedirs(inspection_dir, exist_ok=True)
    all_windows_for_inspection = []
    inspection_patient_ids = []

    def _float_feature(value):
        return tf.train.Feature(float_list=tf.train.FloatList(value=value))

    target_window_size = int(window_seconds * sampling_rate)
    start_index = int(start_offset_seconds * sampling_rate)

    # Denoising parameters (default values if not provided)
    default_denoising_params = {
        'powerline_freq': 50,
        'cutoff_highpass': 0.5,
        'notch_quality_factor': 30,
        'wavelet': 'db4',
        'decomposition_level': 4
    }
    params = denoising_params if denoising_params is not None else default_denoising_params

    for i, pid in enumerate(tqdm(patient_ids, desc=f"Writing TFRecords ({window_seconds} sec, Lead {lead_index})")):
        record_path = f"{record_dir}/{pid}_H"
        try:
            record = wfdb.rdrecord(record_path, channels=[lead_index])
            signals = record.p_signal.astype(np.float32)
            fields = record.__dict__
            fs = fields['fs']
            if fs != sampling_rate:
                log_info(f"Sampling rate mismatch for patient {pid}. Skipping.")
                continue
            if signals.shape[1] < 1:
                log_info(f"Less than 1 lead for patient {pid}. Skipping.")
                continue
            if len(signals) < start_index + target_window_size:
                log_info(f"Signal too short for patient {pid}. Skipping.")
                continue

            window = signals[start_index:start_index + target_window_size, 0]
            
            # --- Apply Denoising to the extracted window ---
            
            # Step 1: Baseline Wander Removal (High-Pass Filter)
            nyquist = 0.5 * sampling_rate
            normal_cutoff_highpass = params['cutoff_highpass'] / nyquist
            b_hp, a_hp = butter(2, normal_cutoff_highpass, btype='high', analog=False)
            window_denoised_hp = filtfilt(b_hp, a_hp, window)

            # Step 2: Powerline Interference Removal (Notch Filter)
            b_notch, a_notch = iirnotch(params['powerline_freq'], params['notch_quality_factor'], sampling_rate)
            window_denoised_pl = filtfilt(b_notch, a_notch, window_denoised_hp)

            # Step 3: High-Frequency Noise Removal (Wavelet Denoising)
            coeffs = pywt.wavedec(window_denoised_pl, params['wavelet'], level=params['decomposition_level'])
            # Universal threshold calculation
            threshold = 0.5 * np.std(coeffs[-1]) * np.sqrt(2 * np.log(len(window_denoised_pl)))
            
            denoised_coeffs = [coeffs[0]] # Keep approximation coefficients
            for k in range(1, len(coeffs)):
                denoised_coeffs.append(pywt.threshold(coeffs[k], threshold, mode='soft'))
            
            cleaned_window = pywt.waverec(denoised_coeffs, params['wavelet'])

            # Ensure the length matches the original window length after reconstruction
            if len(cleaned_window) > len(window):
                cleaned_window = cleaned_window[:len(window)]
            elif len(cleaned_window) < len(window):
                padding = np.zeros(len(window) - len(cleaned_window))
                cleaned_window = np.concatenate((cleaned_window, padding))

            # --- End Denoising Application ---

            patient_row = refTable[refTable['Patient_ID'] == pid]
            if patient_row.empty:
                log_info(f"Patient {pid} not in refTable. Skipping.")
                continue
            lvef = patient_row['LVEF'].values[0]

            tfrecord_path = os.path.join(output_dir, f"{pid}_lead_{lead_index}_window_{start_offset_seconds}s_{window_seconds}s.tfrecord")
            with tf.io.TFRecordWriter(tfrecord_path) as writer:
                feature = {
                    'signal': _float_feature(cleaned_window.astype(np.float32).flatten()), # Use cleaned_window
                    'lvef': _float_feature([float(lvef)]),
                }
                example = tf.train.Example(features=tf.train.Features(feature=feature))
                writer.write(example.SerializeToString())

            log_info(f"Processed and wrote TFRecord for patient {pid}.")

            all_windows_for_inspection.append(cleaned_window) # Append cleaned_window
            inspection_patient_ids.append(pid)

            if len(all_windows_for_inspection) >= plots_per_image:
                plot_filename = f"inspection_batch_{i // inspection_frequency + 1}_part_{(len(all_windows_for_inspection) - 1) // plots_per_image + 1}_lead_{lead_index}.png"
                batch_signals = all_windows_for_inspection[:plots_per_image]
                batch_pids = inspection_patient_ids[:plots_per_image]
                plot_and_save_batch(batch_signals, batch_pids, inspection_dir, plot_filename, sampling_rate, plots_per_image)
                all_windows_for_inspection = all_windows_for_inspection[plots_per_image:]
                inspection_patient_ids = inspection_patient_ids[plots_per_image:]

        except Exception as e:
            log_error(f"Error processing patient {pid}: {e}")

    # Save any remaining windows
    if all_windows_for_inspection:
        plot_filename = f"inspection_batch_final_lead_{lead_index}.png"
        plot_and_save_batch(all_windows_for_inspection, inspection_patient_ids, inspection_dir, plot_filename, sampling_rate, plots_per_image)


In [11]:
# Example usage
if __name__ == '__main__':
    data_prep_dir = "./data_preparation"
    os.makedirs(data_prep_dir, exist_ok=True)
    log_directory = os.path.join(data_prep_dir, "logs_highres")
    inspection_directory = os.path.join(data_prep_dir, "inspection_plots_highres")

    logger = setup_logger(log_directory)
    logger.info("Starting TFRecord generation with inspection (9 plots per image).")

    refTable = pd.read_csv('resources/reference-table.csv')
    all_patients = refTable['Patient_ID'].tolist()
    num_patients = len(all_patients)

    record_directory = "F:\\physionet.org\\files\\music-sudden-cardiac-death\\1.0.1\\High-resolution_ECG"
    output_directory = os.path.join(data_prep_dir, "tfrecords-5seconds-singlelead-highres")
    reference_table = refTable
    sampling_rate = 1000
    train_patient_ids = all_patients[:num_patients]
    lead_to_process = 0 #Index of lead
    window_size_seconds = 5
    inspection_frequency = 5
    plots_per_image = 9  # Set the desired number of plots per image

    write_tfrecords(train_patient_ids, record_directory, reference_table, output_directory, inspection_directory, sampling_rate, window_size_seconds, lead_index=lead_to_process, inspection_frequency=inspection_frequency, plots_per_image=plots_per_image, logger=logger)

    logger.info("TFRecord generation with inspection (9 plots per image) completed.")

Writing TFRecords (5 sec, Lead 0): 100%|██████████| 695/695 [01:14<00:00,  9.29it/s]
