In [None]:
import scipy
import mne
import os
from prediction_utils import *
import shutil

gof_thresh_average = 75

responses = ['n15','p30','n45','p60','mep']
pos_names = ['n15','p30','n45','p60','handknob']
depth="depth0.8"
where = f"source_{depth}"

feature_path_base = r"D:\REFTEP_ALL\Features_v2"
source_path_base =r"D:\REFTEP_ALL\Source_analysis"

sites = ['Tuebingen','Aalto']
freq_range_names = ['theta','alpha','beta','gamma']
labels_all_aparc = mne.read_labels_from_annot("fsaverage", "aparc", "both", subjects_dir=r"D:\REFTEP_ALL\REFTEP_reco\Aalto_recon_all")
label_names_aparc = [label.name for label in labels_all_aparc if "unknown" not in label.name]
#define which params to use
usepsd = True
usecoil = True
rejcoil = False
usephase = False
usepac = False
models_path = fr"D:\REFTEP_ALL\Models_Aalto_Tuebingen_phase_{usephase}_rejcoil_{rejcoil}_{depth}_514"
if os.path.exists(models_path):
    shutil.rmtree(models_path)
os.makedirs(models_path)
usetime = [True, 'sample', {'Aalto':False,'Tuebingen':False}]
phasefreqs = ['alpha']
if usetime[0]:
    interaction_variables_with_time = ['PSD_' + freq for freq in freq_range_names]
else:
    interaction_variables_with_time = []

    
angle_diff_default_normal=1.1898413884798587
angle_diff_default_dir=1.2484358972410208
pos_diff_default=2.125563855244972
if usecoil and rejcoil:
    distance_thresh=5
    angle_distance_thresh =10
elif usecoil and not rejcoil:
    distance_thresh=np.inf
    angle_distance_thresh =np.inf
else:
    print("not using coil")
combine_normal_ori = False
combine_all = True
re_formula_now = False

