# 1b. EEG Feature Extraction

- 1b EEG Feature Extraction [Jump To](#run-1b-eeg-feature-extraction)


## Some Ref Links

- The Brain Imaging Data Structure (BIDS): https://bids.neuroimaging.io
- MNE-Python: https://mne.tools/stable/index.html
- SpecParam: https://specparam-tools.github.io and https://github.com/fooof-tools

## Dependencies

General dependencies:
- python = 3.11.13
- numpy = 2.0.2
- scipy = 1.15.3
- pandas = 2.2.3
- matplotlib = 3.10.3

ML dependencies:
- scikit-learn = 1.6.1

EEG specific dependencies:
- mne = 1.9.0
- mne-icalabel = 0.7.0
- autoreject = 0.4.3
- specparam = 2.0.0rc3

# Imports & Functions

## Imports

In [None]:
# General imports
import os
import sys
import gc
import warnings

from datetime import datetime
from pprint import pprint
import time
import pickle
import random
from collections import Counter

# Custom Functions
sys.path.append(os.path.abspath('../Notebooks/Utilities')) 
import cust_utilities as utils

# Maths, Pandas etc
import math
import numpy as np
import pandas as pd
import scipy as sci

# Plots
import matplotlib.pyplot as plt
plt.style.use('ggplot')
from matplotlib.backends.backend_pdf import PdfPages

# MNE-Python
import mne
# from mne.preprocessing import ICA
# from mne_icalabel import label_components
# from autoreject import AutoReject
# from autoreject import get_rejection_threshold

# SpecParam
from specparam import SpectralGroupModel
from specparam.plts.spectra import plot_spectra
# from specparam import __version__ as specparam_version
# print('Current SpecParam version:', specparam_version)


## EEG 

In [None]:
# Setup Montage, Channels & Regions
#

def channels_setup(eeg_raw):

    # Standard brain region groupings based on 10-20 system, with the 10-10 extension
    # #https://www.sciencedirect.com/science/article/pii/S1388245717304832

    flag_verbose = (True if 'VERBOSE' in globals() and VERBOSE else False)

    BRAIN_REGIONS = {
        'frontal': ['Fp1', 'Fp2', 'Fpz', 'AF3', 'AF4', 'AF7', 'AF8', 'AFz',
                    'F1', 'F2', 'F3', 'F4', 'F5', 'F6', 'F7', 'F8', 'Fz',
                    'FC1', 'FC2', 'FC3', 'FC4', 'FC5', 'FC6', 'FCz',
                    'FT10', 'FT8', 'FT7', 'FT9'],
        'central': ['C1', 'C2', 'C5', 'C6', 'C2', 
                    'CP3','CP4','CPz',
                    'C3', 'Cz', 'C4', 'CP5', 'CP1', 'CP2', 'CP6', 
                     'P1','P2','P3','P4', 'P5','P6','Pz',
                     'T7', 'T8', 'C5', 'C6', 'TP7', 'TP8', 'TP9', 'TP10', 'P7', 'P8'],
        'occipital': ['O1', 'Oz', 'O2', 'PO9', 'PO7', 'PO3', 'POz', 'PO4', 'PO8', 'PO10'],

        }


    # Create a channel group with a dict for each region and: channel names, channel indices
    channel_groups = {}
    used_channels = eeg_raw.ch_names
    for region, channels in BRAIN_REGIONS.items():
        region_channels = [ch for ch in channels if ch in used_channels]
        if region_channels:  # Only include if channels exist
            # Get channel indices
            channel_indices = [used_channels.index(ch) for ch in region_channels]
            channel_groups[region] = {
                'channels': region_channels,
                'indices': channel_indices
        }
    
    if flag_verbose: 
        # Comparison with Raw EEG and Reference Used
        std_1020_montage = mne.channels.make_standard_montage("standard_1020")
        montage_channels_set = set(std_1020_montage.ch_names)
        subject_channels_set = set(eeg_raw.ch_names)
        regions_channels_set = set([ch for chans in BRAIN_REGIONS.values() for ch in chans])

        print(f'Missing montage - subject: {montage_channels_set - subject_channels_set}')
        print(f'Missing subject - montage {subject_channels_set - montage_channels_set}')
        print(f'Missing regions - subject {regions_channels_set - subject_channels_set}')
        print(f'Missing subject - regions {subject_channels_set - regions_channels_set}')

        # Plot the positions and regions
        ch_groups_list = []
        region_names = []
        for region, info in channel_groups.items():
            ch_groups_list.append(info['indices'])
            region_names.append(region)
        
        fig = mne.viz.plot_sensors(eeg_raw.info, show_names=True, ch_groups=ch_groups_list, show=False, pointsize=65)
        plt.title('Channels with Brain Regions')
        plt.show()

    return channel_groups


In [None]:
# Function to plot EEG Time Series & Save to PDF
#

def plot_EEG_TD(eeg_FIF, channels, time_range_s=[], pdf_file = None):
    
    """
    EEG Time Domain Plot.

    Parameters
    ----------
    eeg_FIF : FIF
    channels : [str] eg ['P5', 'Pz']
    time_range_s : [nn,nn]
    """

    # Channels Limit
    max_plots = 5
    if len(channels) > max_plots:
        channels = channels[:max_plots]

    # Time Range
    sfreq = eeg_FIF.info['sfreq']
    if time_range_s:
        start = int(time_range_s[0] * sfreq)
        stop = int(time_range_s[1] * sfreq)
    else:
        start = 0
        stop = None

    # Get signal and times for the elected channel(s)
    # Handle Epoched data
    if eeg_FIF.__class__.__name__.startswith('Raw'):
        signal, times = eeg_FIF.get_data(picks=channels, 
                                        start=start, stop=stop,
                                        return_times=True)
    else:
        # warnings.warn("Epoched data detected: plotting may not be accurate for all types.", UserWarning)
        # return
        signal = eeg_FIF.get_data(picks=channels)  # shape: (n_epochs, n_channels, n_times)
        signal = signal.mean(axis=0)  # shape: (n_channels, n_times)
        times = eeg_FIF.times  # shape: (n_times,)

    # plot and copy to pdf
    plt.figure(figsize=(12, 8))
    for i, ch_name in enumerate(channels):
        plt.plot(times, signal[i] * 1e6, label=ch_name, alpha=0.7, linewidth=0.6)
    plt.title('EEG Time Series - Channels: ' + ', '.join(channels))
    plt.xlabel('Time (s)')
    plt.ylabel('Amplitude (µV)')
    plt.legend()
    plt.tight_layout()
    if pdf_file:
        pdf_file.savefig()
    plt.show()

In [None]:
# Function to plot EEG Frequency Domain / PSD
#

def plot_EEG_PSD(spectra, channel):
    """
    EEG Frequency Domain / PSD Plot.

    Parameters
    ----------
    spectra : mne.Spectrum
    channel : str eg 'P5'
    """

    freqs = spectra.freqs
    powers = spectra.get_data(return_freqs=False)

    # Average power across all epochs for each channel
    if powers.ndim == 3:  # shape: (n_epochs, n_channels, n_freqs)
        powers = powers.mean(axis=0)  # shape: (n_channels, n_freqs)

    # print('Quick PSD')
    # spectra.plot()
    # plt.show()

    # Log-log plot for all channels
    plt.figure(figsize=(12, 6))
    for idx, ch in enumerate(spectra.ch_names):
        plt.loglog(freqs, powers[idx], alpha=0.5, label=ch if idx < 10 else None)  # label only first 10 for clarity
    plt.title(f"Power Spectrum (Log-Log) - All Channels ({spectra.method} method)")
    plt.xlabel("Frequency (Hz)")
    plt.ylabel("Power ($V^2/Hz$)")
    plt.tight_layout()
    plt.show()

    # Log-lin plot for all channels
    plt.figure(figsize=(12, 6))
    for idx, ch in enumerate(spectra.ch_names):
         plt.semilogy(freqs, powers[idx], alpha=0.5, label=ch if idx < 10 else None)  # label only first 10 for clarity
    plt.title(f"Power Spectrum (Log-Lin) - All Channels ({spectra.method} method)")
    plt.xlabel("Frequency (Hz)")
    plt.ylabel("Power ($V^2/Hz$)")
    plt.tight_layout()
    plt.show()

    # A selected channel
    ch_label = channel
    fig, (ax0, ax1) = plt.subplots(nrows=1, ncols=2, figsize=(15, 5))
    fig.suptitle(f'Power Spectrum - Channel {ch_label} ({spectra.method} method)', fontsize=20)
    ax0.set_title('Log/Log')
    ax1.set_title('Log/Lin')
    plot_spectra(freqs, powers[spectra.ch_names.index(ch_label)], log_freqs=True, log_powers=True, ax=ax0, colors='blue')
    plot_spectra(freqs, powers[spectra.ch_names.index(ch_label)], log_powers=True, ax=ax1, colors='blue')
    plt.show()


## SpecParam Functions

In [None]:
# Function to inspect quality of a SpecParam Group fit
#

def inspect_spm_group(fg, error_threshold = 0.05, r2_threshold = 0.1):

    fit_warning = False
    flagged_channels = []

    # Check Null Fits
    if fg.n_null_ > 0:
        fit_warning = True

    # Check the Mean Error & R^2
    error_mean = np.mean(fg.get_params('error'))
    r_squared_mean = np.mean(fg.get_params('r_squared'))
    if (error_mean > error_threshold) or (r_squared_mean < r2_threshold):
        fit_warning = True

    # TODO: Is this check too time consuming!?
    # Capture the worst channels
    flagged_channels = [ind for ind, res in enumerate(fg) if res.error > error_threshold]
    if len(flagged_channels) > 0:
        fit_warning = True
    
    group_summary = pd.Series({
        'SPM_fit_quality_warning': fit_warning,
        'spectra_count': len(fg),
        'null_fits': fg.n_null_,
        'error_mean': error_mean,
        'r2_mean': r_squared_mean,
        'flagged_channels': flagged_channels,
        'peaks_count_max': np.max(fg.n_peaks_),
        'peaks_count_mean': np.mean(fg.n_peaks_)
        })
    
    return group_summary


In [None]:
# Function to plot a SpecParam Fit
#

def plot_spm_fit(fm, channel_name):

    print(f'SpecParam for Channel: {channel_name}')
    fm.print_results()

    freqs = fm.freqs
    flattened_data = fm.get_model(component='peak', space='log')

    # Three plots
    fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(18, 6))
    fig.suptitle(f'SpecParam Fit for Channel: {channel_name}', fontsize=20)
    axes[0].set_title('Fit + Peaks')
    fm.plot(plot_peaks='shade', peak_kwargs={'color' : 'green'}, ax=axes[0])
    axes[1].set_title('Flattened - Offset Removed')
    axes[1].plot(freqs, flattened_data)
    axes[2].set_title('Log/Log + Peaks')
    fm.plot(plot_peaks='shade', peak_kwargs={'color' : 'green'}, plt_log=True, ax=axes[2])
    plt.show()
    

