# Imports & Installs

In [40]:
# !pip install mne
# !pip install autoreject
# !pip install pyxdf

In [41]:
import traceback

In [42]:
import os
import warnings
import numpy as np
import pandas as pd
from scipy.io import loadmat
import matplotlib.pyplot as plt
from scipy.stats import ttest_ind

In [43]:
import mne
from mne.preprocessing import ICA

In [44]:
from autoreject import AutoReject

In [45]:
import pyxdf

In [46]:
warnings.filterwarnings('ignore', category=RuntimeWarning)
mne.set_log_level('WARNING')

In [47]:
# import numpy as np
# import pandas as pd
# from scipy.io import loadmat
# import mne
# from autoreject import AutoReject
# from mne.preprocessing import ICA
# import pyxdf

# Manipulated Variables

In [48]:
file_pairs = [
    # Participant Number, EEG Data, LSL Data, Digitization Data
    # ("P01", "Dataset/P002 06.18.2025/sub-P001_ses-S001_task-Default_run-001_eeg.xdf", "Dataset/P012 06.16.2025/P01_cleaned.csv", "Dataset/P001 06.16.2025/P001 06.16.2025.txt"),
    ("P02", "Dataset/P002 06.18.2025/sub-P001_ses-S001_task-Default_run-001_eeg.xdf", "Dataset/P002 06.18.2025/P02_cleaned.csv", "Dataset/P002 06.18.2025/P002 06.18.2025.txt"),
    ("P03", "Dataset/P003 06.20.2025/sub-P003/ses-S001/eeg/sub-P003_ses-S001_task-Default_run-001_eeg.xdf", "Dataset/P003 06.20.2025/P03_cleaned.csv", "Dataset/P003 06.20.2025/P003 06.20.2025.txt"),
    ("P04", "Dataset/P004 06.25.2025/sub-P001_ses-S001_task-Default_run-001_eeg.xdf", "Dataset/P004 06.25.2025/P04_cleaned.csv", "Dataset/P004 06.25.2025/p004.txt"),
    ("P05", "Dataset/P005 06.26.2025/sub-P005_ses-S001_task-Default_run-001_eeg.xdf", "Dataset/P005 06.26.2025/P05_cleaned.csv", "Dataset/P005 06.26.2025/P005 06.26.2025.txt"),
    # ("P06", "Dataset/P006 06.28.2025/sub-P006_ses-S001_task-Default_run-001_eeg.xdf", "Dataset/P006 06.28.2025/P06_cleaned.csv", "Dataset/P006 06.28.2025/P006 06.28.2025.txt"),
    ("P07", "Dataset/P007 06.30.2025/sub-P001_ses-S001_task-Default_run-001_eeg.xdf", "Dataset/P007 06.30.2025/P07_cleaned.csv", "Dataset/P007 06.30.2025/P007 06.30.2025.txt"),
    ("P08", "Dataset/P008 07.01.2025/sub-P008_ses-S001_task-Default_run-001_eeg.xdf", "Dataset/P008 07.01.2025/P08_cleaned.csv", "Dataset/P008 07.01.2025/P008 07.01.2025.txt"),
    # ("P09", "Dataset/P009 07.01.2025/sub-P009_ses-S001_task-Default_run-001_eeg.xdf", "Dataset/P009 07.01.2025/P09_cleaned.csv", "Dataset/P009 07.01.2025/P009 07.01.2025.txt"),
    ("P10", "Dataset/P010 07.02.2025/sub-P010_ses-S001_task-Default_run-001_eeg.xdf", "Dataset/P010 07.02.2025/P10_cleaned.csv", "Dataset/P010 07.02.2025/P010 07.02.2025.txt"),
    ("P11", "Dataset/P011 07.08.2025/sub-P011_ses-S001_task-Default_run-001_eeg.xdf", "Dataset/P011 07.08.2025/P11_cleaned.csv", "Dataset/P011 07.08.2025/P011 07.08.2025.txt"),
    ("P12", "Dataset/P012 07.08.2025/sub-P012_ses-S001_task-Default_run-001_eeg.xdf", "Dataset/P012 07.08.2025/P12_cleaned.csv", "Dataset/P012 07.08.2025/P012 07.08.2025.txt"),
    ("P13", "Dataset/P013 07.09.2025/sub-P013_ses-S001_task-Default_run-001_eeg.xdf", "Dataset/P013 07.09.2025/P13_cleaned.csv", "Dataset/P013 07.09.2025/P013 07.09.2025.txt"),
    ("P14", "Dataset/P014 07.10.2025/sub-P014_ses-S001_task-Default_run-001_eeg.xdf", "Dataset/P014 07.10.2025/P14_cleaned.csv", "Dataset/P014 07.10.2025/P014 07.10.2025.txt"),
    ("P15", "Dataset/P015 07.17.2025/sub-P015_ses-S001_task-Default_run-001_eeg.xdf", "Dataset/P015 07.17.2025/P15_cleaned.csv", "Dataset/P015 07.17.2025/P015 07.17.2025.txt"),
    ("P16", "Dataset/P016 07.18.2025/sub-P016_ses-S001_task-Default_run-001_eeg.xdf", "Dataset/P016 07.18.2025/P16_cleaned.csv", "Dataset/P016 07.18.2025/P016 07.18.2025.txt"),
    ("P17", "Dataset/P017 07.19.2025/sub-P017_ses-S001_task-Default_run-001_eeg.xdf", "Dataset/P017 07.19.2025/P17_cleaned.csv", "Dataset/P017 07.19.2025/P017 07.19.2025.txt"),
    ("P18", "Dataset/P018 07.19.2025/sub-P018_ses-S001_task-Default_run-001_eeg.xdf", "Dataset/P018 07.19.2025/P18_cleaned.csv", "Dataset/P018 07.19.2025/P018 07.19.2025.txt"),
    ("P19", "Dataset/P019 07.30.2025/sub-P019_ses-S001_task-Default_run-001_eeg.xdf", "Dataset/P019 07.30.2025/P19_cleaned.csv", "Dataset/P019 07.30.2025/P019 07.30.2025.txt"),
    ("P20", "Dataset/P020 07.24.2025/sub-P020_ses-S001_task-Default_run-001_eeg.xdf", "Dataset/P020 07.24.2025/P20_cleaned.csv", "Dataset/P020 07.24.2025/P020 07.24.2025.txt"),
    # ("P21", "Dataset/P021 09.02.2025/sub-P021_ses-S001_task-Default_run-001_eeg.xdf", "Dataset/P021 09.02.2025/P21_cleaned.csv", "Dataset/P021 09.02.2025/P021 09.02.2025.txt"),
    ("P22", "Dataset/P022 09.03.2025/sub-P022_ses-S001_task-Default_run-001_eeg.xdf", "Dataset/P022 09.03.2025/P22_cleaned.csv", "Dataset/P022 09.03.2025/P022 09.03.2025.txt"),
    ("P23", "Dataset/P023 09.16.2025/sub-P023_ses-S001_task-Default_run-001_eeg.xdf", "Dataset/P023 09.16.2025/P23_cleaned.csv", "Dataset/P023 09.16.2025/P023 09.16.2025.txt"),
    # ("P24", "Dataset/P024 09.17.2025/sub-P024_ses-S001_task-Default_run-001_eeg.xdf", "Dataset/P024 09.17.2025/P24_cleaned.csv", "Dataset/P024 09.17.2025/P024 09.17.2025.txt"),
    # ("P25", "Dataset/P025 09.18.2025/sub-P024_ses-S001_task-Default_run-001_eeg.xdf", "Dataset/P025 09.18.2025/P25_cleaned.csv", "Dataset/P025 09.18.2025/P024 09.18.2025.txt"),
    ("P26", "Dataset/P026 09.20.2025/sub-P026_ses-S001_task-Default_run-001_eeg.xdf", "Dataset/P026 09.20.2025/P26_cleaned.csv", "Dataset/P026 09.20.2025/P026 09.20.2025.txt"),
    ("P27", "Dataset/P027 09.20.2025/sub-P027_ses-S001_task-Default_run-001_eeg.xdf", "Dataset/P027 09.20.2025/P27_cleaned.csv", "Dataset/P027 09.20.2025/P027 09.20.2025.txt"),
    ("P28", "Dataset/P028 09.21.2025/sub-P028_ses-S001_task-Default_run-001_eeg.xdf", "Dataset/P028 09.21.2025/P28_cleaned.csv", "Dataset/P028 09.21.2025/P028 09.21.2025.txt"),
    ("P29", "Dataset/P029 09.21.2025/sub-P029_ses-S001_task-Default_run-001_eeg.xdf", "Dataset/P029 09.21.2025/P29_cleaned.csv", "Dataset/P029 09.21.2025/P029 09.21.2025.txt"),
    # ("P30", "Dataset/P030 09.21.2025/sub-P030_ses-S001_task-Default_run-001_eeg.xdf", "Dataset/P030 09.21.2025/P30_cleaned.csv", "Dataset/P030 09.21.2025/P030 09.21.2025.txt"),
    ("P31", "Dataset/P031 09.29.2025/sub-P031_ses-S001_task-Default_run-001_eeg.xdf", "Dataset/P031 09.29.2025/P31_cleaned.csv", "Dataset/P031 09.29.2025/P031 09.29.2025.txt"),
    ("P32", "Dataset/P032 10.02.2025/sub-P032_ses-S001_task-Default_run-001_eeg.xdf", "Dataset/P032 10.02.2025/P32_cleaned.csv", "Dataset/P032 10.02.2025/P032 10.02.2025.txt"),
    # ("P33", "Dataset/P033 10.02.2025/sub-P033_ses-S001_task-Default_run-001_eeg.xdf", "Dataset/P033 10.02.2025/P33_cleaned.csv", "Dataset/P033 10.02.2025/P033 10.02.2025.txt"),
    # ("P34", "Dataset/P034 10.08.2025/sub-P034_ses-S001_task-Default_run-001_eeg.xdf", "Dataset/P034 10.08.2025/P34_cleaned.csv", "Dataset/P034 10.08.2025/P034 10.08.2025.txt"),
]

