# CNN x ECG on PTB-XL Dataset

tbd description

## Import

In [1]:
import pandas as pd
import numpy as np
import wfdb
import ast
from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
import torch
from torch.utils.data import Dataset, DataLoader

## Load Data

tbd description

In [2]:
path = 'data/ptbxl/'

### Metadata

tbd description

In [3]:
df_metadata = pd.read_csv(path+'ptbxl_database.csv', index_col='ecg_id')

In [4]:
# add the base path to the filenames
df_metadata['filename_lr'] = df_metadata['filename_lr'].apply(lambda x: path + x)
df_metadata['filename_hr'] = df_metadata['filename_hr'].apply(lambda x: path + x)

In [5]:
# convert scp_codes strings, like "{'NORM': 100}", to dictionaries like {'NORM': 100}
df_metadata['scp_codes'] = df_metadata['scp_codes'].apply(lambda x: ast.literal_eval(x))

In [6]:
# The Y for machine learning
df_metadata

Unnamed: 0_level_0,patient_id,age,sex,height,weight,nurse,site,device,recording_date,report,...,validated_by_human,baseline_drift,static_noise,burst_noise,electrodes_problems,extra_beats,pacemaker,strat_fold,filename_lr,filename_hr
ecg_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
1,15709.0,56.0,1,,63.0,2.0,0.0,CS-12 E,1984-11-09 09:17:34,sinusrhythmus periphere niederspannung,...,True,,", I-V1,",,,,,3,data/ptbxl/records100/00000/00001_lr,data/ptbxl/records500/00000/00001_hr
2,13243.0,19.0,0,,70.0,2.0,0.0,CS-12 E,1984-11-14 12:55:37,sinusbradykardie sonst normales ekg,...,True,,,,,,,2,data/ptbxl/records100/00000/00002_lr,data/ptbxl/records500/00000/00002_hr
3,20372.0,37.0,1,,69.0,2.0,0.0,CS-12 E,1984-11-15 12:49:10,sinusrhythmus normales ekg,...,True,,,,,,,5,data/ptbxl/records100/00000/00003_lr,data/ptbxl/records500/00000/00003_hr
4,17014.0,24.0,0,,82.0,2.0,0.0,CS-12 E,1984-11-15 13:44:57,sinusrhythmus normales ekg,...,True,", II,III,AVF",,,,,,3,data/ptbxl/records100/00000/00004_lr,data/ptbxl/records500/00000/00004_hr
5,17448.0,19.0,1,,70.0,2.0,0.0,CS-12 E,1984-11-17 10:43:15,sinusrhythmus normales ekg,...,True,", III,AVR,AVF",,,,,,4,data/ptbxl/records100/00000/00005_lr,data/ptbxl/records500/00000/00005_hr
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
21833,17180.0,67.0,1,,,1.0,2.0,AT-60 3,2001-05-31 09:14:35,ventrikulÄre extrasystole(n) sinustachykardie ...,...,True,,", alles,",,,1ES,,7,data/ptbxl/records100/21000/21833_lr,data/ptbxl/records500/21000/21833_hr
21834,20703.0,300.0,0,,,1.0,2.0,AT-60 3,2001-06-05 11:33:39,sinusrhythmus lagetyp normal qrs(t) abnorm ...,...,True,,,,,,,4,data/ptbxl/records100/21000/21834_lr,data/ptbxl/records500/21000/21834_hr
21835,19311.0,59.0,1,,,1.0,2.0,AT-60 3,2001-06-08 10:30:27,sinusrhythmus lagetyp normal t abnorm in anter...,...,True,,", I-AVR,",,,,,2,data/ptbxl/records100/21000/21835_lr,data/ptbxl/records500/21000/21835_hr
21836,8873.0,64.0,1,,,1.0,2.0,AT-60 3,2001-06-09 18:21:49,supraventrikulÄre extrasystole(n) sinusrhythmu...,...,True,,,,,SVES,,8,data/ptbxl/records100/21000/21836_lr,data/ptbxl/records500/21000/21836_hr


In [7]:
df_metadata.columns

Index(['patient_id', 'age', 'sex', 'height', 'weight', 'nurse', 'site',
       'device', 'recording_date', 'report', 'scp_codes', 'heart_axis',
       'infarction_stadium1', 'infarction_stadium2', 'validated_by',
       'second_opinion', 'initial_autogenerated_report', 'validated_by_human',
       'baseline_drift', 'static_noise', 'burst_noise', 'electrodes_problems',
       'extra_beats', 'pacemaker', 'strat_fold', 'filename_lr', 'filename_hr'],
      dtype='object')

### ECG Data

