## DWI_RI_CCA_Analysis_PMD

In [None]:
from cca_zoo.models import SCCA_PMD
from cca_zoo.model_selection import permutation_test_score
import copy
from datetime import date
import dill
from glob import glob
import matplotlib.pyplot as plt
from matplotlib import rc
import numpy as np
import pandas as pd
import seaborn as sns
from sklearn.model_selection import KFold, train_test_split
from sklearn.utils import resample
from sklearn.preprocessing import StandardScaler
import statsmodels.api as sm
from statsmodels.stats.multitest import fdrcorrection as fdrcorr
from statsmodels.stats.outliers_influence import variance_inflation_factor
import statsmodels.formula.api as smf
from statsmodels.discrete.discrete_model import Poisson 
from scipy.stats import spearmanr

import time
import traceback

pd.options.display.max_rows = 999
pd.options.display.max_columns = 999

today=str(date.today())

In [None]:
#Set Plotting defaults
sns.set_style(sns.set_style('whitegrid', {'font.family':'serif', 'font.serif':'Times New Roman'}))
sns.set_palette('nipy_spectral_r', n_colors=15)

In [None]:
# Data paths
candpath = '/gpfs/milgram/pi/gee_dylan/candlab'
datapath = candpath + '/analyses/shapes/dwi/QSIPrep'
newri = '/gpfs/milgram/pi/gee_dylan/lms233/RI_Data/coded_output'
analysis = '/gpfs/milgram/pi/gee_dylan/candlab/analyses/shapes/dwi/QSIPrep/analysis'


# Import data
full_df_raw = pd.read_csv(analysis+ '/DWI_RI_FullDataset_RegressedCovariates_InclSex_n=107_2024-04-16_GFA_QA_RD_ZIPBehavModel_ages0-17_RIAgeRegressed.csv') 

In [None]:
# Examine variable multicollinearity
bx_df_coll = full_df.loc[:, "all_0.0_regr":'all_17.0_regr']
bx_df_coll.columns = bx_df_coll.columns.str.replace('_', ' ').str.replace('all', 'Exposures at age').str.replace('.0 regr', '')

corrMatrix = bx_df_coll.corr()

fig, ax = plt.subplots(1, 1, figsize = (20,20))
sns.heatmap(corrMatrix, annot=True, ax=ax, vmin = -1, vmax=1, annot_kws = {'fontsize':12}, cmap= 'coolwarm')

plt.savefig(analysis + "/figures/Adversity_Heatmap_{}.png".format(today), dpi=300, transparent=True) 

In [None]:
hist_df = full_df.loc[:, "all_0.0":"all_999.0"]
hist_df.columns = hist_df.columns.str.replace('_', ' ').str.replace('all', 'Exposures at age').str.replace('Exposures at age 999.0', 'Exposures with age not reported')
pivot_df = hist_df.reset_index().melt(var_name = 'Age of Exposure', 
                        value_name = 'Number of Exposures').groupby('Age of Exposure').sum().reset_index()

order = ['Exposures at age 0.0', 'Exposures at age 1.0', 'Exposures at age 2.0', 'Exposures at age 3.0',
         'Exposures at age 4.0', 'Exposures at age 5.0', 'Exposures at age 6.0', 'Exposures at age 7.0', 
         'Exposures at age 8.0', 'Exposures at age 9.0', 'Exposures at age 10.0', 'Exposures at age 11.0',
         'Exposures at age 12.0', 'Exposures at age 13.0', 'Exposures at age 14.0', 'Exposures at age 15.0',
         'Exposures at age 16.0', 'Exposures at age 17.0', 'Exposures at age 18.0', 'Exposures at age 19.0',
         'Exposures at age 20.0', 'Exposures at age 21.0', 'Exposures at age 22.0', 'Exposures at age 23.0',
         'Exposures at age 24.0', 'Exposures at age 25.0', 'Exposures at age 26.0', 'Exposures at age 27.0',
         'Exposures at age 28.0', 'Exposures at age 29.0', 'Exposures at age 30.0', 'Exposures with age not reported']

fig, ax = plt.subplots(1, 1, figsize=(14, 10))
sns.barplot(x = 'Age of Exposure', y='Number of Exposures', data = pivot_df, ax=ax, order=order, palette = sns.color_palette("Reds_r", n_colors=round(len(pivot_df)*1.25)))
plt.xticks(rotation=90)
ax.set_title('Summed Adversity Exposures by Age Across Participants', size =24)
plt.tight_layout()

plt.savefig(analysis + "/figures/Adversity_Distribution_{}.png".format(today), dpi=300, transparent=True) 

### Visualization

In [None]:
analysis_cols = ['all_0.0_regr','all_1.0_regr', 'all_2.0_regr', 
                 'all_3.0_regr','all_4.0_regr', 'all_5.0_regr', 
                 'all_6.0_regr', 'all_7.0_regr', 'all_8.0_regr',
                 'all_9.0_regr', 'all_10.0_regr', 'all_11.0_regr', 
                 'all_12.0_regr', 'all_13.0_regr', 'all_14.0_regr',
                 'all_15.0_regr', 'all_16.0_regr', 'all_17.0_regr']

fig, ((ax1, ax2, ax3, ax4, ax5), 
      (ax6, ax7, ax8, ax9, ax10),
      (ax11, ax12, ax13, ax14, ax15), 
      (ax16, ax17, ax18, ax19, ax20)) = plt.subplots(4, 5, figsize = (25, 10)) 

axes = [ax1, ax2, ax3, ax4, ax5, ax6, ax7, ax8, ax9, ax10, ax11, ax12, ax13, ax14, ax15, ax16, ax17, ax18, ax19, ax20]

for i in range(0, len(analysis_cols)):
    dist_col = full_df[analysis_cols[i]] # Select column to plot
    sns.histplot(x = dist_col, ax=axes[i], bins=20)
    axes[i].set_xlim(-3, 7)
    # axes[i].set_ylim(0, 60)
    
fig.tight_layout()
fig.savefig(analysis + "/figures/AdversityDistributions_{}.png".format(today), dpi=300, transparent=True) 

### Demographics

Female: 1.0
Male: 0.0

In [None]:
print("Overall sample size is {}".format(len(full_df)))
print("{}% Female; {}% Male".format(round(len(full_df[full_df['sex']==1])/len(full_df)*100, 3), 
                                    round(len(full_df[full_df['sex']==0])/len(full_df)*100, 3)))
print("Mean age: {}".format(round(full_df['age_at_scan'].mean(), 3)))
print("{} subjects scanned at BIC, {} at MRRC".format(len(full_df[full_df['site_bin'] == 0]), len(full_df[full_df['site_bin'] == 1])))

### CCA Analysis

In [None]:
# Resources: https://vitalflux.com/pca-explained-variance-concept-python-example/

def compute_covariance_explained(transformed_df1, transformed_df2):
    
    scaler = StandardScaler()
    if not np.allclose(transformed_df1.mean(), 0):
        transformed_df1 = pd.DataFrame(scaler.fit_transform(transformed_df1), columns=transformed_df1.columns)
    if not np.allclose(transformed_df2.mean(), 0):
        transformed_df2 = pd.DataFrame(scaler.fit_transform(transformed_df2), columns=transformed_df2.columns)
   
    # R code from Nat Comms paper: diag(covmat)^2 / sum(diag(covmat)^2) 
    cov_arr = np.ones((transformed_df1.shape[1],))
   
    for i in range(0, len(transformed_df1.columns)):
        cov_val = np.cov(transformed_df1.iloc[:, i], transformed_df2.iloc[:, i], ddof=1)[0][1] #Compute pairwise covariance between components in stress and dwi matrices
        cov_arr[i] = cov_val #Append this value to empty array
    
    exp_var = cov_arr**2/np.sum(cov_arr**2) # covariance explained = per element in array, covariance^2 divided by all summed covariance^2
    
    return exp_var

