# Imports

In [None]:
from ecg_plotting import *
import IPython
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from tensorflow.keras import datasets, layers, models

# Data

In [None]:
# download datasets

from dataset_downloader import download_dataset
import datasets.ptb_xl.data_handling as ptb_xl_dh

download_dataset('backgrounds')
download_dataset('ptb_xl')

## Processing and Labeling

In [None]:
from datasets.ptb_xl.data_handling import get_ecg_array

num_samples = 100
X, Y_label = get_ecg_array(sampling_rate=500, max_samples=num_samples)
X = np.reshape(X[:,452:-452,:],(-1, 4096,12 ,1))
Y = np.zeros((num_samples))

danger_list = ["IMI", "ASMI", "ILMI", "AMI", "LMI", "IPLMI", "IPMI", "PMI"]
for k in range(num_samples):
    if len(set(Y_label['scp_codes'].iloc[k].keys()) & set(danger_list)):
        Y[k] = 1

# Network Definition

In [None]:
class Residual(layers.Layer):
    def __init__(self, last_num_filters, **kwargs):
        super(Residual, self).__init__(**kwargs)
        self.last_num_filters = last_num_filters
        
        self.layer_1 = layers.Conv2D(last_num_filters + 64, (16, 1), activation=None, padding='same')
        self.layer_2 = layers.MaxPooling2D((2, 1)) # CHECK BECAUSE PAPER DOESN'T MENTION PRECISELY
        self.layer_3 = layers.BatchNormalization(axis=[1,2]) #? axis = [1,2] to normalize the axis=0 (over the batch)
        self.layer_4 = layers.Activation(activation='relu')
        self.layer_5 = layers.Dropout(dropout_p) # To prevent overfit
        self.layer_6 = layers.Conv2D(last_num_filters + 64, (16, 1), activation=None, padding='same')
        self.layer_7 = layers.MaxPooling2D((2, 1)) # CHECK BECAUSE PAPER DOESN'T MENTION PRECISELY
        self.layer_8 = layers.BatchNormalization(axis=[1,2]) #? axis = [1,2] to normalize the axis=0 (over the batch)
        self.layer_9 = layers.Activation(activation='relu')
        self.layer_10 = layers.Dropout(dropout_p) # To prevent overfit

        self.layer_11 = layers.Conv2D(last_num_filters + 64, (1, 1), activation=None, padding='same')
        self.layer_12 = layers.MaxPooling2D((4, 1)) # CHECK BECAUSE PAPER DOESN'T MENTION PRECISELY
        self.layer_13 = layers.BatchNormalization(axis=[1,2]) #? axis = [1,2] to normalize the axis=0 (over the batch)
        self.layer_14 = layers.Activation(activation='relu')
        self.layer_15 = layers.Dropout(dropout_p) # To prevent overfit

    def call(self, x):
        # the residual block using Keras functional API
        x_backup = x
        last_num_filters = self.last_num_filters

        x = self.layer_1(x)
        x = self.layer_2(x)
        x = self.layer_3(x)
        x = self.layer_4(x)
        x = self.layer_5(x)
        x = self.layer_6(x)
        x = self.layer_7(x)
        x = self.layer_8(x)
        x = self.layer_9(x)
        x = self.layer_10(x)

        x_backup = self.layer_11(x_backup)
        x_backup = self.layer_12(x_backup)
        x_backup = self.layer_13(x_backup)
        x_backup = self.layer_14(x_backup)
        x_backup = self.layer_15(x_backup)
        
        x = layers.Add()([x,x_backup])

        return x

In [None]:
dropout_p = 0.1

model = models.Sequential()

model.add(layers.Conv2D(64, (16, 1), activation=None, input_shape=(4096, 12, 1)))
model.add(layers.BatchNormalization(axis=[1, 2])) #? axis = [1,2] to normalize the axis=0 (over the batch)
model.add(layers.Activation(activation='relu'))
model.add(layers.Dropout(dropout_p)) # to prevent overfit

last_num_filters = 64
model.add(Residual(last_num_filters))
last_num_filters += 64
model.add(Residual(last_num_filters))
last_num_filters += 64
model.add(Residual(last_num_filters))
last_num_filters += 64
model.add(Residual(last_num_filters))
last_num_filters += 64

model.add(layers.Flatten())

model.add(layers.Dense(1, activation='sigmoid'))

model.summary()

## Compile

In [None]:
model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])

## Train

In [None]:
# Check whether we're training on a GPU or not
tf.test.is_gpu_available()

In [None]:
model.fit(X, Y, epochs=40, batch_size=2)