In [1]:
import os 
import wfdb
import ast
import pandas as pd 
import numpy as np 
from tqdm import tqdm
from sklearn.model_selection import train_test_split

pd.set_option('display.max_columns', None)

In [2]:
data_path = 'data'

In [3]:
annotation = pd.read_csv(os.path.join(data_path, 'ptbxl_database.csv'))
annotation.ecg_id -= 1 

In [4]:
annotation.head()

Unnamed: 0,ecg_id,patient_id,age,sex,height,weight,nurse,site,device,recording_date,report,scp_codes,heart_axis,infarction_stadium1,infarction_stadium2,validated_by,second_opinion,initial_autogenerated_report,validated_by_human,baseline_drift,static_noise,burst_noise,electrodes_problems,extra_beats,pacemaker,strat_fold,filename_lr,filename_hr,rhythm_class,is_arrhytmia
0,-2,15709.0,56.0,1,,63.0,2.0,0.0,CS-12 E,1984-11-09 09:17:34,sinusrhythmus periphere niederspannung,"{'NORM': 100.0, 'LVOLT': 0.0, 'SR': 0.0}",,,,,False,False,True,,", I-V1,",,,,,3,records100/00000/00001_lr,records500/00000/00001_hr,['sinus rhythm'],0
1,-1,13243.0,19.0,0,,70.0,2.0,0.0,CS-12 E,1984-11-14 12:55:37,sinusbradykardie sonst normales ekg,"{'NORM': 80.0, 'SBRAD': 0.0}",,,,,False,False,True,,,,,,,2,records100/00000/00002_lr,records500/00000/00002_hr,['sinus bradycardia'],1
2,0,20372.0,37.0,1,,69.0,2.0,0.0,CS-12 E,1984-11-15 12:49:10,sinusrhythmus normales ekg,"{'NORM': 100.0, 'SR': 0.0}",,,,,False,False,True,,,,,,,5,records100/00000/00003_lr,records500/00000/00003_hr,['sinus rhythm'],0
3,1,17014.0,24.0,0,,82.0,2.0,0.0,CS-12 E,1984-11-15 13:44:57,sinusrhythmus normales ekg,"{'NORM': 100.0, 'SR': 0.0}",,,,,False,False,True,", II,III,AVF",,,,,,3,records100/00000/00004_lr,records500/00000/00004_hr,['sinus rhythm'],0
4,2,17448.0,19.0,1,,70.0,2.0,0.0,CS-12 E,1984-11-17 10:43:15,sinusrhythmus normales ekg,"{'NORM': 100.0, 'SR': 0.0}",,,,,False,False,True,", III,AVR,AVF",,,,,,4,records100/00000/00005_lr,records500/00000/00005_hr,['sinus rhythm'],0


In [5]:
annotation.shape

(21837, 30)

In [6]:
annotation.scp_codes = annotation.scp_codes.apply(lambda x: ast.literal_eval(x))

In [7]:
example_wf = os.path.join(data_path, annotation.iloc[0].filename_lr)

In [8]:
signal_data, meta_data = wfdb.rdsamp(example_wf)

In [9]:
signal_data

array([[-0.119, -0.055,  0.064, ..., -0.026, -0.039, -0.079],
       [-0.116, -0.051,  0.065, ..., -0.031, -0.034, -0.074],
       [-0.12 , -0.044,  0.076, ..., -0.028, -0.029, -0.069],
       ...,
       [ 0.069,  0.   , -0.069, ...,  0.024, -0.041, -0.058],
       [ 0.086,  0.004, -0.081, ...,  0.242, -0.046, -0.098],
       [ 0.022, -0.031, -0.054, ...,  0.143, -0.035, -0.12 ]])

In [10]:
signal_data.shape

(1000, 12)

In [11]:
signal_data[:, 0].shape

(1000,)

In [12]:
meta_data

{'fs': 100,
 'sig_len': 1000,
 'n_sig': 12,
 'base_date': None,
 'base_time': None,
 'units': ['mV',
  'mV',
  'mV',
  'mV',
  'mV',
  'mV',
  'mV',
  'mV',
  'mV',
  'mV',
  'mV',
  'mV'],
 'sig_name': ['I',
  'II',
  'III',
  'AVR',
  'AVL',
  'AVF',
  'V1',
  'V2',
  'V3',
  'V4',
  'V5',
  'V6'],
 'comments': []}

