Summary:

1. Segment record to 512 points width periods. For each period, do stationary wavelet transform and use YOLO model to predict. 


2. Do nonmax suppression:

    1) for each grid, discard predicted box with p<=0.5
    
    2) for every two adjacent period, calculate the IOU. If IOU>=0.5, discard the box with lower possibility.
    

3. Calculate the general position using the predicted relative position


4. Draw the graph after each step/ Calculate recall and sensitivity

In [15]:
import gc, h5py, os
import numpy as np
import tensorflow as tf
from tensorflow import keras
from Stationary_transform import *
from tensorflow.keras.models import load_model

import matplotlib.pyplot as plt
%matplotlib inline

### Prediction

In [16]:
from tensorflow.keras import backend as K

def f1_value(y_true, y_pred): # custom f1 score metric
    true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
    possible_positives = K.sum(K.round(K.clip(y_true, 0, 1)))
    predicted_positives = K.sum(K.round(K.clip(y_pred, 0, 1)))
    precision = true_positives / (predicted_positives + K.epsilon())
    recall = true_positives / (possible_positives + K.epsilon())
    f1_val = 2*(precision*recall)/(precision+recall+K.epsilon())
    return f1_val

def yolo_loss(y_true, y_pred):
    true_x = y_true[...,1]
    pred_x = y_pred[...,1]
    pred_prob = y_pred[...,0]
    object_mask = y_true[...,0]
    
    object_loss = 3*object_mask*K.square(1-pred_prob)
    no_object_loss = 2*(1-object_mask)*K.square(0-pred_prob)
    x_loss = 5*object_mask*K.square(true_x-pred_x)
    loss = object_loss + no_object_loss + x_loss
    return loss

In [17]:
def predict(ecg, yolo_model):
    
    size = len(ecg)//512

    data_x = np.empty((size,4098,1),'float32')

    for i in range(size):
        sample = decomp(ecg[i*512:(i+1)*512],'db2',(512,8))
        data_x[i][:4096] = sample.reshape(4096,1)

    data_y = yolo_model.predict(data_x) # ? * 8 * 2
    return data_y

### Nonmax Suppression and Locate

In [19]:
'''discard predicted box with p<=0.5'''


def binarize(data_y):
    data_y = np.reshape(data_y,(-1,2)) # (?*8) * 2

    mask = data_y[:,0]>0.5
    mask = np.expand_dims(mask.astype('int'),0)
    mask = np.transpose(np.insert(mask,1,mask[0],axis = 0))

    data_y = data_y*mask # (?*8) * 2
    return data_y

In [21]:
def iou(gridl, gridr):
    
    gridl_x, gridr_x = gridl[1],gridr[1]
    gridl_l, gridl_r = gridl_x-25, gridl_x+25
    gridr_l, gridr_r = gridr_x-25, gridr_x+25
    
    intersect = max(gridl_r-gridr_l,0)
    union = 50+50-intersect    
    iou = intersect/union
    
    return iou

In [22]:
'''For every two adjacent periods, calculate the IOU. If IOU>=0.5, discard the box with lower possibility.
   Store all predicted R peak in R array
'''

def nonmax_suppress(data_y, ecg):
    R = np.zeros((len(ecg)),'int')

    for i in range(data_y.shape[0]):
        if i<len(data_y)-1:
            if data_y[i,0] != 0 and data_y[i+1,0] != 0: # when 2 adjacent periods both predict R
                iou_score = iou(data_y[i],data_y[i+1])
                if iou_score>=0.5:
                    min_index = np.argmin([data_y[i,0],data_y[i+1,0]])+i
                    max_index = np.argmax([data_y[i,0],data_y[i+1,0]])+i
                    data_y[min_index] = [0,0]
                    R[int((max_index+data_y[max_index,1])*64)] = 1
                else:
                    R[int((i+data_y[i,1])*64)] = 1
            if data_y[i,0] != 0 and data_y[i+1,0] == 0: # when only the former period predict R
                R[int((i+data_y[i,1])*64)] = 1
        if i == len(data_y)-1 and data_y[i,0] != 0:
            R[int((i+data_y[i,1])*64)] = 1
            
    return R