In [None]:
# # Function to plot SpecParam Results
# #

# def plot_SpecParam(fg, channel_indx, channel_name):
#     """
#     EEG SpecParam Results Plot

#     Parameters
#     ----------
#     fg : SpecParamGroupModel
#     """
#     fm = fg.get_model(ind=channel_indx, regenerate=True)

#     print(f'SpecParam for Channel: {channel_name}')
#     fm.print_results()

#     # fm.plot()
#     fig, (ax0, ax1) = plt.subplots(nrows=1, ncols=2, figsize=(12, 8))
#     fig.suptitle(f'Spec Param for Channel: {channel_name}', fontsize=20)
#     ax0.set_title('Components - Log/Lin')
#     ax1.set_title('Components - Log/Log')
#     fm.plot(plot_peaks='shade', peak_kwargs={'color' : 'green'}, ax=ax0)
#     fm.plot(plot_peaks='shade', peak_kwargs={'color' : 'green'}, plt_log=True, ax=ax1)
#     plt.show()


In [None]:
# # Test Bands

# from specparam import Bands
# from specparam.data.periodic import get_band_peak, get_band_peak_group

# bands = Bands({'Delta': [0.5, 4],
#                'Theta': [4, 8],
#                'Alpha': [8, 13],
#                'Beta': [13, 30],
#                'Gamma': [30, 100]
#                })

