In [None]:
import mne
import numpy as np
import mne_connectivity
import os
import pactools

def extract_features(data, sfreq, freq_ranges, bw_scales, use_psd, use_pac, use_wpli, feature_path, subject):
    freqs_arrays = {}
    for freq_ind, (freq_identifier,(fmin, fmax)) in enumerate(freq_ranges.items()):
        psds, freqs, bw = get_bandpowers(data, fmin, fmax, sfreq, bw_scales[freq_identifier])
        f_bins = np.arange(fmin + bw/2,fmax,bw)
        freqs_arrays[freq_identifier] = f_bins
        print(f_bins)
        if use_psd:
            bandpowers = np.mean(psds,axis=2)
            #print(bandpowers.shape)
            filepath_bandpowers = os.path.join(feature_path,f"{subject}_{freq_identifier}_bandpowers")
            np.save(filepath_bandpowers, bandpowers)
        if use_wpli['bool']: #do connectivity or not
            n_cycles = use_wpli['n_cycles'][freq_identifier] #number of cycles to use for this freq band
            con = get_con(data,freqs, sfreq, n_cycles)
            filepath_wplis = os.path.join(feature_path,f"{subject}_{freq_identifier}_wplis")
            con.save(filepath_wplis, con) #save the con
    if use_pac['bool']: #calculate pac or not
        for freq_ind1, freq_identifier1 in enumerate(list(freq_ranges.keys())):
            for freq_ind2, freq_identifier2 in enumerate(list(freq_ranges.keys())):
                if freq_ind1 < freq_ind2:
                    freqs_1 = freqs_arrays[freq_identifier1] #freqs lower
                    freqs_2 = freqs_arrays[freq_identifier2] #freqs upper
                    pac = get_pac(data, sfreq, freqs_1, freqs_2) #get the pac
                    filepath_pacs_now = os.path.join(feature_path,f"{subject}_{freq_identifier1}-{freq_identifier2}_pac")
                    np.save(filepath_pacs_now,pac) #save the pac

def get_bandpowers(data,fmin,fmax,sfreq, bw_scale):
    n_times = data.shape[-1] #number of time points
    bw = bw_scale * (sfreq / n_times) #used bandwidth
    psds, freqs = mne.time_frequency.psd_array_multitaper(data, sfreq, fmin=fmin, fmax=fmax, bandwidth=bw, output='power')
    return psds, freqs, bw

def get_con(data,freqs,sfreq,n_cycles): #get connectivity as wpli
    con = mne_connectivity.spectral_connectivity_time(data, freqs=freqs, method='wpli', average=False, mode='multitaper',sfreq=sfreq, faverage=True,n_cycles=n_cycles,verbose=False,n_jobs=1)
    return con

def get_pac(data, sfreq, freqs_lower, freqs_upper):
    comods = []
    n_trials, n_pos, n_times = data.shape #n_pos is number of labels or number of channels
    for trial_ind in range(n_trials):
        data_trial = data[trial_ind,:,:] #data at this trial
        chans_pacs_trials = []
        for ch_ind in range(n_pos): #go over all pos (channels or labels)
            ch_data_trial = data_trial[ch_ind,:] #data of this pos in this trial
            estimator = pactools.Comodulogram(fs=sfreq, low_fq_range=freqs_lower, high_fq_range=freqs_upper, method='tort',progress_bar=False, random_state=42, n_jobs=1) #fix random state for consistency
            estimator.fit(ch_data_trial) #fit data
            chans_pacs_trials.append(estimator.comod_)
        comods.append(np.array(chans_pacs_trials))
    return np.array(comods)

In [None]:
import mne
import numpy as np
import os
from concurrent.futures import ThreadPoolExecutor
from queue import Queue
import logging

def process_subject(source_site, features_site, subject, typenow, use_pac, use_wpli, use_psd, bw_scales, fr_names, tmin, tmax):
    try:
        logging.info(f"Processing {subject}")
        sourcepath_subject = os.path.join(source_site,subject)
        featurepath_subject = os.path.join(features_site,subject,typenow)

        if typenow == 'source_depth0.8': #then load source estimates (labelled)
            for parctype in [str(['n15', 'p30', 'n45', 'p60', 'handknob']),'aparc']:
                featurepath_subject_parctype = os.path.join(featurepath_subject,str(subject + "_" + parctype))
                epochs_filepath = os.path.join(sourcepath_subject,f'{subject}_final_eeg-epo.fif')
                source_estimates_path_subject = os.path.join(sourcepath_subject,f'{subject}_stcs_in_fsaverage_{parctype}_depth0.8')
                n_files = len([file for file in os.listdir(source_estimates_path_subject) if f"fsaverage_{parctype}_epoch" in file]) #number of files of this type in the folder
                data_all = np.array([np.load(os.path.join(source_estimates_path_subject,f"{subject}-fsaverage_{parctype}_epoch_{k}.npy")) for k in range(n_files)]) #load all data
                epochs = mne.read_epochs(epochs_filepath) #get the respective epochs info for reading the sampling freq
                times = epochs.times
                indices = np.array([index for index in range(len(times)) if tmin <= times[index] <= tmax]) #indices to use from source estimates
                data = data_all[:,:,indices]
                sfreq = epochs.info['sfreq'] #used sampling frequency
                #print(data.shape) #trials x labels x timepoints

                os.makedirs(featurepath_subject_parctype, exist_ok=True) #make dir if needed

                freq_range_dict = np.load(os.path.join(features_site,subject,f'{subject}-freq_ranges_dict.npy'),allow_pickle=True).item()
                freq_ranges = {fr_name:freq_range_dict[fr_name] for fr_name in fr_names} #do not take alpha peak or the boolean related to it

                extract_features(data, sfreq, freq_ranges, bw_scales, use_psd, use_pac, use_wpli, featurepath_subject_parctype, subject)
            logging.info(f"Completed processing {subject}")
        else:
            print("wrong typenow (only source curretly)")
    except Exception as e:
        logging.error(f"Error processing subject {subject}: {str(e)}")

In [None]:
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

fr_names = ['theta','alpha','beta','gamma']
use_pac = {'bool':False}
use_wpli = {'bool':False, 'n_cycles':{'theta':1, 'alpha':5, 'beta':7, 'gamma':7}} #whether to calculate wpli or not and the associated n_cycles values for each freq band
bw_scales = {'theta':2, 'alpha':2, 'beta':3, 'gamma':6} #scales for calculating bandwidths for different freq ranges
use_psd = True
#define data cropping times
tmin = -1.015
tmax = -0.015
#max_workers = 26 #n workers


for typenow in ['source_depth0.8']:
    for siteind, site in enumerate(['Tuebingen','Aalto']):
        source_site = rf"D:\REFTEP_ALL\Source_analysis\Source_analysis_{site}"
        features_site = rf"D:\REFTEP_ALL\Features_v2\Features_{site}"
        subjects = [dirname for dirname in os.listdir(source_site) if "sub" in dirname]
        for subject in subjects:
            process_subject(source_site, features_site, subject, typenow, use_pac, use_wpli, use_psd, bw_scales, fr_names, tmin, tmax)