# Imports & Installs

In [200]:
# pip install mne
# pip install autoreject
# pip install pyxdf
# pip install numpy
# pip install pandas
# pip install scipy
# pip install matplotlib
# pip install pyxdf

In [201]:
import os
import warnings
import numpy as np
import pandas as pd
from scipy.io import loadmat
import matplotlib.pyplot as plt
from scipy.stats import ttest_ind
import mne
from mne.preprocessing import ICA
from autoreject import AutoReject
import pyxdf
import glob

# Compress Warnings

In [202]:
# Ignore runtime warnings for clean output (Came with the SK code)
warnings.filterwarnings('ignore', category=RuntimeWarning)
# "Effective window size : 1.024 (s)" woudl print 5000x (J added)
mne.set_log_level('WARNING')

# Step 0: Obtain XDF and CSV File

In [203]:
def get_set_file(participant):
    # Build the folder path for the given participant
    folder = os.path.join("Raw", participant)
    # Find all .xdf files in that folder
    files = glob.glob(os.path.join(folder, "*.xdf"))
    if not files:
        raise FileNotFoundError(f"No .xdf file found in {folder}")
    # Return the first file found (or adjust if you expect more than one)
    return files[0]

# Step 1: Global Variables (Determine your variables here)

In [204]:
# Participants are expected to have .xdf files and .csv files
participants = ["P01", "P07", "P15"]

# Uses LSL Values (The CSV File)
conditions = {
    "LL": {"cond_start": 42, "cond_end": 12},
    "LH": {"cond_start": 43, "cond_end": 13},
    "HL": {"cond_start": 44, "cond_end": 14},
    "HH": {"cond_start": 46, "cond_end": 16},
}

# Define regions and their corresponding channels
regions = {
    'Frontal': ['Fp1', 'Fz', 'F3', 'F7', 'F9', 'FC3', 'FC5', 'ACC_X', 'ACC_Y', 'ACC_Z'],
    'Central': ['C3'],
    'Parietal': ['P3', 'P4', 'PZ'],
    'Occipital': ['O1', 'O2'],
}

# Freqency analysis? Define ranges here
bands = {
    'Delta': (0.5, 4),
    'Theta': (4, 7),
    'Alpha': (8, 13),
    'Beta': (14, 20),
    'Gamma': (20, 100)
}

In [205]:
# Settings

# Analysis Type
Frequency = 1
ERP_option = 0

# Choose Preprocessing Options
ICA_option = 1

# Yes or No: Epochs?
Epochs = 1
Interval = 2 # In seconds

# Step 2: Loading Data & Channel Information

In [206]:
# Process Pipeline: channel_info = load_channel_info(channel_mat)

def load_channel_info(channel_mat):
    print("\nStep 2.1: Loading channel.mat file for channel names...")

    # Loads the .mat file with the 6 channels
    mat_data = loadmat(channel_mat)
    print(f"Load Channel Info: Channel.mat Data: {mat_data}")

    # Hard code channel names
    # channel_names = [str(mat_data['Channel']['Name'][0][i][0]) for i in range(mat_data['Channel']['Name'].shape[1])] # Extracts the channel names from a matlab file and saves in a []
    channel_names = ["Fz", "Cz", "Pz", "Acc1", "Acc2", "Acc3"]
    print(f"Load Channel Info: Hardcoded Channel Names: {channel_names}")

    # Extract channel locations and compresses into x,y,z for further analysis
    channel_locs = np.array([mat_data['Channel']['Loc'][0][i][:3] for i in range(mat_data['Channel']['Loc'].shape[1])]).squeeze().T

    # Creates a dictionary for the channel names with their associated locations
    return {'names': channel_names, 'locs': channel_locs}

In [207]:
# Process Pipeline: raw = load_set(set_file, channel_info, fixed_channels)

