<h3><b><u> PREPROCESSING PIPELINE FOR MEG DATA </b></u></h3>

<p5> Using MNE Python, we have constructed a pipeline which will go through various stages of the preprocessing stages from notch filtering, bandpass filtering, tSSS sampling and epoching.

Each section is able to be amended as required, as long as the documentation for the specific MNE module is followed.

When calling the "__init__" constructor, just copy your raw.fif file path as it is and paste when prompted. It will take the file, and pass it through each stage of a processing pipeline, saving and loading elements as it goes to your local environment.

In regards to the tSSS sampling, you will need to specify the head_pos file, which is made when passing in compute_head_pos. You will also need to direct the object to your cross talk and calibration files for the Elekta NeuroMag system.

For more information on constructing filters, familiarise yourself with the scipy library, specifically signal processing, sos window, and the filt, or filtfilt method. MNE uses wrappers for these specific functions.

For further queries, please at first instance read through the MNE documentation.</p5>

<p5> html link: https://mne.tools/stable/overview/index.html </p5>

In [1]:
class meg_preprocessing_pipeline:
    
    ### Import required libraries ###
    import mne
    import numpy as np
    import os
    import warnings
    warnings.filterwarnings('ignore')
    
    """ 
    Class constructor containing processing pipeline for MEGNet.

    Filtering: Function can be edited, however WE DO NOT recommend using mne.filter.create_filter
    function as this causes instability issues when running mne.find_events.

    More functionality is planned for future, still a work in progress.

    Resampling: Returns a raw, resampled object down to 500hz.
                
                Resampling before filtering/epoching reduces alliasing,
                and memory consumption of raw object.

    Bad channel interpolation: Returns a interpolated, raw object.

                Bad channel interpolation serves the purpose of repairing channels which
                are either bad, or static based on signals from surrounding coils.
                Data therefore is not lost, and there is still the chance to retain information.

                Later functionality will add the autoreject package.
    
    Notch Filtering: Samples taken at 50hz powerline freq.
                    
                Notch filtered commenced at 50hz, plus third octave harmonic frequency.

    Butterworth bandpass filter: Returns a filtered, raw array object.

                Keyword args:
                    output='sos' == scipy.sos
                
                    Phase = 'zero-double' == scipy.filtfilt

                    Order = 4th order.

                    filter type (ftype): = 'butter' == scipy.butter

                Freq ranges specified between 7hz ~ 35hz to capture activations within these bandwidths.
    
    TSSS MAXWELL Filtering: 
        Returns: Maxwell Filtered raw object.

        Performs tSSS sampling using a st_duration derived from:
                
                        st_duration of sample
        st_duration =   ---------------------
                            nyquist_freq

        This allows tSSS sampling to fit into evenly spaced windows
        and still give high amounts of temporal precision when 
        calculating orthaginal points.

    Data is saved in local environment. Must manually move saved data at this point.

    Each element of the class is operable on its own.
    
    """
    def __init__(self, raw):

        self.raw = raw
        self.eog_events = None
        self.ecg_events = None
        self.eog_projs = None
        self.ecg_projs = None
        self.epochs = None
    
    def get_subject_name(self, filename):
        """ 
        Extract subject information from the filename as given by the user.
        Assumes the filename has the format: "{subject}" within the name.
        
        """
        base_name = os.path.splitext(os.path.basename(filename))[0].split("_raw")[0]

        # Check for condition with "subject" prefix
        if base_name.startswith("subject_"):
            subject = base_name.split("subject_")[1]
        # Check for condition with "sub" prefix
        elif base_name.startswith("sub_"):
            subject = "sub_" + base_name.split("sub_")[1]
        elif base_name.startswith("sub-"):
            subject = "sub-" + base_name.split("sub-")[1]
        elif base_name.startswith("-sub"):
            subject = "-sub" + base_name.split("-sub")[1]
        elif base_name.startswith("Sub_"):
            subject = "Sub_" + base_name.split("Sub_")[1]
        elif base_name.startswith("Sub-"):
            subject = "Sub-" + base_name.split("Sub-")[1]
        elif base_name.startswith("-Sub"):
            subject = "-Sub" + base_name.split("-Sub")[1]
        else:
            subject = None
            print("No subject found in the filename. Please specify the subject within the filename to begin processing...")
        
        return subject


    def resample_data(self):
        print("Resampling raw data...")
        self.raw.resample(500)
        print(f"Resampling complete at freq: {self.raw.info['sfreq']}")
        return self

    def notch_filter(self):

        """ 
        Notch filtering is required to be the same as the powerline frequency.

        From the specific dataset used, this occurs at 50hz, and at every
        harmonic.

        Specify the frequency range depending on your own use case.

        freqs= up to 5th harmonic frequency.
        
        """

        self.raw.notch_filter(freqs=[50, 100, 150],
                                picks='meg',
                            method='spectrum_fit',
                        filter_length='auto',
                    fir_window='hamming',
                fir_design='firwin2',
            n_jobs=-1,
        verbose=True
                            
        )

        return self

    def finding_bad_channels_maxwell(self):
        """
        Using the inbuilt MNE operations to determine flat, or noisy channels as
        automatic detection which can later be interpolated and addressed.
        The automatic detection method specifically for MEG data should be
        kind enough in determining averaged cut off points, and allocating the bad annotations.

        It takes the empty list of bad channels, and updates them as it iterates over the channels.

        Returns:
            An updated list of bad channels based on noisy or flat/static channels.
            ~ Flat or static channels indicate faulty sensor.
            ~ Noisy channels indicate external noise, sensor issues etc.

        Note:
            It changes the data object in place, returning a new self.raw object
            with updated and marked bad channels.

        """
        from mne.preprocessing import find_bad_channels_maxwell
        self.raw.info['bads'] = []

        raw_check = self.raw.copy() # first uses a copy of the original raw data
        auto_noisy_chs, auto_flat_chs, auto_scores = find_bad_channels_maxwell(raw_check, verbose=True, return_scores=True)
        bads = self.raw.info['bads'] + auto_noisy_chs + auto_flat_chs # concatenates the empty list of bad channels, noisy channels, flat channels
        self.raw.info['bads'] = bads # setting the bads parameter as the concetenated list

        return self

    def interpolate_bads(self):
        """
        Interpolate bad channels in the MEG Data as marked by our bad channel detection method.

        This function firstly creates a copy of the raw data,
        then performs interpolation on the bad channels, based on good channels.

        It then assigns the interpolated data back to the original raw object.

        Returns:
            self: The updated object with interpolated bad channels.

        """
        # Create a copy of the raw data and perform interpolation
        interpolated_raw = self.raw.copy().interpolate_bads(reset_bads=True)

        # Assign interpolated data back to original raw
        self.raw = interpolated_raw

        subject = self.get_subject_name(self.raw.filenames[0])

        self.raw.save(f"{subject}_interpolated_bads_raw.fif",overwrite=True)

        return self, self.raw
    
    def estimate_continuous_head_pos(self):

        """ 
        Input: updated self.raw object

        chpi_freqs, ch_idx, chpi_codes: used under the bonnet in MNE to calculate chpi amplitudes,
        locations and head position. Do not remove, otherwise will run into errors.
        
        Returns: head position file

        """

        self.raw.load_data()
        chpi_freqs, ch_idx, chpi_codes = mne.chpi.get_chpi_info(info=self.raw.info)
        chpi_amplitudes = mne.chpi.compute_chpi_amplitudes(self.raw)
        chpi_locs = mne.chpi.compute_chpi_locs(self.raw.info, chpi_amplitudes)
        self.head_pos = mne.chpi.compute_head_pos(self.raw.info, chpi_locs, verbose=True)
        
        subject = self.get_subject_name(self.raw.filenames[0])
        output_head_pos = f'{subject}_head_pos.pos'

        mne.chpi.write_head_pos(output_head_pos, self.head_pos)
        
        return self
    
    @staticmethod
    def map_to_power_of_two(event_id):
        """
        Acts as a catch all method to convert detected event array into expected event array.

        Expected event array:
            [4 8 16 32]
        
        Detected event array:
            event_array += 1
            1. [5 9 17 33]
            2. [6 10 18 34]
                etc.

        Returns:
            closest_power_of_two [4, 8, 16, 32]

        """

        if event_id == 255:  # Ignore initial trigger
            return event_id

        event_values = [4, 8, 16, 32]
        event_pair_values = min(event_values, key=lambda x: abs(x - event_id))
        
        return event_pair_values
    
    def find_events(self):

        
        """
        Find and assigns events as stipulated by event trigger channel.
        Filter events to include only those specified in the mapping dictionary.

        Returns:
            self.events == an updated events object which can be used later.

        """

        # Find events based on the 'STI101' stimulus channel.
        all_events = mne.find_events(self.raw, stim_channel='STI101',
                                    initial_event=False, verbose=True)

        # Map any detected event IDs to the closest power of 2 using list comprehension with given event pair list
        all_events[:, -1] = np.array([self.map_to_power_of_two(event_id) for event_id in all_events[:, -1]])

        # Filter events to include only those specified in the mapping dictionary
        self.events = mne.pick_events(all_events, include=[4, 8, 16, 32])
        print(f"Events selected from data: {self.events[:, -1][:4]}")

        subject = self.get_subject_name(self.raw.filenames[0])

        mne.write_events(f'{subject}_events.txt', self.events, overwrite=True)
        
        return self
        
    def bandpass_filter_butter(self):

        """
        Appling a bandpass filter to the raw data.

        This method utilizes a bandpass filter which is applied to the raw data using specific freq range.

        The chosen parameters include a 'IIR' design, with second order sections, filter order 4
        using filtfilt method. (zero-double) 2 phase, backward and forward pass.

        It utilizes the nyquist freqency ranges within both the lower and upper passband edge
        to reduce the effect of artifact aliasing.

        Returns:
            self: The modified object with filtered data.

        Note:
            This method modifies the 'raw' attribute of the object in place.
            
        """

        sfreq=self.raw.info['sfreq']
        nyquist_freq = sfreq / 2

        l_freq= min(1, nyquist_freq)
        h_freq= min(40, nyquist_freq)
        
        sfreq=self.raw.info['sfreq']

        iir_params = dict(order=4, ftype='butter', output='sos')
        
        
        self.raw = self.raw.filter(l_freq=l_freq, h_freq=h_freq,
                                              method='iir', phase='zero-double',
                                              iir_params=iir_params,
                                              filter_length='auto',
                                              verbose=True)
        
        subject = self.get_subject_name(self.raw.filenames[0])

        self.raw.save(f"{subject}_checkpoint_filtered_-raw.fif", overwrite=True)
        
        return self, self.raw


    def nyquist_st_duration(self):
        sfreq = self.raw.info['sfreq']
        nyquist_freq = sfreq / 2 # 500hz to reduce effect of aliasing

        st_duration = self.raw.times[-1] / nyquist_freq
        return st_duration
    
    def apply_tsss_filter(self):

        """" Required cross talk and calibration file """

        # If cross talk and calibration files are not available, omit calibration and cross_talk
        # from kwargs*

        """ 
        Uses tSSS to reduce each signal into its base component.
        st_duration is calculated from time / nyquist freq.

        head_pos: Calculated using mne.chpi.get_head_loc.
                See MNE docs for further details on head movement compensation.

        kwarg 'regularize': According to MNE docs, same profile as used with NeuroMag system.
                            Docs recommend it be used.
                
        Returns: 
            tSSS calibrated raw object.
        
        
        """
        st_duration = self.nyquist_st_duration()
        subject = self.get_subject_name(self.raw.filenames[0])
        head_pos = mne.chpi.read_head_pos(rf"{subject}_head_pos.pos")
        self.raw = mne.preprocessing.maxwell_filter(self.raw,
                    calibration=r"D:\charl\Documents\CE901_MEG_DATA_AND_CODE\MEG_BIDS\MEG_BIDS\NEUROMAG_CROSSTALK\sss_cal_3120_20130327.dat",
                cross_talk=r"D:\charl\Documents\CE901_MEG_DATA_AND_CODE\MEG_BIDS\MEG_BIDS\NEUROMAG_CROSSTALK\ct_sparse.fif",
            coord_frame='head',
            head_pos=head_pos, 
                st_duration=st_duration,
                st_correlation=0.98,
                    origin='auto',
                        int_order=8,
                            ext_order=3,
                                regularize='in',
                                    verbose=True)
        
        
        subject = self.get_subject_name(self.raw.filenames[0])


        self.raw.save(f'{subject}_tsss_filtered_raw.fif', overwrite=True)
        
        return self

    def create_eog_ecg_projs(self):

        """
        Create projs from the MEG data based on event information.

        This function uses the provided event dictionary to define event types,
        and the corresponding event codes.

        EOG/blink and ECG Artifact removal are computed via SSP projections from the MNE Library.

        Returns:
            self: ECG/EOG removed object.


        """
        # self.raw = mne.io.read_raw_fif(r"checkpoint_filter-raw.fif", preload=True)
        # self.events = mne.read_events(r"events.txt")

        # Specifying the EOG and ECG channels
        eog_channel = ["EOG001", "EOG002"]
        ecg_channel = "ECG003"

        # Defining rejection criteria and flat threshold
        reject = dict(grad=5000e-13) # 4000 femtoteslas

        # Compute ECG projections
        ecg_projs, _ = mne.preprocessing.compute_proj_ecg(self.raw,
                                                    ch_name=ecg_channel,
                                                n_grad=1,
                                            n_mag=1,
                                        no_proj=True,
                                    reject=reject
        )

        # Compute EOG/Blink projections
        eog_projs, _ = mne.preprocessing.compute_proj_eog(self.raw,
                                                    ch_name=eog_channel,
                                                n_grad=1,
                                            n_mag=1,
                                        no_proj=True,
                                    reject=reject
        )

        # Add projectors to raw object ready for epoch creations
        self.raw.add_proj(ecg_projs)
        self.raw.add_proj(eog_projs)

        #self.raw.save("tsss_eog_ecg_ssp_repaired_raw.fif", overwrite=True)

        subject = self.get_subject_name(self.raw.filenames[0])

        self.raw.save(f"{subject}_artifact_repaired-raw.fif", overwrite=True)
        


    def create_epochs(self):
        #self.raw = mne.io.read_raw_fif(r"D:\charl\Documents\CE901_MEG_DATA_AND_CODE\PROCESSING_PIPE\tsss_eog_ecg_ssp_repaired_raw.fif", preload=True)
        # Event dictionary mapping event types to codes
        event_dict = {
            "hand_imagery": 4,
            "feet_imagery": 8,
            "subtraction_imagery": 16,
            "word_imagery": 32,
        }

        # Create epochs from raw data using events and event dict.

        """ 
        
        Whether to reject based on annotations. 
        If True (default), epochs overlapping with segments
        whose description begins with 'bad' are rejected. 
        If False, no rejection based on annotations is performed. 
        
        """
        subject = self.get_subject_name(self.raw.filenames[0])
        self.events = mne.read_events(fr"{subject}_events.txt")
        events = self.events
        # Time Window of interest: 2-6s
        
        self.epochs = mne.Epochs(self.raw, events=events,
                            event_id=event_dict,
                            tmin=-2, tmax=6, # Will specify 2 seconds after trial onset, to 6 seconds after using cropping later.
                            preload=True,
                            reject=None,
                            reject_by_annotation=False,
                            baseline=(-0.2, 0.0),
                            verbose=True)

        self.epochs= self.epochs.crop(tmin=2, tmax=6)

        subject = self.get_subject_name(self.raw.filenames[0])

        self.epochs.save(f"{subject}_epoched_data-epo.fif", overwrite=True)

        return self  

    def apply_pipeline(self):

        """
        Applies a series of preprocessing steps to the raw data based on the defined pipeline.

        Steps:
        1. Finding bad channels using Maxwell filtering.
        2. Interpolating bad channels.
        3. Applying multiple bandpass Butterworth filters.
        4. Finding events in the data.
        5. Creating epochs based on the events.

        Returns:
        -------
        self: Instance of the class.
            The modified instance of the class with the applied
            preprocessing steps

        """
        # Get subject name
        self.get_subject_name(filename=filename)

        # Estimate CHP
        self.estimate_continuous_head_pos()

        # Find bad channels using maxwell filtering
        self.finding_bad_channels_maxwell()

        # bad channel interpolation
        self.interpolate_bads()
        
        # apply tsss sampling
        self.apply_tsss_filter()

        # Resample data
        self.resample_data()

        # # Notch Filtering
        self.notch_filter()
        
        # Apply a bandpass 4th order, zero double phase (scipywrapper:sosfiltfilt) butterworth filter
        self.bandpass_filter_butter()

        # find events
        self.find_events()
        
        # Create projectors for blinks and heartbeats 
        self.create_eog_ecg_projs()
        
        # Create epochs
        self.create_epochs()

        return self


In [None]:
if __name__ == "__main__":

    """ Filename imput is a copied path of raw fif file """
    
    import os

    while True:
        filename = input("Please specify file path: ").replace('"', '').replace("'", "")
        #filename.replace('"', '').replace("'", "")

        if os.path.isfile(filename):
            # Assuming valid file path detected
            break
        else:
            print("Invalid file path specified. Please try again.")
            break

    raw = mne.io.read_raw_fif(filename, preload=True)
    instance = meg_preprocessing_pipeline(raw)
    instance.apply_pipeline()