In [None]:
from junifer.storage import HDF5FeatureStorage
from julearn.api import run_cross_validation
from julearn.pipeline import PipelineCreator
from julearn.viz import plot_scores
from julearn.stats.corrected_ttest import corrected_ttest
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.metrics import mean_absolute_error
import numpy as np

In [6]:
df_demographics = pd.read_csv('./data/IXI_demograpic_data.csv',sep=',')
df_demographics.rename(columns={"IXI_ID": "subject"}, inplace=True)
df_demographics['subject'] = df_demographics['subject'].apply(lambda x: f'sub-IXI{x}')
storage = HDF5FeatureStorage(uri='./data/IXI_Histograms_Parcels.hdf5')
df_parcellations = storage.read_df('VBM_GM_Schaefer100x7_mean_aggregation')
df_hists = storage.read_df('VBM_GM_Histogram_100bins_IXI_hist')

df_parcellations.dropna(inplace=True)
df_hists.dropna(inplace=True)

In [10]:
df_demographics = pd.read_csv('/home/hsreekri/Julearn_predictions/data/IXI_demograpic_data.csv',sep=',')
df_demographics.rename(columns={"IXI_ID": "subject"}, inplace=True)
df_demographics['subject'] = df_demographics['subject'].apply(lambda x: f'sub-IXI{x}')
df_hists.columns = df_hists.columns.astype(str)

X_hists = list(df_hists.columns)
X_hists = X_hists[1:100]
df_full_histograms = df_hists.merge(df_demographics, on="subject")

df_parcellations.columns = df_parcellations.columns.astype(str)
X_parcels = list(df_parcellations.columns)
X_parcels = X_parcels[1:100]
df_full_parcellations = df_parcellations.merge(df_demographics, on="subject")

In [None]:
df_demographics['SEX_ID (1=m, 2=f)']

In [None]:
# For SVM (Support Vector Machine)
creator_svm = PipelineCreator(problem_type="classification")
creator_svm.add("zscore")
creator_svm.add(
    "svm",
    C=(0.001, 100, "log-uniform"),
)

search_params_svm = {
    "kind": "optuna",
    "cv": 4
}

scoring = ["balanced_accuracy", "accuracy"]

# SVM on histograms
scores_hists_svm, model_hists_svm, inspector_hists_svm = run_cross_validation(
    X=X_hists,
    y='SEX_ID (1=m, 2=f)',
    data=df_full_histograms,
    search_params=search_params_svm,
    model=creator_svm,
    return_train_score=True,
    return_inspector=True,
    cv=4,
    scoring=scoring,
)

# SVM on parcellations
scores_schaefer_svm, model_schaefer_svm, inspector_schaefer_svm = run_cross_validation(
    X=X_parcels,
    y='SEX_ID (1=m, 2=f)',
    data=df_full_parcellations,
    search_params=search_params_svm,
    model=creator_svm,
    return_train_score=True,
    return_inspector=True,
    cv=4,
    scoring=scoring,
)





In [58]:
scores_schaefer_svm

Unnamed: 0,fit_time,score_time,estimator,test_balanced_accuracy,train_balanced_accuracy,test_accuracy,train_accuracy,n_train,n_test,repeat,fold,cv_mdsum,model
0,0.493587,0.005574,"OptunaSearchCV(cv=KFold(n_splits=4, random_sta...",0.695046,0.941193,0.692913,0.942257,381,127,0,0,bc7087515161a73a5a6aff57863f3803,IXI_Schaefer_SVM_100
1,0.503228,0.006035,"OptunaSearchCV(cv=KFold(n_splits=4, random_sta...",0.670272,0.997006,0.669291,0.997375,381,127,0,1,bc7087515161a73a5a6aff57863f3803,IXI_Schaefer_SVM_100
2,0.502407,0.005967,"OptunaSearchCV(cv=KFold(n_splits=4, random_sta...",0.772597,0.970153,0.779528,0.971129,381,127,0,2,bc7087515161a73a5a6aff57863f3803,IXI_Schaefer_SVM_100
3,0.497863,0.00599,"OptunaSearchCV(cv=KFold(n_splits=4, random_sta...",0.691381,0.996815,0.685039,0.997375,381,127,0,3,bc7087515161a73a5a6aff57863f3803,IXI_Schaefer_SVM_100


