In [None]:
import os， gc, math, CustomAttention
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import initializers
from tensorflow.keras import layers
from tensorflow.keras import regularizers
from tensorflow.keras.models import load_model
from tensorflow.keras.callbacks import ReduceLROnPlateau
from sklearn.model_selection import train_test_split

import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
# parse examples
def decode(example):
    features = {
        'ecg': tf.io.FixedLenFeature([4098], tf.float32),
        'label': tf.io.FixedLenFeature([1], tf.int64),
    }
    feature_dict = tf.io.parse_single_example(example,features)
    ecg = feature_dict['ecg']
    label = feature_dict['label']
    return ecg,label

In [None]:
# initialization
sample_size = 15048*67
val_size = 152*67
LEARNING_RATE = 0.01
EPOCH_NUM = 50
BATCH_SIZE = 256
BUFFER_SIZE = 48000
NUM_PARALEEL_CALLS = 8

In [None]:
temp = [i for i in os.walk('/tmpdata/train1/')][0][2] # get all file path
filenames = ['/tmpdata/train1/'+file for file in temp]
print(filenames)

dataset_train = tf.data.TFRecordDataset(filenames) # load train dataset

dataset_train = dataset_train.shuffle(buffer_size = BUFFER_SIZE).repeat() # dataset preprocessing
dataset_train = dataset_train.map(decode,num_parallel_calls = NUM_PARALEEL_CALLS)
dataset_train = dataset_train.batch(batch_size = BATCH_SIZE)
dataset_train = dataset_train.prefetch(1)

In [None]:
temp = [i for i in os.walk('/tmpdata/val1/')][0][2] # get all file path
filenames = ['/tmpdata/val1/'+file for file in temp]
print(filenames)

dataset_val = tf.data.TFRecordDataset(filenames) # load val dataset

dataset_val = dataset_val.map(decode,num_parallel_calls = NUM_PARALEEL_CALLS) # dataset preprocessing
dataset_val = dataset_val.batch(val_size)

In [None]:
from tensorflow.keras import backend as K

def f1(y_true, y_pred): # custom f1 score metric
    true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
    possible_positives = K.sum(K.round(K.clip(y_true, 0, 1)))
    predicted_positives = K.sum(K.round(K.clip(y_pred, 0, 1)))
    precision = true_positives / (predicted_positives + K.epsilon())
    recall = true_positives / (possible_positives + K.epsilon())
    f1_val = 2*(precision*recall)/(precision+recall+K.epsilon())
    return f1_val

In [None]:
model = load_model('model_dense_drop.h5', custom_objects={'CustomAttention': CustomAttention.CustomAttention})

reduce_lr = ReduceLROnPlateau(monitor='val_f1', factor=0.7,patience=5, min_lr=0.0005)

model_check = ModelCheckpoint(filepath = "../../model/model_{epoch:03d}.h5",
                             monitor = 'val_loss', mode = 'min'),

model.compile(
    loss="binary_crossentropy",
    optimizer=keras.optimizers.Adam(learning_rate = LEARNING_RATE),
    metrics=['accuracy',tf.keras.metrics.Precision(),tf.keras.metrics.Recall(),f1],
)

history = model.fit(dataset_train,epochs=EPOCH_NUM, steps_per_epoch = sample_size/BATCH_SIZE, 
                    validation_data=dataset_val,verbose = 1,callbacks=[reduce_lr,model_check])

# store model
model.save("model_attention_trained.h5")

In [None]:
def plot_metric(rate,size,history):
    metrics = [key for key in history.history.keys()][:5]   # plot all metric information

    for metric in metrics:
        fig = plt.figure(figsize=(10,3))
        plt.plot(history.history[metric])
        plt.plot(history.history['val_'+metric])
        plt.title(metric+ ' rate: '+str(rate)+' size: '+str(size))
        plt.ylabel(metric)
        plt.xlabel('epoch')

        plt.legend(['train', 'val'], loc='upper left')
        plt.show()
        
plot_metric(0.01,256,history)

In [None]:
def plot_attention(model):
    figure = plt.figure(figsize = (10,5))      # plot attention weight
    weights= model.get_weights()[0][0][:].reshape((1,512,1))
    sig = 1/(1 + np.exp(-pool))
    result = sig[0].reshape((512))

    plt.plot(2*result)
    plt.axvline(256)
    plt.axvline(236, color = 'red')
    plt.axvline(276, color = 'red')
    
plot_attention(model)

In [None]:
def find_misclassify(model_path, dataset_val):
    model = load_model(model_path, custom_objects={'CustomAttention': CustomAttention.CustomAttention,'f1':f1})
    x = list(dataset_val.as_numpy_iterator())[0][0] # get predicted probability of each sample in val
    test_prob = model.predict(x) 
    y_hat = np.empty((test_prob.shape[0],1)) # turn the probability to 0 or 1
    for i,prob in enumerate(test_prob):
        y_hat[i] = prob>0.5
        
    y_true = list(dataset_val.as_numpy_iterator())[0][1]

    misclass = np.array([],dtype = int) # count total number of misclassified sample
    for i,pred in enumerate(y_true):
        if not np.array_equal(pred,y_hat[i]):
            misclass = np.append(misclass,i)
    return misclass

In [None]:
def plot_misclassify(misclass, num_to_plot = 10): # plot designated number of misclassified samples and their position in record
    for i in misclass[:num_to_plot]: 
        fig = plt.figure(figsize = (8,2))
        plt.plot(x[i][:4096].reshape(512,8).T[0])
        plt.axvline(236)
        plt.axvline(276)
        plt.title(str(y_true[i]) + ' ' + str(x[i][-2:]))