In [3]:
pip install wfdb

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting wfdb
  Downloading wfdb-4.1.2-py3-none-any.whl (159 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m160.0/160.0 kB[0m [31m8.2 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: wfdb
Successfully installed wfdb-4.1.2


In [4]:
import os
import numpy as np
import wfdb

# download dataset
dataset_root = './dataset'
download_dir = os.path.join(dataset_root, 'download')
wfdb.dl_database('mitdb', dl_dir=download_dir)

# setting
window_size=720  # 2 seconds
sample_rate = 360  # 360 Hz

# list
train_record_list = [
        '101', '106', '108', '109', '112', '115', '116', '118', '119', '122',
        '124', '201', '203', '205', '207', '208', '209', '215', '220', '223', '230'
        ]
test_record_list = [
        '100', '103', '105', '111', '113', '117', '121', '123', '200', '210',
        '212', '213', '214', '219', '221', '222', '228', '231', '232', '233', '234'
        ]

# annotation
labels = ['N', 'V']
valid_symbols = ['N', 'L', 'R', 'e', 'j', 'V', 'E']
label_map = {'N': 'N', 'L': 'N', 'R': 'N', 'e': 'N', 'j': 'N','V': 'V', 'E': 'V'}

def _load_data(base_record, channel=0):
    record_name = os.path.join(download_dir, str(base_record))
    # read dat file
    signals, fields = wfdb.rdsamp(record_name)
    assert fields['fs'] == sample_rate
    # read annotation file
    annotation = wfdb.rdann(record_name, 'atr')
    symbols = annotation.symbol
    positions = annotation.sample
    return signals[:, channel], symbols, positions

def _segment_data(signal, symbols, positions):
    X, y = [], []
    sig_len = len(signal)
    for i in range(len(symbols)):
        start = positions[i] - window_size // 2
        end = positions[i] + window_size // 2
        if symbols[i] in valid_symbols and start >= 0 and end <= sig_len:
            segment = signal[start:end]
            assert len(segment) == window_size, "Invalid length"
            X.append(segment)
            y.append(labels.index(label_map[symbols[i]]))
    return np.array(X), np.array(y)

def preprocess_dataset(record_list, mode):
    Xs, ys = [], []
    save_dir = os.path.join(dataset_root)
    for i in range(len(record_list)):
        signal, symbols, positions = _load_data(record_list[i])
        signal = (signal - np.mean(signal)) / np.std(signal)
        X, y = _segment_data(signal, symbols, positions)
        Xs.append(X)
        ys.append(y)
    os.makedirs(save_dir, exist_ok=True)
    np.save(os.path.join(save_dir, "x_"+str(mode)+".npy"), np.vstack(Xs))
    np.save(os.path.join(save_dir, "y_"+str(mode)+".npy"), np.concatenate(ys))

preprocess_dataset(train_record_list, "train")
preprocess_dataset(test_record_list, "test")

Generating record list for: 100
Generating record list for: 101
Generating record list for: 102
Generating record list for: 103
Generating record list for: 104
Generating record list for: 105
Generating record list for: 106
Generating record list for: 107
Generating record list for: 108
Generating record list for: 109
Generating record list for: 111
Generating record list for: 112
Generating record list for: 113
Generating record list for: 114
Generating record list for: 115
Generating record list for: 116
Generating record list for: 117
Generating record list for: 118
Generating record list for: 119
Generating record list for: 121
Generating record list for: 122
Generating record list for: 123
Generating record list for: 124
Generating record list for: 200
Generating record list for: 201
Generating record list for: 202
Generating record list for: 203
Generating record list for: 205
Generating record list for: 207
Generating record list for: 208
Generating record list for: 209
Generati