In [1]:
import ast
import os
import pickle

In [2]:
import h5py
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from tqdm import tqdm
import wfdb

In [3]:
pd.set_option('display.max_columns', None)

# PTB-XL Dataset

Download the latest PTB-XL dataset from https://physionet.org/content/ptb-xl and extract it to `data/ptb_xl/raw`

In [4]:
DATA_ROOT = './data/ptb_xl/raw'
SAVE_DIR = './data/ptb_xl/processed'

PTB-XL dataset contains the same data for two different sampling rates, 100Hz and 500Hz.

We use 100Hz version for this project.

In [5]:
TARGET_FS = 100

Processing the data and annotations can be done similar to `example_physionet.py` provided by the PTB-XL dataset

In [6]:
def load_raw_data(df, sampling_rate, data_dir):
    if sampling_rate == 100:
        data = [wfdb.rdsamp(os.path.join(data_dir, f)) for f in df.filename_lr]
    else:
        data = [wfdb.rdsamp(os.path.join(data_dir, f)) for f in df.filename_hr]
    data = np.array([signal for signal, meta in data])
    return data

## Load and convert annotation data

In [7]:
df_ann = pd.read_csv(os.path.join(DATA_ROOT, 'ptbxl_database.csv'), index_col='ecg_id')
# ast.literal_eval tranform string representations of Python literals to actual Python object
df_ann['scp_codes'] = df_ann['scp_codes'].apply(lambda x: ast.literal_eval(x))
df_ann.head()

Unnamed: 0_level_0,patient_id,age,sex,height,weight,nurse,site,device,recording_date,report,scp_codes,heart_axis,infarction_stadium1,infarction_stadium2,validated_by,second_opinion,initial_autogenerated_report,validated_by_human,baseline_drift,static_noise,burst_noise,electrodes_problems,extra_beats,pacemaker,strat_fold,filename_lr,filename_hr
ecg_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1,Unnamed: 24_level_1,Unnamed: 25_level_1,Unnamed: 26_level_1,Unnamed: 27_level_1
1,15709.0,56.0,1,,63.0,2.0,0.0,CS-12 E,1984-11-09 09:17:34,sinusrhythmus periphere niederspannung,"{'NORM': 100.0, 'LVOLT': 0.0, 'SR': 0.0}",,,,,False,False,True,,", I-V1,",,,,,3,records100/00000/00001_lr,records500/00000/00001_hr
2,13243.0,19.0,0,,70.0,2.0,0.0,CS-12 E,1984-11-14 12:55:37,sinusbradykardie sonst normales ekg,"{'NORM': 80.0, 'SBRAD': 0.0}",,,,,False,False,True,,,,,,,2,records100/00000/00002_lr,records500/00000/00002_hr
3,20372.0,37.0,1,,69.0,2.0,0.0,CS-12 E,1984-11-15 12:49:10,sinusrhythmus normales ekg,"{'NORM': 100.0, 'SR': 0.0}",,,,,False,False,True,,,,,,,5,records100/00000/00003_lr,records500/00000/00003_hr
4,17014.0,24.0,0,,82.0,2.0,0.0,CS-12 E,1984-11-15 13:44:57,sinusrhythmus normales ekg,"{'NORM': 100.0, 'SR': 0.0}",,,,,False,False,True,", II,III,AVF",,,,,,3,records100/00000/00004_lr,records500/00000/00004_hr
5,17448.0,19.0,1,,70.0,2.0,0.0,CS-12 E,1984-11-17 10:43:15,sinusrhythmus normales ekg,"{'NORM': 100.0, 'SR': 0.0}",,,,,False,False,True,", III,AVR,AVF",,,,,,4,records100/00000/00005_lr,records500/00000/00005_hr


In [8]:
df_ann.shape

(21799, 27)

In [9]:
df_ann['patient_id'].unique().shape

(18869,)

## Diagnostic aggregation

In [10]:
df_agg = pd.read_csv(os.path.join(DATA_ROOT, 'scp_statements.csv'), index_col=0)
df_agg = df_agg[df_agg.diagnostic == 1]

In [11]:
df_agg.head()

