In [0]:
%reset

# [ 0_IMPORTS ]

In [0]:
import datetime
import os
import random
import statistics as stats
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker

import scipy.signal as scsig
from scipy.io import loadmat
from scipy.signal import medfilt
import pywt
from pywt import wavedec
#import hrv
#from ecgdetectors import Detectors

import tensorflow as tf
from tensorflow.keras.layers import Dense, Input, LSTM, Conv1D, MaxPooling1D, Flatten
from tensorflow.keras.models import Model, Sequential
from tensorflow import keras
from tensorflow.keras.utils import to_categorical
import tensorflow.keras.backend as kbend

from sklearn.cluster import DBSCAN
from sklearn import metrics
import scipy.spatial.distance as disf

# [ 1_GLOBAL ]

In [0]:

#===========================================================================================================
# DEFINE SOURCE DIRECTORIES
#===========================================================================================================
# > _ base working directory
global_dir = '/content/drive/My Drive/Masters/workdir/ecg_data' #'/home/spooky/ecg/workdir'
print('GLOBAL DIR :: '+global_dir)

# >> global annotation and mapping, common for all ECG from https://physionet.org/about/database/
global_annot = '/content/drive/My Drive/Masters/workdir/ecg_data/annotations.txt' #'/home/spooky/ecg/workdir/ecg_code/v2/annotations.txt'
#os.path.join(global_dir, 'annotations.txt' ) 
print('GLOBAL ANNOTATIONS :: '+global_annot)

# >> global model directory, contains model weights, use load_weights(), save_weights() 
global_modeldir = os.path.join(global_dir, 'db_model') 
os.makedirs(global_modeldir , exist_ok = True) 
print('GLOBAL MODEL DIR :: '+global_modeldir)

# >> global dataset directory, contains manually generated datasets to be used for experiments
global_datadir = os.path.join(global_dir, 'db_dataset') 
os.makedirs(global_datadir , exist_ok = True) 
print('GLOBAL DATA DIR :: '+global_datadir)


#===========================================================================================================
ds_name = 'custom_ds'
ds_dir = os.path.join(global_datadir, ds_name) 
os.makedirs(ds_dir , exist_ok = True) 
print('CUSTOM DATA DIR :: '+ds_dir)
#%%

#===========================================================================================================
#  Define annotation mapping dictionary (from global_annot)
#=========================================================================================================== 
g_map_data = np.loadtxt(global_annot, dtype='str',delimiter="\t")
g_map={}
g_map_beat_ants=[]
g_map_non_beat_ants=[]
print('ANNOTATION MAPPING :: ')
for a in g_map_data:

    g_mit_label = a[0] # orignal mit label (char)
    g_int_label = a[1] # mapped integer label (int)
    g_beat_description = a[2] # description (str)

    g_map[g_mit_label]= int(g_int_label) ##<<----------------mapping dictionary

    print(g_mit_label+'\t'+g_int_label+'\t'+g_beat_description)
    
    if int(g_int_label)>-2:
        g_map_beat_ants.append(g_mit_label)  #<<----beat annotation
    else:
        g_map_non_beat_ants.append(g_mit_label)  #<<----beat annotation

#g_map_keys = g_map.keys()
#print('\nAll Annotations : [' + str(len(g_map_keys))+'] :: ' + str(g_map_keys))
print('')
print('Beat Annotations : [' + str(len(g_map_beat_ants))+'] :: ' + str(g_map_beat_ants))
print('Non-Beat Annotations : [' + str(len(g_map_non_beat_ants))+'] :: ' + str(g_map_non_beat_ants))
print('All Annotations : [' + str(len(g_map.keys()))+'] :: ' + str(g_map.keys()))
print('')
def mapstd(peak_label):
    res = np.zeros(len(peak_label),dtype='int')
    for i in range(0, len(peak_label)):
        res[i] = g_map[peak_label[i]]
    return res

g_LMAX = np.max(np.array(list(g_map.values()))) # this is max mapping value starting from 0
g_LMIN = -1 # this means unmapped (RR beat), 
#anything less than g_LMIN is unmapped(Non RR beat) and anythone greater is mapped(RR beat)
print('Integer Label Range [ '+str(g_LMIN)+ ' : '+str(g_LMAX)+' ]')
g_LABELS= ['N','X']
g_COLOR=['tab:green','tab:red']
# [N] Normal, 
# [S] Supraventricular-Premature, 
# [V] Ventricular-Premature, 
# [F] Fusion, 
#===========================================================================================================

#------------------------------------------------------------------------------------------------
# signal sampling params
#------------------------------------------------------------------------------------------------
BASIC_SRATE = 128 #Hz
print('Basic sampling rate(Hz): '+str(BASIC_SRATE))

# fixed input dimension for beat vector
v_dimC = int(round((3*BASIC_SRATE))) # 3 seconds
print('Fixed dimension of beat vector: '+str(v_dimC))
# beat and non-beat annotations, signal data types
g_BEAT_POSTFIX, g_NBEAT_POSTFIX = 'BEAT', 'NBEAT' 
g_SIG_II_POSTFIX = 'SIG_II'
g_FIX_II_POSTFIX = 'FIX_II'

#%%

#------------------------------------------------------------------------------------------------
# Hear-Rate Params
#------------------------------------------------------------------------------------------------
H_min = 20          #bpm
max_rri = 60/H_min  #sec

H_low = 60          #bpm
hig_rri = 60/H_low  #sec

H_hig = 100         #bpm
low_rri = 60/H_hig  #sec

H_max = 240         #bpm
min_rri = 60/H_max  #sec

#%%
#=========================================================================================================================
#======================= NEURAL NETWORK PERFORMANCE MEASURES
#=========================================================================================================================
# 3.3 :: define performance evaluation functions

def get_performance(conf_matrix):
    #how many classes? = len of conf_matril
    nos_class = len(conf_matrix[0,:]) # len of 0th row
    res = np.zeros((0,8),dtype ='float64')
    for i in range(0,nos_class):
        # for each class calculate 4 performance measure - ACC, PRE, SEN, SPF, 
        # first compute TP, TN, FP, FN
        TP = conf_matrix[i,i]
        FP = np.sum(conf_matrix[:,i]) - TP
        FN = np.sum(conf_matrix[i,:]) - TP
        TN = np.sum(conf_matrix) - FN - FP - TP

        ACC = (TP+TN)   /   (TP+FP+FN+TN)
        PRE = (TP)      /   (TP+FP)
        SEN = (TP)      /   (TP+FN)
        SPF = (TN)      /   (TN+FP)

        res_i = np.array([TP, FN, FP, TN, ACC, PRE, SEN, SPF])
        res = np.vstack((res,res_i))
    return res


#------------------------------------------------------------------PRINTING

def print_lstr(class_labels):
    g_LSTR=''   # HEADER ROW for printing confusing matrix
    for i in range(0,len(class_labels)):
        g_LSTR+='\t'+str(class_labels[i])
    return  g_LSTR

def print_cf_row(cf_row,nos_labels):
    res = ''
    for j in range(0,nos_labels):
        res += '\t'+ str(cf_row[j])
    return res
def print_conf_matrix(conf_matrix, suffix, class_labels):
    res=(suffix+'A\\P' + print_lstr(class_labels)+'\n')
    nos_l=len(class_labels)
    for i in range(0,nos_l):
        res+=(suffix+str(class_labels[i]) + print_cf_row(conf_matrix[i],nos_l )+'\n')
    return res
def print_performance(perf_measures, class_labels):
    nos_class = len(perf_measures[:,0])
    print('Performance for '+str(nos_class)+' classes')
    print ('Class\tACC\tPRE\tSEN\tSPF')
    for i in range(0, nos_class):
        perf_i = np.round(perf_measures [i,:],2)
        #print('\tT.P : '+str(perf_i[0])+'\tF.N : '+str(perf_i[1]))
        #print('\tF.P : '+str(perf_i[2])+'\tT.N : '+str(perf_i[3]))
        print(str(class_labels[i])+'\t'+str(perf_i[4])+'\t'+str(perf_i[5])+'\t'+str(perf_i[6])+'\t'+str(perf_i[7]))
    return
#------------------------------------------------------------------


# [ 2_CLASS_DEFS ]

In [0]:
#Class Definitions
#---------------------------------------------------------------------------------------------------------------------------------------------
# CLASS ecg_db : represents one ECG database
#---------------------------------------------------------------------------------------------------------------------------------------------
class ecg_db:
    def __init__(self, dbname,  exclude_recs, tag_recs, sampling_rate):
        print('\nInitailze new ecg database ... ')
        self.name = dbname  #str
        self.srate = sampling_rate #float or int
        self.dir_ds = os.path.join(global_datadir , dbname+'_ds') #str
        self.recs_all = set(np.loadtxt(os.path.join(self.dir_ds,'RECORDS'), dtype='str',delimiter="\n")) #set
        self.recs_exc = set(exclude_recs)
        self.recs = set.difference(self.recs_all, self.recs_exc) 
        self.recs_tag = set(tag_recs)

        self.recs_dict = {} # initially empty, will be loaded on demand using function 'get_record'
        self.info()

    def info(self):
        print( 'DB NAME :: '+ self.name)
        print( 'SAMPLING RATE :: '+ str(self.srate))
        print( 'DATA DIR :: ' + self.dir_ds )
        print( 'RECORD SET :: [' +str(len(self.recs))+'] ' + str(self.recs) )
        return 0

    def get_record(self,rec):
        if not (rec in self.recs_dict.keys()):
            self.recs_dict[rec] = ecg_record(self,rec)
        return self.recs_dict[rec]
    
    def get_random_record(self, recset):
        rec = random.choice(list(recset))
        if not (rec in self.recs_dict.keys()):
            self.recs_dict[rec] = ecg_record(self,rec)
        return self.recs_dict[rec]

#---------------------------------------------------------------------------------------------------------------------------------------------

#---------------------------------------------------------------------------------------------------------------------------------------------
# CLASS ecg_record : represents one ECG Record in any database
#---------------------------------------------------------------------------------------------------------------------------------------------
g_SUPRESS_DATA_WARNING=False
class ecg_record:

    def __init__(self, db, recname):
        self.db = db                                # class:{ecg_db}    object this record belongs to
        self.rec = recname                          # string            name of this record
        self.name = db.name + '_'+ recname          # string            full name including db.name
        if recname in db.recs_all:
            if recname in db.recs_exc:
                print('WARNING:: Record "'+ recname +'" is marked excluded from database '+ db.name )
        else:
            print('WARNING:: Record "'+ recname +'" not found in database '+ db.name )
        self.data = {}                              # dict dict of npy data file content used in self.read_data('key')
        self.binfo = None                           # class binfo       

    def read_binfo(self):
        if self.binfo == None:
            self.binfo = ecg_binfo(self)
        return self.binfo
    
    def load_data(self, data_type):
        ipath = os.path.join(self.db.dir_ds, self.rec + '_'+data_type+'.npy')
        try: # try to load this data
            self.data[data_type] = np.load(ipath) # adds this to dictionary so next time it can read
            return self.data[data_type] #= np.load(self.dirs[s])
        except:
            if g_SUPRESS_DATA_WARNING == False:
                print('WARNING:: Cant load "'+data_type+ '" file at '+ str(ipath) )
            return np.array([])
        
    def read_data(self, data_type):
        if data_type in self.data.keys():
            return self.data[data_type] #= np.load(self.dirs[s])
        else:
            return self.load_data(data_type)

    def save_data(self, data_type, data_array):
        ipath = os.path.join(self.db.dir_ds, self.rec + '_'+data_type+'.npy')
        np.save(ipath, data_array)
        return ipath

    def del_data(self, data_type, vb):
        ipath = os.path.join(self.db.dir_ds, self.rec + '_'+data_type+'.npy')
        if os.path.exists(ipath):
            if vb:
                print('Removing: '+str(ipath))
            os.remove(ipath)
            return 1
        else:
            return 0

#---------------------------------------------------------------------------------------------------------------------------------------------

#------------------------------------------------------------------------------------------------

class ecg_binfo:
    def __init__(self, rec):
        self.rec = rec          # the record object
        rr_peaks_ants = rec.read_data(g_BEAT_POSTFIX)       # orignal ant file [ *  '625310' 'N' * ]
        # slice array
        rr_peaks_int = rr_peaks_ants[:,0].astype('int')     # col0 : samples * 62531 *  <---------------- not excluded
        rr_ants_str = rr_peaks_ants[:,1]                    # col1 : labels * 'N' *     <---------------- not excluded
        # excluded first and last
        self.rr_peaks = rr_peaks_int[1:-1]                  # col0 : samples (int) 62531 ==>==>==> sample# (orignal)
        self.rr_prev = rr_peaks_int[0:-2]                   # prev R peak (in samples)
        self.rr_next = rr_peaks_int[2:]                                           # next R peak (in samples)
        self.nos_rr_peaks = len(self.rr_peaks)              # no fo RR peaks (excluding first and last)
        
        self.rr_labels = rr_ants_str[1:-1]                  # col1 : labels (str)  'N'  ==>==>==> mit label (orignal)
        self.rr_plabels = rr_ants_str[0:-2]                  # col1 : labels (str)  'N'  ==>==>==> mit label (orignal)
        self.rr_nlabels = rr_ants_str[2:]                  # col1 : labels (str)  'N'  ==>==>==> mit label (orignal)
        
        self.rr_int_labels = mapstd(self.rr_labels)         # col1 : mapped int labels ==>==>==> int label (mapped -1, 0 ...) cant be -2 in 
        self.rr_int_plabels = mapstd(self.rr_plabels) 
        self.rr_int_nlabels = mapstd(self.rr_nlabels) 

        #temporal info
        self.rr_peaks_sec = self.rr_peaks / rec.db.srate             # col0 : time in sec (float) ==>==>==> sample (time in sec) sample#/srate
        self.rri_prev = (self.rr_peaks - self.rr_prev) / rec.db.srate   # prev RRI (in sec) 
        self.rri_next = (self.rr_next - self.rr_peaks) / rec.db.srate  # next RRI (in sec) 
        self.rri_delta = (self.rri_next - self.rri_prev)        # difference b/w prev and next RRI in seconds 
        self.rri_dur = (self.rri_next + self.rri_prev)
        #self.rri_avg = (self.dur) / (2)    # avg of prev and next RRI in seconds ==>==>==> length of the beat (prev R to next R peak)
        #self.rr_signal = []
        #self.rr_signal_fixed = []
        self.sr_ratio = BASIC_SRATE/self.rec.db.srate

    def get_signal_data_var(self, ith_peak): # data_type = g_SIG_II_POSTFIX
    # prev peak to next peak
        sel_sig = self.rec.read_data(g_SIG_II_POSTFIX) 
        ff = int(self.rr_prev[ith_peak]*self.sr_ratio)
        tt = int(self.rr_next[ith_peak]*self.sr_ratio)
        pp = int(self.rr_peaks[ith_peak]*self.sr_ratio)
        return sel_sig[ff:tt+1], (pp-ff) #<- also return position of peak
    
    def get_signal_data_fix(self, ith_peak, v_left_sec, v_right_sec): # data_type = g_SIG_II_POSTFIX
        sel_sig = self.rec.read_data(g_SIG_II_POSTFIX) 
        #sr_ratio = BASIC_SRATE/self.rec.db.srate
        ff = int((self.rr_peaks[ith_peak]*self.sr_ratio)-(v_left_sec*BASIC_SRATE))
        tt = int((self.rr_peaks[ith_peak]*self.sr_ratio)+(v_right_sec*BASIC_SRATE))
        pp = int(self.rr_peaks[ith_peak]*self.sr_ratio)
        return sel_sig[ff:tt+1], (pp-ff) #<- also return position of peak
    
#-----------------------------------------------------------------------------------------------


# [ 3_BUILD_ALL_DB ]

In [0]:
all_db = {}
#------------------------------------------------------------------------
mitdb_ex = ['102', '104', '107', '217', 
            '212', '231',  '207']
all_db['mitdb'] = ecg_db('mitdb',  mitdb_ex, [], 360)
#------------------------------------------------------------------------
svdb_ex = []
all_db['svdb'] = ecg_db('svdb',  svdb_ex, [], 128)
#------------------------------------------------------------------------
incartdb_ex = []
all_db['incartdb'] = ecg_db('incartdb', incartdb_ex, [], 257)
#------------------------------------------------------------------------
#print(all_db.values())
print('')
print(all_db.keys())


## [ 3.0_REPORT_DBS ]

In [0]:
heading = 'DB_RECORD'
for i in range(0,len(g_map_beat_ants)):
    heading+='\t'+str(g_map_beat_ants[i])
heading+='\tTOTAL'
print(heading)

for idb in all_db.keys():
    sel_db = all_db[idb]
    #if idb != 'mitdb':
    #    continue
    for irec in sel_db.recs:
       # if irec != '208':
       #     continue
        sel_rec = sel_db.get_record(irec)
        rst = sel_rec.name + '\t'

        sbi = sel_rec.read_binfo()

        count_all = 0
        for i in range(0,len(g_map_beat_ants)):
            count_i = len(np.where(sbi.rr_labels==g_map_beat_ants[i])[0])
            count_all+=count_i
            rst+=str(count_i)+'\t'
        rst+=str(count_all)
        print(rst)



## [ 3.1_REP_NORMALS ]

In [0]:
#------------------------------------------------------------------------
lim_first_F_sec = 5*60 # 5 or 30 minutes #<<====select this
#------------------------------------------------------------------------
work_db = all_db #<<------------ select working db

g_REP_II_POSTFIX = 'REP_'+str(lim_first_F_sec) # representative normal all record
lim_delta_rri = 0.04
lim_min_Nbeats = 3
#------------------------------------------------------------------------