# def test_get_bands(fm, channel_name):
    
#     all_bands = []
#     for band_label in bands.labels:
#         band_range = bands[band_label]
#         peak_details = (band_label, get_band_peak(fm, band_range, select_highest=False))
#         all_bands.append(peak_details)

#     print(f'For Channel: {channel_name} Bands: \n{all_bands}')


# Run: 1b EEG Feature Extraction

In [None]:
# Run Details & Parameters
#

#---- Run Parameters --------------------------------
# Study Details
study_name = 'IOWA_Rest'

# Run Name, Data Source & Test Mode
run_description = 'sample_test_regions'
data_source = '1a_EEG_Preprocessing_Run_20250705_sample_test2'
test_mode = True

# Preprocessing Parameters
psd_params = {'method': 'welch', #welch or multitaper (often used for epoched data)
              'fmin': 1, 'fmax': 100,       
              'exclude': []  # Includes all, even bad channels
              }
specparam_params = {'peak_width_limits': [1, 12],
                    'max_n_peaks': 10,
                    'min_peak_height': 0.1,
                    'peak_threshold': 2.0,
                    'aperiodic_mode': 'fixed',
                    'fit_window' : [1, 100],
                    'fit_error_threshold': 0.1,
                    'fit_r2_threshold': 0.9
                     }