Unnamed: 0,description,diagnostic,form,rhythm,diagnostic_class,diagnostic_subclass,Statement Category,SCP-ECG Statement Description,AHA code,aECG REFID,CDISC Code,DICOM Code
NDT,non-diagnostic T abnormalities,1.0,1.0,,STTC,STTC,other ST-T descriptive statements,non-diagnostic T abnormalities,,,,
NST_,non-specific ST changes,1.0,1.0,,STTC,NST_,Basic roots for coding ST-T changes and abnorm...,non-specific ST changes,145.0,MDC_ECG_RHY_STHILOST,,
DIG,digitalis-effect,1.0,1.0,,STTC,STTC,other ST-T descriptive statements,suggests digitalis-effect,205.0,,,
LNGQT,long QT-interval,1.0,1.0,,STTC,STTC,other ST-T descriptive statements,long QT-interval,148.0,,,
NORM,normal ECG,1.0,,,NORM,NORM,Normal/abnormal,normal ECG,1.0,,,F-000B7


In [12]:
scp_to_diag_subclass_map = df_agg.diagnostic_subclass.to_dict()

In [13]:
scp_to_diag_subclass_map

{'NDT': 'STTC',
 'NST_': 'NST_',
 'DIG': 'STTC',
 'LNGQT': 'STTC',
 'NORM': 'NORM',
 'IMI': 'IMI',
 'ASMI': 'AMI',
 'LVH': 'LVH',
 'LAFB': 'LAFB/LPFB',
 'ISC_': 'ISC_',
 'IRBBB': 'IRBBB',
 '1AVB': '_AVB',
 'IVCD': 'IVCD',
 'ISCAL': 'ISCA',
 'CRBBB': 'CRBBB',
 'CLBBB': 'CLBBB',
 'ILMI': 'IMI',
 'LAO/LAE': 'LAO/LAE',
 'AMI': 'AMI',
 'ALMI': 'AMI',
 'ISCIN': 'ISCI',
 'INJAS': 'AMI',
 'LMI': 'LMI',
 'ISCIL': 'ISCI',
 'LPFB': 'LAFB/LPFB',
 'ISCAS': 'ISCA',
 'INJAL': 'AMI',
 'ISCLA': 'ISCA',
 'RVH': 'RVH',
 'ANEUR': 'STTC',
 'RAO/RAE': 'RAO/RAE',
 'EL': 'STTC',
 'WPW': 'WPW',
 'ILBBB': 'ILBBB',
 'IPLMI': 'IMI',
 'ISCAN': 'ISCA',
 'IPMI': 'IMI',
 'SEHYP': 'SEHYP',
 'INJIN': 'IMI',
 'INJLA': 'AMI',
 'PMI': 'PMI',
 '3AVB': '_AVB',
 'INJIL': 'IMI',
 '2AVB': '_AVB'}

In [14]:
df_unique_diag_subclass = df_agg.drop_duplicates('diagnostic_subclass')
diag_subclass_to_superclass_map = pd.Series(df_unique_diag_subclass.diagnostic_class.values,index=df_unique_diag_subclass.diagnostic_subclass).to_dict()

In [15]:
df_unique_diag_subclass[['diagnostic_class', 'diagnostic_subclass']].sort_values(by='diagnostic_class')

Unnamed: 0,diagnostic_class,diagnostic_subclass
IRBBB,CD,IRBBB
ILBBB,CD,ILBBB
CLBBB,CD,CLBBB
CRBBB,CD,CRBBB
LAFB,CD,LAFB/LPFB
WPW,CD,WPW
1AVB,CD,_AVB
IVCD,CD,IVCD
RAO/RAE,HYP,RAO/RAE
LAO/LAE,HYP,LAO/LAE


In [16]:
len(df_unique_diag_subclass['diagnostic_subclass'])

23

In [17]:
subclass_label_map = {k: i for i, k in enumerate(sorted(diag_subclass_to_superclass_map.keys()))}
subclass_label_map

{'AMI': 0,
 'CLBBB': 1,
 'CRBBB': 2,
 'ILBBB': 3,
 'IMI': 4,
 'IRBBB': 5,
 'ISCA': 6,
 'ISCI': 7,
 'ISC_': 8,
 'IVCD': 9,
 'LAFB/LPFB': 10,
 'LAO/LAE': 11,
 'LMI': 12,
 'LVH': 13,
 'NORM': 14,
 'NST_': 15,
 'PMI': 16,
 'RAO/RAE': 17,
 'RVH': 18,
 'SEHYP': 19,
 'STTC': 20,
 'WPW': 21,
 '_AVB': 22}

In [18]:
superclass_label_map = {'NORM': 0, 'CD': 1, 'MI': 2, 'HYP': 3, 'STTC': 4}
superclass_label_map

{'NORM': 0, 'CD': 1, 'MI': 2, 'HYP': 3, 'STTC': 4}