regions = {
    'Frontal': ['Fp1', 'Fz', 'F3', 'F7', 'F9', 'FC3', 'FC5', 'ACC_X', 'ACC_Y', 'ACC_Z'],
    'Central': ['C3', 'Cz'],
    'Parietal': ['P3', 'P4', 'Pz'],
    'Occipital': ['O1', 'O2'],
}

conds = {
    'LowGermane': [41, 11],
    'HighGermane': [42, 12]
}

bands = {
    'Delta': (0.5, 4),
    'Theta': (4, 7),
    'Alpha': (8, 13),
    'Beta': (14, 20),
    'Gamma': (20, 100)
}

downsample_value = 250.0
bandpass_low = 0.1
bandpass_high = 30

outputFolder = "GammaResults2"

channels_to_drop = ['F7', 'F9', 'FC5', 'FC1', 'C3']

# Rename Fp1 → Fz, Fz → Cz, F3 → Pz
rename_map = {
    'Fp1': 'Fz',
    'Fz': 'Cz',
    'F3': 'Pz'
}

eeg_channels = ['Fz', 'Cz', 'Pz']
acc_channels = ['ACC_X', 'ACC_Y', 'ACC_Z']

# ICA
n_components = 3

# Digitization?
digitized = True

# 1. Define the labels we need from the raw file and their MNE mappings
# Mappings: d17=Fz, d18=Cz, d19=Pz
# Fiducials: nz (nasion), al (lpa), ar (rpa)
# Ignore cz since I have d18
# Ignore Iz since we only need three
# mne also doesn't require req or gnd
labels_to_find = {
    'nz': 'nasion',
    'al': 'lpa',
    'ar': 'rpa',
    'd17': 'Fz',
    'd18': 'Cz',
    'd19': 'Pz'
}

# EEG Analysis

In [49]:
##############################################################################################################################################################
# Step 1: Global Variables & Initialization
##############################################################################################################################################################

class EEGProcessor:
    
    # def __init__(self, set_file, marker_csv, participant, cond, cond_start, cond_end, prestim, poststim):
    def __init__(self, set_file, marker_csv, participant, digitizer_txt, cond, cond_start, cond_end, prestim, poststim):
        """
        Initializes EEGProcessor for EEG preprocessing and trial segmentation.

        Args:
        - set_file (str): Path to .xdf EEG data file.
        - marker_csv (str): Path to .csv file with event markers.
        - participant (str): Participant identifier.
        - cond (str): Experimental condition label.
        - cond_start, cond_end (float): Start and end times of condition.
        - prestim, poststim (int): Epoch times before/after stimulus (ms).
        
        Note:
        - prestim and poststim are converted from milliseconds to seconds by dividing by 1000.
        """
        
        # Participant specific
        print("\nStep 1: Obtain .XDF File")
        self.set_file = set_file
        self.marker_df = pd.read_csv(marker_csv)
        self.participant = participant
        self.digitizer_txt = digitizer_txt
        
        # Condition specific
        self.cond = cond
        self.cond_start = cond_start
        self.cond_end = cond_end
        
        # EEG analysis specific
        self.prestim = prestim / 1000
        self.poststim = poststim / 1000
        
        # Data structures for storage
        self.sfreq = None
        self.condition_timestamps = {}
        self.trials = []
        self.trl = []
        self.participant_ratios = []
        self.epochs_train = None
        self.epochs_test = None