#----------------------------------------------------

# Get existing study details, if exists
study_folder_path = utils.get_folder_path('Study_' + study_name)
study_info = pd.read_pickle(study_folder_path + '/study_inf.pkl', compression='zip')
study_subjects_df = pd.read_pickle(study_folder_path + '/study_subjects_df.pkl', compression='zip')

# EEG Preprocessed data
eeg_preprocessed_data_path = utils.get_folder_path(study_info['eeg_processing_results_path'] + '/' + data_source + '/Cleaned_files' )

# Setup the processing run and results folder & save params
current_date = datetime.now().strftime('%Y%m%d')
run_name = f'1b_EEG_Features_Results_Run_{current_date}_{run_description}'
run_results_path = utils.extend_folder_path(study_info['eeg_processing_results_path'], run_name, exists_ok=False)

run_details = pd.Series({
    'study_name': study_name,
    'run_name': run_name,
    'eeg_preprocessed_data': data_source,
    'psd_params': psd_params,
    'specparam_params': specparam_params,
})
run_details.to_pickle(run_results_path + '/run_details.pkl', compression='zip')

# # Create empty study results
# eeg_results_superset_df = pd.DataFrame()

# Set progress messages, testing
if test_mode:
    VERBOSE = True
    test_subjects = [0,5,101]
    test_channels = ['F4', 'C5', 'Cz', 'P6']
else:
    VERBOSE = False
    test_subjects = []
    test_channels = []
    
del current_date, study_name, run_description, study_folder_path


In [None]:
# Run all the steps for EEG feature extraction
#

# Start Trace
summary = f'EEG Feature Extraction'
summary = summary + f"\n- Study: {study_info['study_name']} {study_info['dataset_ref']}"
summary = summary + f"\n- Run: {run_details['run_name']}"
summary = summary + f"\n- EEG Preprocessed Data Source: {run_details['eeg_preprocessed_data']}"
summary = summary + f"\n- PSD Params: {psd_params}"
summary = summary + f"\n- SpecParam Params: {specparam_params}"
print(summary)

