# Imports and function definitions

In [None]:
# Imports
import mne
%matplotlib qt
import os

# Helper function definitions
def filter_raw(raw, order = 2):
    '''
    This function erases the first 5 seconds of the EEG, which include
    the calibration period.
    It then rereferences to the average of the channels.
    Finally, it applies the following filters to an MNE raw object:
    Band-pass from .5 to 50 Hz
    Notch at 47-53, 97-103 & 147-153 for the elimination of industrial noise
    and its harmonics.

    All 4 filters are applied sequentially and are 2nd order Butterworth with 
    0 padding

    Input: The raw object
    Returns: None (since the raw object is passed by reference and changes are in-place)

    '''

    raw.crop(tmin=5.0)
    raw = raw.set_eeg_reference(ref_channels='average')

    my_iir_params = dict(order=order, ftype='butter', output='ba', padlen=0)
    
    
    

    raw=raw.filter(l_freq=1, h_freq=50, method='iir', iir_params=my_iir_params)
    raw=raw.filter(l_freq=53, h_freq=47, method='iir', iir_params=my_iir_params)
    raw=raw.filter(l_freq=103, h_freq=97, method='iir', iir_params=my_iir_params)
    raw=raw.filter(l_freq=153, h_freq=147, method='iir', iir_params=my_iir_params)

    return None

def continuous_regions(condition):
    import numpy as np
    """Finds contiguous True regions of the boolean array "condition". Returns
    a 2D array where the first column is the start index of the region and the
    second column is the end index."""

    # Find the indicies of changes in "condition"
    d = np.diff(condition)
    idx, = d.nonzero() 

    # We need to start things after the change in "condition". Therefore, 
    # we'll shift the index by 1 to the right.
    idx += 1

    if condition[0]:
        # If the start of condition is True prepend a 0
        idx = np.r_[0, idx]

    if condition[-1]:
        # If the end of condition is True, append the length of the array
        idx = np.r_[idx, condition.size] # Edit

    # Reshape the result into two columns
    idx.shape = (-1,2)
    return idx

# @ LEVON - https://stackoverflow.com/questions/10996140/how-to-remove-specific-elements-in-a-numpy-array

def CropCalibrationZeros(raw):
    import numpy as np
    # Raw has been tested for .EDF Files
    a = raw.get_data()[0] # First Channel
    EEG_zero = np.min(abs(a))
    
    sfreq = int(raw.info['sfreq']) # Sampling Frequency
    end = (a.shape[0]-1)/sfreq# Default end
    total_duration = end
    idx = []
    condition = np.abs(a) == EEG_zero
    for start, stop in continuous_regions(condition):
        segment = a[start:stop]
        if stop-start > sfreq/20:
            idx.append(start/sfreq)
            idx.append(stop/sfreq)
    for element in idx:
        if element < 15: start = element
        if element > 50: end = element; break
    

    return start, end, total_duration

# Change path in the next cell accordingly

In [None]:
path = ''

In [None]:
raw = mne.io.read_raw_edf(path, preload=True)
rnet = mne.channels.make_standard_montage('brainproducts-RNP-BA-128')

# Rename channels according to our setup. 
## It might need adjusting in the future

### New eeg files (after around April 2022) -for old use next cell

In [None]:
# Duplicate name correction
mapping = {'POO1-0':'POO1', 
           'POO1-1':'POO10', 
           'XX-0':'FPZ', 
           'XX-1':'FCZ', 
           'XX-2':'AFF4', 
           'TPP1':'TPP10'}
raw.rename_channels(mapping)

# Drop trailing H from ch_names
new_mapping = dict(zip(raw.ch_names, [ch.rstrip('H') for ch in raw.ch_names]))
raw.rename_channels(new_mapping)
rnet_fix_ch_names = [chan.upper().rstrip('H') for chan in rnet.ch_names]
rnet.rename_channels(dict(zip(rnet.ch_names, rnet_fix_ch_names)))

In [None]:
# Duplicate name correction
mapping = {'POO1-0':'POO1', 
           'POO1-1':'POO10', 
           'FFC5-0':'FCC5', 
           'FFC5-1':'FFC5', 
           'XX-0':'FPZ', 
           'XX-1':'FCZ', 
           'XX-2':'AFF4', 
           'TPP1':'TPP10'}

raw.rename_channels(mapping)

# Drop trailing H from ch_names
new_mapping = dict(zip(raw.ch_names, [ch.rstrip('H') for ch in raw.ch_names]))
raw.rename_channels(new_mapping)
rnet_fix_ch_names = [chan.upper().rstrip('H') for chan in rnet.ch_names]
rnet.rename_channels(dict(zip(rnet.ch_names, rnet_fix_ch_names)))

In [None]:
raw.drop_channels([ch for ch in raw.ch_names if 'DC' in ch or 'GND' in ch])
raw.set_montage(rnet)
raw.plot_psd()

# Interpolating bad channels. Standard interpolations are 'AFF4', 'FPZ', 'FCZ', add whatever is needed from PSD above.

In [None]:
raw.info['bads'] = ['AFF4', 'FPZ', 'FCZ']
raw.interpolate_bads(reset_bads=True)

In [None]:
raw.plot_psd()

In [None]:
raw.plot()

In [None]:
filter_raw(raw)

In [None]:
raw.plot_psd()

In [None]:
raw.plot()

# For ICA a 98 % of components capturing the cumulative covariance are required. In cases where a lot of artifacts exist, a good choice is changing ".98" to 40 components.

In [None]:
ica = mne.preprocessing.ICA(n_components=.98, method='picard', fit_params=dict(extended=True) )
ica.fit(raw)

In [None]:
ica.plot_components()

In [None]:
ica.plot_sources(raw)

In [None]:
raw.plot(title="Before ICA")
appd = raw.copy()
ica.apply(appd)
appd.plot(title="After ICA")

In [None]:
raw.plot_psd()

In [None]:
appd.plot_psd()

# Interpolate channels not cleaned with ICA component removal

In [None]:
appd.info['bads'] = ['AF7', 'AF8', 'F7']
appd.interpolate_bads(reset_bads=True)

In [None]:
appd.plot_psd()

# Save as .fif

In [None]:
path_arr = path.split(os.sep)
path_arr[-1] = path_arr[-1].split('.')[0] + '_raw.fif'
new_path = os.path.join(*path_arr)
new_path = os.path.relpath(new_path)
appd.save(new_path)#, overwrite=True)