In [None]:
%reset

Once deleted, variables cannot be recovered. Proceed (y/[n])? y


snippet for duration reporting

In [None]:
timestamp_start = datetime.datetime.now()

timestamp_dur = datetime.datetime.now() - timestamp_start
print('Elapsed time = ' + str(timestamp_dur))

# [ IMPORTS ]

In [None]:
import datetime
import os
import random
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker

import statistics as stats
import scipy.signal as scsig

import tensorflow as tf
import keras
from tensorflow.keras.layers import Dense, Input, LSTM, Conv1D, MaxPooling1D, Flatten, BatchNormalization
from tensorflow.keras.models import Model, Sequential
from tensorflow.keras.utils import to_categorical

import warnings
warnings.filterwarnings("ignore", category=RuntimeWarning)

# [ GLOBAL ]

In [None]:
# Resample every signal to this rate for consistency
BASIC_SRATE = 128 #Hz
print('Basic sampling rate(Hz): '+str(BASIC_SRATE))


#=================================================
# working directories
#=================================================

# > _ base working directory
global_dir = '/content/drive/My Drive/Masters/workdir/ecg_data'
print('GLOBAL DIR :: '+global_dir)

# >> global MAT directory, contains signal data in matlab (.mat) format
global_matdir = os.path.join(global_dir, 'db_mat') 
print('GLOBAL MAT DIR :: '+global_matdir)

# >> global NPY directory, contains signal and meta data in numpy (.npy) format
global_npydir = os.path.join(global_dir, 'db_npy') 
print('GLOBAL NPY DIR :: '+global_npydir)

# >> global datasets directory, contains manually generated datasets
global_dsdir = os.path.join(global_dir, 'db_dataset') 
print('GLOBAL DATSET DIR :: '+global_dsdir)

# >> model directory, contains model weights and test results use load_weights(), save_weights() 
global_modeldir = os.path.join(global_dir, 'db_model')
print('GLOBAL MODEL DIR :: ' + global_modeldir)

#>>----------------------------------------------- 


#=================================================
# Annotations
#=================================================
# >> annotation directory, contains annotation mapping files to be used for experiments
global_antdir = os.path.join(global_dir, 'db_ant') 
print('GLOBAL ANNOTATION DIR :: ' + global_antdir)
#>>----------------------------------------------- 


#=================================================
# File Identifiers
#=================================================
# beat and non-beat annotations, signal data types to be used to save data in npy format
g_BA = 'BA'                     #<<--- beat annotations (@orignal Sampling rate)
g_NBA = 'NBA'                   #<<--- non-beat annotations (@orignal Sampling rate)
g_RAW2 = 'RAW2'                 #<<--- Raw lead2 signal from mat file
g_BLF2 = 'BLF2'                 #<<--- Baseline fitted signal
g_RES2 = 'RES2'                 #<<--- Resampled to BASIC_SRATE

g_SIG2 = 'SIG2'     #<<--- Removed manual gain
g_RPEAK = 'RRP'     #<<--- Resampled R-peaks
#>>----------------------------------------------- 

Basic sampling rate(Hz): 128
GLOBAL DIR :: /content/drive/My Drive/Masters/workdir/ecg_data
GLOBAL MAT DIR :: /content/drive/My Drive/Masters/workdir/ecg_data/db_mat
GLOBAL NPY DIR :: /content/drive/My Drive/Masters/workdir/ecg_data/db_npy
GLOBAL DATSET DIR :: /content/drive/My Drive/Masters/workdir/ecg_data/db_dataset
GLOBAL MODEL DIR :: /content/drive/My Drive/Masters/workdir/ecg_data/db_model
GLOBAL ANNOTATION DIR :: /content/drive/My Drive/Masters/workdir/ecg_data/db_ant


# [ CLASS DEFINITIONS ]

In [None]:
#---------------------------------------------------------------------------------------------------------------------------------------------
# CLASS ecg_db : represents one ECG database
#---------------------------------------------------------------------------------------------------------------------------------------------
class ecg_db:
    def __init__(self, dbname,  tag_recs):
        print('\nInitailze new ecg database ... ')
        self.name = dbname  #str
        self.dir_npy = os.path.join(global_npydir , dbname+'_npy') #str
        self.recs = set(np.loadtxt(os.path.join(self.dir_npy,'RECORDS'), dtype='str',delimiter="\n")) #set
        self.recs_tag = set(tag_recs)
        self.recs_dict = {} # initially empty, will be loaded on demand using function 'get_record'
        self.info()

    def info(self):
        print( 'DB NAME :: '+ self.name)
        print( 'DATA DIR :: ' + self.dir_npy )
        print( 'RECORDS :: [' +str(len(self.recs))+'] ' + str(self.recs) )
        print( 'TAG RECORDS :: [' +str(len(self.recs_tag))+'] ' + str(self.recs_tag))
        return 0

    def get_record(self,rec):
        if not (rec in self.recs_dict.keys()):
            self.recs_dict[rec] = ecg_record(self,rec)
        return self.recs_dict[rec]
    
    def get_random_record(self, recset):
        rec = random.choice(list(recset))
        if not (rec in self.recs_dict.keys()):
            self.recs_dict[rec] = ecg_record(self,rec)
        return self.recs_dict[rec]

#---------------------------------------------------------------------------------------------------------------------------------------------

