In [None]:
import numpy as np
import os
from scipy import signal
import mne
from mne.preprocessing.ica import corrmap
from mne.preprocessing import ICA
from matplotlib.pyplot import savefig
import pandas as pd
import datetime

In [None]:
class ICA:
    
    def __init__(self, channel_list, sampling_rate, log_filename):
        self.__ch_names = channel_list
        self.__sampling_rate = sampling_rate
        self.__no_of_channels = len(self.__ch_names)
        
        # set constant
        ## ['Fp1', 'Fp2', 'Fz', 'Cz', 'T3', 'T4', 'Pz', 'Oz']
        ch_f = ["Fz"]
        ch_cp = ["Cz"]
        ch_t = ['T3', 'T4']
        ch_op = ["Oz", 'Pz']
        ch_fp = ['Fp1', 'Fp2']

        self.__chs_dict = {'Frontal':ch_f,
                    'CP':ch_cp,
                    'Temporal':ch_t,
                    'OP':ch_op,
                    'FP':ch_fp
                    }
        self.__info = mne.create_info(self.__ch_names, self.__sampling_rate, ch_types=["eeg"] * self.__no_of_channels)
        self.__log_filename = log_filename
        
        try:
            self.__exclude_df = pd.read_csv(self.__log_filename)
        except:
            print ('Cannot read from', self.__log_filename)

    
    #find min max of each channel in ica components 
    #imported from: Karis Matchaparn
    def find_exclude_ica(self, ica_dict, channel, focus_part):
        exclude_list = [] 
#         print(channel.keys())

        for ica_idx in range(len(ica_dict)): # get into each component

            #find summation of each brain part
            sum_chs = {}
            for brain_part in channel.keys(): # loop for each brain part
                sum_buffer = 0

                for chs in channel[brain_part]: #loop for each channel in brain part
#                     print("---",chs, ica_dict[ica_idx][chs])
                    sum_buffer += ica_dict[ica_idx][chs]
#                 print("x"*20, brain_part)
                sum_chs[brain_part] = sum_buffer

            # **************************  
#             print("----------")
#             print(ica_idx,sum_chs)
#             print("----------")
            for chs in sum_chs.keys():
                if sum_chs[focus_part] < sum_chs[chs]:
                    exclude_list.append(ica_idx)
                    break

        print(exclude_list)
        return exclude_list
    
    def label_exclude_list(self, eeg, subj_id, clip_id):
        raw = mne.io.RawArray(eeg, self.__info)
        raw.set_montage(mne.channels.read_montage("standard_1020"))
        raw_tmp = raw.copy()
        
        ica = mne.preprocessing.ICA(method="extended-infomax", random_state=1)
        ica.fit(raw_tmp)       

        ica.plot_components(inst=raw_tmp)
        ica.plot_sources(raw_tmp)
        
        complete = False
        while not complete:
            try:
                exc = input("Which components do you want to delete? (0-7 or -1 if no component) (put , between each component no.): ")
                if type(exc) == int:
                    exclude_list = [exc]
                    if exc < -1 or exc > 7:
                        raise Exception
                else:
                    exclude_list = [int(x) for x in exc.split(',')]
                    for x in exclude_list:
                        if x < -1 or x > 7 :
                            raise Exception

                self.__log(subj_id, clip_id, '|'.join([str(x) for x in exclude_list]))
                complete = True
            except:
                print ('Error: Please try to input component no. again (-1 to 7 only)')
                complete = False
        
        print("Component to delete =", exclude_list)
        
    
    def remove_exclude_components(self, eeg, subj_id, clip_id):
        raw = mne.io.RawArray(eeg, self.__info)
        raw.set_montage(mne.channels.read_montage("standard_1020"))

        raw_tmp = raw.copy()
#         raw_tmp.filter(1, None, fir_design="firwin")
        
        ica = mne.preprocessing.ICA(method="extended-infomax", random_state=1)
        ica.fit(raw_tmp)
        
        df = self.__exclude_df
        d = df.loc[(df['subj_id'] == subj_id) & (df[' clip_id'] == clip_id)]
        exc = d[' exclude'].values[0]
        exclude_list = [int(e) for e in exc.split('|')]
        
        if -1 in exclude_list:
            print('No component to delete.')
            return np.array(raw_tmp[:])[0]
        else:
            print("Delete components :", exclude_list)
            ica.exclude = exclude_list #select components to exclude

            raw_corrected = raw.copy()
            ica.apply(raw_corrected)
            
            result = np.array(raw_corrected[:])[0] #get data after ica
            return result
    
    
    def __log(self, *msg):
        ## save subject id, clip id, list of excluding components (concatenated with |)
        st = ''
        for m in msg:
            st += str(m) + ', '
        f = open(self.__log_filename, "a")
        f.write(st + '\n') 
        print(msg)
        f.close()
        
        now = datetime.datetime.now()
        copyfile(self.__log_filename, './log-exclude-list/ica_exclude_part'+str(now)+'.csv')