##############################################################################################################################################################
# Step 2: Loading Data & Channel Information
##############################################################################################################################################################

    def parse_and_create_montage(self):
        """
        Parses the digitization file, extracts required points,
        and creates an MNE DigMontage object.
        """
        print(f"Parse and Create Montage: Parsing digitization file: {self.digitizer_txt}")
        
        parsed_coords = {}
        ch_positions = {}

        try:
            # Opens the file
            # r tells python to open in read only 
            # assigned to variable f
            with open(self.digitizer_txt, 'r') as f:
                # for each line in the file
                for line in f:
                    line = line.strip()
                    if not line or ':' not in line:
                        continue
                    
                    try:
                        label, coords_str = line.split(':', 1)
                        label = label.strip()
                        
                        if label in labels_to_find:
                            # Split coordinates and convert to float
                            coords_mm = [float(c) for c in coords_str.strip().split()]
                            
                            if len(coords_mm) == 3:
                                # I am using an older version of MNE
                                # Convert from mm to meters (MNE's standard)
                                coords_m = [c / 1000.0 for c in coords_mm]
                                # Save the coordinates *in meters*
                                parsed_coords[label] = coords_m
                            else:
                                print(f"Parse and Create Montage: Warning: Skipping malformed line for {label}: {line}")
                                
                    except Exception as e:
                        print(f"Parse and Create Montage: Warning: Skipping line due to parsing error: {line} | Error: {e}")
        
        except FileNotFoundError:
            print(f"Parse and Create Montage: ERROR: Digitization file not found at {self.digitization_file}. Cannot apply custom montage.")
            return None
        except Exception as e:
            print(f"Parse and Create Montage: ERROR: Could not read digitization file: {e}")
            return None

        # 2. Map parsed labels to MNE-required format
        try:
            ch_positions = {
                'Fz': parsed_coords['d17'],
                'Cz': parsed_coords['d18'],
                'Pz': parsed_coords['d19'],
            }
            
            lpa_coords = parsed_coords['al']
            nasion_coords = parsed_coords['nz']
            rpa_coords = parsed_coords['ar']
        
        except KeyError as e:
            print(f"Parse and Create Montage: ERROR: Missing required key in digitization file: {e}. Cannot create montage.")
            print(f"Parse and Create Montage: Required keys: {list(labels_to_find.keys())}")
            print(f"Parse and Create Montage: Found keys: {list(parsed_coords.keys())}")
            return None

        # 3. Create the DigMontage object
        print("Parse and Create Montage: Creating DigMontage object from parsed coordinates...")
        montage = mne.channels.make_dig_montage(
            ch_pos=ch_positions,
            lpa=lpa_coords,
            nasion=nasion_coords,
            rpa=rpa_coords
        )
        
        return montage

    def load_set(self):
        """
        Loads and preprocesses EEG data from a .xdf file.

        - Uses pyxdf to read EEG stream and extract signal data.
        - Renames key channels (Fp1→Fz, Fz→Cz, F3→Pz).
        - Drops unnecessary channels (e.g., F7, C3).
        - Applies standard 10-20 montage and sets channel types.
        - Resamples to 250 Hz and applies 0.1–30 Hz bandpass filter.

        Returns:
        - Preprocessed MNE Raw EEG object.
        """
        
        print("\nStep 2: Loading .xdf file...")

        # Load the .xdf file
        # Load the three streams from each file. Preview of the streams in the helpers folder
        streams, _ = pyxdf.load_xdf(self.set_file)
        # print(f'Here are the streams: {streams}')

        # Find the EEG stream
        # For each of the streams, only one says 'eeg' type, the other two are 'markers'
        eeg_stream = next((s for s in streams if s['info']['type'][0].lower() == 'eeg'), None)
        if eeg_stream is None:
            raise ValueError("No EEG stream found in the XDF file.")

        # Code updated 10.19
        # Now effectively pulls nominal srate 
        data = np.array(eeg_stream['time_series']).T
        try:
            if 'sample_rate' in eeg_stream['info']:
                sfreq = float(eeg_stream['info']['sample_rate'][0])
                print(f'Original sample rate found (sample_rate): {sfreq}')
            elif 'nominal_srate' in eeg_stream['info']:
                sfreq = float(eeg_stream['info']['nominal_srate'][0])
                print(f'Original sample rate found (nominal_srate): {sfreq}')
            elif 'effective_srate' in eeg_stream['info']:
                sfreq = float(eeg_stream['info']['effective_srate'])
                print(f'Original sample rate found (effective_srate): {sfreq}')
            else:
                raise KeyError
        except Exception:
            sfreq = downsample_value
            print(f'Original sample rate was not found — using downsample_value: {sfreq}')
        print(f'Load Set: sfreq = {sfreq}')


        # Extract channel names from the stream
        try:
            ch_names = [chan['label'][0] for chan in eeg_stream['info']['desc'][0]['channels'][0]['channel']]
        except:
            print('Was not successfully in pulling the channel labels')
            ch_names = [f"Ch{i+1}" for i in range(data.shape[0])]
        
        print(f'Load Set: Channel names: {ch_names}')
        
        # Create MNE Info and Raw objects
        info = mne.create_info(ch_names=ch_names, sfreq=sfreq, ch_types=['eeg'] * len(ch_names))
        # This is the raw variable that gets returned and used in Process()
        # Data is the EEG time series data
        # Info is the MNE object of channels, frequency, and types of channels
        raw = mne.io.RawArray(data, info) 
        
        # Drop unwanted channels or electrode locations
        if channels_to_drop:
            print(f"Dropping channels: {channels_to_drop}")
            raw.drop_channels(channels_to_drop)
        
        # Renames the electrode / channels as established earlier
        raw.rename_channels(rename_map)
        
        # Set channel types
        # EEG for EEG channels / electrodes
        channel_types = {ch: 'eeg' for ch in eeg_channels}
        # Misc for accelerometers
        channel_types.update({ch: 'misc' for ch in acc_channels})
        # apply change to MNE raw object
        raw.set_channel_types(channel_types)
        
        if (digitized):
            print("Load Set: Apply custom digitization (montage) from file...")
            try:
                montage = self.parse_and_create_montage()
                raw.set_montage(montage, on_missing='ignore')
                print("Load Set: Custom digitization applied successfully.")
            except FileNotFoundError:
                print("Load Set: Montage file not found. Skipping custom montage.")
        else:
            montage = mne.channels.make_standard_montage('standard_1020')
            raw.set_montage(montage, on_missing='ignore')

        print(f"\nStep 2.1: Downsampling to {downsample_value} Hz...")
        raw.resample(downsample_value)
        
        print(f"\nStep 2.2: Applying bandpass filter ({bandpass_low} Hz to {bandpass_high} Hz)...")
        raw.filter(l_freq=bandpass_low, h_freq=bandpass_high, picks='eeg')

        print("Load Set: EEG data successfully loaded and preprocessed.")
        return raw