#---------------------------------------------------------------------------------------------------------------------------------------------
# CLASS ecg_record : represents one ECG Record in any database
#---------------------------------------------------------------------------------------------------------------------------------------------
g_SUPRESS_DATA_WARNING=False
class ecg_record:

    def __init__(self, db, recname):
        self.db = db                                # class:{ecg_db}    object this record belongs to
        self.rec = recname                          # string            name of this record
        self.name = db.name + '_'+ recname          # string            full name including db.name
        if not recname in db.recs:
            print('WARNING:: Record "'+ recname +'" not found in database '+ db.name )
        self.data_npy = {}                          # dict dict of data file content used in self.read_data_npy('key')
        self.data_temp = {}                          # dict dict of data file content used in self.read_data_temp('key')
        self.binfo = None                           # class binfo       

##<------------------------------------------------- get instance of binfo class
    def read_binfo(self):
        if self.binfo == None:
            self.binfo = ecg_binfo(self)
        return self.binfo

    def refresh_binfo(self):
        self.binfo = ecg_binfo(self)
        return self.binfo

##<------------------------------------------------- data reading for npydir
    def load_data(self, data_type):
        ipath = os.path.join(self.db.dir_npy, self.rec + '_'+data_type+'.npy')
        try: # try to load this data
            self.data_npy[data_type] = np.load(ipath) # adds this to dictionary so next time it can read
            return self.data_npy[data_type] #= np.load(self.dirs[s])
        except:
            if g_SUPRESS_DATA_WARNING == False:
                print('WARNING:: Cant load "'+data_type+ '" file at '+ str(ipath) )
            return np.array([])
        
    def read_data(self, data_type):
        if data_type in self.data_npy.keys():
            return self.data_npy[data_type] #= np.load(self.dirs[s])
        else:
            return self.load_data(data_type)

##<------------------------------------------------- for tempdir
    def load_data_temp(self, data_type, dir_path):
        ipath = os.path.join(dir_path, self.rec + '_'+data_type+'.npy')
        try: # try to load this data
            self.data_temp[data_type] = np.load(ipath) # adds this to dictionary so next time it can read
            return self.data_temp[data_type] #= np.load(self.dirs[s])
        except:
            if g_SUPRESS_DATA_WARNING == False:
                print('WARNING:: Cant load "'+data_type+ '" file at '+ str(ipath) )
            return np.array([])
        
    def read_data_temp(self, data_type, dir_path):
        if data_type in self.data_temp.keys():
            return self.data_temp[data_type] #= np.load(self.dirs[s])
        else:
            return self.load_data_temp(data_type, dir_path)

    def save_data_temp(self, data_type, data_array, dir_path):
        ipath = os.path.join(dir_path, self.rec + '_'+data_type+'.npy')
        np.save(ipath, data_array)
        return ipath

    def del_data_temp(self, data_type, dir_path, vb):
        ipath = os.path.join(dir_path, self.rec + '_'+data_type+'.npy')
        if os.path.exists(ipath):
            if vb:
                print('Removing: '+str(ipath))
            os.remove(ipath)
            return 1
        else:
            return 0
#------------------------------------------------------------------------------------------------