In [None]:
def get_overlap(weights, loadings, component):
    exp_info = weights.loc[:, component]
    load_info = loadings.loc[:, component]
    final_loadings = load_info[exp_info != 0] #get loadings where weights are not zero
    return final_loadings

### Analyses

#### Define Model Metrics

In [None]:
metric_name = 'rd'
metname_cap = 'RD'
analysis_date = '2024-03-15_ZIPBehavModel_ages0-17_RIAgeRegressed'
model_output = analysis + '/model_output_0.5_0.5_{}_{}'.format(metric_name, analysis_date)

In [None]:
def fit_model(in_xmat, in_ymat, c_val):
    n_components = in_xmat.shape[1]
    
    #Fit CCA model
    model = SCCA_PMD(latent_dims=n_components, 
                       random_state=0, 
                       scale=True,
                       centre=True,
                       max_iter=1000,
                       c=c_val)

    model.fit((in_xmat, in_ymat))
    model_score = model.score((in_xmat, in_ymat))
    model_results = model.pairwise_correlations((in_xmat, in_ymat))[0][1]
    
    gen_colnames_stress = []
    gen_colnames_dti = []

    for i in range(0, n_components):
        gen_colnames_stress.append('Variate_{}_Stress'.format(i+1))
        gen_colnames_dti.append('Variate_{}_DTI'.format(i+1))

    # Transform X and y matrices to see how CCA fitting changes variables
    tranf_df_bx = pd.DataFrame(model.transform((in_xmat, in_ymat))[0], columns = gen_colnames_stress).dropna(axis=1)
    tranf_df_fa = pd.DataFrame(model.transform((in_xmat, in_ymat))[1], columns = gen_colnames_dti).dropna(axis=1)
    
    return model, model_score, model_results, tranf_df_bx, tranf_df_fa

In [None]:
# Select columns for x-dataset to analyze (adversity exp birth-18)
scaler = StandardScaler()

analysis_columns = ['all_0.0_regr', 'all_1.0_regr', 'all_2.0_regr', 
                 'all_3.0_regr', 'all_4.0_regr', 'all_5.0_regr', 
                 'all_6.0_regr', 'all_7.0_regr', 'all_8.0_regr',
                 'all_9.0_regr', 'all_10.0_regr', 'all_11.0_regr', 
                 'all_12.0_regr', 'all_13.0_regr', 'all_14.0_regr',
                 'all_15.0_regr', 'all_16.0_regr', 'all_17.0_regr']

in_xmat_data = pd.DataFrame(scaler.fit_transform(full_df[analysis_columns].replace(np.nan, 0.0)),
                           columns = analysis_columns)
in_ymat_data = pd.DataFrame(scaler.fit_transform(full_df.loc[:, "{}_AF_left_regr".format(metric_name):"{}_ST_PREM_right_regr".format(metric_name)]),
                           columns = full_df.loc[:, "{}_AF_left_regr".format(metric_name):"{}_ST_PREM_right_regr".format(metric_name)].columns) 

assert np.nan not in in_xmat_data
assert np.nan not in in_ymat_data

#Set regularization parameter (derived from SelectHyperparameters script)
cval = [.5, .5]

model, main_model_score, main_model_results, model_df_bx, model_df_fa = fit_model(in_xmat_data, in_ymat_data, cval)

# Compute covariance of components (Helpful: https://towardsdatascience.com/5-things-you-should-know-about-covariance-26b12a0516f1)
model_exp_var = compute_covariance_explained(model_df_bx, model_df_fa)

print("First component correlation strength: {}".format(main_model_score[0]))

In [None]:
# Examine model loadings
ex_comp = "Variate_1"
n_components = in_xmat_data.shape[1]

gen_colnames = []
for i in range(0, n_components):
    gen_colnames.append('Variate_{}'.format(i+1))

#weights are standardized canonical coefficients, canonical weights are structure correlations

pheno_weights_mmod = pd.DataFrame(model.get_factor_loadings([in_xmat_data, in_ymat_data])[0], index = in_xmat_data.columns, 
                             columns = gen_colnames)
pheno_weights_mmod['abs_weights'] = abs(pheno_weights_mmod[ex_comp]) #Absolute value of weights to include negative associations
pheno_weights_mmod.sort_values(by = 'abs_weights', ascending = False, inplace=True)
 
neuro_weights_mmod = pd.DataFrame(model.get_factor_loadings([in_xmat_data, in_ymat_data])[1], index = in_ymat_data.columns,
                            columns = gen_colnames).sort_values(by = ex_comp, ascending = False)
neuro_weights_mmod['abs_weights'] = abs(neuro_weights_mmod[ex_comp])  #Absolute value of weights to include negative associations
neuro_weights_mmod.sort_values(by = 'abs_weights', ascending = False, inplace=True)

In [None]:
main_res_files = glob(model_output + '/Main_Bootstrapped_Results_Correlation*.csv')
main_cov_files = glob(model_output + '/Main_Bootstrapped_Results_Covariation*.csv')
shuff_res_files = glob(model_output + '/Shuffled_Bootstrapped_Results_Correlation*.csv')
shuff_cov_files = glob(model_output + '/Shuffled_Bootstrapped_Results_Covariation*.csv')
assert len(main_res_files)==10000
assert len(main_cov_files)==10000
assert len(shuff_res_files)==10000
assert len(shuff_cov_files)==10000

In [None]:
print(len(main_res_files))
print(len(shuff_res_files))

In [None]:
main_results = []
main_cov = []
main_res_list = []
main_cov_list = []
shuffled_res_list = []
shuffled_cov_list = []

for i in range(0, len(main_res_files)):
    file_main_cor = main_res_files[i]
    file_main_cov = main_cov_files[i]
    file_shuff_cor = shuff_res_files[i]
    file_shuff_cov = shuff_cov_files[i]
    main_cor = pd.read_csv(file_main_cor)
    main_cov = pd.read_csv(file_main_cov)
    shuff_cor = pd.read_csv(file_shuff_cor)
    shuff_cov = pd.read_csv(file_shuff_cov)
    main_res_list.append(main_cor.iloc[:, 1])
    main_cov_list.append(main_cov.iloc[:, 1])
    shuffled_res_list.append(shuff_cor.iloc[:, 1])
    shuffled_cov_list.append(shuff_cov.iloc[:, 1])

In [None]:
# Dataframes with results read in
final_main_cor = pd.DataFrame(main_res_list)
final_main_cov = pd.DataFrame(main_cov_list)
final_shuff_cor = pd.DataFrame(shuffled_res_list)
final_shuff_cov = pd.DataFrame(shuffled_cov_list)

In [None]:
# Permutation testing function (one-sided)
def compute_perm_val(val, score):
    score = pd.Series(score)
    past = score[score >= val]
    p=len(past)/len(score)
    return p

In [None]:
print('Main model correlation strength significance: \n')
n_components = 18

pvals_cor = []
pvals_cov = []

for i in range(0, n_components-1):
    val = main_model_results[i] # ith modal result (correlation strength)
    cov_val = model_exp_var[i] # ith modal result (covariance explained)
    shuff_col = final_shuff_cor.iloc[:, i] # ith column in shuffled data (representing correlations across shuffled iterations)
    assert final_shuff_cor.shape == (10000, 18) # Ensure correct structure
    shuff_col_cov = final_shuff_cov.iloc[:, i] # ith column in shuffled data (representing correlations across shuffled iterations)
    assert final_shuff_cov.shape == (10000, 18)
    pval = compute_perm_val(val, shuff_col)
    pval_cov = compute_perm_val(cov_val, shuff_col_cov)
    pvals_cor.append([i+1, pval])
    pvals_cov.append([i+1, pval_cov])
    print("For component {}, correlation p={}, covariance p={}".format(i+1, round(pval, 5), round(pval_cov, 5)))