##############################################################################################################################################################
# Step 3: Event Extraction
##############################################################################################################################################################

    def extract_event_windows(self, raw):
        """
        Segments raw EEG into training and test periods using event markers.

        - Identifies condition start/end markers (e.g., 42, 12).
        - Crops raw EEG data into training and test segments.
        - Saves `.fif` files and logs timing metadata.
        - Outputs condition timestamps and trial metadata as CSVs.

        Args:
        - raw (mne.io.Raw): Preprocessed raw EEG object.

        Returns:
        - Tuple of training and test Raw objects (or None if invalid).
        """

        print("\nStep 3.1: Extracting event windows for 'training' and 'test' periods...")

        # Prepare CSV
        if 'time' not in self.marker_df.columns or 'value' not in self.marker_df.columns:
            raise KeyError("The CSV file must contain 'time' and 'value' columns.")
        self.marker_df['value'] = pd.to_numeric(self.marker_df['value'], errors='coerce').dropna().astype(int)
        self.marker_df['time'] = pd.to_numeric(self.marker_df['time'], errors='coerce').dropna()
        self.marker_df = self.marker_df.dropna(subset=['time', 'value']).reset_index(drop=True)

        # Pull LSL from CSV
        events = self.marker_df['value']
        print(f"Extract Event Windows: LSL Values: {self.marker_df['value']}")
        times = self.marker_df['time']

        # Extract indices of specific event markers
        event_start_indices = [i for i, event in enumerate(events) if event == self.cond_start]
        print(f"Extract Event Windows: Indices of First Condition Start: {events[event_start_indices[0]]}")
        event_end_indices = [i for i, event in enumerate(events) if event == self.cond_end]

        if len(event_start_indices) < 7 or len(event_end_indices) < 7:
            print("Extract Event Windows: Insufficient markers for training or test period.")
            print(f'Extract Event Windows: Length of start indices: {len(event_start_indices)}')
            print(f'Extract Event Windows:Length of end indices: {len(event_end_indices)}')
            return None, None

        training_start = times[event_start_indices[0]]
        print(f'Extract Event Windows:Actual training start: {training_start}')
        training_end = times[event_end_indices[5]]
        print(f'Extract Event Windows:Actual training end: {training_end}')
        test_start = times[event_start_indices[-1]]
        test_end = times[event_end_indices[-1]]

        # Some clarification on what is happening here
        # Crop works with seconds. So taking the start sample / frequency provides seconds
        # Crops then copies the data from that start and end section
        raw_training = raw.copy().crop(tmin=training_start, tmax=training_end) if training_start < training_end else None
        raw_test = raw.copy().crop(tmin=test_start, tmax=test_end) if test_start < test_end else None

        # Save cropped data
        if raw_training:
            training_file = f"{outputFolder}/{participant}/fif/{self.cond}_training_data.fif"
            raw_training.save(training_file, overwrite=True)
            print(f"Extract Event Windows: Training data saved to '{training_file}'")
        else:
            print("Extract Event Windows: Invalid training period duration.")

        if raw_test:
            test_file = f"{outputFolder}/{participant}/fif/{self.cond}_test_data.fif"
            raw_test.save(test_file, overwrite=True)
            print(f"Extract Event Windows: Test data saved to '{test_file}'")
        else:
            print("Extract Event Windows: Invalid test period duration.")

        # Log durations
        if raw_training:
            print(f"Extract Event Windows: Training data duration: {(training_end - training_start)} seconds")
        if raw_test:
            print(f"Extract Event Windows: Test data duration: {(test_end - test_start)} seconds")

        # Save condition timestamps
        condition_timestamps = {
            'training_start': training_start,
            'training_end': training_end,
            'test_start': test_start,
            'test_end': test_end,
        }
        
        self.condition_timestamps = condition_timestamps
        pd.DataFrame([condition_timestamps]).to_csv(f"{outputFolder}/{participant}/{self.cond}_condition_timestamps.csv", index=False)
        print(f"Extract Event Windows: Condition timestamps saved to '{outputFolder}/{participant}/{self.cond}_condition_timestamps.csv'.")

        # Prepare and save trial data
        trials = []
        
        # Added 10.19 to remove the other if statement
        # Needed a valid check
        if len(event_start_indices) != len(event_end_indices):
            raise ValueError(
                f"Extract Event Windows: Mismatch between start ({len(event_start_indices)}) and end ({len(event_end_indices)}) event indices."
            )

        # Loop through each pair of start and end
        for i, start_index in enumerate(event_start_indices):
            # Ensure there's a corresponding end index
            if i < len(event_end_indices):
                end_index = event_end_indices[i]

                # Determine condition based on whether it's the last start/end pair
                condition = 'train' if i < len(event_start_indices) - 1 else 'test'

                # Add data for both the start and end events
                trials.append({'event_type': events[start_index], 'sample_idx': start_index, 'time': times[start_index], 'condition': condition})
                trials.append({'event_type': events[end_index], 'sample_idx': end_index, 'time': times[end_index], 'condition': condition})

        # # Handle last start/end pair for test condition (if required)
        # if len(event_start_indices) > len(event_end_indices):
        #     last_start_index = event_start_indices[-1]
        #     trials.append({'event_type': events[last_start_index], 'sample_idx': last_start_index, 'time': times[last_start_index], 'condition': 'test'})

        trials_df = pd.DataFrame(trials)
        trials_df.to_csv(f"{outputFolder}/{participant}/{self.cond}_trials.csv", index=False)

        print(f"Extract Event Windows: Trials saved to '{outputFolder}/{participant}/{self.cond}_trials.csv'.")

        return raw_training, raw_test