In [23]:
'''Process annotation and create annot storaging all true R position in record'''

def ann_process(ecg,path):
    annot = np.zeros((len(ecg)),'int')

    with open(path+'/ann.txt','r') as f:
        for line in f:
            pos, ann = line.rstrip().split(',')[:2]
            pos = round(int(pos)/1000*256)
            if ann =='X':
                continue
            elif pos>len(ecg):
                break
            else:
                annot[pos] = 1
    f.close()
    return annot

### Recall and Precision 

In [26]:
'''Calculate recall and precision based on true R position and predicted R position'''

def correspond(ary, index):
    for i in range(11):
        if ary[index+i]==1:
            return index+i+1
    return -1

def measure(Rs, annot, record):
    
    nn,no,on, i = 0,0,0,0 #TP, FN, FP

    f_no = open('../../record/yolo_no_'+record+'2.txt','w')
    f_on = open('../../record/yolo_on_'+record+'2.txt','w')
    
    while i<len(annot):
        if Rs[i] == 1:
            index = correspond(annot, i)
            if index == -1:
                f_on.write(str(i)+'\n')
                on += 1
            else:
                nn += 1
                i = index
                continue
        elif annot[i] == 1:
            index = correspond(Rs, i)
            if index == -1:
                f_no.write(str(i)+'\n')
                no += 1
            else:
                nn += 1
                i = index
                continue
        i+=1
    f_on.close()
    f_no.close()
    se = nn/(nn+no+K.epsilon())
    prec = nn/(nn+on+K.epsilon())
    return se, prec 
 

### Main

In [None]:
def main(model_path = '../../model/new_yolo_model/model_035.h5', database_path = '../../record/'):
    
    files = sorted([i for i in os.walk(database_path)][0][1])
    files.remove('.ipynb_checkpoints')

    sensitivity, precision = 0,0

    for record in files:

        yolo_model = load_model(model_path,custom_objects={'yolo_loss': yolo_loss, 'f1_value': f1_value})

        path = database_path+record
        ecg = np.fromfile(path+'/ecg.dat', '>i2') # plot original ecg sequence
        ecg = ecg[:len(ecg)-(len(ecg))%512]

        data_y = predict(ecg, yolo_model)
        data_y = binarize(data_y)
        R = nonmax_suppress(data_y, ecg)
        annot = ann_process(ecg,path)

        se,prec = measure(R,annot,record)
        sensitivity += se
        precision += prec
        f1 = (2*se*prec)/(se+prec)
        print(record + ': sensitivity: '+str(se)+ '   precision: '+ str(prec)+ '   f1: '+ str(f1))

        del yolo_model
        gc.collect()

    print('Average sensitivity: ', str(sensitivity/67))
    print('Average precision: ', str(precision/67))

In [30]:
main()

20000: sensitivity: 0.9994905805449832   precision: 0.9998749999990386   f1: 0.9996827533157088
20001: sensitivity: 0.9953870642876529   precision: 0.9996411327145596   f1: 0.9975095629516318
20002: sensitivity: 0.9963535218186629   precision: 0.9999234049063412   f1: 0.9981352714042399
20003: sensitivity: 0.9690740666471757   precision: 0.9959610613333476   f1: 0.9823336205907364
20004: sensitivity: 0.988600924829741   precision: 0.9994111305378648   f1: 0.9939766363762561
20005: sensitivity: 0.9931582304159243   precision: 0.9985232108495438   f1: 0.9958334948248171
20006: sensitivity: 0.9981317319597768   precision: 0.9998854152731892   f1: 0.999007804002176
20007: sensitivity: 0.9900753014359707   precision: 0.9996579415849746   f1: 0.9948435463086717
20008: sensitivity: 0.9972780868147705   precision: 0.9993806205693534   f1: 0.9983282466805901
20009: sensitivity: 0.9863467821511853   precision: 0.9991750055798748   f1: 0.9927194530419255
20010: sensitivity: 0.762608128667018   pr