print('=================================================')
print('Representative Normal')
print('lim_delta_rri:'+str(lim_delta_rri))
print('lim_min_Nbeats_per_episode:'+str(lim_min_Nbeats))
print('lim_first_F_sec:'+str(lim_first_F_sec))
print('g_REP_II_POSTFIX:'+str(g_REP_II_POSTFIX))
print('=================================================')

timestamp_start = datetime.datetime.now()
print('REC\ttotal_beats_in_n_query\ttotal_n_episodes\tvalid_n_episodes\tResult')

for idb in work_db.keys():
    sel_db = work_db[idb]
    #if idb != 'mitdb':
    #    continue
    for irec in sel_db.recs:      
       # if irec != '208':
      #      continue
        sel_rec = sel_db.get_record(irec)
        rst = sel_rec.name

#--------# load signal----------------------------------------
        sel_sig = sel_rec.read_data(g_SIG_II_POSTFIX)
        if len(sel_sig)<1:
            rst+='\tSignal doesnt exist, Skip this record'
            print(rst)
            continue
        
#--------# load beat info----------------------------------------
        sbi = sel_rec.read_binfo()

#-------# Normal Episodes----------------------------------------
        ne_list = []
        ne_query = (
                    (sbi.rr_peaks_sec <= lim_first_F_sec) &
                    (np.absolute(sbi.rri_delta)<=lim_delta_rri) & 
                    (sbi.rr_int_labels==0)  &
                    (sbi.rr_int_plabels==0)  &
                    (sbi.rr_int_nlabels==0)
                  )
        
        ne_list = np.where(ne_query)[0]
        
        if len(ne_list)<lim_min_Nbeats:
            rst+='\tNot enough N beats within rri limits, Skip this record'
            print(rst )
            continue
        else:
            rst+='\t'+str(len(ne_list))   
            
        ne_list1 = np.hstack((np.sort(ne_list),np.array([-1])))
        n_epi = []
        # extract episodes from ne_list
        n_s = ne_list[0]
        delta = 1
        for i in range(1, len(ne_list1)):
            n_e = ne_list1[i]
            if n_e == n_s + delta:
                delta+=1
            else:
                # check if enough number of N beats exist
                # if delta >=lim_min_Nbeats:
                n_epi.append([n_s,n_s+delta])
                n_s = n_e
                delta = 1

        if len(n_epi)==0:
            rst+='\tNot enough episodes, Skip this record'
            print(rst)
            continue    
        else:
            rst+='\t'+ str(len(n_epi))

        # select from all episodes
        # [epi_index,nos_beats, avg_dur,var_dur,max_dur(signal_len) ]
        
        mega_epi = np.zeros((0,2*v_dimC + 4),dtype='float')
        # 2*vdim for resampled median and mean, 
        # +4 for avg_dur, var_dur, nos_beats, actual signal_len
        
        epi_stats = np.zeros((0,4),dtype='float')
        #epi_short = 0
        for j in range(0, len(n_epi)):
            sepi = n_epi[j]
            bepi = sepi[1]-sepi[0]
            if bepi<lim_min_Nbeats:
                # print('Not enough beats in this episode, skip')
                # epi_short+=1
                continue
            else:
                sdur =  sbi.rri_dur[sepi[0]:sepi[1]]
                sdur_avg = round(stats.mean(sdur),3)
                sdur_var = round(stats.variance(sdur),5)
                epi_stats = np.vstack((epi_stats,np.array([j,bepi,sdur_avg,sdur_var])))
                # print('Avg_duration = '+ str(sdur_avg)+ ' | bpm = '+ str(60/sdur_avg))
                # print('VAr_duration = '+ str(sdur_var))
                # print('#'+str(j)+'\t'+str(bepi)+'\t'+str(sdur_avg)+'\t'+str(sdur_var))
                
                #now for each episode find            
                all_epi_signals = []
                all_epi_pk = []
                sg_left=[]
                sg_right=[]
                sll,slr = [],[]
                
                for i in range(sepi[0],sepi[1]):
                    sg,pk = sbi.get_signal_data_var(i)
                    all_epi_signals.append(sg)
                    all_epi_pk.append(pk)
                    sg_left.append(sg[0:pk])
                    sg_right.append(sg[pk:])
                    sll.append(pk)
                    slr.append(len(sg)-pk)
                
                ll_max = max(sll)
                lr_max = max(slr)
                l_max = ll_max+lr_max                
                
                sg_all2 =  np.zeros((0,l_max))
                for i in range(0, bepi):    
                    sg_all2 = np.vstack((
                                    sg_all2,
                                    np.hstack((
                                        scsig.resample(sg_left[i],ll_max), 
                                        scsig.resample(sg_right[i],lr_max)
                                        ))
                                    ))
                
                x3_men = np.zeros(l_max)
                x3_med = np.zeros(l_max)
                #x3_var = np.zeros(l_max)
                for i in range(0, l_max):
                    x3_men[i] = stats.mean(sg_all2[:,i])
                    x3_med[i] = stats.median(sg_all2[:,i])
                    #x3_var[i] = stats.variance(sg_all2[:,i])     
                
                # [nos_beats,signal_len, avgd,vard, mean,median] signal_len/Basic_srate = duration
                rec_epi = np.array([ bepi,l_max,sdur_avg,sdur_var ])
                rec_epi = np.hstack((rec_epi,
                                     scsig.resample(x3_men,v_dimC),
                                     scsig.resample(x3_med,v_dimC) ))
                mega_epi = np.vstack((mega_epi, rec_epi ))
        rst+= '\t'+ str(len(mega_epi))
        if len(mega_epi) > 0:      
            sel_rec.save_data(g_REP_II_POSTFIX, mega_epi)
            rst+= '\t Success'
        else:
            rst+= '\t Failed'
        print(rst)


print('\nDone')
timestamp_dur = datetime.datetime.now() - timestamp_start
print('Elapsed time = ' + str(timestamp_dur))
#------------------------------------------------------------------------


# [ 4_PLOT_SIGNALS ]


DB NAME :: mitdb

RECORD SET :: [48] 
{'114', '116', '111', '103', '230', '123', '232',
 '200', '234', '203', '209', '215', '231', '207', 
 '201', '202', '214', '101', '113', '210', '105', 
 '117', '119', '212', '122', '213', '205', '208', 
 '220', '233', '118', '124', '100', '228', '219', 
 '102', '108', '221', '107', '223', '106', '121', 
 '109', '217', '115', '222', '104', '112'}


DB NAME :: svdb

RECORD SET :: [78] 
{'802', '859', '845', '872', '879', '842', '843', 
 '868', '863', '891', '874', '848', '883', '854', 
 '829', '841', '844', '805', '889', '855', '850', 
 '856', '809', '825', '857', '812', '877', '811', 
 '840', '803', '810', '858', '887', '869', '881', 
 '822', '878', '867', '800', '804', '893', '876', 
 '849', '871', '875', '808', '862', '847', '846', 
 '827', '873', '821', '884', '828', '866', '824', 
 '865', '864', '820', '886', '888', '861', '826', 
 '801', '885', '860', '823', '851', '806', '892', 
 '852', '880', '807', '894', '853', '890', '882', '870'}


DB NAME :: incartdb

RECORD SET :: [75] 
{'I56', 'I07', 'I29', 'I33', 'I37', 'I42', 'I09', 
 'I12', 'I71', 'I04', 'I65', 'I74', 'I60', 'I61', 
 'I02', 'I57', 'I08', 'I05', 'I54', 'I67', 'I43', 
 'I19', 'I31', 'I27', 'I38', 'I39', 'I36', 'I17', 
 'I20', 'I52', 'I13', 'I45', 'I11', 'I73', 'I28', 
 'I01', 'I63', 'I49', 'I18', 'I46', 'I15', 'I26', 
 'I34', 'I50', 'I66', 'I72', 'I41', 'I53', 'I25', 
 'I62', 'I22', 'I35', 'I64', 'I55', 'I24', 'I03', 
 'I32', 'I48', 'I14', 'I68', 'I59', 'I70', 'I75', 
 'I47', 'I10', 'I06', 'I30', 'I40', 'I69', 'I51', 
 'I44', 'I16', 'I21', 'I58', 'I23'}

 dict_keys(['mitdb', 'svdb', 'incartdb'])

## [ 4.1_Select_Record ]

In [0]:

idb = 'mitdb'
irec = '214'

sel_db = all_db[idb]
sel_rec = sel_db.get_record(irec)
print(sel_rec.name)

# load signal
sel_sig = sel_rec.read_data(g_SIG_II_POSTFIX)
print('Total signal length @ ' + str(BASIC_SRATE) + 'Hz = ' + str(sel_sig.shape))

# load beat info
sel_binfo = sel_rec.read_binfo()
print('Total beats = '+ str(sel_binfo.nos_rr_peaks))


### [ 4.1.1_Query_Meta_Data ]

In [0]:
# what is the time stamp for ith R-peak
ith_rpeak = 1417

tstamp = sel_binfo.rr_peaks_sec[ith_rpeak]
print(tstamp)

## [ 4.2_Prepare_Variables ]

In [0]:
#-----------------------------------------
# seperate out beats
print('Type of beats')
sel_btypes = {}

xsum = 0
for i in range(0,len(g_LABELS)):
    sel_btypes[i] = np.where(sel_binfo.rr_int_labels==i)[0]
    print(str(g_LABELS[i])+'\t'+str(len(sel_btypes[i])))
    xsum+=len(sel_btypes[i])
print('sum\t'+str(xsum))

sel_btypes[-1] = np.where(sel_binfo.rr_int_labels==-1)[0]
print('x\t'+str(len(sel_btypes[-1])))


# copy to local variables
sel_rri_dur = sel_binfo.rri_dur
sel_rri_delta = sel_binfo.rri_delta
sel_rri_p = sel_binfo.rri_prev
sel_rri_n = sel_binfo.rri_next
sel_labels = sel_binfo.rr_int_labels
sel_xrange = np.arange(0,sel_binfo.nos_rr_peaks)


## [ 4.3_Plot_RRI_&_Duration ]

In [0]:
# plot RRI

plt.figure(0, figsize=(10,5))
plt.ylim(-0.5,3.5)

#plt.title('duration')
plt.plot(sel_rri_dur,linewidth = 0.5 ,color='black')
plt.scatter(sel_xrange,sel_rri_dur, label='duration', marker='.',color='black')
#plt.figure(1, figsize=(10,5))
#plt.ylim(-1,1)
#plt.title('delta')
plt.plot(np.absolute(sel_rri_delta),linewidth = 0.5,color='red' )
plt.scatter(sel_xrange,np.absolute(sel_rri_delta), label='delta', marker='.',color='red')


plt.hlines(0,0,sel_binfo.nos_rr_peaks, linewidth=0.3)
plt.scatter(sel_btypes[-1],np.zeros(len(sel_btypes[-1]))-0.25,  marker='.', color='black', label='unknown '+str(len(sel_btypes[-1])))
for i in range(0, len(g_LABELS)):
    plt.scatter(sel_btypes[i],np.zeros(len(sel_btypes[i]))-0.25,  marker='.', color=g_COLOR[i], label=g_LABELS[i]+' '+str(len(sel_btypes[i])))
plt.legend(bbox_to_anchor=(0,1.02,1,0.2), loc="lower left",
                mode="expand", borderaxespad=0, ncol=3)
plt.hlines(lim_delta_rri,0,sel_binfo.nos_rr_peaks,linewidth=0.3,color='green')


## [ 4.4_Plot_RRIs_only ]

In [0]:
# only Prev_RRIs for all beats

plt.figure(1, figsize=(10,5))
plt.ylim(0,3)
plt.title('RRIs(prev)')
plt.plot(sel_rri_p,linewidth = 0.5, color='black')
plt.scatter(sel_xrange,sel_rri_p, label='RRi', marker='.', color='black')
plt.scatter(sel_btypes[-1],np.zeros(len(sel_btypes[-1]))+2,  marker='.', color='black', label='unknown '+str(len(sel_btypes[-1])))
for i in range(0, len(g_LABELS)):
    plt.scatter(sel_btypes[i],np.zeros(len(sel_btypes[i]))+2,  marker='.', color=g_COLOR[i], label=g_LABELS[i]+' '+str(len(sel_btypes[i])))

plt.hlines(max_rri,0,sel_binfo.nos_rr_peaks, color='red', linewidth=0.5)
plt.hlines(min_rri,0,sel_binfo.nos_rr_peaks, color='tab:red', linewidth=0.5)
plt.hlines(hig_rri,0,sel_binfo.nos_rr_peaks,color='green', linewidth=0.5)
plt.hlines(low_rri,0,sel_binfo.nos_rr_peaks,color='tab:green', linewidth=0.5)
plt.legend(bbox_to_anchor=(0,1.02,1,0.2), loc="lower left",
                mode="expand", borderaxespad=0, ncol=3)


## [ 4.5_Plot_ECG_Signal ]

In [0]:
# plot signal segments

#<<---------------------------------------------Select Paper Resolution
x_scale = 25 * 0.0393701 # mm/sec -> inches/sec
y_scale = 10 * 0.0393701 # mm/mV -> inches/sec
y_low = -2.5
y_high = 3.5
#<<--------------------------------------------------------------------

#<<---------------------------------------------Select ECG Segment
fsec = 1430
tsec = fsec+(15)
dsec = tsec - fsec
#<<--------------------------------------------------------------------


ff = fsec * BASIC_SRATE
tt = tsec * BASIC_SRATE
dd = tt - ff

bps = sel_sig[ff:tt]

dticks = sel_binfo.rr_peaks[(sel_binfo.rr_peaks_sec >= fsec) & (sel_binfo.rr_peaks_sec < tsec)]
dlabels = sel_binfo.rr_labels[(sel_binfo.rr_peaks_sec >= fsec) & (sel_binfo.rr_peaks_sec < tsec)]
dticks = (dticks / sel_rec.db.srate)*BASIC_SRATE - ff

plt.figure(2, figsize = (dsec*x_scale ,(y_high-y_low) * y_scale) )
plt.xlim(0, len(bps))
plt.ylim(y_low,y_high)
plt.xticks(dticks,dlabels)
#x_grid = np.arange(0,tt-ff, 1*BASIC_SRATE)
#plt.xticks(x_grid)
plt.grid(axis='x')

#drris = sel_binfo.rri_delta[(sel_binfo.rr_peaks_sec >= fsec) & (sel_binfo.rr_peaks_sec < tsec)]
#drrid = sel_binfo.rri_dur[(sel_binfo.rr_peaks_sec >= fsec) & (sel_binfo.rr_peaks_sec < tsec)]
# RED: rri_delta
#plt.scatter(dticks,drris, marker='s',color='tab:red')
# GREEN = Duration
#plt.scatter(dticks,drrid, marker='s',color='tab:green')

plt.plot(bps, linewidth=0.5, color='black')
plt.hlines(0,0,len(bps), linewidth=0.3)


## [ 4.5_Plot_Single_Beats ]

In [0]:
#<<---------------------------------------------Select Beats Index to be plotted
bi = random.randint(0, sel_binfo.nos_rr_peaks-1)
#bi = 1730
bi =  np.random.choice(sel_btypes[1], size=1, replace=False, p=None)[0]
#<<-------------------------------------------------------------------------------
print('#'+ str(bi) + ' of ' + str(sel_binfo.nos_rr_peaks))

# plot beats
#sigs = sel_fsig[bi]
#sigt = sel_bsig[bi]

# print info
slabel = sel_binfo.rr_labels[bi]
ilabel = sel_binfo.rr_int_labels[bi]
tstamp = sel_binfo.rr_peaks_sec[bi]
tdur = sel_binfo.rri_dur[bi]
tprev = sel_binfo.rri_prev[bi]
tnext = sel_binfo.rri_next[bi]

print('Label: '+ slabel + '['+ str(ilabel) +']')
print('Timestamp: '+ str(tstamp))
print('Duration: '+ str(tdur))
print('RRIs: '+ str(tprev)+ ','+ str(tnext))

#---------------------------] Plotted using varible length (prev_R to next_R peak)
sg,pk = sel_binfo.get_signal_data_var(bi)
plt.figure(1)
plt.ylim(-2,3.5)
print(sg.shape)
print(pk)
plt.plot(sg)
plt.vlines(pk,-2,3.5, linewidth=0.5)

#------------------------------] Plotted using fixed length (in seconds) on either side of R peak
sg,pk = sel_binfo.get_signal_data_fix(bi,1.5,1.5)
plt.figure(2)
print(sg.shape)
print(pk)
plt.plot(sg)
plt.vlines(pk,0,2)

#------------------------------] Plotted Resampled signal and peak location
# X_res2 = scsig.resample(sg,v_dimC)
# plt.figure(3)
# print(X_res2.shape)
# plt.plot(X_res2)
# pk1 = pk*( len(X_res2)/(len(sg) ))
# plt.vlines(pk1,0,2)


## [ 4.6_Plot_Rep_Normal_Set ]

In [0]:
# plot representative set

#<<---------------------------------------------Select rep_norm set to be plotted
g_REP_II_POSTFIX = 'REP_1800'
#g_REP_II_POSTFIX = 'REP_300'


jx = sel_rec.load_data(g_REP_II_POSTFIX)
print (sel_rec.name+'\t'+str(jx.shape))
print ('Total reps = '+ str(len(jx)))

plt.figure(0)
plt.ylabel('duration variance')
plt.ylabel('nos beats')
plt.ylim(-0.001, 0.01)

plt.scatter(jx[:,0], jx[:,3],marker='.', color='black')
for i in range(0,len(jx)):
    plt.annotate(str(i),xy=(jx[i,0], jx[i,3]))