##############################################################################################################################################################
# Step 4: Preprocessing Functions (Baseline, Autoreject, ICA)
##############################################################################################################################################################

    def apply_baseline_correction(self, raw, rest_events=None):
        """
        Applies baseline correction to EEG data using rest markers or a fallback.

        - Uses marker 200–210 interval as rest baseline if available.
        - Falls back to default 15s baseline from start if rest markers are missing.
        - Applies correction across EEG channels using mean subtraction.

        Args:
        - raw (mne.io.Raw): Raw EEG data to be baseline corrected.

        Returns:
        - Baseline-corrected Raw EEG object.
        """
        
        print("\nStep 4.1: Applying baseline correction...")
        sfreq = raw.info['sfreq']

        print("Apply Baseline Correction: Detecting rest intervals from markers...")
        
        # rest_start_event = self.marker_df[self.marker_df['value'] == 200].index.tolist()
        # rest_end_event = self.marker_df[self.marker_df['value'] == 210].index.tolist()

        # rest_start = rest_start_event[0] if rest_start_event else None
        # rest_end = rest_end_event[0] if rest_end_event else None
        
        # ONR2 Germane didn't use 210
        print("Apply Baseline Correction: Detecting rest intervals from markers (Start: 200, End: 31 or 32)...")
        # Find the first occurrence of the start marker (200)
        # Searches for all rows where value is 200 - a list of indices
        rest_start_indices = self.marker_df[self.marker_df['value'] == 200].index.tolist()
        # takes the first start marker - the first rest - if not found, sets to None
        rest_start = rest_start_indices[0] if rest_start_indices else None
        rest_end = None # Initialize rest_end

        if rest_start is not None:
            # Create a slice of the DataFrame starting *after* the rest_start marker
            # We look for the *next* 31 or 32
            df_after_start = self.marker_df.loc[rest_start + 1:]
            # Find the first index where the value is either 31 or 32 within the slice of data after the first 200
            rest_end_indices = df_after_start[
                (df_after_start['value'] == 31) | (df_after_start['value'] == 32)
            ].index.tolist()
            # List of occurrences of 31 and 32 indices
            # Get the first occurrence, if one exists
            rest_end = rest_end_indices[0] if rest_end_indices else None
            # -----------------------------------

        # Verify the rest interval
        # As long as the rest start and end are present and end comes after start it is valid
        if rest_start is not None and rest_end is not None and rest_end > rest_start:
            print(f"Apply Baseline Correction: Using detected rest interval for baseline correction: Start={rest_start}, End={rest_end}")
            try:
                # we essentially re-center each channels' signal, making the average activity during the rest period equal to zero
                # We calcualte hte mean voltage for each eeg channel during the rest period
                # Then subtract that mean from the entire signal for that channel
                raw.apply_function(
                    # x[] is the slice - the rest period - output is small array of events
                    # np.mean ,axis = -1 calcualtes the mean of the rest period - 1D array each value is the average voltage for one channel
                    # ...[:,None] is a NumPy trick to add a new dimension turning n_channels, into a 2D array
                    # x - ... final subtraction from the n_channels and n_samples
                    # Like subtract the Fz mean from every time point in the Fz channel
                    lambda x: x - np.mean(x[:, rest_start:rest_end], axis=-1)[:, None],
                    picks='eeg', channel_wise=False
                )
                print("Apply Baseline Correction: Baseline correction using rest interval applied successfully.")
            except Exception as e:
                print(f"Apply Baseline Correction: Error applying baseline correction with rest interval: {e}")
        else:
            if rest_start is None or rest_end is None:
                print("Warning: Rest interval markers not detected.")
            elif rest_end <= rest_start:
                print("Warning: Detected rest interval is invalid (end occurs before or at start).")

            # Fallback: Use default 15s baseline
            print("Apply Baseline Correction: Falling back to default 15s baseline...")
            baseline_start = int(max(0, (raw.times[0] + 15) * sfreq))  # First 15 seconds of the recording
            baseline_end = int(max(0, raw.times[0] * sfreq))  # Start of the recording

            if baseline_end > baseline_start:
                print(f"Apply Baseline Correction: Using default baseline: Start={baseline_start}, End={baseline_end}")
                try:
                    raw.apply_function(
                        lambda x: x - np.mean(x[:, baseline_start:baseline_end], axis=-1),
                        picks='eeg', channel_wise=False
                    )
                    print("Apply Baseline Correction: Baseline correction using default interval applied successfully.")
                except Exception as e:
                    print(f"Error applying baseline correction with default interval: {e}")
            else:
                print("Error: Default baseline interval is invalid. Skipping baseline correction.")

        # Save the baseline-corrected file
        print(f"Apply Baseline Correction: Saving baseline-corrected file to '{outputFolder}/{participant}/fif/{self.cond}_D_bc.fif'...")
        raw.save(f"{outputFolder}/{participant}/fif/{self.cond}_D_bc.fif", overwrite=True)

        return raw

    # def apply_ica(self, raw):
    #     """
    #     Applies ICA to remove general artifacts from EEG data.

    #     - Performs ICA decomposition using 3 components.
    #     - Applies ICA to cleaned EEG signal.
    #     - Saves the resulting EEG data to disk.

    #     Args:
    #     - raw (mne.io.Raw): EEG data to process.

    #     Returns:
    #     - Tuple of cleaned EEG (Raw) and ICA object.
    #     """
        
    #     print("\nStep 4.2: Applying ICA...")

    #     # n_components less than or equal to the number of eeg channels
    #     # I set n_components to 3
    #     # Really not recommended to only run with 3 channels
    #     ica = ICA(n_components=n_components, random_state=97, max_iter=800)

    #     # Select EEG channels only for ICA
    #     picks = mne.pick_types(raw.info, eeg=True, exclude='bads')
    #     print(f"Apply ICA: EEG Picks: {picks}")
    #     ica.fit(raw, picks=picks)

    #     # Apply ICA to the raw data
    #     raw = ica.apply(raw)
    #     print(f"Apply ICA: ICA applied with {n_components} components.")

    #     # Save the cleaned EEG file
    #     raw.save(f"{outputFolder}/{participant}/fif/{self.cond}_D_cleaned.fif", overwrite=True)
    #     print(f"Apply ICA: Cleaned EEG file saved as '{outputFolder}/{participant}/fif/{self.cond}_D_cleaned.fif'.")

    #     return raw, ica

    def apply_autoreject(self, raw):
        """
        Applies AutoReject to detect and repair artifacts in EEG data.

        - Segments raw data into 2-second epochs.
        - Tries multiple consensus thresholds for repair (no channel rejection).
        - Reconstructs a cleaned Raw object from repaired epochs.

        Args:
        - raw (mne.io.Raw): Raw EEG data to process.

        Returns:
        - Cleaned Raw EEG data if successful, otherwise original input.
        """
        
        print("\nStep 4.3: Applying AutoReject (Repair Only)...")
        
        # We can change the duration here between 1 and 2 seconds
        epochs = mne.make_fixed_length_epochs(raw, duration=2.0, preload=True)

        # the consensus parameter determines how strict the rejection is across trials
        consensus_values = [1.0, 0.1, 0.5, 0.3, 0.75]
        for consensus in consensus_values:
            try:
                print(f"Apply AutoReject: Trying consensus={consensus}...")
                ar = AutoReject(
                    # Max interpolation options
                    n_interpolate=[2, 5, len(raw.ch_names)],
                    # Consensus threshold
                    consensus=[consensus],
                    thresh_method='bayesian_optimization',
                    random_state=42,
                    n_jobs=-1,
                    verbose=True
                )

                # Fit AutoReject
                ar.fit(epochs)
                # Repair epochs
                # epochs_clean = ar.transform(epochs, return_log=False)
                epochs_clean, reject_log = ar.transform(epochs, return_log=True)
                
                # --- ADD THIS CHECK ---
                total_repairs = np.nansum(reject_log.labels)
                total_bad_epochs = reject_log.bad_epochs.sum()

                print(f"Apply AutoReject: Found {total_bad_epochs} bad epochs.")
                print(f"Apply AutoReject: Repaired a total of {total_repairs} channels.")
                # --- END CHECK ---

                # If total_repairs is > 0, you know it's influencing the data.
                if total_repairs == 0:
                    print("Apply AutoReject: No channels were repaired. Data was already clean.")


                # Reconstruct cleaned data into raw format
                raw_clean = epochs_clean.get_data().reshape(len(raw.ch_names), -1)
                raw_clean = mne.io.RawArray(raw_clean, raw.info)

                print(f"Apply AutoReject: Success with consensus={consensus}!")
                print(f"Apply AutoReject: Channels after AutoReject: {raw_clean.ch_names}")
                return raw_clean
            except Exception as e:
                print(f"Apply AutoReject: Failed with consensus={consensus}: {e}")

        print("Apply AutoReject: AutoReject failed. Returning original raw data.")
        return raw

##############################################################################################################################################################
# Step 5: Epoching
##############################################################################################################################################################

    def epoch_data(self, raw, type, duration=1.0, overlap=0.0):
        """
        Segments raw EEG into fixed-length epochs and saves trial metadata.

        - Creates non-event-based epochs with optional overlap.
        - Stores each epoch’s start, end, and sample count to CSV.
        - Returns an MNE Epochs object or None on failure.

        Args:
        - raw (mne.io.Raw): EEG data to segment.
        - type (str): Label for the epoch type ('train', 'test', etc.).
        - duration (float): Length of each epoch in seconds.
        - overlap (float): Overlap between epochs in seconds.

        Returns:
        - MNE Epochs object or None if no epochs were created.
        """

        print("\nStep 5.1: Epoching data...")
        
        start = 0
        stop = raw.n_times / raw.info['sfreq']
        
        temp = raw.info['sfreq']
        print(f'Epoch Data: Raw n times: {raw.n_times}')
        print(f'Epoch Data: Raw info sfreq: {temp}')
        print(f"Epoch Data: Raw data range: Start={start}, Stop={stop}, Duration={duration}, Overlap={overlap}")

        try:
            if overlap >= duration:
                raise ValueError(f"Epoch Data: Overlap must be >=0 but < duration ({duration}), got {overlap}")

            events = mne.make_fixed_length_events(
                raw, id=1, start=start, stop=stop, duration=duration, overlap=overlap
            )
            print(f"Epoch Data: Generated {len(events)} fixed-length events.")
        except ValueError as e:
            print(f"Epoch Data: Error generating events: {e}")
            return None

        if len(events) == 0:
            print("Epoch Data: No fixed-length events created.")
            return None

        epochs = mne.Epochs(
            raw, events, tmin=0, tmax=duration, baseline=None, detrend=1, preload=True
        )

        # Save trial information
        trial_data = []
        for event in events:
            trial_info = {
                'start': event[0] / raw.info['sfreq'],
                'end': (event[0] + int(raw.info['sfreq'] * duration)) / raw.info['sfreq'],
                'samples': int(raw.info['sfreq'] * duration),
                'type': event[2]
            }
            self.trl.append(trial_info)
            trial_data.append(trial_info)

        # Save trial data to a CSV file
        trial_df = pd.DataFrame(trial_data)
        filename = f"{outputFolder}/{participant}/{self.cond}_epoch_trl_{type}.csv"
        trial_df.to_csv(filename, index=False)
        print(f"Epoch Data: Epoch trial data saved to '{filename}'.")

        return epochs