#---------------------------------------------------------------------------------------------------------------------------------------------
# CLASS ecg_binfo : information about beats in a record
#---------------------------------------------------------------------------------------------------------------------------------------------
class ecg_binfo:
    def __init__(self, rec):
         
        # the record object
        self.rec = rec
        
        # read orignal annotations
        r_peaks_ants = rec.read_data(g_RPEAK)       # resampled ant file
        
        # calculate count of R peaks (excluding first and last)
        self.rp_count = len(r_peaks_ants) - 2
        
        # Extract Location and Labels of Peaks (exclude first and last beat)
        r_peaks_int = r_peaks_ants[:,0].astype('int')
        r_ants_str = r_peaks_ants[:,1]
        
        #self.rp_first = r_peaks_int_raw[0] # = self.rp_prev[0]
        #self.rp_last = r_peaks_int_raw[-1] # = self.rp_next[-1]
        
        # Location
        self.rp_curr = r_peaks_int[1:-1]    # current R peak
        self.rp_prev = r_peaks_int[0:-2]    # previous R peak (in samples)
        self.rp_next = r_peaks_int[2:]      # next R peak (in samples)
        
        # Label
        self.rl_curr = r_ants_str[1:-1]
        self.rl_prev = r_ants_str[0:-2]
        self.rl_next = r_ants_str[2:]

        # mapped Label
        self.rli_prev = []
        self.rli_curr = []
        self.rli_next = []

        # calculate temporal info
        self.rp_sec = self.rp_curr / BASIC_SRATE                 # peak location (in sec)
        self.rri_prev = (self.rp_curr - self.rp_prev) / BASIC_SRATE   # prev RRI (in sec) 
        self.rri_next = (self.rp_next - self.rp_curr) / BASIC_SRATE   # next RRI (in sec) 
        self.rri_delta = (self.rri_next - self.rri_prev)              # difference b/w prev and next RRI (in sec) 
        self.rri_dur = (self.rri_next + self.rri_prev)                # total duration from prev to next R-peak
        

    def get_signal_data_var(self, ith_peak): # data_type = g_SIG_II_POSTFIX
        # prev peak to next peak
        sel_sig = self.rec.read_data(g_SIG2) 
        ff = self.rp_prev[ith_peak]
        tt = self.rp_next[ith_peak]
        pp = self.rp_curr[ith_peak]
        return sel_sig[ff:tt+1], (pp-ff), (tt+1-ff) #<- also return position of peak
    
    def get_signal_data_fix(self, ith_peak, v_left_sec, v_right_sec): # data_type = g_SIG_II_POSTFIX
        return self.get_signal_data_fix_samples(ith_peak,int(v_left_sec*BASIC_SRATE),int(v_right_sec*BASIC_SRATE))

    def get_signal_data_fix_samples(self, ith_peak, v_left, v_right): # data_type = g_SIG_II_POSTFIX
        sel_sig = self.rec.read_data(g_SIG2) 
        ff = self.rp_curr[ith_peak]-v_left
        tt = self.rp_curr[ith_peak]+v_right
        pp = self.rp_curr[ith_peak]

        f_pad,t_pad=0,0
        if ff<0:
            f_pad=0-ff
            ff=0

        if tt>len(sel_sig):
            tpad=tt-len(sel_sig)
            tt=len(sel_sig)

        sel_part = np.hstack((
            np.zeros(f_pad),
            sel_sig[ff:tt],
            np.zeros(t_pad),
            ))

        pl = pp+f_pad
        return sel_part, pl #<- also return position of peak

    def get_local_hrT(self,local_window_start,local_window_end): # within a time duration
        lws = local_window_start*BASIC_SRATE # in samples        
        lwe = local_window_end*BASIC_SRATE # in samples        
        #ff and tt should be within signal limits
        # if not in limits then take shortest : means truncate lw duration
        ff = max( lws ,self.rp_prev[0])
        tt = min( lwe ,self.rp_next[-1])
        dd = (tt-ff)/BASIC_SRATE
        qq = np.where((self.rp_curr>=ff) & (self.rp_curr<=tt))[0] #  these many peaks in dd sec
        nq = len(qq)# qq must be at least 2 peaks
        # if qq peaks in dd secs then heart rate = (qq/dd) bps =  (qq/dd)*60 bpm
        if nq<2:  
             dd=0
             local_bps = 0
        else:
             ff = self.rp_curr[qq[0]]
             tt = self.rp_curr[qq[-1]]
             dd = (tt-ff)/BASIC_SRATE
             local_bps = (nq-1)/dd #bps
        
        return local_bps, dd
   
     
    def get_local_hr(self,ith_peak, local_window_left,local_window_right): # within local duration of ith peak
        lwl = local_window_left*BASIC_SRATE # in samples        
        lwr = local_window_right*BASIC_SRATE # in samples        
        #ff and tt should be within signal limits
        # if not in limits then take shortest : means truncate lw duration
        ff = max(self.rp_curr[ith_peak] - lwl ,self.rp_prev[0])
        tt = min(self.rp_curr[ith_peak] + lwr ,self.rp_next[-1])
        
        qq = np.where((self.rp_curr>=ff) & (self.rp_curr<=tt))[0] #  these many peaks in dd sec
        nq = len(qq)# qq must be at least 2 peaks
        if nq<2:  
             return 0, 0
        else:
             ff = self.rp_curr[qq[0]]
             tt = self.rp_curr[qq[-1]]
             dd = (tt-ff)/BASIC_SRATE
        # if nq peaks in dd secs then heart rate = (nq/dd) bps =  (nq/dd)*60 bpm
             local_bps = (nq-1)/dd #bps
             return local_bps, dd

    def get_local_hrA(self, local_window_left,local_window_right): # within local duration of all peaks
        lwl = local_window_left*BASIC_SRATE # in samples        
        lwr = local_window_right*BASIC_SRATE # in samples       
        #ff and tt should be within signal limits
        # if not in limits then take shortest : means truncate lw duration
        local_bps = np.zeros(self.rp_count,dtype='float')
        local_dd = np.zeros(self.rp_count,dtype='float')
        for ith_peak in range(0, self.rp_count):
             ff = max(self.rp_curr[ith_peak] - lwl ,self.rp_prev[0])
             tt = min(self.rp_curr[ith_peak] + lwr ,self.rp_next[-1])
             qq = np.where((self.rp_curr>=ff) & (self.rp_curr<=tt))[0] #  these many peaks in dd sec
             nq = len(qq) # qq must be at least 2 peaks
             if nq<2:  
                  local_bps[ith_peak] = 0 #bps
                  local_dd[ith_peak] = 0 #bps
             else:
                  ff = self.rp_curr[qq[0]]
                  tt = self.rp_curr[qq[-1]]
                  dd = (tt-ff)/BASIC_SRATE
                  local_bps[ith_peak] = (nq-1)/dd #bps
                  local_dd[ith_peak] = dd #bps
        return local_bps, local_dd
    
    def map_ants2int(self,map_dict):
        if len(self.rli_curr)!=self.rp_count:
            temp = np.zeros(self.rp_count+2,dtype='str')
            temp[0] =  map_dict[self.rl_prev[0]]
            for i in range(0, self.rp_count):
                temp[i+1] = map_dict[self.rl_curr[i]]
            temp[-1] =  map_dict[self.rl_next[-1]]
            self.rli_curr = temp[1:-1]
            self.rli_prev = temp[0:-2]
            self.rli_next = temp[2:]



# [ BUILD STANDARD DBs ]

In [None]:
print('Buidling standard databases')
#------------------------------------------------------------------------
std_mitdb = ecg_db('mitdb', [])
#------------------------------------------------------------------------
std_svdb = ecg_db('svdb', [])
#------------------------------------------------------------------------
std_incartdb= ecg_db('incartdb', [])
#------------------------------------------------------------------------

Buidling standard databases

Initailze new ecg database ... 
DB NAME :: mitdb
DATA DIR :: /content/drive/My Drive/Masters/workdir/ecg_data/db_npy/mitdb_npy
RECORDS :: [48] {'122', '223', '118', '102', '221', '107', '202', '201', '214', '203', '222', '105', '212', '233', '121', '210', '106', '209', '215', '108', '104', '123', '231', '115', '116', '119', '234', '213', '200', '228', '114', '117', '103', '113', '232', '101', '124', '220', '208', '100', '230', '207', '205', '109', '217', '219', '112', '111'}
TAG RECORDS :: [0] set()