#[nos_beats,signal_len, avgd,vard, mean,median] signal_len/Basic_srate = duration
print('#epi\t#beats\tsig_len\tavg_dur\tvar_dur\tmax_dur')
plt.figure(1)
plt.ylim(-2,3.5)
plt.xlim(-10,v_dimC+10)
for cepi in range(0,len(jx)):
    iepi = jx[cepi]
    n0 = round(iepi[0])
    n1 = round(iepi[1])
    n2 = round(iepi[2],3)
    n3 = round(iepi[3],3)
    n4 = round(n1/(128),3)
    print(str(cepi)+'\t'+str(n0)+'\t'+str(n1)+'\t'+str(n2)+'\t'+str(n3)+'\t'+str(n4))
    n_med = iepi[-v_dimC:]
    n_men = iepi[4:4+v_dimC]
    #<<---------------------------------------------Select Rep_N Median or Mean Beat to be plotted
    plt.plot(n_med, linewidth=0.3, color='tab:green') 
    #plt.plot(n_men, linewidth=0.3, color='tab:blue')


### [ 4.6.1_Plot_Single_Rep_Normal ]

In [0]:
#<<---------------------------------------------Select index from Rep_Normal Set 
sel_episode = 12

print (sel_rec.name+'\tepisode#'+str(sel_episode))
iepi = jx[sel_episode]
jx_men = iepi[4:4+v_dimC]
jx_med = iepi[-v_dimC:]

n0 = round(iepi[0])
n1 = round(iepi[1])
n2 = round(iepi[2],3)
n3 = round(iepi[3],3)
n4 = round(n1/(128),3)
print('#beats = '+str(n0)+'\nsig_len = '+str(n1)+'\navg_dur = '+str(n2)+'\nvar_dur = '+str(n3)+'\nmax_dur = '+str(n4))

plt.figure(sel_episode+2)
plt.title('#epi '+str(sel_episode))
#plt.plot(jx_men,label='mean')
plt.plot(jx_med,label='median')
plt.legend()


# [ 5_PREPARE_WORKING_DBs ]

In [0]:
#-----------------------------------------------------------------------------------------
#-----------------------------------------------------------------------------------------
#-----------------------------------------------------------------------------------------
#<<--------------------- FOR TRAINING model_01
train_db1 = {} 
#------------------------------------------------------------------------
mitdb_ex = ['102', '104', '107', '217', 
            '212', '231',  '207']
            #'232'] #<- no rep normal
train_db1['mitdb'] = ecg_db('mitdb',  mitdb_ex, [], 360)
#------------------------------------------------------------------------
svdb_ex = [] #<-- les than 10 nrep normals
train_db1['svdb'] = ecg_db('svdb',  svdb_ex, [], 128)
#------------------------------------------------------------------------
#incartdb_ex = []
#train_db1['incartdb'] = ecg_db('incartdb', incartdb_ex, [], 257)
#------------------------------------------------------------------------
#print(train_db1.values())
print('')
print(train_db1.keys())


#-----------------------------------------------------------------------------------------
#-----------------------------------------------------------------------------------------
#-----------------------------------------------------------------------------------------
#<<--------------------- FOR TESTING model_01
test_db1 = {} 
#------------------------------------------------------------------------
#mitdb_ex = ['102', '104', '107', '217', 
#            '212', '231',  '207']
#            #'232'] #<- no rep normal
#test_db1['mitdb'] = ecg_db('mitdb',  mitdb_ex, [], 360)
#------------------------------------------------------------------------
#svdb_ex = [] #<-- les than 10 nrep normals
#test_db1['svdb'] = ecg_db('svdb',  svdb_ex, [], 128)
#------------------------------------------------------------------------
incartdb_ex = []
test_db1['incartdb'] = ecg_db('incartdb', incartdb_ex, [], 257)
#------------------------------------------------------------------------
#print(test_db1.values())
print('')
print(test_db1.keys())

# [ 6_GENRATE_DATASETS ]

In [0]:
# CELL 0
'''
1. Select train_db dict object [CELL 1]
2. Select g_CLASS_II_POSTFIX and limit values [CELL 2]
3. Select your query [CELL 3]
'''
#%% CELL 1
#------------------------------------------------------------------------
train_db = train_db1

#%% CELL 2
#------------------------------------------------------------------------
g_CLASS_II_POSTFIX = 'CLASS'   ##<<<<---------------- [select your post fix]
lim_lower, lim_upper = 40, 100 ##<<<<---------------- [select your beat limits]

# log file --------------------------------------------------------------------
log_file= os.path.join(global_datadir, g_CLASS_II_POSTFIX+'_db_build_log.txt') 
def print_log(log_string):
    log_handle.write(log_string+'\n')

#%% CELL 3
#------------------------------------------------------------------------

all_total_sveb = 0
log_handle = open(log_file,'w')
timestamp_start = datetime.datetime.now()
print('\n Start Iteration for '+g_CLASS_II_POSTFIX+' \n')
print_log(g_CLASS_II_POSTFIX+'_LOG_START ['+str(timestamp_start)+ ']')
print_log('nos_beat_limits[lower,upper] = ['+ str(lim_lower)+ ','+ str(lim_upper)+']')
for idb in train_db.keys():
    sel_db = train_db[idb]
    #if idb!='svdb':
    #    continue
    for irec in sel_db.recs:
    #    if irec!='865':
    #        continue
        rst = ''
        sel_rec = sel_db.get_record(irec)      
        sel_sig = sel_rec.read_data(g_SIG_II_POSTFIX)
        rst += str(sel_rec.name)+'\t'
        
        #============================================================
        print_log('\n\n[Selected Record = '+str(sel_rec.name)+ ']')
        #============================================================
        
        if len(sel_sig)<1:
            rst+='\tSignal doesnt exist, Skip this record'
            print(rst)
            
            #============================================================
            print_log(' >>Signal doesnt exist, Skip this record')
            #============================================================
                    
            continue
        
        sbi = sel_rec.read_binfo()      # load beat info
    

    #----------------------------------------------------------------------------------
    #----------------------------------------------------------------------------------
        sveb_list = []
        sveb_query = () ##<<<<<<<<<<<<<<<<<<<<---------------[Select your query]
        sveb_list = np.where(sveb_query)[0]
    #----------------------------------------------------------------------------------
    #----------------------------------------------------------------------------------
    
    
        
        if len(sveb_list)<lim_lower:
            rst+='\tNot enough CLASS beats, Skip this record'
            print(rst )
            #============================================================
            print_log(' >>Invalid query_count = '+str(len(sveb_list))+ ' - skip record')
            #============================================================
            continue
        else:
            
            #============================================================
            print_log(' >>Valid query_count = '+str(len(sveb_list)))
            #============================================================
            
            
            rst+='\tquery:'+str(len(sveb_list))
            sveb_sel = np.zeros((0,v_dimC+2),dtype='float') # +2 for peak location and duration
            
            if len(sveb_list)<=lim_upper: 
                #============================================================
                print_log('CASE_1::query_count <= upper_limit : Need to select all class beats, iterate ...')
                #============================================================
                rst+='\tquery<=upperlimit'
                # is within selction limits, select all class beats
                for ibeat in sveb_list:
                    sg,pk = sbi.get_signal_data_var(ibeat)
                    
                    #============================================================
                    print_log('\tbeat# '+str(ibeat) + '\tLabel='+sbi.rr_labels[ibeat]+ '\tTS='+str(round(sbi.rr_peaks_sec[ibeat],2))+'\tDUR='+str(sbi.rri_dur[ibeat]))
                    #============================================================
                    
                    sg_resamp = scsig.resample(sg,v_dimC)
                    pk_resamp = round(pk*( len(sg_resamp)/(len(sg) )))
                    beat_duration = sbi.rri_dur[ibeat]
                    
                    #[ orignal_duration(secs), peak_location(samples),resampled_signal{array}]
                    a_resamp = np.hstack((beat_duration,pk_resamp,sg_resamp))
                    
                    #============================================================
                    #print_log('\tMeta: ['+ str(len(a_resamp)) +']:'+str(a_resamp[0:3]))
                    #============================================================
                    
                    sveb_sel = np.vstack((sveb_sel, a_resamp)) # done now save it
                    
                #============================================================
                print_log('CASE_1::End of Selection, beats_selected = '+ str(sveb_sel.shape))
                #============================================================
                   
            else: # len(sveb_list)>lim_upper
                # more than upper_limit find episodes
                
                #============================================================
                print_log('CASE_2::query_count > upper_limit : Do not select all class beats. Find episodes...')
                #============================================================
                
                rst+='\tquery>upperlimit'
                sveb_list1 = np.hstack((np.sort(sveb_list),np.array([-1])))
                
                sveb_epi = []
                # extract episodes from sveb_list
                s_s = sveb_list[0]
                delta = 1
                for i in range(1, len(sveb_list1)):
                    s_e = sveb_list1[i]
                    if s_e == s_s + delta:
                        delta+=1
                    else:
                        sveb_epi.append([s_s,s_s+delta])
                        s_s = s_e
                        delta = 1      
                        
                if len(sveb_epi)==0:
                    rst+='\tImpossible::Not enough class episodes, Skip this record'#<<- this cannot happen
                    print(rst)
                    print_log('CASE_2::Impossible, no class episodes exist!!')
                    continue    
                else:
                    rst+='\t'+ str(len(sveb_epi))    
                #============================================================
                print_log(' >>found episodes : '+ str(len(sveb_epi))+'\n >>compare #episodes and upper selection limit..')
                #============================================================               
  
                if len(sveb_epi)<=lim_upper:
                    delta_ratio = lim_upper/len(sveb_list)
                    
                    #============================================================
                    print_log('CASE_2.1::Less episodes than upper limit : Selection ratio [upper_limit/total_beats] = '+str(round(delta_ratio,2)))
                    #============================================================                      
                    
                    rst+='\t#epi<=upperlimit, delta_ratio='+ str(round(delta_ratio,2))
                    # take delta_ratio times beats from each episode
                    
                    #============================================================ 
                    print_log(' >>Prepare rsel: take delta_ratio times beats from each episode, iterate...')
                    #============================================================ 
                    
                    for iepi in range(0,len(sveb_epi)):
                        i_episode = sveb_epi[iepi]

                        beats_in_epi = i_episode[1]-i_episode[0]

                        beats_taken = int(int(beats_in_epi*delta_ratio))
                        if beats_taken==0:
                            beats_taken = 1 # atleast take one beat from each episode
                        
                        #============================================================ 
                        print_log('\tepisode# '+ str(iepi)+'='+str(i_episode)+
                              ' has '+ str(beats_in_epi)+ ' beats, randomly take '+
                              str(beats_taken)+ ' beats')                        
                        #============================================================ 

                        a = np.arange(i_episode[0],i_episode[1])
                        rsel = np.random.choice(a, size=beats_taken, replace=False, p=None)
                        
                        print_log('\t >>Selected: '+ str(rsel)+' iterate...')
                        for ibeat in rsel:
                            sg,pk = sbi.get_signal_data_var(ibeat)
                            
                            #============================================================
                            print_log('\t\tbeat# '+str(ibeat) + '\tL='+sbi.rr_labels[ibeat]+ '\tTS='+str(round(sbi.rr_peaks_sec[ibeat],2))+'\tDUR='+str(sbi.rri_dur[ibeat]))
                            #============================================================
                            
                            sg_resamp = scsig.resample(sg,v_dimC)
                            pk_resamp = round(pk*( len(sg_resamp)/(len(sg) )))
                            beat_duration = sbi.rri_dur[ibeat]
                            
                            #[ orignal_duration(secs), peak_location(samples),resampled_signal{array}]
                            a_resamp = np.hstack((beat_duration,pk_resamp,sg_resamp))
                            
                            #============================================================ 
                            #print_log('\tMeta: ['+ str(len(a_resamp)) +']:'+str(a_resamp[0:3]))
                            #============================================================ 
                            
                            sveb_sel = np.vstack((sveb_sel, a_resamp))  # done now save it
                            
                        #============================================================
                        #print_log('\tend of episode iteration sbeats_selected(this episode) = '+ str(sveb_sel.shape))
                        #============================================================
                        
                    #============================================================
                    print_log('end of all episode iteration beats_selected(overall) = '+ str(sveb_sel.shape))
                    #============================================================
                    
                else: # is in b/w upper and lower limit, select one from each episode

                    delta_epi = int(len(sveb_epi) / lim_upper) #how many extra episodes
                    
                    #============================================================
                    print_log('CASE_2.2::More episodes than upper limit, index step = int[total_episodes / upperlimit] = '+str(round(delta_epi,2)))
                    #============================================================  

                    rst+='\t#epi>upperlimit delta_epi='+ str(round(delta_epi,2))
                    ist = 0
                    rsel=np.zeros(lim_upper,dtype='int')
                    
                    ##============================================================  
                    print_log(' >>Prepare rsel: step indices, select ONE beat randomly from each episode, iterate...')
                    #============================================================  
                    
                    for ix in range(0,lim_upper):
                        i_episode = sveb_epi[ist]
                        beats_in_epi = i_episode[1]-i_episode[0]
                        j_sel = i_episode[0]
                        
                        if beats_in_epi > 1:
                            j_sel = random.randint(i_episode[0],i_episode[1]-1)
                            
                        #============================================================  
                        print_log('\tepisode# '+ str(ist)+'='+str(i_episode)+
                              ' has '+ str(beats_in_epi)+ ' beats, take random # '+
                              str(j_sel))                            
                        #============================================================  
                        
                        rsel[ix] = j_sel
                        ist+=delta_epi
                        
                    #============================================================  
                    print_log('Selected Striding: '+ str(rsel)+' iterate...')
                    #============================================================  
                    
                    for ibeat in rsel:
                        sg,pk = sbi.get_signal_data_var(ibeat)
                        
                        #============================================================
                        print_log('\tbeat# '+str(ibeat) + '\tL='+sbi.rr_labels[ibeat]+ '\tTS='+str(round(sbi.rr_peaks_sec[ibeat],2))+'\tDUR='+str(sbi.rri_dur[ibeat]))
                        #============================================================
                        
                        sg_resamp = scsig.resample(sg,v_dimC)
                        pk_resamp = round(pk*( len(sg_resamp)/(len(sg) )))
                        beat_duration = sbi.rri_dur[ibeat]
                        
                        #[ orignal_duration(secs), peak_location(samples),resampled_signal{array}]
                        a_resamp = np.hstack((beat_duration,pk_resamp,sg_resamp))
                        
                        #============================================================  
                        #print_log('\tMeta: ['+ str(len(a_resamp)) +']:'+str(a_resamp[0:3]))
                        #============================================================  
                        
                        sveb_sel = np.vstack((sveb_sel, a_resamp))    
                        
                    #============================================================
                    print_log('end of iteration beats_selected = '+ str(sveb_sel.shape))
                    #============================================================
                    
            #============================================================
            print_log('end of record, total_beats_selected = '+ str(sveb_sel.shape))
            #============================================================
            
            rst+='\ttotal_CLASS_BEATS:'+ str(len(sveb_sel))
            print(rst)   
            
            sel_rec.save_data(g_CLASS_II_POSTFIX,sveb_sel)
            
            all_total_sveb+=len(sveb_sel)            
            
    # loop end record ----------------------------------------------------------------------------------
# loop end database ----------------------------------------------------------------------------------

timestamp_dur = datetime.datetime.now() - timestamp_start
print('Elapsed time = ' + str(timestamp_dur))
print('\nEnd of Procedure, grand_total_class_beats = '+ str(all_total_sveb))
print_log('\nEnd of Procedure, grand_total_beats = '+ str(all_total_sveb))
print_log('\n'+g_CLASS_II_POSTFIX+'_LOG_END, elapsed time = ['+str(timestamp_dur)+ ']')
log_handle.close()


## [ 6.1_S_CLASS_DS ]

In [0]:
# CELL 0
'''
1. Build train_db dict object [CELL 1]
2. Select g_CLASS_II_POSTFIX and limit values [CELL 2]
3. Select your query [CELL 3]
'''
#%% CELL 1
#------------------------------------------------------------------------

train_db = train_db1

#%% CELL 2
#------------------------------------------------------------------------

g_CLASS_II_POSTFIX = 'S'   ##<<<<---------------- [select your post fix]
lim_lower, lim_upper = 40, 100 ##<<<<---------------- [select your beat limits]

# log file --------------------------------------------------------------------
log_file= os.path.join(global_datadir, g_CLASS_II_POSTFIX+'_db_build_log.txt') 
def print_log(log_string):
    log_handle.write(log_string+'\n')

#%% CELL 3
#------------------------------------------------------------------------

