In [3]:
# ----------------- Load packages ------------------- #
import os
import numpy as np
import pandas as pd
import mne
from scipy.io import loadmat
from AutoPSG import *
from AutoPSG.artifact_detect import hypnogram_segment

In [6]:
def load_raw(file_path):
    if file_path.endswith('.vhdr'):
        return mne.io.read_raw_brainvision(file_path, preload=True)
    elif file_path.endswith('.edf'):
        return mne.io.read_raw_edf(file_path, preload=True)
    elif file_path.endswith('.fif'):
        return mne.io.read_raw_fif(file_path, preload=True)
    elif file_path.endswith('.set'):
        return mne.io.read_raw_eeglab(file_path, preload=True)
    elif file_path.endswith('.mat'):
        return mne.io.read_raw_fieldtrip(file_path, preload=True)
    elif file_path.endswith('.bdf'):
        return mne.io.read_raw_bdf(file_path, preload=True)
    elif file_path.endswith('.cnt'):
        return mne.io.read_raw_cnt(file_path, preload=True)
    else:
        raise ValueError(f"Unsupported file format: {file_path}")

In [5]:
def auto_artifact_detect_interp(raw, picks, reference, subject_id, epoch_length, epoch_slide = 0):
    """
    Perform artifact rejection on EEG data, including bad channel handling,
    epoch-wise artifact detection, and interpolation.

    Parameters:
        raw: mne.io.Raw
            The raw EEG data.
        picks: list
            Channels to include in processing.
        reference: list or str
            Reference channel(s) used, if average keep empty.
        subject_id: int or str
            Identifier for the subject.
        epoch_length: float
            Length of each epoch in seconds.
        epoch_slide: float
            Slide duration for overlapping epochs in seconds.

    Returns:
        mne.io.Raw
            Preprocessed raw EEG data with artifacts rejected and interpolated.
    """
    # ---- Bad Channel Identification and Visualization ----
    raw_copy = badchan_var(raw, picks=picks, reference=reference)
    raw_copy.plot(duration=60.0, scalings=dict(eeg=200e-6), bad_color='red', block=True)

    # ---- Artifact Detection ----
    ica_segments = hypnogram_segment(subject_id, raw_copy)
    raw_copy = ptpamp_epoch(raw_copy, threshold=600e-6, epoch_length=epoch_length, epoch_slide=epoch_slide, picks=picks, reference=reference)
    raw_copy = zscoreamp_epoch_epochwise(raw_copy, ica_segments, threshold_z=5, epoch_length=epoch_length, epoch_slide=epoch_slide, picks=picks, reference=reference)
    raw_copy = zscoreamp_epoch_chanwise(raw_copy, threshold_z=5, epoch_length=epoch_length, epoch_slide=epoch_slide, picks=picks, reference=reference)
    threshold_badepoch(raw_copy, picks=picks, percent_chan_bad=10)
    raw_copy.plot(duration=60.0, scalings=dict(eeg=200e-6), bad_color='red', block=True)

    # ---- Artifact Interpolation ----
    raw_copy = raw_copy.interpolate_bads(reset_bads=True, mode='accurate', origin='auto', method='spline', exclude=(), verbose='ERROR')
    raw_copy = interpolate_by_epoch(raw_copy, picks=picks, window_ratio=0.05)

    return raw_copy

In [None]:
# Set default paths and parameters in preprocessing
os.chdir('/Desktop/replace_w_your_dir')
raw_file = 'Raw/'
processed_file = 'Preproc/'
scoring_file = 'Scoring/'
stim_file = None
montage_temp = 'standard_1005'
online_refer = 'FCz'
rerefer = ['M1', 'M2']
epoch_length = 10

