# This preprocessing workflow is developped at 08/07/2019
* mne event coding in VD {'12.0': 5, '100.0': 2, '21.0': 8, '102.0': 4, '254.0': 11, '131.0': 6, '132.0': 7, '1.0': 1, '23.0': 10, '101.0': 3, '255.0': 12, '22.0': 9})   6-safe period   7-threat period  11-start 12-end
* one needs to notice the event coding can be different in diff files
* event recoding three-number.0(float format) code: session+state+cond
    * section 1,2
    * state: VD:1, FA:2, OP:3
    * condition safe:1, threat:2  baseline:3




In [98]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
===============================================
Preprocessing on Enrico data using MNE and ASR
===============================================
We firstly define subject dictionary as well as state list, reject dict, then we import eeglab format 
Raw data with MNE package. We apply:
1) a notch filter to remove powerline artifact (50 Hz)
2) a 1Hz-100Hz band-pass filter
====> output = subj0*number*_*state*_filt_raw.fif  
Then concatenate the data of the same session
3) ASR and ICA fitting: This is a parallel process of preprocessing, 
    the goal is to store two sets of ica component images and generate an exclude dict from that: 
        3)do epochs in order to autoreject bad epochs and observe rejection report ===> output: cleaned epochs
        3)ICA fit and save the ICA components to reject - this is done by using function ica_component_selection()
        ====> output = fif file that save ica object and a rejecting component dict


Note: version not fullfill custer-run requirement

Suggestions:
1) decide infomation storage format
2) 

Updated on July 2019