# Loop through all selected subjects in the study
for idx, subject in study_subjects_df.iterrows():

    # Just sample a subset of subjects when in test mode
    if test_mode and idx not in test_subjects:
        continue

    subject_id = subject['subject_id']
    print('\n-----------------------------------------------------------------------------------------------')
    print(f'Subject: {subject_id}')

    # 1. EEG Preprocessed Data Load
    #
    print('---Get EEG Preprocessed Data - FIF ----------------------------------')
    source_file_path = utils.get_file_path(eeg_preprocessed_data_path, f"{subject['subject_id']}_preprocessed_epo.fif")
    eeg_preprocessed_epoch = mne.read_epochs(source_file_path, preload=True, verbose=VERBOSE)
    preprocessing_results = pd.read_pickle(eeg_preprocessed_data_path + f'/{subject_id}_preprocessing_results.pkl', compression='zip')

    # Get Channels grouped by region, with name and index
    channel_groups = channels_setup(eeg_preprocessed_epoch)

    if VERBOSE:
        print(f"Description: {eeg_preprocessed_epoch.info['description']} on {eeg_preprocessed_epoch.info['meas_date']}")
        print(eeg_preprocessed_epoch)
        print(eeg_preprocessed_epoch.info)

        # plot_EEG_TD(EEG_preprocessed, test_channels) .. just the first epoch

    # 2. PSD & SpecParam
    #

    # EEG Specparam DF by subject
    # Group Fit Results? ... Summary for Channel, Regions group. Keep results or just report?

    # Get SpecParam for Regions
    # Get DF

    # Get SpecParam for All Channels
    # Get DF
    # Append to overall DF
    


## Look at this comparison plot

In [None]:
def create_specparam_summary_plot(regional_psds, freqs, specparam_results):
    """
    Create summary plot of specparam results.
    """
    n_regions = len(regional_psds)
    fig, axes = plt.subplots(2, n_regions, figsize=(4*n_regions, 8))

    if n_regions == 1:
        axes = axes.reshape(2, 1)

    # Plot PSDs and fits
    for i, (region, psd) in enumerate(regional_psds.items()):
        # Plot original PSD
        axes[0, i].loglog(freqs, psd, 'k-', alpha=0.7, label='Original PSD')
        axes[0, i].set_title(f'{region.capitalize()} Region')
        axes[0, i].set_xlabel('Frequency (Hz)')
        axes[0, i].set_ylabel('Power')
        axes[0, i].legend()

        # Plot aperiodic parameters
        exponent = specparam_results[region]['aperiodic_params'][1]
        offset = specparam_results[region]['aperiodic_params'][0]

        axes[1, i].bar(['Offset', 'Exponent'], [offset, exponent])
        axes[1, i].set_title(f'{region.capitalize()} Aperiodic Params')
        axes[1, i].set_ylabel('Value')

    plt.tight_layout()
    return fig

In [None]:
aa_test = test_combine(eeg_preprocessed_epoch)

