In [34]:
import mne
import numpy as np

In [None]:
import pandas as pd
participants = pd.read_csv('../Dataset/participants.tsv', sep='\t')
participants.head() # To inspect the first few rows

In [None]:
participants.columns

In [None]:
import os

# List of subject IDs (assuming you have filenames like sub-pd01, sub-hc01, etc.)
subjects = participants["participant_id"] # List all subjects here
subjects=subjects.to_list()
sessions = ['ses-off', 'ses-on']  # for Parkinson's patients, if available

print(subjects)

In [None]:
import os

# List of subject IDs (assuming you have filenames like sub-pd01, sub-hc01, etc.)
subjects = participants.participant_id # List all subjects here
subjects=subjects.to_list()
sessions = ['ses-off', 'ses-on']  # for Parkinson's patients, if available

print(subjects)

In [None]:
subject_files_pd_on=[]
subject_files_pd_off=[]
subject_files_hc=[]

for subject in subjects:
    if 'pd' in subject:
        for session in sessions:
            file_path = f"../Dataset/{subject}/{session}/eeg/{subject}_{session}_task-rest_eeg.bdf"
            if session == 'ses-on':
                subject_files_pd_on.append(file_path)
            else:
                subject_files_pd_off.append(file_path)
    elif 'hc' in subject:
        session = 'ses-hc'
        file_path = f"../Dataset/{subject}/{session}/eeg/{subject}_{session}_task-rest_eeg.bdf"
        subject_files_hc.append(file_path)

print(subject_files_pd_on)
print(subject_files_pd_off)
print(subject_files_hc)

In [40]:
def set_montage(raw_data):
    montage = mne.channels.make_standard_montage('biosemi32')
    raw_data.set_montage(montage, on_missing='warn')
    return raw_data

In [41]:
def bandpass_filter(raw_data, l_freq=0.5, h_freq=50.0):
    raw_data.filter(l_freq=l_freq, h_freq=h_freq)
    return raw_data

In [42]:
# Use T7 or T8 as proxy ECG channels (experimental approach)
def find_ecg_via_temporal_channels(ica, raw_data):
    # Experimentally identify ECG-like artifacts using temporal channels
    ecg_indices, ecg_scores = ica.find_bads_ecg(raw_data, ch_name='T7')
    ica.exclude += ecg_indices  # Exclude identified ECG-like components
    return ica

In [43]:
from mne.preprocessing import ICA
def apply_ica(raw_data, n_components=32):
    ica = ICA(n_components=n_components, random_state=97, max_iter="auto")
    ica.fit(raw_data)
    
    # Detect artifacts
    eog_indices, _ = ica.find_bads_eog(raw_data,ch_name=['Fp2', 'F8'],threshold=1.96)  # Detect eye blink components
    
    # Mark components for removal
    ica.exclude = eog_indices
    # Experimental ECG detection
    ica = find_ecg_via_temporal_channels(ica, raw_data)
    
    # Apply ICA to remove artifacts
    raw_data = ica.apply(raw_data)
    return raw_data

In [44]:
def segment_data(raw_data, duration=1.0):
    events = mne.make_fixed_length_events(raw_data, duration=duration)
    epochs = mne.Epochs(raw_data, events, tmin=0, tmax=duration, baseline=None, preload=True)
    eeg_data = epochs.get_data()  # Shape should be (180, 32, 512) if 3 mins, 32 channels, 512 samples/s
    return eeg_data

In [45]:
# Define frequency bands
freq_bands = {
    "delta": (1, 4),
    "theta": (4, 8),
    "alpha": (8, 12),
    "beta": (13, 30),
    "gamma": (30, 48)
}

In [46]:

def compute_psd(eeg_data, sfreq):
    psd_features = {}
    for band, (low, high) in freq_bands.items():
        psd_band, _ = mne.time_frequency.psd_array_multitaper(
            eeg_data, sfreq=sfreq, fmin=low, fmax=high, adaptive=True, normalization='full'
        )
        psd_features[band] = psd_band.mean(axis=2)  # Average PSD across time
    return psd_features

In [47]:
# Define a function to process a single subject and extract PSD features
def process_subject(raw_data, sfreq):
    # Apply bandpass filtering and artifact removal here as per previous preprocessing steps
    set_montage(raw_data)
    raw_filtered = bandpass_filter(raw_data)  # Assuming bandpass_filter function is defined
    raw_filtered = apply_ica(raw_filtered)
    epochs = segment_data(raw_filtered)  # Assuming segment_data function is defined to get (180, 32, 512)
    eeg_data = epochs[:, :, :512]  # Shape (180, 32, 512)

    # Compute PSD features for this subject
    psd_features = compute_psd(eeg_data, sfreq)  # Dictionary with PSD for each frequency band
    return psd_features

