Summary:
1. deal with the annotation

    1) create a empty set
    
    2) read through ann.txt. For each N, put the point and its +-20 points into the set. For X, skip.
   
    
2. segment and wavelet transform

    1) randomly pick a midpoint and scan through +-1.5s and see if there is a R wave. If no, throw out the midpoint. 
    
    2) keep track of R and non-R selected. Each should be 7500.(total samples per records = 15000)
    
    
3. store in TFrecords


In [1]:
import os
import numpy as np
import tensorflow as tf
from Stationary_transform import *
from sklearn.model_selection import train_test_split

### Annotation Preprocessing

In [3]:
def annProcess(path,width=21):
    
    with open(path+'/ann.txt','r') as f: # read annotation
        lines = f.readlines()
    
    R = set()  # set initialize
    R_set = set()
    
    ecg = np.fromfile(path+'/ecg.dat', '>i2') # get ecg length
    ecg_len = len(ecg)
    
    
    for line in lines:   # add each R wave and its surrounding to set
        temp = line.split(',')
        pos,typ = temp[0],temp[1]
        if typ == 'X':
            continue
        pos = round(int(pos)/1000*256)
        R_set.add(pos)
        R.add(pos)
        for i in range(width):
            if pos-i>=0:
                R_set.add(pos-i)
            if pos+i<ecg_len:
                R_set.add(pos+i)
    
    return R_set,R

### Sampling Non-noise Period

In [4]:
'''Check if at least 1 R wave present in the period'''
def validPeriod(n, R_set):
    for i in range(384):
        if n-i in R_set or n+i in R_set:
            return True
    return False


'''sampling noR and R period without noise'''
def sampling_normal(path,size_R,size_nR,R_set,record):
    ecg = np.fromfile(path +'/ecg.dat', '>i2') 
    
    selected = set() # initialize storage for selected sequence

    data_x,data_y = np.empty((size_R+size_nR,514),'float32'),np.empty((size_R+size_nR),'int')
    
    np.random.seed(40)
    count_R,count_nR,count_total=0,0,0
    
    while count_R<size_R or count_nR<size_nR:
        
        if count_total%50==0:
            print(record + ' sampling '+str(count_total)+'/ '+str(size_R+size_nR), end = '\r')
        
        n = np.random.randint(256,len(ecg)-256) 
        
        # check if has selected the period and if +-1.5s of this point has a R and if n is within QRS period
        if n not in selected and validPeriod(n,R_set):
            if n in R_set and count_R<size_R: 
                x = np.empty((514),'float32')
                x[:512] = np.array([ecg[n-256:n+256]])
                x[-2:] = [record,n]
                data_x[count_total] = x
                data_y[count_total] = 1
                count_R +=1
                count_total+=1
            if n not in R_set and count_nR<size_nR:
                x = np.empty((514),'float32')
                x[:512] = np.array([ecg[n-256:n+256]])
                x[-2:] = [record,n]
                data_x[count_total] = x
                data_y[count_total] = 0
                count_nR +=1
                count_total+=1
            
            selected.add(n)
            
    return data_x,data_y

### Sampling True Noise

In [6]:
'''sampling true noise from each record'''
def sampling_noise(path,R,record): 
    
    ecg = np.fromfile(path +'/ecg.dat', '>i2') 

    data_x,data_y = np.empty((150000,514),'float32'),np.empty((150000),'int') 
    # 150000 is an abitrary large number. It is used cuz exact number of noise in the record is unknown
    
    R = sorted(list(R))
    np.random.seed(40) 
    count_total = 0
    
    for i, loc in enumerate(R[1:]):
        dis=loc-R[i]
        if dis>1064 and count_total<150000: #1064 considered +-20 of the R peak
            for n in range(R[i]+1,loc,25):
                x = np.empty((514),'float32')
                x[:512] = np.array([ecg[n-256:n+256]])
                x[-2:] = [record,n]
                data_x[count_total] = x
                data_y[count_total] = 0
                count_total+=1
    print('     count_total '+str(count_total))
    
    if count_total ==0: # when no noise in the record
        return [],[],0
    
    x_noise,y_noise = np.copy(data_x[:count_total]),np.copy(data_y[:count_total]) # make copy to avoid memory leak
    del data_x
    del data_y
    
    return x_noise,y_noise,count_total

### Sampling Regular Noise

In [7]:
'''Create sine/square/triangle waves as regular noise'''

def get_sin(freq_low, freq_high, size=747):
    
    data_x, data_y = np.empty((size,514),'float32'),np.empty((size),'int')
    
    count = 0
    time = np.arange(512)
    for freq in range(freq_low, freq_high+1):  # get sine with multiple frequency and phase
        B = 2*np.pi*freq/256
        for phi in range(int(256/freq)):
            print('Get sine '+str(count)+'/ '+str(size), end = '\r')
            sine = np.sin(B*time-B*phi)
            data_x[count] = append(sine,[0,0])
            data_y[count] = 0
            count+=1 
    return data_x,data_y
    
def get_square(freq_low, freq_high, size=747):
    data_x, data_y = np.empty((size,514),'float32'),np.empty((size),'int')
    count = 0
    
    for freq in range(freq_low, freq_high+1):
        square = np.zeros(1024,'float32')  # generate square wave wiht length 1024, but will only crop 512 segment from it
        length,pos,i = int(256/(2*freq)),1,0
        while i< 4*2*freq:
            square[i*length:(i+1)*length] = np.ones(length,'int')*pos
            pos *= -1
            i+=1 
        for phi in range(int(256/freq)): # sliding window to crop the length 512 from square wave according to phi
            print('Get square '+str(count)+'/ '+str(size), end = '\r')
            data_x[count]=np.append(square[phi:phi+512],[0,0])
            data_y[count] = 0
            count+=1
    return data_x, data_y
    
