# [Working Title] Immediate and 6 month outcomes from the LEISURE study, a multidomain lifestyle intervention study to reduce dementia risk in healthy older people

Andrews, S.C., Treacy, C., Pace, T., Levenstein, J., Quigley, B., Metse, A.P., Schaumberg, M.A., Villani, A., Campbell, A.J , Lagopoulos, J., & Hermens, D.

In [None]:
import warnings
warnings.filterwarnings("ignore")

import mne, os
import seaborn as sns
import warnings
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from fooof import FOOOF 
from fooof.plts.spectra import plot_spectrum

# Set the current working directory to be the project main folder
os.chdir('/Users/aliciacampbell/Documents/code/EEG-notebooks')

import basic.arrange_data as arrange
import signal_processing.pre_process_v2 as pre_process
import signal_processing.spectral_analysis as spectr

In [None]:
### DEFINE ###
raw_folder = 'Data/Raw/'
clean_folder = 'Data/Clean/'
spectra_folder = 'Data/Spectra/'
results_folder = 'Results/'
savefinal_folder = 'Results/LEISURE/'

exp_folder = 'LEISURE'
exp_timepoint = ['T3'] # 'T1', 'T2', 
exp = 'EC'

In [None]:
### DEFINE ###
montage = 'biosemi32'
eog_channels = ['EXG3', 'EXG4', 'EXG5', 'EXG6']
misc = ['EXG1', 'EXG2', 'EXG7', 'EXG8']
all_external_channels = eog_channels + misc
stimulus_channel = 'Status'  
erg_channel = 'Erg1'  
reference = 'average'
epochs_duration = 5
filter_design = dict(l_freq=0.5,h_freq=30,filter_length='auto',method='fir',
                     l_trans_bandwidth='auto',h_trans_bandwidth='auto',
                     phase='zero',fir_window='hamming',fir_design='firwin')

wrong_mnt_chs = ["B1","B2","B3","B4","B5","B6","B7","B8","B9","B10","B11","B12",
                "B13","B14","B15","B16","B17","B18","B19","B20","B21","B22","B23","B24",
                "B25","B26","B27","B28","B29","B30","B31","B32",
                "C1","C2","C3-1","C4-1","C5","C6","C7","C8","C9","C10","C11","C12",
                "C13","C14","C15","C16","C17","C18","C19","C20","C21","C22","C23","C24",
                "C25","C26","C27","C28","C29","C30","C31","C32",
                "D1","D2","D3","D4","D5","D6","D7","D8","D9","D10","D11","D12",
                "D13","D14","D15","D16","D17","D18","D19","D20","D21","D22","D23","D24",
                "D25","D26","D27","D28","D29","D30","D31","D32"]

In [None]:
def find_event_pair(events, start_id, end_id):
    """function for finding event ID pairs."""
    for i in range(len(events)):
        if events[i][2] == start_id:
            start_sample = events[i][0]
            # Find the next occurrence of end_id
            for j in range(i + 1, len(events)):
                if events[j][2] == end_id:
                    end_sample = events[j][0]
                    return start_sample, end_sample
    return None, None  # Return None if no valid pair is found