Initailze new ecg database ... 
DB NAME :: svdb
DATA DIR :: /content/drive/My Drive/Masters/workdir/ecg_data/db_npy/svdb_npy
RECORDS :: [78] {'864', '847', '869', '868', '854', '801', '844', '845', '802', '880', '885', '856', '841', '800', '811', '849', '884', '887', '826', '865', '823', '803', '843', '848', '883', '806', '881', '829', '820', '872', '810', '860', '859', '886', '812', '878', '871', '877', '890', '861', '891', '862', '863', '807', '824', '879', '

# [ PERFORMANCE MEASURES ]

In [None]:
#=========================================================================================================================
#======================= NEURAL NETWORK PERFORMANCE MEASURES
#=========================================================================================================================
# 3.3 :: define performance evaluation functions

def get_performance(conf_matrix):
    #how many classes? = len of conf_matril
    nos_class = len(conf_matrix[0,:]) # len of 0th row
    res = np.zeros((0,8),dtype ='float64')
    for i in range(0,nos_class):
        # for each class calculate 4 performance measure - ACC, PRE, SEN, SPF, 
        # first compute TP, TN, FP, FN
        TP = conf_matrix[i,i]
        FP = np.sum(conf_matrix[:,i]) - TP
        FN = np.sum(conf_matrix[i,:]) - TP
        TN = np.sum(conf_matrix) - FN - FP - TP

        ACC = (TP+TN)   /   (TP+FP+FN+TN)
        PRE = (TP)      /   (TP+FP)
        SEN = (TP)      /   (TP+FN)
        SPF = (TN)      /   (TN+FP)

        res_i = np.array([TP, FN, FP, TN, ACC, PRE, SEN, SPF])
        res = np.vstack((res,res_i))
    return res


#------------------------------------------------------------------PRINTING

def print_lstr(class_labels):
    g_LSTR=''   # HEADER ROW for printing confusing matrix
    for i in range(0,len(class_labels)):
        g_LSTR+='\t'+str(class_labels[i])
    return  g_LSTR

def print_cf_row(cf_row,nos_labels):
    res = ''
    for j in range(0,nos_labels):
        res += '\t'+ str(cf_row[j])
    return res
def print_conf_matrix(conf_matrix, suffix, class_labels):
    res=(suffix+'A\\P' + print_lstr(class_labels)+'\n')
    nos_l=len(class_labels)
    for i in range(0,nos_l):
        res+=(suffix+str(class_labels[i]) + print_cf_row(conf_matrix[i],nos_l )+'\n')
    return res
def print_performance(perf_measures, class_labels):
    nos_class = len(perf_measures[:,0])
    print('Performance for '+str(nos_class)+' classes')
    print ('Class\tACC\tPRE\tSEN\tSPF')
    for i in range(0, nos_class):
        perf_i = np.round(perf_measures [i,:],2)
        #print('\tT.P : '+str(perf_i[0])+'\tF.N : '+str(perf_i[1]))
        #print('\tF.P : '+str(perf_i[2])+'\tT.N : '+str(perf_i[3]))
        print(str(class_labels[i])+'\t'+str(perf_i[4])+'\t'+str(perf_i[5])+'\t'+str(perf_i[6])+'\t'+str(perf_i[7]))
    return
#------------------------------------------------------------------

def plot_ecg_segment(signal_info, signal_array, fsec, tsec, x_scale, y_scale, y_low, y_high, mticks_pos, show_rris, predx, predy, a_color, gain=1):
    # plot signal segments
    #<<---------------------------------------------Select ECG Segment
    dsec = tsec - fsec
    print(signal_info.rec.name)
    if len(signal_array)==0:
        print('WARNING::Signal was not loaded.')
        return 0
    else:
        ff = int(fsec * BASIC_SRATE)
        tt = int(tsec * BASIC_SRATE)
        dd = tt - ff

        bps = signal_array[ff:tt] * gain  # signal data * gain

        lim_query = np.where((signal_info.rp_sec >= fsec) & (signal_info.rp_sec < tsec))[0]

        dticks = signal_info.rp_curr[lim_query]-ff  # tick position
        nos_ticks = len(dticks)

        dlabels = signal_info.rl_curr[lim_query]    # orignal labels
        dilabels = signal_info.rli_curr[lim_query]  # mapped labels
        ditruth = predy[lim_query]
        dpredx = predx[lim_query]
        diPred = dpredx.argmax(axis=1)
        diPredCol = np.zeros(nos_ticks, dtype='U15') 
        for i in range(0,nos_ticks):
            if diPred[i]==0:
                diPredCol[i]= 'tab:green'      
            else:
                diPredCol[i]= a_color  
        dicolors = np.zeros(nos_ticks, dtype='U15') # get color repesentation
        for i in range(0,nos_ticks):
            dicolors[i]= g_STD_LABELS[dilabels[i]]

        print('Time Interval{'+str(dsec)+'s}:['+str(fsec)+':'+str(tsec)+']')
        if nos_ticks > 0:
            print('Beat Interval{'+str(nos_ticks)+'#}:['+str(lim_query[0])+':'+str(lim_query[-1])+']')
        else:
            print('Beat Interval{'+str(nos_ticks)+'#}')

        # prepare figure: predictions
        plt.figure('ecg predictions', figsize = (dsec*x_scale ,(y_high-y_low) * y_scale) )
        plt.xlim(0, len(bps))
        plt.ylim(-0.5,1.1)
        plt.yticks([])
        plt.xticks(dticks,dlabels)
        plt.grid(axis='x')
        plt.hlines(0,0,len(bps), linewidth=0.3)
        plt.hlines(0.5,0,len(bps), linewidth=0.3)
        plt.hlines(1,0,len(bps), linewidth=0.3)
        plt.scatter(dticks,np.zeros(nos_ticks)-0.40,marker='s',color=dicolors)
        plt.scatter(dticks,np.zeros(nos_ticks)-0.20,marker='o',color=diPredCol)

        pred_str = np.zeros(len(diPred))
        for i in range(0,len(diPred)):
            pred_str[i]=dpredx[i][diPred[i]]
            
        plt.scatter(dticks,pred_str,marker='.',color='black')

        # where predy is 0 and predx is 1
        d_aN_pA = dticks[np.where((ditruth==0)&(diPred==1))[0]]
        plt.scatter(d_aN_pA,np.zeros(len(d_aN_pA)),marker='x',color='tab:green')

        d_aA_pN = dticks[np.where((ditruth==1)&(diPred==0))[0]]
        plt.scatter(d_aA_pN,np.zeros(len(d_aA_pN)),marker='x',color='tab:red')

        #plt.scatter(dticks,dpredx[:,0],marker='x',color='tab:green')
        #plt.scatter(dticks,dpredx[:,1],marker='x',color=a_color)

        plt.tight_layout()
        plt.show()
        

        # prepare figure: signal
        plt.figure('ecg signal', figsize = (dsec*x_scale ,(y_high-y_low) * y_scale) )
        plt.xlim(0, len(bps))
        plt.ylim(y_low,y_high)
        plt.yticks([])
        plt.xticks(dticks,dlabels)
        #x_grid = np.arange(0,tt-ff, 1*BASIC_SRATE)
        #plt.xticks(x_grid)
        plt.grid(axis='x')
        # plot signal and baseline
        plt.plot(bps, linewidth=0.5, color='black')
        plt.hlines(0,0,len(bps), linewidth=0.3)
        # plot mapped labels
        plt.scatter(dticks,np.zeros(nos_ticks)+mticks_pos,marker='s',color=dicolors)

        # finalize
        plt.tight_layout()
        plt.show()
        
        if show_rris:
             ddur = signal_info.rri_dur[lim_query]       # duration
             ddel = np.absolute(signal_info.rri_delta[lim_query] )      # delta rri
             # prepare figure: rri,delta rri
             my_low, my_high = -0.1, 3.5
             plt.figure('ecg meta', figsize = (dsec*x_scale ,(my_high-my_low) * 1.5*y_scale) )
             plt.xlim(0, len(bps))
             plt.ylim(my_low,my_high)
             plt.yticks([])
             plt.xticks(dticks,dlabels)
             #x_grid = np.arange(0,tt-ff, 1*BASIC_SRATE)
             #plt.xticks(x_grid)
             plt.grid(axis='x')
     
             # plot grid and baseline
             plt.hlines(0,0,len(bps), linewidth=0.3,color='red')
             #for j in [0.5,1,1.5,2,2.5,3]:
             #    plt.hlines(j,0,len(bps), linewidth=0.3,color='black')
     
             # plot mapped labels
             plt.scatter(dticks,ddur,marker='s',color=dicolors)
             plt.scatter(dticks,ddel,marker='o',color='tab:purple')
             plt.plot(dticks,ddur,color='black',linewidth=0.5,linestyle='dotted')
             plt.plot(dticks,ddel,color='black',linewidth=0.5,linestyle='dotted')
             # finalize
             plt.tight_layout()
             plt.show()

        return bps,dticks,dlabels





---

END OF SHARED SECTION

---



# [EXP DATA DICT]

In [None]:
mitdb_ex = set([
            '102','104','107','217', # paced
            '207',   # VFlutter
            '212', '231',   # both N and BBB
            '108', # bad signal
            '202','203' # bad labeling
            ])
svdb_ex = set([])
incartdb_ex = set([])

#<<--------------------------------------------
std_mitdb.recs_tag = set.difference(std_mitdb.recs, mitdb_ex)
std_svdb.recs_tag = set.difference(std_svdb.recs, svdb_ex)
std_incartdb.recs_tag = set.difference(std_incartdb.recs, incartdb_ex)

#<<--------------------------------------------
std_db_msi = {}
std_db_msi['mitdb']=std_mitdb
std_db_msi['svdb']=std_svdb
std_db_msi['incartdb']=std_incartdb

#<<--------------------------------------------
std_db_ms = {}
std_db_ms['mitdb']=std_mitdb
std_db_ms['svdb']=std_svdb

#<<--------------------------------------------
std_db_mi = {}
std_db_mi['mitdb']=std_mitdb
std_db_mi['incartdb']=std_incartdb

#<<--------------------------------------------
std_db_si = {}
std_db_si['svdb']=std_svdb
std_db_si['incartdb']=std_incartdb

#<<--------------------------------------------
std_db_m = {}
std_db_m['mitdb']=std_mitdb

#<<--------------------------------------------
std_db_s = {}
std_db_s['svdb']=std_svdb

#<<--------------------------------------------
std_db_i = {}
std_db_i['incartdb']=std_incartdb



# [ --- EXP_3 : RNN/LSTM --- ]

This experiment classifies beats into Normal or Abnormal category. Annotations used are N,S,V (F beats are included in V type)

# [ VIEW ANNOTATION MAPPERS ]

In [None]:
ls_ants = os.listdir(global_antdir)
ls_ants=np.sort(ls_ants)
print('Available annotation files ['+str(len(ls_ants))+']')
for ls_ant in ls_ants:
    print(ls_ant)
print('--------------------------')

Available annotation files [6]
default_labels.txt
default_map.txt
exp01_labels.txt
exp01_map.txt
nsvf_labels.txt
nsvf_map.txt
--------------------------


# [ MAP ANNOTATIONS ]

In [None]:
# standard labels and mappings default_labels
sel_labels = os.path.join(global_antdir, 'exp01_labels.txt') 
sel_map = os.path.join(global_antdir, 'exp01_map.txt') 

# ----------------------------------------------------------------------
# ------ load standard labels ------------------------------------------
# ----------------------------------------------------------------------
sel_labels_data = np.loadtxt(sel_labels, dtype='str',delimiter="\t")
g_STD_LABELS={}
print('\nStandard Labels::')
for a in sel_labels_data:
    # a[0] =  # standard label (char)
    # a[1] =  # mapped color (str)
    # a[2]  = # description (str)
    g_STD_LABELS[a[0]]= a[1]
    print(a[0]+'\t'+a[1]+'\t'+a[2])

# ----------------------------------------------------------------------
# ------ load mapping data ---------------------------------------------
# ----------------------------------------------------------------------
ant_map_data = np.loadtxt(sel_map, dtype='str',delimiter="\t")
g_STD_NO_MAP = '_'
g_STD_LABELS[g_STD_NO_MAP]='black'
g_STD_MAP={}
print('\nMapping::')
for a in ant_map_data:
    # a[0] =  # orignal pysionet label (char)
    # a[1] =  # mapped standard label (char)
    # a[2]  = # description (str)
    g_STD_MAP[a[0]]= a[1] ##<<----------------mapping dictionary
    print(a[0]+'\t'+a[1]+'\t'+a[2])
print('\n',g_STD_MAP.keys())


#<<--------------------------------------------
for idb in std_db_msi.keys():
    sel_db = std_db_msi[idb]
    for irec in sel_db.recs_tag:
        sel_rec = sel_db.get_record(irec)
        sel_info = sel_rec.read_binfo()
        sel_info.map_ants2int(g_STD_MAP)



Standard Labels::
N	green	Normal
S	red	Supraventricular Premature
V	blue	Ventricular Premature

Mapping::
N	N	Normal beat
L	N	Left bundle branch block beat
R	N	Right bundle branch block beat
B	N	Bundle branch block beat (unspecified)
A	S	Atrial premature beat
a	S	Aberrated atrial premature beat
J	S	Nodal (junctional) premature beat
S	S	Supraventricular premature or ectopic beat (atrial or nodal)
V	V	Premature ventricular contraction
r	V	R-on-T premature ventricular contraction
F	V	Fusion of ventricular and normal beat
e	_	Atrial escape beat
j	_	Nodal (junctional) escape beat
n	_	Supraventricular escape beat (atrial or nodal)
E	_	Ventricular escape beat
/	_	Paced beat
f	_	Fusion of paced and normal beat
Q	_	Unclassifiable 
?	_	Beat not classified during learning
[	_	Start of ventricular flutter/fibrillation
!	_	Ventricular flutter wave
]	_	End of ventricular flutter/fibrillation
x	_	Non-conducted P-wave (blocked APC)
(	_	Waveform onset
)	_	Waveform end
p	_	Peak of P-wave
t	_	Peak of T-

# [ MODEL ]

In [None]:
# variable length input LSTM model

cost = 'binary_crossentropy'
opt = 'adam'

def get_modelLSTM_01(print_summary):

# NR INPUT SIDE ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++

    iNR_0 = Input( shape=(None,1), name = "in_NR_0" )
    #bn_N =                              BatchNormalization()(iNR_0)
    lNR_1 =  LSTM                       (150,
                                        activation='tanh', 
                                        recurrent_activation='sigmoid', 
                                        use_bias=True, 
                                        kernel_initializer='glorot_uniform', 
                                        recurrent_initializer='orthogonal', 
                                        bias_initializer='zeros', unit_forget_bias=True, 
                                        kernel_regularizer=None, 
                                        recurrent_regularizer=None, 
                                        bias_regularizer=None, 
                                        activity_regularizer=None, 
                                        kernel_constraint=None, 
                                        recurrent_constraint=None, 
                                        bias_constraint=None, 
                                        dropout=0.0, 
                                        recurrent_dropout=0.0, 
                                        implementation=2, 
                                        return_sequences=True, 
                                        return_state=False, 
                                        go_backwards=False,
                                        stateful=False, 
                                        unroll=False,
                                        name='LSTM_NR_1')(iNR_0) 
    
    lNR_2 = LSTM                        (50,
                                        activation='tanh', 
                                        recurrent_activation='sigmoid', 
                                        use_bias=True, 
                                        kernel_initializer='glorot_uniform', 
                                        recurrent_initializer='orthogonal', 
                                        bias_initializer='zeros', unit_forget_bias=True, 
                                        kernel_regularizer=None, 
                                        recurrent_regularizer=None, 
                                        bias_regularizer=None, 
                                        activity_regularizer=None, 
                                        kernel_constraint=None, 
                                        recurrent_constraint=None, 
                                        bias_constraint=None, 
                                        dropout=0.0, 
                                        recurrent_dropout=0.0, 
                                        implementation=2, 
                                        return_sequences=False, 
                                        return_state=False, 
                                        go_backwards=False,
                                        stateful=False, 
                                        unroll=False,
                                        name='LSTM_NR_2')(lNR_1) 

# CR INPUT SIDE ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++

    iCR_0 = Input( shape=(None,1), name = "in_CR_0" )
    #bn_C =                              BatchNormalization()(iCR_0)
    lCR_1 =  LSTM                       (150,
                                        activation='tanh', 
                                        recurrent_activation='sigmoid', 
                                        use_bias=True, 
                                        kernel_initializer='glorot_uniform', 
                                        recurrent_initializer='orthogonal', 
                                        bias_initializer='zeros', unit_forget_bias=True, 
                                        kernel_regularizer=None, 
                                        recurrent_regularizer=None, 
                                        bias_regularizer=None, 
                                        activity_regularizer=None, 
                                        kernel_constraint=None, 
                                        recurrent_constraint=None, 
                                        bias_constraint=None, 
                                        dropout=0.0, 
                                        recurrent_dropout=0.0, 
                                        implementation=2, 
                                        return_sequences=True, 
                                        return_state=False, 
                                        go_backwards=False,
                                        stateful=False, 
                                        unroll=False,
                                        name='LSTM_CR_1')(iCR_0) 
    
    lCR_2 = LSTM                        (50,
                                        activation='tanh', 
                                        recurrent_activation='sigmoid', 
                                        use_bias=True, 
                                        kernel_initializer='glorot_uniform', 
                                        recurrent_initializer='orthogonal', 
                                        bias_initializer='zeros', unit_forget_bias=True, 
                                        kernel_regularizer=None, 
                                        recurrent_regularizer=None, 
                                        bias_regularizer=None, 
                                        activity_regularizer=None, 
                                        kernel_constraint=None, 
                                        recurrent_constraint=None, 
                                        bias_constraint=None, 
                                        dropout=0.0, 
                                        recurrent_dropout=0.0, 
                                        implementation=2, 
                                        return_sequences=False, 
                                        return_state=False, 
                                        go_backwards=False,
                                        stateful=False, 
                                        unroll=False,
                                        name='LSTM_CR_2')(lCR_1)    

# CONCAT NR and CR ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++

    d_concat =  tf.concat([lNR_2, lCR_2],axis=1, name = "dense_NC")

# DENSE SIDE ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++

    den_fc0 = Dense(20, activation=tf.nn.leaky_relu, name = "DENSE_FC0")(d_concat)
    den_fc1 = Dense(15, activation=tf.nn.leaky_relu, name = "DENSE_FC1")(den_fc0)

# OUTPUT SIDE ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
    den_out = Dense(1, activation=tf.nn.sigmoid, name = "OUTPUT_FC")(den_fc1)

# =========================================================================================

    model=Model(inputs=[iNR_0,iCR_0], outputs=den_out)

    #-------------------------------------
    #model.get_layer(name="LSTM_50").trainable=is_trainable
    #-------------------------------------

    model.compile(loss=cost, optimizer=opt, metrics=['accuracy'])
    if print_summary:
        print(model.summary())
    return model



# [ DATA GENERATOR ]

In [None]:
class DGen:
    def __init__(self, samplelist, ecgdb, shuffle_count):
        self.ecgdb = ecgdb
        self.shuffle=shuffle_count
        self.sample_list = samplelist
        self.nos_samples = len(samplelist)
        self.isample = -1


    def read_data_sample(self): # sel_dbA,sel_recA,sel_NR,sel_A,sel_A_label
        self.isample = -1
        while True:
            self.isample+=1
            if self.isample>=self.nos_samples:
                #print('Repeating...')
                for i in range (0,self.shuffle):
                    np.random.shuffle(self.sample_list)
                self.isample = 0
                


            cdata = self.sample_list[self.isample]

            data_db = cdata[0]
            data_rec = cdata[1]
            data_NR = int(cdata[2])
            data_CR = int(cdata[3])
            data_L = np.array([int(cdata[4])])

            #data_recO = self.ecgdb.get_record(data_rec)
            data_binfo = self.ecgdb[data_db].get_record(data_rec).read_binfo()
            NR_sig,_,_ = data_binfo.get_signal_data_var(data_NR)
            CR_sig,_,_ = data_binfo.get_signal_data_var(data_CR)


            
            NR_sig = np.expand_dims(np.expand_dims(NR_sig,axis=-1),axis=0)
            CR_sig = np.expand_dims(np.expand_dims(CR_sig,axis=-1),axis=0)
            data_L = np.expand_dims(data_L,axis=0)
            #print(cdata)
            #print(NR_sig.shape,CR_sig.shape,data_L.shape)

            #print('\nyield:',self.isample,self.iepoch)
            yield [NR_sig,CR_sig],data_L

    def read_data_sample_predict(self): # sel_dbA,sel_recA,sel_NR,sel_A,sel_A_label
        self.true_labels = []
        self.isample = -1
        while True:
            self.isample+=1
            if self.isample>=self.ifilelen:
                #print('Repeating...')
                self.isample = 0
                


            cdata = self.sample_list[self.isample]

            data_db = cdata[0]
            data_rec = cdata[1]
            data_NR = int(cdata[2])
            data_CR = int(cdata[3])
            data_L = np.array([int(cdata[4])])

            #data_recO = self.ecgdb.get_record(data_rec)
            data_binfo = self.ecgdb[data_db].get_record(data_rec).read_binfo()
            NR_sig,_,_ = data_binfo.get_signal_data_var(data_NR)
            CR_sig,_,_ = data_binfo.get_signal_data_var(data_CR)


            
            NR_sig = np.expand_dims(np.expand_dims(NR_sig,axis=-1),axis=0)
            CR_sig = np.expand_dims(np.expand_dims(CR_sig,axis=-1),axis=0)
            data_L = np.expand_dims(data_L,axis=0)
            #print(cdata)
            #print(NR_sig.shape,CR_sig.shape,data_L.shape)

            #print('\nyield:',self.isample,self.iepoch)
            self.true_labels.append(data_L[0])
            yield [NR_sig,CR_sig]


# [ TRAINING ]

## [ LOAD DATASET ]

In [None]:
ds_list = ['train_X']
timestamp_start = datetime.datetime.now()
ds_str = np.zeros((0,5),dtype='U10')
for ds_name in ds_list:
    ds_path = os.path.join(global_dsdir,ds_name+'.npy') 
    ds_this = np.load(ds_path)
    print('loaded:',ds_this.shape, ds_path)
    ds_str=np.vstack((ds_str,ds_this))


genD = DGen(ds_str,std_db_m,2)
print(genD.isample,'/',genD.nos_samples)

loaded: (414, 5) /content/drive/My Drive/Masters/workdir/ecg_data/db_dataset/train_X.npy
-1 / 414


## [ PERFORM TRAINING ]

In [None]:
timestamp_start = datetime.datetime.now()
model = get_modelLSTM_01(True)
hx = model.fit_generator(genD.read_data_sample(),epochs=30, steps_per_epoch=genD.nos_samples,verbose=1
                         #validation_data=unet_gen(),
                         #validation_steps=1
                         )

timestamp_dur = datetime.datetime.now() - timestamp_start
print('Elapsed time = ' + str(timestamp_dur))

Model: "model_11"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
in_NR_0 (InputLayer)            [(None, None, 1)]    0                                            
__________________________________________________________________________________________________
in_CR_0 (InputLayer)            [(None, None, 1)]    0                                            
__________________________________________________________________________________________________
LSTM_NR_1 (LSTM)                (None, None, 150)    91200       in_NR_0[0][0]                    
__________________________________________________________________________________________________
LSTM_CR_1 (LSTM)                (None, None, 150)    91200       in_CR_0[0][0]                    
___________________________________________________________________________________________

KeyboardInterrupt: ignored

## [ SAVE MODEL ]

In [None]:
ds_model = 'model_X'
# save this model
model_path = os.path.join(global_modeldir, ds_model+'.h5')
model.save_weights(model_path)
print('Saved Model Weights at : '+ str(model_path))

Saved Model Weights at : /content/drive/My Drive/Masters/workdir/ecg_data/db_exp/model_X.h5


# [ TESTING ]

## [ EVAL DATASET ]

In [None]:
ds_path = os.path.join(global_dsdir,'lstm')
print(ds_path)
genD = DGen(ds_path,std_db_m,2)
genD.set_up()
print(genD.isample,'/',genD.iepoch)
print(genD.nos_epochs,genD.iepochlen)
A_Label = 'V'
g_LABELS = ['N',A_Label]
ds_model = 'model_X'       # SELECT MODEL WEIGHTS TO TEST UPON
model_path = os.path.join(global_modeldir, ds_model+'.h5')
model=get_modelLSTM_01(False)
model.load_weights(model_path)
print('Loaded Model weights '+ str(model_path))
##<----------------------------------------------
timestamp_start = datetime.datetime.now()
print('Manual Prediction on : ' , ds_path)
predx = model.evaluate_generator(genD.read_data_sample(),steps=genD.nos_epochs*genD.iepochlen,verbose=1
                         #validation_data=unet_gen(),
                         #validation_steps=1
                         )

print(predx)
timestamp_dur = datetime.datetime.now() - timestamp_start
print('Elapsed time = ' + str(timestamp_dur))


/content/drive/My Drive/Masters/workdir/ecg_data/db_exp/lstm
-1 / 0
30 992
Loaded Model weights /content/drive/My Drive/Masters/workdir/ecg_data/db_exp/model_X.h5
Manual Prediction on :  /content/drive/My Drive/Masters/workdir/ecg_data/db_exp/lstm
Instructions for updating:
Please use Model.evaluate, which supports generators.
[0.6777151823043823, 0.9697580933570862]
Elapsed time = 0:16:19.189684


## [ PREDICT DATASET ]

In [None]:
ds_path = os.path.join(global_dsdir,'lstm')
print(ds_path)
genD = DGen(ds_path,std_db_m,2)
genD.set_up()
print(genD.isample,'/',genD.iepoch)
print(genD.nos_epochs,genD.iepochlen)
A_Label = 'V'
g_LABELS = ['N',A_Label]
ds_model = 'model_X'       # SELECT MODEL WEIGHTS TO TEST UPON
model_path = os.path.join(global_modeldir, ds_model+'.h5')
model=get_modelLSTM_01(False)
model.load_weights(model_path)
print('Loaded Model weights '+ str(model_path))
##<----------------------------------------------
timestamp_start = datetime.datetime.now()
print('Manual Prediction on : ' , ds_path)
predx = model.predict_generator(genD.read_data_sample_predict(),steps=genD.nos_epochs*genD.iepochlen,verbose=1
                         #validation_data=unet_gen(),
                         #validation_steps=1
                         )
#------------------------------------------------------------ manual prediction
data_y = np.array(genD.true_labels)
cmx_local = np.zeros((len(g_LABELS),len(g_LABELS)),dtype='int32')
cmx2_local = predx.argmax(axis=1)
for i in range(0,len(cmx2_local)):
    alabel = int(data_y[i])
    plabel = cmx2_local[i]
    cmx_local[alabel,plabel]+=1
print('\tConfusion Matrix')
print(print_conf_matrix( cmx_local,'', g_LABELS)) #logit('\t'+str(cmx))
print_performance( get_performance(cmx_local) ,g_LABELS ) 
#------------------------------------------------------------
timestamp_dur = datetime.datetime.now() - timestamp_start
print('Elapsed time = ' + str(timestamp_dur))


/content/drive/My Drive/Masters/workdir/ecg_data/db_exp/lstm
-1 / 0
30 992
Loaded Model weights /content/drive/My Drive/Masters/workdir/ecg_data/db_exp/model_X.h5
Manual Prediction on :  /content/drive/My Drive/Masters/workdir/ecg_data/db_exp/lstm
Instructions for updating:
Please use Model.predict, which supports generators.
	Confusion Matrix
A\P	N	V
N	14871	9
V	13989	891

Performance for 2 classes
Class	ACC	PRE	SEN	SPF
N	0.53	0.52	1.0	0.06
V	0.53	0.99	0.06	1.0
Elapsed time = 0:16:18.877093