# def load_set(set_file, channel_info, fixed_channels):
def load_set(set_file, channel_info):
    print("\nStep 2.2: Loading .xdf file...")

    # Load the .xdf file
    streams, _ = pyxdf.load_xdf(set_file)

    # Find the EEG stream
    eeg_stream = next((s for s in streams if s['info']['type'][0].lower() == 'eeg'), None)
    if eeg_stream is None:
        raise ValueError("No EEG stream found in the XDF file.")

    # Extract data and sampling frequency
    data = np.array(eeg_stream['time_series']).T
    try:
        sfreq = float(eeg_stream['info']['sample_rate'][0])
    except:
        sfreq = 250.0
    print(f'Load Set: sfreq: {sfreq}')

    # Extract channel names from the stream
    try:
        ch_names = [chan['label'][0] for chan in eeg_stream['info']['desc'][0]['channels'][0]['channel']]
    except:
        ch_names = [f"Ch{i+1}" for i in range(data.shape[0])]
    # Probably will provide the original 11 names
    print(f'Load Set: Channel names: {ch_names}')

    # Create the MNE Raw object
    info = mne.create_info(ch_names=ch_names, sfreq=sfreq, ch_types=['eeg'] * len(ch_names))
    raw = mne.io.RawArray(data, info)

    # Replace channel names using the .mat file info
    ch_names_from_mat = channel_info['names']
    print(f"Load Set: Original 6 hardcoded channel names: {ch_names_from_mat}")
    if len(raw.ch_names) != len(ch_names_from_mat):
        # A bunch of stuff got hardcoded here
        # Drop channels not found in the .mat file
        channels_to_drop = ['F7', 'F9', 'FC5', 'FC1', 'C3']

        # Drop channels not in the .mat file
        if channels_to_drop:
            print(f"Load Set: Dropping channels: {channels_to_drop}")
            raw.drop_channels(channels_to_drop)

    # Rename channels and set montage
    raw.rename_channels({raw.ch_names[i]: ch_names_from_mat[i] for i in range(len(raw.ch_names))})
    montage = mne.channels.make_dig_montage(ch_pos=dict(zip(ch_names_from_mat, channel_info['locs'].T)))

    # Set channel types (EEG vs. accelerometer)
    # We need to set up acc and eeg differently
    channel_types = {ch: 'eeg' for ch in ['Fz', 'Cz', 'Pz']}
    channel_types.update({ch: 'misc' for ch in ['Acc1', 'Acc2', 'Acc3']})
    raw.set_channel_types(channel_types)

    raw.set_montage(montage)

    # Downsample to 250 Hz and apply a bandpass filter
    print("\nStep 2.2.1: Downsampling to 250 Hz...")
    raw.resample(250)
    print("\nStep 2.2.2: Applying bandpass filter (0.1 Hz to 30 Hz)...")
    raw.filter(l_freq=0.1, h_freq=30, picks='eeg')

    print("Load Set: EEG data successfully loaded and preprocessed.")
    return raw

# Step 3: Event Extraction

In [208]:
# Process Pipeline: raw_training, raw_test = extract_event_windows(raw, marker_df, cond_start, cond_end)

