In [None]:
import os
import scipy.io
from tqdm import tqdm
import numpy as np
from scipy.signal import butter, lfilter
import pickle

def butter_bandpass(lowcut, highcut, fs, order=5):
    nyq = 0.5 * fs
    low = lowcut / nyq
    high = highcut / nyq
    b, a = butter(order, [low, high], btype='band')
    return b, a


def butter_bandpass_filter(data, lowcut, highcut, fs, order=5):
    b, a = butter_bandpass(lowcut, highcut, fs, order=order)
    y = lfilter(b, a, data)
    return y

def butter_lowpass(lowcut, fs, order=5):
    nyq = 0.5 * fs
    low = lowcut / nyq
    b, a = butter(order, low, btype='low')
    return b, a


def butter_lowpass_filter(data, lowcut, fs, order=5):
    b, a = butter_lowpass(lowcut, fs, order=order)
    y = lfilter(b, a, data)
    return y

In [None]:
base_dir = "[path to Zheng et al. dataset downloaded from PhysioNet]/WFDBRecords"

In [None]:
files = []

for subdir in os.listdir(base_dir):
    if len(subdir) == 2: #all subdir lenghts in this dataset are 2 characters
        for subsubdir in os.listdir(os.path.join(base_dir, subdir)):
            if len(subsubdir) == 3: #all subdir lengths in this dataset are 3 characters
                for file in os.listdir(os.path.join(base_dir, subdir, subsubdir)):
                    if ".mat" in file:
                        files.append(os.path.join(base_dir, subdir, subsubdir, file))

In [None]:
def extract_Dxs(file):
    if ".mat" in file:
        file = file.replace(".mat", ".hea")
    elif ".hea" not in file:
        file += ".hea"
    
    with open(file, "r") as f:
        lines = f.readlines()
    
    for item in lines:
        if "#Dx" in item:
            Dxs = [dx.strip() for dx in item.split("Dx: ")[-1].split(',')]
    return Dxs

Dxs = [] #get set of dxs
for i in tqdm(range(0, len(files))):
    Dxs += extract_Dxs(files[i])
Dxs = list(set(Dxs))

In [None]:
X = []
labels = []

for i in tqdm(range(0, len(files))):
    raw = np.float32(scipy.io.loadmat(files[i])['val'])
    
    x = np.zeros((12, 1200))
    
    for channel in range(0, 12):
        x[channel] = scipy.signal.resample(raw[channel], 1200)
    
    X.append(x)
    labels.append(extract_Dxs(files[i]))

X = np.float32(X)
X = np.transpose(X.reshape(X.shape+(1,)), (0, 3, 1, 2))

In [None]:
MEANS = []
STDEVS = []

for i in tqdm(range(0, 12)):
    MEANS.append(np.mean(X[:, :, i, :]))
    STDEVS.append(np.std(X[:, :, i, :]))

In [None]:
for j in range(0, 12):
    X[:, :, j, :] = (X[:, :, j, :] - MEANS[j]) / STDEVS[j]

In [None]:
for i in tqdm(range(0, len(X))):
    for j in range(0, 12):
        X[i, :, j, :] = butter_bandpass_filter(X[i, :, j, :], 0.55, 20, 120) # NOTE: this is not the final bandpass filter used in the paper, this was a preliminary filter used in the preprocess step.

In [None]:
Y = labels
classes = []
for y in Y:
    classes.extend(y)

classes = list(set(classes))

# convert labels to multi-hot encoding
Y = np.array([np.array([1 if c in y else 0 for c in classes]) for y in Y])

In [None]:
with open("preprocessed_data.pkl", "wb") as f:
    pickle.dump({"data": X, "labels": Y, "classes": classes}, f)