In [None]:
print('AVERAGE correlation strength significance across bootstraps: \n')
n_components = 18

pvals_cor = []
pvals_cov = []

for i in range(0, n_components-1):
    val = final_main_cor.iloc[:,i].median() # Bootstrapped values
    cov_val =  final_main_cov.iloc[:,i].median() # Bootstrapped values
    shuff_col = final_shuff_cor.iloc[:, i]
    assert(len(shuff_col) == 10000)
    shuff_col_cov = final_shuff_cov.iloc[:, i]
    assert(len(shuff_col) == 10000)
    pval = compute_perm_val(val, shuff_col)
    pval_cov = compute_perm_val(cov_val, shuff_col_cov)
    pvals_cor.append([i+1, pval])
    pvals_cov.append([i+1, pval_cov])
    print("For component {}, correlation p={}, covariance p={}".format(i+1, round(pval, 5), round(pval_cov, 5)))

In [None]:
# Manual FDR correction across models-- variates that explain >.1 covariance
## GFA Model:
# For component 1, correlation p=0.0053, covariance p=0.4503
# For component 2, correlation p=0.0156, covariance p=0.3229
# For component 3, correlation p=0.0732, covariance p=0.6595
# ## QA Model:
# For component 1, correlation p=0.0075, covariance p=0.4068
# For component 2, correlation p=0.0248, covariance p=0.3832
# For component 3, correlation p=0.0381, covariance p=0.6362
## RD Model:
# For component 1, correlation p=0.0113, covariance p=0.6138
# For component 2, correlation p=0.0109, covariance p=0.4264
# For component 3, correlation p=0.0514, covariance p=0.5287

manual_p_names = ['GFA_Comp1', 'GFA_Comp2', 'GFA_Comp3', 
                  'QA_Comp1', 'QA_Comp2', 'QA_Comp3', 
                  'RD_Comp1', 'RD_Comp2', 'RD_Comp3']
manual_ps = [0.0053, 0.0156, 0.0732, 
             0.0075, 0.0248, 0.0381, 
             0.0113, 0.0109, 0.0514] #Double checked 10/7/23
man_p_df = pd.DataFrame(manual_ps, index = manual_p_names, columns = ['pval'])

In [None]:
from statsmodels.stats.multitest import fdrcorrection as fdr
p_df = pd.DataFrame(pvals_cor, columns = ['Variate', 'pval'])

test_df = man_p_df #
test_df['Passed'], test_df['FDR_pval'] = fdr(test_df['pval'])
test_df.sort_values(by='pval', ascending=True)

In [None]:
# Set variables for plotting
cov_plot_df = final_main_cov.median(axis=0)
cor_plot_df = final_main_cor.median(axis=0)
shuff_cor_plot_df = final_shuff_cor.mean(axis=0)

In [None]:
# Plot covariance explained
sns.set_palette('Greens_r', n_colors=25)

cov_df = pd.DataFrame(cov_plot_df, columns = ['cov'])
cov_df['Variate'] = pd.Series(range(1,n_components+1)).astype(str)
cov_sorted = cov_df.sort_values(by='cov', ascending = False)
# survived_corr = cov_df.set_index('Variate').drop(failed).reset_index()

# Plot Covariances
plt.rcParams.update({'font.size': 16})

fig, ax1 = plt.subplots(1,1, figsize = (8,5));
scatplot1 = sns.pointplot(x = cov_sorted['Variate'], y = cov_sorted['cov'], kind = "line",
                         join = True, hue = cov_df['Variate'], ax=ax1, scale=1.5);

scatplot1.legend_.remove()
fig.suptitle('Covariance Explained by Canonical Mode\n({} Model)'.format(metname_cap), fontsize = 22, fontweight='bold')
ax1.set_xlabel('Canonical Mode', fontsize = 20)
ax1.set_ylabel('Covariance Explained', fontsize = 20)

plt.xticks(fontsize = 12);
fig.tight_layout()
fig = scatplot1.get_figure()

fig.savefig(analysis + "/figures/CovarianceExplained_{}.png".format(metric_name), dpi=300, transparent=True) 

In [None]:
# Plot correlation strengths
res_real = pd.DataFrame(cor_plot_df)
res_real['Medians'] = res_real.median(axis=1)
res_real.reset_index(inplace=True)

res_shuff = pd.DataFrame(shuff_cor_plot_df)
res_shuff['Medians'] = res_shuff.median(axis=1)
res_shuff.reset_index(inplace=True)

fig, ax1 = plt.subplots(1,1, figsize = (8,5));
sns.barplot(x = res_real['index']+1, y = res_real['Medians'], ax = ax1,
           linewidth=0)
sns.barplot(x = res_shuff['index']+1, y = res_shuff['Medians'], ax = ax1, color='gray', 
            alpha = .8, linewidth=0)
           
fig.suptitle('Correlation Strength by Canonical Mode\n({} Model)'.format(metname_cap), fontsize=22, fontweight='bold')

ax1.set_xlabel('Canonical Mode', fontsize=20)
ax1.set_ylabel('Correlation Strength', fontsize=20)
plt.ylim(0, 1)
fig.tight_layout()
fig.savefig(analysis + "/figures/CorrelationStrengths_{}.png".format(metric_name), dpi=300, transparent=True) 

In [None]:
# Glob model loadings (REAL)
adv_loadingfiles = glob(model_output + '/Main_Loadings_AdvExp_iteration*.csv')
dwi_loadingfiles = glob(model_output + '/Main_Loadings_DWI_iteration*.csv')
 
assert len(adv_loadingfiles) == 10000
assert len(dwi_loadingfiles) == 10000

# Glob model loadings (SHUFFLED)
shuff_adv_loadingfiles = glob(model_output + '/Shuffled_Loadings_AdvExp_iteration*.csv')
shuff_dwi_loadingfiles = glob(model_output + '/Shuffled_Loadings_DWI_iteration*.csv')

assert len(shuff_adv_loadingfiles) == 10000
assert len(shuff_dwi_loadingfiles) == 10000

In [None]:
# Read in weights and loadings (REAL)
adversity_loadings = np.ones((18, n_components, len(adv_loadingfiles))) # Structure: variables, components, total iterations
diffusion_loadings = np.ones((43, n_components, len(dwi_loadingfiles))) # Structure: variables, components, total iterations

for i in range(0, len(adv_loadingfiles)):
    adv_file = pd.read_csv(adv_loadingfiles[i], index_col=0)
    adversity_loadings[:, :, i] = adv_file
    dwi_file = pd.read_csv(dwi_loadingfiles[i], index_col=0)
    diffusion_loadings[:, :, i] = dwi_file
    
# Read in weights and loadings (SHUFFLED)
shuff_adv_loadings = np.ones((18, n_components, len(shuff_adv_loadingfiles))) # Structure: variables, components, total iterations
shuff_dwi_loadings = np.ones((43, n_components, len(shuff_dwi_loadingfiles))) # Structure: variables, components, total iterations

for i in range(0, len(adv_loadingfiles)):
    shuff_adv_file = pd.read_csv(shuff_adv_loadingfiles[i], index_col=0)
    shuff_adv_loadings[:, :, i] = shuff_adv_file
    shuff_dwi_file = pd.read_csv(shuff_dwi_loadingfiles[i], index_col=0)
    shuff_dwi_loadings[:, :, i] = shuff_dwi_file

In [None]:
# Compute values for plotting (rows are variables, columns are CCA components)
adv_loadingvals = np.median(adversity_loadings, axis=2) # Axis 0 is variables, 1 is components, 2 is iterations--find median of particular variable/variate across iterations
dwi_loadingvals = np.median(diffusion_loadings, axis=2) # Axis 0 is variables, 1 is components, 2 is iterations--find median of particular variable/variate across iterations