In [None]:
class EEGPreprocessing:
    
    def __init__(self):
        self.a = 6
        self.__input_path = '../data/EEG/'
        self.__emotions = ['Happiness', 'Fear', 'Excitement', \
                           'Arousal', 'Valence', 'Reward']
        self.__band_list = [{ 'name': 'theta', 'low': 3, 'high': 7}, 
                            { 'name': 'alpha', 'low': 8, 'high': 13}, 
                            { 'name': 'beta', 'low': 14, 'high': 29},
                            { 'name': 'gamma', 'low': 30, 'high': 47},
#                            { 'name': 'all', 'low': 4, 'high': 47}
                           ]
        self.__channel_list = ['Fp1', 'Fp2', 'Fz', 'Cz', 'T3', 'T4', 'Pz', 'Oz']
        self.__log_filename = './ica_exclude_part.csv'

            
    def read_signals(self):
        self.__data = np.load(os.path.join(self.__input_path, 'raw/EEG.npy'))
        self.get_input_data_info()
        
    
    def set_labels(self, thresholds = None, emotions = None):
        self.__label = np.load(os.path.join(self.__input_path, 'raw/result.npy'))
        
        if thresholds and len(thresholds) == 1:
            above_thres = self.__label > thresholds[0]
            below_thres = self.__label <= thresholds[0]
            self.__label[above_thres] = 1
            self.__label[below_thres] = 0

            for i, emo in enumerate(self.__emotions):
                label_res = self.__label[:,:,i].reshape(1, -1)
                label_res = np.array(label_res[0])

                print ('save results:', emo, label_res.shape)
                print (len(label_res[label_res==0]), len(label_res[label_res==1]) )
                print (label_res)
                print  ()         
                np.savez(os.path.join('../data/score/label/result_' + emo + '_binclass'), 
                     y = label_res, 
                     threshold = thresholds)
                
        elif thresholds:

            for i, emo in enumerate(self.__emotions):
                if emo in emotions:
                    label_res = self.__label[:,:,i].reshape(1, -1)
                    label_res = label_res[0]
                    print (label_res)
                    label_res[label_res < thresholds[0]] = 0
                    label_res[((label_res >= thresholds[0]) & (label_res < thresholds[1]))] = 1
                    label_res[label_res >= thresholds[1]] = 2


                    print ('save results:', emo, label_res.shape)
                    print (label_res)
                    
                    for i in range(0, 3):
                        print( i, len(label_res[label_res==i]))
                    print()
                    
                    np.savez(os.path.join('../data/score/label/result_' + emo + '_3class'), 
                         y = label_res, 
                         threshold = thresholds)
            
        else:
            thresholds = range(1, 10)
            for i, emo in enumerate(self.__emotions):
                print ('Finding threshold for', emo)
                min_diff = 300
                best_threshold = -1
                used_labels = None
                
                for current_thres in thresholds:
                    labels = np.copy(self.__label)
                    above_thres = labels > current_thres
                    below_thres = labels <= current_thres
                    labels[above_thres] = 1
                    labels[below_thres] = 0
                    label_res = labels[:,:,i].reshape(1, -1)
                    label_res = np.array(label_res[0])
                    
                    diff = abs(len(label_res[label_res==0]) - len(label_res[label_res==1]))
                    print ('diff', diff)
                    if diff < min_diff:
                        used_labels = np.copy(label_res)
                        best_threshold = current_thres
                        min_diff = diff
                        
                print ("Threshold = ", best_threshold, min_diff )
                print (len(used_labels[used_labels==0]), len(used_labels[used_labels==1]) )

                print ('save results:', emo, used_labels.shape)
                print (used_labels)
                print()
                np.savez(os.path.join(self.__input_path, 'result_' + emo + '_binclass'), 
                         y = used_labels, 
                         threshold = best_threshold)

        
    def get_input_data_info(self):
        shp = self.__data.shape
        self.__no_of_subj, self.__no_of_clips, self.__no_of_channels, self.__no_of_sampling = shp
        self.__sampling_rate = self.__no_of_sampling / 56
        
        print ('====== input data ======')
        print ('No. of subject:', self.__no_of_subj)
        print ('No. of clips:', self.__no_of_clips)
        print ('No. of channels:', self.__no_of_channels)
        print ('No. of points (time series data):', self.__no_of_sampling)
        print ('Samping rate:', self.__sampling_rate)
        print()
        
    def __get_already_input_txt(self):
        try:
            f = open(self.__log_filename, 'r') 
            results = []
            i = 0
            for line in f.readlines():
                spl = line.split(' ')
                results.append([int(spl[0]), int(spl[1])])
            print (results)
            return results
        except:
            print ('No file')
            return []
        
    def __get_already_input_csv(self):
        try:
            df = pd.read_csv(self.__log_filename)
            arr = np.array(df)
            return [tuple(a) for a in arr[:,0:2]]
        except:
            print ('No file')
            return []
        
        
    def preprocessing(self):
        print ('====== Preprocessing ======')
        self.__data_prep = np.zeros((self.__no_of_subj*self.__no_of_clips, \
                                     self.__no_of_channels, len(self.__band_list), self.__no_of_sampling))
        self.__info = mne.create_info(ch_names = self.__channel_list,
                   sfreq = self.__sampling_rate,
                   ch_types = 'eeg')
        
        self.__ica = ICA(self.__channel_list, self.__sampling_rate, self.__log_filename)
        
        rerun = input("Do you want to re-run all? (y/n): ")
        if rerun.lower() == 'n':
            already_done = self.__get_already_input_csv()
        else:
            rerun = input("Are you sure that you want to re-run all? (y/n): ")
            if rerun.lower() == 'y':
                already_done = []

                try:
                    os.remove(self.__log_filename)
                except OSError:
                    pass

                f = open(self.__log_filename, "a")
                f.write('subj_id, clip_id, exclude,\n') 
                f.close()
            else:
                print ('Continue using previous data..')
            
        index = 0
        for subject_id, data in enumerate(self.__data):
            #each subject 
            for clip_id, dt in enumerate(data):
                #each clip
                print ('\nPreprocessing: subject_id =', subject_id, 'clip_id =', clip_id )
                raw = mne.io.RawArray(data = dt, info = self.__info)
                
                # with CAR
                after_car = self.__calculate_CAR(raw)
                # without CAR