In [19]:
def map_scp_to_diag_class(row):
    # each scp_code field of annotation data frame is a dictionary 
    # which can contain multiple statement:likelihood pairs
    # e.g., {'ASMI': 15.0, 'LVH': 100.0, 'ISC_': 100.0, 'PVC': 100.0, 'ABQRS': 0.0, 'AFIB': 0.0}
    # We first map each SCP code to the corresponding diagnostic subclass keeping the likelihood value (for potential usage later).
    # If there are multiple SCP code that correspond to the same diagnostic subclass, keep the larger likelihood value.
    # Then, map each diagnostic subclass to its superclass.
    # Similarly, we keep the larger likelihood value if there are multiple subclasses mapped to the same super class.
        
    dict_scp_code = row.scp_codes
    valid_scp_code_dict = {s: l for s, l in dict_scp_code.items() if s in scp_to_diag_subclass_map}
    diag_subclass = dict()
    for scp, likelihood in valid_scp_code_dict.items():
        potential_subclass = scp_to_diag_subclass_map[scp]
        # keep all likelihood for each subclass
        if potential_subclass not in diag_subclass:
            diag_subclass[potential_subclass] = []
        diag_subclass[potential_subclass].append(likelihood)
    
    # keep only the maximum likelihood value
    diag_subclass = {key: max(lst_v) for key, lst_v in diag_subclass.items()}
    
    diag_superclass = dict()
    for subclass, likelihood in diag_subclass.items():
        potential_superclass = diag_subclass_to_superclass_map[subclass]
        # keep all likelihood for each superclass
        if potential_superclass not in diag_superclass:
            diag_superclass[potential_superclass] = []
        diag_superclass[potential_superclass].append(likelihood)
    # keep only the maximum likelihood value
    
    diag_superclass = {key: max(lst_v) for key, lst_v in diag_superclass.items()}

    # Note that either super/subclass could be empty (due to no mapping from SCP to subclass) or 0.0 (unknown) likelihood

    # create multi-class multi-label labels (real-valued between 0.0 - 1.0) for both super/sub-classes
    # empty super/sub-classes will have a zero vector label
    label_superclass = np.zeros(len(superclass_label_map))
    for key, likelihood in diag_superclass.items():
        if likelihood == 0.0:
            label_superclass[superclass_label_map[key]] = 0.5
        else:
            label_superclass[superclass_label_map[key]] = likelihood / 100.
    
    label_subclass = np.zeros(len(subclass_label_map))
    for key, likelihood in diag_subclass.items():
        if likelihood == 0.0:
            label_subclass[subclass_label_map[key]] = 0.5
        else:
            label_subclass[subclass_label_map[key]] = likelihood / 100.

    return diag_superclass, diag_subclass, label_superclass, label_subclass

In [20]:
# Map SCP codes to diagnostic super/subclasses
df_ann[['diagnostic_superclass', 'diagnostic_subclass', 'label_superclass', 'label_subclass']] = df_ann.apply(map_scp_to_diag_class, axis=1, result_type='expand')

In [21]:
df_ann

