In [None]:
import os
import pickle
import scipy as sci
import numpy as np
import scipy.stats as stats
from statsmodels.graphics.gofplots import qqplot
import matplotlib.pyplot as plt
from scipy.signal import periodogram
import mne
import glob
from autoreject import AutoReject
from mne.preprocessing import ICA
import mne_icalabel
from mne_icalabel import label_components
import pandas as pd
from autoreject import get_rejection_threshold

plt.rcParams['figure.dpi'] = 600
plt.rcParams['savefig.dpi'] = 600


# # Check if CUDA is available
# if torch.cuda.is_available():
#     device = torch.device("cuda")
# else:
#     device = torch.device("cpu")
# print(device)
#
# # Cuda
# mne.set_config('MNE_CUDA_DEVICE', '0')
# mne.utils.set_config('MNE_USE_CUDA', 'true')

ar = AutoReject(cv=3, n_interpolate=[1, 4, 8, 16], consensus=[0.5, 1], random_state=313, n_jobs=-1,
                thresh_method='bayesian_optimization', verbose=False)

data_path = '/home/hno/datasets/update_for_hamzeh/Python/IOWA_DATA/IOWA_Processing/Public_Datasets/Raw_data/'

files = glob.glob('/home/hno/datasets/update_for_hamzeh/Python/IOWA_DATA/IOWA_Processing/Public_Datasets/Raw_data//**/*.vhdr',
                  recursive=True)

save_root = '/home/hno/datasets/update_for_hamzeh/Python/IOWA_DATA/IOWA_Processing/results/'
sbj_files = [file for file in files if '.vhdr' in file]
sbj_files = sorted(sbj_files)
parkinson = [PD for PD in sbj_files if 'PD' in PD]
healthy = [HC for HC in sbj_files if 'Con' in HC]

'''To match channels for all datasets, we selected SanDiego dataset's channel as reference, however, some datasets might 
not have all the channels for SanDiego dataset, so we also excluded them. These are all channels that seemed to be 
common for all existing datasets'''

include_channels = ['Fz', 'Fp1', 'AF3', 'F7', 'F3', 'F4', 'AF4', 'Fp2', 'F8', 'FC5', 'FC1', 'FC2', 'FC6', 'C3', 'Cz',
                    'C4', 'CP1', 'CP2', 'CP5', 'CP6', 'PO4', 'PO3', 'P3', 'Pz', 'P4', 'P7', 'P8', 'O1', 'O2', 'Oz', 'T7', 'T8']
# fmin = int(input('Enter Min Frequency: '))
# fmax = int(input('Enter Max Frequency: '))
# condition = ['PD vs HC', 'Parkinson', 'Control']
# f_range = [fmin, fmax]
# ######### Getting bad components ######
groups = ['pd', 'hc']
ica_rejected = dict()
failed_subj = dict()
bad_subj = dict()
for group in groups:
    ica_rejected[group] = dict()
    failed_subj[group] = dict()
    bad_subj[group] = dict()
######################################################
pd_result = dict()
hc_result = dict()

In [None]:
# Define paths and channel information
# data_path = '/path/to/your/data'
# include_channels = ['Fp1', 'Fp2', 'F7', 'F3', 'Fz', 'F4', 'F8', 'T7', 'C3', 'Cz', 'C4', 'T8', 'P7', 'P3', 'Pz', 'P4', 'P8', 'O1', 'O2'] # Example channels

# Frequency bands for analysis
freqs = [
    [1, 30], [1, 40], [1, 50], [1, 60], [1, 70], [1, 80], [1, 90],
    [3, 30], [3, 40], [3, 50], [3, 60], [3, 70], [3, 80], [3, 90]
]

# Initialize AutoReject
ar = AutoReject(n_interpolate=[1, 2, 4], random_state=42, n_jobs=-1, verbose=False)



'''You may not need this or you may need to change it based on your data naming'''
def parse_parkinson_name(path):
    """Extracts and formats the subject name for the Parkinson group."""
    name = path.split('Raw_data')[-1].split('Rest/')[-1].split('.vhdr')[0]
    if len(name) > 6:
        name = f"px{name.split('Rest_')[-1]}"
    return name

def parse_healthy_name(path):
    """Extracts and formats the subject name for the Healthy group."""
    return path.split('_Rest/')[-1].split('.vhdr')[0]




