Summary:
    
1. Seperate the whole record into multi batches; Feed batch to trained model; Get probability of each point


2. Apply median filter to model output and binarize output using threshold = 0.5


3. Refine prediction:

    1) Filter out predicted QRS periods with width <100ms
    
    2) For two too closed QRS periods(<0.2s), discard discard the narrower one
    
    3) For two too distant QRS periods(>1,5s), re-search the gap for missing QRS using lower threshold = 0.4
    
    4) Filter out predicted QRS periods with width <100ms again


4. Process true annotation and predicted R peak position; Calculate recall and precision (tolerance = 8 points)

In [1]:
import gc, h5py, time
import numpy as np
import scipy.signal
import tensorflow as tf
from Stationary_transform import *
from tensorflow import keras
from tensorflow.keras import initializers
from tensorflow.keras.models import load_model

import matplotlib.pyplot as plt
%matplotlib inline

### Model Prediction

In [2]:
from tensorflow.keras import backend as K

def f1_value(y_true, y_pred): # custom f1 score metric
    true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
    possible_positives = K.sum(K.round(K.clip(y_true, 0, 1)))
    predicted_positives = K.sum(K.round(K.clip(y_pred, 0, 1)))
    precision = true_positives / (predicted_positives + K.epsilon())
    recall = true_positives / (possible_positives + K.epsilon())
    f1_val = 2*(precision*recall)/(precision+recall+K.epsilon())
    return f1_val

In [4]:
def get_prob(ecg,model_path): 
    batch_size = int(LEN/BATCH_NUM)
    prob = np.empty(LEN, 'float16')
    data_X = np.empty((batch_size,4098),'float16')
    
    for i in range(BATCH_NUM):
        #print('Batch '+ str(i))
        start = time.time()
        
        seq = ecg[i*batch_size+START-256:(i+1)*batch_size+START+256]
        
        for j in range(batch_size):  # preprocessing original ecg sliding window
            #print(j, end = '\r')
            
            data = decomp(seq[j:j+512,:],'db2',(512,8) )
            data_X[j][:4096] = data.reshape(4096)
        
        batch_end = time.time()
        #print('time used for decompose: '+str(batch_end-start))
        
        model = load_model(model_path,custom_objects={'f1_value': f1_value})
        prob[i*batch_size:(i+1)*batch_size] = model.predict(data_X).reshape(1, batch_size)[0] 
        del model # delete model after every batch prediction to avoid memory leak
        gc.collect()
        
        end = time.time()
        #print('Time used for predict '+ str(i) +': ' + str(end-batch_end))  
        
    return prob

### Median Filter and Binarize

In [5]:
def binarize(med_prob, threshold, start, end):
    binarize_prob = np.empty((end-start),'int') # apply threshold = 0.5 to the probability
    for i,prob in enumerate(med_prob[start:end]):
        if prob<threshold:
            binarize_prob[i] = 0
        else:
            binarize_prob[i] = 1
    return binarize_prob

### Prediction Refine

In [None]:
def counter(start, prob, num, direction): 
    count,i = 0,start
    while i < len(prob) and i >= 0 and prob[i]==num:
        count += 1
        i = i+1 if direction == 1 else i-1
    return count


'''Filter out predicted QRS periods with width <100ms'''

def filter_by_width(binarize_prob):    
    i = 0
    while i<LEN: 
        if binarize_prob[i]==1:
            count = counter(i,binarize_prob,1,1)
            if count<=WIDTH_LOW: 
                for j in range(count):
                    binarize_prob[i+j]=0
            i += count
            continue
        i += 1
    return binarize_prob

In [None]:
'''For two too closed QRS periods(<0.2s), discard the narrower one'''


'''remove a disgnated period'''
def remove(start, end, num, prob):  
    for i in range(start,end):
        prob[i] = num
    return prob

'''remove narrower period '''
def remove_smaller(start, gap, prob): 
    prev_wid = counter(start-1, prob, 1, 0)    # get width
    last_wid = counter(start+gap, prob, 1, 1) 
    if prev_wid < last_wid:                    # remove the previous QRS
        prob = remove(start-prev_wid, start, 0, prob)
        remove_prev = True
    elif prev_wid >= last_wid:                 # remove the later QRS
        prob = remove(start+gap, start+gap+last_wid, 0, prob)
        remove_prev = False
    return prob, remove_prev