all_total_sveb = 0
log_handle = open(log_file,'w')
timestamp_start = datetime.datetime.now()
print('\n Start Iteration for '+g_CLASS_II_POSTFIX+' \n')
print_log(g_CLASS_II_POSTFIX+'_LOG_START ['+str(timestamp_start)+ ']')
print_log('nos_beat_limits[lower,upper] = ['+ str(lim_lower)+ ','+ str(lim_upper)+']')
for idb in train_db.keys():
    sel_db = train_db[idb]
    #if idb!='svdb':
    #    continue
    for irec in sel_db.recs:
    #    if irec!='865':
    #        continue
        rst = ''
        sel_rec = sel_db.get_record(irec)      
        sel_sig = sel_rec.read_data(g_SIG_II_POSTFIX)
        rst += str(sel_rec.name)+'\t'
        
        #============================================================
        print_log('\n\n[Selected Record = '+str(sel_rec.name)+ ']')
        #============================================================
        
        if len(sel_sig)<1:
            rst+='\tSignal doesnt exist, Skip this record'
            print(rst)
            
            #============================================================
            print_log(' >>Signal doesnt exist, Skip this record')
            #============================================================
                    
            continue
        
        sbi = sel_rec.read_binfo()      # load beat info
    

    #----------------------------------------------------------------------------------
    #----------------------------------------------------------------------------------
        sveb_list = []
        sveb_query = (
                    (sbi.rr_labels=='A')  |
                    (sbi.rr_labels=='a')  |
                    (sbi.rr_labels=='J')  |
                    (sbi.rr_labels=='S')  
                  )
        ##<<<<<<<<<<<<<<<<<<<<---------------[Select your query]
        sveb_list = np.where(sveb_query)[0]
    #----------------------------------------------------------------------------------
    #----------------------------------------------------------------------------------
    
    
        
        if len(sveb_list)<lim_lower:
            rst+='\tNot enough CLASS beats, Skip this record'
            print(rst )
            #============================================================
            print_log(' >>Invalid query_count = '+str(len(sveb_list))+ ' - skip record')
            #============================================================
            continue
        else:
            
            #============================================================
            print_log(' >>Valid query_count = '+str(len(sveb_list)))
            #============================================================
            
            
            rst+='\tquery:'+str(len(sveb_list))
            sveb_sel = np.zeros((0,v_dimC+2),dtype='float') # +2 for peak location and duration
            
            if len(sveb_list)<=lim_upper: 
                #============================================================
                print_log('CASE_1::query_count <= upper_limit : Need to select all class beats, iterate ...')
                #============================================================
                rst+='\tquery<=upperlimit'
                # is within selction limits, select all class beats
                for ibeat in sveb_list:
                    sg,pk = sbi.get_signal_data_var(ibeat)
                    
                    #============================================================
                    print_log('\tbeat# '+str(ibeat) + '\tLabel='+sbi.rr_labels[ibeat]+ '\tTS='+str(round(sbi.rr_peaks_sec[ibeat],2))+'\tDUR='+str(sbi.rri_dur[ibeat]))
                    #============================================================
                    
                    sg_resamp = scsig.resample(sg,v_dimC)
                    pk_resamp = round(pk*( len(sg_resamp)/(len(sg) )))
                    beat_duration = sbi.rri_dur[ibeat]
                    
                    #[ orignal_duration(secs), peak_location(samples),resampled_signal{array}]
                    a_resamp = np.hstack((beat_duration,pk_resamp,sg_resamp))
                    
                    #============================================================
                    #print_log('\tMeta: ['+ str(len(a_resamp)) +']:'+str(a_resamp[0:3]))
                    #============================================================
                    
                    sveb_sel = np.vstack((sveb_sel, a_resamp)) # done now save it
                    
                #============================================================
                print_log('CASE_1::End of Selection, beats_selected = '+ str(sveb_sel.shape))
                #============================================================
                   
            else: # len(sveb_list)>lim_upper
                # more than upper_limit find episodes
                
                #============================================================
                print_log('CASE_2::query_count > upper_limit : Do not select all class beats. Find episodes...')
                #============================================================
                
                rst+='\tquery>upperlimit'
                sveb_list1 = np.hstack((np.sort(sveb_list),np.array([-1])))
                
                sveb_epi = []
                # extract episodes from sveb_list
                s_s = sveb_list[0]
                delta = 1
                for i in range(1, len(sveb_list1)):
                    s_e = sveb_list1[i]
                    if s_e == s_s + delta:
                        delta+=1
                    else:
                        sveb_epi.append([s_s,s_s+delta])
                        s_s = s_e
                        delta = 1      
                        
                if len(sveb_epi)==0:
                    rst+='\tImpossible::Not enough class episodes, Skip this record'#<<- this cannot happen
                    print(rst)
                    print_log('CASE_2::Impossible, no class episodes exist!!')
                    continue    
                else:
                    rst+='\t'+ str(len(sveb_epi))    
                #============================================================
                print_log(' >>found episodes : '+ str(len(sveb_epi))+'\n >>compare #episodes and upper selection limit..')
                #============================================================               
  
                if len(sveb_epi)<=lim_upper:
                    delta_ratio = lim_upper/len(sveb_list)
                    
                    #============================================================
                    print_log('CASE_2.1::Less episodes than upper limit : Selection ratio [upper_limit/total_beats] = '+str(round(delta_ratio,2)))
                    #============================================================                      
                    
                    rst+='\t#epi<=upperlimit, delta_ratio='+ str(round(delta_ratio,2))
                    # take delta_ratio times beats from each episode
                    
                    #============================================================ 
                    print_log(' >>Prepare rsel: take delta_ratio times beats from each episode, iterate...')
                    #============================================================ 
                    
                    for iepi in range(0,len(sveb_epi)):
                        i_episode = sveb_epi[iepi]

                        beats_in_epi = i_episode[1]-i_episode[0]

                        beats_taken = int(int(beats_in_epi*delta_ratio))
                        if beats_taken==0:
                            beats_taken = 1 # atleast take one beat from each episode
                        
                        #============================================================ 
                        print_log('\tepisode# '+ str(iepi)+'='+str(i_episode)+
                              ' has '+ str(beats_in_epi)+ ' beats, randomly take '+
                              str(beats_taken)+ ' beats')                        
                        #============================================================ 

                        a = np.arange(i_episode[0],i_episode[1])
                        rsel = np.random.choice(a, size=beats_taken, replace=False, p=None)
                        
                        print_log('\t >>Selected: '+ str(rsel)+' iterate...')
                        for ibeat in rsel:
                            sg,pk = sbi.get_signal_data_var(ibeat)
                            
                            #============================================================
                            print_log('\t\tbeat# '+str(ibeat) + '\tL='+sbi.rr_labels[ibeat]+ '\tTS='+str(round(sbi.rr_peaks_sec[ibeat],2))+'\tDUR='+str(sbi.rri_dur[ibeat]))
                            #============================================================
                            
                            sg_resamp = scsig.resample(sg,v_dimC)
                            pk_resamp = round(pk*( len(sg_resamp)/(len(sg) )))
                            beat_duration = sbi.rri_dur[ibeat]
                            
                            #[ orignal_duration(secs), peak_location(samples),resampled_signal{array}]
                            a_resamp = np.hstack((beat_duration,pk_resamp,sg_resamp))
                            
                            #============================================================ 
                            #print_log('\tMeta: ['+ str(len(a_resamp)) +']:'+str(a_resamp[0:3]))
                            #============================================================ 
                            
                            sveb_sel = np.vstack((sveb_sel, a_resamp))  # done now save it
                            
                        #============================================================
                        #print_log('\tend of episode iteration sbeats_selected(this episode) = '+ str(sveb_sel.shape))
                        #============================================================
                        
                    #============================================================
                    print_log('end of all episode iteration beats_selected(overall) = '+ str(sveb_sel.shape))
                    #============================================================
                    
                else: # is in b/w upper and lower limit, select one from each episode

                    delta_epi = int(len(sveb_epi) / lim_upper) #how many extra episodes
                    
                    #============================================================
                    print_log('CASE_2.2::More episodes than upper limit, index step = int[total_episodes / upperlimit] = '+str(round(delta_epi,2)))
                    #============================================================  

                    rst+='\t#epi>upperlimit delta_epi='+ str(round(delta_epi,2))
                    ist = 0
                    rsel=np.zeros(lim_upper,dtype='int')
                    
                    ##============================================================  
                    print_log(' >>Prepare rsel: step indices, select ONE beat randomly from each episode, iterate...')
                    #============================================================  
                    
                    for ix in range(0,lim_upper):
                        i_episode = sveb_epi[ist]
                        beats_in_epi = i_episode[1]-i_episode[0]
                        j_sel = i_episode[0]
                        
                        if beats_in_epi > 1:
                            j_sel = random.randint(i_episode[0],i_episode[1]-1)
                            
                        #============================================================  
                        print_log('\tepisode# '+ str(ist)+'='+str(i_episode)+
                              ' has '+ str(beats_in_epi)+ ' beats, take random # '+
                              str(j_sel))                            
                        #============================================================  
                        
                        rsel[ix] = j_sel
                        ist+=delta_epi
                        
                    #============================================================  
                    print_log('Selected Striding: '+ str(rsel)+' iterate...')
                    #============================================================  
                    
                    for ibeat in rsel:
                        sg,pk = sbi.get_signal_data_var(ibeat)
                        
                        #============================================================
                        print_log('\tbeat# '+str(ibeat) + '\tL='+sbi.rr_labels[ibeat]+ '\tTS='+str(round(sbi.rr_peaks_sec[ibeat],2))+'\tDUR='+str(sbi.rri_dur[ibeat]))
                        #============================================================
                        
                        sg_resamp = scsig.resample(sg,v_dimC)
                        pk_resamp = round(pk*( len(sg_resamp)/(len(sg) )))
                        beat_duration = sbi.rri_dur[ibeat]
                        
                        #[ orignal_duration(secs), peak_location(samples),resampled_signal{array}]
                        a_resamp = np.hstack((beat_duration,pk_resamp,sg_resamp))
                        
                        #============================================================  
                        #print_log('\tMeta: ['+ str(len(a_resamp)) +']:'+str(a_resamp[0:3]))
                        #============================================================  
                        
                        sveb_sel = np.vstack((sveb_sel, a_resamp))    
                        
                    #============================================================
                    print_log('end of iteration beats_selected = '+ str(sveb_sel.shape))
                    #============================================================
                    
            #============================================================
            print_log('end of record, total_beats_selected = '+ str(sveb_sel.shape))
            #============================================================
            
            rst+='\ttotal_CLASS_BEATS:'+ str(len(sveb_sel))
            print(rst)   
            
            sel_rec.save_data(g_CLASS_II_POSTFIX,sveb_sel)
            
            all_total_sveb+=len(sveb_sel)            
            
    # loop end record ----------------------------------------------------------------------------------
# loop end database ----------------------------------------------------------------------------------

timestamp_dur = datetime.datetime.now() - timestamp_start
print('Elapsed time = ' + str(timestamp_dur))
print('\nEnd of Procedure, grand_total_class_beats = '+ str(all_total_sveb))
print_log('\nEnd of Procedure, grand_total_beats = '+ str(all_total_sveb))
print_log('\n'+g_CLASS_II_POSTFIX+'_LOG_END, elapsed time = ['+str(timestamp_dur)+ ']')
log_handle.close()


## [ 6.2_V_CLASS_DS ]

In [0]:
# CELL 0
'''
1. Build train_db dict object [CELL 1]
2. Select g_CLASS_II_POSTFIX and limit values [CELL 2]
3. Select your query [CELL 3]
'''
#%% CELL 1
#------------------------------------------------------------------------

train_db = train_db1

#%% CELL 2
#------------------------------------------------------------------------

g_CLASS_II_POSTFIX = 'V'   ##<<<<---------------- [select your post fix]
lim_lower, lim_upper = 40, 100 ##<<<<---------------- [select your beat limits]

# log file --------------------------------------------------------------------
log_file= os.path.join(global_datadir, g_CLASS_II_POSTFIX+'_db_build_log.txt') 
def print_log(log_string):
    log_handle.write(log_string+'\n')

#%% CELL 3
#------------------------------------------------------------------------

all_total_sveb = 0
log_handle = open(log_file,'w')
timestamp_start = datetime.datetime.now()
print('\n Start Iteration for '+g_CLASS_II_POSTFIX+' \n')
print_log(g_CLASS_II_POSTFIX+'_LOG_START ['+str(timestamp_start)+ ']')
print_log('nos_beat_limits[lower,upper] = ['+ str(lim_lower)+ ','+ str(lim_upper)+']')
for idb in train_db.keys():
    sel_db = train_db[idb]
    #if idb!='svdb':
    #    continue
    for irec in sel_db.recs:
    #    if irec!='865':
    #        continue
        rst = ''
        sel_rec = sel_db.get_record(irec)      
        sel_sig = sel_rec.read_data(g_SIG_II_POSTFIX)
        rst += str(sel_rec.name)+'\t'
        
        #============================================================
        print_log('\n\n[Selected Record = '+str(sel_rec.name)+ ']')
        #============================================================
        
        if len(sel_sig)<1:
            rst+='\tSignal doesnt exist, Skip this record'
            print(rst)
            
            #============================================================
            print_log(' >>Signal doesnt exist, Skip this record')
            #============================================================
                    
            continue
        
        sbi = sel_rec.read_binfo()      # load beat info
    

    #----------------------------------------------------------------------------------
    #----------------------------------------------------------------------------------
        sveb_list = []
        sveb_query = ((sbi.rr_labels=='V')) ##<<<<<<<<<<<<<<<<<<<<---------------[Select your query]
        sveb_list = np.where(sveb_query)[0]
    #----------------------------------------------------------------------------------
    #----------------------------------------------------------------------------------
    
    
        
        if len(sveb_list)<lim_lower:
            rst+='\tNot enough CLASS beats, Skip this record'
            print(rst )
            #============================================================
            print_log(' >>Invalid query_count = '+str(len(sveb_list))+ ' - skip record')
            #============================================================
            continue
        else:
            
            #============================================================
            print_log(' >>Valid query_count = '+str(len(sveb_list)))
            #============================================================
            
            
            rst+='\tquery:'+str(len(sveb_list))
            sveb_sel = np.zeros((0,v_dimC+2),dtype='float') # +2 for peak location and duration
            
            if len(sveb_list)<=lim_upper: 
                #============================================================
                print_log('CASE_1::query_count <= upper_limit : Need to select all class beats, iterate ...')
                #============================================================
                rst+='\tquery<=upperlimit'
                # is within selction limits, select all class beats
                for ibeat in sveb_list:
                    sg,pk = sbi.get_signal_data_var(ibeat)
                    
                    #============================================================
                    print_log('\tbeat# '+str(ibeat) + '\tLabel='+sbi.rr_labels[ibeat]+ '\tTS='+str(round(sbi.rr_peaks_sec[ibeat],2))+'\tDUR='+str(sbi.rri_dur[ibeat]))
                    #============================================================
                    
                    sg_resamp = scsig.resample(sg,v_dimC)
                    pk_resamp = round(pk*( len(sg_resamp)/(len(sg) )))
                    beat_duration = sbi.rri_dur[ibeat]
                    
                    #[ orignal_duration(secs), peak_location(samples),resampled_signal{array}]
                    a_resamp = np.hstack((beat_duration,pk_resamp,sg_resamp))
                    
                    #============================================================
                    #print_log('\tMeta: ['+ str(len(a_resamp)) +']:'+str(a_resamp[0:3]))
                    #============================================================
                    
                    sveb_sel = np.vstack((sveb_sel, a_resamp)) # done now save it
                    
                #============================================================
                print_log('CASE_1::End of Selection, beats_selected = '+ str(sveb_sel.shape))
                #============================================================
                   
            else: # len(sveb_list)>lim_upper
                # more than upper_limit find episodes
                
                #============================================================
                print_log('CASE_2::query_count > upper_limit : Do not select all class beats. Find episodes...')
                #============================================================
                
                rst+='\tquery>upperlimit'
                sveb_list1 = np.hstack((np.sort(sveb_list),np.array([-1])))
                
                sveb_epi = []
                # extract episodes from sveb_list
                s_s = sveb_list[0]
                delta = 1
                for i in range(1, len(sveb_list1)):
                    s_e = sveb_list1[i]
                    if s_e == s_s + delta:
                        delta+=1
                    else:
                        sveb_epi.append([s_s,s_s+delta])
                        s_s = s_e
                        delta = 1      
                        
                if len(sveb_epi)==0:
                    rst+='\tImpossible::Not enough class episodes, Skip this record'#<<- this cannot happen
                    print(rst)
                    print_log('CASE_2::Impossible, no class episodes exist!!')
                    continue    
                else:
                    rst+='\t'+ str(len(sveb_epi))    
                #============================================================
                print_log(' >>found episodes : '+ str(len(sveb_epi))+'\n >>compare #episodes and upper selection limit..')
                #============================================================               
  
                if len(sveb_epi)<=lim_upper:
                    delta_ratio = lim_upper/len(sveb_list)
                    
                    #============================================================
                    print_log('CASE_2.1::Less episodes than upper limit : Selection ratio [upper_limit/total_beats] = '+str(round(delta_ratio,2)))
                    #============================================================                      
                    
                    rst+='\t#epi<=upperlimit, delta_ratio='+ str(round(delta_ratio,2))
                    # take delta_ratio times beats from each episode
                    
                    #============================================================ 
                    print_log(' >>Prepare rsel: take delta_ratio times beats from each episode, iterate...')
                    #============================================================ 
                    
                    for iepi in range(0,len(sveb_epi)):
                        i_episode = sveb_epi[iepi]

                        beats_in_epi = i_episode[1]-i_episode[0]

                        beats_taken = int(int(beats_in_epi*delta_ratio))
                        if beats_taken==0:
                            beats_taken = 1 # atleast take one beat from each episode
                        
                        #============================================================ 
                        print_log('\tepisode# '+ str(iepi)+'='+str(i_episode)+
                              ' has '+ str(beats_in_epi)+ ' beats, randomly take '+
                              str(beats_taken)+ ' beats')                        
                        #============================================================ 

                        a = np.arange(i_episode[0],i_episode[1])
                        rsel = np.random.choice(a, size=beats_taken, replace=False, p=None)
                        
                        print_log('\t >>Selected: '+ str(rsel)+' iterate...')
                        for ibeat in rsel:
                            sg,pk = sbi.get_signal_data_var(ibeat)
                            
                            #============================================================
                            print_log('\t\tbeat# '+str(ibeat) + '\tL='+sbi.rr_labels[ibeat]+ '\tTS='+str(round(sbi.rr_peaks_sec[ibeat],2))+'\tDUR='+str(sbi.rri_dur[ibeat]))
                            #============================================================
                            
                            sg_resamp = scsig.resample(sg,v_dimC)
                            pk_resamp = round(pk*( len(sg_resamp)/(len(sg) )))
                            beat_duration = sbi.rri_dur[ibeat]
                            
                            #[ orignal_duration(secs), peak_location(samples),resampled_signal{array}]
                            a_resamp = np.hstack((beat_duration,pk_resamp,sg_resamp))
                            
                            #============================================================ 
                            #print_log('\tMeta: ['+ str(len(a_resamp)) +']:'+str(a_resamp[0:3]))
                            #============================================================ 
                            
                            sveb_sel = np.vstack((sveb_sel, a_resamp))  # done now save it
                            
                        #============================================================
                        #print_log('\tend of episode iteration sbeats_selected(this episode) = '+ str(sveb_sel.shape))
                        #============================================================
                        
                    #============================================================
                    print_log('end of all episode iteration beats_selected(overall) = '+ str(sveb_sel.shape))
                    #============================================================
                    
                else: # is in b/w upper and lower limit, select one from each episode

                    delta_epi = int(len(sveb_epi) / lim_upper) #how many extra episodes
                    
                    #============================================================
                    print_log('CASE_2.2::More episodes than upper limit, index step = int[total_episodes / upperlimit] = '+str(round(delta_epi,2)))
                    #============================================================  

                    rst+='\t#epi>upperlimit delta_epi='+ str(round(delta_epi,2))
                    ist = 0
                    rsel=np.zeros(lim_upper,dtype='int')
                    
                    ##============================================================  
                    print_log(' >>Prepare rsel: step indices, select ONE beat randomly from each episode, iterate...')
                    #============================================================  
                    
                    for ix in range(0,lim_upper):
                        i_episode = sveb_epi[ist]
                        beats_in_epi = i_episode[1]-i_episode[0]
                        j_sel = i_episode[0]
                        
                        if beats_in_epi > 1:
                            j_sel = random.randint(i_episode[0],i_episode[1]-1)
                            
                        #============================================================  
                        print_log('\tepisode# '+ str(ist)+'='+str(i_episode)+
                              ' has '+ str(beats_in_epi)+ ' beats, take random # '+
                              str(j_sel))                            
                        #============================================================  
                        
                        rsel[ix] = j_sel
                        ist+=delta_epi
                        
                    #============================================================  
                    print_log('Selected Striding: '+ str(rsel)+' iterate...')
                    #============================================================  
                    
                    for ibeat in rsel:
                        sg,pk = sbi.get_signal_data_var(ibeat)
                        
                        #============================================================
                        print_log('\tbeat# '+str(ibeat) + '\tL='+sbi.rr_labels[ibeat]+ '\tTS='+str(round(sbi.rr_peaks_sec[ibeat],2))+'\tDUR='+str(sbi.rri_dur[ibeat]))
                        #============================================================
                        
                        sg_resamp = scsig.resample(sg,v_dimC)
                        pk_resamp = round(pk*( len(sg_resamp)/(len(sg) )))
                        beat_duration = sbi.rri_dur[ibeat]
                        
                        #[ orignal_duration(secs), peak_location(samples),resampled_signal{array}]
                        a_resamp = np.hstack((beat_duration,pk_resamp,sg_resamp))
                        
                        #============================================================  
                        #print_log('\tMeta: ['+ str(len(a_resamp)) +']:'+str(a_resamp[0:3]))
                        #============================================================  
                        
                        sveb_sel = np.vstack((sveb_sel, a_resamp))    
                        
                    #============================================================
                    print_log('end of iteration beats_selected = '+ str(sveb_sel.shape))
                    #============================================================
                    
            #============================================================
            print_log('end of record, total_beats_selected = '+ str(sveb_sel.shape))
            #============================================================
            
            rst+='\ttotal_CLASS_BEATS:'+ str(len(sveb_sel))
            print(rst)   
            
            sel_rec.save_data(g_CLASS_II_POSTFIX,sveb_sel)
            
            all_total_sveb+=len(sveb_sel)            
            
    # loop end record ----------------------------------------------------------------------------------
