In [6]:
from junifer.storage import HDF5FeatureStorage
from julearn.api import run_cross_validation
from julearn.pipeline import PipelineCreator
from julearn.viz import plot_scores
from julearn.stats.corrected_ttest import corrected_ttest
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.metrics import mean_absolute_error
import numpy as np

In [7]:
df_demographics = pd.read_csv('./data/IXI_demograpic_data.csv',sep=',')
df_demographics.rename(columns={"IXI_ID": "subject"}, inplace=True)
df_demographics['subject'] = df_demographics['subject'].apply(lambda x: f'sub-IXI{x}')
storage = HDF5FeatureStorage(uri='./data/IXI_Histograms_Parcels_1000.hdf5')
df_parcellations = storage.read_df('VBM_GM_Schaefer1000x7_mean_aggregation')
df_hists = storage.read_df('VBM_GM_Histogram_1000bins_IXI_hist')

df_parcellations.dropna(inplace=True)
df_hists.dropna(inplace=True)

In [8]:
df_demographics = pd.read_csv('/home/hsreekri/Julearn_predictions/data/IXI_demograpic_data.csv',sep=',')
df_demographics.rename(columns={"IXI_ID": "subject"}, inplace=True)
df_demographics['subject'] = df_demographics['subject'].apply(lambda x: f'sub-IXI{x}')
df_hists.columns = df_hists.columns.astype(str)

X_hists = list(df_hists.columns)
X_hists = X_hists[1:100]
df_full_histograms = df_hists.merge(df_demographics, on="subject")

df_parcellations.columns = df_parcellations.columns.astype(str)
X_parcels = list(df_parcellations.columns)
X_parcels = X_parcels[1:100]
df_full_parcellations = df_parcellations.merge(df_demographics, on="subject")

In [9]:
# For SVM (Support Vector Machine)
creator_svm = PipelineCreator(problem_type="classification")
creator_svm.add("zscore")
creator_svm.add(
    "svm",
    C=(0.001, 100, "log-uniform"),
)

search_params_svm = {
    "kind": "optuna",
    "cv": 4
}

scoring = ["balanced_accuracy", "accuracy"]

# SVM on histograms
scores_hists_svm, model_hists_svm, inspector_hists_svm = run_cross_validation(
    X=X_hists,
    y='SEX_ID (1=m, 2=f)',
    data=df_full_histograms,
    search_params=search_params_svm,
    model=creator_svm,
    return_train_score=True,
    return_inspector=True,
    cv=4,
    scoring=scoring,
)