for timepoint in exp_timepoint:
    print('Checking files in', timepoint)
    # Set the directory in progress and find all BDF (raw EEG) files in there
    dir_inprogress = os.path.join(raw_folder, exp_folder, timepoint, exp)
    export_dir = os.path.join(clean_folder, exp_folder, timepoint, exp)
    file_dirs, subject_names = arrange.read_files(dir_inprogress, '.bdf')

    for i in range(len(file_dirs)):
        print(f'\n{subject_names[i]}...')
        # Read in the raw EEG data
        raw = mne.io.read_raw_bdf(file_dirs[i], infer_types=True, eog=eog_channels,misc=misc,
                                stim_channel=stimulus_channel, verbose=False)
        
        if 'Erg1' in raw.info['ch_names']:
            print('Erg1 channel detected and dropping..')
            raw = raw.drop_channels(['Erg1'])

        
        if len(set(wrong_mnt_chs).intersection(set(raw.info['ch_names']))) > 0:
            print('Wrong montage, removing bad channels and replacing some channel names..')
            print('C-channels before:', [c for c in raw.info['ch_names'] if c.startswith('C')])
            raw = raw.drop_channels(wrong_mnt_chs).rename_channels({'C3-0':'C3','C4-0':'C4'})
            print('C-channels after:', [c for c in raw.info['ch_names'] if c.startswith('C')])
        
        # Set the right montage (Biosemi32) and set reference as average across all channels
        raw = raw.set_montage(mne.channels.make_standard_montage(montage), verbose=False).load_data()\
                .set_eeg_reference(ref_channels=reference, verbose=False)

        # Find event markers for the start and end of resting state recordings
        events = mne.find_events(raw, stim_channel=stimulus_channel, consecutive=False, output='offset', verbose=False)
        print(f'Datapoints and IDs of the events:\n{events}')
        
        s_minmax = find_event_pair(events, start_id=8, end_id=2)
        if s_minmax[0] and s_minmax[1]:
            print(f"Starting point sample: {s_minmax[0]}")
            print(f"Ending point sample: {s_minmax[1]}")
        else:
            print("No valid event pair (8->2) found.")

        # Use the markers to crop to EEG signal to leave only the actual EC resting state
        if s_minmax[0] != None and s_minmax[1] != None:
            t_minmax = [s_minmax[0]/raw.info['sfreq'], s_minmax[1]/raw.info['sfreq']]
            raw_c = raw.crop(tmin=t_minmax[0], tmax=t_minmax[1])
            tlen = t_minmax[1] - t_minmax[0]
            if (230 <= tlen <= 250) == True:
                print(f'Raw signal length = {tlen} s')
            else:
                print(f'WARN: Raw signal length = {tlen} s (i.e. not between 230-250 s)')
            raw_c = raw_c.drop_channels(stimulus_channel)
        else:
            print(f'WARN: No cropping done.')
            raw_c = raw.copy()
        
        # Filter the signal with bandpass filter and remove EOG artefacts with SSP
        filt = pre_process.filter_raw_data(raw_c, filter_design, line_remove=None, eog_channels=eog_channels, 
                                        drop_external_channels=all_external_channels, 
                                        plot_filt=False, savefig=False, verbose=False)

        # Divide the filtered signal to epochs and run Autoreject artefact rejection on the epochs
        # %matplotlib inline
        epochs = pre_process.artefact_rejection(filt, subject_names[i], epo_duration=epochs_duration,
                                                pltfig=False, savefig=True, verbose=False)


        # save the EEG file
        mne.Epochs.save(epochs,fname='{}/{}_clean-epo.fif'.format(export_dir,subject_names[i]), overwrite=True)

In [None]:
### DEFINE ###
bands = {'Alpha' : [7, 14]}
brain_regions = {'Global' : ['Fp1', 'AF3', 'F3', 'FC1', 'Fp2', 'AF4', 'F4', 'FC2', 'Fz',
                             'F7', 'FC5', 'T7', 'C3', 'F8', 'FC6', 'T8', 'C4', 'Cz',
                             'CP5', 'P3', 'P7', 'CP6', 'P4', 'P8', 'CP1', 'CP2', 'Pz',
                             'PO3', 'PO4', 'O1', 'O2', 'Oz']}

ind_spectr_type = 'linear_flat' # linear_normal, _normal, 'linear_flat', log_flat
plot_rich = True
savefig = True
savespectrum = False
psd_params = dict(method='welch', fminmax=[1, 30], window='hamming', window_duration=2.5,
                  window_overlap=0.5, zero_padding=39)
fooof_params = dict(peak_width_limits=[1,12], max_n_peaks=float("inf"), min_peak_height=0.05,
                    peak_threshold=2.0, aperiodic_mode='fixed')

spectrum_name = psd_params['method']+'_'+str(psd_params['fminmax'][0])+'-'+str(psd_params['fminmax'][1])+'Hz_WIN='+str(
                psd_params['window_duration'])+'s_'+psd_params['window']+'_OL='+str(psd_params['window_overlap']*
                100)+'%_ZP='+str(psd_params['zero_padding']*psd_params['window_duration'])+'s'

sns.set_palette('muted')
sns.set_style("whitegrid")