In [None]:
# Plot histogram of diffusion loadings to get a sense of whether normally distributed
sns.histplot(data = diffusion_loadings[0, 0, :], bins = 60)

In [None]:
# Examine model loadings
ex_comp = "Variate_1" # Component to examine

gen_colnames = [] # List of generated column names for each variate
for i in range(0, n_components):
    gen_colnames.append('Variate_{}'.format(i+1))

# Compute loadings for REAL models
adv_loading_medians = pd.DataFrame(adv_loadingvals, index = in_xmat_data.columns, 
                             columns = gen_colnames)
adv_loading_medians['abs_loadings'] = abs(adv_loading_medians[ex_comp]) #Absolute value of loadings to include negative associations
adv_loading_medians.sort_values(by = 'abs_loadings', ascending = False, inplace=True)
 
dwi_loading_medians = pd.DataFrame(dwi_loadingvals, index = in_ymat_data.columns,
                            columns = gen_colnames).sort_values(by = ex_comp, ascending = False)
dwi_loading_medians['abs_loadings'] = abs(dwi_loading_medians[ex_comp])  #Absolute value of loadings to include negative associations
dwi_loading_medians.sort_values(by = 'abs_loadings', ascending = False, inplace=True)

In [None]:
# Print out loadings
adv_loading_medians.loc[:, ex_comp] #.sort_values(ascending = False)

In [None]:
# Print out loadings
dwi_loading_medians.loc[:, ex_comp][0:20] #.sort_values()

In [None]:
# Plot Adversity loadings

pltdf2 = pd.DataFrame(adv_loading_medians.loc[:, ex_comp])
pltdf2.columns = pltdf2.columns.str.replace('_', ' ')

labels = pltdf2.index.str.replace('_', ' ').str.replace('all', 'Exposures at age').str.replace('.0 regr', '')
fig, ax = plt.subplots(1,1, figsize=(8,8))
# pltdf.head()

sns.barplot(x=pltdf2.iloc[:,0], y = labels, palette = sns.color_palette("Reds_r", n_colors=round(len(pltdf2)*1.5)), ax=ax)
# ax.set_xlim(-1.5,1.5)
ax.set_xlabel('Canonical Mode {}'.format(ex_comp.lstrip('Variate_')))
fig.suptitle("Median {} Model Loadings for Adversity Exposure".format(metname_cap), fontsize = 20,
            fontweight='bold')
fig.tight_layout() 
fig.savefig(analysis + "/figures/AdversityLoadings_{}.png".format(metric_name), dpi=300, transparent=True) 

In [None]:
# Determine loading significance (Wilcoxon test)
from scipy.stats import mannwhitneyu, wilcoxon

print('Loading significance across bootstraps: \n')

adv_loadings = pd.DataFrame(adversity_loadings[:,0,:], index = in_xmat_data.columns) #[:,0,:] selects first component across variables/iterations
shuff_loadings = pd.DataFrame(shuff_adv_loadings[:,0,:], index = in_xmat_data.columns) #[:,0,:] selects first component across variables/iterations
n_components = len(adv_loadings)

# Create list and run significance tests
pvals_loadings_cor_adv = []

for i in range(0, n_components):
    val = adv_loadings[i] # Bootstrapped values
    main_col = adv_loadings.iloc[i, :]
    shuff_col = shuff_loadings.iloc[i, :] # Select row representing that variable
    assert len(main_col) == 10000 # Ensure fitting across iterations
    assert len(shuff_col) == 10000 # Ensure fitting across iterations
    stat, pval_loadings = wilcoxon(main_col, shuff_col, alternative= 'two-sided')
    print("For {}, stat={}, p={}".format(adv_loadings.index[i], round(stat, 5), round(pval_loadings, 3)))
    pvals_loadings_cor_adv.append([adv_loadings.index[i], stat, pval_loadings])

In [None]:
# Save results in dataframe and perform FDR correction
stat_table_adv = pd.DataFrame(pvals_loadings_cor_adv, columns = ['Variable', 'Wilcoxon Stat', 'p-value'])
stat_table_adv['Passed'], stat_table_adv['FDR_pval'] = fdr(stat_table_adv['p-value'])
stat_table_adv['FDR_pval'] = round(stat_table_adv['FDR_pval'], 3)
stat_table_adv['p-value'] = round(stat_table_adv['p-value'], 3)
stat_table_adv.to_csv(analysis + '/figures/table_loadings_adversity_{}_{}.csv'.format(metric_name, today))

In [None]:
# Select variables that remained significant after FDR correction
adv_loads_insig = stat_table_adv[stat_table_adv['Passed'] == False]['Variable']

In [None]:
# Plot DWI loadings

pltdf = pd.DataFrame(dwi_loading_medians.loc[:, ex_comp][0:15])
pltdf.columns = pltdf.columns.str.replace('_', ' ')

labels = pltdf.index.str.replace('gfa_', 'GFA ').str.replace('qa_', 'QA ').str.replace('rd', 'RD ').str.replace('_', ' ').str.replace(' regr', '')

fig, ax = plt.subplots(1,1, figsize=(8,8))

sns.barplot(x=pltdf.iloc[:,0], y = labels, palette = sns.color_palette("Blues_r", n_colors=round(len(pltdf)*1.5)), ax=ax)

# ax.set_xlim(-1.5, 1.5)
fig.suptitle("Median {} Model Loadings for Tract Integrity".format(metname_cap), ha='center',
            fontsize=20, fontweight='bold')
ax.set_xlabel('Canonical Mode {}'.format(ex_comp.lstrip('Variate_')))
fig.tight_layout()
fig.savefig(analysis + "/figures/DWILoadings_{}.png".format(metric_name), dpi=300, transparent=True) 

In [None]:
print('Loading significance across bootstraps: \n')

dwi_loadings = pd.DataFrame(diffusion_loadings[:,0,:], index = in_ymat_data.columns)
shuff_dwis = pd.DataFrame(shuff_dwi_loadings[:,0,:], index = in_ymat_data.columns)

# Create list and run significance tests
pvals_loadings_cor_dwi = []

for i in range(0, len(dwi_loadings)):
    val = abs(dwi_loadings[i]) # Bootstrapped values
    main_col = dwi_loadings.iloc[i,:]
    shuff_col = shuff_dwis.iloc[i, :] # Select row representing that variable
    assert len(main_col) == 10000
    assert len(shuff_col) == 10000
    stat, pval_loadings = wilcoxon(main_col, shuff_col, alternative= 'two-sided')
    print("For {}: stat={}, p={}".format(dwi_loadings.index[i], round(stat, 5), round(pval_loadings, 3)))
    pvals_loadings_cor_dwi.append([dwi_loadings.index[i], stat, pval_loadings])

In [None]:
# Save results in dataframe and perform FDR correction
stat_table_dwi = pd.DataFrame(pvals_loadings_cor_dwi, columns = ['Variable', 'Wilcoxon Stat', 'p-value'])
stat_table_dwi['Passed'], stat_table_dwi['FDR_pval'] = fdr(stat_table_dwi['p-value'])
stat_table_dwi['FDR_pval'] = round(stat_table_dwi['FDR_pval'], 3)
stat_table_dwi['p-value'] = round(stat_table_dwi['p-value'], 3)
stat_table_dwi.to_csv(analysis + '/figures/table_loadings_dwi_{}_{}.csv'.format(metric_name, today))

In [None]:
# Select variables that remained significant after FDR correction
dwi_loads_insig = stat_table_dwi[stat_table_dwi['Passed'] == False]['Variable']

In [None]:
# Compute statistics and confidence intervals
adv_loading_iters_full = pd.DataFrame(adversity_loadings[:, 0, :], index = in_xmat_data.columns) #3 dims are variables, components, iterations--selecting first component, all vars, all iters
dwi_loading_iters_full = pd.DataFrame(diffusion_loadings[:, 0, :], index = in_ymat_data.columns) #3 dims are variables, components, iterations--selecting first component, all vars, all iters