In [59]:
scores_hists_svm

Unnamed: 0,fit_time,score_time,estimator,test_balanced_accuracy,train_balanced_accuracy,test_accuracy,train_accuracy,n_train,n_test,repeat,fold,cv_mdsum,model
0,0.496066,0.006022,"OptunaSearchCV(cv=KFold(n_splits=4, random_sta...",0.642931,0.698092,0.661417,0.706037,381,127,0,0,bc7087515161a73a5a6aff57863f3803,IXI_Histograms_SVM_100
1,0.486439,0.005433,"OptunaSearchCV(cv=KFold(n_splits=4, random_sta...",0.557973,0.851768,0.566929,0.855643,381,127,0,1,bc7087515161a73a5a6aff57863f3803,IXI_Histograms_SVM_100
2,0.482218,0.00551,"OptunaSearchCV(cv=KFold(n_splits=4, random_sta...",0.581169,0.879641,0.598425,0.884514,381,127,0,2,bc7087515161a73a5a6aff57863f3803,IXI_Histograms_SVM_100
3,0.50584,0.005418,"OptunaSearchCV(cv=KFold(n_splits=4, random_sta...",0.550546,0.581168,0.535433,0.643045,381,127,0,3,bc7087515161a73a5a6aff57863f3803,IXI_Histograms_SVM_100


In [None]:
# For Random Forest
creator_rf = PipelineCreator(problem_type="classification")
creator_rf.add("zscore")
creator_rf.add(
    "rf",
    max_depth=5,
    n_estimators=100,
    
)

search_params_rf = {
    "kind": "grid",
    "cv": 4
}

# Random Forest on histograms
scores_hists_rf, model_hists_rf, inspector_hists_rf = run_cross_validation(
    X=X_hists,
    y='SEX_ID (1=m, 2=f)',
    data=df_full_histograms,
    search_params=search_params_rf,
    model=creator_rf,
    return_train_score=True,
    return_inspector=True,
    cv=4,
    scoring=scoring,
)

# Random Forest on parcellations
scores_schaefer_rf, model_schaefer_rf, inspector_schaefer_rf = run_cross_validation(
    X=X_parcels,
    y='SEX_ID (1=m, 2=f)',
    data=df_full_parcellations,
    search_params=search_params_rf,
    model=creator_rf,
    return_train_score=True,
    return_inspector=True,
    cv=4,
    scoring=scoring,
)

In [55]:
scores_hists_rf

Unnamed: 0,fit_time,score_time,estimator,test_balanced_accuracy,train_balanced_accuracy,test_accuracy,train_accuracy,n_train,n_test,repeat,fold,cv_mdsum,model
0,0.152291,0.005824,"(SetColumnTypes(X_types={}), StandardScaler(),...",0.57727,0.926839,0.590551,0.926509,381,127,0,0,bc7087515161a73a5a6aff57863f3803,IXI_Histograms_RF_100
1,0.147353,0.005591,"(SetColumnTypes(X_types={}), StandardScaler(),...",0.582872,0.91537,0.590551,0.91601,381,127,0,1,bc7087515161a73a5a6aff57863f3803,IXI_Histograms_RF_100
2,0.151157,0.006829,"(SetColumnTypes(X_types={}), StandardScaler(),...",0.59013,0.929691,0.622047,0.934383,381,127,0,2,bc7087515161a73a5a6aff57863f3803,IXI_Histograms_RF_100
3,0.147764,0.00571,"(SetColumnTypes(X_types={}), StandardScaler(),...",0.622578,0.883744,0.614173,0.900262,381,127,0,3,bc7087515161a73a5a6aff57863f3803,IXI_Histograms_RF_100


In [56]:
scores_schaefer_rf