# loop end database ----------------------------------------------------------------------------------

timestamp_dur = datetime.datetime.now() - timestamp_start
print('Elapsed time = ' + str(timestamp_dur))
print('\nEnd of Procedure, grand_total_class_beats = '+ str(all_total_sveb))
print_log('\nEnd of Procedure, grand_total_beats = '+ str(all_total_sveb))
print_log('\n'+g_CLASS_II_POSTFIX+'_LOG_END, elapsed time = ['+str(timestamp_dur)+ ']')
log_handle.close()


## [ 6.3_F_CLASS_DS ]

In [0]:
# CELL 0
'''
1. Build train_db dict object [CELL 1]
2. Select g_CLASS_II_POSTFIX and limit values [CELL 2]
3. Select your query [CELL 3]
'''
#%% CELL 1
#------------------------------------------------------------------------

train_db = train_db1

#%% CELL 2
#------------------------------------------------------------------------

g_CLASS_II_POSTFIX = 'F'   ##<<<<---------------- [select your post fix]
lim_lower, lim_upper = 40, 100 ##<<<<---------------- [select your beat limits]

# log file --------------------------------------------------------------------
log_file= os.path.join(global_datadir, g_CLASS_II_POSTFIX+'_db_build_log.txt') 
def print_log(log_string):
    log_handle.write(log_string+'\n')

#%% CELL 3
#------------------------------------------------------------------------

all_total_sveb = 0
log_handle = open(log_file,'w')
timestamp_start = datetime.datetime.now()
print('\n Start Iteration for '+g_CLASS_II_POSTFIX+' \n')
print_log(g_CLASS_II_POSTFIX+'_LOG_START ['+str(timestamp_start)+ ']')
print_log('nos_beat_limits[lower,upper] = ['+ str(lim_lower)+ ','+ str(lim_upper)+']')
for idb in train_db.keys():
    sel_db = train_db[idb]
    #if idb!='svdb':
    #    continue
    for irec in sel_db.recs:
    #    if irec!='865':
    #        continue
        rst = ''
        sel_rec = sel_db.get_record(irec)      
        sel_sig = sel_rec.read_data(g_SIG_II_POSTFIX)
        rst += str(sel_rec.name)+'\t'
        
        #============================================================
        print_log('\n\n[Selected Record = '+str(sel_rec.name)+ ']')
        #============================================================
        
        if len(sel_sig)<1:
            rst+='\tSignal doesnt exist, Skip this record'
            print(rst)
            
            #============================================================
            print_log(' >>Signal doesnt exist, Skip this record')
            #============================================================
                    
            continue
        
        sbi = sel_rec.read_binfo()      # load beat info
    

    #----------------------------------------------------------------------------------
    #----------------------------------------------------------------------------------
        sveb_list = []
        sveb_query = (
                    (sbi.rr_labels=='F')  |
                    (sbi.rr_labels=='e')  |
                    (sbi.rr_labels=='j')  |
                    (sbi.rr_labels=='n')  
                  )
        ##<<<<<<<<<<<<<<<<<<<<---------------[Select your query]
        sveb_list = np.where(sveb_query)[0]
    #----------------------------------------------------------------------------------
    #----------------------------------------------------------------------------------
    
    
        
        if len(sveb_list)<lim_lower:
            rst+='\tNot enough CLASS beats, Skip this record'
            print(rst )
            #============================================================
            print_log(' >>Invalid query_count = '+str(len(sveb_list))+ ' - skip record')
            #============================================================
            continue
        else:
            
            #============================================================
            print_log(' >>Valid query_count = '+str(len(sveb_list)))
            #============================================================
            
            
            rst+='\tquery:'+str(len(sveb_list))
            sveb_sel = np.zeros((0,v_dimC+2),dtype='float') # +2 for peak location and duration
            
            if len(sveb_list)<=lim_upper: 
                #============================================================
                print_log('CASE_1::query_count <= upper_limit : Need to select all class beats, iterate ...')
                #============================================================
                rst+='\tquery<=upperlimit'
                # is within selction limits, select all class beats
                for ibeat in sveb_list:
                    sg,pk = sbi.get_signal_data_var(ibeat)
                    
                    #============================================================
                    print_log('\tbeat# '+str(ibeat) + '\tLabel='+sbi.rr_labels[ibeat]+ '\tTS='+str(round(sbi.rr_peaks_sec[ibeat],2))+'\tDUR='+str(sbi.rri_dur[ibeat]))
                    #============================================================
                    
                    sg_resamp = scsig.resample(sg,v_dimC)
                    pk_resamp = round(pk*( len(sg_resamp)/(len(sg) )))
                    beat_duration = sbi.rri_dur[ibeat]
                    
                    #[ orignal_duration(secs), peak_location(samples),resampled_signal{array}]
                    a_resamp = np.hstack((beat_duration,pk_resamp,sg_resamp))
                    
                    #============================================================
                    #print_log('\tMeta: ['+ str(len(a_resamp)) +']:'+str(a_resamp[0:3]))
                    #============================================================
                    
                    sveb_sel = np.vstack((sveb_sel, a_resamp)) # done now save it
                    
                #============================================================
                print_log('CASE_1::End of Selection, beats_selected = '+ str(sveb_sel.shape))
                #============================================================
                   
            else: # len(sveb_list)>lim_upper
                # more than upper_limit find episodes
                
                #============================================================
                print_log('CASE_2::query_count > upper_limit : Do not select all class beats. Find episodes...')
                #============================================================
                
                rst+='\tquery>upperlimit'
                sveb_list1 = np.hstack((np.sort(sveb_list),np.array([-1])))
                
                sveb_epi = []
                # extract episodes from sveb_list
                s_s = sveb_list[0]
                delta = 1
                for i in range(1, len(sveb_list1)):
                    s_e = sveb_list1[i]
                    if s_e == s_s + delta:
                        delta+=1
                    else:
                        sveb_epi.append([s_s,s_s+delta])
                        s_s = s_e
                        delta = 1      
                        
                if len(sveb_epi)==0:
                    rst+='\tImpossible::Not enough class episodes, Skip this record'#<<- this cannot happen
                    print(rst)
                    print_log('CASE_2::Impossible, no class episodes exist!!')
                    continue    
                else:
                    rst+='\t'+ str(len(sveb_epi))    
                #============================================================
                print_log(' >>found episodes : '+ str(len(sveb_epi))+'\n >>compare #episodes and upper selection limit..')
                #============================================================               
  
                if len(sveb_epi)<=lim_upper:
                    delta_ratio = lim_upper/len(sveb_list)
                    
                    #============================================================
                    print_log('CASE_2.1::Less episodes than upper limit : Selection ratio [upper_limit/total_beats] = '+str(round(delta_ratio,2)))
                    #============================================================                      
                    
                    rst+='\t#epi<=upperlimit, delta_ratio='+ str(round(delta_ratio,2))
                    # take delta_ratio times beats from each episode
                    
                    #============================================================ 
                    print_log(' >>Prepare rsel: take delta_ratio times beats from each episode, iterate...')
                    #============================================================ 
                    
                    for iepi in range(0,len(sveb_epi)):
                        i_episode = sveb_epi[iepi]

                        beats_in_epi = i_episode[1]-i_episode[0]

                        beats_taken = int(int(beats_in_epi*delta_ratio))
                        if beats_taken==0:
                            beats_taken = 1 # atleast take one beat from each episode
                        
                        #============================================================ 
                        print_log('\tepisode# '+ str(iepi)+'='+str(i_episode)+
                              ' has '+ str(beats_in_epi)+ ' beats, randomly take '+
                              str(beats_taken)+ ' beats')                        
                        #============================================================ 

                        a = np.arange(i_episode[0],i_episode[1])
                        rsel = np.random.choice(a, size=beats_taken, replace=False, p=None)
                        
                        print_log('\t >>Selected: '+ str(rsel)+' iterate...')
                        for ibeat in rsel:
                            sg,pk = sbi.get_signal_data_var(ibeat)
                            
                            #============================================================
                            print_log('\t\tbeat# '+str(ibeat) + '\tL='+sbi.rr_labels[ibeat]+ '\tTS='+str(round(sbi.rr_peaks_sec[ibeat],2))+'\tDUR='+str(sbi.rri_dur[ibeat]))
                            #============================================================
                            
                            sg_resamp = scsig.resample(sg,v_dimC)
                            pk_resamp = round(pk*( len(sg_resamp)/(len(sg) )))
                            beat_duration = sbi.rri_dur[ibeat]
                            
                            #[ orignal_duration(secs), peak_location(samples),resampled_signal{array}]
                            a_resamp = np.hstack((beat_duration,pk_resamp,sg_resamp))
                            
                            #============================================================ 
                            #print_log('\tMeta: ['+ str(len(a_resamp)) +']:'+str(a_resamp[0:3]))
                            #============================================================ 
                            
                            sveb_sel = np.vstack((sveb_sel, a_resamp))  # done now save it
                            
                        #============================================================
                        #print_log('\tend of episode iteration sbeats_selected(this episode) = '+ str(sveb_sel.shape))
                        #============================================================
                        
                    #============================================================
                    print_log('end of all episode iteration beats_selected(overall) = '+ str(sveb_sel.shape))
                    #============================================================
                    
                else: # is in b/w upper and lower limit, select one from each episode

                    delta_epi = int(len(sveb_epi) / lim_upper) #how many extra episodes
                    
                    #============================================================
                    print_log('CASE_2.2::More episodes than upper limit, index step = int[total_episodes / upperlimit] = '+str(round(delta_epi,2)))
                    #============================================================  

                    rst+='\t#epi>upperlimit delta_epi='+ str(round(delta_epi,2))
                    ist = 0
                    rsel=np.zeros(lim_upper,dtype='int')
                    
                    ##============================================================  
                    print_log(' >>Prepare rsel: step indices, select ONE beat randomly from each episode, iterate...')
                    #============================================================  
                    
                    for ix in range(0,lim_upper):
                        i_episode = sveb_epi[ist]
                        beats_in_epi = i_episode[1]-i_episode[0]
                        j_sel = i_episode[0]
                        
                        if beats_in_epi > 1:
                            j_sel = random.randint(i_episode[0],i_episode[1]-1)
                            
                        #============================================================  
                        print_log('\tepisode# '+ str(ist)+'='+str(i_episode)+
                              ' has '+ str(beats_in_epi)+ ' beats, take random # '+
                              str(j_sel))                            
                        #============================================================  
                        
                        rsel[ix] = j_sel
                        ist+=delta_epi
                        
                    #============================================================  
                    print_log('Selected Striding: '+ str(rsel)+' iterate...')
                    #============================================================  
                    
                    for ibeat in rsel:
                        sg,pk = sbi.get_signal_data_var(ibeat)
                        
                        #============================================================
                        print_log('\tbeat# '+str(ibeat) + '\tL='+sbi.rr_labels[ibeat]+ '\tTS='+str(round(sbi.rr_peaks_sec[ibeat],2))+'\tDUR='+str(sbi.rri_dur[ibeat]))
                        #============================================================
                        
                        sg_resamp = scsig.resample(sg,v_dimC)
                        pk_resamp = round(pk*( len(sg_resamp)/(len(sg) )))
                        beat_duration = sbi.rri_dur[ibeat]
                        
                        #[ orignal_duration(secs), peak_location(samples),resampled_signal{array}]
                        a_resamp = np.hstack((beat_duration,pk_resamp,sg_resamp))
                        
                        #============================================================  
                        #print_log('\tMeta: ['+ str(len(a_resamp)) +']:'+str(a_resamp[0:3]))
                        #============================================================  
                        
                        sveb_sel = np.vstack((sveb_sel, a_resamp))    
                        
                    #============================================================
                    print_log('end of iteration beats_selected = '+ str(sveb_sel.shape))
                    #============================================================
                    
            #============================================================
            print_log('end of record, total_beats_selected = '+ str(sveb_sel.shape))
            #============================================================
            
            rst+='\ttotal_CLASS_BEATS:'+ str(len(sveb_sel))
            print(rst)   
            
            sel_rec.save_data(g_CLASS_II_POSTFIX,sveb_sel)
            
            all_total_sveb+=len(sveb_sel)            
            
    # loop end record ----------------------------------------------------------------------------------
# loop end database ----------------------------------------------------------------------------------

timestamp_dur = datetime.datetime.now() - timestamp_start
print('Elapsed time = ' + str(timestamp_dur))
print('\nEnd of Procedure, grand_total_class_beats = '+ str(all_total_sveb))
print_log('\nEnd of Procedure, grand_total_beats = '+ str(all_total_sveb))
print_log('\n'+g_CLASS_II_POSTFIX+'_LOG_END, elapsed time = ['+str(timestamp_dur)+ ']')
log_handle.close()


## [ 6.4_N_CLASS_DS ]

In [0]:
# CELL 0
'''
1. Build train_db dict object [CELL 1]
2. Select g_CLASS_II_POSTFIX and limit values [CELL 2]
3. Select your query [CELL 3]
'''
#%% CELL 1
#------------------------------------------------------------------------

train_db = train_db1

#%% CELL 2
#------------------------------------------------------------------------

g_CLASS_II_POSTFIX = 'N'   ##<<<<---------------- [select your post fix]
lim_lower, lim_upper = 40, 100 ##<<<<---------------- [select your beat limits]

# log file --------------------------------------------------------------------
log_file= os.path.join(global_datadir, g_CLASS_II_POSTFIX+'_db_build_log.txt') 
def print_log(log_string):
    log_handle.write(log_string+'\n')

#%% CELL 3
#------------------------------------------------------------------------