# Drop variables that did not differ significantly from shuffled data
adv_loading_iters = adv_loading_iters_full.drop(adv_loads_insig)
dwi_loading_iters = dwi_loading_iters_full.drop(dwi_loads_insig)

# adv_loading_mmod = pd.DataFrame(pheno_weights_mmod[['Variate_1']]).rename(columns = {'Variate_1':'Model Loadings'})
adv_loading_ci = pd.DataFrame(np.percentile(adv_loading_iters, 2.5, axis = 1), columns = ['2.5th percentile (Bootstrapped Models)'], 
                              index = adv_loading_iters.index)
adv_loading_ci['97.5th percentile (Bootstrapped Models)'] = np.percentile(adv_loading_iters, 97.5, axis = 1)
adv_loading_ci['Median (Bootstrapped Models)'] = np.median(adv_loading_iters, axis = 1)
adv_loading_ci['Mean (Bootstrapped Models)'] = np.mean(adv_loading_iters, axis = 1)
adv_loading_ci['Standard Deviation (Bootstrapped Models)'] = np.std(adv_loading_iters, axis = 1)
adv_loading_ci = adv_loading_ci[['Mean (Bootstrapped Models)', 'Standard Deviation (Bootstrapped Models)', 
                                 'Median (Bootstrapped Models)', '2.5th percentile (Bootstrapped Models)', '97.5th percentile (Bootstrapped Models)']]
adv_loading_ci['Abs_Median'] = abs(adv_loading_ci['Median (Bootstrapped Models)'])
adv_loading_ci.sort_values(by='Abs_Median', inplace=True, ascending=False)
adv_loading_ci.to_csv(analysis + '/figures/Results_Descriptive_Stats_AdversityLoadings_{}_{}.csv'.format(metname_cap, today))

# dwi_loading_mmod = pd.DataFrame(neuro_weights_mmod[['Variate_1']]).rename(columns = {'Variate_1':'Model Loadings'})
dwi_loading_ci = pd.DataFrame(np.percentile(dwi_loading_iters, 2.5, axis = 1), columns = ['2.5th percentile (Bootstrapped Models)'],
                             index = dwi_loading_iters.index)
dwi_loading_ci['97.5th percentile (Bootstrapped Models)'] = np.percentile(dwi_loading_iters, 97.5, axis = 1)
dwi_loading_ci['Median (Bootstrapped Models)'] = np.median(dwi_loading_iters, axis=1)
dwi_loading_ci['Mean (Bootstrapped Models)'] = np.mean(dwi_loading_iters, axis=1)
dwi_loading_ci['Standard Deviation (Bootstrapped Models)'] = np.std(dwi_loading_iters, axis=1)
dwi_loading_ci = dwi_loading_ci[['Mean (Bootstrapped Models)', 'Standard Deviation (Bootstrapped Models)', 'Median (Bootstrapped Models)', 
                                 '2.5th percentile (Bootstrapped Models)', '97.5th percentile (Bootstrapped Models)']]
dwi_loading_ci['Abs_Median'] = abs(dwi_loading_ci['Median (Bootstrapped Models)'])
dwi_loading_ci.sort_values(by='Abs_Median', inplace=True, ascending=False)
dwi_loading_ci.to_csv(analysis + '/figures/Results_Descriptive_Stats_DiffusionLoadings_{}_{}.csv'.format(metname_cap, today))

In [None]:
#Drop vars that are not significant
adv_loads_dropped = adv_loading_medians.drop(adv_loads_insig).iloc[0:12,:]
dwi_loads_dropped = dwi_loading_medians.drop(dwi_loads_insig).iloc[0:20,:]

In [None]:
adv_loading_iters_long1 = adv_loading_iters.transpose().reset_index().melt(id_vars = 'index', var_name = 'age_exp')
adv_loading_iters_long1['age_exp'] = adv_loading_iters_long1['age_exp'].str.rstrip('_regr')
dwi_loading_iters_long1 = dwi_loading_iters.transpose().reset_index().melt(id_vars = 'index', var_name = 'metric')
print(adv_loading_iters_long1.shape)

adv_subset_df = pd.DataFrame(adv_loads_dropped.loc[:, ex_comp]).reset_index().rename(columns = {'index':'age_exp'})
adv_subset_df['age_exp'] = adv_subset_df['age_exp'].str.rstrip('_regr')
adv_loading_iters_long = pd.merge(adv_subset_df, adv_loading_iters_long1, how='inner')
if len(adv_loading_iters_long) == len(adv_loading_iters_long1):
    pass 
else: 
    print('Merge error')

dwi_subset_df = pd.DataFrame(dwi_loads_dropped.loc[:, ex_comp]).reset_index().rename(columns = {'index':'metric'})
dwi_loading_iters_long = pd.merge(dwi_subset_df, dwi_loading_iters_long1, how='inner')
if len(dwi_loading_iters_long) == len(dwi_loading_iters_long1):
    pass 
else: 
    print('Merge error')

adv_loadings_long = adv_loads_dropped.loc[:, ex_comp].reset_index().melt(id_vars = 'index', value_name = 'value_sing').rename(columns = {'index':'age_exp'})
adv_loadings_long['value_sing_abs'] = abs(adv_loadings_long['value_sing'])
adv_loadings_long = adv_loadings_long.sort_values(by='value_sing_abs', ascending=False)
dwi_loadings_long = dwi_loads_dropped.loc[:, ex_comp][0:20].reset_index().melt(id_vars = 'index', value_name = 'value_sing').rename(columns = {'index':'metric'})
dwi_loadings_long['value_sing_abs'] = abs(dwi_loadings_long['value_sing'])
dwi_loadings_long = dwi_loadings_long.sort_values(by='value_sing_abs', ascending=False)
print(adv_loadings_long.shape)

In [None]:
# Plot medians on top of bootstrapped distributions

fig, (ax, ax2) = plt.subplots(1, 2, figsize = (16,8))
# Adversity data
adv_loading_iters_long['age_exp'] = adv_loading_iters_long['age_exp'].str.replace('_', ' ').str.replace('all', 'Exposures at age').str.replace('.0 regr', '')
adv_loadings_long['age_exp'] = adv_loadings_long['age_exp'].str.replace('_', ' ').str.replace('all', 'Exposures at age').str.replace('.0 regr', '')
sns.violinplot(data = adv_loading_iters_long,
              x = 'value',
              y = 'age_exp',
              saturation = .7,
              palette = sns.color_palette("Reds_r", n_colors=round(len(pltdf2)*1.25)),
              inner = None,
              ax=ax)
sns.pointplot(data = adv_loadings_long,
              x = 'value_sing',
              y = 'age_exp',
              color = 'black',
              scale=1.5,
              join=False,
              ax=ax)
sns.pointplot(data = adv_loadings_long,
              x = 'value_sing',
              y = 'age_exp',
              color = 'white',
              scale=1,
              join=False,
              ax=ax)

fig.suptitle('Strongest Loadings across 10,000 iterations ({} Model)'.format(metname_cap), fontsize = 22, fontweight='bold')
ax.set_ylabel('Adversity Exposure by Age', fontsize = 20, fontweight='bold')
ax.set_xlabel('Bootstrapped Loading Values\n({} Model)'.format(metname_cap), fontsize = 20, fontweight='bold')

#DWI data
dwi_loading_iters_long['metric'] = dwi_loading_iters_long['metric'].str.replace('gfa_', 'GFA ').str.replace('qa_', 'QA ').str.replace('rd', 'RD ').str.replace('_', ' ').str.replace(' regr', '')
dwi_loadings_long['metric'] = dwi_loadings_long['metric'].str.replace('gfa_', 'GFA ').str.replace('qa_', 'QA ').str.replace('rd', 'RD ').str.replace('_', ' ').str.replace(' regr', '')

