In [None]:
'''
# preprocessing procedure:
1.load data
2.run SSP
3.filter:notch, high pass, low pass
4.run ICA  # before epoch, 1 Hz highpass suggested 
5.epoch data (auto rejection, baseline correction)
6.Annotate and delete bad epochs manually

# post prep
7.average epochs
8.concatenate data at subject level
9.source reconstruction
10.decoding
'''
import numpy as np
import os,sys
from os.path import join as pj
import time
import matplotlib
matplotlib.use('TkAgg') #   Qt5Agg #'TkAgg'
'''
oldMne = '/usr/local/neurosoft/anaconda3/lib/python3.8/site-packages/mne'
sys.path.remove(oldMne)
currMne = '/nfs/s2/userhome/tianjinhua/workingdir/code'
sys.path.append(currMne)
'''
import mne

rootDir = '/nfs/s2/userhome/tianjinhua/workingdir/meg/numerosity'
subjName = ['subj005'] # 'gu_wenyu','sun_baojia','wang_yan'
# 'subj002','subj003','subj004','subj005','subj006','subj007','subj008','subj009','subj010'
taskName = 'raw'

# filter parameters
freqs = np.arange(50, 200, 50)
highcutoff = 1 # ICA recommandation
lowcutoff = 40
newSamplingRate = 500
emptyroomm = 'emptyroom_tsss.fif'
from mne.preprocessing import (create_eog_epochs, create_ecg_epochs, compute_proj_ecg, compute_proj_eog)

for subj in subjName:
    rawDir = pj(rootDir, subj, taskName)

    # name and makedir the save path
    savePath = pj(rootDir, subj, 'preprocessed')
    if not os.path.exists(savePath):
        os.makedirs(savePath)
    # walk through subj path, filter, epoch and save epochd ata
    for file in os.listdir(rawDir):
        if '_tsss.fif' in file:
            fileName = 'filterEpochICA_' + file
            savePath2 = pj(savePath,fileName)
            if not os.path.exists(savePath2):
                # rawName = raw + str(fileNum)
                filepath = pj(rawDir,file)
                raw = mne.io.read_raw_fif(filepath,allow_maxshield=True, preload=True) #allow_maxshield=True,

                # --------------------------------------
                # 1.1 load data, run SSP, filter
                # --------------------------------------
                # 1. Repairing artifacts with SSP
                system_projs = raw.info['projs']
                raw.del_proj()  #discard the system-provided SSP projectors
                empty_room_file = os.path.join(rootDir, subj, emptyroomm)
                empty_room_raw = mne.io.read_raw_fif(empty_room_file)
                empty_room_raw.del_proj()

                mags = mne.pick_types(raw.info, meg='mag')
                empty_room_projs = mne.compute_proj_raw(empty_room_raw, n_grad=3, n_mag=3)
                projs=empty_room_projs[3:]
                raw.add_proj(projs, remove_existing=True)
                raw.apply_proj()

                # 2. filter the data: notch, high pass, low pass
                meg_picks = mne.pick_types(raw.info, meg=True, eeg=False, eog=False)
                raw = raw.notch_filter(freqs=freqs, picks=meg_picks, method='spectrum_fit', n_jobs=4,
                                              filter_length="auto", phase="zero", verbose=True)
                raw = raw.filter(l_freq=highcutoff, h_freq=None)
                raw = raw.filter(l_freq=None, h_freq=lowcutoff)

                # --------------------------------------
                # 1.2 run ICA and reject artifact components
                # --------------------------------------
                from mne.preprocessing import (ICA, create_eog_epochs, create_ecg_epochs,corrmap)
                ica = ICA(n_components=60, random_state=97)
                ica.fit(raw)
                ica.plot_sources(raw, show_scrollbars=True)
                ica.plot_components() #0 14 17 21 32; 0 1 6; 2 10 32

                #reject ica components from input
                ica_rej = input()
                bad_comps = ica_rej.split(" ")
                bad_comps = [int(bad_comps[i]) for i in range(len(bad_comps))] #transform str to number

                #ica.exclude = list(ica_rej)
                ica_rej = ica.apply(raw,exclude=bad_comps)
                #ica_rej.plot()
                ica_rej.plot_psd(fmax=100)

                # --------------------------------------
                # 1.3 Epoch and reject bad epochs
                # --------------------------------------
                # peak-to-peak amplitude rejection parameters
                reject = dict(grad=4000e-13, mag=4e-12)
                # select events and epoch data
                events = mne.find_events(ica_rej, shortest_event=2,min_duration=0.005)
                events_no_1back = mne.pick_events(events, exclude=[99, 101])

                # epoch data: select events, remove the first epoch.
                ica_rej = mne.Epochs(ica_rej, events_no_1back, tmin=-0.2, tmax=1, baseline=(-0.2, 0), reject=reject, preload=True, detrend=1, verbose=True) # events_no_1back[1:, :]
                ica_rej.apply_baseline((-0.2, 0))

                # downsample to 500Hz
                ica_rej.resample(
                    sfreq=newSamplingRate,
                    npad="auto",
                    window="boxcar",
                    pad="edge",
                    verbose=True) #n_jobs= 4,

                # save ICA rejected and epoched files
                ica_rej.save(savePath2, overwrite=True)
                '''
                # --------------------------------------
                # 1.3 Reject epochs manually
                # --------------------------------------
                #select and annotate bad epoch
                fig = ica_rej.plot()
                fig.canvas.key_press_event('a')

                #apply bad epoch
                ica_rej.drop_bad()
                # save the manual rejected file
                fileName = 'filterEpochICAAj_' + file
                tempSavename = pj(savePath,fileName)
                ica_rej.save(tempSavename, overwrite=True)
                '''
                del raw,ica_rej
print('All Done')