# Create Modified Reference Table from CSV

In [1]:
import pandas as pd

refTable = pd.read_csv('resources/subject-info.csv')
refTable = refTable[['Patient ID', 'LVEF (%)', 'Creatinine (?mol/L)', 'Cause of death', 'Exit of the study']]
refTable.rename(columns={'Patient ID': 'Patient_ID', "LVEF (%)": "LVEF", "Creatinine (?mol/L)": "Creatinine"}, inplace = True)

refTable['LVEF'] = refTable['LVEF'].astype(float)
refTable['Creatinine'] = refTable['Creatinine'].astype(float)

# Fill NaN values in 'Exit_of_the_study' with 0 (avoiding chained assignment)
refTable['Exit of the study'] = refTable['Exit of the study'].fillna(0)

# Now, you can filter based on 'Exit_of_the_study' being 0
refTable = refTable[refTable['Exit of the study'] == 0]

# Exclude patients where 'Cause_of_Death' is not equal to 0
refTable = refTable[refTable['Cause of death'] == 0]

refTable = refTable.dropna()

refTable.index = range(0, len(refTable), 1)
refTable.head()

Unnamed: 0,Patient_ID,LVEF,Creatinine,Cause of death,Exit of the study
0,P0001,35.0,106.0,0,0.0
1,P0002,35.0,121.0,0,0.0
2,P0003,39.0,87.0,0,0.0
3,P0004,38.0,77.0,0,0.0
4,P0005,34.0,88.0,0,0.0


In [2]:
refTable.to_csv('resources/reference-table.csv', index = False)

# Signal Retrieval and Processing

## Signal processing (the whole dataset is downloaded to local storage with wget) [Employed]

### Imports

In [3]:
import os
import logging
import datetime
import re
from tqdm import tqdm
import pandas as pd
import numpy as np
import tensorflow as tf
import wfdb
from scipy.signal import butter, lfilter, filtfilt
from IPython import get_ipython
import matplotlib.pyplot as plt

### Logging

In [4]:
def setup_logger(log_dir):
    """Sets up a logger that writes to a file in the specified directory."""
    os.makedirs(log_dir, exist_ok=True)

    timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
    log_filename = f"{timestamp}.log"
    log_filepath = os.path.join(log_dir, log_filename)

    logger = logging.getLogger(__name__)
    logger.setLevel(logging.INFO)

    formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
    file_handler = logging.FileHandler(log_filepath)
    file_handler.setFormatter(formatter)
    logger.addHandler(file_handler)

    return logger

### Filter for ECG Signals

In [5]:
def butter_bandpass(lowcut, highcut, fs, order=5):
    nyq = 0.5 * fs
    low = lowcut / nyq
    high = highcut / nyq
    b, a = butter(order, [low, high], btype='band')
    return b, a

def butter_bandpass_filter(r_signal, lowcut, highcut, fs, order=5):
    b, a = butter_bandpass(lowcut, highcut, fs, order=order)
    f_signal = filtfilt(b, a, r_signal)
    return f_signal

### Plot for Inspection

In [6]:
def plot_and_save_batch(signals, patient_ids, output_dir, filename, sampling_rate, plots_per_image=9):
    """Plots a batch of ECG signals with grid and patient IDs, and saves them as a single image with a fixed number of plots."""
    num_signals = len(signals)
    if num_signals == 0:
        return

    rows = 3
    cols = 3
    time = np.arange(signals[0].shape[0]) / sampling_rate if signals else np.array([])

    fig, axes = plt.subplots(rows, cols, figsize=(15, 10))
    axes = axes.flatten()

    for i in range(plots_per_image):
        if i < num_signals:
            axes[i].plot(time, signals[i])
            axes[i].grid(True)
            axes[i].set_title(f"Patient ID: {patient_ids[i]}", fontsize=8)
            axes[i].tick_params(axis='both', which='major', labelsize=6)
        else:
            fig.delaxes(axes[i])  # Remove empty subplots

    plt.tight_layout()
    os.makedirs(output_dir, exist_ok=True)
    plt.savefig(os.path.join(output_dir, filename))
    plt.close()

### Processing into TFRecord