sns.violinplot(data = dwi_loading_iters_long,
              x = 'value',
              y = 'metric',
              palette = sns.color_palette("Blues_r", n_colors=round(len(pltdf2)*1.25)),
              inner = None,
              color = 'white',
              saturation = .7,
              ax=ax2)
sns.pointplot(data = dwi_loadings_long,
              x = 'value_sing',
              y = 'metric',
              color = 'black',
              join=False,
              scale = 1.5,
              ax=ax2,
              fillstyle=None)
sns.pointplot(data = dwi_loadings_long,
              x = 'value_sing',
              y = 'metric',
              color = 'white',
              join=False,
              scale = 1,
              ax=ax2,
              fillstyle=None)

ax2.set_ylabel('White Matter Tracts', fontsize = 20, fontweight='bold')
ax2.set_xlabel('Bootstrapped Loading Values\n({} Model)'.format(metname_cap), fontsize = 20, fontweight='bold')
plt.tight_layout()

fig.savefig(analysis + "/figures/Bootstrapped_MainModel_Loadings_{}.png".format(metname_cap), dpi=300, transparent=True) 

### Symptom Analyses

In [None]:
# Read in model weights and loadings
# GFA model
gfa_adv_tranf_files = glob(analysis + '/model_output_0.5_0.5_gfa_{}/Transformed_Data_AdvExp_iteration*.csv'.format(analysis_date))
gfa_dwi_tranf_files = glob(analysis + '/model_output_0.5_0.5_gfa_{}/Transformed_Data_DWI_iteration*.csv'.format(analysis_date))

assert len(gfa_adv_tranf_files) == 10000
assert len(gfa_dwi_tranf_files) == 10000

# QA Model
qa_adv_tranf_files = glob(analysis + '/model_output_0.5_0.5_qa_{}/Transformed_Data_AdvExp_iteration*.csv'.format(analysis_date))
qa_dwi_tranf_files = glob(analysis + '/model_output_0.5_0.5_qa_{}/Transformed_Data_DWI_iteration*.csv'.format(analysis_date))

assert len(qa_adv_tranf_files) == 10000
assert len(qa_dwi_tranf_files) == 10000

rd_adv_tranf_files = glob(analysis + '/model_output_0.5_0.5_rd_{}/Transformed_Data_AdvExp_iteration*.csv'.format(analysis_date))
rd_dwi_tranf_files = glob(analysis + '/model_output_0.5_0.5_rd_{}/Transformed_Data_DWI_iteration*.csv'.format(analysis_date))

assert len(rd_adv_tranf_files) == 10000
assert len(rd_dwi_tranf_files) == 10000

In [None]:
# Run this if one of the data folders isn't at 10000 to see which iteration is missing

# x = list(range(0, 10000))

# nos = []
# for i in range(0, len(qa_adv_tranf_files)):
#     iterno = int(qa_adv_tranf_files[i].split('iteration')[-1].rstrip('.csv'))
#     nos.append(iterno)
    
# [ele for ele in x if ele not in nos]

In [None]:
# Generate column names for adversity and DWI transformed data
cols_adv = [] 
cols_dwi = []

for i in range(1, 19):
    cols_adv.append('Variate_{}_Stress'.format(i))
    cols_dwi.append('Variate_{}_DWI'.format(i))

In [None]:
# Read in transformed data
gfa_adv_transformed = np.ones((107, 18, 10000))
gfa_dwi_transformed = np.ones((107, 18, 10000))

qa_adv_transformed = np.ones((107, 18, 10000))
qa_dwi_transformed = np.ones((107, 18, 10000))

rd_adv_transformed = np.ones((107, 18, 10000))
rd_dwi_transformed = np.ones((107, 18, 10000))

for i in range(0, len(gfa_adv_tranf_files)):
    assert len(gfa_adv_tranf_files) == 10000
    assert len(gfa_dwi_tranf_files) == 10000
    gfa_adv_transformed[:, :, i] = pd.read_csv(gfa_adv_tranf_files[i], index_col=0, header = 0)
    gfa_dwi_transformed[:, :, i] = pd.read_csv(gfa_dwi_tranf_files[i], index_col=0, header = 0)
    assert len(qa_adv_tranf_files) == 10000
    assert len(qa_dwi_tranf_files) == 10000
    qa_adv_transformed[:, :, i] = pd.read_csv(qa_adv_tranf_files[i], index_col=0, header = 0)
    qa_dwi_transformed[:, :, i] = pd.read_csv(qa_dwi_tranf_files[i], index_col=0, header = 0)
    assert len(rd_adv_tranf_files) == 10000
    assert len(rd_dwi_tranf_files) == 10000
    rd_adv_transformed[:, :, i] = pd.read_csv(rd_adv_tranf_files[i], index_col=0, header = 0)
    rd_dwi_transformed[:, :, i] = pd.read_csv(rd_dwi_tranf_files[i], index_col=0, header = 0)

In [None]:
# Compute means across iterations
gfa_tranf_df_adv = pd.DataFrame(gfa_adv_transformed.mean(axis=2), columns = cols_adv)
gfa_tranf_df_dwi = pd.DataFrame(gfa_dwi_transformed.mean(axis=2), columns = cols_dwi)

qa_tranf_df_adv = pd.DataFrame(qa_adv_transformed.mean(axis=2), columns = cols_adv)
qa_tranf_df_dwi = pd.DataFrame(qa_dwi_transformed.mean(axis=2), columns = cols_dwi)

rd_tranf_df_adv = pd.DataFrame(rd_adv_transformed.mean(axis=2), columns = cols_adv)
rd_tranf_df_dwi = pd.DataFrame(rd_dwi_transformed.mean(axis=2), columns = cols_dwi)

assert rd_tranf_df_adv.shape == (107, 18)

In [None]:
ex_comp = "Variate_1"

from scipy.stats import ttest_ind, spearmanr
adv_dfs = [gfa_tranf_df_adv, qa_tranf_df_adv, rd_tranf_df_adv]
dwi_dfs = [gfa_tranf_df_dwi, qa_tranf_df_dwi, rd_tranf_df_dwi]

test_metrics = ['GFA', "QA", "RD"]
for i in range(0, len(test_metrics)):
    tranf_df_adv = adv_dfs[i]
    tranf_df_dwi = dwi_dfs[i]
    
    models_corr_df = pd.concat([full_df, tranf_df_dwi, tranf_df_adv], axis=1) #.dropna(axis=0)
    m_df = models_corr_df[models_corr_df['sex'] == 0]
    f_df = models_corr_df[models_corr_df['sex'] == 1]

    r1, p1 = ttest_ind(m_df['{}_Stress'.format(ex_comp)], f_df['{}_Stress'.format(ex_comp)], nan_policy= 'omit')
    r2, p2 = ttest_ind(m_df['{}_DWI'.format(ex_comp)], f_df['{}_DWI'.format(ex_comp)], nan_policy= 'omit')
    r3, p3 = spearmanr(models_corr_df['years_education'], models_corr_df['{}_DWI'.format(ex_comp)], nan_policy= 'omit')
    r4, p4 = spearmanr(models_corr_df['years_education'], models_corr_df['{}_Stress'.format(ex_comp)], nan_policy= 'omit')
    r5, p5 = spearmanr(models_corr_df['combined_income'], models_corr_df['{}_DWI'.format(ex_comp)], nan_policy= 'omit')
    r6, p6 = spearmanr(models_corr_df['combined_income'], models_corr_df['{}_Stress'.format(ex_comp)], nan_policy= 'omit')
    
    print("\n{} model:\n".format(test_metrics[i]))
    print('Trauma Exp {} sex difference: {}, {}'.format(ex_comp, round(r1, 5), round(p1, 3)))
    print('Diffusion {} sex difference: {}, {}'.format(ex_comp, round(r2, 5), round(p2, 3)))
    print('Trauma comp correlation with education: {}, {}'.format(round(r3, 5), round(p3, 5)))
    print('Diffusion comp correlation with education: {}, {}'.format(round(r4, 5), round(p4, 5)))
    print('Trauma comp correlation with income: {}, {}'.format(round(r5, 5), round(p5, 5)))
    print('Diffusion comp correlation with income: {}, {}'.format(round(r6, 5), round(p6, 5)))