tbd description

In [8]:
def load_single_ecg(filepath:str):
    '''
    Load a single ECG file.
    :param filepath: Path to the ECG file.
    :return: The ECG signal as a numpy array (e.g. (1000, 12) for 1000 samples and 12 leads).
    '''
    signal, _ = wfdb.rdsamp(filepath)
    return signal

In [9]:
def load_ecg_data(df:pd.DataFrame, sampling_rate:int, max_workers:int=6) -> np.ndarray:
    '''
    Load raw ECG data from the PTB-XL database parallel.

    :param df: DataFrame containing metadata of ECG files
    :param sampling_rate: Sampling rate of the ECG data (100Hz or 500Hz)
    :param max_workers: Number of parallel workers to use for loading ECG files
    :return: Numpy array of ECG signals
    '''
    # lr = Low Resolution, 100Hz, hr = High Resolution, 500Hz
    filenames_col = 'filename_lr' if sampling_rate == 100 else 'filename_hr'
    filepaths = df[filenames_col].tolist()

    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        signals = list(tqdm(
            executor.map(load_single_ecg, filepaths),
            total=len(filepaths),
            desc=f"Loading {sampling_rate}Hz ECG data",
            unit="files"
        ))

    return np.array(signals)

In [10]:
# The X for machine learning
data_100hz = load_ecg_data(df_metadata, 100)

Loading 100Hz ECG data: 100%|██████████| 21799/21799 [01:14<00:00, 293.44files/s]


In [11]:
# The X for machine learning
# data_500hz = load_ecg_data(df_metadata, 500)

### SCP Codes

tbd description

In [12]:
df_scp_statements = pd.read_csv(path+'scp_statements.csv', index_col=0)
df_scp_statements