all_total_sveb = 0
log_handle = open(log_file,'w')
timestamp_start = datetime.datetime.now()
print('\n Start Iteration for '+g_CLASS_II_POSTFIX+' \n')
print_log(g_CLASS_II_POSTFIX+'_LOG_START ['+str(timestamp_start)+ ']')
print_log('nos_beat_limits[lower,upper] = ['+ str(lim_lower)+ ','+ str(lim_upper)+']')
for idb in train_db.keys():
    sel_db = train_db[idb]
    #if idb!='mitdb':
     #   continue
    for irec in sel_db.recs:
        #if irec!='122':
         #   continue
        rst = ''
        sel_rec = sel_db.get_record(irec)      
        sel_sig = sel_rec.read_data(g_SIG_II_POSTFIX)
        rst += str(sel_rec.name)+'\t'
        
        #============================================================
        print_log('\n\n[Selected Record = '+str(sel_rec.name)+ ']')
        #============================================================
        
        if len(sel_sig)<1:
            rst+='\tSignal doesnt exist, Skip this record'
            print(rst)
            
            #============================================================
            print_log(' >>Signal doesnt exist, Skip this record')
            #============================================================
                    
            continue
        
        sbi = sel_rec.read_binfo()      # load beat info
    

    #----------------------------------------------------------------------------------
    #----------------------------------------------------------------------------------
        sveb_list = []
        sveb_query = (
                    #(np.absolute(sbi.rri_delta)<=lim_delta_rri) & 
                    (sbi.rr_int_labels==0)  &
                    (sbi.rr_int_plabels==0)  &
                    (sbi.rr_int_nlabels==0)
            ) ##<<<<<<<<<<<<<<<<<<<<---------------[Select your query]
        sveb_list = np.where(sveb_query)[0]
    #----------------------------------------------------------------------------------
    #----------------------------------------------------------------------------------
    
    
        
        if len(sveb_list)<lim_lower:
            rst+='\tNot enough CLASS beats, Skip this record'
            print(rst )
            #============================================================
            print_log(' >>Invalid query_count = '+str(len(sveb_list))+ ' - skip record')
            #============================================================
            continue
        else:
            
            #============================================================
            print_log(' >>Valid query_count = '+str(len(sveb_list)))
            #============================================================
            
            
            rst+='\tquery:'+str(len(sveb_list))
            sveb_sel = np.zeros((0,v_dimC+2),dtype='float') # +2 for peak location and duration
            
            if len(sveb_list)<=lim_upper: 
                #============================================================
                print_log('CASE_1::query_count <= upper_limit : Need to select all class beats, iterate ...')
                #============================================================
                rst+='\tquery<=upperlimit'
                # is within selction limits, select all class beats
                for ibeat in sveb_list:
                    sg,pk = sbi.get_signal_data_var(ibeat)
                    
                    #============================================================
                    print_log('\tbeat# '+str(ibeat) + '\tLabel='+sbi.rr_labels[ibeat]+ '\tTS='+str(round(sbi.rr_peaks_sec[ibeat],2))+'\tDUR='+str(sbi.rri_dur[ibeat]))
                    #============================================================
                    
                    sg_resamp = scsig.resample(sg,v_dimC)
                    pk_resamp = round(pk*( len(sg_resamp)/(len(sg) )))
                    beat_duration = sbi.rri_dur[ibeat]
                    
                    #[ orignal_duration(secs), peak_location(samples),resampled_signal{array}]
                    a_resamp = np.hstack((beat_duration,pk_resamp,sg_resamp))
                    
                    #============================================================
                    #print_log('\tMeta: ['+ str(len(a_resamp)) +']:'+str(a_resamp[0:3]))
                    #============================================================
                    
                    sveb_sel = np.vstack((sveb_sel, a_resamp)) # done now save it
                    
                #============================================================
                print_log('CASE_1::End of Selection, beats_selected = '+ str(sveb_sel.shape))
                #============================================================
                   
            else: # len(sveb_list)>lim_upper
                # more than upper_limit find episodes
                
                #============================================================
                print_log('CASE_2::query_count > upper_limit : Do not select all class beats. Find episodes...')
                #============================================================
                
                rst+='\tquery>upperlimit'
                sveb_list1 = np.hstack((np.sort(sveb_list),np.array([-1])))
                
                sveb_epi = []
                # extract episodes from sveb_list
                s_s = sveb_list[0]
                delta = 1
                for i in range(1, len(sveb_list1)):
                    s_e = sveb_list1[i]
                    if s_e == s_s + delta:
                        delta+=1
                    else:
                        sveb_epi.append([s_s,s_s+delta])
                        s_s = s_e
                        delta = 1      
                        
                if len(sveb_epi)==0:
                    rst+='\tImpossible::Not enough class episodes, Skip this record'#<<- this cannot happen
                    print(rst)
                    print_log('CASE_2::Impossible, no class episodes exist!!')
                    continue    
                else:
                    rst+='\t'+ str(len(sveb_epi))    
                #============================================================
                print_log(' >>found episodes : '+ str(len(sveb_epi))+'\n >>compare #episodes and upper selection limit..')
                #============================================================               
  
                if len(sveb_epi)<=lim_upper:
                    delta_ratio = lim_upper/len(sveb_list)
                    
                    #============================================================
                    print_log('CASE_2.1::Less episodes than upper limit : Selection ratio [upper_limit/total_beats] = '+str(round(delta_ratio,2)))
                    #============================================================                      
                    
                    rst+='\t#epi<=upperlimit, delta_ratio='+ str(round(delta_ratio,2))
                    # take delta_ratio times beats from each episode
                    
                    #============================================================ 
                    print_log(' >>Prepare rsel: take delta_ratio times beats from each episode, iterate...')
                    #============================================================ 
                    
                    for iepi in range(0,len(sveb_epi)):
                        i_episode = sveb_epi[iepi]

                        beats_in_epi = i_episode[1]-i_episode[0]

                        beats_taken = int(int(beats_in_epi*delta_ratio))
                        if beats_taken==0:
                            beats_taken = 1 # atleast take one beat from each episode
                        
                        #============================================================ 
                        print_log('\tepisode# '+ str(iepi)+'='+str(i_episode)+
                              ' has '+ str(beats_in_epi)+ ' beats, randomly take '+
                              str(beats_taken)+ ' beats')                        
                        #============================================================ 

                        a = np.arange(i_episode[0],i_episode[1])
                        rsel = np.random.choice(a, size=beats_taken, replace=False, p=None)
                        
                        print_log('\t >>Selected: '+ str(rsel)+' iterate...')
                        for ibeat in rsel:
                            sg,pk = sbi.get_signal_data_var(ibeat)
                            
                            #============================================================
                            print_log('\t\tbeat# '+str(ibeat) + '\tL='+sbi.rr_labels[ibeat]+ '\tTS='+str(round(sbi.rr_peaks_sec[ibeat],2))+'\tDUR='+str(sbi.rri_dur[ibeat]))
                            #============================================================
                            
                            sg_resamp = scsig.resample(sg,v_dimC)
                            pk_resamp = round(pk*( len(sg_resamp)/(len(sg) )))
                            beat_duration = sbi.rri_dur[ibeat]
                            
                            #[ orignal_duration(secs), peak_location(samples),resampled_signal{array}]
                            a_resamp = np.hstack((beat_duration,pk_resamp,sg_resamp))
                            
                            #============================================================ 
                            #print_log('\tMeta: ['+ str(len(a_resamp)) +']:'+str(a_resamp[0:3]))
                            #============================================================ 
                            
                            sveb_sel = np.vstack((sveb_sel, a_resamp))  # done now save it
                            
                        #============================================================
                        #print_log('\tend of episode iteration sbeats_selected(this episode) = '+ str(sveb_sel.shape))
                        #============================================================
                        
                    #============================================================
                    print_log('end of all episode iteration beats_selected(overall) = '+ str(sveb_sel.shape))
                    #============================================================
                    
                else: # is in b/w upper and lower limit, select one from each episode

                    delta_epi = int(len(sveb_epi) / lim_upper) #how many extra episodes
                    
                    #============================================================
                    print_log('CASE_2.2::More episodes than upper limit, index step = int[total_episodes / upperlimit] = '+str(round(delta_epi,2)))
                    #============================================================  

                    rst+='\t#epi>upperlimit delta_epi='+ str(round(delta_epi,2))
                    ist = 0
                    rsel=np.zeros(lim_upper,dtype='int')
                    
                    ##============================================================  
                    print_log(' >>Prepare rsel: step indices, select ONE beat randomly from each episode, iterate...')
                    #============================================================  
                    
                    for ix in range(0,lim_upper):
                        i_episode = sveb_epi[ist]
                        beats_in_epi = i_episode[1]-i_episode[0]
                        j_sel = i_episode[0]
                        
                        if beats_in_epi > 1:
                            j_sel = random.randint(i_episode[0],i_episode[1]-1)
                            
                        #============================================================  
                        print_log('\tepisode# '+ str(ist)+'='+str(i_episode)+
                              ' has '+ str(beats_in_epi)+ ' beats, take random # '+
                              str(j_sel))                            
                        #============================================================  
                        
                        rsel[ix] = j_sel
                        ist+=delta_epi
                        
                    #============================================================  
                    print_log('Selected Striding: '+ str(rsel)+' iterate...')
                    #============================================================  
                    
                    for ibeat in rsel:
                        sg,pk = sbi.get_signal_data_var(ibeat)
                        
                        #============================================================
                        print_log('\tbeat# '+str(ibeat) + '\tL='+sbi.rr_labels[ibeat]+ '\tTS='+str(round(sbi.rr_peaks_sec[ibeat],2))+'\tDUR='+str(sbi.rri_dur[ibeat]))
                        #============================================================
                        
                        sg_resamp = scsig.resample(sg,v_dimC)
                        pk_resamp = round(pk*( len(sg_resamp)/(len(sg) )))
                        beat_duration = sbi.rri_dur[ibeat]
                        
                        #[ orignal_duration(secs), peak_location(samples),resampled_signal{array}]
                        a_resamp = np.hstack((beat_duration,pk_resamp,sg_resamp))
                        
                        #============================================================  
                        #print_log('\tMeta: ['+ str(len(a_resamp)) +']:'+str(a_resamp[0:3]))
                        #============================================================  
                        
                        sveb_sel = np.vstack((sveb_sel, a_resamp))    
                        
                    #============================================================
                    print_log('end of iteration beats_selected = '+ str(sveb_sel.shape))
                    #============================================================
                    
            #============================================================
            print_log('end of record, total_beats_selected = '+ str(sveb_sel.shape))
            #============================================================
            
            rst+='\ttotal_CLASS_BEATS:'+ str(len(sveb_sel))
            print(rst)   
            
            sel_rec.save_data(g_CLASS_II_POSTFIX,sveb_sel)
            
            all_total_sveb+=len(sveb_sel)            
            
    # loop end record ----------------------------------------------------------------------------------
# loop end database ----------------------------------------------------------------------------------

timestamp_dur = datetime.datetime.now() - timestamp_start
print('Elapsed time = ' + str(timestamp_dur))
print('\nEnd of Procedure, grand_total_class_beats = '+ str(all_total_sveb))
print_log('\nEnd of Procedure, grand_total_beats = '+ str(all_total_sveb))
print_log('\n'+g_CLASS_II_POSTFIX+'_LOG_END, elapsed time = ['+str(timestamp_dur)+ ']')
log_handle.close()


# [ 7_COMPILE_DATASETS ]

## [ 7.1_Compile_Training_Set ]

In [0]:
#<<---------------------------------- Select save name of this dataset
ds_name = 'custom_set_1'
ds_path = os.path.join(ds_dir, ds_name+'.npy') 

#<<---------------------------------- Select which working db to compile from
train_db = train_db1    

#<<---------------------------------- SELCT PARAMS
g_REP_NORM_POSTFIX = 'REP_1800' # for taining use rep norm from full record
g_CLASS_N_POSTFIX = 'N'
g_CLASS_S_POSTFIX = 'S'
g_CLASS_V_POSTFIX = 'V'
g_CLASS_F_POSTFIX = 'F'
lim_min_rep_norms = 10 # at least this many rep norms must exist
#lim_lower, lim_upper = 40, 100 ##<<<<---------------- [select your beat limits]
#------------------------------------------------------------------------

g_SUPRESS_DATA_WARNING=True # supress 'file not found' warnings from ecg_record class

timestamp_start = datetime.datetime.now()
print('\n Start Iteration for train_db \n')
print('DB_REC\t#r\t#n\t#s\t#v\t#f\t#a\tN~A\tselN\tselA\tselNA\trec_input')
mega_input = np.zeros((0,1+v_dimC+1+v_dimC+1))
for idb in train_db.keys():
    sel_db = train_db[idb]
    #if idb!='svdb':
     #   continue
    for irec in sel_db.recs:
        #if irec!='865':
         #   continue
        sel_rec = sel_db.get_record(irec)   
        
        #============================================================
        # print_log('\n\n[Selected Record = '+str(sel_rec.name)+ ']')
        #============================================================
        rst=sel_rec.name+'\t'
        
        # first laod representative normal file
        npy_rep = sel_rec.read_data(g_REP_NORM_POSTFIX)
        nos_rep = len(npy_rep)
        
        # check if enough normal episodes
        if nos_rep<lim_min_rep_norms:
            #============================================================
            # print_log(' >> Not enough Representative Normals, skip record')
            #============================================================ 
            rst+=' >> Not enough Representative Normals, skip record\t'
            print(rst)
            continue
        else:
            #============================================================
            # print_log(' >> Nos Representative Normals = '+str(nos_rep))
            #============================================================           
            rst+=str(nos_rep)+'\t'
        
        # Now load S,V,F beats, returns blank array if file doesn't exist
        npy_N = sel_rec.read_data(g_CLASS_N_POSTFIX)
        npy_S = sel_rec.read_data(g_CLASS_S_POSTFIX)
        npy_V = sel_rec.read_data(g_CLASS_V_POSTFIX)
        npy_F = sel_rec.read_data(g_CLASS_F_POSTFIX)
        
        nos_N, nos_S, nos_V, nos_F = len(npy_N), len(npy_S), len(npy_V), len(npy_F)
        nos_A = nos_S + nos_V + nos_F
        
        #============================================================
        # print_log(' >> Nos [N] = ['+str(nos_N)+']')
        # print_log(' >> Nos [S,V,F] = ['+str(nos_S)+','+str(nos_V)+','+str(nos_F)+'] = '+str(nos_A))
        #============================================================         
        rst+=str(nos_N)+'\t'+str(nos_S)+'\t'+str(nos_V)+'\t'+str(nos_F)+'\t'+str(nos_A)+'\t'
        
        if nos_N == 0 or nos_A==0:
            #============================================================
            # print_log(' >> Cannot continue with zero beats, skip record')
            #============================================================
            rst+=' >> Cannot continue with zero beats, skip record\t'
            print(rst)
            continue
        
        if nos_S == 0:
            npy_S = np.zeros((0,v_dimC+2))
            
        if nos_V == 0:
            npy_V = np.zeros((0,v_dimC+2))
            
        if nos_F == 0:
            npy_F = np.zeros((0,v_dimC+2))
        

        
        # need to select as many as nos_A beats..and concatenate label = 1 (for abnormal)
        #label_A = np.ones((nos_A,1))
        sel_A = np.hstack((  
                        np.ones((nos_A,1)),             #<<----label
                        np.vstack((npy_S,npy_V,npy_F))  #<<----data
                        ))
        sel_N = []
        a = np.arange(0,nos_N)
        if nos_N>=nos_A:
            # random choice from nos_N 
            #============================================================
            # print_log('CASE_1:: nos_N>=nos_A')
            #============================================================  
            rst+= ' N>=A\t'
            sel_N = np.hstack((
                np.zeros((nos_A,1)),
                npy_N[ np.random.choice( a, size=nos_A, replace=False, p=None ) ]
                            ))
        else:
            # repeate N beats - how many to repeat ?
            #============================================================
            # print_log('CASE_2:: nos_N<nos_A')
            #============================================================  
            rst+= ' N<A\t'
            nos_repeat = nos_A%nos_N
            times_repeat = int(nos_A/nos_N)
            
            npy_NT = np.zeros((0,v_dimC+2))
            for nr in range(0,times_repeat):
                npy_NT = np.vstack((npy_NT,npy_N))
            
            sel_N = np.hstack((
                        np.zeros((nos_A,1)),            
                        np.vstack((
                            npy_NT,    
                            npy_N[ np.random.choice( a, size=nos_repeat, replace=False, p=None ) ]
                                ))
                            ))
        #----- endof selection
        #============================================================
        # print_log('End of Selection :: N = '+ str(len(sel_N))+ ' A = '+str(len(sel_A)))
        #============================================================ 
        rst+= str(len(sel_N))+ '\t'+str(len(sel_A))+'\t'
        
        sel_data = np.vstack((sel_A,sel_N))
        
        #============================================================ 
        # print_log('Total beats :: '+ str(len(sel_data)))
        #============================================================ 
        rst+= str(len(sel_data))+'\t'
        
        
        # now select rep_normals to be used for training with data
        # >> sort by nos_beats and select high count episodes
        # col 0 contains nos_beats in that episode 
        npy_rep_sort = npy_rep[np.argsort(npy_rep[:, 0])][0:lim_min_rep_norms]
        # structure of npy_rep is 
        #               0 nos_beats, 
        #               1 l_max(resampled length in samples), 
        #               2 avg_dur(sec), 
        #               3 var_dur, 
        #               4:4+v_dimC mean_signal(array of len v_dimC), 
        #               -v_dimC: median_signal(array of len v_dimC)
        
        # now total samples would be (nos_A + nos_N) * lim_min_rep_norms
        # inputs=[input_N, input_N_dur,input_C, input_C_dur],
        rec_input = np.zeros((0,1+v_dimC+1+v_dimC+1))
        for rr in npy_rep_sort:

            #input_N = rr[4:4+v_dimC]  #<<------- mean representaion
            input_N = rr[-v_dimC:]  #<<------- median representaion
            
            #input_N_dur = rr[1]/BASIC_SRATE  #<<----------- max_duration of episode
            input_N_dur = rr[2]  #<<----------- avg_duration of episode
            
            for ss in sel_data: 
            # NOTE : int lable has been stacked in front of array, shift index by 1
                input_label = ss[0] #<<----- label
                
                input_C = ss[-v_dimC:] #<<----- resampled signal
                
                input_C_dur = ss[1] #<<----- orginal duration of beat in seconds
                
                
                final_input = np.hstack((
                                        input_label,
                                        input_N,
                                        input_N_dur,
                                        input_C,
                                        input_C_dur
                                        ))
                
                rec_input = np.vstack((rec_input,final_input))
                
        #============================================================ 
        # print_log('Record selection :: '+ str(rec_input.shape))
        #============================================================
        rst+= str(len(rec_input))
        print(rst)

        mega_input = np.vstack((mega_input,rec_input))