##############################################################################################################################################################
# Step 6: Analysis Functions (FFT, Band Power, PSD, etc.)
##############################################################################################################################################################

    def compute_and_save_band_power(self, epochs, condition_name):
        """
        Computes average band power per channel and epoch, then saves to CSV.

        - Uses Welch’s method to compute PSD from 0.1–30 Hz.
        - Aggregates mean power within standard EEG bands.
        - Saves results by epoch and channel.

        Args:
        - epochs (mne.Epochs): Epochs to analyze.
        - condition_name (str): Label for the condition (e.g., 'train').

        Output:
        - CSV of band power saved per participant and condition.
        """
        
        print(f"\nStep 6.1: Computing band power for {condition_name}...")

        band_power_results = []

        for epoch_idx, epoch_data in enumerate(epochs.get_data()):
            for ch_idx, channel_name in enumerate(epochs.ch_names):
                psd, freqs = mne.time_frequency.psd_array_welch(
                    epoch_data[ch_idx], sfreq=epochs.info['sfreq'], fmin=0.1, fmax=30.0, n_per_seg=128
                )

                band_power = {}
                for band, (fmin, fmax) in bands.items():
                    band_mask = (freqs >= fmin) & (freqs <= fmax)
                    band_power[f'{band}_power'] = psd[band_mask].mean() if band_mask.any() else 0

                band_power['Epoch'] = epoch_idx + 1
                band_power['Channel'] = channel_name
                band_power_results.append(band_power)

        band_power_df = pd.DataFrame(band_power_results)
        output_file = f"{outputFolder}/{participant}/{self.cond}_band_power_{condition_name}.csv"
        band_power_df.to_csv(output_file, index=False)
        print(f"Compute Band Power: Band power results saved to '{output_file}'.")

    def compute_psd_and_ratios(self, raw, type):
        """
        Computes PSD and EEG power ratios, then saves results to CSV.

        - Uses Welch’s method (0.1–30 Hz) to compute PSD per channel.
        - Calculates Pz Alpha / Fz Theta and Beta / (Alpha + Theta) ratios.
        - Saves full PSD and ratio summaries per participant and condition.

        Args:
        - raw (mne.io.Raw): Preprocessed EEG data.
        - type (str): Condition label (e.g., "train" or "test").

        Outputs:
        - Two CSVs: full PSD values and summary power ratios.
        """
        
        print("\nStep 6.2: Computing PSD and ratios for analysis...")
        
        psd_results = []
        ratio_results = []

        try:
            # Compute PSD for each channel
            for channel_idx, channel_name in enumerate(raw.ch_names):
                # Extract channel data
                psd_data = raw.get_data(picks=[channel_idx])
                psd_values, freqs = mne.time_frequency.psd_array_welch(
                    psd_data, sfreq=raw.info['sfreq'], fmin=0.1, fmax=30.0, n_per_seg=int(4 * raw.info['sfreq'])
                )
                psd_values = psd_values[0]  # PSD values are returned in a nested array

                # Calculate mean power for each band
                psd_band_values = {'Channel': channel_name}
                for band, (fmin, fmax) in bands.items():
                    band_mask = (freqs >= fmin) & (freqs <= fmax)
                    psd_band_values[f'psd_{band}'] = np.mean(psd_values[band_mask]) if np.any(band_mask) else np.nan

                psd_results.append(psd_band_values)

            # Save PSD results to a CSV file
            psd_df = pd.DataFrame(psd_results)
            psd_df.to_csv(f'{outputFolder}/{participant}/{self.cond}_psd_results_{type}.csv', index=False)
            print(f"Compute PSD & Ratios: PSD results saved to '{outputFolder}/{participant}/{self.cond}_psd_results_{type}.csv'.")

            # Validate PSD results
            if psd_df.empty:
                print("Compute PSD & Ratios: Error: PSD results are empty. Cannot compute ratios.")
                return

            # Retrieve values for specific channels
            pz_alpha = psd_df.loc[psd_df['Channel'] == 'Pz', 'psd_Alpha'].values # Changed to capital
            fz_theta = psd_df.loc[psd_df['Channel'] == 'Fz', 'psd_Theta'].values # Changed to capital

            # Handle cases where the channel values are missing or zero
            if len(pz_alpha) > 0:
                pz_alpha = pz_alpha[0]
            else:
                print("Compute PSD & Ratios: Warning: Missing value for 'Pz Alpha'. Setting to NaN.")
                pz_alpha = np.nan

            if len(fz_theta) > 0 and fz_theta[0] != 0:
                fz_theta = fz_theta[0]
            else:
                print("Compute PSD & Ratios: Warning: 'Fz Theta' is zero or invalid. Setting to NaN.")
                fz_theta = np.nan

            # Calculate alpha/theta ratios
            if not np.isnan(fz_theta):
                alpha_theta_ratio = pz_alpha / fz_theta
            else:
                alpha_theta_ratio = np.nan
                print("Compute PSD & Ratios: Warning: Cannot compute Alpha/Theta ratio due to missing or zero 'Fz Theta'.")

            # Calculated average beta / (alpha + theta)
            selected_channels = ['Fz', 'Cz', 'Pz']
            
            # Ensure all channels are present
            available_channels = psd_df['Channel'].values
            missing_channels = [ch for ch in selected_channels if ch not in available_channels]
            if missing_channels:
                print(f"Compute PSD & Ratios: Warning: Missing channels in data: {missing_channels}")
                
            # Filter for available channels
            psd_filtered = psd_df[psd_df['Channel'].isin(selected_channels)]
            
            # Calculate mean power across selected channels
            mean_alpha = psd_filtered['psd_Alpha'].mean() # Changed to capital letters
            mean_theta = psd_filtered['psd_Theta'].mean()
            mean_beta = psd_filtered['psd_Beta'].mean()
            
            print("Entered?")
            
            # Handle NaN values
            if np.isnan(mean_alpha) or np.isnan(mean_theta) or np.isnan(mean_beta):
                print("Compute PSD & Ratios: Warning: Missing values in the selected channels. Ratios may be inaccurate.")
                beta_combined_ratio = np.nan
            else:
                beta_combined_ratio = mean_beta / (mean_alpha + mean_theta)

            # Add participant ID and save ratio results
            ratio_results.append({
                'Participant ID': participant,
                'Pz Alpha': pz_alpha,
                'Fz Theta': fz_theta,
                'Pz Alpha / Fz Theta': alpha_theta_ratio,
                'Beta / (Alpha + Theta)': beta_combined_ratio
            })

            
            ratio_df = pd.DataFrame(ratio_results)
            ratio_df.to_csv(f'{outputFolder}/{participant}/{self.cond}_ratios_analysis_{type}.csv', index=False)
            print(f"Compute PSD & Ratios: Ratios analysis saved to '{outputFolder}/{participant}/{self.cond}_ratios_analysis_{type}.csv'.")

        except Exception as e:
            print(f"Compute PSD & Ratios: Error during PSD or ratio computation: {e}")

    def finalize_gamma_table(self, condition_name):
        """
        Computes log-transformed, z-scored gamma power (20–100 Hz) per epoch and channel.

        - Processes either training or test epochs.
        - Applies log and z-transforms to gamma band power.
        - Saves per-epoch and averaged gamma power to CSV.

        Args:
        - condition_name (str): 'train' or 'test' to specify which epochs to process.

        Outputs:
        - Two CSVs: full z-scored gamma table and per-channel averages.
        """
        
        print(f"\nStep 6.3: Finalizing gamma table for {condition_name}...")
        if condition_name == "train":
            if self.epochs_train is None:
                print("Gamma Table: Training epochs not found.")
                return
            epochs = self.epochs_train
        elif condition_name == "test":
            if self.epochs_test is None:
                print("Gamma Table: Test epochs not found.")
                return
            epochs = self.epochs_test
        else:
            print("Gamma Table: Invalid condition name.")
            return

        gamma_low, gamma_high = 20, 100
        sfreq = epochs.info['sfreq']
        results = []

        for epoch_idx, epoch in enumerate(epochs.get_data()):
            log_gamma_power = []
            for ch_idx, channel_name in enumerate(epochs.ch_names):
                psd, freqs = mne.time_frequency.psd_array_welch(
                    epoch[ch_idx], sfreq=sfreq, fmin=gamma_low, fmax=gamma_high, n_per_seg=128
                )
                power = psd.mean()
                if power <= 0:
                    power = 1e-10  # safeguard against non-positive power
                log_power = np.log(power)
                log_gamma_power.append(log_power)
            log_gamma_power = np.array(log_gamma_power)
            mean_val = log_gamma_power.mean()
            std_val = log_gamma_power.std()
            if std_val == 0:
                z_gamma = np.zeros_like(log_gamma_power)
            else:
                z_gamma = (log_gamma_power - mean_val) / std_val

            epoch_result = {"Epoch": epoch_idx + 1}
            for ch_idx, channel_name in enumerate(epochs.ch_names):
                epoch_result[channel_name] = z_gamma[ch_idx]
            results.append(epoch_result)

        # Save per-epoch table
        df = pd.DataFrame(results)
        output_file = f"{outputFolder}/{participant}/{self.cond}_log_z_gamma_power_{condition_name}.csv"
        df.to_csv(output_file, index=False)
        print(f"Gamma Table: Finalized gamma table saved to '{output_file}'.")

        # Exclude the 'Epoch' column and compute the mean for each channel.
        avg_gamma = df.drop(columns=["Epoch"]).mean()
        # Create a DataFrame with one row where each column corresponds to a channel.
        avg_df = pd.DataFrame(avg_gamma).transpose()
        avg_df.insert(0, "Condition", condition_name)  # Optionally add a condition column.
        avg_output_file = f"{outputFolder}/{participant}/{self.cond}_avg_log_z_gamma_power_{condition_name}.csv"
        avg_df.to_csv(avg_output_file, index=False)
        print(f"Gamma Table: Averaged gamma power (log_z) saved to '{avg_output_file}'.")

    def compute_fft(self, epochs, condition_name):
        """
        Computes FFT per channel and epoch, then saves results to CSV.

        - Uses 1-second FFT window per channel in each epoch.
        - Extracts frequencies and power magnitudes (|FFT|).
        - Saves frequency spectra per epoch and channel.

        Args:
        - epochs (mne.Epochs): Segmented EEG data.
        - condition_name (str): Label for the condition (e.g., 'train').

        Output:
        - CSV file with FFT data saved by participant and condition.
        """

        print(f"\nStep 6.4: Computing FFT for {condition_name}...")

        fft_results = []
        sfreq = epochs.info['sfreq']
        n_fft = int(sfreq)  # Use one second worth of samples for FFT

        for epoch_idx, epoch_data in enumerate(epochs.get_data()):
            for ch_idx, channel_name in enumerate(epochs.ch_names):
                # Compute FFT for the channel in the current epoch
                fft_vals = np.fft.rfft(epoch_data[ch_idx], n=n_fft)
                freqs = np.fft.rfftfreq(n_fft, d=1/sfreq)

                # Store FFT results
                fft_results.append({
                    'Epoch': epoch_idx + 1,
                    'Channel': channel_name,
                    'Frequencies': freqs.tolist(),
                    'FFT_Values': np.abs(fft_vals).tolist()
                })

        # Save FFT results to a CSV file
        fft_df = pd.DataFrame(fft_results)
        output_file = f"{outputFolder}/{participant}/{self.cond}_fft_{condition_name}.csv"
        fft_df.to_csv(output_file, index=False)
        print(f"FFT: FFT results saved to '{output_file}'.")

    def compute_band_averages(self, epochs, condition_name):
        """
        Computes average EEG band power across brain regions and saves to CSV.

        - Uses multitaper PSD to extract power from 1–100 Hz.
        - Averages power within standard bands and brain regions.
        - Skips missing channels or unmatched frequency bins.

        Args:
        - epochs (mne.Epochs): Segmented EEG data.
        - condition_name (str): Label for the condition (e.g., "train").

        Output:
        - CSV summarizing average power per region and frequency band.
        """

        print(f"\nStep 6.5: Computing band averages for {condition_name}...")

        results = []

        # Compute PSD for the epochs
        try:
            psd = epochs.compute_psd(method='multitaper', fmin=1, fmax=100)
            psd_data = psd.get_data()  # Shape: (epochs, channels, frequencies)
            freqs = psd.freqs  # 1D array of frequencies

            for region, channels in regions.items():
                for band, (fmin, fmax) in bands.items():
                    selected_channels = [ch for ch in channels if ch in epochs.ch_names]
                    if not selected_channels:
                        # print(f"No channels found for region {region}. Skipping.")
                        continue

                    # Get indices of selected channels in the PSD data
                    channel_indices = [epochs.ch_names.index(ch) for ch in selected_channels]

                    # Frequency mask
                    freq_mask = (freqs >= fmin) & (freqs <= fmax)
                    if not freq_mask.any():
                        print(f"No frequencies found for band {band} in region {region}. Skipping.")
                        continue

                    # Extract and average power over epochs, selected channels, and frequency band
                    band_power = psd_data[:, channel_indices, :][:, :, freq_mask].mean(axis=(0, 2))

                    # Append results
                    results.append({
                        'Region': region,
                        'Band': band,
                        'Average_Power': band_power.mean(),  # Overall average across epochs
                    })

            # Save results to a CSV file
            df = pd.DataFrame(results)
            output_file = f"{outputFolder}/{participant}/{self.cond}_band_averages_{condition_name}.csv"
            df.to_csv(output_file, index=False)
            print(f"Compute Band Averages: Band averages saved to '{output_file}'.")

        except Exception as e:
            print(f"Compute Band Averages: Error during band average computation: {e}")