Unnamed: 0_level_0,patient_id,age,sex,height,weight,nurse,site,device,recording_date,report,scp_codes,heart_axis,infarction_stadium1,infarction_stadium2,validated_by,second_opinion,initial_autogenerated_report,validated_by_human,baseline_drift,static_noise,burst_noise,electrodes_problems,extra_beats,pacemaker,strat_fold,filename_lr,filename_hr,diagnostic_superclass,diagnostic_subclass,label_superclass,label_subclass
ecg_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1,Unnamed: 24_level_1,Unnamed: 25_level_1,Unnamed: 26_level_1,Unnamed: 27_level_1,Unnamed: 28_level_1,Unnamed: 29_level_1,Unnamed: 30_level_1,Unnamed: 31_level_1
1,15709.0,56.0,1,,63.0,2.0,0.0,CS-12 E,1984-11-09 09:17:34,sinusrhythmus periphere niederspannung,"{'NORM': 100.0, 'LVOLT': 0.0, 'SR': 0.0}",,,,,False,False,True,,", I-V1,",,,,,3,records100/00000/00001_lr,records500/00000/00001_hr,{'NORM': 100.0},{'NORM': 100.0},"[1.0, 0.0, 0.0, 0.0, 0.0]","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ..."
2,13243.0,19.0,0,,70.0,2.0,0.0,CS-12 E,1984-11-14 12:55:37,sinusbradykardie sonst normales ekg,"{'NORM': 80.0, 'SBRAD': 0.0}",,,,,False,False,True,,,,,,,2,records100/00000/00002_lr,records500/00000/00002_hr,{'NORM': 80.0},{'NORM': 80.0},"[0.8, 0.0, 0.0, 0.0, 0.0]","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ..."
3,20372.0,37.0,1,,69.0,2.0,0.0,CS-12 E,1984-11-15 12:49:10,sinusrhythmus normales ekg,"{'NORM': 100.0, 'SR': 0.0}",,,,,False,False,True,,,,,,,5,records100/00000/00003_lr,records500/00000/00003_hr,{'NORM': 100.0},{'NORM': 100.0},"[1.0, 0.0, 0.0, 0.0, 0.0]","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ..."
4,17014.0,24.0,0,,82.0,2.0,0.0,CS-12 E,1984-11-15 13:44:57,sinusrhythmus normales ekg,"{'NORM': 100.0, 'SR': 0.0}",,,,,False,False,True,", II,III,AVF",,,,,,3,records100/00000/00004_lr,records500/00000/00004_hr,{'NORM': 100.0},{'NORM': 100.0},"[1.0, 0.0, 0.0, 0.0, 0.0]","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ..."
5,17448.0,19.0,1,,70.0,2.0,0.0,CS-12 E,1984-11-17 10:43:15,sinusrhythmus normales ekg,"{'NORM': 100.0, 'SR': 0.0}",,,,,False,False,True,", III,AVR,AVF",,,,,,4,records100/00000/00005_lr,records500/00000/00005_hr,{'NORM': 100.0},{'NORM': 100.0},"[1.0, 0.0, 0.0, 0.0, 0.0]","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ..."
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
21833,17180.0,67.0,1,,,1.0,2.0,AT-60 3,2001-05-31 09:14:35,ventrikulÄre extrasystole(n) sinustachykardie ...,"{'NDT': 100.0, 'PVC': 100.0, 'VCLVH': 0.0, 'ST...",LAD,,,1.0,False,True,True,,", alles,",,,1ES,,7,records100/21000/21833_lr,records500/21000/21833_hr,{'STTC': 100.0},{'STTC': 100.0},"[0.0, 0.0, 0.0, 0.0, 1.0]","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ..."
21834,20703.0,300.0,0,,,1.0,2.0,AT-60 3,2001-06-05 11:33:39,sinusrhythmus lagetyp normal qrs(t) abnorm ...,"{'NORM': 100.0, 'ABQRS': 0.0, 'SR': 0.0}",MID,Stadium II-III,,1.0,False,True,True,,,,,,,4,records100/21000/21834_lr,records500/21000/21834_hr,{'NORM': 100.0},{'NORM': 100.0},"[1.0, 0.0, 0.0, 0.0, 0.0]","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ..."
21835,19311.0,59.0,1,,,1.0,2.0,AT-60 3,2001-06-08 10:30:27,sinusrhythmus lagetyp normal t abnorm in anter...,"{'ISCAS': 50.0, 'SR': 0.0}",MID,,,1.0,True,True,True,,", I-AVR,",,,,,2,records100/21000/21835_lr,records500/21000/21835_hr,{'STTC': 50.0},{'ISCA': 50.0},"[0.0, 0.0, 0.0, 0.0, 0.5]","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.5, 0.0, 0.0, ..."
21836,8873.0,64.0,1,,,1.0,2.0,AT-60 3,2001-06-09 18:21:49,supraventrikulÄre extrasystole(n) sinusrhythmu...,"{'NORM': 100.0, 'SR': 0.0}",LAD,,,1.0,False,True,True,,,,,SVES,,8,records100/21000/21836_lr,records500/21000/21836_hr,{'NORM': 100.0},{'NORM': 100.0},"[1.0, 0.0, 0.0, 0.0, 0.0]","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ..."


Make sure that there is no such case where only one of super/sub-class label is given

In [22]:
df_ann['superclass_labeled'] = df_ann['diagnostic_superclass'].apply(lambda x: len(x) > 0)
df_ann['subclass_labeled'] = df_ann['diagnostic_subclass'].apply(lambda x: len(x) > 0)

In [23]:
assert len(df_ann[(~df_ann['superclass_labeled']) & (df_ann['subclass_labeled'])]) == 0
assert len(df_ann[(df_ann['superclass_labeled']) & (~df_ann['subclass_labeled'])]) == 0

In [24]:
df_ann['is_labeled'] = (df_ann['superclass_labeled']) & (df_ann['subclass_labeled'])

## Load raw signal data

