In [35]:
# add bootstrapping of HC test to generate confidence interval of prediction performance.
# 12/1/2021

import os
from os import path
import numpy as np
import pandas as pd
from scipy.optimize import curve_fit
from sklearn.metrics import r2_score

def qudratic_r_squared_gender(chro_age, brain_age, gender):
    """
    compute the r squared of curve fit of chro_age vs. brain age.
    1. Use qudratic function to fit brain age with chronological age and gender.
    2. Then, compute r squared of brain age and predicted brain age with qudratic fit model.
    """
    
    def qudratic_fun(x, a, b, c, d, e, f):
        x1, x2 = x
        #return a + b*x1 + c*(x1**2) + d*x2 + e*x1*x2 f*(x1**2)*x2
        return a + b*x1 + c*(x1**2) + d*x2 + e*x1*x2 + f*(x1**2)*x2
    
    popt, pcov = curve_fit(qudratic_fun, (chro_age, gender), brain_age)
    brain_age_pred = qudratic_fun((chro_age, gender), popt[0], popt[1], popt[2], popt[3], popt[4], popt[5])

    r_square = r2_score(brain_age, brain_age_pred)
    return(r_square)



input_dir = "out03_age_prediction_hc2_stdz_age_reverse_notract_fa_scale_thresh.05"
output_dir = 'out04_age_prediction_hc2_stdz_age_reverse_notract_fa_scale_thresh.05'
cluster_col_index = 'cluster_gmmEEE4'


if not os.path.exists(output_dir):
    os.mkdir(output_dir)

data = pd.read_csv(input_dir + "/out03_scatter_data_ridge_disorder_" + cluster_col_index + ".csv", index_col = 0)

In [36]:
data

Unnamed: 0,cluster,SUBJID,Sex,group,chronological age,brain age
26,1.0,600129552715,0.0,smry_ptd,11.0,14.409815
33,1.0,600210241146,0.0,smry_ptd,18.0,18.283338
70,1.0,600682103788,0.0,smry_ptd,18.0,16.869028
71,1.0,600689706588,0.0,smry_ptd,16.0,14.456132
80,1.0,600778703691,0.0,smry_ptd,13.0,12.954733
...,...,...,...,...,...,...
867,all,609173350200,1.0,HC,12.0,12.602509
920,all,609706993828,1.0,HC,10.0,10.463315
923,all,609714765360,1.0,HC,13.0,14.022208
931,all,609802779962,0.0,HC,16.0,18.360447


In [59]:
## compute prediction performance on bootstrapping samples:

from sklearn.utils import resample

cluster_index = data['cluster'].unique()
group_index = data['group'].unique()
n_boot = 500

# create empty dataframe to save results:
column_index = pd.MultiIndex.from_product([cluster_index, ['mean', '95% CI lower', '95% CI upper']])

row_index = group_index
result_table = pd.DataFrame(index = row_index, columns = column_index)
# result_table = pd.concat([result_table, pd.DataFrame(index = ['all'], columns = column_index)])

for group in group_index:
    for cluster in cluster_index:
        print(group)
        print(cluster)

        data_cluster = data.loc[(data['cluster']==cluster) & (data['group']==group),:]
        X = data_cluster[['chronological age', 'Sex', 'brain age']].values

        r2_boot = []
        for boot in range(n_boot):

            X_boot = resample(X, n_samples=data_cluster.shape[0], 
                                  replace=True, stratify=X[:,0], random_state=boot)
            
            r2 = qudratic_r_squared_gender(X_boot[:,0], X_boot[:,2], X_boot[:,1])
            r2_boot.append(r2)
        
        r2_boot = np.array(r2_boot)
        m = np.mean(r2_boot)
        ci_top = np.percentile(r2_boot, 95)
        ci_bot = np.percentile(r2_boot, 5)
        
        result_table.loc[group, (cluster, 'mean')] = m
        result_table.loc[group, (cluster, '95% CI lower')] = ci_bot
        result_table.loc[group, (cluster, '95% CI upper')] = ci_top

        

smry_ptd
1.0
smry_ptd
2.0
smry_ptd
3.0
smry_ptd
4.0
smry_ptd
all
smry_dep
1.0
smry_dep
2.0
smry_dep
3.0
smry_dep
4.0
smry_dep
all
smry_phb
1.0
smry_phb
2.0
smry_phb
3.0
smry_phb
4.0
smry_phb
all
smry_soc
1.0
smry_soc
2.0
smry_soc
3.0
smry_soc
4.0
smry_soc
all
smry_add
1.0
smry_add
2.0
smry_add
3.0
smry_add
4.0
smry_add
all
smry_odd
1.0
smry_odd
2.0
smry_odd
3.0
smry_odd
4.0
smry_odd
all
HC
1.0
HC
2.0
HC
3.0
HC
4.0
HC
all


In [60]:
# result_table.sort_index(ascending=True, inplace = True)
result_table2 = result_table.reindex(['HC', 'smry_phb', 'smry_soc', 'smry_dep', 'smry_ptd', 'smry_odd', 'smry_add'])
result_table2 = result_table2.set_axis(['HC', 'Specific phobia', 'Social phobia', 'Depression', 'PTSD', 'ODD', 'ADHD'], axis=0)
result_table2

result_table2

Unnamed: 0_level_0,1.0,1.0,1.0,2.0,2.0,2.0,3.0,3.0,3.0,4.0,4.0,4.0,all,all,all
Unnamed: 0_level_1,mean,95% CI lower,95% CI upper,mean,95% CI lower,95% CI upper,mean,95% CI lower,95% CI upper,mean,95% CI lower,95% CI upper,mean,95% CI lower,95% CI upper
HC,0.733834,0.654927,0.809295,0.489276,0.365514,0.601812,0.734236,0.661606,0.801121,0.248268,0.126916,0.382633,0.756766,0.68362,0.825912
Specific phobia,0.670287,0.623889,0.714463,0.429798,0.351129,0.515054,0.57929,0.507877,0.646944,0.107717,0.0529266,0.168822,0.706906,0.65731,0.753257
Social phobia,0.583906,0.507445,0.659968,0.373365,0.2892,0.458088,0.483251,0.394333,0.567434,0.105968,0.0519738,0.169318,0.608581,0.544888,0.673126
Depression,0.586464,0.497877,0.684422,0.254627,0.150867,0.396532,0.64461,0.568629,0.716573,0.146993,0.0631179,0.235126,0.674934,0.619671,0.735703
PTSD,0.541285,0.440719,0.643,0.272575,0.123402,0.468479,0.47502,0.362737,0.582764,0.155392,0.0620738,0.281625,0.589354,0.490859,0.678429
ODD,0.474205,0.402744,0.544972,0.249405,0.178734,0.319506,0.440751,0.361911,0.523011,0.0633565,0.0268769,0.109862,0.54586,0.483539,0.60454
ADHD,0.598494,0.500216,0.690019,0.386152,0.273226,0.496955,0.596974,0.509714,0.673281,0.152453,0.0611222,0.250799,0.702935,0.616498,0.77622


In [62]:
result_table2.to_csv(output_dir + '/out04_r2_boot.csv')