#============================================================ 
# print_log('Total selection :: '+ str(mega_input.shape))
#============================================================
#rst+= str(len(mega_input))+'\t'
print('\nDone! Total beats = '+ str(len(mega_input)))
timestamp_dur = datetime.datetime.now() - timestamp_start
print('Elapsed time = ' + str(timestamp_dur))

#%%
# save mega_input
np.save(ds_path,mega_input)
print('saved at '+ str(ds_path))

#%%
g_SUPRESS_DATA_WARNING=False # resume 'file not found' warnings


## [ 7.2_Compile_Testing_Set ]

In [0]:
#<<---------------------------------- Select save name of this dataset
ds_name = 'custom_set_test_1'
ds_path = os.path.join(ds_dir, ds_name+'.npy') 

#<<---------------------------------- Select which working db to compile from
g_REP_NORM_POSTFIX = 'REP_300'
lim_min_rep_norms = 1
#<<---------------------------------- Select which working db to compile from
test_db = test_db1

g_SUPRESS_DATA_WARNING=True # supress 'file not found' warnings from ecg_record class
print('db_rec\t#beats\t#r\t#stacked')
timestamp_start = datetime.datetime.now()
mega_input = np.zeros((0,1+v_dimC+1+v_dimC+1))
for idb in test_db.keys():
    #if idb!='svdb':
     #   continue
    sel_db = test_db[idb]
    for irec in sel_db.recs:
        #if irec!='865':
         #   continue
        #------------------------------------------------------------------------
        sel_rec = sel_db.get_record(irec)
        #print(sel_rec.name)
        #------------------------------------------------------------------------
        # load signal
        sel_sig = sel_rec.read_data(g_SIG_II_POSTFIX)
        rst=sel_rec.name+'\t'
        #print('signal shape = ' + str(sel_sig.shape))
        if len(sel_sig)<1:
            rst+=' >>Signal cannot be loaded\t'
            print(rst)
            continue
        #else:
        #    print(' >>Signal loaded succesfully')

        # load beat info
        sbi = sel_rec.read_binfo()
        rst+=str(sbi.nos_rr_peaks)+'\t'


        # first laod representative normal file
        
        npy_rep = sel_rec.read_data(g_REP_NORM_POSTFIX)
        nos_rep = len(npy_rep)

        # check if enough normal episodes
        if nos_rep<lim_min_rep_norms:
            rst+=' >> Not enough Representative Normals\t'
            print(rst)
            continue
        else:
            rst+=str(nos_rep)+'\t'

        all_beats = np.zeros((0,v_dimC+1+1))
        for ibeat in range(0,sbi.nos_rr_peaks):
            sg,pk = sbi.get_signal_data_var(ibeat)
            #============================================================
            #print_log('\tbeat# '+str(ibeat) + '\tLabel='+sbi.rr_labels[ibeat]+ '\tTS='+str(round(sbi.rr_peaks_sec[ibeat],2))+'\tDUR='+str(sbi.rri_dur[ibeat]))
            #============================================================
            sg_resamp = scsig.resample(sg,v_dimC)
            #pk_resamp = round(pk*( len(sg_resamp)/(len(sg) ))) #<<------ peak value not required now
            beat_duration = sbi.rri_dur[ibeat]
            beat_label = sbi.rr_int_labels[ibeat]

            a_resamp = np.hstack((beat_label,beat_duration,sg_resamp)) 
            all_beats = np.vstack((all_beats, a_resamp)) # done now save it   


        sel_replace = False
        if nos_rep<len(all_beats):
            sel_replace=True

        a = np.arange(0,nos_rep) 
        npy_rep_sel = npy_rep[np.random.choice( a, size=len(all_beats), replace=sel_replace, p=None )]

        rec_input = np.zeros((0,1+v_dimC+1+v_dimC+1))
        for rs in range(0,len(all_beats)):
            rr = npy_rep_sel[rs]
            ss = all_beats[rs]

            #input_N = rr[4:4+v_dimC]  #<<------- mean representaion
            input_N = rr[-v_dimC:]  #<<------- median representaion
            
            #input_N_dur = rr[1]/BASIC_SRATE  #<<----------- max_duration of episode
            input_N_dur = rr[2]  #<<----------- avg_duration of episode
        
            input_label = ss[0] #<<----- label
            
            input_C = ss[-v_dimC:] #<<----- resampled signal
            
            input_C_dur = ss[1] #<<----- orginal duration of beat in seconds
            
            
            final_input = np.hstack((
                                    input_label,
                                    input_N,
                                    input_N_dur,
                                    input_C,
                                    input_C_dur
                                    ))
            
            rec_input = np.vstack((rec_input,final_input))
        
        rst+=str(len(rec_input))
        print(rst)
        mega_input = np.vstack((mega_input,rec_input))

np.save(ds_path,mega_input)
print('saved at '+ str(ds_path))
print('Done! Final Input shape:'+ str(mega_input.shape))
timestamp_dur = datetime.datetime.now() - timestamp_start
print('Elapsed time = ' + str(timestamp_dur))
g_SUPRESS_DATA_WARNING=False # supress 'file not found' warnings from ecg_record class

## [ 7.3_Compile_Testing_Set_Class ]

In [0]:
#<<---------------------------------- Select save name of this dataset
ds_name = 'custom_set_test_1_Class_N'
ds_path = os.path.join(ds_dir, ds_name+'.npy') 

#<<---------------------------------- Select which working db to compile from
g_REP_NORM_POSTFIX = 'REP_300'
lim_min_rep_norms = 1
#<<---------------------------------- Select which working db to compile from
test_db = test_db1

g_SUPRESS_DATA_WARNING=True # supress 'file not found' warnings from ecg_record class
print('db_rec\t#beats\t#r\t#stacked')
timestamp_start = datetime.datetime.now()
mega_input = np.zeros((0,1+v_dimC+1+v_dimC+1))
for idb in test_db.keys():
    #if idb!='svdb':
     #   continue
    sel_db = test_db[idb]
    for irec in sel_db.recs:
        #if irec!='865':
         #   continue
        #------------------------------------------------------------------------
        sel_rec = sel_db.get_record(irec)
        #print(sel_rec.name)
        #------------------------------------------------------------------------
        # load signal
        sel_sig = sel_rec.read_data(g_SIG_II_POSTFIX)
        rst=sel_rec.name+'\t'
        #print('signal shape = ' + str(sel_sig.shape))
        if len(sel_sig)<1:
            rst+=' >>Signal cannot be loaded\t'
            print(rst)
            continue
        #else:
        #    print(' >>Signal loaded succesfully')

        # load beat info
        sbi = sel_rec.read_binfo()
        rst+=str(sbi.nos_rr_peaks)+'\t'


        # first laod representative normal file
        
        npy_rep = sel_rec.read_data(g_REP_NORM_POSTFIX)
        nos_rep = len(npy_rep)

        # check if enough normal episodes
        if nos_rep<lim_min_rep_norms:
            rst+=' >> Not enough Representative Normals\t'
            print(rst)
            continue
        else:
            rst+=str(nos_rep)+'\t'

    #----------------------------------------------------------------------------------
    # S = [A,a,J,S]
    # V = [V]
    # F = [F,e,j,n]
    # N = [N,L,R] 
    #----------------------------------------------------------------------------------
        class_list = []
        class_query = (
                    #(sbi.rr_labels=='F')  |
                    (sbi.rr_labels=='N')  |
                    (sbi.rr_labels=='L')  |
                    (sbi.rr_labels=='R')  
                  )
        ##<<<<<<<<<<<<<<<<<<<<---------------[Select your query]
        class_list = np.where(class_query)[0]
    #----------------------------------------------------------------------------------
    #----------------------------------------------------------------------------------

        if len(class_list)==0:
            rst+=' >> Not enough class beats'
            print(rst)
            continue
            
        all_beats = np.zeros((0,v_dimC+1+1))
        for ibeat in class_list:
            sg,pk = sbi.get_signal_data_var(ibeat)
            #============================================================
            #print_log('\tbeat# '+str(ibeat) + '\tLabel='+sbi.rr_labels[ibeat]+ '\tTS='+str(round(sbi.rr_peaks_sec[ibeat],2))+'\tDUR='+str(sbi.rri_dur[ibeat]))
            #============================================================
            sg_resamp = scsig.resample(sg,v_dimC)
            #pk_resamp = round(pk*( len(sg_resamp)/(len(sg) ))) #<<------ peak value not required now
            beat_duration = sbi.rri_dur[ibeat]
            beat_label = sbi.rr_int_labels[ibeat]
            #if beat_label!=1:
            #    print('WARNING check beat type')
            a_resamp = np.hstack((beat_label,beat_duration,sg_resamp)) 
            all_beats = np.vstack((all_beats, a_resamp)) # done now save it   


        sel_replace = False
        if nos_rep<len(all_beats):
            sel_replace=True

        a = np.arange(0,nos_rep) 
        npy_rep_sel = npy_rep[np.random.choice( a, size=len(all_beats), replace=sel_replace, p=None )]

        rec_input = np.zeros((0,1+v_dimC+1+v_dimC+1))
        for rs in range(0,len(all_beats)):
            rr = npy_rep_sel[rs]
            ss = all_beats[rs]

            #input_N = rr[4:4+v_dimC]  #<<------- mean representaion
            input_N = rr[-v_dimC:]  #<<------- median representaion
            
            #input_N_dur = rr[1]/BASIC_SRATE  #<<----------- max_duration of episode
            input_N_dur = rr[2]  #<<----------- avg_duration of episode
        
            input_label = ss[0] #<<----- label
            
            input_C = ss[-v_dimC:] #<<----- resampled signal
            
            input_C_dur = ss[1] #<<----- orginal duration of beat in seconds
            
            
            final_input = np.hstack((
                                    input_label,
                                    input_N,
                                    input_N_dur,
                                    input_C,
                                    input_C_dur
                                    ))
            
            rec_input = np.vstack((rec_input,final_input))
        
        rst+=str(len(rec_input))
        print(rst)
        mega_input = np.vstack((mega_input,rec_input))

np.save(ds_path,mega_input)
print('saved at '+ str(ds_path))
print('Done! Final Input shape:'+ str(mega_input.shape))
timestamp_dur = datetime.datetime.now() - timestamp_start
print('Elapsed time = ' + str(timestamp_dur))
g_SUPRESS_DATA_WARNING=False # supress 'file not found' warnings from ecg_record class

# [ 8_MODEL_DEFINITIONS ]

In [0]:
'''
Callbacks
'''
cb_esr = tf.keras.callbacks.EarlyStopping(
        monitor='accuracy', 
        min_delta=0.00001, 
        patience=2, 
        verbose=0, 
        mode='auto', 
        baseline=None, 
        restore_best_weights=False)
cb_listr=[cb_esr] 

"""
MODEL

"""
m_cost = 'sparse_categorical_crossentropy'
m_opt = 'rmsprop'

def model_01(print_summary, input_shape_N, input_shape_C, fl_filters, nos_output):
    
    # NORMAL Input  +++++++++++++++++++++++++++++++++++++++++++++++++++++
    
    input_N = Input( shape=input_shape_N, name = "input_N" )
    input_N_dur = Input( shape=(1,), name = "input_N_dur" ) 
    # NORMAL Feature extract +++++++++++++++++++++++++++++++++++++++++++++++++++++
    
    conv_N_1 =           Conv1D(30,                #filters, 
                          fl_filters,          #kernel_size, 
                          strides=1, 
                          padding='valid', 
                          data_format='channels_last', 
                          dilation_rate=1, 
                          activation=tf.nn.leaky_relu, 
                          use_bias=True, 
                          kernel_initializer='glorot_uniform', 
                          bias_initializer='zeros', 
                          kernel_regularizer=None, 
                          bias_regularizer=None, 
                          activity_regularizer=None, 
                          kernel_constraint=None, 
                          bias_constraint=None) (input_N) 
    
    pool_N_2 =          MaxPooling1D(pool_size=2, 
                                  strides=None, 
                                  padding='valid', 
                                  data_format='channels_last') (conv_N_1)
    
    conv_N_3 =           Conv1D(20,                #filters, 
                          3,                  #kernel_size, 
                          strides=1, 
                          padding='valid', 
                          data_format='channels_last', 
                          dilation_rate=1, 
                          activation=tf.nn.leaky_relu, 
                          use_bias=True, 
                          kernel_initializer='glorot_uniform', 
                          bias_initializer='zeros', 
                          kernel_regularizer=None, 
                          bias_regularizer=None, 
                          activity_regularizer=None, 
                          kernel_constraint=None, 
                          bias_constraint=None)(pool_N_2)
    
    pool_N_4 =          MaxPooling1D(pool_size=2, 
                                  strides=None, 
                                  padding='valid', 
                                  data_format='channels_last') (conv_N_3)
    
    conv_N_5 =          Conv1D(10,                #filters, 
                          3,                  #kernel_size, 
                          strides=1, 
                          padding='valid', 
                          data_format='channels_last', 
                          dilation_rate=1, 
                          activation=tf.nn.leaky_relu, 
                          use_bias=True, 
                          kernel_initializer='glorot_uniform', 
                          bias_initializer='zeros', 
                          kernel_regularizer=None, 
                          bias_regularizer=None, 
                          activity_regularizer=None, 
                          kernel_constraint=None, 
                          bias_constraint=None) (pool_N_4)
    
    pool_N_6 =          MaxPooling1D(pool_size=2, 
                                  strides=None, 
                                  padding='valid', 
                                  data_format='channels_last') (conv_N_5)
    
    # NORMAL Resgression  +++++++++++++++++++++++++++++++++++++++++++++++++++++
    
    flat_N_7 = Flatten(data_format=None) (pool_N_6)
    
    den_Ndur_concat =  tf.concat([flat_N_7, input_N_dur],axis=1, name = "dense_N_concat")

    #den_N_8 =       Dense(20, 
    #                activation=tf.nn.leaky_relu, 
    #                name = "dense_N") (den_Ndur_concat)

    

    # DENSE Regression +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++

    #denN_01 = Dense(20, activation=tf.nn.relu, name = "DENSE_N_01")(den_N_8)
    
    #denN_02 = Dense(10, activation=tf.nn.relu, name = "DENSE_N_02")(denN_01)

#------------------------------------------------------------------------------

    # COMPAR Input  +++++++++++++++++++++++++++++++++++++++++++++++++++++
    
    input_C = Input( shape=input_shape_C, name = "input_C" )
    input_C_dur = Input( shape=(1,), name = "input_C_dur" ) 
    # NORMAL Feature extract +++++++++++++++++++++++++++++++++++++++++++++++++++++
    
    conv_C_1 =           Conv1D(30,                #filters, 
                          fl_filters,          #kernel_size, 
                          strides=1, 
                          padding='valid', 
                          data_format='channels_last', 
                          dilation_rate=1, 
                          activation=tf.nn.leaky_relu, 
                          use_bias=True, 
                          kernel_initializer='glorot_uniform', 
                          bias_initializer='zeros', 
                          kernel_regularizer=None, 
                          bias_regularizer=None, 
                          activity_regularizer=None, 
                          kernel_constraint=None, 
                          bias_constraint=None) (input_C) 
    
    pool_C_2 =          MaxPooling1D(pool_size=2, 
                                  strides=None, 
                                  padding='valid', 
                                  data_format='channels_last') (conv_C_1)
    
    conv_C_3 =           Conv1D(20,                #filters, 
                          3,                  #kernel_size, 
                          strides=1, 
                          padding='valid', 
                          data_format='channels_last', 
                          dilation_rate=1, 
                          activation=tf.nn.leaky_relu, 
                          use_bias=True, 
                          kernel_initializer='glorot_uniform', 
                          bias_initializer='zeros', 
                          kernel_regularizer=None, 
                          bias_regularizer=None, 
                          activity_regularizer=None, 
                          kernel_constraint=None, 
                          bias_constraint=None)(pool_C_2)
    
    pool_C_4 =          MaxPooling1D(pool_size=2, 
                                  strides=None, 
                                  padding='valid', 
                                  data_format='channels_last') (conv_C_3)
    
    conv_C_5 =          Conv1D(10,                #filters, 
                          3,                  #kernel_size, 
                          strides=1, 
                          padding='valid', 
                          data_format='channels_last', 
                          dilation_rate=1, 
                          activation=tf.nn.leaky_relu, 
                          use_bias=True, 
                          kernel_initializer='glorot_uniform', 
                          bias_initializer='zeros', 
                          kernel_regularizer=None, 
                          bias_regularizer=None, 
                          activity_regularizer=None, 
                          kernel_constraint=None, 
                          bias_constraint=None) (pool_C_4)
    
    pool_C_6 =          MaxPooling1D(pool_size=2, 
                                  strides=None, 
                                  padding='valid', 
                                  data_format='channels_last') (conv_C_5)
    
    # NORMAL Resgression  +++++++++++++++++++++++++++++++++++++++++++++++++++++
    
    flat_C_7 = Flatten(data_format=None) (pool_C_6)
    
    den_Cdur_concat =  tf.concat([flat_C_7, input_C_dur],axis=1, name = "dense_C_concat")

    #den_C_8 =       Dense(20, 
    #                activation=tf.nn.leaky_relu, 
    #                name = "dense_C") (den_Cdur_concat)

    

    # DENSE Regression +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++

    #denC_01 = Dense(20, activation=tf.nn.relu, name = "DENSE_C_01")(den_C_8)
    
    #denC_02 = Dense(10, activation=tf.nn.relu, name = "DENSE_C_02")(denC_01)