#                 after_car = dt
#                 print('NO CAR')
                
                if len(already_done) < self.__no_of_subj*self.__no_of_clips:
                    print( 'Continue labelling..')
                    if (subject_id, clip_id) in already_done:
                        print ('already_done', (subject_id, clip_id))
                        continue
                    else:
                        self.__ica.label_exclude_list(after_car, subject_id, clip_id)
                        continue

                else:
                    print ('All data is labeled for ICA: Actual removing components from saved csv..')
                    ## with ICA                
                    after_ica = self.__ica.remove_exclude_components(after_car, subject_id, clip_id)
                    ## without ICA
#                     after_ica = after_car # without ICA
#                     print('NO ICA')

                    del after_car

                    for channel_id, ch in enumerate(after_ica):
                        #each channel

                        #! already done notch at 50 Hz
                        #! asr = self.__calculate_ASR(bands_data) -> cannot do this

                        index = self.__no_of_clips*subject_id + clip_id
                        #print 'index', index, self.__no_of_clips, subject_id, clip_id, channel_id
                        self.__data_prep[index, channel_id, :, :] = self.__bands_filter(ch)
                    
        return self.__data_prep
        
        
    def __bands_filter(self, data):
        results = np.zeros(shape = (len(self.__band_list), len(data)))

        for i, band in enumerate(self.__band_list):
            results[i] = self.__bandpass_filter(data, band['low'], band['high'])
        
        return results
            
    def __bandpass_filter(self, data, low, high):
        nyq = 0.5 * self.__sampling_rate
        low = low / nyq
        high = high / nyq
        order = 2
        b, a = signal.butter(order, [low, high], btype='band')
        filtered = signal.lfilter(b, a, data)

        return filtered
    
    def __calculate_ASR(self, data):
        geo = geometric_median(data)
        print (data)
        print (geo.shape, geo)
        print ()
    
    def __calculate_CAR(self, raw):
        #calculate CAR from all channels of one freq band in one clip 
        raw_car, _ = mne.set_eeg_reference(raw, 'average', projection=True) #Bad EEG channels are automatically excluded if they are properly set in info['bads']
        applied = raw_car.apply_proj()
        car_npy = applied.get_data()
        
        return car_npy
    
    def save_to_numpy(self, result):
        np.save(os.path.join(self.__input_path, 'preprocessed/EEG_ICA.npy'), result)


In [None]:
eegPreprocessing = EEGPreprocessing()
eegPreprocessing.read_signals()
result = eegPreprocessing.preprocessing()
print (result.shape)
print (result)

In [None]:
eegPreprocessing.save_to_numpy(result)