def get_triangle(freq_low, freq_high,size=747):
    data_x, data_y = np.empty((size,514),'float32'),np.empty((size),'int')
    count = 0
    
    for freq in range(freq_low, freq_high+1):
        tran = np.zeros(1024,'float32')   # generate square wave wiht length 1024, but will only crop 512 segment from it
        slope, part, period = 1/(256/(4*freq)), 256/(4*freq), 256//freq
        for i in range(256//freq):  # draw triangle wave with length 256
            m = 256/(4*freq)
            if (i//(256/(4*freq)))%4==0:
                tran[i] = slope*i
            elif (i//(256/(4*freq)))%4==3:
                tran[i] = slope*i+(-1-slope*(3*part))
            else:
                tran[i] = -slope*i+slope*(part*2)
        for i in range(1,4*freq): # copy length 256 for multi-times to make full 1024 length wave
            tran[i*period:(i+1)*period] = tran[:period]

        for phi in range(int(256/freq)): # sliding window to crop the length 512 from square wave according to phase
            print('Get triangle '+str(count)+'/ '+str(size), end = '\r')
            data_x[count]=np.append(tran[phi:phi+512],[0,0])
            data_y[count] = 0
            count+=1
    return data_x, data_y

### TFRecord Storage

In [8]:
def putSample(writer, size, data_x, data_y):
    for i in range(size):
        
        if i%500 ==0:
            print('      store '+str(i)+'/ '+str(size), end = '\r')
            
        y = data_y[i]
        
        sample = decomp(data_x[i][:-2],'db2',(512,8)).reshape(4096) # decompose each sample
        x = np.empty((4098),'float32')
        x[:4096] = sample
        x[-2:] = data_x[i][-2:]

        # create example
        feature = {
            'ecg': tf.train.Feature(float_list=tf.train.FloatList(value=x)),
            'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[y])),
        }
        example = tf.train.Example(features = tf.train.Features(feature = feature))
        writer.write(example.SerializeToString())

### Other Helper

In [9]:
def findSubdir(path):
    subdir = [x for x in os.walk(path)]
    return subdir[0][1]

def shuffle_and_split(data_x,data_y,seed = 10,size = 0.03):
    np.random.seed(seed)
    np.random.shuffle(data_x)
    np.random.seed(seed)
    np.random.shuffle(data_y)
    x_train, x_val, y_train, y_val = train_test_split(data_x,data_y, test_size=0.03)
    return x_train, x_val, y_train, y_val

### Main

In [10]:
# get all subdir name in mitdb dir
lydhdb = findSubdir('db/lydhdb')
lydhdb.remove('lorenz_plots')
lydhdb = sort(lydhdb)

total_noise = 0

with tf.io.TFRecordWriter('/tmpdata/val.tfrecords') as val_writer:
    for i in range(0,9):
        with tf.io.TFRecordWriter('/tmpdata/train_'+str(i+1)+'.tfrecords') as train_writer:
            records = lydhdb[i*8:(i+1)*8] if i!=8 else lydhdb[i*8:]
            for record in records:
                path = 'db/lydhdb/'+record

                R_set,R = annProcess(path) 

                x_normal, y_normal = sampling_normal(path,15500,10800,R_set,record) # sampling R and nR
                print(record + ' normal sampling finished')
                x_train, x_val, y_train, y_val = shuffle_and_split(x_normal, y_normal,seed = 10,size = 0.03)
            
                putSample(train_writer, len(x_train), x_train,y_train) # store samples to TFrecord
                putSample(val_writer, len(x_val), x_val,y_val)
                print('      normal store finished       ')  
                
                
                x_noise, y_noise,count_noise = sampling_noise(path,R,record) # sampling true noise
                total_noise+=count_noise
                print('      noise sampling finished')
                if len(x_noise)!=0:
                    x_train, x_val, y_train, y_val = shuffle_and_split(x_noise,y_noise,seed = 10,size = 0.03)
                    
                    putSample(train_writer, len(x_train), x_train,y_train) # store samples to TFrecord
                    putSample(val_writer, len(x_val), x_val,y_val)
                    print('     noise store finished       ')  

            if i==8:
                x_sine,y_sine = get_sin(1,10) # store regular noise samples to TFrecord
                x_square, y_square = get_square(1,10)
                x_tran, y_tran = get_triangle(1,10)
                putSample(train_writer, 747, x_sine,y_sine)
                putSample(train_writer, 747, x_square, y_square)
                putSample(train_writer, 747, x_tran, y_tran)
                print('sine, square, traingle finished') 

        train_writer.close()
val_writer.close()

20000 normal sampling finished
      normal store finished       
     count_total 0
      noise sampling finished
20001 normal sampling finished
      normal store finished       
     count_total 125
      noise sampling finished
     noise store finished       
20002 normal sampling finished
      normal store finished       
     count_total 0
      noise sampling finished
20003 normal sampling finished
      normal store finished       
     count_total 902
      noise sampling finished
     noise store finished       
20004 normal sampling finished
      normal store finished       
     count_total 1567
      noise sampling finished
     noise store finished       
20005 normal sampling finished
      normal store finished       
     count_total 0
      noise sampling finished
20006 normal sampling finished
      normal store finished       
     count_total 43
      noise sampling finished
     noise store finished       
20007 normal sampling finished
      normal store finis