In [7]:
def write_tfrecords(patient_ids, record_dir, refTable, output_dir, inspection_dir, sampling_rate, window_seconds=5, start_offset_seconds=3600, lead_index=0, inspection_frequency=10, plots_per_image=9, logger=None):
    """
    Writes TFRecords and periodically saves batches of ECG windows for inspection with grid and patient IDs, with a fixed number of plots per image.

    Args:
        patient_ids (list): List of patient IDs to process.
        record_dir (str): Directory containing the WFDB record files.
        refTable (pd.DataFrame): DataFrame containing patient information.
        output_dir (str): Directory to save the TFRecord files.
        inspection_dir (str): Directory to save inspection plot images.
        sampling_rate (int): Expected sampling rate of the ECG signals.
        window_seconds (int): The size of the window in seconds to extract.
        start_offset_seconds (int): Starting offset in seconds.
        lead_index (int): The index of the lead to process.
        inspection_frequency (int): Save an inspection plot every N processed patients.
        plots_per_image (int): The fixed number of plots to include in each inspection image.
        logger (logging.Logger, optional): Logger object.
    """
    if logger is None:
        def log_info(message):
            print(message)
        def log_error(message):
            print(f"Error: {message}")
    else:
        log_info = logger.info
        log_error = logger.error

    os.makedirs(output_dir, exist_ok=True)
    os.makedirs(inspection_dir, exist_ok=True)
    all_windows_for_inspection = []
    inspection_patient_ids = []

    def _float_feature(value):
        return tf.train.Feature(float_list=tf.train.FloatList(value=value))

    target_window_size = int(window_seconds * sampling_rate)
    start_index = int(start_offset_seconds * sampling_rate)

    for i, pid in enumerate(tqdm(patient_ids, desc=f"Writing TFRecords ({window_seconds} sec, Lead {lead_index})")):
        record_path = f"{record_dir}/{pid}"
        try:
            record = wfdb.rdrecord(record_path, channels=[lead_index])
            signals = record.p_signal.astype(np.float32)
            fields = record.__dict__
            fs = fields['fs']
            if fs != sampling_rate:
                log_info(f"Sampling rate mismatch for patient {pid}. Skipping.")
                continue
            if signals.shape[1] < 1:
                log_info(f"Less than 1 lead for patient {pid}. Skipping.")
                continue
            if len(signals) < start_index + target_window_size:
                log_info(f"Signal too short for patient {pid}. Skipping.")
                continue

            window = signals[start_index:start_index + target_window_size, 0]
            filtered_signal = butter_bandpass_filter(window, 1, 40, fs, order=5)

            patient_row = refTable[refTable['Patient_ID'] == pid]
            if patient_row.empty:
                log_info(f"Patient {pid} not in refTable. Skipping.")
                continue
            lvef = patient_row['LVEF'].values[0]
            creatinine = patient_row['Creatinine'].values[0]

            tfrecord_path = os.path.join(output_dir, f"{pid}_lead_{lead_index}_window_{start_offset_seconds}s_{window_seconds}s.tfrecord")
            with tf.io.TFRecordWriter(tfrecord_path) as writer:
                feature = {
                    'signal': _float_feature(filtered_signal.astype(np.float32).flatten()),
                    'lvef': _float_feature([float(lvef)]),
                    'creatinine': _float_feature([float(creatinine)])
                }
                example = tf.train.Example(features=tf.train.Features(feature=feature))
                writer.write(example.SerializeToString())

            log_info(f"Processed and wrote TFRecord for patient {pid}.")

            all_windows_for_inspection.append(filtered_signal)
            inspection_patient_ids.append(pid)

            if len(all_windows_for_inspection) >= plots_per_image:
                plot_filename = f"inspection_batch_{i // inspection_frequency + 1}_part_{(len(all_windows_for_inspection) - 1) // plots_per_image + 1}_lead_{lead_index}.png"
                batch_signals = all_windows_for_inspection[:plots_per_image]
                batch_pids = inspection_patient_ids[:plots_per_image]
                plot_and_save_batch(batch_signals, batch_pids, inspection_dir, plot_filename, sampling_rate, plots_per_image)
                all_windows_for_inspection = all_windows_for_inspection[plots_per_image:]
                inspection_patient_ids = inspection_patient_ids[plots_per_image:]

        except Exception as e:
            log_error(f"Error processing patient {pid}: {e}")

    # Save any remaining windows
    if all_windows_for_inspection:
        plot_filename = f"inspection_batch_final_lead_{lead_index}.png"
        plot_and_save_batch(all_windows_for_inspection, inspection_patient_ids, inspection_dir, plot_filename, sampling_rate, plots_per_image)

In [8]:
# Example usage
if __name__ == '__main__':
    data_prep_dir = "./data_preparation"
    os.makedirs(data_prep_dir, exist_ok=True)
    log_directory = os.path.join(data_prep_dir, "logs")
    inspection_directory = os.path.join(data_prep_dir, "inspection_plots")

    logger = setup_logger(log_directory)
    logger.info("Starting TFRecord generation with inspection (9 plots per image).")

    refTable = pd.read_csv('resources/reference-table.csv')
    all_patients = refTable['Patient_ID'].tolist()
    num_patients = len(all_patients)

    record_directory = "F:\\physionet.org\\files\\music-sudden-cardiac-death\\1.0.1\\Holter_ECG"
    output_directory = os.path.join(data_prep_dir, "tfrecords-5seconds-singlelead")
    reference_table = refTable
    sampling_rate = 200
    train_patient_ids = all_patients[:num_patients]
    lead_to_process = 1 #Index of lead
    window_size_seconds = 5
    inspection_frequency = 5
    plots_per_image = 9  # Set the desired number of plots per image

    write_tfrecords(train_patient_ids, record_directory, reference_table, output_directory, inspection_directory, sampling_rate, window_size_seconds, lead_index=lead_to_process, inspection_frequency=inspection_frequency, plots_per_image=plots_per_image, logger=logger)

    logger.info("TFRecord generation with inspection (9 plots per image) completed.")

Writing TFRecords (5 sec, Lead 1): 100%|██████████| 695/695 [12:44<00:00,  1.10s/it]


In [None]:
import wfdb
import matplotlib.pyplot as plt

# Assuming your file is in a format that wfdb can read (e.g., .dat and .hea from PhysioNet)
record = wfdb.rdrecord('your_ecg_file_base_name', channels=[0]) # Read the first channel
signal = record.p_signal
time = np.arange(signal.shape[0]) / record.fs

plt.figure()
plt.plot(time, signal)
plt.xlabel('Time (s)')
plt.ylabel('Amplitude (mV)')
plt.title('ECG Signal')
plt.grid(True)
plt.show()

print(record) # Print record information

## In place signal retrieval and processing from Physionet database [Alternative]