In [None]:
def test_combine(eeg_sigs):
    
    print(eeg_sigs)
    print(eeg_sigs.info)

    # PSD 
    # NB: PSD defaults to Welch for continuous, Multitaper for epoched data, here will typically force to Welch
    psds = eeg_sigs.compute_psd(method=psd_params['method'],
                                      fmin=psd_params['fmin'],
                                      fmax=psd_params['fmax'], exclude=psd_params['exclude'])
    
    print(psds)
    print(psds.info)
    psds.plot()
    plt.show()

    #
    from mne.channels import combine_channels
    # groups = {'Left': [1, 2, 3, 4], 
    #           'Centre': [5, 6, 7, 8],
    #           'Right': [12, 13, 14, 15]
    #           }
    # combined_eeg_sigs = combine_channels(eeg_sigs, region_channel_indices, method='mean')

    region_channel_indices = {region: info['indices'] for region, info in channel_groups.items()}
    combined_eeg_sigs = combine_channels(eeg_sigs, region_channel_indices, method='mean')


    psds_combined = combined_eeg_sigs.compute_psd(method=psd_params['method'],
                                      fmin=psd_params['fmin'],
                                      fmax=psd_params['fmax'], exclude=psd_params['exclude'])
    
    print(psds_combined)
    print(psds_combined.info)
    print(psds_combined.ch_names)
    psds_combined.plot()
    plt.show()


    # Get powers and frequency
    freqs = psds_combined.freqs
    powers = psds_combined.get_data(return_freqs=False)

    # Ensure lower bound limit is > 2 * sampling frequency
    lower, upper = specparam_params['peak_width_limits']
    new_lower = 4 * (eeg_sigs.info['sfreq'] / 1000)
    print(f'Given: {lower} & {upper} and new {new_lower}')
    specparam_params['peak_width_limits'][0] = new_lower if lower < new_lower else lower

    # if Epoched then average powers
    # TODO: Does averaging lose information?
    if powers.ndim == 3:  # shape: (n_epochs, n_channels, n_freqs)
        powers = np.mean(powers, axis=0)  # shape: (n_channels, n_freqs)

    # Initialise SpecParam Group Model & fit it
    progress_flag = 'tqdm.notebook' if VERBOSE else None
    fg = SpectralGroupModel(peak_width_limits=specparam_params['peak_width_limits'], 
                                max_n_peaks=specparam_params['max_n_peaks'], 
                                min_peak_height=specparam_params['min_peak_height'],
                                peak_threshold=specparam_params['peak_threshold'],
                                aperiodic_mode=specparam_params['aperiodic_mode'])
    fg.fit(freqs, powers, specparam_params['fit_window'], progress=progress_flag, n_jobs=-1)
    specparam_df = fg.to_df(specparam_params['max_n_peaks']) 

    # Overall group results
    fg.print_results()
    fg.plot()
    plt.show()

    return specparam_df


# xxxxxxx

In [None]:
# Run all the steps for EEG feature extraction
#

# Start Trace
summary = f'EEG Feature Extraction'
summary = summary + f"\n- Study: {study_info['study_name']} {study_info['dataset_ref']}"
summary = summary + f"\n- Run: {run_details['run_name']}"
summary = summary + f"\n- Data Source: {run_details['data_source']}"
summary = summary + f"\n- PSD Params: {psd_params}"
summary = summary + f"\n- SpecParam Params: {specparam_params}"
print(summary)

