In [2]:
import pickle
import numpy as np
import pandas as pd
import json
from random import sample
from tensorflow import keras
from tensorflow.keras.models import Sequential, Model, load_model, model_from_json
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.layers import Dense, Conv1D, InputLayer, Masking, MaxPooling1D, GlobalAveragePooling1D, Dropout
from biosppy.signals.tools import filter_signal
from biosppy.signals import ecg

In [3]:
class CNN:
    def __init__(self):
        self.model = self.build_model()
        self.Xs = []
        self.ys = []
        self.callbacks = [
            keras.callbacks.ModelCheckpoint(
                filepath='models.{epoch:02d}-{val_loss:.2f}.h5',
                save_best_only=True
            )
        ]
    
    # build cnn
    def build_model(self):
        model = Sequential()
        model.add(Masking(mask_value=0, input_shape=(1300, 2)))
        model.add(Conv1D(16, 13, activation='relu'))
        model.add(Conv1D(16, 13, activation='relu'))
        model.add(MaxPooling1D(3))
        model.add(Conv1D(32, 13, activation='relu'))
        model.add(Conv1D(32, 13, activation='relu'))
        model.add(GlobalAveragePooling1D())
        model.add(Dense(32, activation='relu'))
        model.add(Dense(16, activation='relu'))
        model.add(Dense(2, activation='softmax')) # af, non-af
        
        return model
    
    def get_train(self, train):
        # get X and y
        # scan the first five heartbeats to predict the fifth
        for i, t in enumerate(train):
            print(str(i+1) + '/' + str(len(train)), end='\r')
            rpeaks = t['rpeaks']
            # from the seventh peak
            for p in range(6, len(rpeaks)):
                start = rpeaks[p-6]
                end = rpeaks[p]
                if end - start > 1300:
                    end = start + 1299
                sig = t['sig_filtered'][start:end+1]
                # classification on p-1 location
                if t['class_true'] == 1:
                    y = 1
                elif t['class_true'] == 2:
                    # if p-1 is between an af start and end point, y = 1
                    for j in range(len(t['af_start'])):
                        if t['rpeaks'][p-1] in range(t['af_start'][j], t['af_end'][j] + 1):
                            y = 1
                            break
                else:
                    y = 0
                        
                self.Xs.append(sig)
                self.ys.append(y)
                
#         self.Xs = np.array(self.Xs)
            
        
    def filter_signals(self, sig, fs):
        sig_filtered = [filter_signal(sig[:,0],
                        ftype='FIR',
                        band='bandpass',
                        order=50,
                        frequency=[0.5,45],
                        sampling_rate=fs)[0],
                        filter_signal(sig[:,1],
                        ftype='FIR',
                        band='bandpass',
                        order=50,
                        frequency=[0.5,45],
                        sampling_rate=fs)[0]
                        ]
        return sig_filtered
    
    def get_rpeaks(self, sig_filtered, fs):
        rpeaks = ecg.christov_segmenter(sig_filtered[0], fs)[0]
        return rpeaks
        
    # only for test data, training data has already been preprocessed
    def preprocessing(self, test):
        
        for i in range(len(test)):
            print(str(i+1) + '/' + str(len(test)), end='\r')
            # filter signals and get rpeaks
            sig = self.filter_signals(test[i]['sig'], test[i]['fs'])
            rpeaks = self.get_rpeaks(sig, test[i]['fs'])
            sig = np.transpose(sig)
            
            test[i]['sig_filtered'] = sig
            test[i]['rpeaks'] = rpeaks
        
        return test
    
    def fit(self, train):
        
        self.get_train(train)
        
        self.model.compile(loss='categorical_crossentropy',
                           optimizer='adam',
                           metrics=['accuracy'])
        
        self.model.summary()
        
        self.model.fit(self.Xs,
                       to_categorical(self.ys, 2),
                       batch_size = 200,
                       epochs = 10,
                       verbose = 1,
                       validation_split = 0.2,
                       callbacks = self.callbacks)
        
    def score(self, test):
        pass

In [4]:
# TRAINING_DATA_PATH = '../data/train_preprocessed.pkl'
# TEST_DATA_PATH = '../data/test.pkl'

# # load data
# with open(TRAINING_DATA_PATH, 'rb') as file:
#     train = pickle.load(file)

# with open(TEST_DATA_PATH, 'rb') as file:
#     test = pickle.load(file)

In [5]:
TRAINING_DATA_PATH = '../data/train_sampled_2.json'

with open(TRAINING_DATA_PATH, 'r') as file:
    train = json.load(file)

In [6]:
# further process training data
# train_saved = train
# for i in range(len(train)):
#     print(str(i+1) + '/' + str(len(train)), end='\r')
#     train[i]['af_start'] = [train[i]['beat_loc'][x] for x in train[i]['af_start_scripts']]
#     train[i]['af_end'] = [train[i]['beat_loc'][x] for x in train[i]['af_end_scripts']]
#     del train[i]['af_start_scripts']
#     del train[i]['af_end_scripts']
#     del train[i]['beat_loc']

In [8]:
train[0].keys()

dict_keys(['record_name', 'fs', 'sig_filtered', 'rpeaks', 'class_true', 'af_start', 'af_end'])

In [None]:
cnn = CNN()
cnn.fit(train)

2022-12-02 23:13:51.679009: I tensorflow/compiler/jit/xla_cpu_device.cc:41] Not creating XLA devices, tf_xla_enable_xla_devices not set
2022-12-02 23:13:51.681462: I tensorflow/core/platform/cpu_feature_guard.cc:142] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  SSE4.2
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.


Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
masking (Masking)            (None, 1300, 2)           0         
_________________________________________________________________
conv1d (Conv1D)              (None, 1288, 16)          432       
_________________________________________________________________
conv1d_1 (Conv1D)            (None, 1276, 16)          3344      
_________________________________________________________________
max_pooling1d (MaxPooling1D) (None, 425, 16)           0         
_________________________________________________________________
conv1d_2 (Conv1D)            (None, 413, 32)           6688      
_________________________________________________________________
conv1d_3 (Conv1D)            (None, 401, 32)           13344     
_________________________________________________________________
global_average_pooling1d (Gl (None, 32)                0

In [None]:
type(cnn.Xs)