In [None]:
for parctype in ['aparc',str(['n15', 'p30', 'n45', 'p60', 'handknob'])]:
    datas = []
    spatial_names_subjects = []
    for response_ind, response in enumerate(responses): #go through different response types
        #define "source" spatial names, i.e. anatomical labels or functional labels
        if response_ind == 0: #only load features once to save time
            for site in sites:
                features_path_site = os.path.join(feature_path_base,f'Features_{site}')
                source_path_site = os.path.join(source_path_base,f'Source_analysis_{site}')
                for subind, subject in enumerate(os.listdir(features_path_site)):
                    if subind==0: #only read once
                        if parctype == "aparc":
                            spatial_names = label_names_aparc
                        elif parctype == str(['n15', 'p30', 'n45', 'p60', 'handknob']):
                            spatial_names_orig = []
                            for pos_name in pos_names:
                                custom_label_file = f"{subject}_{pos_name}_label_fsaverage"
                                labels_all_now = mne.read_labels_from_annot("fsaverage", custom_label_file, "both", subjects_dir=r"D:\REFTEP_ALL\REFTEP_reco\Aalto_recon_all", verbose=False)
                                labels_now = [label.name for label in labels_all_now if "unknown" not in label.name]
                                spatial_names_orig = spatial_names_orig + labels_now
                            spatial_names_orig_splitted = [name.split("_") for name in spatial_names_orig]
                            spatial_names = [name[1] +  "_" + name[2] + "_" + name[3][:-3] for name in spatial_names_orig_splitted] #"around_xxx_label", ignoring hemisphere and subject
                        else:
                            raise ValueError(f"bad parameter: parctype: {parctype}")
                    where_now = f'{where}/{subject}_{parctype}' #where to find the features from now for this parcellation
                    data_subject = load_data_subject(subject, site, features_path_site, where_now, spatial_names, freq_range_names, usepsd, usecoil,
                                                      usephase, usepac, phasefreqs, usetime, distance_thresh, angle_distance_thresh, angle_diff_default_normal,angle_diff_default_dir, pos_diff_default, combine_normal_ori, combine_all, usecoil_sigma=1e-3)
                    datas.append(data_subject)
        good_subjects_this_response = []
        responses_all = []
        idx = 0
        for site in sites:
            source_path_site = os.path.join(source_path_base,f'Source_analysis_{site}')
            for subject in os.listdir(source_path_site):
                amplitudes, average_gof = load_responses_subject(source_path_site, subject, response)
                #only use data from these subjects in this response
                if average_gof >= gof_thresh_average or response == 'mep':
                    good_subjects_this_response.append(idx)
                    responses_all.append(amplitudes)
                idx += 1
        datas_to_use = [datas[ind] for ind in good_subjects_this_response]
        if parctype == 'aparc':
            spatial_names_now = spatial_names
        else:
            spatial_names_now = [name for name in spatial_names if response in name or 'handknob' in name]
        for name in spatial_names_now:
            #create a dataframe
            df = create_df_with_features(datas_to_use, responses_all, name, freq_range_names, usephase, usepac, usecoil, phasefreqs, usetime, combine_normal_ori, combine_all)
            #df = df.dropna() #drop nan values
            df, explanatory_variables_numeric = create_scaled_df(df) #scale explanatory variables
            #df_path = os.path.join(models_path,f'{response}_{name}_usecoil_{usecoil}_usepsd_{usepsd}_usephase_{usephase}_usepac_{usepac}_usetime_{usetime[0]}_{usetime[1]}_grouptype_Subject_dataframe.csv')
            #df.to_csv(df_path) #save the used dataframe for later assessment
            for ref_site in [None]:
                result = perform_lmm(df, explanatory_variables_numeric, 'Subject', interaction_variables_with_time, ref_site=ref_site, re_formula=re_formula_now) #fit linear mixed effects model
                #calc r2 and marginal r2
                overall_r2 = calculate_r2(df, 'AMP', result)
                conditional_r2, marginal_r2 = calculate_r2_cond_marginal(result)
                dict_params = dict(result.params)
                dict_p_values = dict(result.pvalues)
                #save the results
                results_path_model = os.path.join(models_path,f'{response}_{name}_usecoil_{usecoil}_usepsd_{usepsd}_usephase_{usephase}_usepac_{usepac}_usetime_{usetime[0]}_{usetime[1]}_grouptype_Subject_ref_site_{ref_site}_model.pkl')
                result.save(results_path_model)
                if usephase:
                    #calculate coefs for phase using sin and cos and estimate the p-value
                    b_phase, p_phase = estimate_phase_coef('phase_sin_alpha', 'phase_cos_alpha', result)
                    dict_params['phase_alpha'] = b_phase
                    dict_p_values['phase_alpha'] = p_phase
                    #print('coef', b_phase,'pval',p_phase)
                else:
                    p_phase = False

                #save the results after performing p-value correction
                corrected_p_values = perform_p_val_correction(result, p_phase)
                p_values_bonferroni = {var:corrected_p_values['bonferroni'][corrected_p_values['predictor']==var].iloc[0] if var != "Intercept" else np.nan for var in list(dict_params.keys())}
                p_values_fdr = {var:corrected_p_values['fdr'][corrected_p_values['predictor']==var].iloc[0] if var != "Intercept" else np.nan for var in list(dict_params.keys())}
                coefficients_df = pd.DataFrame({'coefficients':dict_params, 'pvalues_raw':dict_p_values, 'overall_r2':overall_r2,
                                                'conditional_r2':conditional_r2,'marginal_r2':marginal_r2, 'pvalues_bonferroni':p_values_bonferroni,'pvalues_fdr':p_values_fdr})
                results_path = os.path.join(models_path,f'{response}_{name}_usecoil_{usecoil}_usepsd_{usepsd}_usephase_{usephase}_usepac_{usepac}_usetime_{usetime[0]}_{usetime[1]}_grouptype_Subject_ref_site_{ref_site}.csv')
                coefficients_df.to_csv(results_path)