# Loop through all selected subjects in the study
for idx, subject in study_subjects_df.iterrows():

    # Just sample a subset of subjects when in test mode
    if test_mode and idx not in test_subjects:
        continue

    subject_id = subject['subject_id']
    print('\n-----------------------------------------------------------------------------------------------')
    print(f'Subject: {subject_id}')

    # 1. EEG Preprosessed Data Load
    #
    print('---Get EEG Preprocessed Data - FIF ----------------------------------')
    source_file_path = utils.get_file_path(source_data_path, f"{subject['subject_id']}_preprocessed_epo.fif")
    EEG_preprocessed = mne.read_epochs(source_file_path, preload=True, verbose=VERBOSE)
    preprocessing_results = pd.read_pickle(source_data_path + f'/{subject_id}_preprocessing_results.pkl', compression='zip')

    if VERBOSE:
        print(f"Description: {EEG_preprocessed.info['description']} on {EEG_preprocessed.info['meas_date']}")
        print(EEG_preprocessed)
        print(EEG_preprocessed.info)
        # plot_EEG_TD(EEG_preprocessed, test_channels)

    # 2. Power Spectra
    #
    print('---Power Spectra----------------------------------------------------')
    # NB: PSD defaults to Welch for continuous, Multitaper for epoched data, here will typically force to Welch
    psd = EEG_preprocessed.compute_psd(method=psd_params['method'],
                                      fmin=psd_params['fmin'],
                                      fmax=psd_params['fmax'], exclude=psd_params['exclude'])
    
    if VERBOSE:
        print(psd)
        print(psd.info)
        plot_EEG_PSD(psd, test_channels[0])

    # 3. Spectral Parameterisation
    #
    print('---EEG Spectral Parameterisation------------------------------------')
    # Get powers and frequency
    freqs = psd.freqs
    powers = psd.get_data(return_freqs=False)

    # Ensure lower bound limit is > 2 * sampling frequency
    lower, upper = specparam_params['peak_width_limits']
    new_lower = 4 * (EEG_preprocessed.info['sfreq'] / 1000)
    print(f'Given: {lower} & {upper} and new {new_lower}')
    specparam_params['peak_width_limits'][0] = new_lower if lower < new_lower else lower

    # if Epoched then average powers
    # TODO: Does averaging lose information?
    if powers.ndim == 3:  # shape: (n_epochs, n_channels, n_freqs)
        powers = np.mean(powers, axis=0)  # shape: (n_channels, n_freqs)

    # Initialise SpecParam Group Model & fit it
    progress_flag = 'tqdm.notebook' if VERBOSE else None
    fg = SpectralGroupModel(peak_width_limits=specparam_params['peak_width_limits'], 
                                max_n_peaks=specparam_params['max_n_peaks'], 
                                min_peak_height=specparam_params['min_peak_height'],
                                peak_threshold=specparam_params['peak_threshold'],
                                aperiodic_mode=specparam_params['aperiodic_mode'])
    fg.fit(freqs, powers, specparam_params['fit_window'], progress=progress_flag, n_jobs=-1)
    specparam_df = fg.to_df(specparam_params['max_n_peaks']) 

    # Inspect the group fit
    group_summary = inspect_spm_group(fg, 
                                      error_threshold=specparam_params['fit_error_threshold'], 
                                      r2_threshold=specparam_params['fit_r2_threshold'])
    if group_summary['SPM_fit_quality_warning']:
        warnings.warn(f'SpecParam Fit Issues', UserWarning)

    if VERBOSE:
        # print(f'Executed Settings: {fg.get_settings()}')
        
        # Overall group results
        print(f'Group Summary: \n{group_summary}')
        fg.print_results()
        fg.plot()
        plt.show()

        # Plot flagged channels
        for channel_idx in group_summary['flagged_channels']:
            channel_name = psd.ch_names[channel_idx]
            fm = fg.get_model(ind=channel_idx, regenerate=True)
            plot_spm_fit(fm, channel_name)

        # Plot test channels
        for channel_name in test_channels:
            channel_idx = psd.ch_names.index(channel_name)
            fm = fg.get_model(ind=channel_idx, regenerate=True)
            plot_spm_fit(fm, channel_name)
            test_get_bands(fm, channel_name)
    
    # 4. Subject Results
    #
    print(f'---Collating Results for {subject_id} ------------------------------')

    subject_results = pd.concat([subject, preprocessing_results, group_summary])
    # Flatten all the channels
    for idx, row in specparam_df.iterrows():
        row_series = pd.Series(row.values, index=[f'chn_{idx}_{col}' for col in row.index])
        subject_results = pd.concat([subject_results, row_series])
    
    # Append subject_results
    eeg_results_superset_df = pd.concat(
        [eeg_results_superset_df, subject_results.to_frame().T],
        ignore_index=True
    )
    
# Save all results, need to convert datatypes from generic dtype object
eeg_results_superset_df = eeg_results_superset_df.convert_dtypes()
eeg_results_superset_df.to_pickle(run_results_path + '/eeg_results_superset_df.pkl', compression='zip')


# xxxxx

In [None]:
from specparam import Bands
from specparam.data.periodic import get_band_peak, get_band_peak_group

bands = Bands({'Delta': [0.5, 4],
               'Theta': [4, 8],
               'Alpha': [8, 13],
               'Beta': [13, 30],
               'Gamma': [30, 100]
               })

deltas = get_band_peak_group(fg, bands.Delta)
thetas = get_band_peak_group(fg, bands.Theta)
alphas = get_band_peak_group(fg, bands.Alpha)
betas = get_band_peak_group(fg, bands.Beta)
gammas = get_band_peak_group(fg, bands.Gamma)



In [None]:
# # Plot test channels - Bands

# from specparam import Bands
# from specparam.data.periodic import get_band_peak, get_band_peak_group

# bands = Bands({'Delta': [0.5, 4],
#                'Theta': [4, 8],
#                'Alpha': [8, 13],
#                'Beta': [13, 30],
#                'Gamma': [30, 100]
#                })