##############################################################################################################################################################
# Step 7: Creating Summary Tables
##############################################################################################################################################################

    def update_master_ratios(self, cond_type):
        """
        Updates the master ratio summary CSV with current condition ratios.

        - Adds {cond}_ratio1 and {cond}_ratio2 for participant/condition.
        - Creates master file if missing, with rows for P01–P36 train/test.
        - Appends or updates ratio values in the appropriate row.

        Args:
        - cond_type (str): Either "train" or "test", used to locate input file and label rows.

        Output:
        - Updated 'All_Ratios_Summary.csv' with new or modified values.
        """

        master_path = f"{outputFolder}/All_Ratios_Summary.csv"
        cond_file = f"{outputFolder}/{self.participant}/{self.cond}_ratios_analysis_{cond_type}.csv"

        df_cond = pd.read_csv(cond_file)
        ratio1 = df_cond['Pz Alpha / Fz Theta'].values[0]
        ratio2 = df_cond['Beta / (Alpha + Theta)'].values[0]

        row_id = f"{self.participant}_{cond_type}"
        col1 = f"{self.cond}_ratio1"
        col2 = f"{self.cond}_ratio2"

        # Initialize or load the master file
        if os.path.exists(master_path):
            df_master = pd.read_csv(master_path, index_col=0)
        else:
            participants = [f"P{str(i).zfill(2)}" for i in range(1, 35)] ## Previously 1 to 37
            row_labels = [f"{p}_train" for p in participants] + [f"{p}_test" for p in participants]
            df_master = pd.DataFrame(index=row_labels)

        # Ensure the columns exist
        if col1 not in df_master.columns:
            df_master[col1] = pd.NA
        if col2 not in df_master.columns:
            df_master[col2] = pd.NA

        # Update the row with ratios
        df_master.loc[row_id, col1] = ratio1
        df_master.loc[row_id, col2] = ratio2

        # Save updated summary
        df_master.to_csv(master_path)
        print(f"Summary Table: Updated {master_path} for {row_id}.")