In [25]:
ex_data, ex_meta = wfdb.rdsamp(os.path.join(DATA_ROOT, 'records100/00000/00001_lr'))

In [26]:
ex_data

array([[-0.119, -0.055,  0.064, ..., -0.026, -0.039, -0.079],
       [-0.116, -0.051,  0.065, ..., -0.031, -0.034, -0.074],
       [-0.12 , -0.044,  0.076, ..., -0.028, -0.029, -0.069],
       ...,
       [ 0.069,  0.   , -0.069, ...,  0.024, -0.041, -0.058],
       [ 0.086,  0.004, -0.081, ...,  0.242, -0.046, -0.098],
       [ 0.022, -0.031, -0.054, ...,  0.143, -0.035, -0.12 ]],
      shape=(1000, 12))

In [27]:
ex_data.shape

(1000, 12)

In [28]:
ex_meta

{'fs': 100,
 'sig_len': 1000,
 'n_sig': 12,
 'base_date': None,
 'base_time': None,
 'units': ['mV',
  'mV',
  'mV',
  'mV',
  'mV',
  'mV',
  'mV',
  'mV',
  'mV',
  'mV',
  'mV',
  'mV'],
 'sig_name': ['I',
  'II',
  'III',
  'AVR',
  'AVL',
  'AVF',
  'V1',
  'V2',
  'V3',
  'V4',
  'V5',
  'V6'],
 'comments': []}

## Create HDF5 Dataset

In [29]:
raw_data = load_raw_data(df_ann, TARGET_FS, DATA_ROOT)
raw_data.shape

(21799, 1000, 12)

In [30]:
label_superclass = np.vstack(df_ann['label_superclass'].values)
label_superclass.shape

(21799, 5)

In [31]:
label_subclass = np.vstack(df_ann['label_subclass'].values)
label_subclass.shape

(21799, 23)

In [32]:
is_labeled = df_ann['is_labeled'].values
is_labeled.shape

(21799,)

In [33]:
strat_fold = df_ann['strat_fold'].values
strat_fold.shape

(21799,)

In [34]:
os.makedirs(SAVE_DIR, exist_ok=True)
# Create an HDF5 file with row-wise chunking for both data and labels
with h5py.File(os.path.join(SAVE_DIR, 'dataset.h5'), 'w') as hf:
    # Define chunk size for each row
    hf.create_dataset(f'ecg_data', raw_data.shape, dtype='float32', chunks=(1, *raw_data.shape[1:]), compression=None)
    hf.create_dataset('label_superclass', label_superclass.shape, dtype='float32', chunks=(1, label_superclass.shape[1]), compression=None)
    hf.create_dataset('label_subclass', label_subclass.shape, dtype='float32', chunks=(1, label_subclass.shape[1]), compression=None)
    hf.create_dataset('is_labeled', is_labeled.shape, dtype='bool', chunks=(1,), compression=None)
    hf.create_dataset('strat_fold', strat_fold.shape, dtype='int32', chunks=(1,), compression=None)
    
    # Write data and labels
    hf[f'ecg_data'][:] = raw_data
    hf['label_superclass'][:] = label_superclass
    hf['label_subclass'][:] = label_subclass
    hf['is_labeled'][:] = is_labeled
    hf['strat_fold'][:] = strat_fold

In [35]:
# Open the HDF5 file in read mode
with h5py.File(os.path.join(SAVE_DIR, 'dataset.h5'), "r") as hf:
    # List all groups and datasets in the file
    def print_structure(name, obj):
        print(f"{name}: {obj}")

    # Walk through the file structure
    hf.visititems(print_structure)

ecg_data: <HDF5 dataset "ecg_data": shape (21799, 1000, 12), type "<f4">
is_labeled: <HDF5 dataset "is_labeled": shape (21799,), type "|b1">
label_subclass: <HDF5 dataset "label_subclass": shape (21799, 23), type "<f4">
label_superclass: <HDF5 dataset "label_superclass": shape (21799, 5), type "<f4">
strat_fold: <HDF5 dataset "strat_fold": shape (21799,), type "<i4">


In [36]:
class_label_maps = {'superclass_label_map': superclass_label_map, 'subclass_label_map': subclass_label_map, 'diag_subclass_to_superclass_map': diag_subclass_to_superclass_map}

In [37]:
with open(os.path.join(SAVE_DIR, 'class_label_maps.pkl'), 'wb') as f:
    pickle.dump(class_label_maps, f, pickle.HIGHEST_PROTOCOL)