In [None]:
# Add code here to compare the main model transformed df to the top loading values
gfa_corr_df = pd.concat([full_df, gfa_tranf_df_dwi, gfa_tranf_df_adv], axis=1) #.dropna(axis=0)
qa_corr_df = pd.concat([full_df, qa_tranf_df_dwi, qa_tranf_df_adv], axis=1) #.dropna(axis=0)
rd_corr_df = pd.concat([full_df, rd_tranf_df_dwi, rd_tranf_df_adv], axis=1) #.dropna(axis=0)

In [None]:
fig, axes = plt.subplots(4, 4, figsize = (20,20))
#Set metric we want to plot
name_cap_univ = 'GFA'

# Run plots
for i in range(0, 4):
    for j in range(0, 4):
        if name_cap_univ == 'QA':
            plot_color = '#ac0e7c'
            line_color = '#860b60'
            dwi_names = ['qa_CST_right_regr', 'qa_POPT_left_regr', 'qa_STR_right_regr', 'qa_ST_FO_left_regr']
            adv_names = ['all_4.0_regr', 'all_5.0_regr', 'all_8.0_regr', 'all_3.0_regr']
        elif name_cap_univ == 'GFA':
            plot_color = '#e69e38'
            line_color = '#b87e2c'
            dwi_names = ['gfa_CST_left_regr', 'gfa_POPT_right_regr', 'gfa_FPT_left_regr', 'gfa_SLF_III_right_regr']
            adv_names = ['all_1.0_regr', 'all_0.0_regr', 'all_8.0_regr', 'all_4.0_regr', ]
        elif name_cap_univ == 'RD':
            plot_color = '#0e3bac'
            line_color = '#0b2e86'
            dwi_names = ['rd_T_PREM_left_regr', 'rd_POPT_left_regr', 'rd_FPT_left_regr', 'rd_SLF_III_left_regr']
            adv_names = ['all_5.0_regr', 'all_4.0_regr', 'all_10.0_regr', 'all_3.0_regr']
        else:
            print('Error--none of three metric names match')
            
        sns.regplot(x = full_df['{}'.format(adv_names[i])], 
                    y = full_df['{}'.format(dwi_names[j])],
                    ax=axes[i, j], scatter_kws={"color": '{}'.format(plot_color)},
                    line_kws={"color": "{}".format(line_color)})
        axes[i, j].set_xlabel(adv_names[i].replace('all_', 'Exposures at age ').replace('.0_regr',''))
        axes[i, j].set_ylabel(dwi_names[j].replace('{}_'.format(name_cap_univ), '{} '.format(name_cap_univ)).replace('_regr',''.format(name_cap_univ)).replace('_', ' '))
        
        fig.suptitle('Univariate associations among adversity exposure by age and white matter tract integrity ({})\n(covariates regressed)'.format(name_cap_univ),
                    size=20, weight = 'bold')
        plt.tight_layout()
plt.savefig(analysis + "/figures/{}_Univariate_Plots_{}.png".format(name_cap_univ, today), dpi=300, transparent=True) 
print(analysis + "/figures/{}_Univariate_Plots_{}.png".format(name_cap_univ, today))

In [None]:
name_cap = 'GFA'
fig, ax1 = plt.subplots(1,1, figsize=(6,6))
ec_filt_df = gfa_corr_df[(gfa_corr_df['all_8.0'] > 0) | (gfa_corr_df['all_5.0'] > 0)]


sns.regplot(gfa_corr_df['{}_Stress'.format(ex_comp)], gfa_corr_df['{}_DWI'.format(ex_comp)],
           scatter_kws={"color": '#e69e38'}, line_kws={"color": "#b87e2c"})

ax1.grid(False)
# ax1.set_xlim(-0.04, 0.04)
# ax1.set_ylim(-0.04, 0.06)
ax1.set_xlabel('Adversity Variate\n({} Model; Mode 1)'.format(name_cap), fontsize=20, weight='bold')
ax1.set_ylabel('Tract Integrity Variate\n({} Model; Mode 1)'.format(name_cap), fontsize=20, weight='bold')
fig.tight_layout()
fig.savefig(analysis + "/figures/Comp1Adv_Comp1DWI_{}.png".format(name_cap), dpi=300, transparent=True) 

In [None]:
name_cap = 'QA'
fig, ax1 = plt.subplots(1,1, figsize=(6,6))
ec_filt_df = qa_corr_df[(qa_corr_df['all_4.0'] > 0) | (qa_corr_df['all_5.0'] > 0)]


sns.regplot(qa_corr_df['{}_Stress'.format(ex_comp)], qa_corr_df['{}_DWI'.format(ex_comp)],
           scatter_kws={"color": '#ac0e7c'}, line_kws={"color": "#860b60"})

ax1.grid(False)
# ax1.set_xlim(-0.04, 0.04)
# ax1.set_ylim(-0.04, 0.06)
ax1.set_xlabel('Adversity Variate\n({} Model; Mode 1)'.format(name_cap), fontsize=20, weight='bold')
ax1.set_ylabel('Tract Integrity Variate\n({} Model; Mode 1)'.format(name_cap), fontsize=20, weight='bold')
fig.tight_layout()
fig.savefig(analysis + "/figures/Comp1Adv_Comp1DWI_{}.png".format(name_cap), dpi=300, transparent=True) 
plt.show()

In [None]:
name_cap = 'RD'
fig, ax1 = plt.subplots(1,1, figsize=(6,6))
ec_filt_df = rd_corr_df[(rd_corr_df['all_4.0'] > 0) | (rd_corr_df['all_5.0'] > 0)]


sns.regplot(rd_corr_df['{}_Stress'.format(ex_comp)], rd_corr_df['{}_DWI'.format(ex_comp)],
           scatter_kws={"color": '#0e3bac'}, line_kws={"color": "#0b2e86"})

ax1.grid(False)
# ax1.set_xlim(-0.04, 0.04)
# ax1.set_ylim(-0.04, 0.06)
ax1.set_xlabel('Adversity Variate\n({} Model; Mode 1)'.format(name_cap), fontsize=20, weight='bold')
ax1.set_ylabel('Tract Integrity Variate\n({} Model; Mode 1)'.format(name_cap), fontsize=20, weight='bold')
fig.tight_layout()
fig.savefig(analysis + "/figures/Comp1Adv_Comp1DWI_{}.png".format(name_cap), dpi=300, transparent=True) 

In [None]:
# Combine adversity and DWI dfs
gfa_tranf_df_adv.columns = "gfa_" + gfa_tranf_df_adv.columns
qa_tranf_df_adv.columns = "qa_" + qa_tranf_df_adv.columns
rd_tranf_df_adv.columns = "rd_" + rd_tranf_df_adv.columns

gfa_tranf_df_dwi.columns = "gfa_" + gfa_tranf_df_dwi.columns
qa_tranf_df_dwi.columns = "qa_" + qa_tranf_df_dwi.columns
rd_tranf_df_dwi.columns = "rd_" + rd_tranf_df_dwi.columns

