In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from keras.layers import Conv1D
import wfdb                            # Package for loading the ecg and annotation
from sklearn.model_selection import train_test_split
from keras.models import Sequential
from keras.layers import Dense, Dropout
from tensorflow.keras.utils import to_categorical
from sklearn.metrics import roc_auc_score, accuracy_score, precision_score, recall_score
from sklearn.metrics import accuracy_score
import warnings
warnings.filterwarnings("ignore") 
import random
from scipy import signal
import pywt
import biosppy

# Random Initialization
random.seed(42)

record_list = ['100','101','102','103','104','105','106','107',
           '108','109','111','112','113','114','115',
           '117','118','119','121','122','123','124','200',
           '201','202','203','205','207','208','209','210',
           '212','213','214','215','217','219','220','221',
           '222','223','228','230','231','232','233','234']

path = "data/"

nonbeat_symbols = ['[','!',']','x','(',')','p','t','u','`',
           '\'','^','|','~','+','s','T','*','D','=','"','@','Q','?']

abnormal = ['L','R','V','/','A','f','F','j','a','E','J','e','S']

normal = ['N']

num_sec = 3
fs = 360

In [2]:

def bandpassFilter(data):
    nyq = 0.5 * 500
    low = 3 / nyq
    high = 12 / nyq
    b, a = signal.butter(3, [low, high], btype='band')
    filtered_data = signal.filtfilt(b, a, data)
    return filtered_data

def get_ecg(filename):
    signal = wfdb.rdrecord(filename, channels=[0]).p_signal
    annotation_symbols = wfdb.rdann(filename, "atr").symbol
    annotation_symbol_location = wfdb.rdann(filename, "atr").sample
    
    return signal, annotation_symbols, annotation_symbol_location

def pan_tompkins(ecg_data, fs=500, rel_amplitude=0.5, min_rr=120, max_rr=200):
    diff_data = np.diff(ecg_data)
    squared_data = diff_data**2
    window_size = int(0.02 * fs)
    average_data = np.convolve(squared_data, np.ones(window_size) / window_size, mode='same')
    high_threshold = rel_amplitude * np.max(average_data)
    low_threshold = 0.5 * high_threshold
    peaks = []
    for i in range(len(ecg_data)):
        if average_data[i] > high_threshold:
            peaks.append(i)
        if len(peaks) > 1 and i - peaks[-2] < min_rr:
            peaks.pop()
        elif len(peaks) > 2 and i - peaks[-1] > max_rr:
            peaks.pop()

    return np.array(peaks)

def preProcessing(filename):
    with open(filename, 'rb') as file:
        data = np.fromfile(file, dtype='int16')

  #correcting the baseline of the data

    corrected_data = signal.detrend(data)

  # creating a butter-worth filter or 2nd order
  # these are the best frequencie for this data

    cutoff_freq = 10
    sample_freq = 360 #Hz

    b, a = signal.butter(2, cutoff_freq / (sample_freq / 2))

  #filtering the signal

    filtered_data = signal.filtfilt(b, a, corrected_data)

  # applying wavelet transform on the baseline corrected data and then ignoring the high frequency and low frequency components

    arr = pywt.wavedec(corrected_data, 'sym4', level=4)

  # arr[0] = np.zeros_like(arr[0])
    arr[1] = np.zeros_like(arr[1])
  # arr[2] = np.zeros_like(arr[2])
  # arr[3] = np.zeros_like(arr[3])
    arr[4] = np.zeros_like(arr[4])

    wavdec_filtered_signal = pywt.waverec(arr, 'sym4')

    final_signal = bandpassFilter(wavdec_filtered_signal)


    results = biosppy.signals.ecg.christov_segmenter(signal=final_signal, sampling_rate=500)

    r_peaks = results['rpeaks']

  # return [final_signal, r_peaks]
    return r_peaks




In [3]:
# One complete signal consists of annotation and data. Each annotation file can be divided into two groups whcih
# are symbols represting the peaks, like N, A, etc and the location of these symbols in the ecg data from 
# annotation objects's sample part

def get_ecg_(filename):
    signal = wfdb.rdrecord(filename, channels=[0]).p_signal
    annotation_symbols = wfdb.rdann(filename, "atr").symbol
    annotation_symbol_location = preProcessing(filename+".dat")
    
    return signal, annotation_symbols, annotation_symbol_location