''' repeatedly check the next period and discard narrower period until the two adjacent periods are not too closed'''
def remove_dis_low(k, gap, binarize_prob): 
    remove_prev = False
    while gap<DIS_LOW and not remove_prev:
        binarize_prob,remove_prev = remove_smaller(k, gap, binarize_prob)
        gap = counter(k,binarize_prob,0,1)
    return binarize_prob, gap

In [None]:
'''For two too distant QRS periods(>1,5s), re-search the gap for missing QRS using lower threshold = 0.4'''

def search_new(start, gap, med_prob,binarize_prob): # re-search gap for missing QRS with threshold = 0.4
    end = start+ gap
    prob_segment= binarize(med_prob, THRESHOLD2, start, end)
    new_gap = counter(0,prob_segment,0,1)
    if new_gap == 0:
        wid = counter(0,prob_segment,1,1)
        prob_segment = remove(0,wid,0,prob_segment)
        new_gap = counter(0,prob_segment,0,1)
    binarize_prob[start:end] = prob_segment
    binarize_prob, gap = remove_dis_low(start,new_gap,binarize_prob) 
    return binarize_prob, gap

In [6]:
'''refine process main'''

def refine(binarize_prob, med_prob):
    k, meet_R = 0, False
    while k < len(binarize_prob):
        #print(k,end='\r')
        if binarize_prob[k] ==1 and not meet_R:
            meet_R = True
        elif meet_R and binarize_prob[k]==0:
            gap = counter(k,binarize_prob,0,1)
            if gap < DIS_LOW and k+gap<len(binarize_prob): # too closed
                binarize_prob, gap = remove_dis_low(k, gap, binarize_prob) 
            if gap > DIS_HIGH: # too distant
                binarize_prob, gap = search_new(k, gap, med_prob, binarize_prob)
            k += gap
            continue
        k += 1
    return binarize_prob

### Calculate Recall and Precision

In [None]:
'''get location of predicted R'''
def locate(binarize_prob,ecg):
    Rs = np.zeros((len(binarize_prob)),'int')
    count, j  = 0,0
    while j < len(binarize_prob):
        if binarize_prob[j] == 1:
            count = counter(j,binarize_prob,1,1)
            period = ecg[START:END][j:j+count]
            index = np.argmax(np.abs(period))
            Rs[j+index] = 1
            j+=count
            continue
        j+=1
    return Rs

'''process true annotation to get true R position'''
def get_ann(binarize_prob):
    with open('../../record/20010/ann.txt','r') as f:
        annot = np.zeros((len(binarize_prob)),'int')
        for line in f:
            pos, ann = line.rstrip().split(',')[:2]
            pos = round(int(pos)/1000*256)
            if ann =='X' or pos<START:
                continue
            elif pos>END:
                break
            else:
                annot[pos-START] = 1
    return annot

In [13]:
'''compare true R position and predicted R position; calculate recall and precision'''

def correspond(ary, index):
    for i in range(ERROR):
        if ary[index+i]==1:
            return index+i+1
    return -1

def measure(Rs, annot):
    
    nn,no,on, i = 0,0,0,0 #TP, FN, FP
    
    while i<len(annot):
        if Rs[i] == 1:
            index = correspond(annot, i)
            if index == -1:
                on += 1
            else:
                nn += 1
                i = index
                continue
        elif annot[i] == 1:
            index = correspond(Rs, i)
            if index == -1:
                no += 1
            else:
                nn += 1
                i = index
                continue
        i+=1
    se = nn/(nn+no+K.epsilon())
    prec = nn/(nn+on+K.epsilon())
    return se, prec 

### Main

In [None]:
START = 256
END = len(ecg)-256
LEN = END-START
RECORD = 20009
BATCH_NUM = 128
THRESHOLD1 = 0.5 #第一次binarize的阈值
THRESHOLD2 = 0.4 #第二次搜索的阈值
DIS_LOW = 52 #0.2s 可允许的两段1区间的最小距离
DIS_HIGH = 307 #1.5s 可允许的两段1区间最大距离
WIDTH_LOW = 24 #100ms 可允许的1区间最窄宽度
ERROR = 8 #可偏差点数

