# Diagnosis Cross-Prediction

In [1]:
import pandas as pd
import numpy as np

%load_ext autoreload
%autoreload 2

import matplotlib.pyplot as plt
%matplotlib inline

import seaborn as sns
sns.set_theme(
    context="paper", 
    style="whitegrid", 
    font_scale=1.2,
    rc={'figure.figsize': (10, 10), 'figure.dpi': 300}
)

## Get Data

In [2]:
from os.path import join
from common.data import get_data
from common.paths import BIOBANK_LABELS

healthy = join(BIOBANK_LABELS, 'Subjects_with_WISC (healthy).csv')
X_healthy, Y_healthy, demographics, population = get_data(5, healthy)

adhd_group_one = join(BIOBANK_LABELS, 'Subjects_with_WISC (adhd 1).csv')
X_adhd_one, Y_adhd_one, demographics, population = get_data(5, adhd_group_one)

adhd_group_two = join(BIOBANK_LABELS, 'Subjects_with_WISC (adhd 2).csv')
X_adhd_two, Y_adhd_two, demographics, population = get_data(5, adhd_group_two)

print(f'X_healthy: {X_healthy.shape} | X_adhd_one: {X_adhd_one.shape} | X_adhd_two: {X_adhd_two.shape}')

X_healthy: (106, 34716) | X_adhd_one: (190, 34716) | X_adhd_two: (190, 34716)


In [3]:
from common.cross_prediction import get_group_cv_splits, get_group_order
from common.paths import CROSS_PRED_RESULTS
from common.results import CVResult, save_results
from common.scoring import (unimetric_scorer, 
                            custom_permutation_test_score, 
                            N_PERM, SCORING, RKF_10_10)
from sklearn.linear_model import Ridge
from sklearn.model_selection import cross_validate
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import make_pipeline

## Run for one target, one age bin

In [7]:
selected_target = "WISC_FSIQ"
y_healthy = Y_healthy[selected_target]
y_adhd_one = Y_adhd_one[selected_target]
y_adhd_two = Y_adhd_two[selected_target]

print(f'{selected_target}: {y_healthy.shape}, {y_adhd_one.shape}, {y_adhd_two.shape}')

WISC_PSI: (106,), (190,), (190,)


In [8]:
healthy = (X_healthy, y_healthy)
adhd_one = (X_adhd_one, y_adhd_one)
adhd_two = (X_adhd_two, y_adhd_two)

diags = [healthy, adhd_one, adhd_two]
diags_cv = get_group_cv_splits(diags, RKF_10_10)

print(f'Healthy: {healthy[0].shape} | ADHD_ONE: {adhd_one[0].shape} | ADHD_TWO: {adhd_two[0].shape}')
print(f'healthy_cv: {len(diags_cv[0])} | adhd_one_cv: {len(diags_cv[1])} | adhd_two_cv: {len(diags_cv[2])}')

### Run permutation-test (train group, test group)

In [9]:
%%time
# From previous results
diag_alphas = [5000, 35000, 35000]
diag_labels = ['Healthy', 'ADHD_One', 'ADHD_Two']

diag_order, diag_cv_order, diag_labels = get_group_order(diags, diags_cv, diag_labels)
results = []
perm_scores = []

for diag_alpha, diags, diags_cv, labels in zip(diag_alphas, diag_order, diag_cv_order, diag_labels):
    train_diag, test_diag_one, test_diag_two = diags[0], diags[1], diags[2]
    train_diag_cv, test_diag_one_cv, test_diag_two_cv = diags_cv[0], diags_cv[1], diags_cv[2]
    
    pipe = make_pipeline(StandardScaler(), Ridge(alpha=diag_alpha))
    rs, perms, ps = custom_permutation_test_score(
        pipe, train_diag, test_diag_one, test_diag_two, 
        train_diag_cv, test_diag_one_cv, test_diag_two_cv, N_PERM, unimetric_scorer)
    
    train_group = labels[0]
    for r, p, test_group in zip(rs, ps, labels):
        results.append(
            CVResult('ridge', selected_target, train_group, test_group, r, p, train_group, N_PERM)
        )
    perm_scores.append(perms)
    print(f'Train Group: {train_group}')

results_df = pd.DataFrame([r.to_dict() for r in results])
display(results_df.round(4))
filename = f'ridge_pts_diagnosis_cross_prediction.csv'
results_fp = save_results(results_df, filename, CROSS_PRED_RESULTS)
print('Results saved to:', results_fp)

Train Group: Healthy
Train Group: ADHD_One
Train Group: ADHD_Two


Unnamed: 0,Model,Target,Num Permutations,Train Group,Test Group,Score,P-value
0,ridge,WISC_PSI,500,Healthy,Healthy,-0.0356,0.5828
1,ridge,WISC_PSI,500,Healthy,ADHD_One,0.0758,0.1098
2,ridge,WISC_PSI,500,Healthy,ADHD_Two,0.0085,0.4551
3,ridge,WISC_PSI,500,ADHD_One,ADHD_One,0.0156,0.4571
4,ridge,WISC_PSI,500,ADHD_One,ADHD_Two,0.0538,0.2076
5,ridge,WISC_PSI,500,ADHD_One,Healthy,0.0281,0.3972
6,ridge,WISC_PSI,500,ADHD_Two,ADHD_Two,0.0027,0.523
7,ridge,WISC_PSI,500,ADHD_Two,Healthy,0.0525,0.2834
8,ridge,WISC_PSI,500,ADHD_Two,ADHD_One,0.0773,0.1138


CPU times: user 1d 9h 31min 11s, sys: 36min 4s, total: 1d 10h 7min 15s
Wall time: 3h 24min 43s


### Visualize permutation results