full_adv_df = pd.concat([full_df, gfa_tranf_df_adv, qa_tranf_df_adv, rd_tranf_df_adv], axis=1)
full_dwi_df = pd.concat([full_df, gfa_tranf_df_dwi, qa_tranf_df_dwi, rd_tranf_df_dwi], axis=1) #.dropna(axis=0)

full_all_df = pd.concat([full_adv_df, gfa_tranf_df_dwi, qa_tranf_df_dwi, rd_tranf_df_dwi], axis=1)

In [None]:
# Test for multicollinearity
test_met = 'DWI'
sns.histplot(full_dwi_df['gfa_Variate_1_DWI'])

vif_df = full_dwi_df[['gfa_Variate_1_DWI', 'qa_Variate_1_DWI', 
                      'rd_Variate_1_DWI']].dropna(axis=0).reset_index()
  
# Get variance inflation factor (code from https://stackoverflow.com/questions/42658379/variance-inflation-factor-in-python)
print(pd.Series(np.linalg.inv(vif_df.corr().to_numpy()).diagonal(), 
                 index=vif_df.columns, 
                 name='VIF'))

vif_df2 = full_adv_df[['gfa_Variate_1_Stress', 'qa_Variate_1_Stress', 
                      'rd_Variate_1_Stress']].dropna(axis=0).reset_index()
  
# Get variance inflation factor (code from https://stackoverflow.com/questions/42658379/variance-inflation-factor-in-python)
print(pd.Series(np.linalg.inv(vif_df2.corr().to_numpy()).diagonal(), 
                 index=vif_df2.columns, 
                 name='VIF'))

In [None]:
from statsmodels.stats.multitest import fdrcorrection as fdr
from statsmodels.discrete.count_model import ZeroInflatedPoisson
from statsmodels.discrete.discrete_model import NegativeBinomial
from statsmodels.regression.linear_model import OLS
from statsmodels.discrete.discrete_model import Logit

scaler = StandardScaler()
full_all_df['tsc_sum_tranf'] = np.sqrt(full_df['tsc_sum'] + 1)
full_all_df['ri_ptsd_tranf'] = np.log(full_df['ri_ptsd_total'] + 1)
full_all_df['total_probs_tranf'] = np.log(full_df['Total_Problems_Total'] + 1)

# Recode diagnostic info
d = {2: 0, 3:0}
full_all_df['diagnostic_group_bin'] = full_all_df['diagnostic_group'].replace(2, 0, regex=True).replace(3, np.nan, regex=True)

mode = 'DWI'
yvar = 'Total_Problems_TScore' 
sns.displot(full_all_df[yvar])
plt.show()
print('\nMean = {}, var = {}\n'.format(full_all_df[yvar].mean(), full_all_df[yvar].var())) # If variance is greater than mean, data is overdispersed, Poisson contraindicated

bx_df = full_all_df[['tsc_sum', 'age_at_scan', 'asr_age','sex', 'gender', 'diagnostic_group', 'combined_income', 'years_education', 'Total_Problems_Total',
                    'gfa_Variate_1_DWI', 'qa_Variate_1_DWI', 'rd_Variate_1_DWI', 'gfa_Variate_1_Stress', 'qa_Variate_1_Stress', 'Total_Problems_TScore',
                     'rd_Variate_1_Stress', 'total_probs_tranf', 'tsc_sum_tranf', 'Internalizing_Problems_Total',  'site_bin',
                    'Externalizing_Problems_Total', 'ri_ptsd_tranf', 'ri_ptsd_total', 'ri_ptsd_past', 'diagnostic_group_bin']].dropna(axis=0, 
                                                                                    subset =[yvar,
                                                                             'combined_income', 
                                                                             'years_education',
                                                                             # 'Total_Problems_Total'
                                                                            ]).reset_index() # 'combined_income', 'years_education'


en_df = pd.to_numeric(bx_df[[yvar]].iloc[:,0])
ex_df = bx_df[['gfa_Variate_1_{}'.format(mode), 'qa_Variate_1_{}'.format(mode), 'rd_Variate_1_{}'.format(mode),
               'combined_income', 'years_education']] 
ex_df = pd.DataFrame(scaler.fit_transform(ex_df), columns = ex_df.columns)
     
ex_df = sm.add_constant(ex_df)

symptom_mod = OLS(endog = en_df, exog = ex_df) 
symptom_results = symptom_mod.fit()

print(symptom_results.summary(), '\nBIC: ', symptom_results.bic )
# print('For FDR: p-val for DWI QA compoennt is {}'.format(symptom_results.pvalues['qa_Variate_1_DWI']))

In [None]:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize = (14, 6))
sns.regplot(y ='ri_ptsd_tranf', x = 'qa_Variate_1_DWI', data = bx_df, ax=ax1, scatter_kws={"color": '#407caa'}, line_kws={"color": "#163d6d"})
ax1.set_ylabel('Transformed Post-Traumatic Stress\nSymptoms Scores')
ax1.set_xlabel('White Matter Variate 1 (QA model)')

sns.regplot(y ='ri_ptsd_tranf', x = 'rd_Variate_1_DWI', data = bx_df, ax=ax2, scatter_kws={"color": '#8cb0cc'}, line_kws={"color": "#407caa"})
ax2.set_ylabel('Transformed Post-Traumatic Stress\nSymptoms Scores')
ax2.set_xlabel('White Matter Variate 1 (RD model)')
plt.tight_layout()

fig.savefig(analysis + "/figures/TSC_DWI_StressPlot_{}.png".format(today), dpi=300, transparent=True)
print(analysis + "/figures/TSC_DWI_StressPlot_{}.png".format(today))

In [None]:
#  See whether top individual tracts in QA loadings are associated with symptoms 

yvar = 'ri_ptsd_tranf'

bx_df = full_all_df[[yvar, 'combined_income', 'years_education', 'age_at_scan', 
                     'sex', "all_8.0_regr", "all_7.0_regr",
                     'all_4.0_regr', 'all_5.0_regr','all_6.0_regr',
                     "qa_CST_left_regr", "qa_CST_right_regr", "qa_POPT_left_regr", "qa_POPT_right_regr",
         "qa_STR_right_regr", "qa_STR_left_regr", "qa_CC_1_regr", 'qa_FPT_right',
         'qa_T_PREM_right', 'qa_T_PREM_left', 'qa_ST_FO_left']].dropna(axis=0)


en_df = bx_df[[yvar]]

tracts = ["qa_CST_left_regr", "qa_CST_right_regr", "qa_POPT_left_regr", "qa_POPT_right_regr",
         "qa_STR_right_regr", "qa_STR_left_regr", "qa_CC_1_regr", 'qa_FPT_right',
         'qa_T_PREM_right', 'qa_T_PREM_left', 'qa_ST_FO_left']
for i in range(0, len(tracts)):
    ex_df = bx_df[[tracts[i]]]
    ex_df = sm.add_constant(ex_df)

    testmod1 = Poisson(endog = en_df, exog = ex_df) 
    testmod1_results = testmod1.fit()
    print(testmod1_results.summary())

In [None]:
# See which dwi variables are most strongly correlated with components

alldwidf3 = full_dwi_df.loc[:, "rd_AF_left_regr":"rd_ST_PREM_right_regr"]
alldwidf3.columns = alldwidf3.columns.str.replace('_', ' ').str.replace('regr', '')

alldwidf = alldwidf3 #pd.concat([alldwidf1, alldwidf2, alldwidf3], axis=1)



corrMatrix2 = alldwidf3.corr(method='spearman')

fig, ax = plt.subplots(1, 1, figsize = (30,30))
sns.heatmap(corrMatrix2, annot=True, ax=ax, vmin = -1, vmax=1, annot_kws = {'fontsize':8}, cmap= 'coolwarm')
 
plt.savefig(analysis + "/figures/DWI_Heatmap_RD_{}.png".format(today), dpi=300, transparent=True) 

In [None]:
print(analysis + '/figures')