In [None]:
for timepoint in exp_timepoint:
    print('Checking files in', timepoint)
    # Set the directory in progress and find all BDF (raw EEG) files in there
    dir_inprogress = os.path.join(clean_folder, exp_folder, timepoint, exp)
    file_dirs, subject_names = arrange.read_files(dir_inprogress, '_clean-epo.fif')
    arrange.create_results_folders(exp_folder=exp_folder, results_folder=results_folder, fooof=True)

    for i in range(len(file_dirs)):
        # Read in the clean EEG data
        epochs = mne.read_epochs(fname='{}/{}_clean-epo.fif'.format(dir_inprogress, subject_names[i]),
                                                                    verbose=False)
        
        # Calculate Welch's power spectrum density
        [psds,freqs] = spectr.calculate_psd(epochs, subject_names[i], method=psd_params['method'],
                                            fminmax=psd_params['fminmax'], window=psd_params['window'],
                                            window_duration=psd_params['window_duration'],
                                            window_overlap=psd_params['window_overlap'],
                                            zero_padding=psd_params['zero_padding'],
                                            verbose=True, plot=False)
        
        # Average all epochs and channels together -> (freq bins,) shape
        if i == 0:
            psds_allch = np.zeros(shape=(len(file_dirs), len(freqs)))
        psds_allch[i] = psds.mean(axis=(0, 1))

        # Average all epochs together for each channel and also for each region
        psds = psds.mean(axis=(0))
        df_psds_ch = arrange.array_to_df(subject_names[i], epochs, psds).\
                                reset_index().drop(columns='Subject')
        df_psds_regions = arrange.df_channels_to_regions(df_psds_ch, brain_regions).\
                                    reset_index().drop(columns='Subject')

        # Go through all regions of interest
        for region in df_psds_regions.columns:
            if i == 0:
                globals()["df_fooof_"+region] = pd.DataFrame(index=subject_names)
                globals()["df_powerspectra_"+region] = pd.DataFrame(columns=freqs, index=subject_names)
                globals()["df_flatpowerspectra_"+region] = pd.DataFrame(columns=freqs, index=subject_names)

            psds_temp = df_psds_regions[region].to_numpy()

            # Fit the spectrum with FOOOF (specparam)
            fm = FOOOF(**fooof_params, verbose=True)
            fm.fit(freqs, psds_temp, psd_params['fminmax'])
                
            # Log-linear conversion based on the chosen amplitude scale
            if ind_spectr_type == 'linear_flat':
                flatten_spectrum = 10 ** fm._spectrum_flat
                flat_spectr_ylabel = 'Flattened power (µV\u00b2/Hz)'
            elif ind_spectr_type == 'log_flat':
                flatten_spectrum = fm._spectrum_flat
                flat_spectr_ylabel = 'Flattened log10-transformed power'
            elif ind_spectr_type == 'linear_normal':
                flatten_spectrum = psds_temp
                flat_spectr_ylabel = 'Power (µV\u00b2/Hz)'
            elif ind_spectr_type == 'log_normal':
                flatten_spectrum = np.log10(psds_temp)
                flat_spectr_ylabel = 'Log10-transformed power'

            # Find individual alpha band parameters
            cf, pw, bw, abs_bp, rel_bp = spectr.find_ind_band(flatten_spectrum, freqs,
                                                                bands['Alpha'], bw_size=6)

            # Set plot styles
            data_kwargs = {'color' : 'black', 'linewidth' : 1.4, 'label' : 'Original'}
            model_kwargs = {'color' : 'red', 'linewidth' : 1.4, 'alpha' : 0.75, 'label' : 'Full model'}
            aperiodic_kwargs = {'color' : 'blue', 'linewidth' : 1.4, 'alpha' : 0.75,
                                'linestyle' : 'dashed', 'label' : 'Aperiodic model'}
            flat_kwargs = {'color' : 'black', 'linewidth' : 1.4}
            hvline_kwargs = {'color' : 'blue', 'linewidth' : 1.0, 'linestyle' : 'dashed', 'alpha' : 0.75}

            # Plot power spectrum model + aperiodic fit
            fig, axs = plt.subplots(nrows=1, ncols=2, figsize=(10, 4), dpi=100)
            plot_spectrum(fm.freqs, fm.power_spectrum,
                        ax=axs[0], **data_kwargs)
            plot_spectrum(fm.freqs, fm.fooofed_spectrum_,
                        ax=axs[0], **model_kwargs)
            plot_spectrum(fm.freqs, fm._ap_fit,
                        ax=axs[0], **aperiodic_kwargs)
            axs[0].set_xlim(psd_params['fminmax'])
            axs[0].grid(linewidth=0.2)
            axs[0].set_xlabel('Frequency (Hz)')
            axs[0].set_ylabel('Log10-transformed power')
            axs[0].set_title('Original power spectrum with model fit')
            axs[0].legend()
            
            # Flattened spectrum plot (i.e., minus aperiodic fit)
            plot_spectrum(fm.freqs, flatten_spectrum,
                        ax=axs[1], **flat_kwargs)
            axs[1].plot(cf, pw, '*', color='blue', label='{} peak'.format(list(bands.keys())[0]))
            axs[1].set_xlim(psd_params['fminmax'])
            if ind_spectr_type == 'linear_flat' or ind_spectr_type == 'linear_normal': 
                (axs[1].set_ylim([0, pw*1.1]))
            if ind_spectr_type == 'log_flat' or ind_spectr_type == 'log_normal':
                (axs[1].set_ylim([None, pw*1.1]))
            axs[1].axvline(x=cf, ymin=0, ymax=pw/(pw*1.1), **hvline_kwargs)
            axs[1].axhline(y=pw, xmin=0, xmax=cf/(psd_params['fminmax'][1]+1), **hvline_kwargs)
            axs[1].axvspan(bw[0], bw[1], alpha=0.1, color='green', label='{} band'.format(list(bands.keys())[0]))
            axs[1].grid(linewidth=0.2)
            axs[1].set_xlabel('Frequency (Hz)')
            axs[1].set_ylabel(flat_spectr_ylabel)
            axs[1].set_title('Power spectrum with individual alpha')
            axs[1].legend()

            # If true, plot all the exported variables on the plots
            if plot_rich == True:
                axs[0].annotate('Error: ' + str(np.round(fm.get_params('error'), 4)) +
                            '\nR\u00b2: ' + str(np.round(fm.get_params('r_squared'), 4)),
                            (0.1, 0.16), xycoords='figure fraction', color='red', fontsize=8.5)
                axs[0].annotate('Exponent: ' + str(np.round(fm.get_params('aperiodic_params','exponent'), 4)) +
                            '\nOffset: ' + str(np.round(fm.get_params('aperiodic_params','offset'), 4)),
                            (0.19, 0.16), xycoords='figure fraction', color='blue', fontsize=8.5)
                axs[1].text(cf+1, pw, 'IAF: '+str(np.round(cf, 4)),
                            verticalalignment='top', color='blue', fontsize=8.5)
                axs[1].annotate('BW: '+str(np.round(bw[0], 4))+' - '+str(np.round(bw[1], 4))+
                                '\nIABP: '+str(np.round(abs_bp, 4)),
                                (0.75, 0.16), xycoords='figure fraction', color='green', fontsize=8.5)
            
            plt.suptitle('{} region ({})'.format(region, subject_names[i]))
            plt.tight_layout()
            if savefig == True:
                _dir = os.path.join(results_folder, exp_folder, timepoint, exp, 'FOOOF', 'outcomes_paper', region, ind_spectr_type, 'plots')
                os.makedirs(_dir, exist_ok=True)
                plt.savefig(fname=os.path.join(_dir, f'{exp}_{subject_names[i]}_{region}.png'), dpi=300)

            plt.show()

            # Add model parameters to dataframe
            globals()["df_fooof_"+region].loc[globals()["df_fooof_"+region].index[i],'Exponent']\
                                                            = fm.get_params('aperiodic_params','exponent')
            globals()["df_fooof_"+region].loc[globals()["df_fooof_"+region].index[i],'Offset']\
                                                            = fm.get_params('aperiodic_params','offset')
            globals()["df_fooof_"+region].loc[globals()["df_fooof_"+region].index[i],'{} CF'.\
                            format(list(bands.keys())[0])] = cf
            globals()["df_fooof_"+region].loc[globals()["df_fooof_"+region].index[i],'{} PW'.\
                            format(list(bands.keys())[0])] = pw
            globals()["df_fooof_"+region].loc[globals()["df_fooof_"+region].index[i],'{} BW'.\
                            format(list(bands.keys())[0])] = str(bw)
            globals()["df_fooof_"+region].loc[globals()["df_fooof_"+region].index[i],'{} absolute power'.\
                            format(list(bands.keys())[0])] = abs_bp
            globals()["df_fooof_"+region].loc[globals()["df_fooof_"+region].index[i],'{} relative power'.\
                            format(list(bands.keys())[0])] = rel_bp
            globals()["df_fooof_"+region].loc[globals()["df_fooof_"+region].index[i],'R_2']\
                                                            = fm.get_params('r_squared')
            globals()["df_fooof_"+region].loc[globals()["df_fooof_"+region].index[i],'Error']\
                                                            = fm.get_params('error')
            
            # Add the original and flattened power spectra to the dataframe
            globals()["df_powerspectra_"+region].loc[subject_names[i]] = fm.power_spectrum
            globals()["df_flatpowerspectra_"+region].loc[subject_names[i]] = flatten_spectrum

    # Export aperiodic data for all regions
    for region in df_psds_regions.columns:
        _dir = os.path.join(results_folder, exp_folder, timepoint, exp, 'FOOOF', 'outcomes_paper', region, ind_spectr_type)
        print(_dir)
        os.makedirs(_dir, exist_ok=True)
        globals()["df_fooof_"+region].to_excel(f"{_dir}.xlsx")
        display(globals()["df_fooof_"+region])

            #display(globals()["df_powerspectra_"+region])
            #display(globals()["df_flatpowerspectra_"+region])