# for channel_name in ['F4']:
#     channel_idx = psd.ch_names.index(channel_name)
#     fm = fg.get_model(ind=channel_idx, regenerate=True)

#     all_bands = []
#     for band_label in bands.labels:
#         band_range = bands[band_label]
#         peak_details = (band_label, get_band_peak(fm, band_range, select_highest=False))
#         print(peak_details)
        
#         # Ensure peak_details[1] is always a 2D array (n_peaks, 3)
#         # peak_arr = np.atleast_2d(peak_details[1])
#         # peak_details = (peak_details[0], peak_arr)
#         # all_bands.append(peak_details)
#     # plot_spm_fit_withbands(fm, channel_name, all_bands)
#     # print(all_bands)

#     # alpha_peak_max = get_band_peak(fm, bands['Alpha'], select_highest=True)

#     # # Get the label for the alpha band
#     # alpha_label = [label for label in bands.labels if bands[label] == bands['Alpha']][0]
#     # print(f"Band label for bands['Alpha']: {alpha_label}")

#     # plot_spm_fit_withbands(fm, channel_name, alpha_peak_max[0] )
#     # print(alpha_peak_max)




In [None]:
def check_nans(data, nan_policy='zero'):
    """Check an array for nan values, and replace, based on policy."""

    # Find where there are nan values in the data
    nan_inds = np.where(np.isnan(data))

    # Apply desired nan policy to data
    if nan_policy == 'zero':
        data[nan_inds] = 0
    elif nan_policy == 'mean':
        data[nan_inds] = np.nanmean(data)
    else:
        raise ValueError('Nan policy not understood.')

    return data

In [None]:
# Function to plot a SpecParam Fit - With Bands
#

def plot_spm_fit_withbands(fm, channel_name, all_bands):

    print(f'SpecParam for Channel: {channel_name}')
    fm.print_results()

    freqs = fm.freqs
    flattened_data = fm.get_model(component='peak', space='log')

    # Three plots
    fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(18, 6))
    fig.suptitle(f'SpecParam Fit for Channel: {channel_name}', fontsize=20)
    axes[0].set_title('Fit + Peaks')
    fm.plot(plot_peaks='shade', peak_kwargs={'color' : 'green'}, ax=axes[0])

    print(all_bands)
    # for band_label, band_peak in all_bands:
    #     if band_peak is not None and not np.isnan(band_peak[0]):
    #         # axes[0].axvline(x=band_peak[0], color='skyblue', linestyle='--', linewidth=1.5, label=band_label)
    #         # axes[0].axvline(x=band_peak[0], color=plt.cm.tab10(bands.labels.index(band_label)), linestyle='--', linewidth=2, label=band_label)
    #         color = plt.get_cmap('tab10')(bands.labels.index(band_label))
    #         axes[0].axvline(x=band_peak[0], color=color, linestyle='--', linewidth=2, label=band_label)
    #         axes[0].legend()
    # handles, labels = axes[0].get_legend_handles_labels()
    # by_label = dict(zip(labels, handles))
    # axes[0].legend(by_label.values(), by_label.keys())

    for band_label, band_peak in all_bands:
        for peak in band_peak:
            print(band_label, peak)


        #     # Handle if band_peak is an array of multiple peaks (shape (N, 3))
        # if isinstance(band_peak, np.ndarray) and band_peak.ndim == 2:
        #     for peak in band_peak:
        #     if peak is not None and not np.isnan(peak[0]):
        #         color = plt.get_cmap('tab10')(bands.labels.index(band_label))
        #         axes[0].axvline(x=peak[0], color=color, linestyle='--', linewidth=2, label=band_label)
        #     axes[0].legend()
        #     continue


    axes[1].set_title('Flattened - Offset Removed')
    axes[1].plot(freqs, flattened_data)
    axes[2].set_title('Log/Log + Peaks')
    fm.plot(plot_peaks='shade', peak_kwargs={'color' : 'green'}, plt_log=True, ax=axes[2])
    plt.show()
    