In [1]:
# DataProcessing and model generation process
import hdf5storage
import numpy as np
from scipy.signal import butter, lfilter
import os
from scipy import signal
import matplotlib.pyplot as plt
from datetime import datetime

%matplotlib inline

def Standardization(Epochs):
    for i in range(Epochs.shape[1]):
        Epochs[:,i,:] = np.subtract(Epochs[:,i,:], np.mean(Epochs[:,i,:]))
        Epochs[:,i,:] = Epochs[:,i,:] / np.std(Epochs[:,i,:])
    
    return Epochs 

def Re_referencing(eegData, channelNum, sampleNum):
        after_car = np.zeros((channelNum,sampleNum))
        for i in np.arange(channelNum):
            after_car[i,:] = eegData[i,:] - np.mean(eegData,axis=0)
        return after_car

def butter_bandpass(lowcut, highcut, fs, order=5):
        nyq = 0.5 * fs
        low = lowcut / nyq
        high = highcut / nyq
        b, a = butter(order, [low, high], btype='band')
        return b, a
def butter_bandpass_filter(data, lowcut, highcut, fs, order=5):
        b, a = butter_bandpass(lowcut, highcut, fs, order=order)
        y = lfilter(b, a, data)
        return y
def Make_Average_Component(EpochsT, NumT, EpochsN, NumN, channelNum, epochSampleNum, componentNum):
    EpochsT = Standardization(EpochsT)
    EpochsN = Standardization(EpochsN)    
    
    NumT_Aver = NumT-componentNum
    NumN_Aver = NumN-componentNum
    
    EpochsT_Aver = np.zeros((NumT_Aver, channelNum, epochSampleNum))
    EpochsN_Aver = np.zeros((NumN_Aver, channelNum, epochSampleNum))
    for i in range(NumT_Aver):
        EpochsT_Aver[i, :, :] = np.mean(EpochsT[i:i+componentNum, :, :], axis=0)
    for j in range(NumN_Aver):
        EpochsN_Aver[j, :, :] = np.mean(EpochsN[j:j+componentNum, :, :], axis=0)
        
    return [EpochsT_Aver, NumT_Aver, EpochsN_Aver, NumN_Aver]


def Epoching(eegData, stims, samplingFreq, channelNum, epochSampleNum, offset, baseline):
        Time_after = np.add(stims,offset).astype(int)
        Time_base = np.add(stims,baseline).astype(int)
        Num = stims.shape[1]
        Epochs = np.zeros((Num, channelNum, epochSampleNum))
        for j in range(Num):
            Epochs[j, :, :] = eegData[:,Time_after[0][j]:Time_after[0][j] + epochSampleNum]
        
        return [Epochs,Num]


def resampling(Epochs, EpochNum, resampleRate, channelNum):
        resampled_epoch = np.zeros((EpochNum, channelNum, resampleRate))
        for i in range(EpochNum):
            for j in range(channelNum):
                resampled_epoch[i,j,:] = signal.resample(Epochs[i,j,:], resampleRate)
        return resampled_epoch

def plotGraph(filename):
        channelNum = 7
        epochSampleNum = 512
        epochNum = 150
        resampleRate = 300
        target = np.zeros((epochNum,1,resampleRate))
        nontarget = np.zeros((epochNum,1,resampleRate))


        mat = hdf5storage.loadmat(filename)
        eegData = mat['eegData']
        samplingFreq = mat['samplingFreq'][0,0]
        stimsN = mat['stimsN']
        stimsT = mat['stimsT']
        sampleNum = eegData.shape[1]
        channelIndex = [18, 30, 12, 11, 19, 10, 15]
            
            # vr300 7 channel
            # [P4, Fz, Pz, P3, PO8, PO7, Oz]
            # [19, 31, 13, 12, 20, 11, 16]
            
        eegData = eegData[channelIndex]
            
            ## Preprocessing process
        eegData = Re_referencing(eegData, channelNum, eegData.shape[1])
            
            #Bandpass Filter
        eegData = butter_bandpass_filter(eegData, 0.1, 30, samplingFreq, 4)
        
#             #Epoching
        epochSampleNum = int(np.floor(1.0 * samplingFreq))
        offset = int(np.floor(0.0 * samplingFreq))
        baseline = int(np.floor(1.0 * samplingFreq))
        [EpochsT, NumT] = Epoching(eegData, stimsT, samplingFreq, channelNum, epochSampleNum, offset, baseline)
        [EpochsN, NumN] = Epoching(eegData, stimsN, samplingFreq, channelNum, epochSampleNum, offset, baseline)
        
        
        #[EpochsT_Aver, NumT_Aver, EpochsN_Aver, NumN_Aver] = Make_Average_Component(EpochsT, NumT, EpochsN, NumN, channelNum, epochSampleNum, 0)
        EpochsT_Aver = resampling(EpochsT, NumT, resampleRate, channelNum) 
        EpochsN_Aver = resampling(EpochsN, NumN, resampleRate, channelNum)


        
        return [EpochsT_Aver, EpochsN_Aver]


def main():
    ctime = datetime.today().strftime("%m%d_%H%M")
    filename = ''
    channelNum = 7 # (n_components)
    epochSampleNum = 300
    epochNum = 150
    subject_num = []
    root = 'D:\\P300_biosemi_55\\S'
    T_all = np.zeros((epochNum*(55 - len(subject_num)),channelNum,epochSampleNum))
    N_all = np.zeros((750*(55 - len(subject_num)),channelNum,epochSampleNum))
    
    T_path = 'D:\\P300_biosemi_55\\New\\T_all.out'
    N_path = 'D:\\P300_biosemi_55\\New\\N_all.out'
    count = 0

    if not os.path.exists(T_path):
        for i in np.arange(1,56):
            count = count + 1
            if i in subject_num:
                count = count - 1
                continue
            if(i<10):
                filename = root + '0' + str(i)
            else:
                filename = root + str(i)
            [T_all[epochNum*(count-1):epochNum*count,:,:],N_all[750*(count-1):750*count,:,:]] = plotGraph(filename)
            print("subject {0} is preprocessed".format(str(i)))

        T_all = np.reshape(T_all, (epochNum*(55 - len(subject_num)),channelNum*epochSampleNum))
        N_all = np.reshape(N_all, (750*(55 - len(subject_num)),channelNum*epochSampleNum))

        np.savetxt(T_path, T_all)
        np.savetxt(N_path, N_all)
    else:
        T_all = np.loadtxt(T_path)
        N_all = np.loadtxt(N_path)
       
    T_all = np.reshape(T_all, (epochNum*(55 - len(subject_num)),channelNum,epochSampleNum))
    N_all = np.reshape(N_all, (750*(55 - len(subject_num)),channelNum,epochSampleNum))            
    target_data = T_all
    non_target_data = N_all
    print(target_data.shape)
    print(non_target_data.shape)
    
#-----------------------------------------
   # target_data = np.transpose(target_data, [1, 0, 2])
    #non_target_data = np.transpose(non_target_data, [1, 0, 2])
    output_file = "./resampled_data_.npz"
    np.savez(output_file, np.array(target_data,dtype=float), np.array(non_target_data,dtype=float))
#-------------------------------------

    
if __name__ == "__main__":
    main()


(8250, 7, 300)
(41250, 7, 300)