def extract_event_windows(raw, marker_df, cond_start, cond_end, output_loc, cond):
    print("\nStep 3.1: Extracting event windows for 'training' and 'test' periods...")

    if 'time' not in marker_df.columns or 'value' not in marker_df.columns:
        raise KeyError("The CSV file must contain 'time' and 'value' columns.")
    marker_df['value'] = pd.to_numeric(marker_df['value'], errors='coerce').dropna().astype(int)
    marker_df['time'] = pd.to_numeric(marker_df['time'], errors='coerce').dropna()
    # Had to reset the index for removed rows
    marker_df = marker_df.dropna(subset=['time', 'value']).reset_index(drop=True)

    events = marker_df['value']
    print(f"Extract Event Windows: LSL Values: {marker_df['value']}")
    times = marker_df['time']

    # Get indices for the specific marker values
    event_42_indices = [i for i, event in enumerate(events) if event == cond_start]
    print(f"Extract Event Windows: Indices of First Condition Start: {events.iloc[event_42_indices[0]]}")
    event_12_indices = [i for i, event in enumerate(events) if event == cond_end]

    # Changed the first mention of 6 from a 1 - Rereading this idk what this meant
    if len(event_42_indices) < 7 or len(event_12_indices) < 7:
        print("Extract Event Windows: Insufficient markers for training or test period.")
        print(f'Extract Event Windows: Length of 42/46 indices: {len(event_42_indices)}')
        print(f'Extract Event Windows: Length of 12/16 indices: {len(event_12_indices)}')
        return None, None

    # Use .iloc?
    # training_start = times.iloc[event_42_indices[0]]
    # print(f'Actual training start: {training_start}')
    # training_end = times.iloc[event_12_indices[5]]
    # print(f'Actual training end: {training_end}')
    # test_start = times.iloc[event_42_indices[-1]]
    # test_end = times.iloc[event_12_indices[-1]]
    training_start = times[event_42_indices[0]]  # First "42"
    print(f'Extract Event Windows: Actual training start: {training_start}')
    training_end = times[event_12_indices[5]]  # 6th "12" # Jay: For some reason this was the last one... changed from 5 to 4 !!
    print(f'Extract Event Windows: Actual training end: {training_end}')
    test_start = times[event_42_indices[-1]]  # Last "42"
    test_end = times[event_12_indices[-1]]  # Last "12"

    # Crop raw data for training and test periods
    # Jay: some clarification on what is happening here
    # Jay: Crop works with seconds. So taking the start sample / frequency provides seconds
    # Jay: Crops then copies the data from that start and end section
    # Jay: AHH we don't want this. Time is already in seconds
    # raw_training = raw.copy().crop(tmin=training_start / sfreq, tmax=training_end / sfreq) if training_start < training_end else None
    # raw_test = raw.copy().crop(tmin=test_start / sfreq, tmax=test_end / sfreq) if test_start < test_end else None
    raw_training = raw.copy().crop(tmin=training_start, tmax=training_end) if training_start < training_end else None
    raw_test = raw.copy().crop(tmin=test_start, tmax=test_end) if test_start < test_end else None

    # Save cropped data
    if raw_training:
        training_file = f"{output_loc}/fif/{cond}_training_data.fif"
        raw_training.save(training_file, overwrite=True)
        print(f"Extract Event Windows: Training data saved to '{training_file}'")
    else:
        print("Extract Event Windows: Invalid training period duration.")

    if raw_test:
        test_file = f"{output_loc}/fif/{cond}_test_data.fif"
        raw_test.save(test_file, overwrite=True)
        print(f"Extract Event Windows: Test data saved to '{test_file}'")
    else:
        print("Extract Event Windows: Invalid test period duration.")

    # Log durations
    if raw_training:
        print(f"Extract Event Windows: Training data duration: {(training_end - training_start)} seconds")
    if raw_test:
        print(f"Extract Event Windows: Test data duration: {(test_end - test_start)} seconds")

    # Save condition timestamps
    condition_timestamps = {
        'training_start': training_start,
        'training_end': training_end,
        'test_start': test_start,
        'test_end': test_end,
    }

    pd.DataFrame([condition_timestamps]).to_csv(f"{output_loc}/{cond}_condition_timestamps.csv", index=False)
    print(f"Extract Event Windows: Condition timestamps saved to '{output_loc}/{cond}_condition_timestamps.csv'.")

    # Prepare and save trial data
    trials = []

    # Loop through each pair of 42 and 12
    for i, start_index in enumerate(event_42_indices):
        # Ensure there's a corresponding "12" index
        if i < len(event_12_indices):
            end_index = event_12_indices[i]

            # Determine condition based on whether it's the last 42/12 pair
            condition = 'train' if i < len(event_42_indices) - 1 else 'test'

            # Add data for both the start (42) and end (12) events
            # trials.append({'event_type': events.iloc[start_index], 'sample_idx': start_index, 'time': times.iloc[start_index], 'condition': condition})
            # trials.append({'event_type': events.iloc[end_index], 'sample_idx': end_index, 'time': times.iloc[end_index], 'condition': condition})

            # Use .iloc?
            trials.append({'event_type': events[start_index], 'sample_idx': start_index, 'time': times[start_index], 'condition': condition})
            trials.append({'event_type': events[end_index], 'sample_idx': end_index, 'time': times[end_index], 'condition': condition})

    if len(event_42_indices) > len(event_12_indices):
        last_start_index = event_42_indices[-1]

        # Use .iloc?
        # trials.append({'event_type': events.iloc[last_start_index], 'sample_idx': last_start_index, 'time': times.iloc[last_start_index], 'condition': 'test'})
        trials.append({'event_type': events[last_start_index], 'sample_idx': last_start_index, 'time': times[last_start_index], 'condition': 'test'})

    # Create a DataFrame and save to a CSV
    trials_df = pd.DataFrame(trials)
    trials_df.to_csv(f"{output_loc}/{cond}_trials.csv", index=False)

    print(f"Extract Event Windows: Trials saved to '{output_loc}/{cond}_trials.csv'.")

    return raw_training, raw_test

# Step 4: Preprocessing Functions (Baseline, Autoreject, ICA)

In [209]:
# Process Pipeline: raw_training = apply_baseline_correction(raw_training, marker_df, condition_type)
# Process Pipeline: raw_test = apply_baseline_correction(raw_test, marker_df, condition_type)