# SVM on parcellations
scores_schaefer_svm, model_schaefer_svm, inspector_schaefer_svm = run_cross_validation(
    X=X_parcels,
    y='SEX_ID (1=m, 2=f)',
    data=df_full_parcellations,
    search_params=search_params_svm,
    model=creator_svm,
    return_train_score=True,
    return_inspector=True,
    cv=4,
    scoring=scoring,
)





  warn_with_log(

  pipeline = search(  # type: ignore

  new_object = klass(**new_object_params)

[I 2024-10-14 14:18:17,619] A new study created in memory with name: no-name-46a4f60d-c4a9-467d-b2e1-845e5039f3d4
[I 2024-10-14 14:18:17,679] Trial 0 finished with value: 0.556359649122807 and parameters: {'svm__C': 2.55003308735002}. Best is trial 0 with value: 0.556359649122807.
[I 2024-10-14 14:18:17,728] Trial 1 finished with value: 0.5669407894736842 and parameters: {'svm__C': 18.31102122191294}. Best is trial 1 with value: 0.5669407894736842.
[I 2024-10-14 14:18:17,775] Trial 2 finished with value: 0.5302905701754386 and parameters: {'svm__C': 0.017945466299201736}. Best is trial 1 with value: 0.5669407894736842.
[I 2024-10-14 14:18:17,821] Trial 3 finished with value: 0.5302905701754386 and parameters: {'svm__C': 0.036906727566289606}. Best is trial 1 with value: 0.5669407894736842.
[I 2024-10-14 14:18:17,866] Trial 4 finished with value: 0.5538103070175437 and parameters: {'svm__C

In [10]:
scores_schaefer_svm

Unnamed: 0,fit_time,score_time,estimator,test_balanced_accuracy,train_balanced_accuracy,test_accuracy,train_accuracy,n_train,n_test,repeat,fold,cv_mdsum
0,0.512334,0.005657,"OptunaSearchCV(cv=KFold(n_splits=4, random_sta...",0.674288,0.847461,0.692913,0.850394,381,127,0,0,bc7087515161a73a5a6aff57863f3803
1,0.504159,0.005261,"OptunaSearchCV(cv=KFold(n_splits=4, random_sta...",0.700025,0.996875,0.700787,0.997375,381,127,0,1,bc7087515161a73a5a6aff57863f3803
2,0.507236,0.006657,"OptunaSearchCV(cv=KFold(n_splits=4, random_sta...",0.666752,1.0,0.677165,1.0,381,127,0,2,bc7087515161a73a5a6aff57863f3803
3,0.49968,0.006484,"OptunaSearchCV(cv=KFold(n_splits=4, random_sta...",0.684524,0.996875,0.685039,0.997375,381,127,0,3,bc7087515161a73a5a6aff57863f3803


In [11]:
scores_hists_svm

Unnamed: 0,fit_time,score_time,estimator,test_balanced_accuracy,train_balanced_accuracy,test_accuracy,train_accuracy,n_train,n_test,repeat,fold,cv_mdsum
0,0.490699,0.005363,"OptunaSearchCV(cv=KFold(n_splits=4, random_sta...",0.58201,0.861428,0.614173,0.863517,381,127,0,0,bc7087515161a73a5a6aff57863f3803
1,0.462206,0.005237,"OptunaSearchCV(cv=KFold(n_splits=4, random_sta...",0.5,0.5,0.503937,0.580052,381,127,0,1,bc7087515161a73a5a6aff57863f3803
2,0.488528,0.005359,"OptunaSearchCV(cv=KFold(n_splits=4, random_sta...",0.531871,0.908573,0.551181,0.910761,381,127,0,2,bc7087515161a73a5a6aff57863f3803
3,0.469568,0.005499,"OptunaSearchCV(cv=KFold(n_splits=4, random_sta...",0.5,0.5,0.503937,0.580052,381,127,0,3,bc7087515161a73a5a6aff57863f3803


In [12]:
# For Random Forest
creator_rf = PipelineCreator(problem_type="classification")
creator_rf.add("zscore")
creator_rf.add(
    "rf",
    max_depth=5,
    n_estimators=100,
    
)

search_params_rf = {
    "kind": "grid",
    "cv": 4
}

# Random Forest on histograms
scores_hists_rf, model_hists_rf, inspector_hists_rf = run_cross_validation(
    X=X_hists,
    y='SEX_ID (1=m, 2=f)',
    data=df_full_histograms,
    search_params=search_params_rf,
    model=creator_rf,
    return_train_score=True,
    return_inspector=True,
    cv=4,
    scoring=scoring,
)

# Random Forest on parcellations
scores_schaefer_rf, model_schaefer_rf, inspector_schaefer_rf = run_cross_validation(
    X=X_parcels,
    y='SEX_ID (1=m, 2=f)',
    data=df_full_parcellations,
    search_params=search_params_rf,
    model=creator_rf,
    return_train_score=True,
    return_inspector=True,
    cv=4,
    scoring=scoring,
)

  warn_with_log(

  warn_with_log(



In [13]:
scores_hists_rf

Unnamed: 0,fit_time,score_time,estimator,test_balanced_accuracy,train_balanced_accuracy,test_accuracy,train_accuracy,n_train,n_test,repeat,fold,cv_mdsum
0,0.126313,0.006009,"(SetColumnTypes(X_types={}), StandardScaler(),...",0.58598,0.87285,0.598425,0.87664,381,127,0,0,bc7087515161a73a5a6aff57863f3803
1,0.12463,0.005336,"(SetColumnTypes(X_types={}), StandardScaler(),...",0.542163,0.870475,0.543307,0.889764,381,127,0,1,bc7087515161a73a5a6aff57863f3803
2,0.128457,0.006558,"(SetColumnTypes(X_types={}), StandardScaler(),...",0.49694,0.837009,0.535433,0.853018,381,127,0,2,bc7087515161a73a5a6aff57863f3803
3,0.125067,0.005342,"(SetColumnTypes(X_types={}), StandardScaler(),...",0.588418,0.8486,0.590551,0.871391,381,127,0,3,bc7087515161a73a5a6aff57863f3803


In [14]:
scores_schaefer_rf

Unnamed: 0,fit_time,score_time,estimator,test_balanced_accuracy,train_balanced_accuracy,test_accuracy,train_accuracy,n_train,n_test,repeat,fold,cv_mdsum
0,0.151841,0.005412,"(SetColumnTypes(X_types={}), StandardScaler(),...",0.703039,0.964323,0.716535,0.965879,381,127,0,0,bc7087515161a73a5a6aff57863f3803
1,0.15163,0.005343,"(SetColumnTypes(X_types={}), StandardScaler(),...",0.746652,0.963363,0.748031,0.968504,381,127,0,1,bc7087515161a73a5a6aff57863f3803
2,0.153525,0.005396,"(SetColumnTypes(X_types={}), StandardScaler(),...",0.616905,0.96879,0.637795,0.971129,381,127,0,2,bc7087515161a73a5a6aff57863f3803
3,0.152385,0.005373,"(SetColumnTypes(X_types={}), StandardScaler(),...",0.59685,0.946338,0.598425,0.952756,381,127,0,3,bc7087515161a73a5a6aff57863f3803


In [15]:
# For Extra Trees Classifier
creator_et = PipelineCreator(problem_type="classification")
creator_et.add("zscore")
creator_et.add(
    "et",
    max_depth=5,
    n_estimators=100,
)

search_params_et = {
    "kind": "grid",
    "cv": 4
}

scoring = ["balanced_accuracy", "accuracy"]

# Extra Trees on histograms
scores_hists_et, model_hists_et, inspector_hists_et = run_cross_validation(
    X=X_hists,
    y='SEX_ID (1=m, 2=f)',
    data=df_full_histograms,
    search_params=search_params_et,
    model=creator_et,
    return_train_score=True,
    return_inspector=True,
    cv=4,
    scoring=scoring,
)

# Extra Trees on parcellations
scores_schaefer_et, model_schaefer_et, inspector_schaefer_et = run_cross_validation(
    X=X_parcels,
    y='SEX_ID (1=m, 2=f)',
    data=df_full_parcellations,
    search_params=search_params_et,
    model=creator_et,
    return_train_score=True,
    return_inspector=True,
    cv=4,
    scoring=scoring,
)


  warn_with_log(

  warn_with_log(



In [16]:
scores_hists_et

Unnamed: 0,fit_time,score_time,estimator,test_balanced_accuracy,train_balanced_accuracy,test_accuracy,train_accuracy,n_train,n_test,repeat,fold,cv_mdsum
0,0.066674,0.00555,"(SetColumnTypes(X_types={}), StandardScaler(),...",0.600082,0.811397,0.637795,0.818898,381,127,0,0,bc7087515161a73a5a6aff57863f3803
1,0.064677,0.005447,"(SetColumnTypes(X_types={}), StandardScaler(),...",0.636161,0.781575,0.637795,0.813648,381,127,0,1,bc7087515161a73a5a6aff57863f3803
2,0.066113,0.005368,"(SetColumnTypes(X_types={}), StandardScaler(),...",0.505099,0.779328,0.551181,0.800525,381,127,0,2,bc7087515161a73a5a6aff57863f3803
3,0.065478,0.005453,"(SetColumnTypes(X_types={}), StandardScaler(),...",0.525422,0.767675,0.527559,0.800525,381,127,0,3,bc7087515161a73a5a6aff57863f3803


In [17]:
scores_schaefer_et

Unnamed: 0,fit_time,score_time,estimator,test_balanced_accuracy,train_balanced_accuracy,test_accuracy,train_accuracy,n_train,n_test,repeat,fold,cv_mdsum
0,0.060422,0.00541,"(SetColumnTypes(X_types={}), StandardScaler(),...",0.640197,0.928965,0.669291,0.931759,381,127,0,0,bc7087515161a73a5a6aff57863f3803
1,0.060109,0.005315,"(SetColumnTypes(X_types={}), StandardScaler(),...",0.666915,0.81875,0.669291,0.847769,381,127,0,1,bc7087515161a73a5a6aff57863f3803
2,0.060692,0.005352,"(SetColumnTypes(X_types={}), StandardScaler(),...",0.583248,0.884555,0.614173,0.892388,381,127,0,2,bc7087515161a73a5a6aff57863f3803
3,0.060246,0.005345,"(SetColumnTypes(X_types={}), StandardScaler(),...",0.611979,0.8611,0.614173,0.88189,381,127,0,3,bc7087515161a73a5a6aff57863f3803


In [18]:
# For Gradient Boosting Classifier
creator_gb = PipelineCreator(problem_type="classification")
creator_gb.add("zscore")
creator_gb.add(
    "gradientboost",
    learning_rate = 0.02,
)

search_params_gb = {
    "kind": "grid",
    "cv": 4
}

scoring = ["balanced_accuracy", "accuracy"]

# Gradient Boosting on histograms
scores_hists_gb, model_hists_gb, inspector_hists_gb = run_cross_validation(
    X=X_hists,
    y='SEX_ID (1=m, 2=f)',
    data=df_full_histograms,
    search_params=search_params_gb,
    model=creator_gb,
    return_train_score=True,
    return_inspector=True,
    cv=4,
    scoring=scoring,
)

# Gradient Boosting on parcellations
scores_schaefer_gb, model_schaefer_gb, inspector_schaefer_gb = run_cross_validation(
    X=X_parcels,
    y='SEX_ID (1=m, 2=f)',
    data=df_full_parcellations,
    search_params=search_params_gb,
    model=creator_gb,
    return_train_score=True,
    return_inspector=True,
    cv=4,
    scoring=scoring,
)


  warn_with_log(

  warn_with_log(



In [19]:
scores_hists_gb

Unnamed: 0,fit_time,score_time,estimator,test_balanced_accuracy,train_balanced_accuracy,test_accuracy,train_accuracy,n_train,n_test,repeat,fold,cv_mdsum
0,0.516409,0.003923,"(SetColumnTypes(X_types={}), StandardScaler(),...",0.529847,0.872214,0.566929,0.87664,381,127,0,0,bc7087515161a73a5a6aff57863f3803
1,0.500183,0.003257,"(SetColumnTypes(X_types={}), StandardScaler(),...",0.558036,0.854313,0.559055,0.874016,381,127,0,1,bc7087515161a73a5a6aff57863f3803
2,0.503608,0.003226,"(SetColumnTypes(X_types={}), StandardScaler(),...",0.600714,0.825815,0.622047,0.84252,381,127,0,2,bc7087515161a73a5a6aff57863f3803
3,0.498622,0.003219,"(SetColumnTypes(X_types={}), StandardScaler(),...",0.525794,0.848925,0.527559,0.868766,381,127,0,3,bc7087515161a73a5a6aff57863f3803


In [20]:
scores_schaefer_gb

Unnamed: 0,fit_time,score_time,estimator,test_balanced_accuracy,train_balanced_accuracy,test_accuracy,train_accuracy,n_train,n_test,repeat,fold,cv_mdsum
0,0.753393,0.003393,"(SetColumnTypes(X_types={}), StandardScaler(),...",0.725767,0.935505,0.732283,0.937008,381,127,0,0,bc7087515161a73a5a6aff57863f3803
1,0.756039,0.003277,"(SetColumnTypes(X_types={}), StandardScaler(),...",0.778398,0.956575,0.779528,0.96063,381,127,0,1,bc7087515161a73a5a6aff57863f3803
2,0.756951,0.00325,"(SetColumnTypes(X_types={}), StandardScaler(),...",0.611295,0.949261,0.606299,0.950131,381,127,0,2,bc7087515161a73a5a6aff57863f3803
3,0.7596,0.003384,"(SetColumnTypes(X_types={}), StandardScaler(),...",0.667907,0.929313,0.669291,0.937008,381,127,0,3,bc7087515161a73a5a6aff57863f3803


In [21]:
scores_hists_xgb = pd.read_csv('IXI_XGB_scores_hists.csv')
scores_schaefer_xgb = pd.read_csv('IXI_XGB_scores_shaefer.csv')


In [22]:
scores_hists_xgb

Unnamed: 0,fit_time,score_time,estimator,test_balanced_accuracy,train_balanced_accuracy,test_accuracy,train_accuracy,n_train,n_test,repeat,fold,cv_mdsum
0,0.152105,0.014713,"Pipeline(steps=[('set_column_types', SetColumn...",0.626548,0.918632,0.645669,0.918635,381,127,0,0,bc7087515161a73a5a6aff57863f3803
1,0.149924,0.014601,"Pipeline(steps=[('set_column_types', SetColumn...",0.549044,0.913761,0.559055,0.918635,381,127,0,1,bc7087515161a73a5a6aff57863f3803
2,0.154075,0.014558,"Pipeline(steps=[('set_column_types', SetColumn...",0.650649,0.928746,0.661417,0.931759,381,127,0,2,bc7087515161a73a5a6aff57863f3803
3,0.149626,0.014585,"Pipeline(steps=[('set_column_types', SetColumn...",0.582216,0.912079,0.574803,0.92126,381,127,0,3,bc7087515161a73a5a6aff57863f3803


In [23]:
scores_schaefer_xgb

Unnamed: 0,fit_time,score_time,estimator,test_balanced_accuracy,train_balanced_accuracy,test_accuracy,train_accuracy,n_train,n_test,repeat,fold,cv_mdsum
0,0.26568,0.01611,"Pipeline(steps=[('set_column_types', SetColumn...",0.652735,0.980166,0.669291,0.981627,381,127,0,0,bc7087515161a73a5a6aff57863f3803
1,0.20903,0.015271,"Pipeline(steps=[('set_column_types', SetColumn...",0.685739,0.969696,0.692913,0.971129,381,127,0,1,bc7087515161a73a5a6aff57863f3803
2,0.201536,0.014705,"Pipeline(steps=[('set_column_types', SetColumn...",0.64961,0.966291,0.677165,0.968504,381,127,0,2,bc7087515161a73a5a6aff57863f3803
3,0.198308,0.01475,"Pipeline(steps=[('set_column_types', SetColumn...",0.698957,0.966873,0.692913,0.971129,381,127,0,3,bc7087515161a73a5a6aff57863f3803


In [26]:
scores_schaefer_svm['model'] = 'IXI_Schaefer_SVM_1000'
scores_hists_svm['model'] = 'IXI_Histograms_SVM_1000'

scores_schaefer_rf['model'] = 'IXI_Schaefer_RF_1000'
scores_hists_rf['model'] = 'IXI_Histograms_RF_1000'

scores_schaefer_et['model'] = 'IXI_Schaefer_ET_1000'
scores_hists_et['model'] = 'IXI_Histograms_ET_1000'

scores_schaefer_gb['model'] = 'IXI_Schaefer_GB_1000'
scores_hists_gb['model'] = 'IXI_Histograms_GB_1000'

scores_schaefer_xgb['model'] = 'IXI_Schaefer_XGB_1000'
scores_hists_xgb['model'] = 'IXI_Histograms_XGB_1000'

In [27]:
plot_scores(scores_schaefer_svm,scores_hists_svm,scores_schaefer_rf,scores_hists_rf,scores_schaefer_et,scores_hists_et,scores_schaefer_gb,scores_hists_gb,scores_schaefer_xgb,scores_hists_xgb)

BokehModel(combine_events=True, render_bundle={'docs_json': {'4f19e648-8961-48d2-93a8-3e3af2558093': {'version…