In [4]:
def build_XY(p_signal, df_ann, num_cols):
    # this function builds the X,Y matrices for each beat
    # it also returns the original symbols for Y
    
    num_rows = len(df_ann)

    X = np.zeros((num_rows, num_cols))
    Y = np.zeros((num_rows,1))
    sym = []
    
    # keep track of rows
    max_row = 0

    for atr_sample, atr_sym in zip(df_ann.atr_sample.values,df_ann.atr_sym.values):

        left = max([0,(atr_sample  - num_sec*fs) ])
        right = min([len(p_signal),(atr_sample + num_sec*fs) ])
        x = p_signal[left: right]
        if len(x) == num_cols:
            X[max_row] = x
            Y[max_row] = int(atr_sym in abnormal)
            sym.append(atr_sym)
            max_row += 1
    X = X[:max_row,:]
    Y = Y[:max_row,:]
    return X,Y,sym

In [5]:
def make_dataset(pts, num_sec, fs):
    # function for making dataset ignoring non-beats
    # input:
    #   pts - list of patients
    #   num_sec = number of seconds to include before and after the beat
    #   fs = frequency
    # output: 
    #   X_all = signal (nbeats , num_sec * fs columns)
    #   Y_all = binary is abnormal (nbeats, 1)
    #   sym_all = beat annotation symbol (nbeats,1)
    
    # initialize numpy arrays
    num_cols = 2*num_sec*fs
    X_all = np.zeros((1,num_cols))
    Y_all = np.zeros((1,1))
    sym_all = []
    
    # list to keep track of number of beats across patients
    max_rows = []
    
    for pt in pts:
        file = path + pt
        
        p_signal, atr_sym, atr_sample = get_ecg(file)
        
        # grab the first signal
        p_signal = p_signal[:,0]
        
        # make df to exclude the nonbeats
        df_ann = pd.DataFrame({'atr_sym':atr_sym, 'atr_sample':atr_sample})
        df_ann = df_ann.loc[df_ann.atr_sym.isin(abnormal + ['N'])]
        
        X,Y,sym = build_XY(p_signal, df_ann, num_cols)
        sym_all = sym_all+sym
        max_rows.append(X.shape[0])
        X_all = np.append(X_all,X,axis = 0)
        Y_all = np.append(Y_all,Y,axis = 0)
        
    # drop the first zero row
    X_all = X_all[1:]
    Y_all = Y_all[1:]

    return X_all, Y_all, sym_all


In [6]:

X_all, Y_all, sym_all = make_dataset(record_list, num_sec, fs)


In [7]:
X_train, X_valid, y_train, y_valid = train_test_split(X_all, Y_all, test_size=0.33, random_state=42)

In [6]:
# Relu for activation function and drop out for regularization
model = Sequential()
model.add(Dense(32, activation = 'relu', input_dim = X_train.shape[1]))
model.add(Dropout(rate = 0.25))
model.add(Dense(1, activation = 'sigmoid'))

model.compile(loss = 'binary_crossentropy',
                optimizer = 'adam',
                metrics = ['accuracy'])

In [9]:
model.fit(X_train, y_train, batch_size = 32, epochs= 10, verbose = 1)

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


<keras.src.callbacks.History at 0x2b3a016a280>



In [12]:
def print_report(y_actual, y_pred, thresh):
    # Function to print evaluation metrics
    accuracy = accuracy_score(y_actual, (y_pred > thresh))
    precision = precision_score(y_actual, (y_pred > thresh))
    print('Accuracy:%.3f'%accuracy)
    print('Precision:%.3f'%precision)
    print(' ')
    return accuracy, precision

y_train_pred = model.predict(X_train,verbose = 1)
y_valid_pred = model.predict(X_valid,verbose = 1)

thresh = (sum(y_train)/len(y_train))[0]

# Accessing Evaluation Metrics Function
print('On Train Data')
print_report(y_train, y_train_pred, thresh)
print('On Valid Data')
print_report(y_valid, y_valid_pred, thresh)

print(X_train.shape)

On Train Data
Accuracy:0.978
Precision:0.979
 
On Valid Data
Accuracy:0.970
Precision:0.966
 
(71485, 2160)


In [13]:
# test_list = []
# idx = []

# test_list.append('116')

# x, y, s = make_dataset(test_list, 3, 360)

# for i in range(len(y)):
#     if y[i]==0: idx.append(i)


In [1]:
# arr = model.predict(x[idx])

# for i in range(len(arr)):
#     if(arr[i]>=0.5):
#         arr[i] = int(1)
#     else:
#         arr[i] = int(0)

# n = y[idx]

# c = 0
# j = 
# for i in range(len(arr)):
#     n[i] = int(n[i])
#     if arr[i]==n[i]: c += 1

# print(c)
# print(len(arr))