####################################### Automated Preprocessing Pipeline #######################################
file_paths = [os.path.join(raw_file, f) for f in os.listdir(raw_file) if f.endswith(('.vhdr', '.edf', '.fif', '.set', '.fdt', '.bdf', '.cnt', '.gdf', '.nxe', '.eeg', '.mef', '.lay'))]
for file_path in file_paths:
    id = os.path.splitext(os.path.basename(file_path))[0]
    print(f"Start preprocessing for subject {id}.")
    raw = load_raw(file_path)
    raw.set_annotations(None)
    
    raw.get_channel_types()
    raw.set_channel_types({
        "M1": "eeg", "M2": "eeg", "I1": "eeg", "I2": "eeg", "EOG1": "eog", 
        "EOG2": "eog", "EMG1": "emg", "EMG2": "emg", "EMG3": "emg", "ECG": "ecg"
    })
    raw.rename_channels({'FPz': 'Fpz'})

    # Downsampling and filtering
    raw.resample(sfreq=200)
    raw.notch_filter(freqs=60, method='fir', fir_window='hamming', phase='zero')
    raw.filter(l_freq=0.1, h_freq=35, picks=['eeg', 'eog'], method='fir', fir_window='hamming', phase='zero')
    raw.filter(l_freq=10, h_freq=100, picks=['emg'], method='fir', fir_window='hamming', phase='zero')

    # Add online reference channels back
    raw = mne.add_reference_channels(raw, ref_channels=(online_refer))

    # Set the montage
    montage = mne.channels.make_standard_montage(montage_temp)
    raw.set_montage(montage)

    # Set EEG reference
    raw.set_eeg_reference(ref_channels=(rerefer))

    # Segment raw data into sleep stages
    stage = hypnogram_segment(subject_id=id, raw=raw)

    # Load stimulation onsets
    if stim_file is not None:
        stim = load_stim(subject_id=id, raw=raw)
        annotate_stim(raw, stim)
    else:
        stim = None

    # Detect and interpolate artifacts
    ## visual_artifact_inspect(raw)
    auto_artifact_detect_interp(raw, stage=stage, stim = stim, picks=['eeg'],
                                reference=(rerefer), subject_id=id, epoch_length=epoch_length)

    # Apply ICA
    apply_ica_by_stage(raw=raw, stage=stage, stim = stim, n_component=0.99,
                       kurt_threshold_z=3, run_ecg=False, auto=True)

    save_path = os.path.join(processed_file, f"{id}.fif")
    raw.save(save_path, overwrite=True)

    raw.plot_psd()
    del raw, stage, id

In [None]:
# -------- Set Parameters --------- #
subject_id = 990
rerefer = ['M1', 'M2']
picks = 'eeg'
epoch_length = 10
epoch_slide = None

# ----------- Load Data ----------- #
raw = mne.io.read_raw_fif(f'~/Desktop/EMO_PSG_fMRI/data/inspected_data/EMO{subject_id}_N_cfv.fif', preload=True)
raw.set_annotations(None)

# ----- Bad Channel Detection ----- #
raw = badchan_var(raw, picks=picks, reference=rerefer)
raw.plot(duration=60.0, scalings=dict(eeg=200e-6), bad_color='red', block=True)

# ------ Artifact Detection ------- #
stage = hypnogram_segment(subject_id, raw)
raw = ptpamp_epoch(raw, threshold=600e-6, epoch_length=epoch_length, epoch_slide=epoch_slide, picks = picks, reference = rerefer)
raw = zscoreamp_epoch_epochwise(raw, stage, threshold_z=5, epoch_length=epoch_length, epoch_slide=epoch_slide, picks=picks, reference = rerefer)
raw = zscoreamp_epoch_chanwise(raw, threshold_z=5, epoch_length=epoch_length, epoch_slide=epoch_slide, picks=picks, reference = rerefer)
raw = flat_epoch(raw, threshold_sd=0.2e-6, min_flat_length=2, window_slide=0.5, epoch_length=epoch_length, epoch_slide=epoch_slide, picks=picks, reference = rerefer)
threshold_badepoch(raw, picks=picks, percent_chan_bad=10)
raw.plot(duration=60.0, scalings=dict(eeg=200e-6), bad_color='red', block=True)

# ---- Artifact Interpolation ----- #
raw = raw.interpolate_bads(reset_bads=True, mode='accurate', origin='auto', method='spline', exclude=(), verbose='ERROR')
raw = interpolate_by_epoch(raw, picks=picks, window_ratio=0.05)

# ------------ Run ICA ------------ #
stage_cleaned = hypnogram_segment(subject_id, raw)
raw = apply_ica_to_segments(raw, stage_cleaned, n_component = 0.99, kurt_threshold_z = 3, run_ecg = False)
raw.plot(duration=60.0, scalings=dict(eeg=200e-6), bad_color='red', block=True) ## Check if ICA removed artifacts in bad epochs

# ------- Save cleaned data ------- #
raw.save(f'~/Desktop/EMO_PSG_fMRI/data/preprocessed_data_temporary/EMO{subject_id}_cleaned.fif', overwrite=True)
del raw, stage, stage_cleaned