In [None]:
def main(ecg_path ='../../record/20010/ecg.dat', model_path = '../../model/model_30.h5'):
    ecg = np.fromfile(ecg_path, '>i2') # plot original ecg sequence
    ecg = ecg[:23903744]
    
    prob = get_prob(ecg, model_path)
    #print('finish predict')
    med_prob = scipy.signal.medfilt(prob) # median filter to the probability
    #print('finish median filter')
    binarize_prob = binarize(med_prob,THRESHOLD1,0,LEN)
    #print('finish binarize')
    binarize_prob = filter_by_width(binarize_prob)
    #print('finish filter by width')
    binarize_prob = refine(binarize_prob, med_prob)
    binarize_prob = filter_by_width(binarize_prob)
    #print('finish refine')

    Rs = locate(binarize_prob,ecg)
    #print('finish locate R')
    annot = get_ann(binarize_prob)
    #print('finish get annot')
    se, prec = measure(Rs, annot)
    print('sensitivity: '+str(se)+ '   precision: '+ str(prec)) # model 30 error = 8
    
    return se, prec

In [11]:
main()

Batch 0
time used for decompose: 36.89652705192566
Time used for predict 0: 15.880534648895264
Batch 1
time used for decompose: 35.361093521118164
Time used for predict 1: 12.282869815826416
Batch 2
time used for decompose: 35.56828761100769
Time used for predict 2: 12.082835674285889
Batch 3
time used for decompose: 35.56056356430054
Time used for predict 3: 12.053177833557129
Batch 4
time used for decompose: 35.19124889373779
Time used for predict 4: 12.070293664932251
Batch 5
time used for decompose: 35.69074869155884
Time used for predict 5: 12.030882120132446
Batch 6
time used for decompose: 35.46727013587952
Time used for predict 6: 12.065332174301147
Batch 7
time used for decompose: 35.295260190963745
Time used for predict 7: 12.092155933380127
Batch 8
time used for decompose: 35.34883236885071
Time used for predict 8: 12.07216501235962
Batch 9
time used for decompose: 36.00576210021973
Time used for predict 9: 12.148287773132324
Batch 10
time used for decompose: 35.620114564895

### Plot after Each Step

In [None]:
prob = get_prob(ecg, model)

figure = plt.figure(figsize = (10,5)) # plot the original ecg along with the probability
plt.plot(prob,color = "orange")
plt.plot(ecg[START:END]/2048, color = "blue")
plt.show()

In [None]:
med_prob = scipy.signal.medfilt(prob)

figure = plt.figure(figsize = (10,5)) # plot the original ecg along with the probability
plt.plot(med_prob,color = "orange")
plt.plot(ecg[START:END]/2048, color = "blue")
plt.show()

In [None]:
binarize_prob = binarize(med_prob,THRESHOLD1,0,LEN)

figure = plt.figure(figsize = (10,5)) # plot the original ecg along with the probability
plt.plot(binarize_prob,color = "orange")
plt.plot(ecg[START:END]/2048, color = "blue")
plt.show()

In [None]:
binarize_prob = filter_by_width(binarize_prob)

figure = plt.figure(figsize = (10,5)) # plot the original ecg along with the probability
plt.plot(binarize_prob,color = "orange")
plt.plot(ecg[START:END]/2048, color = "blue")
plt.show()

In [None]:
binarize_prob = refine(binarize_prob,med_prob)
binarize_prob = filter_by_width(binarize_prob)

figure = plt.figure(figsize = (10,5)) # plot the original ecg along with the probability
plt.plot(binarize_prob,color = "orange")
plt.plot(ecg[START:END]/2048, color = "blue")

plt.show()

In [None]:
Rs = locate(binarize_prob,ecg)
annot = get_ann(binarize_prob)

figure = plt.figure(figsize = (10,5)) 
plt.plot(binarize_prob,color = "orange")
plt.plot(ecg[START:END]/2048, color = "blue")

for i,R in enumerate(Rs):
    if R ==1:
        plt.axvline(i, color = 'red')
        
for i,ann in enumerate(annot):
    if ann==1:
        plt.axvline(i, color = 'green')