def apply_baseline_correction(raw, marker_df, condition_type, output_loc, cond):
    print("\nStep 4.1: Applying baseline correction...")

    sfreq = raw.info['sfreq']

    print("Apply Baseline Correction: Detecting rest intervals from markers...")

    rest_start_event = marker_df[marker_df['value'] == 200].index.tolist()
    rest_end_event = marker_df[marker_df['value'] == 210].index.tolist()

    # Pass the index not the timestamp at the index
    rest_start = rest_start_event[0] if rest_start_event else None
    rest_end = rest_end_event[0] if rest_end_event else None

    # Verify the rest interval
    if rest_start is not None and rest_end is not None and rest_end > rest_start:
        print(f"Apply Baseline Correction: Using detected rest interval for baseline correction: Start={rest_start}, End={rest_end}")
        try:
            raw.apply_function(
                lambda x: x - np.mean(x[:, rest_start:rest_end], axis=-1)[:, None],
                picks='eeg', channel_wise=False
            )
            print("Apply Baseline Correction: Baseline correction using rest interval applied successfully.")
        except Exception as e:
            print(f"Apply Baseline Correction: Error applying baseline correction with rest interval: {e}")
    else:
        if rest_start is None or rest_end is None:
            print("Apply Baseline Correction: Warning: Rest interval markers not detected.")
        elif rest_end <= rest_start:
            print("Apply Baseline Correction: Warning: Detected rest interval is invalid (end occurs before or at start).")

        print("Apply Baseline Correction: Falling back to default 15s baseline...")
        # First 15 seconds of the recording
        baseline_start = int(max(0, (raw.times[0] + 15) * sfreq))
        # Start of the recording
        baseline_end = int(max(0, raw.times[0] * sfreq))

        if baseline_end > baseline_start:
            print(f"Apply Baseline Correction: Using default baseline: Start={baseline_start}, End={baseline_end}")
            try:
                raw.apply_function(
                    lambda x: x - np.mean(x[:, baseline_start:baseline_end], axis=-1),
                    picks='eeg', channel_wise=False
                )
                print("Apply Baseline Correction: Baseline correction using default interval applied successfully.")
            except Exception as e:
                print(f"Apply Baseline Correction: Error applying baseline correction with default interval: {e}")
        else:
            print("Apply Baseline Correction: Error: Default baseline interval is invalid. Skipping baseline correction.")

    # Save the baseline-corrected file
    print(f"Saving baseline-corrected file to '{output_loc}/fif/{cond}_D_bc_{condition_type}.fif'...")
    raw.save(f"{output_loc}/fif/{cond}_D_bc_{condition_type}.fif", overwrite=True)

    return raw

In [210]:
# Process Pipeline: raw_training, ica_training = apply_ica(raw_training, fixed_channels, condition_type)
# Process Pipeline: raw_test, ica_test = apply_ica(raw_test, fixed_channels, condition_type)

def apply_ica(raw, fixed_channels, condition_type, output_loc, cond):
    print("\nStep 4.2: Applying ICA...")

    # again, just going to hard code the value since we don't expect the number of eeg channels to change
    n_components = 3
    ica = ICA(n_components=n_components, random_state=97, max_iter=800)

    # Select EEG channels only for ICA
    picks = mne.pick_types(raw.info, eeg=True, exclude='bads')
    print(f"Apply ICA: EEG Picks: {picks}")
    ica.fit(raw, picks=picks)

    # Handle absence of EOG channels
    # Use channels from fixed_channels for EOG if available
    eog_indices, eog_scores = [], []
    eog_channels = [ch for ch in fixed_channels if ch in raw.ch_names]
    if eog_channels:
        eog_indices, eog_scores = ica.find_bads_eog(raw, ch_name=eog_channels)
        print(f"Apply ICA: Found EOG components: {eog_indices}")
    else:
        print("Apply ICA: No EOG channels found; skipping EOG artifact detection.")
        eog_indices = []

    # Detect ECG artifacts if possible
    try:
        ecg_indices, ecg_scores = ica.find_bads_ecg(raw, method='correlation')
        print(f"Apply ICA: Found ECG components: {ecg_indices}")
    except ValueError as e:
        print(f"Apply ICA: ECG artifact detection skipped: {e}")
        ecg_indices = []

    # Mark bad components for exclusion
    ica.exclude = eog_indices + ecg_indices
    print(f"Apply ICA: Excluded components: {ica.exclude}")

    # Apply ICA to the raw data
    raw = ica.apply(raw)
    print(f"Apply ICA: ICA applied with {n_components} components. Excluded {len(ica.exclude)} components.")

    # Save the cleaned EEG files
    raw.save(f"{output_loc}/fif/{cond}_D_cleaned_{condition_type}.fif", overwrite=True)
    print(f"Apply ICA: Cleaned EEG file saved as '{output_loc}/fif/{cond}_D_cleaned_{condition_type}.fif'.")

    return raw, ica

In [211]:
# Process Pipeline: raw_training = apply_autoreject(raw_training)
# Process Pipeline: raw_test = apply_autoreject(raw_test)

def apply_autoreject(raw):
    print("\nStep 4.3: Applying AutoReject (Repair Only)...")

    # WE can change the duration here between 1 and 2 seconds
    epochs = mne.make_fixed_length_epochs(raw, duration=2.0, preload=True)

    # Try different consensus levels
    consensus_values = [1.0, 0.1, 0.5]
    for consensus in consensus_values:
        try:
            print(f"Apply AutoReject: Trying consensus={consensus}...")
            ar = AutoReject(
                # Max interpolation options
                n_interpolate=[2, 5, len(raw.ch_names)],
                # Consensus threshold
                consensus=[consensus],
                thresh_method='bayesian_optimization',
                random_state=42,
                n_jobs=-1,
                verbose=True
            )

            # Fit AutoReject
            ar.fit(epochs)
            epochs_clean = ar.transform(epochs, return_log=False)

            # Reconstruct cleaned data into raw format
            raw_clean = epochs_clean.get_data().reshape(len(raw.ch_names), -1)
            raw_clean = mne.io.RawArray(raw_clean, raw.info)

            print(f"Apply AutoReject: Success with consensus={consensus}!")
            print(f"Apply AutoReject: Channels after AutoReject: {raw_clean.ch_names}")
            return raw_clean
        except Exception as e:
            print(f"Apply AutoReject: Failed with consensus={consensus}: {e}")

    print("Apply AutoReject: AutoReject failed. Returning original raw data.")
    return raw