##############################################################################################################################################################
# Step 8: Process Pipeline
##############################################################################################################################################################

    def process(self):
        """
        Runs the full EEG processing pipeline for a participant.

        - Loads and preprocesses EEG data.
        - Extracts training and test windows.
        - Applies baseline correction, ICA, and AutoReject.
        - Epochs the signal and performs band power, PSD, FFT, gamma, and regional analysis.
        - Saves all results to participant-specific output folders.

        No arguments. Results are written to disk.
        """
        
        # Load the dataset
        raw = self.load_set()

        # Extract training and test event windows
        print("\nStep 3: Event Extraction")
        raw_training, raw_test = self.extract_event_windows(raw)

        # Process training data
        if raw_training:
            print("\n***********TRAIN***********")
            print("\nStep 4: Preprocessing Functions")
            raw_training = self.apply_baseline_correction(raw_training, None)
            # Removed ICA - Not recommended for only three electrodes
            # raw_training, ica_training = self.apply_ica(raw_training)
            raw_training = self.apply_autoreject(raw_training)
            print("\nStep 5: Epoching")
            epochs_training = self.epoch_data(raw_training, "train", duration=1.0, overlap=0.0)
            if epochs_training:
                # Temp
# -----------------------------------------------------------------
                # --- PLOTTING SENSOR LOCATIONS (TRAIN) ---
                # -----------------------------------------------------------------
                print("\nStep 5.5: Generating Sensor Location Plot (Train)...")
                try:
                    # This plots the 2D sensor locations on a head outline
                    # It confirms your custom montage is working.
                    fig = epochs_training.plot_sensors(show=False, kind='topomap')
                    
                    # Define a path and save the figure
                    plot_path = f"{outputFolder}/{self.participant}/plots"
                    os.makedirs(plot_path, exist_ok=True)
                    fig_path = f"{plot_path}/{self.cond}_train_sensor_locations.png"
                    
                    fig.savefig(fig_path)
                    plt.close(fig) # Close the figure to save memory
                    print(f"Sensor location plot saved to {fig_path}")
                    
                except Exception as e:
                    print(f"Warning: Could not plot/save sensor plot: {e}")
                # -----------------------------------------------------------------
 # ----------------------------------------------------------------
                    
                print("\nStep 6: Analysis Functions")
                self.compute_and_save_band_power(epochs_training, "train")
                self.compute_psd_and_ratios(raw_training, "train")
                
                print("\nStep 7: Summary Table")
                self.update_master_ratios("train")

                self.epochs_train = epochs_training
                self.finalize_gamma_table("train")

                self.compute_fft(epochs_training, "train")
                self.compute_band_averages(epochs_training, "train")

        # Process test data
        if raw_test:
            print("\n***********TEST***********")
            print("\nStep 4: Preprocessing Functions")
            raw_test = self.apply_baseline_correction(raw_test, None)
            # raw_test, ica_test = self.apply_ica(raw_test)
            raw_test = self.apply_autoreject(raw_test)
            print("\nStep 5: Epoching")
            epochs_test = self.epoch_data(raw_test, "test", duration=1.0, overlap=0.0)
            if epochs_test:
                # Temp
# -----------------------------------------------------------------
                # --- PLOTTING SENSOR LOCATIONS (TEST) ---
                # -----------------------------------------------------------------
                print("\nStep 5.5: Generating Sensor Location Plot (Test)...")
                try:
                    # This plots the 2D sensor locations on a head outline
                    fig = epochs_test.plot_sensors(show=False, kind='topomap')
                    
                    plot_path = f"{outputFolder}/{self.participant}/plots"
                    os.makedirs(plot_path, exist_ok=True)
                    fig_path = f"{plot_path}/{self.cond}_test_sensor_locations.png"
                    
                    fig.savefig(fig_path)
                    plt.close(fig)
                    print(f"Sensor location plot saved to {fig_path}")
                    
                except Exception as e:
                    print(f"Warning: Could not plot/save sensor plot: {e}")
# -----------------------------------------------------------------
                    
                print("\nStep 6: Analysis Functions")
                self.compute_and_save_band_power(epochs_test, "test")
                self.compute_psd_and_ratios(raw_test, "test")
                
                print("\nStep 7: Summary Table")
                self.update_master_ratios("test")

                self.epochs_test = epochs_test
                self.finalize_gamma_table("test")

                self.compute_fft(epochs_test, "test")
                self.compute_band_averages(epochs_test, "test")

# def main(participant, set_file, marker_csv, cond, cond_start, cond_end):
def main(participant, set_file, marker_csv, digitizer_txt, cond, cond_start, cond_end):
    """
    Initializes EEGProcessor and runs the full pipeline for a participant.

    Args:
    - participant (str): Participant ID (e.g., "P01").
    - set_file (str): Path to the .xdf EEG data file.
    - marker_csv (str): Path to the marker CSV file.
    - cond (str): Condition label (e.g., "rest", "task").
    - cond_start (int): Marker value indicating condition start.
    - cond_end (int): Marker value indicating condition end.
    """

    print("\nStep 0: Variables Initialized")
    processor = EEGProcessor(
        set_file = set_file,
        marker_csv = marker_csv,
        participant=participant,
        digitizer_txt=digitizer_txt,
        
        cond=cond,
        cond_start=cond_start,
        cond_end=cond_end,
        
        prestim=0,
        poststim=1000)
    processor.process()

if __name__ == "__main__":
    """
    Batch runs EEG processing for all participants and conditions.

    Iterates through each participant's file pair and each condition.
    For each combination, initializes processing and runs the full pipeline.
    """

    # for participant, set_file, marker_csv in file_pairs:
    for participant, set_file, marker_csv, digitizer_txt in file_pairs:
        for cond, (cond_start, cond_end) in conds.items():
            print(f'********************************************************************************************************************************')
            # print(f"Processing {participant} as {set_file} with {marker_csv}")
            print(f"Processing {participant} as {set_file} with {marker_csv} and {digitizer_txt}")
            print(f"Processing {cond} with {cond_start} to {cond_end}")
            print(f'********************************************************************************************************************************')
            # main(participant, set_file, marker_csv, cond, cond_start, cond_end)
            main(participant, set_file, marker_csv, digitizer_txt, cond, cond_start, cond_end)


********************************************************************************************************************************
Processing P02 as Dataset/P002 06.18.2025/sub-P001_ses-S001_task-Default_run-001_eeg.xdf with Dataset/P002 06.18.2025/P02_cleaned.csv and Dataset/P002 06.18.2025/P002 06.18.2025.txt
Processing LowGermane with 41 to 11
********************************************************************************************************************************

Step 0: Variables Initialized

Step 1: Obtain .XDF File

Step 2: Loading .xdf file...
Original sample rate found (nominal_srate): 500.0
Load Set: sfreq = 500.0
Load Set: Channel names: ['Fp1', 'Fz', 'F3', 'F7', 'F9', 'FC5', 'FC1', 'C3', 'ACC_X', 'ACC_Y', 'ACC_Z']
Dropping channels: ['F7', 'F9', 'FC5', 'FC1', 'C3']
Load Set: Apply custom digitization (montage) from file...
Parse and Create Montage: Parsing digitization file: Dataset/P002 06.18.2025/P002 06.18.2025.txt
Parse and Create Montage: Creating DigMontage objec