In [1]:
import pandas as pd

import numpy as np

from pathlib import Path

from sklearn.linear_model import LogisticRegression, LinearRegression
from sklearn.metrics import roc_auc_score, r2_score
from sklearn.model_selection import train_test_split

from tqdm import tqdm

In [2]:
sst_embdeds = pd.read_pickle('/mrhome/vladyslavz/git/central-sulcus-analysis/data/via11/nobackup/contrastive_embeddings/synthseg-monai-BasicUnet-1.5x-half/sst_embeds.pkl')
sst_embdeds.encoder_embed = [x.flatten() for x in sst_embdeds.encoder_embed]

In [3]:
metrics_paths = Path('/mnt/projects/VIA_Vlad/nobackup/MP2RAGE_FS7_1_1/mindboggle').glob('via*')
metrics_paths = [x/'tables/left_cortical_surface/sulcus_shapes.csv' for x in metrics_paths]

In [4]:
features2analyze = ['area', 'travel depth: median', 'geodesic depth: median', 
                    'mean curvature: median',  'freesurfer curvature: MAD',
                    'freesurfer convexity (sulc): median', 'freesurfer thickness: median']

In [15]:
def extract_features(subj_path):
    metrics_df = pd.read_csv(subj_path)
    metrics_df = metrics_df.set_index('name')
    features = dict(metrics_df.loc['central sulcus', :])
    features['caseid'] = 'sub-' + subj_path.parent.parent.parent.name
    return features
features_df = []
for subj_path in tqdm(metrics_paths):
    features_df.append(extract_features(subj_path))
features_df = pd.DataFrame(features_df).set_index('caseid')
features_df

100%|██████████| 325/325 [00:01<00:00, 220.67it/s]


Unnamed: 0_level_0,ID,area,travel depth: median,travel depth: MAD,travel depth: mean,travel depth: SD,travel depth: skew,travel depth: kurtosis,travel depth: 25%,travel depth: 75%,...,Zernike moments: component 27,Zernike moments: component 28,Zernike moments: component 29,Zernike moments: component 30,Zernike moments: component 31,Zernike moments: component 32,Zernike moments: component 33,Zernike moments: component 34,Zernike moments: component 35,Zernike moments: component 36
caseid,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
sub-via003,4.0,4537.207206,9.628324,4.530343,10.543645,6.383913,-0.277243,-0.907948,5.502757,14.783726,...,0.000390,0.000014,0.000081,0.000244,0.000008,0.000051,4.701093e-06,0.000033,2.805619e-06,1.725859e-06
sub-via004,4.0,4051.200182,9.903018,4.455301,10.910019,6.478329,-0.012364,-0.547759,5.758049,14.900315,...,0.000111,0.000003,0.000020,0.000068,0.000002,0.000011,8.536653e-07,0.000007,4.702659e-07,2.634907e-07
sub-via005,4.0,4666.630017,9.708878,4.934100,10.757765,7.044286,-0.539304,-1.049993,5.279148,15.360837,...,0.000402,0.000017,0.000101,0.000298,0.000010,0.000068,6.277717e-06,0.000045,3.740225e-06,2.203411e-06
sub-via010,4.0,4502.578013,10.082948,4.474133,11.087404,6.501278,0.002494,-0.506488,6.003443,15.125625,...,0.000168,0.000004,0.000027,0.000095,0.000002,0.000015,1.179356e-06,0.000009,6.437780e-07,3.849158e-07
sub-via013,4.0,4293.910500,9.790410,4.848544,10.861683,6.735491,-0.236952,-0.901368,5.487458,15.330969,...,0.000254,0.000008,0.000050,0.000152,0.000005,0.000030,2.574580e-06,0.000019,1.406302e-06,7.754791e-07
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
sub-via517,4.0,4615.699580,8.955590,4.381562,9.811945,6.225344,-0.157878,-0.619091,4.812093,13.748234,...,0.000205,0.000005,0.000034,0.000123,0.000002,0.000017,9.308985e-07,0.000008,4.516468e-07,2.295352e-07
sub-via518,4.0,4146.014173,10.004529,4.925986,11.323651,7.236773,0.278796,-0.223592,5.575366,15.751502,...,0.000324,0.000011,0.000068,0.000214,0.000006,0.000039,3.007858e-06,0.000022,1.636744e-06,9.244507e-07
sub-via519,4.0,4440.950709,9.938319,4.799317,11.028758,6.792104,-0.022814,-0.783057,5.561841,15.363136,...,0.000196,0.000005,0.000035,0.000121,0.000003,0.000019,1.556666e-06,0.000012,9.505781e-07,6.104764e-07
sub-via521,4.0,4093.968616,10.663710,4.797794,11.374511,6.272890,-0.256788,-0.933075,6.206764,15.910649,...,0.000079,0.000002,0.000016,0.000054,0.000001,0.000009,6.566717e-07,0.000005,3.534783e-07,1.872134e-07


In [16]:
merged_df = sst_embdeds.merge(features_df, left_on='caseid', right_index=True)

In [17]:
f = 'area'
targ= 'mlp_embed'
for f in features_df.columns:
    train_df, test_df = train_test_split(merged_df, test_size=0.3)

    train_y = train_df[f].values
    train_X = np.vstack(train_df[targ].values)

    test_y = test_df[f].values
    test_X = np.vstack(test_df[targ].values)

    lr = LinearRegression()
    train_pred_y = lr.fit(train_X, train_y).predict(train_X)
    test_pred_y = lr.predict(test_X)

    if r2_score(test_y, test_pred_y) > 0:
        print('BINGO')
    print(f'For feature: {f}')
    print(f'R2 train: {r2_score(train_y, train_pred_y)}')
    print(f'R2 test: {r2_score(test_y, test_pred_y)}')
    print()
    print(f'MSE train: {np.mean((train_y - train_pred_y)**2)}')
    print(f'MSE test: {np.mean((test_y - test_pred_y)**2)}')
    print('____________________', '\n')

BINGO
For feature: ID
R2 train: 1.0
R2 test: 1.0

MSE train: 0.0
MSE test: 0.0
____________________ 

For feature: area
R2 train: 0.4355960650582884
R2 test: -0.0071633490225875285

MSE train: 125506.76304065528
MSE test: 399613.8394463524
____________________ 

For feature: travel depth: median
R2 train: 0.21502229827032138
R2 test: -2.6564359295823503

MSE train: 0.2617721208633029
MSE test: 1.7443792102940419
____________________ 

For feature: travel depth: MAD
R2 train: 0.22061756167515567
R2 test: -0.8969294376389703

MSE train: 0.08006191038984369
MSE test: 0.15281589422706215
____________________ 

For feature: travel depth: mean
R2 train: 0.26368796236859493
R2 test: -6.935666364616446

MSE train: 0.29011392849516887
MSE test: 2.750059023789677
____________________ 

For feature: travel depth: SD
R2 train: 0.34891872949714675
R2 test: -1018.4112038571966

MSE train: 0.10175679238740122
MSE test: 159.26695200904555
____________________ 

For feature: travel depth: skew
R2 train

In [None]:
from src.data.splits import 