# Step 5: Epoching

In [212]:
# Process Pipeline: epochs_training = epoch_data(raw_training, "train", duration=1.0, overlap=0.0)
# Process Pipeline: epochs_test = epoch_data(raw_test, "test", duration=1.0, overlap=0.0)

def epoch_data(raw, condition_label, output_loc, cond, duration=1.0, overlap=0.0):
    print("\nStep 5.1: Epoching data...")

    start = 0
    stop = raw.n_times / raw.info['sfreq']
    temp = raw.info['sfreq'] # J
    print(f'Epoch Data: Raw n times: {raw.n_times}')
    print(f'Epoch Data: Raw info sfreq: {temp}') # J
    print(f"Epoch Data: Raw data range: Start={start}, Stop={stop}, Duration={duration}, Overlap={overlap}")

    try:
        if overlap >= duration:
            raise ValueError(f"Overlap must be >=0 but < duration ({duration}), got {overlap}")

        events = mne.make_fixed_length_events(
            raw, id=1, start=start, stop=stop, duration=duration, overlap=overlap
        )
        print(f"Epoch Data: Generated {len(events)} fixed-length events.")
    except ValueError as e:
        print(f"Epoch Data: Error generating events: {e}")
        return None

    if len(events) == 0:
        print("Epoch Data: No fixed-length events created.")
        return None

    epochs = mne.Epochs(
        raw, events, tmin=0, tmax=duration, baseline=None, detrend=1, preload=True
    )

    # Save trial information
    trial_data = []
    for event in events:
        trial_info = {
            # Convert to seconds
            'start': event[0] / raw.info['sfreq'],
            'end': (event[0] + int(raw.info['sfreq'] * duration)) / raw.info['sfreq'],
            'samples': int(raw.info['sfreq'] * duration),
            'type': event[2]
        }
        trial_data.append(trial_info)
        # trl.append(trial_info)

    # Save trial data to a CSV file
    trial_df = pd.DataFrame(trial_data)
    filename = f"{output_loc}/{cond}_epoch_trl_{condition_label}.csv"
    trial_df.to_csv(filename, index=False)
    print(f"Epoch Data: Epoch trial data saved to '{filename}'.")

    return epochs

# Step 6: Analysis Functions (FFT, Band Power, PSD, etc.)

In [213]:
# Process Pipeline: compute_and_save_band_power(epochs_training, "train")
# Process Pipeline: compute_and_save_band_power(epochs_test, "test")

def compute_and_save_band_power(epochs, condition_name, output_loc, cond):
    print(f"\nStep 6.1: Computing band power for {condition_name}...")

    # Now a global variable
    # bands = {
    #     'Delta': (0.5, 4),
    #     'Theta': (4, 7),
    #     'Alpha': (8, 13),
    #     'Beta': (14, 20),
    #     'Gamma': (20, 100)
    # }

    band_power_results = []

    for epoch_idx, epoch_data in enumerate(epochs.get_data()):
        for ch_idx, channel_name in enumerate(epochs.ch_names):

            psd, freqs = mne.time_frequency.psd_array_welch(
                epoch_data[ch_idx], sfreq=epochs.info['sfreq'], fmin=0.1, fmax=30.0, n_per_seg=128
            )

            band_power = {}
            for band, (fmin, fmax) in bands.items():
                band_mask = (freqs >= fmin) & (freqs <= fmax)
                band_power[f'{band}_power'] = psd[band_mask].mean() if band_mask.any() else 0

            band_power['Epoch'] = epoch_idx + 1
            band_power['Channel'] = channel_name
            band_power_results.append(band_power)

    band_power_df = pd.DataFrame(band_power_results)
    output_file = f"{output_loc}/{cond}_band_power_{condition_name}.csv"
    band_power_df.to_csv(output_file, index=False)
    print(f"Band power results saved to '{output_file}'.")

In [214]:
# Process Pipeline: compute_psd_and_ratios(raw_training, "train")
# Process Pipeline: compute_psd_and_ratios(raw_test, "test")