Unnamed: 0,fit_time,score_time,estimator,test_balanced_accuracy,train_balanced_accuracy,test_accuracy,train_accuracy,n_train,n_test,repeat,fold,cv_mdsum,model
0,0.15833,0.005628,"(SetColumnTypes(X_types={}), StandardScaler(),...",0.64306,0.982558,0.669291,0.984252,381,127,0,0,bc7087515161a73a5a6aff57863f3803,IXI_Schaefer_RF_100
1,0.156213,0.005645,"(SetColumnTypes(X_types={}), StandardScaler(),...",0.597837,0.982036,0.622047,0.984252,381,127,0,1,bc7087515161a73a5a6aff57863f3803,IXI_Schaefer_RF_100
2,0.157967,0.005716,"(SetColumnTypes(X_types={}), StandardScaler(),...",0.636104,0.983145,0.669291,0.984252,381,127,0,2,bc7087515161a73a5a6aff57863f3803,IXI_Schaefer_RF_100
3,0.161046,0.006978,"(SetColumnTypes(X_types={}), StandardScaler(),...",0.673,0.959551,0.661417,0.965879,381,127,0,3,bc7087515161a73a5a6aff57863f3803,IXI_Schaefer_RF_100


In [57]:
# For Extra Trees Classifier
creator_et = PipelineCreator(problem_type="classification")
creator_et.add("zscore")
creator_et.add(
    "et",
    max_depth=5,
    n_estimators=100,
)

search_params_et = {
    "kind": "grid",
    "cv": 4
}

scoring = ["balanced_accuracy", "accuracy"]

# Extra Trees on histograms
scores_hists_et, model_hists_et, inspector_hists_et = run_cross_validation(
    X=X_hists,
    y='SEX_ID (1=m, 2=f)',
    data=df_full_histograms,
    search_params=search_params_et,
    model=creator_et,
    return_train_score=True,
    return_inspector=True,
    cv=4,
    scoring=scoring,
)