def process_group_data(subjects, group_key, name_parser_func, f_range, save_path):
    """
    Processes a group of subjects through the EEG pipeline.

    Args:
        subjects (list): List of file paths for the subjects in the group.
        group_key (str): A short key for the group (e.g., 'pd' or 'hc').
        name_parser_func (function): The function to use for parsing subject names.
        f_range (list): The [fmin, fmax] frequency range for the current analysis.
        save_path (str): The directory path to save results.

    Returns:
        tuple: A tuple containing dictionaries for results, failed subjects,
               bad subjects, and rejected IC counts.
    """
    results = {}
    failed_subjects = []
    bad_subjects = []
    rejected_ics = {}

    for subj_path in subjects:
        try:
            name = name_parser_func(subj_path)
            print(f"Processing Subject: {name} for frequency range {f_range} Hz...")

            # 1. Load and Filter Data
            raw = mne.io.read_raw_brainvision(subj_path, preload=True, verbose=False)
            filtered = raw.copy().pick(picks="eeg").filter(
                l_freq=1, h_freq=100, method='iir', phase='zero', verbose=False
            )

            # 2. Set Montage and Pre-process
            exclude_channels = list(set(filtered.info['ch_names']) - set(include_channels))
            filtered.drop_channels(exclude_channels, on_missing='raise', verbose=False)
            
            # Use raw's montage for healthy, filtered's for parkinson
            montage = filtered.get_montage()
            filtered.set_montage(montage, match_case=False, verbose=False)
            
            filtered.apply_function(sci.signal.detrend, n_jobs=-1, channel_wise=True, type='linear', verbose=False)
            filtered.set_eeg_reference(ref_channels='average', verbose=False)

            # 3. ICA for Artifact Removal
            ica = ICA(n_components=len(filtered.ch_names) - 1, max_iter="auto", method="infomax", random_state=313, fit_params=dict(extended=True))
            ica.fit(filtered.copy(), verbose=False)

            # 4. Label and Exclude Bad Components
            ica_labels = label_components(filtered, ica, method='iclabel')
            labels = ica_labels["labels"]
            exclude_idx = [i for i, label in enumerate(labels) if label != "brain"]
            rejected_ics[name] = len(exclude_idx)

            # Reject subject if too many components are bad
            if len(np.unique(exclude_idx)) >= len(labels) * 0.8:
                print(f'Too many components rejected for Subj {name}. Skipping.')
                bad_subjects.append(subj_path)
                continue

            raw_ica = ica.apply(filtered.copy(), exclude=np.unique(exclude_idx), verbose=False)

            # 5. Epoching and Auto-Rejection
            epochs = mne.make_fixed_length_epochs(raw_ica, duration=1, overlap=0, preload=True, verbose=False)
            epochs.apply_function(sci.signal.detrend, type='linear', verbose=False)

            reject_criteria = get_rejection_threshold(epochs, verbose=False)
            epochs.drop_bad(reject=reject_criteria, verbose=False)
            
            epochs_ar, _ = ar.fit_transform(epochs, return_log=True)
            epochs_ar.pick(['eeg'])
            
            print(f'Number of rejected ICs for {name}: {rejected_ics[name]} out of {len(labels)}')
            epochs_ar.save(os.path.join(save_path, f'{name}_epo.fif'), overwrite=True, verbose=False)
            
            '''For now you do not need to go though fooofing'''
            # 6. Calculate FOOOF.
        #     results[name] = cal_foof(
        #         epochs_ar, f_range=f_range, low_freq_to_del=False,
        #         high_freq_to_del=False, remove_line_noise=False
        #     )

        # except Exception as e:
        #     print(f"!!! Failed to process subject {subj_path}. Error: {e}")
        #     failed_subjects.append(subj_path)
        #     continue
            
    return results, failed_subjects, bad_subjects, rejected_ics


# --- Main Execution Loop ---

for fmin, fmax in freqs:
    f_range = [fmin, fmax]
    save_fooof = f'/space/slow/hno/update_for_hamzeh/Python/IOWA_DATA/IOWA_Processing/Range_{fmin}-{fmax}_Hz'
    os.makedirs(save_fooof, exist_ok=True)
    print(f'\n--- Starting Analysis for Frequency Range {f_range} Hz ---')

    # Define configurations for each group
    group_configs = {
        'parkinson': {
            'subjects': parkinson,
            'key': 'pd',
            'name_parser': parse_parkinson_name,
            'rejected_ics_fname': 'PD_rejected_ICs.pkl',
            'results_fname': 'parkinson.pkl'
        },
        'healthy': {
            'subjects': healthy,
            'key': 'hc',
            'name_parser': parse_healthy_name,
            'rejected_ics_fname': 'HC_rejected_ICs.pkl',
            'results_fname': 'healthy.pkl'
        }
    }

    all_failed = {}
    all_bad = {}

    # Process each group using the unified function
    for group_name, config in group_configs.items():
        print(f"\n... Processing {group_name.capitalize()} Group ...")
        
        group_results, failed, bad, ics = process_group_data(
            subjects=config['subjects'],
            group_key=config['key'],
            name_parser_func=config['name_parser'],
            f_range=f_range,
            save_path=save_fooof
        )
        
        all_failed[config['key']] = failed
        all_bad[config['key']] = bad
        
        # Save results for the current group
        with open(os.path.join(save_fooof, config['rejected_ics_fname']), 'wb') as f:
            pickle.dump(ics, f)
        
        with open(os.path.join(save_fooof, config['results_fname']), 'wb') as f:
            pickle.dump(group_results, f)

    # Save combined lists of failed and bad subjects
    with open(os.path.join(save_fooof, 'failed_subj.pkl'), 'wb') as f:
        pickle.dump(all_failed, f)
        
    with open(os.path.join(save_fooof, 'bad_subj.pkl'), 'wb') as f:
        pickle.dump(all_bad, f)

print("\n--- All Processing Complete ---")