@author: Gansheng TAN aegean0045@outlook.com    based on Manu's codes
"""

##############################################################  Set-up ################################################
import mne
import importlib
import numpy as np
from mne.report import Report
from autoreject import AutoReject
from autoreject import compute_thresholds
from autoreject import get_rejection_threshold 
import matplotlib.pyplot as plt  # noqa
import matplotlib.patches as patches  # noqa
from autoreject import set_matplotlib_defaults  # noqa
from Autoreject_report_plot import Autoreject_report_plot #Gansheng
%matplotlib qt
mne.set_log_level('WARNING')

##################### OS path in INSERM computer #####################################################################
raw_data_path = '/home/gansheng.tan/process_mne/INSERM_EEG_Enrico_Proc/data_eeglab/raw_data/'
montage_fname = '/home/gansheng.tan/process_mne/INSERM_EEG_Enrico_Proc/data_eeglab/raw_data/Biosemi64_MAS_EOG.locs'

########################################## Initialization parameter##########################################""
subj_list = ['94']
section_list=['1']
#state list defines the concatenating order
# state_list = ['VD','FA','OP']
state_list = ['VD','FA']
power_freq_array = [50]
reject_raw_data_section1 = {'07':['OP'],'10':['FA','VD'],'21':['VD'],'36':['OP']}
reject_raw_data_section2 = {'07':['OP'], '10':['VD'], '21':['FA','VD'],
                '22':['OP'], '57':['OP','FA'], '82':['FA','OP','VD']}

# bad channel rejection is not apllied in the preproc, bad channels will be defined by eyes later
bad_channel={'94':{'FA1':['Pz']}}
# example: bad_channel = {'94':{'FA1':['FP1','FP2'],{'VD1':['Cz']}} excluded for ICA analysis


################################ step00: cut and filter data and concatenate 3 recording in one section ############

###### set up montage
montage_biosemi=mne.channels.read_montage(montage_fname)

###### preproc for each raw file
for subj in subj_list:
    for section in section_list:
        reject_state=[]
        conctn_list = []
        conctn_anno_list=[]
#         conctn_dict = {}
        if subj in eval('reject_raw_data_section'+section).keys():
            reject_state = eval('reject_raw_data_section'+section)[subj]
            print("the rejected states of subject {} in section {} are {}".format(subj,section,reject_state))
        for state in state_list:
            if state in reject_state:
                continue
            else:
                raw_fname = raw_data_path + 'subj0'+subj+'_'+state+section+'_mast.set'
                raw = mne.io.read_raw_eeglab(raw_fname,montage_biosemi,verbose='INFO',preload=True,eog='auto')
                
                events = mne.events_from_annotations(raw)
                events_coding=events[1]
                # take recording from 254 start of recording to 255 end of recording
                events=np.asarray(events[0])
                events_code_start = events_coding['254.0']
                events_code_end = events_coding['255.0']
                start = events[events[:,2]==events_code_start][0][0]
                stop = events[events[:,2]==events_code_end][0][0]
                raw_cut_filt = raw.copy()

                raw_cut_filt.crop(tmin = start/raw.info['sfreq'], tmax = stop/raw.info['sfreq'])
                raw_cut_filt.notch_filter(freqs=power_freq_array)
                raw_cut_filt.filter(l_freq=1,h_freq=100)
                
                ############ annotation engineering ################
                index_dlt=0
                for i in range(raw_cut_filt.annotations.__len__()):
                    if (raw_cut_filt.annotations.__getitem__(i-index_dlt)['description']) not in ['131.0','132.0','255.0']:
                        raw_cut_filt.annotations.delete(i-index_dlt)
                        index_dlt+=1                       
                    else: 
                        continue
                mne_annotation_recode_by_adding(section=section,state=state,annotations=raw_cut_filt.annotations)

                for i in range(raw_cut_filt.annotations.__len__()):
                    print(raw_cut_filt.annotations.__getitem__(i))
                raw_cut_filt.plot(title='raw plot after cut',scalings=150e-6)
                conctn_anno_list.append(raw_cut_filt.annotations)
                conctn_list.append(raw_cut_filt)

#         raw_full = mne.io.concatenate_raws(conctn_list)
        full_array = conctn_list[0]._data
        del conctn_list[0]
        for raw2conctn in conctn_list:
            full_array = np.concatenate((full_array,raw2conctn._data),axis=1)
        raw_full = mne.io.RawArray(full_array,info = conctn_list[0].info)
        raw_full.plot(scalings=150e-6) 
        
        full_annotation = conctn_anno_list[0]
        del conctn_anno_list[0]
        for anno2conctn in conctn_anno_list:
            for i in range(anno2conctn.__len__()):
                
        
        print('full anno')
        for i in range(raw_full.annotations.__len__()):
            print(raw_full.annotations.__getitem__(i))  
                
                    

                
        

Reading /home/gansheng.tan/process_mne/INSERM_EEG_Enrico_Proc/data_eeglab/raw_data/subj094_VD1_mast.fdt
Reading 0 ... 308223  =      0.000 ...   601.998 secs...
Setting up band-stop filter from 49 - 51 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandstop filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 49.38
- Lower transition bandwidth: 0.50 Hz (-6 dB cutoff frequency: 49.12 Hz)
- Upper passband edge: 50.62 Hz
- Upper transition bandwidth: 0.50 Hz (-6 dB cutoff frequency: 50.88 Hz)
- Filter length: 3379 samples (6.600 sec)

Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 1 - 1e+02 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuatio

In [82]:
anno_test_full=conctn_anno_list[0].__add__(conctn_anno_list[1])
for i in range(anno_test_full.__len__()):
    print(anno_test_full.__getitem__(i))

OrderedDict([('onset', 0.0), ('duration', 0.0), ('description', '113.0'), ('orig_time', None)])
OrderedDict([('onset', 0.0), ('duration', 0.0), ('description', '123.0'), ('orig_time', None)])
OrderedDict([('onset', 160.529296875), ('duration', 0.0), ('description', '122.0'), ('orig_time', None)])
OrderedDict([('onset', 167.14453125), ('duration', 0.0), ('description', '112.0'), ('orig_time', None)])
OrderedDict([('onset', 206.1328125), ('duration', 0.0), ('description', '121.0'), ('orig_time', None)])
OrderedDict([('onset', 212.822265625), ('duration', 0.0), ('description', '111.0'), ('orig_time', None)])
OrderedDict([('onset', 251.73046875), ('duration', 0.0), ('description', '121.0'), ('orig_time', None)])
OrderedDict([('onset', 259.119140625), ('duration', 0.0), ('description', '112.0'), ('orig_time', None)])
OrderedDict([('onset', 297.328125), ('duration', 0.0), ('description', '122.0'), ('orig_time', None)])
OrderedDict([('onset', 304.826171875), ('duration', 0.0), ('description',

Traceback (most recent call last):
  File "/home/gansheng.tan/mne/local/lib/python3.5/site-packages/matplotlib/cbook/__init__.py", line 215, in process
    func(*args, **kwargs)
  File "/home/gansheng.tan/mne/local/lib/python3.5/site-packages/mne/viz/utils.py", line 1008, in _mouse_click
    params['plot_fun']()
  File "/home/gansheng.tan/mne/local/lib/python3.5/site-packages/mne/viz/raw.py", line 1104, in _plot_raw_traces
    segment_color = params['segment_colors'][dscr]
KeyError: '123.0'
Traceback (most recent call last):
  File "/home/gansheng.tan/mne/local/lib/python3.5/site-packages/matplotlib/cbook/__init__.py", line 215, in process
    func(*args, **kwargs)
  File "/home/gansheng.tan/mne/local/lib/python3.5/site-packages/mne/viz/utils.py", line 1008, in _mouse_click
    params['plot_fun']()
  File "/home/gansheng.tan/mne/local/lib/python3.5/site-packages/mne/viz/raw.py", line 1104, in _plot_raw_traces
    segment_color = params['segment_colors'][dscr]
KeyError: '123.0'


In [99]:
raw_full.info

<Info | 17 non-empty fields
    bads : list | 0 items
    ch_names : list | Fp1, AF7, AF3, F1, F3, F5, F7, FT7, FC5, ...
    chs : list | 66 items (EEG: 64, EOG: 2)
    comps : list | 0 items
    custom_ref_applied : bool | False
    dev_head_t : Transform | 3 items
    dig : list | 66 items (66 EEG)
    events : list | 0 items
    highpass : float | 1.0 Hz
    hpi_meas : list | 0 items
    hpi_results : list | 0 items
    lowpass : float | 100.0 Hz
    meas_date : NoneType | unspecified
    nchan : int | 66
    proc_history : list | 0 items
    projs : list | 0 items
    sfreq : float | 512.0 Hz
    acq_pars : NoneType
    acq_stim : NoneType
    ctf_head_t : NoneType
    description : NoneType
    dev_ctf_t : NoneType
    experimenter : NoneType
    file_id : NoneType
    gantry_angle : NoneType
    hpi_subsystem : NoneType
    kit_system_id : NoneType
    line_freq : NoneType
    meas_id : NoneType
    proj_id : NoneType
    proj_name : NoneType
    subject_info : NoneType
    xplot

In [88]:
    raw_cut_filt._data.shape

(132, 537442)

# Below are personalised function -> to py script and import

In [63]:
import warnings

def mne_annotation_postpone (pptime, annotations):
    

def mne_annotation_recode_by_adding(section,state,annotations):
    onset = []
    duration = []
    description = []
    for i in range(annotations.__len__()):
        if annotations.__getitem__(i)['description'] in ['131.0','132.0']:
            onset,duration,description = mne_annotation_recode_info_extract(section=section,state=state,
                                                                        original_annotation = 
                                                                        annotations.__getitem__(i),
                                                                       onset=onset,duration=duration,
                                                                        description=description)
        else:
            continue
    index_dlt = 0
    for i in range(annotations.__len__()):
        if annotations.__getitem__(i-index_dlt)['description'] in ['131.0','132.0']:
            annotations.delete(i-index_dlt)
            index_dlt+=1
        else:
            continue
    onset.append(0.0)
    duration.append(0.0)
    description.append(mne_annotation_add_baseline(section=section,state=state))
    annotations.append(onset=onset,duration=duration,description=description)
    print ('annotation engineering succeed')
    return True

def mne_annotation_add_baseline(section,state):
    if section == '1':
        if state == 'VD':
            return '113.0'
        elif state == 'FA':
            return '123.0'
        elif state == 'OP':
            return '133.0'
        else:
            warnings.warn("unknown state detected", DeprecationWarning)
    elif section == '2':
        if state == 'VD':
            return '213.0'
        elif state == 'FA':
            return '223.0'
        elif state == 'OP':
            return '233.0'
        else:
            warnings.warn("unknown state detected", DeprecationWarning)
    else:
        warnings.warn("add baseline function only apply on rawfile having 2 sections", DeprecationWarning)
    return '999.0'
        

def mne_annotation_recode_info_extract(section,state,original_annotation,onset,duration,description):
    if section =='1':
        if state == 'VD':
            if original_annotation['description']=='131.0':
                onset.append(original_annotation['onset'])
                duration.append(original_annotation['duration'])
                description.append('111.0')
                
            elif original_annotation['description']=='132.0':
                onset.append(original_annotation['onset'])
                duration.append(original_annotation['duration'])
                description.append('112.0')
            else:
                print('this function only detect safe and threat period, please check original annotations')
        elif state == 'FA':
            if original_annotation['description']=='131.0':
                onset.append(original_annotation['onset'])
                duration.append(original_annotation['duration'])
                description.append('121.0')

            elif original_annotation['description']=='132.0':
                onset.append(original_annotation['onset'])
                duration.append(original_annotation['duration'])
                description.append('122.0')
            else:
                print('this function only detect safe and threat period, please check original annotations')
        elif state == 'OP':
            if original_annotation['description']=='131.0':
                onset.append(original_annotation['onset'])
                duration.append(original_annotation['duration'])
                description.append('131.0')
            elif original_annotation['description']=='132.0':
                onset.append(original_annotation['onset'])
                duration.append(original_annotation['duration'])
                description.append('132.0')
            else:
                print('this function only detect VD, FA, OP states, please check original annotations')
    elif section =='2':
        if state == 'VD':
            if original_annotation['description']=='131.0':
                onset.append(original_annotation['onset'])
                duration.append(original_annotation['duration'])
                description.append('211.0')
            elif original_annotation['description']=='132.0':
                onset.append(original_annotation['onset'])
                duration.append(original_annotation['duration'])
                description.append('212.0')
            else:
                print('this function only detect safe and threat period, please check original annotations')
        elif state == 'FA':
            if original_annotation['description']=='131.0':
                onset.append(original_annotation['onset'])
                duration.append(original_annotation['duration'])
                description.append('221.0')
            elif original_annotation['description']=='132.0':
                onset.append(original_annotation['onset'])
                duration.append(original_annotation['duration'])
                description.append('222.0')
            else:
                print('this function only detect safe and threat period, please check original annotations')
        elif state == 'OP':
            if original_annotation['description']=='131.0':
                onset.append(original_annotation['onset'])
                duration.append(original_annotation['duration'])
                description.append('231.0')
            elif original_annotation['description']=='132.0':
                onset.append(original_annotation['onset'])
                duration.append(original_annotation['duration'])
                description.append('123.0')
            else:
                print('this function only detect VD, FA, OP states, please check original annotations')
    else:
        print('3rd section dected, please check annotations')
    return(onset,duration,description)
        