# Extra Trees on parcellations
scores_schaefer_et, model_schaefer_et, inspector_schaefer_et = run_cross_validation(
    X=X_parcels,
    y='SEX_ID (1=m, 2=f)',
    data=df_full_parcellations,
    search_params=search_params_et,
    model=creator_et,
    return_train_score=True,
    return_inspector=True,
    cv=4,
    scoring=scoring,
)


  warn_with_log(

  warn_with_log(



In [54]:
scores_hists_et

Unnamed: 0,fit_time,score_time,estimator,test_balanced_accuracy,train_balanced_accuracy,test_accuracy,train_accuracy,n_train,n_test,repeat,fold,cv_mdsum,model
0,0.064334,0.00585,"(SetColumnTypes(X_types={}), StandardScaler(),...",0.606811,0.84102,0.622047,0.84252,381,127,0,0,bc7087515161a73a5a6aff57863f3803,IXI_Histograms_ET_100
1,0.062281,0.005584,"(SetColumnTypes(X_types={}), StandardScaler(),...",0.599723,0.822262,0.622047,0.83727,381,127,0,1,bc7087515161a73a5a6aff57863f3803,IXI_Histograms_ET_100
2,0.062742,0.005985,"(SetColumnTypes(X_types={}), StandardScaler(),...",0.62961,0.820379,0.661417,0.832021,381,127,0,2,bc7087515161a73a5a6aff57863f3803,IXI_Histograms_ET_100
3,0.061993,0.005606,"(SetColumnTypes(X_types={}), StandardScaler(),...",0.594759,0.739806,0.582677,0.784777,381,127,0,3,bc7087515161a73a5a6aff57863f3803,IXI_Histograms_ET_100


In [53]:
scores_schaefer_et

Unnamed: 0,fit_time,score_time,estimator,test_balanced_accuracy,train_balanced_accuracy,test_accuracy,train_accuracy,n_train,n_test,repeat,fold,cv_mdsum,model
0,0.062685,0.005615,"(SetColumnTypes(X_types={}), StandardScaler(),...",0.620098,0.927326,0.645669,0.934383,381,127,0,0,bc7087515161a73a5a6aff57863f3803,IXI_Schaefer_ET_100
1,0.062277,0.00561,"(SetColumnTypes(X_types={}), StandardScaler(),...",0.578974,0.874909,0.622047,0.889764,381,127,0,1,bc7087515161a73a5a6aff57863f3803,IXI_Schaefer_ET_100
2,0.06239,0.005459,"(SetColumnTypes(X_types={}), StandardScaler(),...",0.585065,0.928718,0.645669,0.934383,381,127,0,2,bc7087515161a73a5a6aff57863f3803,IXI_Schaefer_ET_100
3,0.062312,0.005519,"(SetColumnTypes(X_types={}), StandardScaler(),...",0.537258,0.754777,0.519685,0.7979,381,127,0,3,bc7087515161a73a5a6aff57863f3803,IXI_Schaefer_ET_100


In [None]:
# For Gradient Boosting Classifier
creator_gb = PipelineCreator(problem_type="classification")
creator_gb.add("zscore")
creator_gb.add(
    "gradientboost",
    learning_rate = 0.02,
)

search_params_gb = {
    "kind": "grid",
    "cv": 4
}

scoring = ["balanced_accuracy", "accuracy"]

# Gradient Boosting on histograms
scores_hists_gb, model_hists_gb, inspector_hists_gb = run_cross_validation(
    X=X_hists,
    y='SEX_ID (1=m, 2=f)',
    data=df_full_histograms,
    search_params=search_params_gb,
    model=creator_gb,
    return_train_score=True,
    return_inspector=True,
    cv=4,
    scoring=scoring,
)

# Gradient Boosting on parcellations
scores_schaefer_gb, model_schaefer_gb, inspector_schaefer_gb = run_cross_validation(
    X=X_parcels,
    y='SEX_ID (1=m, 2=f)',
    data=df_full_parcellations,
    search_params=search_params_gb,
    model=creator_gb,
    return_train_score=True,
    return_inspector=True,
    cv=4,
    scoring=scoring,
)


In [48]:
scores_hists_gb

Unnamed: 0,fit_time,score_time,estimator,test_balanced_accuracy,train_balanced_accuracy,test_accuracy,train_accuracy,n_train,n_test,repeat,fold,cv_mdsum,model
0,0.70662,0.00412,"(SetColumnTypes(X_types={}), StandardScaler(),...",0.597136,0.868032,0.622047,0.868766,381,127,0,0,bc7087515161a73a5a6aff57863f3803,IXI_Histograms_GB_100
1,0.701836,0.004736,"(SetColumnTypes(X_types={}), StandardScaler(),...",0.563129,0.904779,0.574803,0.910761,381,127,0,1,bc7087515161a73a5a6aff57863f3803,IXI_Histograms_GB_100
2,0.697406,0.003826,"(SetColumnTypes(X_types={}), StandardScaler(),...",0.616623,0.905625,0.645669,0.910761,381,127,0,2,bc7087515161a73a5a6aff57863f3803,IXI_Histograms_GB_100
3,0.698807,0.003355,"(SetColumnTypes(X_types={}), StandardScaler(),...",0.592275,0.844575,0.582677,0.868766,381,127,0,3,bc7087515161a73a5a6aff57863f3803,IXI_Histograms_GB_100


In [49]:
scores_schaefer_gb

Unnamed: 0,fit_time,score_time,estimator,test_balanced_accuracy,train_balanced_accuracy,test_accuracy,train_accuracy,n_train,n_test,repeat,fold,cv_mdsum,model
0,0.790873,0.003504,"(SetColumnTypes(X_types={}), StandardScaler(),...",0.662539,0.945797,0.677165,0.950131,381,127,0,0,bc7087515161a73a5a6aff57863f3803,IXI_Schaefer_GB_100
1,0.785952,0.004637,"(SetColumnTypes(X_types={}), StandardScaler(),...",0.706866,0.944065,0.716535,0.947507,381,127,0,1,bc7087515161a73a5a6aff57863f3803,IXI_Schaefer_GB_100
2,0.784314,0.00333,"(SetColumnTypes(X_types={}), StandardScaler(),...",0.682597,0.92826,0.708661,0.931759,381,127,0,2,bc7087515161a73a5a6aff57863f3803,IXI_Schaefer_GB_100
3,0.789366,0.003952,"(SetColumnTypes(X_types={}), StandardScaler(),...",0.654123,0.899994,0.645669,0.91601,381,127,0,3,bc7087515161a73a5a6aff57863f3803,IXI_Schaefer_GB_100


In [43]:
scores_hists_xgb = pd.read_csv('IXI_XGB_scores_hists.csv')
scores_schaefer_xgb = pd.read_csv('IXI_XGB_scores_shaefer.csv')


In [47]:
scores_hists_xgb

Unnamed: 0,fit_time,score_time,estimator,test_balanced_accuracy,train_balanced_accuracy,test_accuracy,train_accuracy,n_train,n_test,repeat,fold,cv_mdsum,model
0,0.152105,0.014713,"Pipeline(steps=[('set_column_types', SetColumn...",0.626548,0.918632,0.645669,0.918635,381,127,0,0,bc7087515161a73a5a6aff57863f3803,IXI_Histograms_XGB_100
1,0.149924,0.014601,"Pipeline(steps=[('set_column_types', SetColumn...",0.549044,0.913761,0.559055,0.918635,381,127,0,1,bc7087515161a73a5a6aff57863f3803,IXI_Histograms_XGB_100
2,0.154075,0.014558,"Pipeline(steps=[('set_column_types', SetColumn...",0.650649,0.928746,0.661417,0.931759,381,127,0,2,bc7087515161a73a5a6aff57863f3803,IXI_Histograms_XGB_100
3,0.149626,0.014585,"Pipeline(steps=[('set_column_types', SetColumn...",0.582216,0.912079,0.574803,0.92126,381,127,0,3,bc7087515161a73a5a6aff57863f3803,IXI_Histograms_XGB_100


In [50]:
scores_schaefer_xgb

Unnamed: 0,fit_time,score_time,estimator,test_balanced_accuracy,train_balanced_accuracy,test_accuracy,train_accuracy,n_train,n_test,repeat,fold,cv_mdsum,model
0,0.26568,0.01611,"Pipeline(steps=[('set_column_types', SetColumn...",0.652735,0.980166,0.669291,0.981627,381,127,0,0,bc7087515161a73a5a6aff57863f3803,IXI_Schaefer_XGB_100
1,0.20903,0.015271,"Pipeline(steps=[('set_column_types', SetColumn...",0.685739,0.969696,0.692913,0.971129,381,127,0,1,bc7087515161a73a5a6aff57863f3803,IXI_Schaefer_XGB_100
2,0.201536,0.014705,"Pipeline(steps=[('set_column_types', SetColumn...",0.64961,0.966291,0.677165,0.968504,381,127,0,2,bc7087515161a73a5a6aff57863f3803,IXI_Schaefer_XGB_100
3,0.198308,0.01475,"Pipeline(steps=[('set_column_types', SetColumn...",0.698957,0.966873,0.692913,0.971129,381,127,0,3,bc7087515161a73a5a6aff57863f3803,IXI_Schaefer_XGB_100


In [51]:
scores_schaefer_svm['model'] = 'IXI_Schaefer_SVM_100'
scores_hists_svm['model'] = 'IXI_Histograms_SVM_100'

scores_schaefer_rf['model'] = 'IXI_Schaefer_RF_100'
scores_hists_rf['model'] = 'IXI_Histograms_RF_100'

scores_schaefer_et['model'] = 'IXI_Schaefer_ET_100'
scores_hists_et['model'] = 'IXI_Histograms_ET_100'

scores_schaefer_gb['model'] = 'IXI_Schaefer_GB_100'
scores_hists_gb['model'] = 'IXI_Histograms_GB_100'

scores_schaefer_xgb['model'] = 'IXI_Schaefer_XGB_100'
scores_hists_xgb['model'] = 'IXI_Histograms_XGB_100'

In [52]:
plot_scores(scores_schaefer_svm,scores_hists_svm,scores_schaefer_rf,scores_hists_rf,scores_schaefer_et,scores_hists_et,scores_schaefer_gb,scores_hists_gb,scores_schaefer_xgb,scores_hists_xgb)

BokehModel(combine_events=True, render_bundle={'docs_json': {'6381c4ad-e869-4e1f-9833-13a3f7666d15': {'version…