#------------------------------------------------------------------------------
# N C Concat +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++

    den_NC_concat =  tf.concat([den_Ndur_concat, den_Cdur_concat],axis=1, name = "dense_NC_concat")
#------------------------------------------------------------------------------
    # DENSE Regression +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++

    den_NC_01 = Dense(20, activation=tf.nn.relu, name = "DENSE_NC_01")(den_NC_concat)
    
    den_NC_02 = Dense(10, activation=tf.nn.relu, name = "DENSE_NC_02")(den_NC_01)  
    
# OUTPUT  +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++

    den_out = Dense(nos_output, activation=tf.nn.softmax, name = "OUTPUT_FC")(den_NC_02)

# =============================================================================
    
    model=Model(inputs=[input_N, input_N_dur,input_C, input_C_dur], outputs=den_out)

    model.compile(
                  loss=m_cost, 
                  optimizer=m_opt, 
                  metrics=['accuracy']
                  )
    
    if print_summary:
        print(model.summary())
    return model
#==============================================================================


# [ 9_EXPERIMENT ]

In [0]:
# list avaialble datasets
ls_datasets = os.listdir(ds_dir)
print('Available Datasets:')
for ds_i in ls_datasets:
    print(ds_i)
print('--------------------------')

# list available models
ls_models = os.listdir(global_modeldir)
print('Available Models:')
for ms_i in ls_models:
    print(ms_i)
print('--------------------------')

## [ 9.1_TRAINING ]

In [0]:
# Training
##<---------------------------------------------- 
ds_name = 'custom_set_1'    # SELECT DATASET FOR LOADING TRAINING DATA FROM
ds_model = 'model_1'       # SELECT NAME FOR SAVING MODEL WEIGHTS
##<----------------------------------------------


ds_path = os.path.join(ds_dir, ds_name+'.npy') 
megaset = np.load(ds_path)    
print('Loaded data from '+ str(ds_path)+'\nShape='+str(megaset.shape))



# Training data structure
    # input_label,      0 
    # input_N,          1:1+v_dimC
    # input_N_dur,      1+v_dimC:1+v_dimC+1
    # input_C,          1+v_dimC+1:1+v_dimC+1+v_dimC
    # input_C_dur       1+v_dimC+1+v_dimC:1+v_dimC+1+v_dimC+1
    
x_labels = megaset[:,0]
x_norm = np.expand_dims(megaset[:,1:1+v_dimC], axis=2) 
x_norm_dur = megaset[:,1+v_dimC:1+v_dimC+1] #np.expand_dims(, axis=2) 
x_comp = np.expand_dims(megaset[:,1+v_dimC+1:1+v_dimC+1+v_dimC], axis=2) 
x_comp_dur = megaset[:,1+v_dimC+1+v_dimC:1+v_dimC+1+v_dimC+1] #np.expand_dims(, axis=2) 

data_x = [x_norm, x_norm_dur, x_comp,x_comp_dur]
data_y = x_labels

# Get Model #def model_01(print_summary, input_shape_N, fl_filters, nos_output):
model = model_01(True, (v_dimC,1), (v_dimC,1), 7, 2)

# Start Training---------------------------------------------------------
timestamp_start = datetime.datetime.now()

history = model.fit(
                    data_x, data_y,
                    batch_size=1000,
                    epochs=300,
                    callbacks=cb_listr,
                    #validation_data=([alle_m.reshape((elen,timesteps,1)),alle_t],alle_l),
                    shuffle=True,
                    verbose=1)

timestamp_dur = datetime.datetime.now() - timestamp_start
print('Elapsed time = ' + str(timestamp_dur))
# End Training---------------------------------------------------------

ff=0
plt.figure(ff)
ff+=1
plt.title('ACC: '+ds_name)
plt.plot(history.history['accuracy'],color='green')
plt.show()

plt.figure(ff)
ff+=1
plt.title('LOSS: '+ds_name)
plt.plot(history.history['loss'],color='red')
plt.show()


# save this model
save_model_name = ds_model +'.h5'        # save model weights to this file
svmpth = os.path.join(global_modeldir, save_model_name)
model.save_weights(svmpth)
print('Saved Model Weights at : '+ str(svmpth))



## [ 9.2_TESTING ]

In [0]:
# Testing
##<---------------------------------------------- 
ds_name = 'custom_set_test_1_Class_N'    # SELECT DATASET FOR LOADING TESTING DATA FROM
ds_model = 'model_1'       # SELECT MODEL WEIGHTS TO TEST UPON
##<----------------------------------------------

ds_path = os.path.join(ds_dir, ds_name+'.npy') 
megaset = np.load(ds_path)    
print('Loaded data from '+ str(ds_path)+'\nShape='+str(megaset.shape))

# Testing data structure
    # input_label,      0 
    # input_N,          1:1+v_dimC
    # input_N_dur,      1+v_dimC:1+v_dimC+1
    # input_C,          1+v_dimC+1:1+v_dimC+1+v_dimC
    # input_C_dur       1+v_dimC+1+v_dimC:1+v_dimC+1+v_dimC+1
    
x_labels = megaset[:,0]
x_norm = np.expand_dims(megaset[:,1:1+v_dimC], axis=2) 
x_norm_dur = megaset[:,1+v_dimC:1+v_dimC+1] #np.expand_dims(, axis=2) 
x_comp = np.expand_dims(megaset[:,1+v_dimC+1:1+v_dimC+1+v_dimC], axis=2) 
x_comp_dur = megaset[:,1+v_dimC+1+v_dimC:1+v_dimC+1+v_dimC+1] #np.expand_dims(, axis=2) 

data_x = [x_norm, x_norm_dur, x_comp,x_comp_dur]
data_y = x_labels

# Get Model #def model_01(print_summary, input_shape_N, fl_filters, nos_output):
model = model_01(False, (v_dimC,1), (v_dimC,1), 7, 2)
load_model_name = ds_model+ '.h5'     # model used for testing
load_model_path = os.path.join(global_modeldir, load_model_name)
model.load_weights(load_model_path)
print('Loaded Model weights '+ str(load_model_path))
#-------------------------------------------------------------------------------------------------------
# #evla = model.evaluate( data_x, data_y ) data_med med_rep
# #print(evla)
#-------------------------------------------------------------------------------------------------------

# manual prediction
print('Manual Prediction on : ' + ds_name)
predx = model.predict( data_x ) # array of  samples x classes(4) - each row is a prediction of sample
cmx_global = np.zeros((len(g_LABELS),len(g_LABELS)),dtype='int32')
cmx2_global = predx.argmax(axis=1)
for i in range(0,len(cmx2_global)):
    alabel = int(data_y[i])
    plabel = cmx2_global[i]
    cmx_global[alabel,plabel]+=1
print('\tConfusion Matrix')
print(print_conf_matrix( cmx_global,'', g_LABELS)) #logit('\t'+str(cmx))
print_performance( get_performance(cmx_global) ,g_LABELS ) 
#------------------------------------------------------------

## [ 9.3_RECORD_TESTING ]

In [0]:
test_db = all_db        #<<---- db dict to read ecg data

idb = 'incartdb'           #<<---- ecg_db
irec = 'I42'            #<<---- ecg_record

sel_db = test_db[idb]
sel_rec = sel_db.get_record(irec)

# load beat info
sbi = sel_rec.read_binfo()
print('Total beats = '+str(sbi.nos_rr_peaks))

# load signal
sel_sig = sel_rec.read_data(g_SIG_II_POSTFIX)
if len(sel_sig)<1:
    print('Signal cannot be loaded')
else:
    print('Signal loaded succesfully')

#<<---------------------------------- Select save name of this dataset
ds_name = 'rec_test_'+idb+'_'+irec
ds_path = os.path.join(ds_dir, ds_name+'.npy') 

### [ 9.3.1_Compile_Test_Set ]

In [0]:

#<<---------------------------------- Select which representative normal to use
g_REP_NORM_POSTFIX = 'REP_300'
lim_min_rep_norms = 1


timestamp_start = datetime.datetime.now()
print('db_rec\t#beats\t#r\t#stacked')

rst=sel_rec.name+'\t'
if len(sel_sig)<1:
    rst+=' >>Signal cannot be loaded\t'
    print(rst)
else:
    rst+=str(sbi.nos_rr_peaks)+'\t'

    # first laod representative normal file
    npy_rep = sel_rec.read_data(g_REP_NORM_POSTFIX)
    nos_rep = len(npy_rep)

    # check if enough normal episodes
    if nos_rep<lim_min_rep_norms:
        rst+=' >> Not enough Representative Normals\t'
        print(rst)
    else:
        rst+=str(nos_rep)+'\t'

        all_beats = np.zeros((0,1+v_dimC+1))
        for ibeat in range(0,sbi.nos_rr_peaks):
            sg,pk = sbi.get_signal_data_var(ibeat)
            #============================================================
            #print('\tbeat# '+str(ibeat) + '\tLabel='+sbi.rr_labels[ibeat]+ '\tTS='+str(round(sbi.rr_peaks_sec[ibeat],2))+'\tDUR='+str(sbi.rri_dur[ibeat]))
            #============================================================
            sg_resamp = scsig.resample(sg,v_dimC)
            #pk_resamp = round(pk*( len(sg_resamp)/(len(sg) ))) #<<------ peak value not required now
            beat_duration = sbi.rri_dur[ibeat]
            beat_label = sbi.rr_int_labels[ibeat]

            #beat_index = ibeat              #<<----- for identifying beat later on <<--- NO NEED FOR THIS

            a_resamp = np.hstack((beat_label,beat_duration,sg_resamp)) 
            all_beats = np.vstack((all_beats, a_resamp)) # done now save it   


        sel_replace = False
        if nos_rep<len(all_beats):
            sel_replace=True

        a = np.arange(0,nos_rep) 
        npy_rep_sel = npy_rep[np.random.choice( a, size=len(all_beats), replace=sel_replace, p=None )]

        rec_input = np.zeros((0,1+v_dimC+1+v_dimC+1))
        for rs in range(0,len(all_beats)): 
            
            rr = npy_rep_sel[rs]
            ss = all_beats[rs]
            # NOTE:beat_identifier has been stacked at the begin, shift index in ss by +1 <<--- NO NEED FOR THIS

            #input_N = rr[4:4+v_dimC]  #<<------- mean representaion
            input_N = rr[-v_dimC:]  #<<------- median representaion
            
            #input_N_dur = rr[1]/BASIC_SRATE  #<<----------- max_duration of episode
            input_N_dur = rr[2]  #<<----------- avg_duration of episode

            input_label = ss[0] #<<----- label
            
            input_C = ss[-v_dimC:] #<<----- resampled signal
            
            input_C_dur = ss[1] #<<----- orginal duration of beat in seconds
            
            #input_beat_id = ss[0]   #<<----------- identifier for beat in binfo class <<-- NO NEED FOR THIS
            
            final_input = np.hstack((
                                    input_label,
                                    input_N,
                                    input_N_dur,
                                    input_C,
                                    input_C_dur,
                                    #input_beat_id  # <<-- NO NEED FOR THIS
                                    ))
            
            rec_input = np.vstack((rec_input,final_input))

        rst+=str(len(rec_input))
        print(rst)
        

        np.save(ds_path,rec_input)
        print('saved at '+ str(ds_path))
        print('Done! Final Input shape:'+ str(rec_input.shape))

timestamp_dur = datetime.datetime.now() - timestamp_start
print('Elapsed time = ' + str(timestamp_dur))


### [ 9.4.2_Test_Record ]

In [0]:
# Testing
##<---------------------------------------------- 
ds_model = 'model_1'       # SELECT MODEL WEIGHTS TO TEST UPON
##<----------------------------------------------

megaset = np.load(ds_path)    
print('Loaded data from '+ str(ds_path)+'\nShape='+str(megaset.shape))

# Testing data structure
    # input_label,      0 
    # input_N,          1:1+v_dimC
    # input_N_dur,      1+v_dimC:1+v_dimC+1
    # input_C,          1+v_dimC+1:1+v_dimC+1+v_dimC
    # input_C_dur       1+v_dimC+1+v_dimC:1+v_dimC+1+v_dimC+1
    
x_labels = megaset[:,0]
x_norm = np.expand_dims(megaset[:,1:1+v_dimC], axis=2) 
x_norm_dur = megaset[:,1+v_dimC:1+v_dimC+1] #np.expand_dims(, axis=2) 
x_comp = np.expand_dims(megaset[:,1+v_dimC+1:1+v_dimC+1+v_dimC], axis=2) 
x_comp_dur = megaset[:,1+v_dimC+1+v_dimC:1+v_dimC+1+v_dimC+1] #np.expand_dims(, axis=2) 

data_x = [x_norm, x_norm_dur, x_comp,x_comp_dur]
data_y = x_labels

# Get Model #def model_01(print_summary, input_shape_N, fl_filters, nos_output):
model = model_01(False, (v_dimC,1), (v_dimC,1), 7, 2)
load_model_name = ds_model+ '.h5'     # model used for testing
load_model_path = os.path.join(global_modeldir, load_model_name)
model.load_weights(load_model_path)
print('Loaded Model weights '+ str(load_model_path))
#-------------------------------------------------------------------------------------------------------
# #evla = model.evaluate( data_x, data_y ) data_med med_rep
# #print(evla)
#-------------------------------------------------------------------------------------------------------

# manual prediction
print('Manual Prediction on : ' + ds_name)
predx = model.predict( data_x ) # array of  samples x classes(4) - each row is a prediction of sample
cmx_global = np.zeros((len(g_LABELS),len(g_LABELS)),dtype='int32')
cmx2_global = predx.argmax(axis=1)
for i in range(0,len(cmx2_global)):
    alabel = int(data_y[i])
    plabel = cmx2_global[i]
    cmx_global[alabel,plabel]+=1
print('\tConfusion Matrix')
print(print_conf_matrix( cmx_global,'', g_LABELS)) #logit('\t'+str(cmx))
print_performance( get_performance(cmx_global) ,g_LABELS ) 
#------------------------------------------------------------

cmx_false_N = np.zeros(len(cmx2_global))
cmx_false_N[np.where(
                    (cmx2_global==0) &      #<<--- predicted Normal
                    (data_y==1)           #<<--- actually Abnormal
                    )[0]]=-1

cmx_false_A = np.zeros(len(cmx2_global))
cmx_false_A[np.where(
                    (cmx2_global==1)   &    #<<--- predicted abnorm
                    (data_y==0)           #<<--- actually norm
                    )[0]]=-1


### [ 9.4.3_Plot_Test_Results ]

In [0]:
# plot signal segments

#<<---------------------------------------------Select Paper Resolution
x_scale = 25 * 0.0393701 # mm/sec -> inches/sec
y_scale = 10 * 0.0393701 # mm/mV -> inches/sec
y_low = -2.5
y_high = 3.5
#<<--------------------------------------------------------------------

#<<---------------------------------------------Select ECG Segment
fsec = 10
tsec = fsec+(90)
dsec = tsec - fsec
#<<--------------------------------------------------------------------


ff = fsec * BASIC_SRATE
tt = tsec * BASIC_SRATE
dd = tt - ff

bps = sel_sig[ff:tt]

dticks = sbi.rr_peaks[(sbi.rr_peaks_sec >= fsec) & (sbi.rr_peaks_sec < tsec)]
dlabels = sbi.rr_labels[(sbi.rr_peaks_sec >= fsec) & (sbi.rr_peaks_sec < tsec)]
dticks = (dticks / sel_rec.db.srate)*BASIC_SRATE - ff

plt.figure(2, figsize = (dsec*x_scale ,(y_high-y_low) * y_scale) )
plt.xlim(0, len(bps))
plt.ylim(y_low,y_high)
plt.xticks(dticks,dlabels)
#x_grid = np.arange(0,tt-ff, 1*BASIC_SRATE)
#plt.xticks(x_grid)
plt.grid(axis='x')

#drris = sbi.rri_delta[(sbi.rr_peaks_sec >= fsec) & (sbi.rr_peaks_sec < tsec)]
#drrid = sbi.rri_dur[(sbi.rr_peaks_sec >= fsec) & (sbi.rr_peaks_sec < tsec)]
# RED: rri_delta
#plt.scatter(dticks,drris, marker='s',color='tab:red')
# GREEN = Duration
#plt.scatter(dticks,drrid, marker='s',color='tab:green')

dcmx2 = cmx2_global[(sbi.rr_peaks_sec >= fsec) & (sbi.rr_peaks_sec < tsec)] 
dfalseN = cmx_false_N[(sbi.rr_peaks_sec >= fsec) & (sbi.rr_peaks_sec < tsec)] 
dfalseA = cmx_false_A[(sbi.rr_peaks_sec >= fsec) & (sbi.rr_peaks_sec < tsec)] 
# RED: Abnormal
plt.scatter(dticks[dcmx2==1],dcmx2[dcmx2==1], marker='s',color='tab:red')
# GREEN: Normal
plt.scatter(dticks[dcmx2==0],dcmx2[dcmx2==0], marker='s',color='tab:green')

# falsely prdicted as Normals, actually abnormal
plt.scatter(dticks[dfalseN==-1],dfalseN[dfalseN==-1], marker='x',color='tab:red')

# falsely predicted as Abnormals, actually normal
plt.scatter(dticks[dfalseA==-1],dfalseA[dfalseA==-1], marker='x',color='tab:green')

plt.plot(bps, linewidth=0.5, color='black')
plt.hlines(0,0,len(bps), linewidth=0.3)