Unnamed: 0,description,diagnostic,form,rhythm,diagnostic_class,diagnostic_subclass,Statement Category,SCP-ECG Statement Description,AHA code,aECG REFID,CDISC Code,DICOM Code
NDT,non-diagnostic T abnormalities,1.0,1.0,,STTC,STTC,other ST-T descriptive statements,non-diagnostic T abnormalities,,,,
NST_,non-specific ST changes,1.0,1.0,,STTC,NST_,Basic roots for coding ST-T changes and abnorm...,non-specific ST changes,145.0,MDC_ECG_RHY_STHILOST,,
DIG,digitalis-effect,1.0,1.0,,STTC,STTC,other ST-T descriptive statements,suggests digitalis-effect,205.0,,,
LNGQT,long QT-interval,1.0,1.0,,STTC,STTC,other ST-T descriptive statements,long QT-interval,148.0,,,
NORM,normal ECG,1.0,,,NORM,NORM,Normal/abnormal,normal ECG,1.0,,,F-000B7
...,...,...,...,...,...,...,...,...,...,...,...,...
BIGU,"bigeminal pattern (unknown origin, SV or Ventr...",,,1.0,,,Statements related to ectopic rhythm abnormali...,"bigeminal pattern (unknown origin, SV or Ventr...",,,,
AFLT,atrial flutter,,,1.0,,,Statements related to impulse formation (abnor...,atrial flutter,51.0,MDC_ECG_RHY_ATR_FLUT,,
SVTAC,supraventricular tachycardia,,,1.0,,,Statements related to impulse formation (abnor...,supraventricular tachycardia,55.0,MDC_ECG_RHY_SV_TACHY,,D3-31290
PSVT,paroxysmal supraventricular tachycardia,,,1.0,,,Statements related to impulse formation (abnor...,paroxysmal supraventricular tachycardia,,MDC_ECG_RHY_SV_TACHY_PAROX,,


In [13]:
# filter statements to only include diagnostic statements
# a SCP code can be a diagnostic, form and/or a rythm statement
df_scp_statements = df_scp_statements[df_scp_statements['diagnostic'] == 1]
display(df_scp_statements)

Unnamed: 0,description,diagnostic,form,rhythm,diagnostic_class,diagnostic_subclass,Statement Category,SCP-ECG Statement Description,AHA code,aECG REFID,CDISC Code,DICOM Code
NDT,non-diagnostic T abnormalities,1.0,1.0,,STTC,STTC,other ST-T descriptive statements,non-diagnostic T abnormalities,,,,
NST_,non-specific ST changes,1.0,1.0,,STTC,NST_,Basic roots for coding ST-T changes and abnorm...,non-specific ST changes,145.0,MDC_ECG_RHY_STHILOST,,
DIG,digitalis-effect,1.0,1.0,,STTC,STTC,other ST-T descriptive statements,suggests digitalis-effect,205.0,,,
LNGQT,long QT-interval,1.0,1.0,,STTC,STTC,other ST-T descriptive statements,long QT-interval,148.0,,,
NORM,normal ECG,1.0,,,NORM,NORM,Normal/abnormal,normal ECG,1.0,,,F-000B7
IMI,inferior myocardial infarction,1.0,,,MI,IMI,Myocardial Infarction,inferior myocardial infarction,161.0,,,
ASMI,anteroseptal myocardial infarction,1.0,,,MI,AMI,Myocardial Infarction,anteroseptal myocardial infarction,165.0,,,
LVH,left ventricular hypertrophy,1.0,,,HYP,LVH,Ventricular Hypertrophy,left ventricular hypertrophy,142.0,,C71076,
LAFB,left anterior fascicular block,1.0,,,CD,LAFB/LPFB,Intraventricular and intra-atrial Conduction d...,left anterior fascicular block,101.0,MDC_ECG_BEAT_BLK_ANT_L_HEMI,C62267,D3-33140
ISC_,non-specific ischemic,1.0,,,STTC,ISC_,Basic roots for coding ST-T changes and abnorm...,ischemic ST-T changes,226.0,,,


In [14]:
# e.g. for dict_scp_codes: {'NORM': 100, 'MI': 80}

def aggregate_diagnostic(dict_scp_codes: dict) -> list:
    '''
    Lookup every SCP Code in the dataframe and add scp code in a unique list.
    '''
    list_scp_codes = []

    for key in dict_scp_codes.keys():
        if key in df_scp_statements.index:
            list_scp_codes.append(df_scp_statements.loc[key]['diagnostic_class'])
    
    # remove duplicates by converting to a set and back to a list
    list_scp_codes = list(set(list_scp_codes))

    return list_scp_codes

In [15]:
# add column diagnostic superclass
tqdm.pandas(desc="Processing diagnostic superclass")
df_metadata['diagnostic_superclass'] = df_metadata['scp_codes'].progress_apply(aggregate_diagnostic)

Processing diagnostic superclass: 100%|██████████| 21799/21799 [00:01<00:00, 12206.72it/s]


In [16]:
# first step for a simple neuronal network
df_metadata['mi_label'] = df_metadata['diagnostic_superclass'].apply(lambda x: 1 if 'MI' in x else 0)

In [17]:
df_metadata

Unnamed: 0_level_0,patient_id,age,sex,height,weight,nurse,site,device,recording_date,report,...,static_noise,burst_noise,electrodes_problems,extra_beats,pacemaker,strat_fold,filename_lr,filename_hr,diagnostic_superclass,mi_label
ecg_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
1,15709.0,56.0,1,,63.0,2.0,0.0,CS-12 E,1984-11-09 09:17:34,sinusrhythmus periphere niederspannung,...,", I-V1,",,,,,3,data/ptbxl/records100/00000/00001_lr,data/ptbxl/records500/00000/00001_hr,[NORM],0
2,13243.0,19.0,0,,70.0,2.0,0.0,CS-12 E,1984-11-14 12:55:37,sinusbradykardie sonst normales ekg,...,,,,,,2,data/ptbxl/records100/00000/00002_lr,data/ptbxl/records500/00000/00002_hr,[NORM],0
3,20372.0,37.0,1,,69.0,2.0,0.0,CS-12 E,1984-11-15 12:49:10,sinusrhythmus normales ekg,...,,,,,,5,data/ptbxl/records100/00000/00003_lr,data/ptbxl/records500/00000/00003_hr,[NORM],0
4,17014.0,24.0,0,,82.0,2.0,0.0,CS-12 E,1984-11-15 13:44:57,sinusrhythmus normales ekg,...,,,,,,3,data/ptbxl/records100/00000/00004_lr,data/ptbxl/records500/00000/00004_hr,[NORM],0
5,17448.0,19.0,1,,70.0,2.0,0.0,CS-12 E,1984-11-17 10:43:15,sinusrhythmus normales ekg,...,,,,,,4,data/ptbxl/records100/00000/00005_lr,data/ptbxl/records500/00000/00005_hr,[NORM],0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
21833,17180.0,67.0,1,,,1.0,2.0,AT-60 3,2001-05-31 09:14:35,ventrikulÄre extrasystole(n) sinustachykardie ...,...,", alles,",,,1ES,,7,data/ptbxl/records100/21000/21833_lr,data/ptbxl/records500/21000/21833_hr,[STTC],0
21834,20703.0,300.0,0,,,1.0,2.0,AT-60 3,2001-06-05 11:33:39,sinusrhythmus lagetyp normal qrs(t) abnorm ...,...,,,,,,4,data/ptbxl/records100/21000/21834_lr,data/ptbxl/records500/21000/21834_hr,[NORM],0
21835,19311.0,59.0,1,,,1.0,2.0,AT-60 3,2001-06-08 10:30:27,sinusrhythmus lagetyp normal t abnorm in anter...,...,", I-AVR,",,,,,2,data/ptbxl/records100/21000/21835_lr,data/ptbxl/records500/21000/21835_hr,[STTC],0
21836,8873.0,64.0,1,,,1.0,2.0,AT-60 3,2001-06-09 18:21:49,supraventrikulÄre extrasystole(n) sinusrhythmu...,...,,,,SVES,,8,data/ptbxl/records100/21000/21836_lr,data/ptbxl/records500/21000/21836_hr,[NORM],0


## Data Split for a simple NN

tbd description

In [18]:
X = data_100hz # OR data_500hz !
Y = df_metadata

In [19]:
print("Distribution of MI (0/1):\n", df_metadata['mi_label'].value_counts(normalize=True))

Distribution of MI (0/1):
 mi_label
0    0.749117
1    0.250883
Name: proportion, dtype: float64


### Test Split

In [20]:
# 20% data for testing, 80% for training and validation
df_train_valid, df_test = train_test_split(
    df_metadata,
    test_size=0.2,
    stratify=df_metadata['mi_label'],
    random_state=42
)

In [21]:
print("Distribution of MI in df_train_valid:\n", df_train_valid['mi_label'].value_counts(normalize=True))
print('---')
print("Distribution of MI in df_test:\n", df_test['mi_label'].value_counts(normalize=True))

Distribution of MI in df_train_valid:
 mi_label
0    0.749126
1    0.250874
Name: proportion, dtype: float64
---
Distribution of MI in df_test:
 mi_label
0    0.749083
1    0.250917
Name: proportion, dtype: float64


### Train and Validation Split

In [22]:
# 20% for test, 10% für validation, 70% for training
df_train, df_val = train_test_split(
    df_train_valid,
    test_size=0.125, # 10% of the original data
    stratify=df_train_valid['mi_label'],
    random_state=42
)

In [23]:
print("Distribution of MI in df_train:\n", df_train['mi_label'].value_counts(normalize=True))
print('---')
print("Distribution of MI in df_valid:\n", df_val['mi_label'].value_counts(normalize=True))

Distribution of MI in df_train:
 mi_label
0    0.749132
1    0.250868
Name: proportion, dtype: float64
---
Distribution of MI in df_valid:
 mi_label
0    0.749083
1    0.250917
Name: proportion, dtype: float64


In [24]:
print(f"Size of Train-Set: {len(df_test)}")
print(f"Size of Validation-Set: {len(df_val)}")
print(f"Size of Test-Set: {len(df_train)}")

Size of Train-Set: 4360
Size of Validation-Set: 2180
Size of Test-Set: 15259


## Normalization

### Calculation of Z-Score-Nomalization (μ, σ) on training data

In [25]:
# turn array into one long 1D array
flattened_train = data_100hz.flatten() 
# flattened_train = data_500hz.flatten() 

In [26]:
# check data type
flattened_train = flattened_train.astype(np.float32)

In [27]:
# calculate meana and standard deviation
train_mean = np.mean(flattened_train)
train_std = np.std(flattened_train)

In [28]:
# Prevend division by zero
# if the standard deviation is very small, set it to a small positive value
if train_std < 1e-6:
    train_std = 1e-6

In [29]:
print(f"Calculated mean (μ) on training data: {train_mean:.6f}")
print(f"Calculated standard deviation (σ) on training data: {train_std:.6f}")

Calculated mean (μ) on training data: -0.000781
Calculated standard deviation (σ) on training data: 0.235392


In [30]:
# save values for normalization on the test and validation set
global_train_mean = train_mean
global_train_std = train_std

### Definition of the PyTorch Dataset Class

In [31]:
class ECGDataset(Dataset):
    '''
    Custom PyTorch Dataset class for loading and preprocessing ECG data.
    '''
    
    def __init__(self, df:pd.DataFrame, mean:float, std:float):
        '''
        Constructor for ECGDataset.

        :param df: DataFrame containing ECG metadata and paths to the ECG files.
        :param mean: Mean value (μ) from training data for normalization.
        :param std: Standard deviation (σ) from training data for normalization.
        '''
        self.df = df
        self.mean = mean
        self.std = std
        self.std_stable = self.std + 1e-8  # Prevent division by zero
    

    def __len__(self):
        '''
        Returns the number of samples in the dataset.

        :return: Number of samples.
        '''
        return len(self.df)


    def __getitem__(self, idx):
        '''
        Loads and preprocesses a single ESG sample based on its index.
        Called by the PyTorch DataLoader when it needs a new sample.

        :param idx: Index of the sample to load.
        :return: A tuple containing the ECG signal as a PyTorch tensor and the corresponding label.
        '''

        # 1. Get information from first sample
        row = self.df.iloc[idx]
        filepath = row['filename_lr']  # or 'filename_hr' for 500Hz
        label = row['mi_label']

        # 2. Load ECG data
        signal, _ = wfdb.rdsamp(filepath)

        # 3. Flatten the signal to a 1D array
        signal = signal.flatten()

        # 4. Normalize the signal (Z-Score-Normalization: (X - μ) / σ)
        signal = (signal - self.mean) / self.std_stable

        # 5. Converto to PyTorch tensor
        signal_tensor = torch.tensor(signal, dtype=torch.float32)
        label_tensor = torch.tensor(label, dtype=torch.float32)

        return signal_tensor, label_tensor

In [32]:
train_ECGDataset = ECGDataset(df=df_train, mean=global_train_mean, std=global_train_std)
val_ECGDataset = ECGDataset(df=df_val, mean=global_train_mean, std=global_train_std)
test_ECGDataset = ECGDataset(df=df_test, mean=global_train_mean, std=global_train_std)

In [33]:
print(f"Size of training dataset: {len(train_ECGDataset)} Samples")
print(f"Size of validation dataset: {len(val_ECGDataset)} Samples")
print(f"Size of testing dataset: {len(test_ECGDataset)} Samples")

Size of training dataset: 15259 Samples
Size of validation dataset: 2180 Samples
Size of testing dataset: 4360 Samples


In [34]:
# Check if loading a single sample works and is correct
try:
    sample_signal, sample_label = train_ECGDataset[0]
    print(f"First sample:")
    print(f"Signal Shape: {sample_signal.shape}")
    print(f"Signal Mean: {sample_signal.mean().item():.4f}")
    print(f"Signal Std: {sample_signal.std().item():.4f}")
    print(f"Label: {sample_label.item()}")
except Exception as e:
    print(f"Error: {e}")

First sample:
Signal Shape: torch.Size([12000])
Signal Mean: 0.0014
Signal Std: 0.6086
Label: 0.0


### Definition of the PyTorch DataLoader

In [35]:
# set the size of how many samples in the neuronal network are processed at once bevore the weights are updated
# 64 or 128 are common start values
# the bigger the batch size,
BATCH_SIZE = 64

num_workers = 0 # because of some windows issues with multiprocessing

In [36]:
train_loader = DataLoader(
    dataset=train_ECGDataset,
    batch_size=BATCH_SIZE,
    shuffle=True, # Shuffle the trainings data to improve generalization
    num_workers=num_workers
)

In [37]:
val_loader = DataLoader(
    dataset=val_ECGDataset,
    batch_size=BATCH_SIZE,
    shuffle=False, # No Suffle, to ensure the same order of evaluation
    num_workers=num_workers
)

In [38]:
test_loader = DataLoader(
    dataset=test_ECGDataset,
    batch_size=BATCH_SIZE,
    shuffle=False, # No Suffle, to ensure the same order of evaluation
    num_workers=num_workers
)

In [40]:
# check if the DataLoader works and returns the expected shapes and values
try:
    # 'iter(train_loader)' to iterate over the dataloader like a list
    # 'next(...)' gets the next batch, a pair of signal tensor and label tensor
    first_batch_signals, first_batch_labels = next(iter(train_loader))

    print(f"Shape of signal batch (should be (Batch_Size, flattened_ECG_length)): {first_batch_signals.shape}")
    print(f"Shape of label batch (should be (Batch_Size,) or (Batch_Size, 1)): {first_batch_labels.shape}")

    print(f"Mean of the first batch (signal): {first_batch_signals.mean().item():.4f}")
    print(f"Standard deviation of the first batch (signal): {first_batch_signals.std().item():.4f}")
    
    print(f"First 5 labels of the batch: {first_batch_labels[:5].tolist()}")

except Exception as e:
    print(f"Error while testing the DataLoader: {e}")

Shape of signal batch (should be (Batch_Size, flattened_ECG_length)): torch.Size([64, 12000])
Shape of label batch (should be (Batch_Size,) or (Batch_Size, 1)): torch.Size([64])
Mean of the first batch (signal): 0.0006
Standard deviation of the first batch (signal): 0.9358
First 5 labels of the batch: [1.0, 0.0, 1.0, 1.0, 1.0]