In [13]:
meta_data['sig_name']

['I', 'II', 'III', 'AVR', 'AVL', 'AVF', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6']

In [14]:
len(meta_data['sig_name'])

12

In [15]:
def load_raw_data(df, sampling_rate, path):
    if sampling_rate == 100:
        data = [wfdb.rdsamp(os.path.join(path, file)) for file in tqdm(df.filename_lr)]
    else:
        data = [wfdb.rdsamp(os.path.join(path, file)) for file in tqdm(df.filename_hr)]
    
    data = np.array([signal for signal, meta in data])
    return data


sampling_rate=100

data = load_raw_data(annotation, sampling_rate, data_path)

100%|████████████████████████████████████████████████████████████████████████████| 21837/21837 [07:41<00:00, 47.29it/s]


In [16]:
sample = data[0]

In [17]:
sample[:10, :]

array([[-0.119, -0.055,  0.064,  0.086, -0.091,  0.004, -0.069, -0.031,
         0.   , -0.026, -0.039, -0.079],
       [-0.116, -0.051,  0.065,  0.083, -0.09 ,  0.006, -0.064, -0.036,
        -0.003, -0.031, -0.034, -0.074],
       [-0.12 , -0.044,  0.076,  0.082, -0.098,  0.016, -0.058, -0.034,
        -0.01 , -0.028, -0.029, -0.069],
       [-0.117, -0.038,  0.08 ,  0.077, -0.098,  0.021, -0.05 , -0.03 ,
        -0.015, -0.023, -0.022, -0.064],
       [-0.103, -0.031,  0.072,  0.066, -0.087,  0.021, -0.045, -0.027,
        -0.02 , -0.019, -0.018, -0.058],
       [-0.097, -0.025,  0.071,  0.061, -0.084,  0.023, -0.036, -0.025,
        -0.009, -0.014, -0.012, -0.052],
       [-0.119, -0.014,  0.106,  0.066, -0.112,  0.046, -0.029, -0.012,
         0.005, -0.008, -0.007, -0.048],
       [-0.096,  0.008,  0.104,  0.044, -0.1  ,  0.056, -0.023,  0.003,
         0.018,  0.002, -0.001, -0.041],
       [-0.048,  0.044,  0.092,  0.002, -0.07 ,  0.068, -0.015,  0.018,
         0.021,  0.009, 

In [18]:
np.save('signals.npy', data) 

In [19]:
agg_df = pd.read_csv(os.path.join(data_path, 'scp_statements.csv'), index_col=0)
agg_df = agg_df[agg_df.rhythm == 1]

In [20]:
agg_df.head()

Unnamed: 0,description,diagnostic,form,rhythm,diagnostic_class,diagnostic_subclass,Statement Category,SCP-ECG Statement Description,AHA code,aECG REFID,CDISC Code,DICOM Code
SR,sinus rhythm,,,1.0,,,Statements related to impulse formation (abnor...,sinus rhythm,20.0,MDC_ECG_RHY_SINUS_RHY,,
AFIB,atrial fibrillation,,,1.0,,,Statements related to impulse formation (abnor...,atrial fibrillation,50.0,MDC_ECG_RHY_ATR_FIB,,D3-31520
STACH,sinus tachycardia,,,1.0,,,Statements related to impulse formation (abnor...,sinus tachycardia,21.0,MDC_ECG_RHY_SINUS_TACHY,,
SARRH,sinus arrhythmia,,,1.0,,,Statements related to impulse formation (abnor...,sinus arrhythmia,23.0,MDC_ECG_RHY_SINUS_ARRHY,,
SBRAD,sinus bradycardia,,,1.0,,,Statements related to impulse formation (abnor...,sinus bradycardia,22.0,MDC_ECG_RHY_SINUS_BRADY,,


In [21]:
rhythm_classes = set(agg_df.description.to_list())
normal_rhythm = set(['sinus rhythm', 'normal functioning artificial pacemaker'])
arrhytmia_classes = rhythm_classes - normal_rhythm

In [22]:
len(arrhytmia_classes)

10

In [23]:
agg_df.shape

(12, 12)

In [24]:
def aggregate_rhythm(y_dic):
    tmp = []
    for key in y_dic.keys():
        if key in agg_df.index:
            tmp.append(agg_df.loc[key].description)
    return list(set(tmp))

In [25]:
annotation['rhythm_class'] = annotation.scp_codes.apply(aggregate_rhythm)

In [26]:
annotation.head()

Unnamed: 0,ecg_id,patient_id,age,sex,height,weight,nurse,site,device,recording_date,report,scp_codes,heart_axis,infarction_stadium1,infarction_stadium2,validated_by,second_opinion,initial_autogenerated_report,validated_by_human,baseline_drift,static_noise,burst_noise,electrodes_problems,extra_beats,pacemaker,strat_fold,filename_lr,filename_hr,rhythm_class,is_arrhytmia
0,-2,15709.0,56.0,1,,63.0,2.0,0.0,CS-12 E,1984-11-09 09:17:34,sinusrhythmus periphere niederspannung,"{'NORM': 100.0, 'LVOLT': 0.0, 'SR': 0.0}",,,,,False,False,True,,", I-V1,",,,,,3,records100/00000/00001_lr,records500/00000/00001_hr,[sinus rhythm],0
1,-1,13243.0,19.0,0,,70.0,2.0,0.0,CS-12 E,1984-11-14 12:55:37,sinusbradykardie sonst normales ekg,"{'NORM': 80.0, 'SBRAD': 0.0}",,,,,False,False,True,,,,,,,2,records100/00000/00002_lr,records500/00000/00002_hr,[sinus bradycardia],1
2,0,20372.0,37.0,1,,69.0,2.0,0.0,CS-12 E,1984-11-15 12:49:10,sinusrhythmus normales ekg,"{'NORM': 100.0, 'SR': 0.0}",,,,,False,False,True,,,,,,,5,records100/00000/00003_lr,records500/00000/00003_hr,[sinus rhythm],0
3,1,17014.0,24.0,0,,82.0,2.0,0.0,CS-12 E,1984-11-15 13:44:57,sinusrhythmus normales ekg,"{'NORM': 100.0, 'SR': 0.0}",,,,,False,False,True,", II,III,AVF",,,,,,3,records100/00000/00004_lr,records500/00000/00004_hr,[sinus rhythm],0
4,2,17448.0,19.0,1,,70.0,2.0,0.0,CS-12 E,1984-11-17 10:43:15,sinusrhythmus normales ekg,"{'NORM': 100.0, 'SR': 0.0}",,,,,False,False,True,", III,AVR,AVF",,,,,,4,records100/00000/00005_lr,records500/00000/00005_hr,[sinus rhythm],0


In [27]:
def get_target_column(row, arrhytmia_classes):
    cur_rhythm = set(row['rhythm_class'])
    
    normal_rhythm = set(['sinus rhythm', 'normal functioning artificial pacemaker'])
    if len(cur_rhythm & arrhytmia_classes) > 0:
        return 1 
    return 0 

In [28]:
annotation.rhythm_class.value_counts()

[sinus rhythm]                                                                           16721
[atrial fibrillation]                                                                     1484
[sinus tachycardia]                                                                        805
[]                                                                                         771
[sinus arrhythmia]                                                                         765
[sinus bradycardia]                                                                        629
[normal functioning artificial pacemaker]                                                  287
[supraventricular arrhythmia]                                                              150
[bigeminal pattern (unknown origin, SV or Ventricular), sinus rhythm]                       37
[atrial flutter]                                                                            36
[bigeminal pattern (unknown origin, SV or Ventricu

In [29]:
annotation['is_arrhytmia'] = annotation.apply(lambda row: get_target_column(row, arrhytmia_classes), axis=1)

In [30]:
annotation['is_arrhytmia'].value_counts()

0    17782
1     4055
Name: is_arrhytmia, dtype: int64

In [31]:
annotation.to_csv(os.path.join(data_path, 'ptbxl_database.csv'), index=False)

# train test split

In [32]:
TEST_SIZE = 0.25

In [33]:
target = annotation['is_arrhytmia'].to_numpy()

In [34]:
np.save(os.path.join(data_path, 'target.npy'), target)

In [35]:
X_train, X_test, y_train, y_test = train_test_split(data, target, test_size=TEST_SIZE, random_state=42)

In [38]:
np.save(os.path.join(data_path, 'train_data.npy'), X_train)
np.save(os.path.join(data_path, 'train_target.npy'), y_train) 
            
np.save(os.path.join(data_path, 'test_data.npy'), X_test)
np.save(os.path.join(data_path, 'test_target.npy'), y_test) 