<a href="https://colab.research.google.com/github/Charliebond125/MEGNet/blob/main/MEGProcessing_Pipeline.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import mne
import tqdm
import os
import matplotlib.pyplot as plt
import warnings
import numpy as np

from mne.preprocessing import find_eog_events, find_ecg_events, create_eog_epochs, create_ecg_epochs, compute_proj_eog, compute_proj_ecg

warnings.filterwarnings('ignore')


class meg_preprocessing_pipeline:
    def __init__(self, raw):
        self.raw = raw
        self.eog_events = None
        self.ecg_events = None
        self.eog_projs = None
        self.ecg_projs = None
        self.epochs = None

    def notch_filter(self):
        picks = mne.pick_types(self.raw.info, meg=True)
        self.raw.notch_filter(freqs=[50, 100, 150, 200, 250], picks=picks,
                               method='spectrum_fit', filter_length='auto',
                               fir_window='hamming', fir_design='firwin',
                               verbose=True)
        return self

    def bandpass(self):

        sfreq = self.raw.info['sfreq']
        nyquist_freq = sfreq / 2

        l_freq = min(8, nyquist_freq)  # Lower cutoff frequency
        h_freq = min(35, nyquist_freq) # Upper cutoff frequency

        filter_order = 4
        ftype = 'butter'
        sfreq = self.raw.info['sfreq']
        iir_params = dict(order=filter_order, ftype=ftype)

        self.raw = self.raw.filter(l_freq=l_freq, h_freq=h_freq,
                                              method='iir', phase='zero',
                                              iir_params=iir_params,
                                              filter_length='auto',
                                              verbose=True)

        return self

    def find_events(self):
        mapping = {4: 'hand_imagery',
                   8: 'feet_imagery',
                   16: 'subtraction_imagery',
                   32: 'word_imagery'}

        self.events = mne.find_events(self.raw, stim_channel='STI101', uint_cast=True,
                                        shortest_event=32,
                                        initial_event=False,
                                        verbose=True)

        annot_from_events = mne.annotations_from_events(events=self.events,
                                                        event_desc=mapping,
                                                        sfreq=self.raw.info['sfreq'],
                                                        orig_time=self.raw.info['meas_date'])

        self.raw.set_annotations(annot_from_events)

        mne.write_events('events.txt', self.events, overwrite=True)

        return self

    def finding_bad_channels_std(self):
        std = np.std(self.raw.get_data(), axis=1)
        threshold = 5 * np.mean(std)
        bad_channel_indices = np.where(std > threshold)[0]
        bad_channels = [self.raw.ch_names[i] for i in bad_channel_indices]
        self.raw.info['bads'] = bad_channels
        return self

    def finding_bad_channels_maxwell(self):
        from mne.preprocessing import find_bad_channels_maxwell
        self.raw.info['bads'] = []

        raw_check = self.raw.copy()
        auto_noisy_chs, auto_flat_chs, auto_scores = find_bad_channels_maxwell(raw_check, verbose=True, return_scores=True)
        bads = self.raw.info['bads'] + auto_noisy_chs + auto_flat_chs
        self.raw.info['bads'] = bads

        return self

    def interpolate_bads(self):
        interpolated_raw = self.raw.copy().interpolate_bads(reset_bads=True)
        self.raw = interpolated_raw
        return self

    def create_epochs(self):
        event_dict = {
            "Motor/hand_imagery": 4,
            "Motor/feet_imagery": 8,
            "Mental/subtraction_imagery": 16,
            "Mental/word_imagery": 32,
        }

        eog_channel = ["EOG001", "EOG002"]
        ecg_channel = "ECG003"

        reject = dict(grad=4000e-13)
        flat = dict(grad=1e-13)

        self.ecg_projs, _ = mne.preprocessing.compute_proj_ecg(self.raw, ch_name=ecg_channel,
                                                                 n_grad=1, n_mag=1, reject=reject,
                                                                 no_proj=True)

        self.eog_projs, _ = mne.preprocessing.compute_proj_eog(self.raw, ch_name=eog_channel,
                                                                n_grad=1, n_mag=1, n_eeg=0,
                                                                reject=reject,
                                                                no_proj=True,)

        self.events = mne.read_events('events.txt')

        self.epochs = mne.Epochs(self.raw, events=self.events,
                            event_id=event_dict,
                            tmin=1, tmax=7,
                            reject=reject,
                            flat=flat,
                            preload=False,
                            proj='delayed',
                            reject_by_annotation=True,
                            baseline=None,
                            verbose=True)

        self.epochs.save('epochs-epo.fif', overwrite=True)

        return self

    def baseline_normalization(self):

        self.epochs = mne.read_epochs(r"D:\charl\Documents\CE901_MEG_DATA_AND_CODE\epochs-epo.fif", preload=True)
        self.epochs.apply_baseline((None, None), verbose=True)
        self.epochs.save('baseline_norm_epochs-epo.fif',overwrite=True)

    def apply_pipeline(self):
        if not os.path.exists("Checkpoint"):
            os.makedirs("Checkpoint")

        self.bandpass()

        self.find_events()

        self.finding_bad_channels_maxwell()

        self.interpolate_bads()

        self.create_epochs()

        self.baseline_normalization()

        return self, self.raw