def compute_psd_and_ratios(raw, phase, output_loc):
    print("\nStep 6.2: Computing PSD and ratios for analysis...")

    psd_results = []
    ratio_results = []

    # bands = {
    #     'delta': (0.5, 4),
    #     'theta': (4, 7),
    #     'alpha': (8, 13),
    #     'beta': (14, 20),
    #     'gamma': (20, 100)
    # }

    try:
        # Compute PSD for each channel
        for channel_idx, channel_name in enumerate(raw.ch_names):
            # Extract channel data
            psd_data = raw.get_data(picks=[channel_idx])
            psd_values, freqs = mne.time_frequency.psd_array_welch(
                psd_data, sfreq=raw.info['sfreq'], fmin=0.1, fmax=30.0, n_per_seg=int(4 * raw.info['sfreq'])
            )
            # PSD values are returned in a nested array
            psd_values = psd_values[0]

            # Calculate mean power for each band
            psd_band_values = {'Channel': channel_name}
            for band, (fmin, fmax) in bands.items():
                band_mask = (freqs >= fmin) & (freqs <= fmax)
                psd_band_values[f'psd_{band}'] = np.mean(psd_values[band_mask]) if np.any(band_mask) else np.nan

            psd_results.append(psd_band_values)

        # Save PSD results to a CSV file
        psd_df = pd.DataFrame(psd_results)
        psd_df.to_csv(f'{output_loc}/{cond}_psd_results_{phase}.csv', index=False)
        print(f"Compute PSD & Ratios: PSD results saved to '{output_loc}/{cond}_psd_results_{phase}.csv'.")

        if psd_df.empty:
            print("Compute PSD & Ratios: Error: PSD results are empty. Cannot compute ratios.")
            return

        # Retrieve values for specific channels
        pz_alpha = psd_df.loc[psd_df['Channel'] == 'Pz', 'psd_Alpha'].values
        fz_theta = psd_df.loc[psd_df['Channel'] == 'Fz', 'psd_Theta'].values

        # Handle cases where the channel values are missing or zero
        if len(pz_alpha) > 0:
            pz_alpha = pz_alpha[0]
        else:
            print("Compute PSD & Ratios: Warning: Missing value for 'Pz Alpha'. Setting to NaN.")
            pz_alpha = np.nan

        if len(fz_theta) > 0 and fz_theta[0] != 0:
            fz_theta = fz_theta[0]
        else:
            print("Compute PSD & Ratios: Warning: 'Fz Theta' is zero or invalid. Setting to NaN.")
            fz_theta = np.nan

        # Calculate alpha/theta ratios
        alpha_theta_ratio = pz_alpha / fz_theta if not np.isnan(fz_theta) else np.nan
        if np.isnan(fz_theta):
            print("Compute PSD & Ratios: Warning: Cannot compute Alpha/Theta ratio due to missing or zero 'Fz Theta'.")

        # We want the average across all channels for ratio2
        selected_channels = ['Fz', 'Cz', 'Pz']
        available_channels = psd_df['Channel'].values
        missing_channels = [ch for ch in selected_channels if ch not in available_channels]
        if missing_channels:
            print(f"Compute PSD & Ratios: Warning: Missing channels in data: {missing_channels}")

        # Filter for available channels
        psd_filtered = psd_df[psd_df['Channel'].isin(selected_channels)]

        # Calculate mean power across selected channels
        mean_alpha = psd_filtered['psd_Alpha'].mean()
        mean_theta = psd_filtered['psd_Theta'].mean()
        mean_beta = psd_filtered['psd_Beta'].mean()

        # Handle NaN values
        if np.isnan(mean_alpha) or np.isnan(mean_theta) or np.isnan(mean_beta):
            print("Compute PSD & Ratios: Warning: Missing values in the selected channels. Ratios may be inaccurate.")
            beta_combined_ratio = np.nan
        else:
            beta_combined_ratio = mean_beta / (mean_alpha + mean_theta)

        # Add PID and save ratio results
        ratio_results.append({
            'Participant ID': participant,
            'Pz Alpha': pz_alpha,
            'Fz Theta': fz_theta,
            'Pz Alpha / Fz Theta': alpha_theta_ratio,
            'Beta / (Alpha + Theta)': beta_combined_ratio,
        })

        ratio_df = pd.DataFrame(ratio_results)
        ratio_df.to_csv(f'{output_loc}/{cond}_ratios_analysis_{phase}.csv', index=False)
        print(f"Compute PSD & Ratios: Ratios analysis saved to '{output_loc}/{cond}_ratios_analysis_{phase}.csv'.")

    except Exception as e:
        print(f"Error during PSD or ratio computation: {e}")

    # Plot an example PSD (of the last channel processed)
    # psd_values, freqs = mne.time_frequency.psd_array_welch(psd_data, sfreq=raw.info['sfreq'], fmin=0.1, fmax=30.0)
    # plt.plot(freqs, psd_values[0])
    # plt.title('PSD of Fz')
    # plt.xlabel('Frequency (Hz)')
    # plt.ylabel('Power (uV^2/Hz)')
    # plt.show()

