# Imports and function definitions

In [None]:
# Imports
import mne
import json
%matplotlib qt
import os
from mne_icalabel import label_components


In [None]:

# Helper function definitions
def filter_raw(raw, order = 3):
    '''
    This function erases the first 5 seconds of the EEG, which include
    the calibration period.
    It then rereferences to the average of the channels.
    Finally, it applies the following filters to an MNE raw object:
    Band-pass from .5 to 50 Hz
    Notch at 47-53, 97-103 & 147-153 for the elimination of industrial noise
    and its harmonics.

    All 4 filters are applied sequentially and are 3nd order Butterworth with 
    0 padding

    Input: The raw object
    Returns: None (since the raw object is passed by reference and changes are in-place)

    '''

    raw.crop(tmin=5.0)
    raw = raw.set_eeg_reference(ref_channels='average')

    my_iir_params = dict(order=order, ftype='butter', output='ba', padlen=0)
    
    
    

    raw=raw.filter(l_freq=1, h_freq=100, method='iir', iir_params=my_iir_params)
    raw=raw.filter(l_freq=53, h_freq=47, method='iir', iir_params=my_iir_params)
    raw=raw.filter(l_freq=103, h_freq=97, method='iir', iir_params=my_iir_params)
    raw=raw.filter(l_freq=153, h_freq=147, method='iir', iir_params=my_iir_params)

    return None



# Change path in the next cell accordingly

In [None]:
path = ''

# .EDF

In [None]:
raw = mne.io.read_raw_edf(path, preload=True)
rnet = mne.channels.make_standard_montage('brainproducts-RNP-BA-128')

# Rename channels according to our setup. 

### New eeg files (after around April 2022) -for old use next cell

In [None]:
# Duplicate name correction
mapping = {'POO1-0':'POO1', 
           'POO1-1':'POO10', 
           'XX-0':'FPZ', 
           'XX-1':'FCZ', 
           'XX-2':'AFF4', 
           'TPP1':'TPP10'}
raw.rename_channels(mapping)

# Drop trailing H from ch_names
new_mapping = dict(zip(raw.ch_names, [ch.rstrip('H') for ch in raw.ch_names]))
raw.rename_channels(new_mapping)
rnet_fix_ch_names = [chan.upper().rstrip('H') for chan in rnet.ch_names]
rnet.rename_channels(dict(zip(rnet.ch_names, rnet_fix_ch_names)))

In [None]:
# Duplicate name correction
# mapping = {'POO1-0':'POO1', 
#            'POO1-1':'POO10', 
#            'FFC5-0':'FCC5', 
#            'FFC5-1':'FFC5', 
#            'XX-0':'FPZ', 
#            'XX-1':'FCZ', 
#            'XX-2':'AFF4', 
#            'TPP1':'TPP10'}

# raw.rename_channels(mapping)

# # Drop trailing H from ch_names
# new_mapping = dict(zip(raw.ch_names, [ch.rstrip('H') for ch in raw.ch_names]))
# raw.rename_channels(new_mapping)
# rnet_fix_ch_names = [chan.upper().rstrip('H') for chan in rnet.ch_names]
# rnet.rename_channels(dict(zip(rnet.ch_names, rnet_fix_ch_names)))

In [None]:
raw.drop_channels([ch for ch in raw.ch_names if 'DC' in ch or 'GND' in ch])
raw.set_montage(rnet)

# .EEG

In [None]:
raw = mne.io.read_raw_nihon(path, preload=True)

# rescaling .eeg file
for i, gain in enumerate(raw._raw_extras[0]['gains'][:,0]):
    #print(raw.ch_names[i], gain, raw.__dict__['_data'][i,:50])
    if gain == 0.001:
        raw.__dict__['_data'][i] = (raw.__dict__['_data'][i]*0.0001)/raw._raw_extras[0]['cal'][i]
        raw._raw_extras[0]['gains'][i,0] = 0.000001
    elif gain == 0.000001:
        raw.__dict__['_data'][i] = (raw.__dict__['_data'][i]*0.1)/raw._raw_extras[0]['cal'][i]

# set montage
rnet = mne.channels.make_standard_montage('brainproducts-RNP-BA-128')
mapping = pd.read_csv('') # /path/to/channels.csv
mapping = dict(zip(mapping.iloc[:,0].values, mapping.iloc[:,1].values))
raw.rename_channels(mapping)

raw.drop_channels([ch for ch in raw.ch_names if 'DC' in ch or 'GND' in ch])

#change channel type
mapping = {ch: 'eeg' for ch in raw.__dict__['info']['ch_names']}
raw.set_channel_types(mapping)
raw.set_montage(rnet)

### Marking bad channels. Standard interpolations are 'AFF4', 'FPZ', 'FCZ'

In [None]:
raw.info['bads'] = ['AFF4', 'FPZ', 'FCZ']
#raw.interpolate_bads(reset_bads=True)

In [None]:
filter_raw(raw)

## Manual channel and time region rejection

Here, single channels significantly deviating from the butterfly plot (enable with pressing `b` in the plot window), should be marked as bad. Time regions also producing artifacts in multiple channels should be annotated with `BAD`. Do not regions with eye artifacts, either blinks or saccadic movements. Also inspect the PSD for indications of noisy channels.

In [None]:
raw.plot_psd()

In [None]:
raw.plot()

### For ICA a 99.99 % of components capturing the cumulative covariance are required. More than 70 components should appear in this case

In [None]:
# Selecting extended infomax for comparison with ICA Label
ica = mne.preprocessing.ICA(n_components=.9999, method='infomax', fit_params=dict(extended=True), random_state=42 )
ica.fit(raw)

In [None]:
ica.plot_sources(raw)

In [None]:
ica.plot_components()

In [None]:
pred = label_components(raw, ica, 'iclabel')


In [None]:
for i in range(len(pred['y_pred_proba'])):
    print(i, pred['labels'][i], pred['y_pred_proba'][i])

In [None]:
import gc; gc.collect()

In [None]:
#raw.plot(title="Before ICA")
appd = raw.copy()
ica.apply(appd)
#appd.plot(title="After ICA")

In [None]:
raw.plot_psd()

In [None]:
appd.plot_psd()

# Interpolate channels not cleaned with ICA component removal

In [None]:
appd.interpolate_bads(reset_bads=True)

### Save a dictionary with bad channels and rejected ica components, the ICA and processed FIF

In [None]:
# Gather bad channels and rejected components into one dictionary
bad_channels = raw.info['bads']
rejected_components = ica.exclude
#rejected_components =
rejected_components.sort()
rejections = {
    "bad_channels": bad_channels,
    "rejected_components": rejected_components
}

counter = 0 

for extension in ['_rejections.json', '-ica.fif', '_raw.fif']:
    # Path Directory where the original file exists
    path_arr = path.split(os.sep)
    path_arr[-1] = path_arr[-1].split('.')[0] 
    path_arr[-1] += extension
    new_path = os.path.join(*path_arr)
    new_path = os.path.relpath(new_path)
    if counter == 0:
        # Rejections dictionary
        with open(os.path.join(new_path), 'w') as f:
            f.write(json.dumps(rejections, indent=4))

    elif counter == 1:
        ica.save(os.path.join(new_path))
        print(new_path)
        
    elif counter == 2:
        appd.save(os.path.join(new_path))
    counter += 1

