<a href="https://colab.research.google.com/github/Charliebond125/MEGNet/blob/main/PreProcessing_Pipeline.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
%pip install mne

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting mne
  Downloading mne-1.4.2-py3-none-any.whl (7.7 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.7/7.7 MB[0m [31m53.6 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: mne
Successfully installed mne-1.4.2


In [None]:
import mne

from mne.preprocessing import (find_eog_events,
                                    find_ecg_events,
                                        create_eog_epochs,
                                            create_ecg_epochs,
                                        compute_proj_eog,
                                    compute_proj_ecg,
)

import numpy as np
import pandas as pd

import matplotlib.pyplot as plt

import warnings
warnings.filterwarnings('ignore')

In [None]:
%pip install tqdm

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
import mne
import tqdm
import os
import matplotlib.pyplot as plt
import warnings
import numpy as np

from mne.preprocessing import find_eog_events, find_ecg_events, create_eog_epochs, create_ecg_epochs, compute_proj_eog, compute_proj_ecg

warnings.filterwarnings('ignore')


class meg_preprocessing_pipeline:
    def __init__(self, raw):
        self.raw = raw
        self.eog_events = None
        self.ecg_events = None
        self.eog_projs = None
        self.ecg_projs = None
        self.epochs = None

    def notch_filter(self):
        picks = mne.pick_types(self.raw.info, meg=True)
        self.raw.notch_filter(freqs=[50, 100, 150, 200, 250], picks=picks,
                               method='spectrum_fit', filter_length='auto',
                               fir_window='hamming', fir_design='firwin',
                               n_jobs=1, verbose=True)
        return self

    def bandpass(self):

        sfreq = self.raw.info['sfreq']
        nyquist_freq = sfreq / 2

        l_freq = min(8, nyquist_freq)  # Lower cutoff frequency
        h_freq = min(30, nyquist_freq) # Upper cutoff frequency

        filter_order = 4
        ftype = 'butter'
        sfreq = self.raw.info['sfreq']
        iir_params = dict(order=filter_order, ftype=ftype)

        self.raw = self.raw.filter(l_freq=l_freq, h_freq=h_freq,
                                              method='iir', phase='zero',
                                              iir_params=iir_params,
                                              filter_length='auto',
                                              verbose=True)

        return self

    def find_events(self):
        mapping = {4: 'hand_imagery',
                   8: 'feet_imagery',
                   16: 'subtraction_imagery',
                   32: 'word_imagery'}

        self.events = mne.find_events(self.raw, stim_channel='STI101', uint_cast=True,
                                        shortest_event=32,
                                        initial_event=False,
                                        verbose=True)

        annot_from_events = mne.annotations_from_events(events=self.events,
                                                        event_desc=mapping,
                                                        sfreq=self.raw.info['sfreq'],
                                                        orig_time=self.raw.info['meas_date'])

        self.raw.set_annotations(annot_from_events)

        mne.write_events('sub_1_ses_2_events.txt', self.events, overwrite=True)

        return self

    def finding_bad_channels_std(self):
        std = np.std(self.raw.get_data(), axis=1)
        threshold = 5 * np.mean(std)
        bad_channel_indices = np.where(std > threshold)[0]
        bad_channels = [self.raw.ch_names[i] for i in bad_channel_indices]
        self.raw.info['bads'] = bad_channels
        return self

    def finding_bad_channels_maxwell(self):
        from mne.preprocessing import find_bad_channels_maxwell
        self.raw.info['bads'] = []

        raw_check = self.raw.copy()
        auto_noisy_chs, auto_flat_chs, auto_scores = find_bad_channels_maxwell(raw_check, verbose=True, return_scores=True)
        bads = self.raw.info['bads'] + auto_noisy_chs + auto_flat_chs
        self.raw.info['bads'] = bads

        return self

    def create_epochs(self):
        event_dict = {
            "Motor/hand_imagery": 4,
            "Motor/feet_imagery": 8,
            "Mental/subtraction_imagery": 16,
            "Mental/word_imagery": 32,
        }

        eog_channel = ["EOG001", "EOG002"]
        ecg_channel = "ECG003"

        reject = dict(grad=4000e-13)
        flat = dict(grad=1e-13)

        self.ecg_projs, _ = mne.preprocessing.compute_proj_ecg(self.raw, ch_name=ecg_channel,
                                                                 n_grad=1, n_mag=1, reject=reject,
                                                                 no_proj=True)

        self.eog_projs, _ = mne.preprocessing.compute_proj_eog(self.raw, ch_name=eog_channel,
                                                                n_grad=1, n_mag=1, n_eeg=0,
                                                                reject=reject,
                                                                no_proj=True)

        self.events = mne.read_events('events.txt')

        self.epochs = mne.Epochs(self.raw, events=self.events,
                            event_id=event_dict,
                            tmin=-1, tmax=5,
                            reject=reject,
                            flat=flat,
                            preload=True,
                            proj=False,
                            reject_by_annotation=True,
                            baseline=None,
                            verbose=True)

        self.raw.add_proj(self.ecg_projs)
        self.raw.add_proj(self.eog_projs)

        self.epochs.save('epochs-epo.fif', overwrite=True)

        return self

    def baseline_normalization(self):

        self.epochs = mne.read_epochs(r"D:\charl\Documents\CE901_MEG_DATA_AND_CODE\epochs-epo.fif", preload=True)
        self.epochs.apply_baseline((None, None), verbose=True)
        self.epochs.save('baseline_norm_epochs-epo.fif',overwrite=True)

    def apply_pipeline(self):
        if not os.path.exists("Checkpoint"):
            os.makedirs("Checkpoint")

        self.notch_filter()

        self.bandpass()

        self.find_events()

        self.finding_bad_channels_std()

        self.finding_bad_channels_maxwell()
        self.raw.save('checkpoint_maxwell_filtered_raw.fif', overwrite=True)

        self.create_epochs()
        self.create_epochs.save('epochs_checkpoint-epo.fif', overwrite=True)

        self.baseline_normalization()
        self.baseline_normalization.save('baseline_norm_epochs_checkpoint-epo.fif', overwrite=True)

        return self, self.raw


In [None]:
raw_1 = mne.io.read_raw_fif(r"/content/drive/MyDrive/MEG_FIF/MEG_BIDS/sub-1/ses-1/meg/sub-1_ses-1_task-bcimici_meg.fif", preload=True, verbose=False)

In [None]:
pipeline_1 = meg_preprocessing_pipeline(raw_1)
#pipeline_2 = meg_preprocessing_pipeline(raw_2)


In [None]:
pipeline_1.apply_pipeline()

In [None]:
pipeline_2.apply_pipeline()

In [None]:
del raw_1, pipeline_1

In [None]:
pipeline_1.notch_filter()

Filtering raw data in 1 contiguous segment
Removed notch frequencies (Hz):
     50.00 : 120258 windows
    100.00 : 120258 windows
    150.00 : 120258 windows
    200.00 : 120258 windows
    249.00 : 120258 windows
    250.00 : 120258 windows
    251.00 : 120258 windows


<__main__.meg_preprocessing_pipeline at 0x7ff45e361a20>

In [None]:
pipeline_1.bandpass()

Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 8 - 30 Hz

IIR filter parameters
---------------------
Butterworth bandpass zero-phase (two-pass forward and reverse) non-causal filter:
- Filter order 16 (effective, after forward-backward)
- Cutoffs at 8.00, 30.00 Hz: -6.02, -6.02 dB



<__main__.meg_preprocessing_pipeline at 0x7ff45e361a20>

In [None]:
pipeline_1.find_events()


200 events found
Event IDs: [ 4  8 16 32]
Overwriting existing file.


<__main__.meg_preprocessing_pipeline at 0x7ff45e361a20>

In [None]:
pipeline_1.finding_bad_channels_std()


In [None]:
pipeline_1.finding_bad_channels_maxwell()


In [None]:
pipeline_1.create_epochs()


In [None]:
pipeline_1.baseline_normalization()

In [None]:
pipeline_1.create_epochs()
pipeline_2.create_epochs()

Running ECG SSP computation
Using channel ECG003 to identify heart beats.
Setting up band-pass filter from 5 - 35 Hz

FIR filter parameters
---------------------
Designing a two-pass forward and reverse, zero-phase, non-causal bandpass filter:
- Windowed frequency-domain design (firwin2) method
- Hann window
- Lower passband edge: 5.00
- Lower transition bandwidth: 0.50 Hz (-12 dB cutoff frequency: 4.75 Hz)
- Upper passband edge: 35.00 Hz
- Upper transition bandwidth: 0.50 Hz (-12 dB cutoff frequency: 35.25 Hz)
- Filter length: 10000 samples (10.000 sec)

Number of ECG events detected : 1793 (average pulse 54 / min.)
Computing projector
Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 1 - 35 Hz

FIR filter parameters
---------------------
Designing a two-pass forward and reverse, zero-phase, non-causal bandpass filter:
- Windowed frequency-domain design (firwin2) method
- Hamming window
- Lower passband edge: 1.00
- Lower transition bandwidth: 0.50 Hz (-12 dB

[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=1)]: Done   1 out of   1 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   2 out of   2 | elapsed:    0.2s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   3 out of   3 | elapsed:    0.3s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   4 out of   4 | elapsed:    0.4s remaining:    0.0s
[Parallel(n_jobs=1)]: Done 309 out of 309 | elapsed:   32.7s finished


Not setting metadata
1793 matching events found
No baseline correction applied
Created an SSP operator (subspace dimension = 13)
13 projection items activated
Using data from preloaded Raw for 1793 events and 601 original time points ...
    Rejecting  epoch based on GRAD : ['MEG1212']
    Rejecting  epoch based on GRAD : ['MEG1212']
    Rejecting  epoch based on GRAD : ['MEG1212']
    Rejecting  epoch based on GRAD : ['MEG1212']
    Rejecting  epoch based on GRAD : ['MEG1212']
    Rejecting  epoch based on GRAD : ['MEG1212']
    Rejecting  epoch based on GRAD : ['MEG1212']
    Rejecting  epoch based on GRAD : ['MEG1212', 'MEG1433']
    Rejecting  epoch based on GRAD : ['MEG1212']
    Rejecting  epoch based on GRAD : ['MEG1212']
    Rejecting  epoch based on GRAD : ['MEG1212']
    Rejecting  epoch based on GRAD : ['MEG1212']
    Rejecting  epoch based on GRAD : ['MEG1212', 'MEG1433']
    Rejecting  epoch based on GRAD : ['MEG1212']
    Rejecting  epoch based on GRAD : ['MEG1212']
    R

[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=1)]: Done   1 out of   1 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   1 out of   1 | elapsed:    0.0s finished


Found 39 significant peaks
Number of EOG events detected: 39
Computing projector
Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 1 - 35 Hz

FIR filter parameters
---------------------
Designing a two-pass forward and reverse, zero-phase, non-causal bandpass filter:
- Windowed frequency-domain design (firwin2) method
- Hamming window
- Lower passband edge: 1.00
- Lower transition bandwidth: 0.50 Hz (-12 dB cutoff frequency: 0.75 Hz)
- Upper passband edge: 35.00 Hz
- Upper transition bandwidth: 0.50 Hz (-12 dB cutoff frequency: 35.25 Hz)
- Filter length: 10000 samples (10.000 sec)



[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=1)]: Done   1 out of   1 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   2 out of   2 | elapsed:    0.1s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   3 out of   3 | elapsed:    0.2s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   4 out of   4 | elapsed:    0.3s remaining:    0.0s
[Parallel(n_jobs=1)]: Done 309 out of 309 | elapsed:   41.3s finished


Not setting metadata
39 matching events found
No baseline correction applied
Created an SSP operator (subspace dimension = 13)
13 projection items activated
Using data from preloaded Raw for 39 events and 401 original time points ...
    Rejecting  epoch based on GRAD : ['MEG1212', 'MEG1742']
    Rejecting  epoch based on GRAD : ['MEG1212']
    Rejecting  epoch based on GRAD : ['MEG1212']
    Rejecting  epoch based on GRAD : ['MEG1212']
    Rejecting  epoch based on GRAD : ['MEG1212']
    Rejecting  epoch based on GRAD : ['MEG1212']
    Rejecting  epoch based on GRAD : ['MEG1212']
    Rejecting  epoch based on GRAD : ['MEG1212']
    Rejecting  epoch based on GRAD : ['MEG1212']
    Rejecting  epoch based on GRAD : ['MEG1212']
    Rejecting  epoch based on GRAD : ['MEG1212']
    Rejecting  epoch based on GRAD : ['MEG1212']
    Rejecting  epoch based on GRAD : ['MEG1212']
    Rejecting  epoch based on GRAD : ['MEG1212']
    Rejecting  epoch based on GRAD : ['MEG1212']
    Rejecting  epoch

TypeError: 'NoneType' object is not iterable

In [None]:
""" Setting up a notch filter to control DC """

freqs = [50, 100, 150, 200, 250]
meg_picks = mne.pick_types(raw.info, meg=True)

custom_notch_filter = raw.copy().notch_filter(
    freqs=[50, 100, 150, 200, 250], # up to 5th harmonic frequency
    picks=mne.pick_types(raw.info, meg=True),
    method='spectrum_fit',
    filter_length='auto',
    fir_window='hamming', # w(n) = 0.54 - 0.46 * cos(2πn / (N-1)) {'n' = index coefficient, 'N' = len }
    fir_design='firwin2', # Use fourier transform based window design
    n_jobs=-1,
    verbose=True)

In [None]:
""" Finding Bad channels through maxwell """

def noisy_flat_channel_detection(file):
    auto_noisy_chs, auto_flat_chs, auto_scores = mne.preprocessing.find_bad_channels_maxwell(
                                                    file,
                                                    return_scores=True,
                                                    verbose=True)

    bads = file.info['bads'] + auto_noisy_chs + auto_flat_chs
    file.info['bads'] = bads
    return file


In [None]:
def notch_filter(raw):
    picks = mne.pick_types(raw.info, meg=True)
    raw.notch_filter(freqs=[50, 100, 150, 200, 250], picks=picks,
                            method='spectrum_fit', filter_length='auto',
                            fir_window='hamming', fir_design='firwin',
                            n_jobs=1, verbose=True)
    return raw

def bandpass(raw):
    l_freq = 0.1  # Lower cutoff frequency
    h_freq = 40  # Upper cutoff frequency
    filter_order = 4
    ftype = 'butter'
    sfreq = raw.info['sfreq']
    iir_params = dict(order=filter_order, ftype=ftype)

    raw = mne.filter.filter_data(raw.get_data(), sfreq=sfreq,
                                        l_freq=l_freq, h_freq=h_freq,
                                        method='iir', phase='zero',
                                        iir_params=iir_params,
                                        verbose=True)
    return raw

def find_events(raw):
    mapping = {4: 'hand_imagery',
                8: 'feet_imagery',
                16: 'subtraction_imagery',
                32: 'word_imagery'}

    reject = dict(grad=4000e-13, eog=350e-6)


    events = mne.find_events(raw, stim_channel='STI101', uint_cast=True,
                                    shortest_event=32,
                                    initial_event=False,
                                    verbose=True)

    annot_from_events = mne.annotations_from_events(events=events,
                                                    event_desc=mapping,
                                                    sfreq=raw.info['sfreq'],
                                                    orig_time=raw.info['meas_date'])

    raw.set_annotations(annot_from_events)


    ecg_projs, _ = mne.preprocessing.compute_proj_ecg(raw,
                                                                n_grad=1, n_mag=1, reject=reject,
                                                                no_proj=True)

    eog_projs, _ = mne.preprocessing.compute_proj_eog(raw,
                                                                n_grad=1, n_mag=1, reject=reject,
                                                                no_proj=True)

    raw.add_proj(ecg_projs)
    raw.add_proj(eog_projs)

    return raw

def finding_bad_channels_std(raw):
    std = np.std(raw.get_data(), axis=1)
    threshold = 5 * np.mean(std)
    bad_channel_indices = np.where(std > threshold)[0]
    bad_channels = [raw.ch_names[i] for i in bad_channel_indices]
    raw.info['bads'] = bad_channels
    return self

""" Finding Bad channels through maxwell """

def noisy_flat_channel_detection(raw):
    auto_noisy_chs, auto_flat_chs, auto_scores = mne.preprocessing.find_bad_channels_maxwell(
                                                    raw,
                                                    return_scores=True,
                                                    verbose=True)

    bads = raw.info['bads'] + auto_noisy_chs + auto_flat_chs
    raw.info['bads'] = bads
    return raw

def nyquist_st_duration(raw):
        sfreq = raw.info['sfreq']
        nyquist_freq = sfreq / 2 # 500hz to reduce effect of aliasing

        st_duration = 20 / nyquist_freq
        return st_duration

def apply_tsss_filter(self):
        st_duration = self.nyquist_st_duration()
        raw = mne.preprocessing.maxwell_filter(self.raw, coord_frame='head', st_duration=st_duration, verbose=True)
        raw.save('checkpoint_raw.fif', overwrite=True)
        return raw

In [None]:
""" Applying EOG ECG SSP Projectors """
def EOG_ECG_SSP_Projections(file):
    eog_channel = "EOG001" # only specifying 1 eog channel due to additional noise when selecting two
    ecg_channel = "ECG003"

    self.ecg_projs, self.ecg_events = mne.preprocessing.compute_proj_ecg(file,
                                                    ch_name=ecg_channel,
                                                        n_grad=1,
                                                        n_mag=2,
                                                        )

    self.eog_projs, self.eog_events = mne.preprocessing.compute_proj_eog(file,
                                                    ch_name=eog_channel,
                                                        n_grad=1,
                                                            n_mag=1,
                                                            )

    file.add_proj(self.eog_projs)
    file.add_proj(self.ecg_projs)

    return file


""" Finding Bad channels through maxwell """

def noisy_flat_channel_detection(file):
    auto_noisy_chs, auto_flat_chs, auto_scores = mne.preprocessing.find_bad_channels_maxwell(
                                                    file,
                                                    return_scores=True,
                                                    verbose=True)

    bads = file.info['bads'] + auto_noisy_chs + auto_flat_chs
    file.info['bads'] = bads
    return file

def nyquist_st_duration(self):
        sfreq = self.raw.info['sfreq']
        nyquist_freq = sfreq / 2 # 500hz to reduce effect of aliasing

        st_duration = 20 / nyquist_freq
        return st_duration

def apply_tsss_filter(self):
        st_duration = self.nyquist_st_duration()
        self.raw = mne.preprocessing.maxwell_filter(self.raw, coord_frame='head', st_duration=st_duration, verbose=True)
        self.raw.save('checkpoint_raw.fif', overwrite=True)
        return self, self.raw