In [215]:
def finalize_gamma_table(epochs, condition_name, participant, cond):
    """
    Computes the gamma power (20-100 Hz) for each epoch and channel,
    applies a log–transform (with safeguard for non–positive values),
    and then z–transforms (z–scores) the log–powers across channels.
    Saves a per-epoch CSV file and another CSV file with the average (across epochs)
    for each channel.

    Parameters:
        epochs (mne.Epochs): The epochs object for the given condition.
        condition_name (str): Either "train" or "test".
        participant (str): Participant identifier used for output file path.
        cond (str): Additional condition identifier for the output file name.
    """
    if epochs is None:
        print(f"{condition_name.capitalize()} epochs not found.")
        return

    print(f"\nFinalizing gamma table for {condition_name}...")

    gamma_low, gamma_high = 20, 100
    sfreq = epochs.info['sfreq']
    results = []

    # Process each epoch
    for epoch_idx, epoch in enumerate(epochs.get_data()):
        log_gamma_power = []
        # Process each channel in the epoch
        for ch_idx, channel_name in enumerate(epochs.ch_names):
            psd, freqs = mne.time_frequency.psd_array_welch(
                epoch[ch_idx],
                sfreq=sfreq,
                fmin=gamma_low,
                fmax=gamma_high,
                n_per_seg=128
            )
            power = psd.mean()
            if power <= 0:
                power = 1e-10  # safeguard against non-positive power
            log_power = np.log(power)
            log_gamma_power.append(log_power)

        log_gamma_power = np.array(log_gamma_power)
        mean_val = log_gamma_power.mean()
        std_val = log_gamma_power.std()

        if std_val == 0:
            z_gamma = np.zeros_like(log_gamma_power)
        else:
            z_gamma = (log_gamma_power - mean_val) / std_val

        # Build result dictionary for this epoch
        epoch_result = {"Epoch": epoch_idx + 1}
        for ch_idx, channel_name in enumerate(epochs.ch_names):
            epoch_result[channel_name] = z_gamma[ch_idx]
        results.append(epoch_result)

    # Save per-epoch table
    df = pd.DataFrame(results)
    output_file = f"Results/{participant}/{cond}_log_z_gamma_power_{condition_name}.csv"
    df.to_csv(output_file, index=False)
    print(f"Finalized gamma table saved to '{output_file}'.")

    # Average across epochs for each channel
    avg_gamma = df.drop(columns=["Epoch"]).mean()
    avg_df = pd.DataFrame(avg_gamma).transpose()
    avg_df.insert(0, "Condition", condition_name)
    avg_output_file = f"Results/{participant}/{cond}_avg_log_z_gamma_power_{condition_name}.csv"
    avg_df.to_csv(avg_output_file, index=False)
    print(f"Averaged gamma power (log_z) saved to '{avg_output_file}'.")

# Step 7: Creating Summary Tables

In [216]:
# Ah never even got to this initially - would be nice to revisit

# compute_training_test_ratios()
# save_participant_data()
# compare_conditions()

# Last Step 8: Process Pipeline

In [217]:
# Params: set_file, marker_csv, channel_mat, fixed_channels, prestim, poststim, baseline_window

# sfreq = None
# condition_timestamps = {}
# trl = []
# participant_ratios = []
# epochs_train = None
# epochs_test = None