In [48]:
all_psd_features = {'delta': {'data':[],'label':[]}, 'theta': {'data':[],'label':[]}, 'alpha': {'data':[],'label':[]}, 
                        'beta': {'data':[],'label':[]}, 'gamma': {'data':[],'label':[]} }

# Loop over all subjects and store PSD features
def collect_psd_features(subject_files, sfreq,label):

    for subject_file in subject_files:
        # Load subject's data
        raw_data = mne.io.read_raw_bdf(subject_file, preload=True)
        raw_data.crop(tmax=180.)  # Keep first 180 seconds
        raw_data = raw_data.drop_channels(['EXG1', 'EXG2', 'EXG3', 'EXG4', 'EXG5', 'EXG6', 'EXG7', 'EXG8', 'Status'])

        # Process subject data to extract PSD features
        psd_features = process_subject(raw_data, sfreq)  # Returns {alpha:[180x32], delta:[180x32], ...}
        print("Shape of psd_features[delta] in each subject:", np.asarray(psd_features['delta']).shape)

        # Append features and corresponding labels for each frequency band
        for band in all_psd_features.keys():
            all_psd_features[band]['data'].append(psd_features[band])  # Store 2D matrix (180x32)
            all_psd_features[band]['label'].append(label)  # Store corresponding integer label

    return all_psd_features

# all_psd_features = {
#     'alpha': {
#         'data': [array(180x32), array(180x32)],  # 2 subjects' data
#         'label': [1, 1]  # 2 corresponding labels
#     },
#     'delta': {
#         'data': [array(180x32), array(180x32)],
#         'label': [1, 1]
#     },
#     ...
# }



Each subject’s PSD feature for the δ band should have a shape of 
180
×
32
180×32 (180 time samples, 32 channels).

In [49]:
# Collect PSD features for each group
sfreq = 512  # Sample frequency as given

For PD_on or PD_off

In [None]:
# psd_features_pd = collect_psd_features(subject_files_pd_on, sfreq,1)
psd_features_pd = collect_psd_features(subject_files_pd_off, sfreq,1)

In [None]:
print(len(all_psd_features["alpha"]['data']))
print(len(all_psd_features["alpha"]['data'][0]))
print(len(all_psd_features["alpha"]['data'][0][0]))
print(all_psd_features["alpha"]['label'])

For HC

In [None]:
psd_features_hc = collect_psd_features(subject_files_hc, sfreq,0)

In [None]:
print(len(all_psd_features["alpha"]['data']))
print(len(all_psd_features["alpha"]['data'][0]))
print(len(all_psd_features["alpha"]['data'][0][0]))
print(all_psd_features["alpha"]['label'])

In [54]:
# all_psd_features = {
#     'alpha': {
#         'data': [array(180x32), array(180x32),.....],  # 31 subjects' data (PD_ON vs HC)
#         'label': [1, 1, ...., 0]  # 31 corresponding labels
#     },
#     'delta': {
#         'data': [array(180x32), array(180x32), .....],
#         'label': [1, 1, ....]
#     },
#     ...
# }


In [55]:
import pickle

def save_psd_features(all_psd_features, output_dir="new_data/pd_off vs hc"):

    import os

    # Create directory if it does not exist
    os.makedirs(output_dir, exist_ok=True)

    for band, band_data in all_psd_features.items():
        file_path = os.path.join(output_dir, f"{band}.pkl")
        
        with open(file_path, 'wb') as f:
            pickle.dump(band_data, f)
        
        print(f"Saved {band} features to {file_path}")

# alpha.pkl{
#     "data": [array(180x32), array(180x32), ...],  # 31 subjects' data
#     "label": [1, 1, ..., 0]  # 31 corresponding labels
# }


In [None]:
# Example usage
save_psd_features(all_psd_features)

<h1>Loading .pkl</h1>

In [None]:
def load_psd_features(band_name, input_dir="new_data/pd_off vs hc"):

    file_path = os.path.join(input_dir, f"{band_name}.pkl")

    with open(file_path, 'rb') as f:
        band_data = pickle.load(f)

    return band_data

# Example usage
alpha_data = load_psd_features("delta")
print(len(alpha_data["data"]))  # Should print 31
print(alpha_data["data"][0].shape)  # Should print (180, 32)
print(alpha_data["label"])  # Should print [1, 1, ..., 0]
