In [16]:
import argparse
import os

import numpy as np
import pandas as pd
import wfdb
from tqdm import tqdm


# _LEAD_NAMES = ["I", "II", "III", "aVR", "aVL", "aVF", "V1", "V2", "V3", "V4", "V5", "V6"]
_LEAD_NAMES = ['I', 'II', 'III', 'AVR', 'AVL', 'AVF', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6']

In [10]:
def get_parser():
    description = "Process WFDB ECG database."
    # parser = argparse.ArgumentParser(description=description)
    parser = argparse.ArgumentParser()
    parser.add_argument('-i',
                        '--input_dir',
                        type=str,
                        # required=True,
                        default='/tf/physionet.org/files/ptb-xl/1.0.3/records500',
                        help="Path to the WFDB ECG database directory.")
    parser.add_argument('-o',
                        '--output_dir',
                        type=str,
                        # required=True,
                        default='./ptbxl/ecgs500/',
                        help="Path to the directory where the preprocessed signals will be saved.")
    parser.add_argument('--index_path',
                        type=str,
                        default='./ptbxl/index.csv',
                        help="Path to the index file.")
    args = parser.parse_args("")
    return args

In [11]:
args = get_parser()
args

Namespace(input_dir='/tf/physionet.org/files/ptb-xl/1.0.3/records500', output_dir='./ptbxl/ecgs500/', index_path='./ptbxl/index.csv')

get_parser() 부분 실행 완 => args 처리됨. 이후 run 코드 순서대로 실행하면 됨!!

In [4]:
def find_records(root_dir):
    """Find all the .hea files in the root directory and its subdirectories.
    Args:
        root_dir (str): The directory to search for .hea files.
    Returns:
        records (set): A set of record names.
                       (e.g., ['database/1/ecg001', 'database/1/ecg001', ..., 'database/9/ecg991'])
    """
    records = set()
    for root, _, files in os.walk(root_dir):
        for file in files:
            extension = os.path.splitext(file)[1]
            if extension == '.hea':
                record = os.path.relpath(os.path.join(root, file), root_dir)[:-4]
                records.add(record)
    records = sorted(records)
    return records

In [5]:
# Identify the header fiels
record_rel_paths = find_records(args.input_dir)
record_rel_paths

['00000/00001_hr',
 '00000/00002_hr',
 '00000/00003_hr',
 '00000/00004_hr',
 '00000/00005_hr',
 '00000/00006_hr',
 '00000/00007_hr',
 '00000/00008_hr',
 '00000/00009_hr',
 '00000/00010_hr',
 '00000/00011_hr',
 '00000/00012_hr',
 '00000/00013_hr',
 '00000/00014_hr',
 '00000/00015_hr',
 '00000/00016_hr',
 '00000/00017_hr',
 '00000/00018_hr',
 '00000/00019_hr',
 '00000/00020_hr',
 '00000/00021_hr',
 '00000/00022_hr',
 '00000/00023_hr',
 '00000/00024_hr',
 '00000/00025_hr',
 '00000/00026_hr',
 '00000/00027_hr',
 '00000/00028_hr',
 '00000/00029_hr',
 '00000/00030_hr',
 '00000/00031_hr',
 '00000/00032_hr',
 '00000/00033_hr',
 '00000/00034_hr',
 '00000/00035_hr',
 '00000/00036_hr',
 '00000/00037_hr',
 '00000/00038_hr',
 '00000/00039_hr',
 '00000/00040_hr',
 '00000/00041_hr',
 '00000/00042_hr',
 '00000/00043_hr',
 '00000/00044_hr',
 '00000/00045_hr',
 '00000/00046_hr',
 '00000/00047_hr',
 '00000/00048_hr',
 '00000/00049_hr',
 '00000/00050_hr',
 '00000/00051_hr',
 '00000/00052_hr',
 '00000/0005

In [6]:
len(record_rel_paths), record_rel_paths[:4], record_rel_paths[-4:]

(21799,
 ['00000/00001_hr', '00000/00002_hr', '00000/00003_hr', '00000/00004_hr'],
 ['21000/21834_hr', '21000/21835_hr', '21000/21836_hr', '21000/21837_hr'])

In [7]:
print(f"Found {len(record_rel_paths)} records.")

Found 21799 records.


In [42]:
# Prepare an index dataframe
index_df = pd.DataFrame(columns = ["RELATIVE_FILE_PATH", "FILE_NAME", "SAMPLE_RATE", "SOURCE"])
index_df

Unnamed: 0,RELATIVE_FILE_PATH,FILE_NAME,SAMPLE_RATE,SOURCE


In [21]:
def moving_window_crop(x: np.ndarray, crop_length: int, crop_stride: int) -> np.ndarray:
    """Crop the input sequence with a moving window.
    """
    if crop_length > x.shape[1]:
        raise ValueError(f"crop_length must be smaller than the length of x ({x.shape[1]}).")
    start_idx = np.arange(0, x.shape[1] - crop_length + 1, crop_stride)
    return [x[:, i:i + crop_length] for i in start_idx]

In [43]:
for record_rel_path in tqdm(record_rel_paths):
    record_rel_dir, record_name = os.path.split(record_rel_path)
    save_dir = os.path.join(args.output_dir, record_rel_dir)
    os.makedirs(save_dir, exist_ok=True)
    source_name = record_rel_dir.split("/")[0]
    signal, record_info = wfdb.rdsamp(os.path.join(args.input_dir, record_rel_path))
    lead_idx = np.array([record_info["sig_name"].index(lead_name) for lead_name in _LEAD_NAMES])
    signal = signal[:, lead_idx]
    fs = record_info["fs"]
    signal_length = record_info["sig_len"]
    if signal_length < 10 * fs:  # Exclude the ECGs with lengths of less than 10 seconds
        continue
    cropped_signals = moving_window_crop(signal.T, crop_length=10 * fs, crop_stride=10 * fs)
    for idx, cropped_signal in enumerate(cropped_signals):
        if cropped_signal.shape[1] != 10 * fs or np.isnan(cropped_signal).any():
            continue
        pd.to_pickle(cropped_signal.astype(np.float32),
                     os.path.join(save_dir, f"{record_name}_{idx}.pkl"))
        index_df.loc[num_saved] = [f"{record_rel_path}_{idx}.pkl",
                                   f"{record_name}_{idx}.pkl",
                                   fs,
                                   source_name]
        num_saved += 1

print(f"Saved {num_saved} cropped signals.")
os.makedirs(os.path.dirname(args.index_path), exist_ok=True)
index_df.to_csv(args.index_path, index=False)

100%|████████████████████████████████████████████████████████████| 21799/21799 [17:20<00:00, 20.96it/s]


Saved 21799 cropped signals.


In [44]:
index_df

Unnamed: 0,RELATIVE_FILE_PATH,FILE_NAME,SAMPLE_RATE,SOURCE
0,00000/00001_hr_0.pkl,00001_hr_0.pkl,500,00000
1,00000/00002_hr_0.pkl,00002_hr_0.pkl,500,00000
2,00000/00003_hr_0.pkl,00003_hr_0.pkl,500,00000
3,00000/00004_hr_0.pkl,00004_hr_0.pkl,500,00000
4,00000/00005_hr_0.pkl,00005_hr_0.pkl,500,00000
...,...,...,...,...
21794,21000/21833_hr_0.pkl,21833_hr_0.pkl,500,21000
21795,21000/21834_hr_0.pkl,21834_hr_0.pkl,500,21000
21796,21000/21835_hr_0.pkl,21835_hr_0.pkl,500,21000
21797,21000/21836_hr_0.pkl,21836_hr_0.pkl,500,21000


In [17]:
# Save all the cropped signals
num_saved = 0
for record_rel_path in tqdm(record_rel_paths):
    record_rel_dir, record_name = os.path.split(record_rel_path)
    save_dir = os.path.join(args.output_dir, record_rel_dir)
    os.makedirs(save_dir, exist_ok=True)
    source_name = record_rel_dir.split("/")[0]
    signal, record_info = wfdb.rdsamp(os.path.join(args.input_dir, record_rel_path))
    lead_idx = np.array([record_info["sig_name"].index(lead_name) for lead_name in _LEAD_NAMES])
    signal = signal[:, lead_idx]
    fs = record_info["fs"]
    signal_length = record_info["sig_len"]
    # if signal_length < 10 * fs:  # Exclude the ECGs with lengths of less than 10 seconds
    #     continue
    # cropped_signals = moving_window_crop(signal.T, crop_length=10 * fs, crop_stride=10 * fs)
    # for idx, cropped_signal in enumerate(cropped_signals):
    #     if cropped_signal.shape[1] != 10 * fs or np.isnan(cropped_signal).any():
    #         continue
    #     pd.to_pickle(cropped_signal.astype(np.float32),
    #                  os.path.join(save_dir, f"{record_name}_{idx}.pkl"))
    #     index_df.loc[num_saved] = [f"{record_rel_path}_{idx}.pkl",
    #                                f"{record_name}_{idx}.pkl",
    #                                fs,
    #                                source_name]
    #     num_saved += 1
    break

  0%|                                                                        | 0/21799 [00:00<?, ?it/s]


In [18]:
record_rel_dir, record_name, save_dir, source_name

('00000', '00001_hr', './ptbxl/ecgs500/00000', '00000')

In [19]:
signal

array([[-0.115, -0.05 ,  0.065, ..., -0.035, -0.035, -0.075],
       [-0.115, -0.05 ,  0.065, ..., -0.035, -0.035, -0.075],
       [-0.115, -0.05 ,  0.065, ..., -0.035, -0.035, -0.075],
       ...,
       [ 0.21 ,  0.205, -0.005, ...,  0.185,  0.17 ,  0.18 ],
       [ 0.21 ,  0.205, -0.005, ...,  0.185,  0.17 ,  0.18 ],
       [ 0.21 ,  0.205, -0.005, ...,  0.185,  0.17 ,  0.18 ]])

In [20]:
record_info

{'fs': 500,
 'sig_len': 5000,
 'n_sig': 12,
 'base_date': None,
 'base_time': None,
 'units': ['mV',
  'mV',
  'mV',
  'mV',
  'mV',
  'mV',
  'mV',
  'mV',
  'mV',
  'mV',
  'mV',
  'mV'],
 'sig_name': ['I',
  'II',
  'III',
  'AVR',
  'AVL',
  'AVF',
  'V1',
  'V2',
  'V3',
  'V4',
  'V5',
  'V6'],
 'comments': []}

In [None]:
if signal_length < 10 * fs:  # Exclude the ECGs with lengths of less than 10 seconds
    continue
cropped_signals = moving_window_crop(signal.T, crop_length=10 * fs, crop_stride=10 * fs)
for idx, cropped_signal in enumerate(cropped_signals):
    if cropped_signal.shape[1] != 10 * fs or np.isnan(cropped_signal).any():
        continue
    pd.to_pickle(cropped_signal.astype(np.float32),
                 os.path.join(save_dir, f"{record_name}_{idx}.pkl"))
    index_df.loc[num_saved] = [f"{record_rel_path}_{idx}.pkl",
                               f"{record_name}_{idx}.pkl",
                               fs,
                               source_name]
    num_saved += 1

In [22]:
if signal_length < 10 * fs:
    print(1)
else: print(0)

0


In [24]:
signal_length

5000

In [25]:
fs * 10

5000

In [26]:
cropped_signals = moving_window_crop(signal.T, crop_length=10 * fs, crop_stride=10 * fs)
cropped_signals

[array([[-0.115, -0.115, -0.115, ...,  0.21 ,  0.21 ,  0.21 ],
        [-0.05 , -0.05 , -0.05 , ...,  0.205,  0.205,  0.205],
        [ 0.065,  0.065,  0.065, ..., -0.005, -0.005, -0.005],
        ...,
        [-0.035, -0.035, -0.035, ...,  0.185,  0.185,  0.185],
        [-0.035, -0.035, -0.035, ...,  0.17 ,  0.17 ,  0.17 ],
        [-0.075, -0.075, -0.075, ...,  0.18 ,  0.18 ,  0.18 ]])]

In [28]:
np.array(signal).shape

(5000, 12)

In [29]:
np.array(cropped_signals).shape

(1, 12, 5000)

In [33]:
for idx, cropped_signal in enumerate(cropped_signals):
    print(cropped_signal)
    break

[[-0.115 -0.115 -0.115 ...  0.21   0.21   0.21 ]
 [-0.05  -0.05  -0.05  ...  0.205  0.205  0.205]
 [ 0.065  0.065  0.065 ... -0.005 -0.005 -0.005]
 ...
 [-0.035 -0.035 -0.035 ...  0.185  0.185  0.185]
 [-0.035 -0.035 -0.035 ...  0.17   0.17   0.17 ]
 [-0.075 -0.075 -0.075 ...  0.18   0.18   0.18 ]]


In [35]:
[f"{record_rel_path}_{idx}.pkl", f"{record_name}_{idx}.pkl", fs, source_name]

['00000/00001_hr_0.pkl', '00001_hr_0.pkl', 500, '00000']

In [37]:
Y = pd.read_csv('/tf/physionet.org/files/ptb-xl/1.0.3/ptbxl_database.csv', index_col='ecg_id')
Y

Unnamed: 0_level_0,patient_id,age,sex,height,weight,nurse,site,device,recording_date,report,...,validated_by_human,baseline_drift,static_noise,burst_noise,electrodes_problems,extra_beats,pacemaker,strat_fold,filename_lr,filename_hr
ecg_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
1,15709.0,56.0,1,,63.0,2.0,0.0,CS-12 E,1984-11-09 09:17:34,sinusrhythmus periphere niederspannung,...,True,,", I-V1,",,,,,3,records100/00000/00001_lr,records500/00000/00001_hr
2,13243.0,19.0,0,,70.0,2.0,0.0,CS-12 E,1984-11-14 12:55:37,sinusbradykardie sonst normales ekg,...,True,,,,,,,2,records100/00000/00002_lr,records500/00000/00002_hr
3,20372.0,37.0,1,,69.0,2.0,0.0,CS-12 E,1984-11-15 12:49:10,sinusrhythmus normales ekg,...,True,,,,,,,5,records100/00000/00003_lr,records500/00000/00003_hr
4,17014.0,24.0,0,,82.0,2.0,0.0,CS-12 E,1984-11-15 13:44:57,sinusrhythmus normales ekg,...,True,", II,III,AVF",,,,,,3,records100/00000/00004_lr,records500/00000/00004_hr
5,17448.0,19.0,1,,70.0,2.0,0.0,CS-12 E,1984-11-17 10:43:15,sinusrhythmus normales ekg,...,True,", III,AVR,AVF",,,,,,4,records100/00000/00005_lr,records500/00000/00005_hr
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
21833,17180.0,67.0,1,,,1.0,2.0,AT-60 3,2001-05-31 09:14:35,ventrikulÄre extrasystole(n) sinustachykardie ...,...,True,,", alles,",,,1ES,,7,records100/21000/21833_lr,records500/21000/21833_hr
21834,20703.0,300.0,0,,,1.0,2.0,AT-60 3,2001-06-05 11:33:39,sinusrhythmus lagetyp normal qrs(t) abnorm ...,...,True,,,,,,,4,records100/21000/21834_lr,records500/21000/21834_hr
21835,19311.0,59.0,1,,,1.0,2.0,AT-60 3,2001-06-08 10:30:27,sinusrhythmus lagetyp normal t abnorm in anter...,...,True,,", I-AVR,",,,,,2,records100/21000/21835_lr,records500/21000/21835_hr
21836,8873.0,64.0,1,,,1.0,2.0,AT-60 3,2001-06-09 18:21:49,supraventrikulÄre extrasystole(n) sinusrhythmu...,...,True,,,,,SVES,,8,records100/21000/21836_lr,records500/21000/21836_hr


In [39]:
Y.scp_codes # = Y.scp_codes.apply(lambda x: ast.literal_eval(x))

ecg_id
1                 {'NORM': 100.0, 'LVOLT': 0.0, 'SR': 0.0}
2                             {'NORM': 80.0, 'SBRAD': 0.0}
3                               {'NORM': 100.0, 'SR': 0.0}
4                               {'NORM': 100.0, 'SR': 0.0}
5                               {'NORM': 100.0, 'SR': 0.0}
                               ...                        
21833    {'NDT': 100.0, 'PVC': 100.0, 'VCLVH': 0.0, 'ST...
21834             {'NORM': 100.0, 'ABQRS': 0.0, 'SR': 0.0}
21835                           {'ISCAS': 50.0, 'SR': 0.0}
21836                           {'NORM': 100.0, 'SR': 0.0}
21837                           {'NORM': 100.0, 'SR': 0.0}
Name: scp_codes, Length: 21799, dtype: object

In [41]:
import ast
Y.scp_codes = Y.scp_codes.apply(lambda x: ast.literal_eval(x))
Y.scp_codes

ecg_id
1                 {'NORM': 100.0, 'LVOLT': 0.0, 'SR': 0.0}
2                             {'NORM': 80.0, 'SBRAD': 0.0}
3                               {'NORM': 100.0, 'SR': 0.0}
4                               {'NORM': 100.0, 'SR': 0.0}
5                               {'NORM': 100.0, 'SR': 0.0}
                               ...                        
21833    {'NDT': 100.0, 'PVC': 100.0, 'VCLVH': 0.0, 'ST...
21834             {'NORM': 100.0, 'ABQRS': 0.0, 'SR': 0.0}
21835                           {'ISCAS': 50.0, 'SR': 0.0}
21836                           {'NORM': 100.0, 'SR': 0.0}
21837                           {'NORM': 100.0, 'SR': 0.0}
Name: scp_codes, Length: 21799, dtype: object