In [218]:
def process_pipeline(set_file, marker_csv, channel_mat, fixed_channels, prestim, poststim, baseline_window, participant, cond, cond_start, cond_end, output_loc):
    print("Starting processing pipeline...\n")

    print("\nStep 1: Global Variables Initialized")

    print("\nStep 2: Data Loading & Channel Information")

    # Load marker CSV and channel info
    marker_df = pd.read_csv(marker_csv)
    channel_info = load_channel_info(channel_mat)
    print(f'Process_Pipeline: Channel Info: {channel_info}')

    # Convert prestim and poststim from ms to seconds if needed
    prestim = prestim / 1000 # Why divide by 1000?
    poststim = poststim / 1000

    # Task: May not have to pass fixed_channels
    # raw = load_set(set_file, channel_info, fixed_channels)
    raw = load_set(set_file, channel_info)

    print("\nStep 3: Event Extraction")

    # Extract training and test segments based on marker events
    raw_training, raw_test = extract_event_windows(raw, marker_df, cond_start, cond_end, output_loc, cond)

    # Process training data
    if raw_training:
        print("\n***********TRAIN***********")
        print("\nStep 4: Preprocessing Functions")
        raw_training = apply_baseline_correction(raw_training, marker_df, "train", output_loc, cond)
        raw_training, ica_training = apply_ica(raw_training, fixed_channels, "train", output_loc, cond)
        raw_training = apply_autoreject(raw_training)
        print("\nStep 5: Epoching")
        epochs_training = epoch_data(raw_training, "train", output_loc, cond, duration=1.0, overlap=0.0)
        if epochs_training:
            print("\nStep 6: Analysis Functions")
            compute_and_save_band_power(epochs_training, "train", output_loc, cond)
            compute_psd_and_ratios(raw_training, "train", output_loc)
            
            finalize_gamma_table(epochs_training, "train", participant, cond)
            
            # self.epochs_train = epochs_training
            # self.finalize_gamma_table("train")
            
            # self.compute_fft(epochs_training, "train")
            # self.compute_band_averages(epochs_training, "train")
                
    #         finalize_gamma_table(epochs_training, "train")
    #         compute_fft(epochs_training, "train")
    #         compute_band_averages(epochs_training, "train", regions)

    # Process test data
    if raw_test:
        print("\n***********TEST***********")
        print("\nStep 4: Preprocessing Functions")
        raw_test = apply_baseline_correction(raw_test, marker_df, "test", output_loc, cond)
        raw_test, ica_test = apply_ica(raw_test, fixed_channels, "test", output_loc, cond)
        raw_test = apply_autoreject(raw_test)
        print("\nStep 5: Epoching")
        epochs_test = epoch_data(raw_test, "test", output_loc, cond, duration=1.0, overlap=0.0)
        if epochs_test:
            print("\nStep 6: Analysis Functions")
            compute_and_save_band_power(epochs_test, "test", output_loc, cond)
            compute_psd_and_ratios(raw_test, "test", output_loc)
            
            finalize_gamma_table(epochs_test, "test", participant, cond)
            
            # self.epochs_test = epochs_test
            # self.finalize_gamma_table("test")
            # self.compute_fft(epochs_test, "test")
            # self.compute_band_averages(epochs_test, "test")
            
    #         finalize_gamma_table(epochs_test, "test")
    #         compute_fft(epochs_test, "test")
    #         compute_band_averages(epochs_test, "test", regions)

    # # Additional computations
    # compute_training_test_ratios()
    # save_participant_data()
    # compare_conditions()

In [219]:
def main():
    for participant in participants:
        for cond, params in conditions.items():
            
            cond_start = params["cond_start"]
            cond_end = params["cond_end"]
            
            print("Step 0: Obtain .XDF File")
            
            set_file = get_set_file(participant)
            marker_csv = f"Raw/{participant}/{participant}events_data.csv"
            output_loc = f"Results/{participant}"
            
            # Ensure the output directory exists
            if not os.path.exists(output_loc):
                os.makedirs(output_loc)
                
            other = f"Results/{participant}/fif"
            
            # Ensure the output directory exists
            if not os.path.exists(other):
                os.makedirs(other)
            
            process_pipeline(
                set_file=set_file,
                marker_csv=marker_csv,
                channel_mat="Raw/new_channel_allocations.mat",
                fixed_channels=['Fp1', 'Fz', 'F3', 'F7', 'F9', 'FC5', 'FC1', 'C3', 'ACC_X', 'ACC_Y', 'ACC_Z'],
                prestim=0,
                poststim=1000,
                baseline_window=[2762.35, 2777.488],
                participant=participant,
                cond=cond,
                cond_start=cond_start,
                cond_end=cond_end,
                output_loc=output_loc
            )

if __name__ == "__main__":
    main()

Step 0: Obtain .XDF File
Starting processing pipeline...


Step 1: Global Variables Initialized

Step 2: Data Loading & Channel Information

Step 2.1: Loading channel.mat file for channel names...
Load Channel Info: Channel.mat Data: {'__header__': b'MATLAB 5.0 MAT-file Platform: posix, Created on: Wed Jan 15 12:32:34 2025', '__version__': '1.0', '__globals__': [], 'Channel': array([[(array(['Fz  ', 'Cz  ', 'Pz  ', 'Acc1', 'Acc2', 'Acc3'], dtype='<U4'), array([[0.1, 0. , 0. ],
               [0.2, 0.1, 0. ],
               [0.3, 0.2, 0. ],
               [0. , 0.1, 0.2],
               [0. , 0.2, 0.3],
               [0. , 0.3, 0.4]]))                                                                     ]],
      dtype=[('Name', 'O'), ('Loc', 'O')])}
Load Channel Info: Hardcoded Channel Names: ['Fz', 'Cz', 'Pz', 'Acc1', 'Acc2', 'Acc3']
Process_Pipeline: Channel Info: {'names': ['Fz', 'Cz', 'Pz', 'Acc1', 'Acc2', 'Acc3'], 'locs': array([[0.1, 0.2, 0.3],
       [0. , 0.1, 0.2],
       [0. 