# Experiments

In [85]:
import pickle
import re
import pandas as pd
import numpy as np
from data import *
from bnMLP import *
import os
from torch.utils.data import TensorDataset, DataLoader, random_split

from sklearn.model_selection import StratifiedKFold
from sklearn.model_selection import train_test_split
import torch
from sklearn.svm import SVC
from sklearn.neighbors import NearestCentroid
from sklearn.metrics import roc_auc_score, pairwise_distances
from utils import *

In [18]:
GLOBAL_CUTOFFS = {}
for a, b in pd.read_csv('cutoffs.csv').values:
    GLOBAL_CUTOFFS[a] = b
    # break

COHORTS = ['Autoimmune: Other', 'Cancer: Other', 'HIV', 'Healthy Control',
           'IBD', 'Multiple Myeloma', 'Transplant']


In [6]:
with open('final_models.pkl', 'rb') as f:
    data = pickle.load(f)


In [8]:
results_6_excel, results_8_excel = data['dict6'], data['dict8']

In [40]:
if not os.path.exists('final_results'):
    os.makedirs('final_results')

#### Post to Post by cohort

In [10]:

columns = []
for t in [0.75, 0.8, 0.85, 0.9, 0.95]:
    columns.append(f'Sensitivity at {t} Specificity')
for t in [0.75, 0.8, 0.85, 0.9, 0.95]:
    columns.append(f'Specificity at {t} Sensitivity')
columns = [
    'Low Responders Accuracy',
    'High Responders Accuracy',
    'Harmonic Mean',
    'AUC'
] + columns


split_pattern = r',\s*(?![^()]*\))'

records = []
for k, v in results_8_excel.items():
    parts = re.split(split_pattern, k)
    seed, cohort, task, setting, clf, coeffs, usage = parts
    records.append({
        'seed': int(seed),
        'cohort': cohort,
        'task': task,
        'setting': setting,
        'clf': clf,
        'coeffs': coeffs,
        'usage': usage,
        **{name: val for name, val in zip(columns, v[-1])}
    })

df = pd.DataFrame(records)

group_cols = ['cohort', 'task', 'setting', 'clf', 'coeffs', 'usage']
metric_cols = columns

agg_df = df.groupby(group_cols)[metric_cols].agg(['mean', 'std']).reset_index()
agg_df.columns = ['_'.join(col).strip('_') for col in agg_df.columns.values]

# print(agg_df)


In [11]:
agg_df.head()

Unnamed: 0,cohort,task,setting,clf,coeffs,usage,Low Responders Accuracy_mean,Low Responders Accuracy_std,High Responders Accuracy_mean,High Responders Accuracy_std,...,Specificity at 0.75 Sensitivity_mean,Specificity at 0.75 Sensitivity_std,Specificity at 0.8 Sensitivity_mean,Specificity at 0.8 Sensitivity_std,Specificity at 0.85 Sensitivity_mean,Specificity at 0.85 Sensitivity_std,Specificity at 0.9 Sensitivity_mean,Specificity at 0.9 Sensitivity_std,Specificity at 0.95 Sensitivity_mean,Specificity at 0.95 Sensitivity_std
0,Autoimmune: Other,8,binary,mlp,"(4, 1)",Ab raw,0.483657,0.283375,0.535814,0.270041,...,0.236413,0.073648,0.188761,0.059383,0.129714,0.058552,0.058277,0.039342,0.040396,0.033726
1,Autoimmune: Other,8,binary,mlp,"(4, 1)",Demographic,0.521686,0.27946,0.520352,0.275685,...,0.265945,0.072609,0.222303,0.066424,0.16089,0.05467,0.072352,0.044463,0.055754,0.041882
2,Autoimmune: Other,8,binary,nc,"(4, 1)",Ab raw,0.397133,0.063008,0.624284,0.057271,...,0.220943,0.061691,0.184239,0.052999,0.130657,0.046604,0.062564,0.03367,0.043646,0.030186
3,Autoimmune: Other,8,binary,nc,"(4, 1)",Demographic,0.426038,0.06722,0.621502,0.053877,...,0.234517,0.050559,0.198028,0.05059,0.14264,0.044893,0.061371,0.032921,0.047356,0.031698
4,Autoimmune: Other,8,binary,svm,"(4, 1)",Ab raw,0.706571,0.055719,0.335949,0.052944,...,0.239227,0.055156,0.202758,0.050026,0.147002,0.046978,0.06779,0.032058,0.05465,0.032557


In [19]:
X, y , emb_list, real_size, encoder = load_6(binary=True, cutoffs=None)
counts = {}
for cohort in COHORTS:
    
    # idx = X[:, 0] == 'Healthy Control'
    idx = X[:, 0] == cohort
    
    all_low_count = (y[idx] == 0).astype(float).sum()
    all_high_count = (y[idx] == 1).astype(float).sum()
    print(all_low_count, all_high_count, cohort)
    counts[f'{cohort},Test_N_Low'] = all_low_count
    counts[f'{cohort},Train_N_Low'] = all_low_count
    counts[f'{cohort},Test_N_High'] = all_high_count
    counts[f'{cohort},Train_N_High'] = all_high_count
counts

72.0 163.0 Autoimmune: Other
9.0 102.0 Cancer: Other
14.0 70.0 HIV
24.0 465.0 Healthy Control
4.0 33.0 IBD
60.0 92.0 Multiple Myeloma
81.0 73.0 Transplant


{'Autoimmune: Other,Test_N_Low': 72.0,
 'Autoimmune: Other,Train_N_Low': 72.0,
 'Autoimmune: Other,Test_N_High': 163.0,
 'Autoimmune: Other,Train_N_High': 163.0,
 'Cancer: Other,Test_N_Low': 9.0,
 'Cancer: Other,Train_N_Low': 9.0,
 'Cancer: Other,Test_N_High': 102.0,
 'Cancer: Other,Train_N_High': 102.0,
 'HIV,Test_N_Low': 14.0,
 'HIV,Train_N_Low': 14.0,
 'HIV,Test_N_High': 70.0,
 'HIV,Train_N_High': 70.0,
 'Healthy Control,Test_N_Low': 24.0,
 'Healthy Control,Train_N_Low': 24.0,
 'Healthy Control,Test_N_High': 465.0,
 'Healthy Control,Train_N_High': 465.0,
 'IBD,Test_N_Low': 4.0,
 'IBD,Train_N_Low': 4.0,
 'IBD,Test_N_High': 33.0,
 'IBD,Train_N_High': 33.0,
 'Multiple Myeloma,Test_N_Low': 60.0,
 'Multiple Myeloma,Train_N_Low': 60.0,
 'Multiple Myeloma,Test_N_High': 92.0,
 'Multiple Myeloma,Train_N_High': 92.0,
 'Transplant,Test_N_Low': 81.0,
 'Transplant,Train_N_Low': 81.0,
 'Transplant,Test_N_High': 73.0,
 'Transplant,Train_N_High': 73.0}

In [20]:
agg_df['Train_Timepoint'] = 'Post-B'
agg_df['Test_Timepoint'] = 'Post-B'
agg_df['Test_Cohort'] = agg_df['cohort']
def update_usage(row, what):
    cohort = row['cohort']
    return counts[f'{cohort},{what}']

for what in ['Train_N_Low', 'Train_N_High', 'Test_N_High', 'Test_N_Low']:
    agg_df[what] = agg_df.apply(lambda z:update_usage(z, what=what), axis=1)
    
agg_df = agg_df.rename(columns={
    'cohort': 'Train_Cohort', 
    'clf': 'Classifier', 
    'AUC_mean': 'CV_AUC_Mean',
    'AUC_std': 'CV_AUC_SD',
    'Sensitivity at 0.8 Specificity_mean': 'CV_Sen_at_80_Spe_Mean',
    'Sensitivity at 0.8 Specificity_std': 'CV_Sen_at_80_Spe_SD',
    'Low Responders Accuracy_mean': 'CV_Acc_Low_Mean',
    'High Responders Accuracy_mean': 'CV_Acc_High_Mean',
    'Low Responders Accuracy_std': 'CV_Acc_Low_SD',
    'High Responders Accuracy_std': 'CV_Acc_High_SD',
})

def update_usage(row):
    usage = row['usage']
    setting = row['setting']
    
    # Case 1: if usage is "Ab raw"
    
    out = 'Ab_raw'
    
    # Case 2: if usage is "Demographic"
    if setting.strip().lower() == 'binary':
        out += '+ Ab_binary'
        
    if usage.strip().lower() == 'demographic':
        out += '+ demographic'
    
    # Case 3: modify further based on 'setting'
    
    if setting.strip().lower() == 'no-cutoff':
        # don't add anything
        pass
    
    return out

# Apply transformation
agg_df['Feature_Set'] = agg_df.apply(update_usage, axis=1)
agg_df[['Train_Timepoint', 'Train_Cohort', 'Train_N_Low', 'Train_N_High', 'Feature_Set', 'Test_Timepoint', 'Test_Cohort', 'Test_N_Low' ,'Test_N_High', 'Classifier',	'CV_AUC_Mean', 'CV_AUC_SD',	'CV_Sen_at_80_Spe_Mean', 'CV_Sen_at_80_Spe_SD', 'CV_Acc_Low_Mean', 'CV_Acc_Low_SD',	'CV_Acc_High_Mean',	'CV_Acc_High_SD']]



Unnamed: 0,Train_Timepoint,Train_Cohort,Train_N_Low,Train_N_High,Feature_Set,Test_Timepoint,Test_Cohort,Test_N_Low,Test_N_High,Classifier,CV_AUC_Mean,CV_AUC_SD,CV_Sen_at_80_Spe_Mean,CV_Sen_at_80_Spe_SD,CV_Acc_Low_Mean,CV_Acc_Low_SD,CV_Acc_High_Mean,CV_Acc_High_SD
0,Post-B,Autoimmune: Other,72.0,163.0,Ab_raw+ Ab_binary,Post-B,Autoimmune: Other,72.0,163.0,mlp,0.519177,0.037570,0.280638,0.057416,0.483657,0.283375,0.535814,0.270041
1,Post-B,Autoimmune: Other,72.0,163.0,Ab_raw+ Ab_binary+ demographic,Post-B,Autoimmune: Other,72.0,163.0,mlp,0.545551,0.040479,0.310752,0.070853,0.521686,0.279460,0.520352,0.275685
2,Post-B,Autoimmune: Other,72.0,163.0,Ab_raw+ Ab_binary,Post-B,Autoimmune: Other,72.0,163.0,nc,0.528434,0.039630,0.302295,0.055403,0.397133,0.063008,0.624284,0.057271
3,Post-B,Autoimmune: Other,72.0,163.0,Ab_raw+ Ab_binary+ demographic,Post-B,Autoimmune: Other,72.0,163.0,nc,0.541059,0.036547,0.320124,0.052592,0.426038,0.067220,0.621502,0.053877
4,Post-B,Autoimmune: Other,72.0,163.0,Ab_raw+ Ab_binary,Post-B,Autoimmune: Other,72.0,163.0,svm,0.531261,0.035584,0.286267,0.048841,0.706571,0.055719,0.335949,0.052944
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
67,Post-B,Transplant,81.0,73.0,Ab_raw+ demographic,Post-B,Transplant,81.0,73.0,mlp,0.552797,0.040570,0.332228,0.065942,0.564265,0.248461,0.509495,0.239451
68,Post-B,Transplant,81.0,73.0,Ab_raw,Post-B,Transplant,81.0,73.0,nc,0.510863,0.043189,0.286088,0.054508,0.520029,0.065614,0.479533,0.062681
69,Post-B,Transplant,81.0,73.0,Ab_raw+ demographic,Post-B,Transplant,81.0,73.0,nc,0.526472,0.046465,0.30361,0.068458,0.554243,0.074772,0.500381,0.059136
70,Post-B,Transplant,81.0,73.0,Ab_raw,Post-B,Transplant,81.0,73.0,svm,0.520031,0.047456,0.292162,0.067263,0.836809,0.058186,0.187333,0.058779


In [23]:
out.to_csv('final_results/model_performance_summary_v1_revised.csv')
out

Unnamed: 0,Train_Timepoint,Train_Cohort,Train_N_Low,Train_N_High,Feature_Set,Test_Timepoint,Test_Cohort,Test_N_Low,Test_N_High,Classifier,CV_AUC_Mean,CV_AUC_SD,CV_Sen_at_80_Spe_Mean,CV_Sen_at_80_Spe_SD,CV_Acc_Low_Mean,CV_Acc_Low_SD,CV_Acc_High_Mean,CV_Acc_High_SD
0,Post-B,Autoimmune: Other,72.0,163.0,Ab_raw+ Ab_binary,Post-B,Autoimmune: Other,72.0,163.0,mlp,0.519177,0.037570,0.280638,0.057416,0.483657,0.283375,0.535814,0.270041
1,Post-B,Autoimmune: Other,72.0,163.0,Ab_raw+ Ab_binary+ demographic,Post-B,Autoimmune: Other,72.0,163.0,mlp,0.545551,0.040479,0.310752,0.070853,0.521686,0.279460,0.520352,0.275685
2,Post-B,Autoimmune: Other,72.0,163.0,Ab_raw+ Ab_binary,Post-B,Autoimmune: Other,72.0,163.0,nc,0.528434,0.039630,0.302295,0.055403,0.397133,0.063008,0.624284,0.057271
3,Post-B,Autoimmune: Other,72.0,163.0,Ab_raw+ Ab_binary+ demographic,Post-B,Autoimmune: Other,72.0,163.0,nc,0.541059,0.036547,0.320124,0.052592,0.426038,0.067220,0.621502,0.053877
4,Post-B,Autoimmune: Other,72.0,163.0,Ab_raw+ Ab_binary,Post-B,Autoimmune: Other,72.0,163.0,svm,0.531261,0.035584,0.286267,0.048841,0.706571,0.055719,0.335949,0.052944
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
67,Post-B,Transplant,81.0,73.0,Ab_raw+ demographic,Post-B,Transplant,81.0,73.0,mlp,0.552797,0.040570,0.332228,0.065942,0.564265,0.248461,0.509495,0.239451
68,Post-B,Transplant,81.0,73.0,Ab_raw,Post-B,Transplant,81.0,73.0,nc,0.510863,0.043189,0.286088,0.054508,0.520029,0.065614,0.479533,0.062681
69,Post-B,Transplant,81.0,73.0,Ab_raw+ demographic,Post-B,Transplant,81.0,73.0,nc,0.526472,0.046465,0.30361,0.068458,0.554243,0.074772,0.500381,0.059136
70,Post-B,Transplant,81.0,73.0,Ab_raw,Post-B,Transplant,81.0,73.0,svm,0.520031,0.047456,0.292162,0.067263,0.836809,0.058186,0.187333,0.058779


#### Post to Post All

In [35]:
columns = []
for t in [0.75, 0.8, 0.85, 0.9, 0.95]:
    columns.append(f'Sensitivity at {t} Specificity')
for t in [0.75, 0.8, 0.85, 0.9, 0.95]:
    columns.append(f'Specificity at {t} Sensitivity')
columns = [
    'Low Responders Accuracy',
    'High Responders Accuracy',
    'Harmonic Mean',
    'AUC'
] + columns
for k, v in results_6_excel.items():
    seed, _, setting, clf, __, usage = re.split(r',\s*(?![^()]*\))', k)


records = []
for k, v in results_6_excel.items():
    parts = re.split(split_pattern, k)
    seed, task, setting, clf, coeffs, usage = parts
    records.append({
        'seed': int(seed),
        'cohort': 'All',
        'task': task,
        'setting': setting,
        'clf': clf,
        'coeffs': coeffs,
        'usage': usage,
        **{name: val for name, val in zip(columns, v[-1])}
    })

df = pd.DataFrame(records)

group_cols = ['task', 'cohort', 'setting', 'clf', 'coeffs', 'usage']
metric_cols = columns

agg_df = df.groupby(group_cols)[metric_cols].agg(['mean', 'std']).reset_index()
agg_df.columns = ['_'.join(col).strip('_') for col in agg_df.columns.values]
agg_df

Unnamed: 0,task,cohort,setting,clf,coeffs,usage,Low Responders Accuracy_mean,Low Responders Accuracy_std,High Responders Accuracy_mean,High Responders Accuracy_std,...,Specificity at 0.75 Sensitivity_mean,Specificity at 0.75 Sensitivity_std,Specificity at 0.8 Sensitivity_mean,Specificity at 0.8 Sensitivity_std,Specificity at 0.85 Sensitivity_mean,Specificity at 0.85 Sensitivity_std,Specificity at 0.9 Sensitivity_mean,Specificity at 0.9 Sensitivity_std,Specificity at 0.95 Sensitivity_mean,Specificity at 0.95 Sensitivity_std
0,6,All,binary,mlp,"(4, 1)",Ab raw,0.572921,0.172749,0.67822,0.162297,...,0.493206,0.038858,0.431945,0.049203,0.327982,0.051969,0.193796,0.071317,0.09531,0.057351
1,6,All,binary,mlp,"(4, 1)",Demographic,0.644741,0.169995,0.696249,0.138374,...,0.60616,0.036515,0.55136,0.04381,0.459103,0.061108,0.324699,0.080503,0.177426,0.091796
2,6,All,binary,nc,"(4, 1)",Ab raw,0.54907,0.0428,0.66208,0.055566,...,0.333472,0.051213,0.277888,0.050271,0.207805,0.040829,0.134128,0.032574,0.077983,0.023814
3,6,All,binary,nc,"(4, 1)",Demographic,0.589658,0.042097,0.713142,0.04981,...,0.38929,0.073052,0.329967,0.067699,0.250464,0.063405,0.162638,0.044944,0.093714,0.030963
4,6,All,binary,svm,"(4, 1)",Ab raw,0.691062,0.037217,0.555638,0.037663,...,0.438089,0.051454,0.366809,0.051421,0.275348,0.048729,0.177031,0.041524,0.099713,0.031732
5,6,All,binary,svm,"(4, 1)",Demographic,0.744039,0.043451,0.603672,0.035942,...,0.552353,0.070853,0.479273,0.082779,0.376729,0.086989,0.258246,0.081154,0.158006,0.063979
6,6,All,no-cutoff,mlp,"(4, 1)",Ab raw,0.584875,0.156188,0.670334,0.157725,...,0.48819,0.041704,0.431632,0.045136,0.334727,0.056983,0.204605,0.062188,0.107816,0.06086
7,6,All,no-cutoff,mlp,"(4, 1)",Demographic,0.666021,0.156551,0.715709,0.126016,...,0.648492,0.035424,0.601295,0.037576,0.522265,0.047696,0.398905,0.072268,0.239715,0.110474
8,6,All,no-cutoff,nc,"(4, 1)",Ab raw,0.547397,0.045727,0.656191,0.058674,...,0.325801,0.051683,0.27178,0.048664,0.201948,0.038854,0.129076,0.02858,0.077481,0.022917
9,6,All,no-cutoff,nc,"(4, 1)",Demographic,0.614024,0.039258,0.730845,0.045814,...,0.41372,0.089823,0.354238,0.08473,0.269227,0.073797,0.173214,0.054522,0.102504,0.03808


In [62]:
agg_df['Train_Timepoint'] = 'Post-B'
agg_df['Test_Timepoint'] = 'Post-B'
agg_df['Test_Cohort'] = 'All'


X, y , emb_list, real_size, encoder = load_6(binary=True, cutoffs=None)
counts = {}
# for cohort in COHORTS:
    
    # idx = X[:, 0] == 'Healthy Control'
    # idx = X[:, 0] == cohort
all_low_count = (y == 0).astype(float).sum()
all_high_count = (y == 1).astype(float).sum()
counts[f'Test_N_Low'] = all_low_count
counts[f'Train_N_Low'] = all_low_count
counts[f'Test_N_High'] = all_high_count
counts[f'Train_N_High'] = all_high_count


def update_usage(row, what):
    return counts[f'{what}']

# Apply transformation
for what in ['Train_N_Low', 'Train_N_High', 'Test_N_High', 'Test_N_Low']:
    agg_df[what] = agg_df.apply(lambda z:update_usage(z, what=what), axis=1)
    
agg_df = agg_df.rename(columns={
    'cohort': 'Train_Cohort', 
    'clf': 'Classifier', 
    'AUC_mean': 'CV_AUC_Mean',
    'AUC_std': 'CV_AUC_SD',
    'Sensitivity at 0.8 Specificity_mean': 'CV_Sen_at_80_Spe_Mean',
    'Sensitivity at 0.8 Specificity_std': 'CV_Sen_at_80_Spe_SD',
    'Low Responders Accuracy_mean': 'CV_Acc_Low_Mean',
    'High Responders Accuracy_mean': 'CV_Acc_High_Mean',
    'Low Responders Accuracy_std': 'CV_Acc_Low_SD',
    'High Responders Accuracy_std': 'CV_Acc_High_SD',
})

def update_usage(row):
    # Start with base usage name
    usage = row['usage']
    setting = row['setting']
    
    
    # Case 1: if usage is "Ab raw"
    
    out = 'Ab_raw'
    
    
    # Case 2: if usage is "Demographic"
    if setting.strip().lower() == 'binary':
        out += '+ Ab_binary'
        
    if usage.strip().lower() == 'demographic':
        out += '+ demographic'
    
    # Case 3: modify further based on 'setting'
    
    if setting.strip().lower() == 'no-cutoff':
        # don't add anything
        pass
    
    return out

# Apply transformation
agg_df['Feature_Set'] = agg_df.apply(update_usage, axis=1)
agg_df[['Train_Timepoint', 'Train_Cohort', 'Train_N_Low', 'Train_N_High', 'Feature_Set', 'Test_Timepoint', 'Test_Cohort', 'Test_N_Low' ,'Test_N_High', 'Classifier', 'CV_AUC_Mean', 'CV_AUC_SD',	'CV_Sen_at_80_Spe_Mean', 'CV_Sen_at_80_Spe_SD', 'CV_Acc_Low_Mean', 'CV_Acc_Low_SD',	'CV_Acc_High_Mean',	'CV_Acc_High_SD']]



Unnamed: 0,Train_Timepoint,Train_Cohort,Train_N_Low,Train_N_High,Feature_Set,Test_Timepoint,Test_Cohort,Test_N_Low,Test_N_High,Classifier,CV_AUC_Mean,CV_AUC_SD,CV_Sen_at_80_Spe_Mean,CV_Sen_at_80_Spe_SD,CV_Acc_Low_Mean,CV_Acc_Low_SD,CV_Acc_High_Mean,CV_Acc_High_SD
0,Post-B,All,264.0,998.0,Ab_raw+ Ab_binary,Post-B,All,264.0,998.0,mlp,0.684762,0.013148,0.475634,0.029027,0.572921,0.172749,0.67822,0.162297
1,Post-B,All,264.0,998.0,Ab_raw+ Ab_binary+ demographic,Post-B,All,264.0,998.0,mlp,0.74264,0.018529,0.558895,0.037587,0.644741,0.169995,0.696249,0.138374
2,Post-B,All,264.0,998.0,Ab_raw+ Ab_binary,Post-B,All,264.0,998.0,nc,0.585022,0.039592,0.329084,0.062969,0.54907,0.0428,0.66208,0.055566
3,Post-B,All,264.0,998.0,Ab_raw+ Ab_binary+ demographic,Post-B,All,264.0,998.0,nc,0.622994,0.050419,0.385864,0.080376,0.589658,0.042097,0.713142,0.04981
4,Post-B,All,264.0,998.0,Ab_raw+ Ab_binary,Post-B,All,264.0,998.0,svm,0.666246,0.017892,0.46645,0.030549,0.691062,0.037217,0.555638,0.037663
5,Post-B,All,264.0,998.0,Ab_raw+ Ab_binary+ demographic,Post-B,All,264.0,998.0,svm,0.721375,0.023013,0.548728,0.03228,0.744039,0.043451,0.603672,0.035942
6,Post-B,All,264.0,998.0,Ab_raw,Post-B,All,264.0,998.0,mlp,0.687912,0.020413,0.480037,0.037763,0.584875,0.156188,0.670334,0.157725
7,Post-B,All,264.0,998.0,Ab_raw+ demographic,Post-B,All,264.0,998.0,mlp,0.765951,0.017666,0.594364,0.035159,0.666021,0.156551,0.715709,0.126016
8,Post-B,All,264.0,998.0,Ab_raw,Post-B,All,264.0,998.0,nc,0.580031,0.04189,0.326875,0.067398,0.547397,0.045727,0.656191,0.058674
9,Post-B,All,264.0,998.0,Ab_raw+ demographic,Post-B,All,264.0,998.0,nc,0.640067,0.052502,0.416181,0.083975,0.614024,0.039258,0.730845,0.045814


In [63]:
out = agg_df[['Train_Timepoint', 'Train_Cohort', 'Train_N_Low', 'Train_N_High', 'Feature_Set', 'Test_Timepoint', 'Test_Cohort', 'Test_N_Low' ,'Test_N_High', 'Classifier',	'CV_AUC_Mean', 'CV_AUC_SD',	'CV_Sen_at_80_Spe_Mean', 'CV_Sen_at_80_Spe_SD', 'CV_Acc_Low_Mean', 'CV_Acc_Low_SD',	'CV_Acc_High_Mean',	'CV_Acc_High_SD']]
out.to_csv('final_results/model_performance_summary_v2_revised.csv')

#### Post to Post from All to cohort

In [73]:
metrics = {}
device = 'cuda' if torch.cuda.is_available() else 'cpu'

for k, v in results_6_excel.items():
    seed, task, setting, clf, coeffs, usage = re.split(r',\s*(?![^()]*\))', k)

    seed = int(seed)
    if setting == 'binary':
        cutoffs = GLOBAL_CUTOFFS
    else:
        cutoffs = None

    if 'raw' in usage:
        use_cat = False
    else:
        use_cat = True
    if use_cat:
            start_from = len(emb_list)
    else:
        start_from = 11
        
        
    
    X, y , emb_list, real_size, encoder = load_6(binary=True, cutoffs=cutoffs)
    size = real_size if use_cat else real_size - 1
    
    en = {}
    for k_cohort, v_cohort in (X[:, [0, 5]]):
        en[k_cohort] = v_cohort
     
    skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=seed)
    
    
    for_stratify = y
    for fold, (train_idx, test_idx) in enumerate(skf.split(X, for_stratify), start=0):
        #################################################### Dataset
        X_train, X_test = X[train_idx], X[test_idx]
        y_train, y_test = y[train_idx], y[test_idx]
        
        
        X_train_tensor = torch.from_numpy(X_train[:, start_from:].astype(np.float32))
        y_train_tensor = torch.from_numpy(y_train).float()
        
        X_test_tensor = torch.from_numpy(X_test[:, start_from:].astype(np.float32))
        y_test_tensor = torch.from_numpy(y_test).float()
        
        
        train_idx, val_idx = train_test_split(
            np.arange(len(X_train_tensor)),
            test_size=0.2,
            stratify=y_train_tensor.numpy() ,
            random_state=seed
        )
        
        # create datasets
        train_dataset = TensorDataset(X_train_tensor[train_idx], y_train_tensor[train_idx])
        val_dataset   = TensorDataset(X_train_tensor[val_idx],  y_train_tensor[val_idx])
        
        # data loaders
        train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True, generator=torch.Generator().manual_seed(seed))
        val_loader   = DataLoader(val_dataset,   batch_size=len(val_dataset), shuffle=False)
        test_loader = DataLoader(TensorDataset(X_test_tensor, y_test_tensor),
                                 batch_size=len(X_test_tensor), shuffle=False)
        
        m = BottleneckMLP(emb_list, size, start_from, 8, (128, 64, 2), p=0.2, use_cat=use_cat, binary=True).to(device)
        
        m.load_state_dict(v[0][fold])
        
        
        for cohort in en.keys():
            for x_batch, y_batch in test_loader:
                pass
            
            index = X_test[:, 5] == en[cohort]
            
            
            
            if (y_batch[index] == 0).float().sum() < 1:
                print(cohort, fold, seed, 'not enough negative samples')
                continue
            if clf == 'mlp':
                with torch.no_grad():
                    m.eval()
                    logits, two_d = m(x_batch[index].to(device))
            else:
                with torch.no_grad():
                    x_for_clf = []
                    y_for_clf = []
                    for x_batch_train, y_batch_train in train_loader:
                        with torch.no_grad():
                            m.eval()
                            _, two_d = m(x_batch_train.to(device))
                        x_for_clf.extend(two_d)
                        y_for_clf.extend(y_batch_train)
                x_for_clf = torch.stack(x_for_clf).cpu().numpy()
                y_for_clf = torch.stack(y_for_clf).cpu().numpy()
                if clf == 'svm':
                    one = (y_for_clf == 1).astype(float).sum()
                    zero = (y_for_clf == 0).astype(float).sum()
                    w = (one/zero)
                    delta = 1
        
                    classifier = SVC(kernel="rbf", class_weight={0: w, 1: 1}, probability=True)
                    classifier.fit(x_for_clf, y_for_clf)
                    with torch.no_grad():
                        m.eval()
                        x_in = m(x_batch[index].to(device))[-1].cpu().numpy()
                    logits = classifier.decision_function(x_in)
                elif clf == 'nc':
                    classifier = NearestCentroid(metric="euclidean")
                    classifier.fit(x_for_clf, y_for_clf)
                    with torch.no_grad():
                        m.eval()
                        x_in = m(x_batch[index].to(device))[-1].cpu().numpy()
                    logits = classifier.predict(x_in)

            if 1:
                if clf == 'mlp':
                    y_pred = (logits.reshape(-1) > 0.5).int().cpu()
                elif clf == 'svm':
                    y_pred = torch.from_numpy(logits >= delta).int()
                elif clf == 'nc':
                    
                    y_pred = torch.from_numpy(logits).int()
                a_test = (y_pred[y_batch[index]==0] == y_batch[index][y_batch[index]==0]).float().mean().item() + 1e-20
                b_test = (y_pred[y_batch[index]==1] == y_batch[index][y_batch[index]==1]).float().mean().item() + 1e-20
                c_test = 2 / ((1/a_test) + 1/b_test)
                if clf == 'svm':
                    logits = torch.sigmoid(torch.from_numpy(logits))
                elif clf == 'nc':
                    D = pairwise_distances(x_in, classifier.centroids_, metric=classifier.metric)
                    pos_idx = np.where(classifier.classes_ == 1)[0][0]
                    logits = torch.from_numpy(-D[:, pos_idx])
                    

                try:
                    auc_test = roc_auc_score(y_batch[index].cpu().numpy(), logits.cpu().numpy())
                except:
                    auc_test = None
                sen_test = []
                spec_test = []
                for t in [0.75, 0.8, 0.85, 0.9, 0.95]:
                    sen_, spec_ = at_95(logits.cpu().numpy(), y_batch[index], t=t)
                    sen_test.append(sen_)
                    spec_test.append(spec_)
                    
            
            metrics[f'{seed}, {fold}, {cohort}, {clf}, {setting}, {usage}'] = ([a_test, b_test, c_test, auc_test, *spec_test, *sen_test])                        
    
        # print('-----------------')

IBD 1 0 not enough negative samples
Cancer: Other 3 0 not enough negative samples
IBD 4 0 not enough negative samples
IBD 1 0 not enough negative samples
Cancer: Other 3 0 not enough negative samples
IBD 4 0 not enough negative samples
IBD 1 0 not enough negative samples
Cancer: Other 3 0 not enough negative samples
IBD 4 0 not enough negative samples
IBD 1 0 not enough negative samples
Cancer: Other 3 0 not enough negative samples
IBD 4 0 not enough negative samples
IBD 1 0 not enough negative samples
Cancer: Other 3 0 not enough negative samples
IBD 4 0 not enough negative samples
IBD 1 0 not enough negative samples
Cancer: Other 3 0 not enough negative samples
IBD 4 0 not enough negative samples
IBD 1 0 not enough negative samples
Cancer: Other 3 0 not enough negative samples
IBD 4 0 not enough negative samples
IBD 1 0 not enough negative samples
Cancer: Other 3 0 not enough negative samples
IBD 4 0 not enough negative samples
IBD 1 0 not enough negative samples
Cancer: Other 3 0 no

In [87]:
columns = []
for t in [0.75, 0.8, 0.85, 0.9, 0.95]:
    columns.append(f'Sensitivity at {t} Specificity')
for t in [0.75, 0.8, 0.85, 0.9, 0.95]:
    columns.append(f'Specificity at {t} Sensitivity')
columns = [
    'Low Responders Accuracy',
    'High Responders Accuracy',
    'Harmonic Mean',
    'AUC'
] + columns



records = []
for k, v in metrics.items():
    seed, fold, cohort, clf, setting, usage = k.split(', ')
    records.append({
        'seed': int(seed),
        'fold': int(fold),
        'cohort': cohort,
        'setting': setting,
        'clf': clf,
        'usage': usage,
        **{name: val for name, val in zip(columns, v)}
    })

df = pd.DataFrame(records)

group_cols = ['cohort', 'setting', 'clf', 'usage']
metric_cols = columns

agg_df = df.groupby(group_cols)[metric_cols].agg(['mean', 'std']).reset_index()
agg_df.columns = ['_'.join(col).strip('_') for col in agg_df.columns.values]
agg_df

Unnamed: 0,cohort,setting,clf,usage,Low Responders Accuracy_mean,Low Responders Accuracy_std,High Responders Accuracy_mean,High Responders Accuracy_std,Harmonic Mean_mean,Harmonic Mean_std,...,Specificity at 0.75 Sensitivity_mean,Specificity at 0.75 Sensitivity_std,Specificity at 0.8 Sensitivity_mean,Specificity at 0.8 Sensitivity_std,Specificity at 0.85 Sensitivity_mean,Specificity at 0.85 Sensitivity_std,Specificity at 0.9 Sensitivity_mean,Specificity at 0.9 Sensitivity_std,Specificity at 0.95 Sensitivity_mean,Specificity at 0.95 Sensitivity_std
0,Autoimmune: Other,binary,mlp,Ab raw,0.332560,0.263971,0.707106,0.215322,0.355744,0.229747,...,0.164627,0.204925,0.121301,0.180765,0.085408,0.155961,0.029827,0.091218,0.003152,0.035180
1,Autoimmune: Other,binary,mlp,Demographic,0.411552,0.281038,0.582601,0.244102,0.365014,0.204857,...,0.160907,0.174994,0.114341,0.155003,0.077651,0.135490,0.029032,0.084865,0.001691,0.019171
2,Autoimmune: Other,binary,nc,Ab raw,0.324891,0.208503,0.678197,0.165884,0.387952,0.197113,...,0.208600,0.208040,0.154906,0.184771,0.117507,0.162262,0.046473,0.102790,0.002454,0.021393
3,Autoimmune: Other,binary,nc,Demographic,0.316463,0.214761,0.658394,0.171643,0.369742,0.199861,...,0.196191,0.196362,0.142819,0.168677,0.109442,0.156061,0.044357,0.103235,0.002403,0.019635
4,Autoimmune: Other,binary,svm,Ab raw,0.465705,0.228333,0.583775,0.155030,0.472762,0.174827,...,0.191810,0.189244,0.151582,0.176172,0.105943,0.154581,0.043714,0.099126,0.002810,0.021536
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
79,Transplant,no-cutoff,mlp,Demographic,0.758551,0.230812,0.248275,0.235312,0.276364,0.204195,...,0.162217,0.150194,0.133221,0.136061,0.093980,0.117515,0.052908,0.096070,0.024917,0.070492
80,Transplant,no-cutoff,nc,Ab raw,0.457266,0.182022,0.486646,0.180783,0.426059,0.136331,...,0.188999,0.171921,0.147059,0.150917,0.110886,0.134572,0.063422,0.108781,0.023550,0.068592
81,Transplant,no-cutoff,nc,Demographic,0.725728,0.182564,0.252982,0.177594,0.324263,0.177518,...,0.170369,0.169583,0.130543,0.154496,0.094894,0.133636,0.054325,0.100453,0.023777,0.067905
82,Transplant,no-cutoff,svm,Ab raw,0.668722,0.179407,0.356127,0.178480,0.420354,0.153710,...,0.184902,0.173290,0.139231,0.152063,0.102060,0.135579,0.056252,0.104977,0.022163,0.064406


In [88]:
X, y , emb_list, real_size, encoder = load_6(binary=True, cutoffs=None)
counts = {}
for cohort in COHORTS:
    
    
    all_low_count_train = (y == 0).astype(float).sum()
    all_high_count_train = (y == 1).astype(float).sum()
    idx = X[:, 0] == cohort
    all_low_count = (y[idx] == 0).astype(float).sum()
    all_high_count = (y[idx] == 1).astype(float).sum()
    counts[f'{cohort},Test_N_Low'] = all_low_count
    counts[f'All,Train_N_Low'] = all_low_count_train
    counts[f'{cohort},Test_N_High'] = all_high_count
    counts[f'All,Train_N_High'] = all_high_count_train
counts

{'Autoimmune: Other,Test_N_Low': 72.0,
 'All,Train_N_Low': 264.0,
 'Autoimmune: Other,Test_N_High': 163.0,
 'All,Train_N_High': 998.0,
 'Cancer: Other,Test_N_Low': 9.0,
 'Cancer: Other,Test_N_High': 102.0,
 'HIV,Test_N_Low': 14.0,
 'HIV,Test_N_High': 70.0,
 'Healthy Control,Test_N_Low': 24.0,
 'Healthy Control,Test_N_High': 465.0,
 'IBD,Test_N_Low': 4.0,
 'IBD,Test_N_High': 33.0,
 'Multiple Myeloma,Test_N_Low': 60.0,
 'Multiple Myeloma,Test_N_High': 92.0,
 'Transplant,Test_N_Low': 81.0,
 'Transplant,Test_N_High': 73.0}

In [89]:
agg_df['Train_Timepoint'] = 'Post-B'
agg_df['Test_Timepoint'] = 'Post-B'

agg_df['Test_Cohort'] = agg_df['cohort']
agg_df['Train_Cohort'] = 'All'



def update_usage(row, what):
    cohort = row['cohort']
    try:
        return counts[f'{cohort},{what}']
    except:
        return counts[f'All,{what}']

# Apply transformation
for what in ['Train_N_Low', 'Train_N_High', 'Test_N_High', 'Test_N_Low']:
    agg_df[what] = agg_df.apply(lambda z:update_usage(z, what=what), axis=1)
    
    
agg_df = agg_df.rename(columns={
    # 'cohort': 'Train_Cohort', 
    'clf': 'Classifier', 
    'AUC_mean': 'CV_AUC_Mean',
    'AUC_std': 'CV_AUC_SD',
    'Sensitivity at 0.8 Specificity_mean': 'CV_Sen_at_80_Spe_Mean',
    'Sensitivity at 0.8 Specificity_std': 'CV_Sen_at_80_Spe_SD',
    'Low Responders Accuracy_mean': 'CV_Acc_Low_Mean',
    'High Responders Accuracy_mean': 'CV_Acc_High_Mean',
    'Low Responders Accuracy_std': 'CV_Acc_Low_SD',
    'High Responders Accuracy_std': 'CV_Acc_High_SD',
})

def update_usage(row):
    # Start with base usage name
    usage = row['usage']
    setting = row['setting']
    
    
    # Case 1: if usage is "Ab raw"
    
    out = 'Ab_raw'
    
    
    # Case 2: if usage is "Demographic"
    if setting.strip().lower() == 'binary':
        out += '+ Ab_binary'
        
    if usage.strip().lower() == 'demographic':
        out += '+ demographic'
    
    # Case 3: modify further based on 'setting'
    
    if setting.strip().lower() == 'no-cutoff':
        # don't add anything
        pass
    
    return out

# Apply transformation
agg_df['Feature_Set'] = agg_df.apply(update_usage, axis=1)
agg_df[['Train_Timepoint', 'Train_Cohort', 'Train_N_Low', 'Train_N_High', 'Feature_Set', 'Test_Timepoint', 'Test_Cohort', 'Test_N_Low' ,'Test_N_High', 'Classifier', 'CV_AUC_Mean', 'CV_AUC_SD',	'CV_Sen_at_80_Spe_Mean', 'CV_Sen_at_80_Spe_SD', 'CV_Acc_Low_Mean', 'CV_Acc_Low_SD',	'CV_Acc_High_Mean',	'CV_Acc_High_SD']]



Unnamed: 0,Train_Timepoint,Train_Cohort,Train_N_Low,Train_N_High,Feature_Set,Test_Timepoint,Test_Cohort,Test_N_Low,Test_N_High,Classifier,CV_AUC_Mean,CV_AUC_SD,CV_Sen_at_80_Spe_Mean,CV_Sen_at_80_Spe_SD,CV_Acc_Low_Mean,CV_Acc_Low_SD,CV_Acc_High_Mean,CV_Acc_High_SD
0,Post-B,All,264.0,998.0,Ab_raw+ Ab_binary,Post-B,Autoimmune: Other,72.0,163.0,mlp,0.521514,0.158315,0.353736,0.225988,0.332560,0.263971,0.707106,0.215322
1,Post-B,All,264.0,998.0,Ab_raw+ Ab_binary+ demographic,Post-B,Autoimmune: Other,72.0,163.0,mlp,0.494844,0.155857,0.320218,0.215230,0.411552,0.281038,0.582601,0.244102
2,Post-B,All,264.0,998.0,Ab_raw+ Ab_binary,Post-B,Autoimmune: Other,72.0,163.0,nc,0.563639,0.162460,0.411397,0.238082,0.324891,0.208503,0.678197,0.165884
3,Post-B,All,264.0,998.0,Ab_raw+ Ab_binary+ demographic,Post-B,Autoimmune: Other,72.0,163.0,nc,0.543587,0.157837,0.386387,0.223567,0.316463,0.214761,0.658394,0.171643
4,Post-B,All,264.0,998.0,Ab_raw+ Ab_binary,Post-B,Autoimmune: Other,72.0,163.0,svm,0.538643,0.160017,0.370269,0.221003,0.465705,0.228333,0.583775,0.155030
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
79,Post-B,All,264.0,998.0,Ab_raw+ demographic,Post-B,Transplant,81.0,73.0,mlp,0.505808,0.120246,0.258536,0.214210,0.758551,0.230812,0.248275,0.235312
80,Post-B,All,264.0,998.0,Ab_raw,Post-B,Transplant,81.0,73.0,nc,0.529683,0.134833,0.364987,0.191334,0.457266,0.182022,0.486646,0.180783
81,Post-B,All,264.0,998.0,Ab_raw+ demographic,Post-B,Transplant,81.0,73.0,nc,0.510128,0.135874,0.325188,0.188201,0.725728,0.182564,0.252982,0.177594
82,Post-B,All,264.0,998.0,Ab_raw,Post-B,Transplant,81.0,73.0,svm,0.510525,0.137290,0.329528,0.193784,0.668722,0.179407,0.356127,0.178480


In [82]:
out = agg_df[['Train_Timepoint', 'Train_Cohort', 'Train_N_Low', 'Train_N_High', 'Feature_Set', 'Test_Timepoint', 'Test_Cohort', 'Test_N_Low' ,'Test_N_High', 'Classifier',	'CV_AUC_Mean', 'CV_AUC_SD',	'CV_Sen_at_80_Spe_Mean', 'CV_Sen_at_80_Spe_SD', 'CV_Acc_Low_Mean', 'CV_Acc_Low_SD',	'CV_Acc_High_Mean',	'CV_Acc_High_SD']]
out.to_csv('final_results/model_performance_summary_v3_revised.csv')

In [312]:
best_for_each_cohort = {'Multiple Myeloma': ('mlp', 'no-cutoff'),
 'Autoimmune: Other': ('mlp', 'no-cutoff'),
 'Healthy Control': ('mlp', 'no-cutoff'),
 'Cancer: Other': ('mlp', 'no-cutoff'),
 'HIV': ('mlp', 'no-cutoff'),
 'Transplant': ('mlp', 'no-cutoff'),
 'IBD': ('mlp', 'no-cutoff')
                       }

In [846]:
# metrics = {}
# this is for 8
out_auc = {}

for k , v in results_8_excel.items():
    split_pattern = r',\s*(?![^()]*\))'
    
    
    seed, cohort, task, setting, clf, _, demographic = re.split(split_pattern, k)
    if setting == 'binary':
        continue
    if demographic != 'Demographic':
        continue

    if clf != 'mlp':
        continue

    
    seed = int(seed)
    
    if setting == 'binary':
        cutoffs = GLOBAL_CUTOFFS
    else:
        cutoffs = None
    X, y , emb_list, real_size, encoder = load_6(binary=True, cutoffs=cutoffs, qcut=False)
    idx = X[:, 0] == cohort
    
    X = X[idx]
    y = y[idx]
    
    en = {}
    for k_cohort, v_cohort in (X[:, [0, 5]]):
        en[k_cohort] = v_cohort
     
    skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=seed)
    
    
    for_stratify = y
    for fold, (train_idx, test_idx) in enumerate(skf.split(X, for_stratify), start=0):
        #################################################### Dataset
        X_train, X_test = X[train_idx], X[test_idx]
        y_train, y_test = y[train_idx], y[test_idx]
        from torch.utils.data import TensorDataset, DataLoader, random_split
        
        start_from = len(emb_list)
        size = real_size 
        X_train_tensor = torch.from_numpy(X_train[:, start_from:].astype(np.float32))
        y_train_tensor = torch.from_numpy(y_train).float()
        
        X_test_tensor = torch.from_numpy(X_test[:, start_from:].astype(np.float32))
        y_test_tensor = torch.from_numpy(y_test).float()
        
        from sklearn.model_selection import train_test_split
        # stratified split using sklearn
        train_idx, val_idx = train_test_split(
            np.arange(len(X_train_tensor)),
            test_size=0.2,
            stratify=y_train_tensor.numpy() ,
            random_state=seed
        )
        
        # create datasets
        train_dataset = TensorDataset(X_train_tensor[train_idx], y_train_tensor[train_idx])
        val_dataset   = TensorDataset(X_train_tensor[val_idx],  y_train_tensor[val_idx])
        
        # data loaders
        train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True, generator=torch.Generator().manual_seed(seed))
        val_loader   = DataLoader(val_dataset,   batch_size=len(val_dataset), shuffle=False)
        test_loader = DataLoader(TensorDataset(X_test_tensor, y_test_tensor),
                                 batch_size=len(X_test_tensor), shuffle=False)
        m = BottleneckMLP(emb_list, real_size, len(emb_list), 8, (128, 64, 2), p=0.2, use_cat=True, binary=True).to(device)
        m.load_state_dict(v[0][fold])
        
        for cohort in en.keys():
            
            if best_for_each_cohort[cohort] != (clf, setting):
                continue
            else:
                for x_batch, y_batch in test_loader:
                    pass
                index = x_batch[:, 0] == en[cohort]
                if (y_batch[index] == 0).float().sum() < 1:
                    print(cohort, fold, seed, 'not enough negative samples')
                    continue
                if clf == 'mlp':
                    with torch.no_grad():
                        m.eval()
                        logits, two_d = m(x_batch[index].to(device))
                else:
                    with torch.no_grad():
                        x_for_clf = []
                        y_for_clf = []
                        for x_batch_train, y_batch_train in train_loader:
                            with torch.no_grad():
                                m.eval()
                                _, two_d = m(x_batch_train.to(device))
                            x_for_clf.extend(two_d)
                            y_for_clf.extend(y_batch_train)
                    x_for_clf = torch.stack(x_for_clf).cpu().numpy()
                    y_for_clf = torch.stack(y_for_clf).cpu().numpy()
                    if clf == 'svm':
                        one = (y_for_clf == 1).astype(float).sum()
                        zero = (y_for_clf == 0).astype(float).sum()
                        w = (one/zero)
                        delta = 1
            
                        classifier = SVC(kernel="rbf", class_weight={0: w, 1: 1}, probability=True)
                        classifier.fit(x_for_clf, y_for_clf)
                        with torch.no_grad():
                            m.eval()
                            x_in = m(x_batch[index].to(device))[-1].cpu().numpy()
                        logits = classifier.decision_function(x_in)
                    elif clf == 'nc':
                        classifier = NearestCentroid(metric="euclidean")
                        classifier.fit(x_for_clf, y_for_clf)
                        with torch.no_grad():
                            m.eval()
                            x_in = m(x_batch[index].to(device))[-1].cpu().numpy()
                        logits = classifier.predict(x_in)
    
                if 1:
                    if clf == 'mlp':
                        y_pred = (logits.reshape(-1) > 0.5).int().cpu()
                    elif clf == 'svm':
                        y_pred = torch.from_numpy(logits >= delta).int()
                    elif clf == 'nc':
                        
                        y_pred = torch.from_numpy(logits).int()
                    a_test = (y_pred[y_batch[index]==0] == y_batch[index][y_batch[index]==0]).float().mean().item() + 1e-20
                    b_test = (y_pred[y_batch[index]==1] == y_batch[index][y_batch[index]==1]).float().mean().item() + 1e-20
                    c_test = 2 / ((1/a_test) + 1/b_test)
                    if clf == 'svm':
                        logits = torch.sigmoid(torch.from_numpy(logits))
                    elif clf == 'nc':
                        D = pairwise_distances(x_in, classifier.centroids_, metric=classifier.metric)
                        pos_idx = np.where(classifier.classes_ == 1)[0][0]
                        logits = torch.from_numpy(-D[:, pos_idx])
                        # logits = torch.sigmoid(torch.from_numpy(logits))
    
                    try:
                        auc_test = roc_auc_score(y_batch[index].cpu().numpy(), logits.cpu().numpy())
                        # print(y_batch[index])
                        # print(logits)
                        out_auc[f'{cohort},{clf},{setting},{fold},{seed}'] = (y_batch[index].cpu().numpy(), logits.cpu().numpy())
                        # print(cohort)
                        # print(clf)
                        # print(setting)
                        # print(fold)
                        # print(seed)
                        # # raise
                        # fpr, tpr, thr = roc_curve(y_batch[index].cpu().numpy(), logits.cpu().numpy())
                        # roc_auc = auc(fpr, tpr)
                        
                        # RocCurveDisplay(fpr=fpr, tpr=tpr, roc_auc=roc_auc, estimator_name="LogReg").plot()
                        # plt.title(f"ROC curve (AUC = {roc_auc:.3f})")
                        # plt.show()
                        # # raise
                    except:
                        auc_test = None
                    # raise
                    sen_test = []
                    spec_test = []
                    for t in [0.75, 0.8, 0.85, 0.9, 0.95]:
                        sen_, spec_ = at_95(logits.cpu().numpy(), y_batch[index], t=t)
                        sen_test.append(sen_)
                        spec_test.append(spec_)
                    
            
            # metrics[f'{seed}, {fold}, {cohort}, {clf}, {setting}'] = ([a_test, b_test, c_test, auc_test, *spec_test, *sen_test])                        
    


In [848]:
to_json = {}
for k, v in out_auc.items():
    to_json[k] = (v[0].reshape(-1).tolist(), v[1].reshape(-1).tolist())

# out_auc
# import json
with open('data_for_auc_mlp-no cutoff-cohort specific-4,1-demographic-post to post.json', 'w') as fp:
    json.dump(to_json,fp)

In [849]:
# metrics = {}
out_auc = {}

for k , v in results_6_excel.items():
    
    split_pattern = r',\s*(?![^()]*\))'
    
    
    seed, task, setting, clf, _, demographic = re.split(split_pattern, k)
    
    
    seed = int(seed)

    if setting == 'binary':
        continue
    if demographic != 'Demographic':
        continue

    if clf != 'mlp':
        continue
    
    if setting == 'binary':
        cutoffs = GLOBAL_CUTOFFS
    else:
        cutoffs = None
    X, y , emb_list, real_size, encoder = load_6(binary=True, cutoffs=cutoffs, qcut=False)
    en = {}
    for k_cohort, v_cohort in (X[:, [0, 5]]):
        en[k_cohort] = v_cohort
     
    skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=seed)
    
    
    for_stratify = y
    for fold, (train_idx, test_idx) in enumerate(skf.split(X, for_stratify), start=0):
        #################################################### Dataset
        X_train, X_test = X[train_idx], X[test_idx]
        y_train, y_test = y[train_idx], y[test_idx]
        from torch.utils.data import TensorDataset, DataLoader, random_split
        
        start_from = len(emb_list)
        size = real_size 
        X_train_tensor = torch.from_numpy(X_train[:, start_from:].astype(np.float32))
        y_train_tensor = torch.from_numpy(y_train).float()
        
        X_test_tensor = torch.from_numpy(X_test[:, start_from:].astype(np.float32))
        y_test_tensor = torch.from_numpy(y_test).float()
        
        from sklearn.model_selection import train_test_split
        # stratified split using sklearn
        train_idx, val_idx = train_test_split(
            np.arange(len(X_train_tensor)),
            test_size=0.2,
            stratify=y_train_tensor.numpy() ,
            random_state=seed
        )
        
        # create datasets
        train_dataset = TensorDataset(X_train_tensor[train_idx], y_train_tensor[train_idx])
        val_dataset   = TensorDataset(X_train_tensor[val_idx],  y_train_tensor[val_idx])
        
        # data loaders
        train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True, generator=torch.Generator().manual_seed(seed))
        val_loader   = DataLoader(val_dataset,   batch_size=len(val_dataset), shuffle=False)
        test_loader = DataLoader(TensorDataset(X_test_tensor, y_test_tensor),
                                 batch_size=len(X_test_tensor), shuffle=False)
        m = BottleneckMLP(emb_list, real_size, len(emb_list), 8, (128, 64, 2), p=0.2, use_cat=True, binary=True).to(device)
        m.load_state_dict(v[0][fold])
        
        for cohort in en.keys():
            
            if best_for_each_cohort[cohort] != (clf, setting):
                continue
            else:
                for x_batch, y_batch in test_loader:
                    pass
                index = x_batch[:, 0] == en[cohort]
                if (y_batch[index] == 0).float().sum() < 1:
                    print(cohort, fold, seed, 'not enough negative samples')
                    continue
                if clf == 'mlp':
                    with torch.no_grad():
                        m.eval()
                        logits, two_d = m(x_batch[index].to(device))
                else:
                    with torch.no_grad():
                        x_for_clf = []
                        y_for_clf = []
                        for x_batch_train, y_batch_train in train_loader:
                            with torch.no_grad():
                                m.eval()
                                _, two_d = m(x_batch_train.to(device))
                            x_for_clf.extend(two_d)
                            y_for_clf.extend(y_batch_train)
                    x_for_clf = torch.stack(x_for_clf).cpu().numpy()
                    y_for_clf = torch.stack(y_for_clf).cpu().numpy()
                    if clf == 'svm':
                        one = (y_for_clf == 1).astype(float).sum()
                        zero = (y_for_clf == 0).astype(float).sum()
                        w = (one/zero)
                        delta = 1
            
                        classifier = SVC(kernel="rbf", class_weight={0: w, 1: 1}, probability=True)
                        classifier.fit(x_for_clf, y_for_clf)
                        with torch.no_grad():
                            m.eval()
                            x_in = m(x_batch[index].to(device))[-1].cpu().numpy()
                        logits = classifier.decision_function(x_in)
                    elif clf == 'nc':
                        classifier = NearestCentroid(metric="euclidean")
                        classifier.fit(x_for_clf, y_for_clf)
                        with torch.no_grad():
                            m.eval()
                            x_in = m(x_batch[index].to(device))[-1].cpu().numpy()
                        logits = classifier.predict(x_in)
    
                if 1:
                    if clf == 'mlp':
                        y_pred = (logits.reshape(-1) > 0.5).int().cpu()
                    elif clf == 'svm':
                        y_pred = torch.from_numpy(logits >= delta).int()
                    elif clf == 'nc':
                        
                        y_pred = torch.from_numpy(logits).int()
                    a_test = (y_pred[y_batch[index]==0] == y_batch[index][y_batch[index]==0]).float().mean().item() + 1e-20
                    b_test = (y_pred[y_batch[index]==1] == y_batch[index][y_batch[index]==1]).float().mean().item() + 1e-20
                    c_test = 2 / ((1/a_test) + 1/b_test)
                    if clf == 'svm':
                        logits = torch.sigmoid(torch.from_numpy(logits))
                    elif clf == 'nc':
                        D = pairwise_distances(x_in, classifier.centroids_, metric=classifier.metric)
                        pos_idx = np.where(classifier.classes_ == 1)[0][0]
                        logits = torch.from_numpy(-D[:, pos_idx])
                        # logits = torch.sigmoid(torch.from_numpy(logits))
    
                    try:
                        auc_test = roc_auc_score(y_batch[index].cpu().numpy(), logits.cpu().numpy())
                        # print(y_batch[index])
                        # print(logits)
                        out_auc[f'{cohort},{clf},{setting},{fold},{seed}'] = (y_batch[index].cpu().numpy(), logits.cpu().numpy())
                        # print(cohort)
                        # print(clf)
                        # print(setting)
                        # print(fold)
                        # print(seed)
                        # # raise
                        # fpr, tpr, thr = roc_curve(y_batch[index].cpu().numpy(), logits.cpu().numpy())
                        # roc_auc = auc(fpr, tpr)
                        
                        # RocCurveDisplay(fpr=fpr, tpr=tpr, roc_auc=roc_auc, estimator_name="LogReg").plot()
                        # plt.title(f"ROC curve (AUC = {roc_auc:.3f})")
                        # plt.show()
                        # # raise
                    except:
                        auc_test = None
                    # raise
                    sen_test = []
                    spec_test = []
                    for t in [0.75, 0.8, 0.85, 0.9, 0.95]:
                        sen_, spec_ = at_95(logits.cpu().numpy(), y_batch[index], t=t)
                        sen_test.append(sen_)
                        spec_test.append(spec_)
                    
            
            # metrics[f'{seed}, {fold}, {cohort}, {clf}, {setting}'] = ([a_test, b_test, c_test, auc_test, *spec_test, *sen_test])                        
    


IBD 1 0 not enough negative samples
Cancer: Other 3 0 not enough negative samples
IBD 4 0 not enough negative samples
IBD 0 1 not enough negative samples
Cancer: Other 1 1 not enough negative samples
IBD 1 1 not enough negative samples
HIV 4 1 not enough negative samples
IBD 0 2 not enough negative samples
IBD 1 2 not enough negative samples
Cancer: Other 2 2 not enough negative samples
Cancer: Other 1 3 not enough negative samples
IBD 2 3 not enough negative samples
Cancer: Other 3 3 not enough negative samples
IBD 4 3 not enough negative samples
IBD 3 4 not enough negative samples
Cancer: Other 1 5 not enough negative samples
IBD 2 5 not enough negative samples
Cancer: Other 0 6 not enough negative samples
IBD 2 6 not enough negative samples
IBD 3 6 not enough negative samples
IBD 1 7 not enough negative samples
IBD 4 7 not enough negative samples
IBD 0 8 not enough negative samples
Cancer: Other 1 8 not enough negative samples
IBD 2 8 not enough negative samples
HIV 3 8 not enough n

In [826]:
out_auc['Multiple Myeloma,mlp,no-cutoff,0,0'][0]

array([1., 0., 1., 0., 1., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 1., 1.,
       0., 0., 1., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1.], dtype=float32)

In [850]:
to_json = {}
for k, v in out_auc.items():
    to_json[k] = (v[0].reshape(-1).tolist(), v[1].reshape(-1).tolist())

In [851]:
out_auc
import json
with open('data_for_auc_mlp-no cutoff-universal-4,1-demographic-post to post.json', 'w') as fp:
    json.dump(to_json,fp)

In [861]:
# metrics = {}
out_auc = {}

for k, v in results_6_excel.items():
    # print(k)
    # raise
    seed, task, setting, clf, coeffs, usage = re.split(r',\s*(?![^()]*\))', k)
    # print(usage)
    seed = int(seed)
    if setting == 'binary':
        continue
    if usage != 'Demographic':
        continue

    if clf != 'mlp':
        continue

    # print(seed, task, setting, clf, coeffs, usage)
    if setting == 'binary':
        cutoffs = GLOBAL_CUTOFFS
    else:
        cutoffs = None

    if 'raw' in usage:
        use_cat = False
    else:
        use_cat = True
    if use_cat:
            start_from = len(emb_list)
    else:
        start_from = 11
        
    
    X6, y6 , pid6, emb_list6, real_size6, encoder6 = load_6(binary=True, cutoffs=cutoffs, qcut=False, include_pid=True)
    
    X2, y2 , pid2, emb_list2, real_size2, encoder2 = load_2(binary=True, cutoffs=cutoffs, qcut=False, include_pid=True)
    
    size = real_size6 if use_cat else real_size6 - 1
    
    en = {}
    for k_cohort, v_cohort in (X6[:, [0, 5]]):
        en[k_cohort] = v_cohort
     
    skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=seed)
    
    
    for_stratify = y6
    for fold, (train_idx, test_idx) in enumerate(skf.split(X6, for_stratify), start=0):
        
        index = np.isin(pid2, pid6[test_idx])
        
        X_train = X6[train_idx]
        y_train = y6[train_idx]
        X_test = X2[index]
        y_test = y2[index]
        
        
        
        #################################################### Dataset
        # X_train, X_test = X6[train_idx], X6[test_idx]
        # y_train, y_test = y6[train_idx], y6[test_idx]
        from torch.utils.data import TensorDataset, DataLoader, random_split
        
        # start_from = len(emb_list)
        # size = real_size 
        X_train_tensor = torch.from_numpy(X_train[:, start_from:].astype(np.float32))
        y_train_tensor = torch.from_numpy(y_train).float()
        
        X_test_tensor = torch.from_numpy(X_test[:, start_from:].astype(np.float32))
        y_test_tensor = torch.from_numpy(y_test).float()
        
        from sklearn.model_selection import train_test_split
        # stratified split using sklearn
        train_idx, val_idx = train_test_split(
            np.arange(len(X_train_tensor)),
            test_size=0.2,
            stratify=y_train_tensor.numpy() ,
            random_state=seed
        )
        
        # create datasets
        train_dataset = TensorDataset(X_train_tensor[train_idx], y_train_tensor[train_idx])
        val_dataset   = TensorDataset(X_train_tensor[val_idx],  y_train_tensor[val_idx])
        
        # data loaders
        train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True, generator=torch.Generator().manual_seed(seed))
        val_loader   = DataLoader(val_dataset,   batch_size=len(val_dataset), shuffle=False)
        test_loader = DataLoader(TensorDataset(X_test_tensor, y_test_tensor),
                                 batch_size=len(X_test_tensor), shuffle=False)
        # print(size, start_from)
        m = BottleneckMLP(emb_list, size, start_from, 8, (128, 64, 2), p=0.2, use_cat=use_cat, binary=True).to(device)
        # print(m)
        # print(fold, k, len(emb_list))
        # for a, b in v[0][fold].items():
            # print(a, b.shape)
        m.load_state_dict(v[0][fold])
        
        
        for cohort in en.keys():
            for x_batch, y_batch in test_loader:
                pass
            # index = x_batch[:, 0] == en[cohort]
            index = X_test[:, 5] == en[cohort]
            # print(cohort, en[cohort])
            
            # print(y_batch[index])
            # print(x_batch.shape, X_test[:, :].shape)
            
            
            if (y_batch[index] == 0).float().sum() < 1:
                print(cohort, fold, seed, 'not enough negative samples')
                continue
            if clf == 'mlp':
                with torch.no_grad():
                    m.eval()
                    logits, two_d = m(x_batch[index].to(device))
            else:
                with torch.no_grad():
                    x_for_clf = []
                    y_for_clf = []
                    for x_batch_train, y_batch_train in train_loader:
                        with torch.no_grad():
                            m.eval()
                            _, two_d = m(x_batch_train.to(device))
                        x_for_clf.extend(two_d)
                        y_for_clf.extend(y_batch_train)
                x_for_clf = torch.stack(x_for_clf).cpu().numpy()
                y_for_clf = torch.stack(y_for_clf).cpu().numpy()
                if clf == 'svm':
                    one = (y_for_clf == 1).astype(float).sum()
                    zero = (y_for_clf == 0).astype(float).sum()
                    w = (one/zero)
                    delta = 1
        
                    classifier = SVC(kernel="rbf", class_weight={0: w, 1: 1}, probability=True)
                    classifier.fit(x_for_clf, y_for_clf)
                    with torch.no_grad():
                        m.eval()
                        x_in = m(x_batch[index].to(device))[-1].cpu().numpy()
                    logits = classifier.decision_function(x_in)
                elif clf == 'nc':
                    classifier = NearestCentroid(metric="euclidean")
                    classifier.fit(x_for_clf, y_for_clf)
                    with torch.no_grad():
                        m.eval()
                        x_in = m(x_batch[index].to(device))[-1].cpu().numpy()
                    logits = classifier.predict(x_in)

            if 1:
                if clf == 'mlp':
                    y_pred = (logits.reshape(-1) > 0.5).int().cpu()
                elif clf == 'svm':
                    y_pred = torch.from_numpy(logits >= delta).int()
                elif clf == 'nc':
                    
                    y_pred = torch.from_numpy(logits).int()
                a_test = (y_pred[y_batch[index]==0] == y_batch[index][y_batch[index]==0]).float().mean().item() + 1e-20
                b_test = (y_pred[y_batch[index]==1] == y_batch[index][y_batch[index]==1]).float().mean().item() + 1e-20
                c_test = 2 / ((1/a_test) + 1/b_test)
                if clf == 'svm':
                    logits = torch.sigmoid(torch.from_numpy(logits))
                elif clf == 'nc':
                    D = pairwise_distances(x_in, classifier.centroids_, metric=classifier.metric)
                    pos_idx = np.where(classifier.classes_ == 1)[0][0]
                    logits = torch.from_numpy(-D[:, pos_idx])
                    # logits = torch.sigmoid(torch.from_numpy(logits))

                try:
                    auc_test = roc_auc_score(y_batch[index].cpu().numpy(), logits.cpu().numpy())
                    out_auc[f'{cohort},{clf},{setting},{fold},{seed}'] = (y_batch[index].cpu().numpy(), logits.cpu().numpy())
                except:
                    auc_test = None
                sen_test = []
                spec_test = []
                for t in [0.75, 0.8, 0.85, 0.9, 0.95]:
                    sen_, spec_ = at_95(logits.cpu().numpy(), y_batch[index], t=t)
                    sen_test.append(sen_)
                    spec_test.append(spec_)
                    
            
            # metrics[f'{seed}, {fold}, {cohort}, {clf}, {setting}, {usage}'] = ([a_test, b_test, c_test, auc_test, *spec_test, *sen_test])                        
    
        # print('-----------------')

IBD 1 0 not enough negative samples
Cancer: Other 3 0 not enough negative samples
IBD 3 0 not enough negative samples
IBD 4 0 not enough negative samples
IBD 0 1 not enough negative samples
Cancer: Other 1 1 not enough negative samples
IBD 1 1 not enough negative samples
IBD 2 1 not enough negative samples
HIV 4 1 not enough negative samples
IBD 0 2 not enough negative samples
IBD 1 2 not enough negative samples
Cancer: Other 2 2 not enough negative samples
IBD 0 3 not enough negative samples
Cancer: Other 1 3 not enough negative samples
IBD 2 3 not enough negative samples
Cancer: Other 3 3 not enough negative samples
IBD 4 3 not enough negative samples
Cancer: Other 1 4 not enough negative samples
IBD 2 4 not enough negative samples
Cancer: Other 3 4 not enough negative samples
IBD 3 4 not enough negative samples
Cancer: Other 1 5 not enough negative samples
IBD 2 5 not enough negative samples
IBD 4 5 not enough negative samples
Cancer: Other 0 6 not enough negative samples
IBD 0 6 no



IBD 4 16 not enough negative samples
IBD 0 17 not enough negative samples
IBD 1 17 not enough negative samples
HIV 3 17 not enough negative samples
IBD 3 17 not enough negative samples
Cancer: Other 0 18 not enough negative samples
IBD 0 18 not enough negative samples
IBD 2 18 not enough negative samples
IBD 3 18 not enough negative samples
HIV 0 19 not enough negative samples
IBD 0 19 not enough negative samples
Cancer: Other 2 19 not enough negative samples
IBD 2 19 not enough negative samples
Cancer: Other 3 19 not enough negative samples
IBD 3 20 not enough negative samples
IBD 4 20 not enough negative samples
Cancer: Other 2 21 not enough negative samples
IBD 2 21 not enough negative samples
Cancer: Other 3 21 not enough negative samples
IBD 3 21 not enough negative samples
Cancer: Other 0 22 not enough negative samples
IBD 0 22 not enough negative samples
Cancer: Other 2 22 not enough negative samples
IBD 3 22 not enough negative samples
IBD 0 23 not enough negative samples
Cance



IBD 0 33 not enough negative samples
IBD 1 33 not enough negative samples
Cancer: Other 4 33 not enough negative samples
IBD 4 33 not enough negative samples
IBD 0 34 not enough negative samples
IBD 4 34 not enough negative samples
Cancer: Other 0 35 not enough negative samples
IBD 0 35 not enough negative samples
Cancer: Other 2 35 not enough negative samples
IBD 2 35 not enough negative samples
IBD 3 35 not enough negative samples
Cancer: Other 0 36 not enough negative samples
IBD 1 36 not enough negative samples
Cancer: Other 3 36 not enough negative samples
IBD 4 36 not enough negative samples
Cancer: Other 0 37 not enough negative samples
IBD 0 37 not enough negative samples
IBD 4 37 not enough negative samples
Cancer: Other 0 38 not enough negative samples
Cancer: Other 2 38 not enough negative samples
IBD 2 38 not enough negative samples
IBD 3 38 not enough negative samples
IBD 4 38 not enough negative samples
IBD 0 39 not enough negative samples
IBD 1 39 not enough negative sam

In [862]:
to_json = {}
for k, v in out_auc.items():
    to_json[k] = (v[0].reshape(-1).tolist(), v[1].reshape(-1).tolist())

# out_auc
# import json
with open('data_for_auc_mlp-no cutoff-universal-4,1-demographic-post to pre.json', 'w') as fp:
    json.dump(to_json,fp)

In [864]:
# metrics = {}
out_acu = {}

for k, v in results_8_excel.items():
    # print(k)
    # raise
    # print(k)
    # raise
    
    # print(usage)
    

    split_pattern = r',\s*(?![^()]*\))'
    
    
    seed, cohort, task, setting, clf, _, demographic = re.split(split_pattern, k)
    seed = int(seed)
    
    if setting == 'binary':
        continue
    if demographic != 'Demographic':
        continue

    if clf != 'mlp':
        continue
    if setting == 'binary':
        cutoffs = GLOBAL_CUTOFFS
    else:
        cutoffs = None

    if 'raw' in demographic:
        use_cat = False
    else:
        use_cat = True
    if use_cat:
            start_from = len(emb_list)
    else:
        start_from = 11
        
    
    X6, y6 , pid6, emb_list6, real_size6, encoder6 = load_6(binary=True, cutoffs=cutoffs, qcut=False, include_pid=True)
    
    X2, y2 , pid2, emb_list2, real_size2, encoder2 = load_2(binary=True, cutoffs=cutoffs, qcut=False, include_pid=True)
    
    idx = X6[:, 0] == cohort
    X6 = X6[idx]
    y6 = y6[idx]
    pid6 = pid6[idx]

    idx = X2[:, 0] == cohort
    X2 = X2[idx]
    y2 = y2[idx]
    pid2 = pid2[idx]
    
    # if (y6[idx] == 0).sum() < 5:
    #     continue
    size = real_size6 if use_cat else real_size6 - 1
    
    # en = {}
    # for k_cohort, v_cohort in (X6[idx, [0, 5]]):
    #     en[k_cohort] = v_cohort
     
    skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=seed)
    
    
    for_stratify = y6
    for fold, (train_idx, test_idx) in enumerate(skf.split(X6, for_stratify), start=0):
        
        index = np.isin(pid2, pid6[test_idx])
        
        X_train = X6[train_idx]
        y_train = y6[train_idx]
        X_test = X2[index]
        y_test = y2[index]

        if (y_test == 0).sum() < 1:
            print('not enough samples for', cohort, fold, seed)
            continue
        
        
        #################################################### Dataset
        # X_train, X_test = X6[train_idx], X6[test_idx]
        # y_train, y_test = y6[train_idx], y6[test_idx]
        from torch.utils.data import TensorDataset, DataLoader, random_split
        
        # start_from = len(emb_list)
        # size = real_size 
        X_train_tensor = torch.from_numpy(X_train[:, start_from:].astype(np.float32))
        y_train_tensor = torch.from_numpy(y_train).float()
        
        X_test_tensor = torch.from_numpy(X_test[:, start_from:].astype(np.float32))
        y_test_tensor = torch.from_numpy(y_test).float()
        
        from sklearn.model_selection import train_test_split
        # stratified split using sklearn
        train_idx, val_idx = train_test_split(
            np.arange(len(X_train_tensor)),
            test_size=0.2,
            stratify=y_train_tensor.numpy() ,
            random_state=seed
        )
        
        # create datasets
        train_dataset = TensorDataset(X_train_tensor[train_idx], y_train_tensor[train_idx])
        val_dataset   = TensorDataset(X_train_tensor[val_idx],  y_train_tensor[val_idx])
        
        # data loaders
        train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True, generator=torch.Generator().manual_seed(seed))
        val_loader   = DataLoader(val_dataset,   batch_size=len(val_dataset), shuffle=False)
        test_loader = DataLoader(TensorDataset(X_test_tensor, y_test_tensor),
                                 batch_size=len(X_test_tensor), shuffle=False)
        # print(size, start_from)
        m = BottleneckMLP(emb_list, size, start_from, 8, (128, 64, 2), p=0.2, use_cat=use_cat, binary=True).to(device)
        # print(m)
        # print(fold, k, len(emb_list))
        # for a, b in v[0][fold].items():
            # print(a, b.shape)
        m.load_state_dict(v[0][fold])
        
        
        # for cohort in en.keys():
        for x_batch, y_batch in test_loader:
            pass
        # index = x_batch[:, 0] == en[cohort]
        # index = X_test[:, 5] == en[cohort]
        # print(cohort, en[cohort])
        
        # print(y_batch[index])
        # print(x_batch.shape, X_test[:, :].shape)
        
        
        # if (y_batch[index] == 0).float().sum() < 1:
        #     print(cohort, fold, seed, 'not enough negative samples')
        #     continue
        if clf == 'mlp':
            with torch.no_grad():
                m.eval()
                logits, two_d = m(x_batch.to(device))
        else:
            with torch.no_grad():
                x_for_clf = []
                y_for_clf = []
                for x_batch_train, y_batch_train in train_loader:
                    with torch.no_grad():
                        m.eval()
                        _, two_d = m(x_batch_train.to(device))
                    x_for_clf.extend(two_d)
                    y_for_clf.extend(y_batch_train)
            x_for_clf = torch.stack(x_for_clf).cpu().numpy()
            y_for_clf = torch.stack(y_for_clf).cpu().numpy()
            if clf == 'svm':
                one = (y_for_clf == 1).astype(float).sum()
                zero = (y_for_clf == 0).astype(float).sum()
                w = (one/zero)
                delta = 1
    
                classifier = SVC(kernel="rbf", class_weight={0: w, 1: 1}, probability=True)
                classifier.fit(x_for_clf, y_for_clf)
                with torch.no_grad():
                    m.eval()
                    x_in = m(x_batch.to(device))[-1].cpu().numpy()
                logits = classifier.decision_function(x_in)
            elif clf == 'nc':
                classifier = NearestCentroid(metric="euclidean")
                classifier.fit(x_for_clf, y_for_clf)
                with torch.no_grad():
                    m.eval()
                    x_in = m(x_batch.to(device))[-1].cpu().numpy()
                logits = classifier.predict(x_in)

        if 1:
            if clf == 'mlp':
                y_pred = (logits.reshape(-1) > 0.5).int().cpu()
            elif clf == 'svm':
                y_pred = torch.from_numpy(logits >= delta).int()
            elif clf == 'nc':
                
                y_pred = torch.from_numpy(logits).int()
            a_test = (y_pred[y_batch==0] == y_batch[y_batch==0]).float().mean().item() + 1e-20
            b_test = (y_pred[y_batch==1] == y_batch[y_batch==1]).float().mean().item() + 1e-20
            c_test = 2 / ((1/a_test) + 1/b_test)
            if clf == 'svm':
                logits = torch.sigmoid(torch.from_numpy(logits))
            elif clf == 'nc':
                D = pairwise_distances(x_in, classifier.centroids_, metric=classifier.metric)
                pos_idx = np.where(classifier.classes_ == 1)[0][0]
                logits = torch.from_numpy(-D[:, pos_idx])
                # logits = torch.sigmoid(torch.from_numpy(logits))

            try:
                auc_test = roc_auc_score(y_batch.cpu().numpy(), logits.cpu().numpy())
                out_auc[f'{cohort},{clf},{setting},{fold},{seed}'] = (y_batch[index].cpu().numpy(), logits.cpu().numpy())
            except:
                auc_test = None
            sen_test = []
            spec_test = []
            for t in [0.75, 0.8, 0.85, 0.9, 0.95]:
                sen_, spec_ = at_95(logits.cpu().numpy(), y_batch, t=t)
                sen_test.append(sen_)
                spec_test.append(spec_)
                
        
        # metrics[f'{seed}, {cohort}, {fold}, {clf}, {setting}, {usage}'] = ([a_test, b_test, c_test, auc_test, *spec_test, *sen_test])                        
        

not enough samples for Cancer: Other 4 0
not enough samples for Cancer: Other 1 1
not enough samples for Cancer: Other 1 2
not enough samples for Cancer: Other 0 3
not enough samples for Cancer: Other 1 5
not enough samples for Cancer: Other 3 6
not enough samples for Cancer: Other 4 8
not enough samples for Cancer: Other 2 9
not enough samples for Cancer: Other 0 10
not enough samples for Cancer: Other 1 10
not enough samples for Cancer: Other 1 11
not enough samples for Cancer: Other 3 13
not enough samples for Cancer: Other 1 14
not enough samples for Cancer: Other 4 16
not enough samples for Cancer: Other 1 18
not enough samples for Cancer: Other 2 19
not enough samples for Cancer: Other 1 21
not enough samples for Cancer: Other 0 22
not enough samples for Cancer: Other 1 25
not enough samples for Cancer: Other 1 28
not enough samples for Cancer: Other 1 29
not enough samples for Cancer: Other 1 30
not enough samples for Cancer: Other 0 31
not enough samples for Cancer: Other 2 32


In [865]:
to_json = {}
for k, v in out_auc.items():
    to_json[k] = (v[0].reshape(-1).tolist(), v[1].reshape(-1).tolist())

# out_auc
# import json
with open('data_for_auc_mlp-no cutoff-cohort specific-4,1-demographic-post to pre.json', 'w') as fp:
    json.dump(to_json,fp)

## Train on 6 test on 2

#### Post to Pre from All to cohort

In [86]:
metrics = {}

for k, v in results_6_excel.items():
    
    seed, task, setting, clf, coeffs, usage = re.split(r',\s*(?![^()]*\))', k)
    seed = int(seed)
    if setting == 'binary':
        cutoffs = GLOBAL_CUTOFFS
    else:
        cutoffs = None

    if 'raw' in usage:
        use_cat = False
    else:
        use_cat = True
    if use_cat:
            start_from = len(emb_list)
    else:
        start_from = 11
        
    
    X6, y6 , pid6, emb_list6, real_size6, encoder6 = load_6(binary=True, cutoffs=cutoffs, include_pid=True)
    
    X2, y2 , pid2, emb_list2, real_size2, encoder2 = load_2(binary=True, cutoffs=cutoffs, include_pid=True)
    
    size = real_size6 if use_cat else real_size6 - 1
    
    en = {}
    for k_cohort, v_cohort in (X6[:, [0, 5]]):
        en[k_cohort] = v_cohort
     
    skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=seed)
    
    
    for_stratify = y6
    for fold, (train_idx, test_idx) in enumerate(skf.split(X6, for_stratify), start=0):
        
        index = np.isin(pid2, pid6[test_idx])
        
        X_train = X6[train_idx]
        y_train = y6[train_idx]
        X_test = X2[index]
        y_test = y2[index]
        
        #################################################### Dataset
        
        X_train_tensor = torch.from_numpy(X_train[:, start_from:].astype(np.float32))
        y_train_tensor = torch.from_numpy(y_train).float()
        
        X_test_tensor = torch.from_numpy(X_test[:, start_from:].astype(np.float32))
        y_test_tensor = torch.from_numpy(y_test).float()
        
        
        # stratified split using sklearn
        train_idx, val_idx = train_test_split(
            np.arange(len(X_train_tensor)),
            test_size=0.2,
            stratify=y_train_tensor.numpy() ,
            random_state=seed
        )
        
        # create datasets
        train_dataset = TensorDataset(X_train_tensor[train_idx], y_train_tensor[train_idx])
        val_dataset   = TensorDataset(X_train_tensor[val_idx],  y_train_tensor[val_idx])
        
        # data loaders
        train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True, generator=torch.Generator().manual_seed(seed))
        val_loader   = DataLoader(val_dataset,   batch_size=len(val_dataset), shuffle=False)
        test_loader = DataLoader(TensorDataset(X_test_tensor, y_test_tensor),
                                 batch_size=len(X_test_tensor), shuffle=False)
        m = BottleneckMLP(emb_list, size, start_from, 8, (128, 64, 2), p=0.2, use_cat=use_cat, binary=True).to(device)
        m.load_state_dict(v[0][fold])
        
        
        for cohort in en.keys():
            for x_batch, y_batch in test_loader:
                pass
            
            index = X_test[:, 5] == en[cohort]
            
            if (y_batch[index] == 0).float().sum() < 1:
                print(cohort, fold, seed, 'not enough negative samples')
                continue
            if clf == 'mlp':
                with torch.no_grad():
                    m.eval()
                    logits, two_d = m(x_batch[index].to(device))
            else:
                with torch.no_grad():
                    x_for_clf = []
                    y_for_clf = []
                    for x_batch_train, y_batch_train in train_loader:
                        with torch.no_grad():
                            m.eval()
                            _, two_d = m(x_batch_train.to(device))
                        x_for_clf.extend(two_d)
                        y_for_clf.extend(y_batch_train)
                x_for_clf = torch.stack(x_for_clf).cpu().numpy()
                y_for_clf = torch.stack(y_for_clf).cpu().numpy()
                if clf == 'svm':
                    one = (y_for_clf == 1).astype(float).sum()
                    zero = (y_for_clf == 0).astype(float).sum()
                    w = (one/zero)
                    delta = 1
        
                    classifier = SVC(kernel="rbf", class_weight={0: w, 1: 1}, probability=True)
                    classifier.fit(x_for_clf, y_for_clf)
                    with torch.no_grad():
                        m.eval()
                        x_in = m(x_batch[index].to(device))[-1].cpu().numpy()
                    logits = classifier.decision_function(x_in)
                elif clf == 'nc':
                    classifier = NearestCentroid(metric="euclidean")
                    classifier.fit(x_for_clf, y_for_clf)
                    with torch.no_grad():
                        m.eval()
                        x_in = m(x_batch[index].to(device))[-1].cpu().numpy()
                    logits = classifier.predict(x_in)

            if 1:
                if clf == 'mlp':
                    y_pred = (logits.reshape(-1) > 0.5).int().cpu()
                elif clf == 'svm':
                    y_pred = torch.from_numpy(logits >= delta).int()
                elif clf == 'nc':
                    
                    y_pred = torch.from_numpy(logits).int()
                a_test = (y_pred[y_batch[index]==0] == y_batch[index][y_batch[index]==0]).float().mean().item() + 1e-20
                b_test = (y_pred[y_batch[index]==1] == y_batch[index][y_batch[index]==1]).float().mean().item() + 1e-20
                c_test = 2 / ((1/a_test) + 1/b_test)
                if clf == 'svm':
                    logits = torch.sigmoid(torch.from_numpy(logits))
                elif clf == 'nc':
                    D = pairwise_distances(x_in, classifier.centroids_, metric=classifier.metric)
                    pos_idx = np.where(classifier.classes_ == 1)[0][0]
                    logits = torch.from_numpy(-D[:, pos_idx])
                    # logits = torch.sigmoid(torch.from_numpy(logits))

                try:
                    auc_test = roc_auc_score(y_batch[index].cpu().numpy(), logits.cpu().numpy())
                except:
                    auc_test = None
                sen_test = []
                spec_test = []
                for t in [0.75, 0.8, 0.85, 0.9, 0.95]:
                    sen_, spec_ = at_95(logits.cpu().numpy(), y_batch[index], t=t)
                    sen_test.append(sen_)
                    spec_test.append(spec_)
                    
            
            metrics[f'{seed}, {fold}, {cohort}, {clf}, {setting}, {usage}'] = ([a_test, b_test, c_test, auc_test, *spec_test, *sen_test])                        
    

IBD 1 0 not enough negative samples
Cancer: Other 3 0 not enough negative samples
IBD 3 0 not enough negative samples
IBD 4 0 not enough negative samples
IBD 1 0 not enough negative samples
Cancer: Other 3 0 not enough negative samples
IBD 3 0 not enough negative samples
IBD 4 0 not enough negative samples
IBD 1 0 not enough negative samples
Cancer: Other 3 0 not enough negative samples
IBD 3 0 not enough negative samples
IBD 4 0 not enough negative samples
IBD 1 0 not enough negative samples
Cancer: Other 3 0 not enough negative samples
IBD 3 0 not enough negative samples
IBD 4 0 not enough negative samples
IBD 1 0 not enough negative samples
Cancer: Other 3 0 not enough negative samples
IBD 3 0 not enough negative samples
IBD 4 0 not enough negative samples
IBD 1 0 not enough negative samples
Cancer: Other 3 0 not enough negative samples
IBD 3 0 not enough negative samples
IBD 4 0 not enough negative samples
IBD 1 0 not enough negative samples
Cancer: Other 3 0 not enough negative sa



IBD 4 16 not enough negative samples
IBD 0 16 not enough negative samples
Cancer: Other 2 16 not enough negative samples
IBD 2 16 not enough negative samples




IBD 4 16 not enough negative samples
IBD 0 16 not enough negative samples
Cancer: Other 2 16 not enough negative samples
IBD 2 16 not enough negative samples




IBD 4 16 not enough negative samples
IBD 0 16 not enough negative samples
Cancer: Other 2 16 not enough negative samples
IBD 2 16 not enough negative samples




IBD 4 16 not enough negative samples
IBD 0 16 not enough negative samples




Cancer: Other 2 16 not enough negative samples
IBD 2 16 not enough negative samples
IBD 4 16 not enough negative samples
IBD 0 16 not enough negative samples




Cancer: Other 2 16 not enough negative samples
IBD 2 16 not enough negative samples
IBD 4 16 not enough negative samples
IBD 0 16 not enough negative samples




Cancer: Other 2 16 not enough negative samples
IBD 2 16 not enough negative samples
IBD 4 16 not enough negative samples
IBD 0 16 not enough negative samples
Cancer: Other 2 16 not enough negative samples




IBD 2 16 not enough negative samples
IBD 4 16 not enough negative samples
IBD 0 16 not enough negative samples




Cancer: Other 2 16 not enough negative samples
IBD 2 16 not enough negative samples
IBD 4 16 not enough negative samples
IBD 0 16 not enough negative samples




Cancer: Other 2 16 not enough negative samples
IBD 2 16 not enough negative samples
IBD 4 16 not enough negative samples
IBD 0 16 not enough negative samples




Cancer: Other 2 16 not enough negative samples
IBD 2 16 not enough negative samples
IBD 4 16 not enough negative samples
IBD 0 16 not enough negative samples




Cancer: Other 2 16 not enough negative samples
IBD 2 16 not enough negative samples
IBD 4 16 not enough negative samples
IBD 0 17 not enough negative samples
IBD 1 17 not enough negative samples
HIV 3 17 not enough negative samples
IBD 3 17 not enough negative samples
IBD 0 17 not enough negative samples
IBD 1 17 not enough negative samples
HIV 3 17 not enough negative samples
IBD 3 17 not enough negative samples
IBD 0 17 not enough negative samples
IBD 1 17 not enough negative samples
HIV 3 17 not enough negative samples
IBD 3 17 not enough negative samples
IBD 0 17 not enough negative samples
IBD 1 17 not enough negative samples
HIV 3 17 not enough negative samples
IBD 3 17 not enough negative samples
IBD 0 17 not enough negative samples
IBD 1 17 not enough negative samples
HIV 3 17 not enough negative samples
IBD 3 17 not enough negative samples
IBD 0 17 not enough negative samples
IBD 1 17 not enough negative samples
HIV 3 17 not enough negative samples
IBD 3 17 not enough negative



IBD 3 32 not enough negative samples
Cancer: Other 4 32 not enough negative samples
IBD 4 32 not enough negative samples
IBD 0 32 not enough negative samples
IBD 2 32 not enough negative samples




IBD 3 32 not enough negative samples
Cancer: Other 4 32 not enough negative samples
IBD 4 32 not enough negative samples
IBD 0 32 not enough negative samples
IBD 2 32 not enough negative samples
IBD 3 32 not enough negative samples
Cancer: Other 4 32 not enough negative samples




IBD 4 32 not enough negative samples
IBD 0 32 not enough negative samples
IBD 2 32 not enough negative samples
IBD 3 32 not enough negative samples
Cancer: Other 4 32 not enough negative samples




IBD 4 32 not enough negative samples
IBD 0 32 not enough negative samples




IBD 2 32 not enough negative samples
IBD 3 32 not enough negative samples
Cancer: Other 4 32 not enough negative samples
IBD 4 32 not enough negative samples
IBD 0 32 not enough negative samples




IBD 2 32 not enough negative samples
IBD 3 32 not enough negative samples
Cancer: Other 4 32 not enough negative samples
IBD 4 32 not enough negative samples
IBD 0 32 not enough negative samples




IBD 2 32 not enough negative samples
IBD 3 32 not enough negative samples
Cancer: Other 4 32 not enough negative samples
IBD 4 32 not enough negative samples
IBD 0 32 not enough negative samples




IBD 2 32 not enough negative samples
IBD 3 32 not enough negative samples
Cancer: Other 4 32 not enough negative samples
IBD 4 32 not enough negative samples
IBD 0 32 not enough negative samples




IBD 2 32 not enough negative samples
IBD 3 32 not enough negative samples
Cancer: Other 4 32 not enough negative samples
IBD 4 32 not enough negative samples
IBD 0 32 not enough negative samples




IBD 2 32 not enough negative samples
IBD 3 32 not enough negative samples
Cancer: Other 4 32 not enough negative samples
IBD 4 32 not enough negative samples
IBD 0 32 not enough negative samples




IBD 2 32 not enough negative samples
IBD 3 32 not enough negative samples
Cancer: Other 4 32 not enough negative samples
IBD 4 32 not enough negative samples
IBD 0 32 not enough negative samples




IBD 2 32 not enough negative samples
IBD 3 32 not enough negative samples
Cancer: Other 4 32 not enough negative samples
IBD 4 32 not enough negative samples
IBD 0 33 not enough negative samples
IBD 1 33 not enough negative samples
Cancer: Other 4 33 not enough negative samples
IBD 4 33 not enough negative samples
IBD 0 33 not enough negative samples
IBD 1 33 not enough negative samples
Cancer: Other 4 33 not enough negative samples
IBD 4 33 not enough negative samples
IBD 0 33 not enough negative samples
IBD 1 33 not enough negative samples
Cancer: Other 4 33 not enough negative samples
IBD 4 33 not enough negative samples
IBD 0 33 not enough negative samples
IBD 1 33 not enough negative samples
Cancer: Other 4 33 not enough negative samples
IBD 4 33 not enough negative samples
IBD 0 33 not enough negative samples
IBD 1 33 not enough negative samples
Cancer: Other 4 33 not enough negative samples
IBD 4 33 not enough negative samples
IBD 0 33 not enough negative samples
IBD 1 33 not en

In [90]:
columns = []
for t in [0.75, 0.8, 0.85, 0.9, 0.95]:
    columns.append(f'Sensitivity at {t} Specificity')
for t in [0.75, 0.8, 0.85, 0.9, 0.95]:
    columns.append(f'Specificity at {t} Sensitivity')
columns = [
    'Low Responders Accuracy',
    'High Responders Accuracy',
    'Harmonic Mean',
    'AUC'
] + columns



records = []
for k, v in metrics.items():
    seed, fold, cohort, clf, setting, usage = k.split(', ')
    records.append({
        'seed': int(seed),
        'fold': int(fold),
        'cohort': cohort,
        'setting': setting,
        'clf': clf,
        'usage': usage,
        **{name: val for name, val in zip(columns, v)}
    })

df = pd.DataFrame(records)

group_cols = ['cohort', 'setting', 'clf', 'usage']
metric_cols = columns

agg_df = df.groupby(group_cols)[metric_cols].agg(['mean', 'std']).reset_index()
agg_df.columns = ['_'.join(col).strip('_') for col in agg_df.columns.values]
agg_df

Unnamed: 0,cohort,setting,clf,usage,Low Responders Accuracy_mean,Low Responders Accuracy_std,High Responders Accuracy_mean,High Responders Accuracy_std,Harmonic Mean_mean,Harmonic Mean_std,...,Specificity at 0.75 Sensitivity_mean,Specificity at 0.75 Sensitivity_std,Specificity at 0.8 Sensitivity_mean,Specificity at 0.8 Sensitivity_std,Specificity at 0.85 Sensitivity_mean,Specificity at 0.85 Sensitivity_std,Specificity at 0.9 Sensitivity_mean,Specificity at 0.9 Sensitivity_std,Specificity at 0.95 Sensitivity_mean,Specificity at 0.95 Sensitivity_std
0,Autoimmune: Other,binary,mlp,Ab raw,0.332560,0.263971,0.707106,0.215322,0.355744,0.229747,...,0.164627,0.204925,0.121301,0.180765,0.085408,0.155961,0.029827,0.091218,0.003152,0.035180
1,Autoimmune: Other,binary,mlp,Demographic,0.411552,0.281038,0.582601,0.244102,0.365014,0.204857,...,0.160907,0.174994,0.114341,0.155003,0.077651,0.135490,0.029032,0.084865,0.001691,0.019171
2,Autoimmune: Other,binary,nc,Ab raw,0.324891,0.208503,0.678197,0.165884,0.387952,0.197113,...,0.208600,0.208040,0.154906,0.184771,0.117507,0.162262,0.046473,0.102790,0.002454,0.021393
3,Autoimmune: Other,binary,nc,Demographic,0.316463,0.214761,0.658394,0.171643,0.369742,0.199861,...,0.196191,0.196362,0.142819,0.168677,0.109442,0.156061,0.044357,0.103235,0.002403,0.019635
4,Autoimmune: Other,binary,svm,Ab raw,0.465705,0.228333,0.583775,0.155030,0.472762,0.174827,...,0.191810,0.189244,0.151582,0.176172,0.105943,0.154581,0.043714,0.099126,0.002810,0.021536
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
79,Transplant,no-cutoff,mlp,Demographic,0.758551,0.230812,0.248275,0.235312,0.276364,0.204195,...,0.162217,0.150194,0.133221,0.136061,0.093980,0.117515,0.052908,0.096070,0.024917,0.070492
80,Transplant,no-cutoff,nc,Ab raw,0.457266,0.182022,0.486646,0.180783,0.426059,0.136331,...,0.188999,0.171921,0.147059,0.150917,0.110886,0.134572,0.063422,0.108781,0.023550,0.068592
81,Transplant,no-cutoff,nc,Demographic,0.725728,0.182564,0.252982,0.177594,0.324263,0.177518,...,0.170369,0.169583,0.130543,0.154496,0.094894,0.133636,0.054325,0.100453,0.023777,0.067905
82,Transplant,no-cutoff,svm,Ab raw,0.668722,0.179407,0.356127,0.178480,0.420354,0.153710,...,0.184902,0.173290,0.139231,0.152063,0.102060,0.135579,0.056252,0.104977,0.022163,0.064406


In [91]:
counts = {}
X6, y6 , pid6, emb_list6, real_size6, encoder6 = load_6(binary=True, cutoffs=cutoffs, include_pid=True)
    
X2, y2 , pid2, emb_list2, real_size2, encoder2 = load_2(binary=True, cutoffs=cutoffs, include_pid=True)


X_train = X6
y_train = y6
X_test = X2
y_test = y2
for cohort in COHORTS:     
    
    idx2 = X_test[:, 0] == cohort

    counts[f'{cohort},Test_N_Low'] = (y_test[idx2] == 0).astype(float).sum()
    counts[f'{cohort},Test_N_High'] = (y_test[idx2] == 1).astype(float).sum()
    counts[f'All,Train_N_Low'] = (y_train == 0).astype(float).sum()
    counts[f'All,Train_N_High'] = (y_train == 1).astype(float).sum()

counts

{'Autoimmune: Other,Test_N_Low': 29.0,
 'Autoimmune: Other,Test_N_High': 68.0,
 'All,Train_N_Low': 264.0,
 'All,Train_N_High': 998.0,
 'Cancer: Other,Test_N_Low': 6.0,
 'Cancer: Other,Test_N_High': 79.0,
 'HIV,Test_N_Low': 14.0,
 'HIV,Test_N_High': 63.0,
 'Healthy Control,Test_N_Low': 21.0,
 'Healthy Control,Test_N_High': 425.0,
 'IBD,Test_N_Low': 3.0,
 'IBD,Test_N_High': 28.0,
 'Multiple Myeloma,Test_N_Low': 42.0,
 'Multiple Myeloma,Test_N_High': 78.0,
 'Transplant,Test_N_Low': 50.0,
 'Transplant,Test_N_High': 56.0}

In [93]:
agg_df['Train_Timepoint'] = 'Post-B'
agg_df['Test_Timepoint'] = 'Pre-B'

agg_df['Test_Cohort'] = agg_df['cohort']
agg_df['Train_Cohort'] = 'All'



def update_usage(row, what):
    cohort = row['cohort']
    try:
        return counts[f'{cohort},{what}']
    except:
        return counts[f'All,{what}']

# Apply transformation
for what in ['Train_N_Low', 'Train_N_High', 'Test_N_High', 'Test_N_Low']:
    agg_df[what] = agg_df.apply(lambda z:update_usage(z, what=what), axis=1)
    
    
agg_df = agg_df.rename(columns={
    # 'cohort': 'Train_Cohort', 
    'clf': 'Classifier', 
    'AUC_mean': 'CV_AUC_Mean',
    'AUC_std': 'CV_AUC_SD',
    'Sensitivity at 0.8 Specificity_mean': 'CV_Sen_at_80_Spe_Mean',
    'Sensitivity at 0.8 Specificity_std': 'CV_Sen_at_80_Spe_SD',
    'Low Responders Accuracy_mean': 'CV_Acc_Low_Mean',
    'High Responders Accuracy_mean': 'CV_Acc_High_Mean',
    'Low Responders Accuracy_std': 'CV_Acc_Low_SD',
    'High Responders Accuracy_std': 'CV_Acc_High_SD',
})

def update_usage(row):
    # Start with base usage name
    usage = row['usage']
    setting = row['setting']
    
    
    # Case 1: if usage is "Ab raw"
    
    out = 'Ab_raw'
    
    
    # Case 2: if usage is "Demographic"
    if setting.strip().lower() == 'binary':
        out += '+ Ab_binary'
        
    if usage.strip().lower() == 'demographic':
        out += '+ demographic'
    
    # Case 3: modify further based on 'setting'
    
    if setting.strip().lower() == 'no-cutoff':
        # don't add anything
        pass
    
    return out

# Apply transformation
agg_df['Feature_Set'] = agg_df.apply(update_usage, axis=1)
agg_df[['Train_Timepoint', 'Train_Cohort', 'Train_N_Low', 'Train_N_High', 'Feature_Set', 'Test_Timepoint', 'Test_Cohort', 'Test_N_Low' ,'Test_N_High', 'Classifier', 'CV_AUC_Mean', 'CV_AUC_SD',	'CV_Sen_at_80_Spe_Mean', 'CV_Sen_at_80_Spe_SD', 'CV_Acc_Low_Mean', 'CV_Acc_Low_SD',	'CV_Acc_High_Mean',	'CV_Acc_High_SD']]



Unnamed: 0,Train_Timepoint,Train_Cohort,Train_N_Low,Train_N_High,Feature_Set,Test_Timepoint,Test_Cohort,Test_N_Low,Test_N_High,Classifier,CV_AUC_Mean,CV_AUC_SD,CV_Sen_at_80_Spe_Mean,CV_Sen_at_80_Spe_SD,CV_Acc_Low_Mean,CV_Acc_Low_SD,CV_Acc_High_Mean,CV_Acc_High_SD
0,Post-B,All,264.0,998.0,Ab_raw+ Ab_binary,Pre-B,Autoimmune: Other,29.0,68.0,mlp,0.521514,0.158315,0.353736,0.225988,0.332560,0.263971,0.707106,0.215322
1,Post-B,All,264.0,998.0,Ab_raw+ Ab_binary+ demographic,Pre-B,Autoimmune: Other,29.0,68.0,mlp,0.494844,0.155857,0.320218,0.215230,0.411552,0.281038,0.582601,0.244102
2,Post-B,All,264.0,998.0,Ab_raw+ Ab_binary,Pre-B,Autoimmune: Other,29.0,68.0,nc,0.563639,0.162460,0.411397,0.238082,0.324891,0.208503,0.678197,0.165884
3,Post-B,All,264.0,998.0,Ab_raw+ Ab_binary+ demographic,Pre-B,Autoimmune: Other,29.0,68.0,nc,0.543587,0.157837,0.386387,0.223567,0.316463,0.214761,0.658394,0.171643
4,Post-B,All,264.0,998.0,Ab_raw+ Ab_binary,Pre-B,Autoimmune: Other,29.0,68.0,svm,0.538643,0.160017,0.370269,0.221003,0.465705,0.228333,0.583775,0.155030
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
79,Post-B,All,264.0,998.0,Ab_raw+ demographic,Pre-B,Transplant,50.0,56.0,mlp,0.505808,0.120246,0.258536,0.214210,0.758551,0.230812,0.248275,0.235312
80,Post-B,All,264.0,998.0,Ab_raw,Pre-B,Transplant,50.0,56.0,nc,0.529683,0.134833,0.364987,0.191334,0.457266,0.182022,0.486646,0.180783
81,Post-B,All,264.0,998.0,Ab_raw+ demographic,Pre-B,Transplant,50.0,56.0,nc,0.510128,0.135874,0.325188,0.188201,0.725728,0.182564,0.252982,0.177594
82,Post-B,All,264.0,998.0,Ab_raw,Pre-B,Transplant,50.0,56.0,svm,0.510525,0.137290,0.329528,0.193784,0.668722,0.179407,0.356127,0.178480


In [94]:
out = agg_df[['Train_Timepoint', 'Train_Cohort', 'Train_N_Low', 'Train_N_High', 'Feature_Set', 'Test_Timepoint', 'Test_Cohort', 'Test_N_Low' ,'Test_N_High', 'Classifier',	'CV_AUC_Mean', 'CV_AUC_SD',	'CV_Sen_at_80_Spe_Mean', 'CV_Sen_at_80_Spe_SD', 'CV_Acc_Low_Mean', 'CV_Acc_Low_SD',	'CV_Acc_High_Mean',	'CV_Acc_High_SD']]
out.to_csv('final_results/model_performance_summary_v4.csv')

#### Post to Pre All

In [97]:
metrics = {}

for k, v in results_6_excel.items():
    seed, task, setting, clf, coeffs, usage = re.split(r',\s*(?![^()]*\))', k)
    seed = int(seed)
    if setting == 'binary':
        cutoffs = GLOBAL_CUTOFFS
    else:
        cutoffs = None

    if 'raw' in usage:
        use_cat = False
    else:
        use_cat = True
    if use_cat:
            start_from = len(emb_list)
    else:
        start_from = 11
        
    
    X6, y6 , pid6, emb_list6, real_size6, encoder6 = load_6(binary=True, cutoffs=cutoffs, include_pid=True)
    
    X2, y2 , pid2, emb_list2, real_size2, encoder2 = load_2(binary=True, cutoffs=cutoffs, include_pid=True)
    
    size = real_size6 if use_cat else real_size6 - 1
    
     
    skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=seed)
    
    
    for_stratify = y6
    for fold, (train_idx, test_idx) in enumerate(skf.split(X6, for_stratify), start=0):
        
        index = np.isin(pid2, pid6[test_idx])
        
        X_train = X6[train_idx]
        y_train = y6[train_idx]
        X_test = X2[index]
        y_test = y2[index]
        
        
        
        #################################################### Dataset
        
        X_train_tensor = torch.from_numpy(X_train[:, start_from:].astype(np.float32))
        y_train_tensor = torch.from_numpy(y_train).float()
        
        X_test_tensor = torch.from_numpy(X_test[:, start_from:].astype(np.float32))
        y_test_tensor = torch.from_numpy(y_test).float()
        
        # stratified split using sklearn
        train_idx, val_idx = train_test_split(
            np.arange(len(X_train_tensor)),
            test_size=0.2,
            stratify=y_train_tensor.numpy() ,
            random_state=seed
        )
        
        # create datasets
        train_dataset = TensorDataset(X_train_tensor[train_idx], y_train_tensor[train_idx])
        val_dataset   = TensorDataset(X_train_tensor[val_idx],  y_train_tensor[val_idx])
        
        # data loaders
        train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True, generator=torch.Generator().manual_seed(seed))
        val_loader   = DataLoader(val_dataset,   batch_size=len(val_dataset), shuffle=False)
        test_loader = DataLoader(TensorDataset(X_test_tensor, y_test_tensor),
                                 batch_size=len(X_test_tensor), shuffle=False)
        
        m = BottleneckMLP(emb_list, size, start_from, 8, (128, 64, 2), p=0.2, use_cat=use_cat, binary=True).to(device)
        
        m.load_state_dict(v[0][fold])
        
        
        
        for x_batch, y_batch in test_loader:
            pass
        
        if clf == 'mlp':
            with torch.no_grad():
                m.eval()
                logits, two_d = m(x_batch.to(device))
        else:
            with torch.no_grad():
                x_for_clf = []
                y_for_clf = []
                for x_batch_train, y_batch_train in train_loader:
                    with torch.no_grad():
                        m.eval()
                        _, two_d = m(x_batch_train.to(device))
                    x_for_clf.extend(two_d)
                    y_for_clf.extend(y_batch_train)
            x_for_clf = torch.stack(x_for_clf).cpu().numpy()
            y_for_clf = torch.stack(y_for_clf).cpu().numpy()
            if clf == 'svm':
                one = (y_for_clf == 1).astype(float).sum()
                zero = (y_for_clf == 0).astype(float).sum()
                w = (one/zero)
                delta = 1
    
                classifier = SVC(kernel="rbf", class_weight={0: w, 1: 1}, probability=True)
                classifier.fit(x_for_clf, y_for_clf)
                with torch.no_grad():
                    m.eval()
                    x_in = m(x_batch.to(device))[-1].cpu().numpy()
                logits = classifier.decision_function(x_in)
            elif clf == 'nc':
                classifier = NearestCentroid(metric="euclidean")
                classifier.fit(x_for_clf, y_for_clf)
                with torch.no_grad():
                    m.eval()
                    x_in = m(x_batch.to(device))[-1].cpu().numpy()
                logits = classifier.predict(x_in)

        if 1:
            if clf == 'mlp':
                y_pred = (logits.reshape(-1) > 0.5).int().cpu()
            elif clf == 'svm':
                y_pred = torch.from_numpy(logits >= delta).int()
            elif clf == 'nc':
                
                y_pred = torch.from_numpy(logits).int()
            a_test = (y_pred[y_batch==0] == y_batch[y_batch==0]).float().mean().item() + 1e-20
            b_test = (y_pred[y_batch==1] == y_batch[y_batch==1]).float().mean().item() + 1e-20
            c_test = 2 / ((1/a_test) + 1/b_test)
            if clf == 'svm':
                logits = torch.sigmoid(torch.from_numpy(logits))
            elif clf == 'nc':
                D = pairwise_distances(x_in, classifier.centroids_, metric=classifier.metric)
                pos_idx = np.where(classifier.classes_ == 1)[0][0]
                logits = torch.from_numpy(-D[:, pos_idx])
                # logits = torch.sigmoid(torch.from_numpy(logits))

            try:
                auc_test = roc_auc_score(y_batch.cpu().numpy(), logits.cpu().numpy())
            except:
                auc_test = None
            sen_test = []
            spec_test = []
            for t in [0.75, 0.8, 0.85, 0.9, 0.95]:
                sen_, spec_ = at_95(logits.cpu().numpy(), y_batch, t=t)
                sen_test.append(sen_)
                spec_test.append(spec_)
                
        
        metrics[f'{seed}, {fold}, {clf}, {setting}, {usage}'] = ([a_test, b_test, c_test, auc_test, *spec_test, *sen_test])                        
    
        

In [98]:
columns = []
for t in [0.75, 0.8, 0.85, 0.9, 0.95]:
    columns.append(f'Sensitivity at {t} Specificity')
for t in [0.75, 0.8, 0.85, 0.9, 0.95]:
    columns.append(f'Specificity at {t} Sensitivity')
columns = [
    'Low Responders Accuracy',
    'High Responders Accuracy',
    'Harmonic Mean',
    'AUC'
] + columns


records = []
for k, v in metrics.items():
    seed, fold, clf, setting, usage = k.split(', ')
    records.append({
        'seed': int(seed),
        'fold': int(fold),
        'cohort': 'All',
        'setting': setting,
        'clf': clf,
        'usage': usage,
        **{name: val for name, val in zip(columns, v)}
    })

df = pd.DataFrame(records)

group_cols = ['cohort', 'setting', 'clf', 'usage']
metric_cols = columns

agg_df = df.groupby(group_cols)[metric_cols].agg(['mean', 'std']).reset_index()
agg_df.columns = ['_'.join(col).strip('_') for col in agg_df.columns.values]
agg_df

Unnamed: 0,cohort,setting,clf,usage,Low Responders Accuracy_mean,Low Responders Accuracy_std,High Responders Accuracy_mean,High Responders Accuracy_std,Harmonic Mean_mean,Harmonic Mean_std,...,Specificity at 0.75 Sensitivity_mean,Specificity at 0.75 Sensitivity_std,Specificity at 0.8 Sensitivity_mean,Specificity at 0.8 Sensitivity_std,Specificity at 0.85 Sensitivity_mean,Specificity at 0.85 Sensitivity_std,Specificity at 0.9 Sensitivity_mean,Specificity at 0.9 Sensitivity_std,Specificity at 0.95 Sensitivity_mean,Specificity at 0.95 Sensitivity_std
0,All,binary,mlp,Ab raw,0.546679,0.193326,0.689418,0.171713,0.55766,0.089702,...,0.466124,0.11694,0.398826,0.13488,0.297711,0.154334,0.172275,0.138274,0.086293,0.096078
1,All,binary,mlp,Demographic,0.632616,0.184636,0.699507,0.149374,0.623668,0.083498,...,0.57432,0.109623,0.50897,0.12405,0.417769,0.152308,0.289022,0.165593,0.150567,0.136864
2,All,binary,nc,Ab raw,0.518382,0.099174,0.670994,0.102654,0.572452,0.060303,...,0.29979,0.106659,0.246237,0.10023,0.190513,0.09133,0.123596,0.073609,0.069,0.056936
3,All,binary,nc,Demographic,0.568769,0.094535,0.726695,0.084805,0.629702,0.063332,...,0.368454,0.133882,0.306229,0.125295,0.238079,0.115049,0.15458,0.091888,0.081413,0.065334
4,All,binary,svm,Ab raw,0.663709,0.104688,0.568295,0.095354,0.599536,0.054096,...,0.41068,0.127076,0.335877,0.120513,0.263031,0.114042,0.172252,0.098184,0.094826,0.074898
5,All,binary,svm,Demographic,0.72334,0.092172,0.613595,0.087935,0.65479,0.051897,...,0.515991,0.13933,0.43383,0.148416,0.345146,0.148904,0.236757,0.134813,0.137043,0.105877
6,All,no-cutoff,mlp,Ab raw,0.55581,0.180396,0.678054,0.169472,0.563384,0.080274,...,0.46674,0.10946,0.402178,0.117846,0.304434,0.136694,0.17655,0.123484,0.091774,0.09122
7,All,no-cutoff,mlp,Demographic,0.634172,0.174214,0.718785,0.136188,0.639169,0.07785,...,0.612784,0.094059,0.554827,0.110685,0.473491,0.142574,0.317607,0.180307,0.164535,0.154044
8,All,no-cutoff,nc,Ab raw,0.515509,0.105706,0.660709,0.112453,0.56392,0.063254,...,0.298961,0.10473,0.241912,0.095393,0.183702,0.086119,0.115515,0.066469,0.06463,0.053535
9,All,no-cutoff,nc,Demographic,0.592777,0.106073,0.735204,0.088069,0.645536,0.064131,...,0.387195,0.138742,0.324726,0.132948,0.248718,0.119506,0.166859,0.10275,0.093044,0.076183


In [99]:
counts = {}
X6, y6 , pid6, emb_list6, real_size6, encoder6 = load_6(binary=True, cutoffs=cutoffs, include_pid=True)
    
X2, y2 , pid2, emb_list2, real_size2, encoder2 = load_2(binary=True, cutoffs=cutoffs, include_pid=True)


X_train = X6
y_train = y6
X_test = X2
y_test = y2


counts[f'Test_N_Low'] = (y_test == 0).astype(float).sum()
counts[f'Test_N_High'] = (y_test == 1).astype(float).sum()
counts[f'Train_N_Low'] = (y_train == 0).astype(float).sum()
counts[f'Train_N_High'] = (y_train == 1).astype(float).sum()

counts

{'Test_N_Low': 165.0,
 'Test_N_High': 797.0,
 'Train_N_Low': 264.0,
 'Train_N_High': 998.0}

In [100]:
agg_df['Train_Timepoint'] = 'Post-B'
agg_df['Test_Timepoint'] = 'Pre-B'

agg_df['Test_Cohort'] = agg_df['cohort']
agg_df['Train_Cohort'] = 'All'



def update_usage(row, what):
    # cohort = row['cohort']
    return counts[f'{what}']

# Apply transformation
for what in ['Train_N_Low', 'Train_N_High', 'Test_N_High', 'Test_N_Low']:
    agg_df[what] = agg_df.apply(lambda z:update_usage(z, what=what), axis=1)
    
    
agg_df = agg_df.rename(columns={
    # 'cohort': 'Train_Cohort', 
    'clf': 'Classifier', 
    'AUC_mean': 'CV_AUC_Mean',
    'AUC_std': 'CV_AUC_SD',
    'Sensitivity at 0.8 Specificity_mean': 'CV_Sen_at_80_Spe_Mean',
    'Sensitivity at 0.8 Specificity_std': 'CV_Sen_at_80_Spe_SD',
    'Low Responders Accuracy_mean': 'CV_Acc_Low_Mean',
    'High Responders Accuracy_mean': 'CV_Acc_High_Mean',
    'Low Responders Accuracy_std': 'CV_Acc_Low_SD',
    'High Responders Accuracy_std': 'CV_Acc_High_SD',
})

def update_usage(row):
    # Start with base usage name
    usage = row['usage']
    setting = row['setting']
    
    
    # Case 1: if usage is "Ab raw"
    
    out = 'Ab_raw'
    
    
    # Case 2: if usage is "Demographic"
    if setting.strip().lower() == 'binary':
        out += '+ Ab_binary'
        
    if usage.strip().lower() == 'demographic':
        out += '+ demographic'
    
    # Case 3: modify further based on 'setting'
    
    if setting.strip().lower() == 'no-cutoff':
        # don't add anything
        pass
    
    return out

# Apply transformation
agg_df['Feature_Set'] = agg_df.apply(update_usage, axis=1)
agg_df[['Train_Timepoint', 'Train_Cohort', 'Train_N_Low', 'Train_N_High', 'Feature_Set', 'Test_Timepoint', 'Test_Cohort', 'Test_N_Low' ,'Test_N_High', 'Classifier', 'CV_AUC_Mean', 'CV_AUC_SD',	'CV_Sen_at_80_Spe_Mean', 'CV_Sen_at_80_Spe_SD', 'CV_Acc_Low_Mean', 'CV_Acc_Low_SD',	'CV_Acc_High_Mean',	'CV_Acc_High_SD']]



Unnamed: 0,Train_Timepoint,Train_Cohort,Train_N_Low,Train_N_High,Feature_Set,Test_Timepoint,Test_Cohort,Test_N_Low,Test_N_High,Classifier,CV_AUC_Mean,CV_AUC_SD,CV_Sen_at_80_Spe_Mean,CV_Sen_at_80_Spe_SD,CV_Acc_Low_Mean,CV_Acc_Low_SD,CV_Acc_High_Mean,CV_Acc_High_SD
0,Post-B,All,264.0,998.0,Ab_raw+ Ab_binary,Pre-B,All,165.0,797.0,mlp,0.672716,0.051423,0.463382,0.089839,0.546679,0.193326,0.689418,0.171713
1,Post-B,All,264.0,998.0,Ab_raw+ Ab_binary+ demographic,Pre-B,All,165.0,797.0,mlp,0.732907,0.049385,0.558058,0.098004,0.632616,0.184636,0.699507,0.149374
2,Post-B,All,264.0,998.0,Ab_raw+ Ab_binary,Pre-B,All,165.0,797.0,nc,0.563766,0.074537,0.307254,0.11155,0.518382,0.099174,0.670994,0.102654
3,Post-B,All,264.0,998.0,Ab_raw+ Ab_binary+ demographic,Pre-B,All,165.0,797.0,nc,0.610702,0.083908,0.370715,0.138304,0.568769,0.094535,0.726695,0.084805
4,Post-B,All,264.0,998.0,Ab_raw+ Ab_binary,Pre-B,All,165.0,797.0,svm,0.654548,0.056933,0.451998,0.095927,0.663709,0.104688,0.568295,0.095354
5,Post-B,All,264.0,998.0,Ab_raw+ Ab_binary+ demographic,Pre-B,All,165.0,797.0,svm,0.713875,0.053128,0.55213,0.095701,0.72334,0.092172,0.613595,0.087935
6,Post-B,All,264.0,998.0,Ab_raw,Pre-B,All,165.0,797.0,mlp,0.672709,0.053971,0.463179,0.097002,0.55581,0.180396,0.678054,0.169472
7,Post-B,All,264.0,998.0,Ab_raw+ demographic,Pre-B,All,165.0,797.0,mlp,0.747714,0.044895,0.580958,0.095335,0.634172,0.174214,0.718785,0.136188
8,Post-B,All,264.0,998.0,Ab_raw,Pre-B,All,165.0,797.0,nc,0.558908,0.076168,0.30434,0.117314,0.515509,0.105706,0.660709,0.112453
9,Post-B,All,264.0,998.0,Ab_raw+ demographic,Pre-B,All,165.0,797.0,nc,0.62603,0.081634,0.396166,0.134052,0.592777,0.106073,0.735204,0.088069


In [101]:
out = agg_df[['Train_Timepoint', 'Train_Cohort', 'Train_N_Low', 'Train_N_High', 'Feature_Set', 'Test_Timepoint', 'Test_Cohort', 'Test_N_Low' ,'Test_N_High', 'Classifier',	'CV_AUC_Mean', 'CV_AUC_SD',	'CV_Sen_at_80_Spe_Mean', 'CV_Sen_at_80_Spe_SD', 'CV_Acc_Low_Mean', 'CV_Acc_Low_SD',	'CV_Acc_High_Mean',	'CV_Acc_High_SD']]
out.to_csv('final_results/model_performance_summary_v5_revised.csv')

#### Post to Pre cohort by cohort

In [111]:
counts = {}
X6, y6 , pid6, emb_list6, real_size6, encoder6 = load_6(binary=True, cutoffs=cutoffs, include_pid=True)
    
X2, y2 , pid2, emb_list2, real_size2, encoder2 = load_2(binary=True, cutoffs=cutoffs, include_pid=True)


X_train = X6
y_train = y6
X_test = X2
y_test = y2

for cohort in COHORTS:     
    
    
    # idx = X[:, 0] == 'Healthy Control'
    idx1 = X_train[:, 0] == cohort
    idx2 = X_test[:, 0] == cohort

    counts[f'{cohort},Test_N_Low'] = (y_test[idx2] == 0).astype(float).sum()
    counts[f'{cohort},Test_N_High'] = (y_test[idx2] == 1).astype(float).sum()
    counts[f'{cohort},Train_N_Low'] = (y_train[idx1] == 0).astype(float).sum()
    counts[f'{cohort},Train_N_High'] = (y_train[idx1] == 1).astype(float).sum()

counts

{'Autoimmune: Other,Test_N_Low': 29.0,
 'Autoimmune: Other,Test_N_High': 68.0,
 'Autoimmune: Other,Train_N_Low': 72.0,
 'Autoimmune: Other,Train_N_High': 163.0,
 'Cancer: Other,Test_N_Low': 6.0,
 'Cancer: Other,Test_N_High': 79.0,
 'Cancer: Other,Train_N_Low': 9.0,
 'Cancer: Other,Train_N_High': 102.0,
 'HIV,Test_N_Low': 14.0,
 'HIV,Test_N_High': 63.0,
 'HIV,Train_N_Low': 14.0,
 'HIV,Train_N_High': 70.0,
 'Healthy Control,Test_N_Low': 21.0,
 'Healthy Control,Test_N_High': 425.0,
 'Healthy Control,Train_N_Low': 24.0,
 'Healthy Control,Train_N_High': 465.0,
 'IBD,Test_N_Low': 3.0,
 'IBD,Test_N_High': 28.0,
 'IBD,Train_N_Low': 4.0,
 'IBD,Train_N_High': 33.0,
 'Multiple Myeloma,Test_N_Low': 42.0,
 'Multiple Myeloma,Test_N_High': 78.0,
 'Multiple Myeloma,Train_N_Low': 60.0,
 'Multiple Myeloma,Train_N_High': 92.0,
 'Transplant,Test_N_Low': 50.0,
 'Transplant,Test_N_High': 56.0,
 'Transplant,Train_N_Low': 81.0,
 'Transplant,Train_N_High': 73.0}

In [105]:
metrics = {}

for k, v in results_8_excel.items():
    
    seed, cohort, task, setting, clf, coeffs, usage = re.split(r',\s*(?![^()]*\))', k)
    
    seed = int(seed)
    if setting == 'binary':
        cutoffs = GLOBAL_CUTOFFS
    else:
        cutoffs = None

    if 'raw' in usage:
        use_cat = False
    else:
        use_cat = True
    if use_cat:
            start_from = len(emb_list)
    else:
        start_from = 11
        
    
    X6, y6 , pid6, emb_list6, real_size6, encoder6 = load_6(binary=True, cutoffs=cutoffs, include_pid=True)
    
    X2, y2 , pid2, emb_list2, real_size2, encoder2 = load_2(binary=True, cutoffs=cutoffs, include_pid=True)
    
    idx = X6[:, 0] == cohort
    X6 = X6[idx]
    y6 = y6[idx]
    pid6 = pid6[idx]

    idx = X2[:, 0] == cohort
    X2 = X2[idx]
    y2 = y2[idx]
    pid2 = pid2[idx]
    
    size = real_size6 if use_cat else real_size6 - 1
         
    skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=seed)
    
    
    for_stratify = y6
    for fold, (train_idx, test_idx) in enumerate(skf.split(X6, for_stratify), start=0):
        
        index = np.isin(pid2, pid6[test_idx])
        
        X_train = X6[train_idx]
        y_train = y6[train_idx]
        X_test = X2[index]
        y_test = y2[index]

        if (y_test == 0).sum() < 1:
            print('not enough samples for', cohort, fold, seed)
            continue
        
        
        #################################################### Dataset
        
        X_train_tensor = torch.from_numpy(X_train[:, start_from:].astype(np.float32))
        y_train_tensor = torch.from_numpy(y_train).float()
        
        X_test_tensor = torch.from_numpy(X_test[:, start_from:].astype(np.float32))
        y_test_tensor = torch.from_numpy(y_test).float()
        
        
        # stratified split using sklearn
        train_idx, val_idx = train_test_split(
            np.arange(len(X_train_tensor)),
            test_size=0.2,
            stratify=y_train_tensor.numpy() ,
            random_state=seed
        )
        
        # create datasets
        train_dataset = TensorDataset(X_train_tensor[train_idx], y_train_tensor[train_idx])
        val_dataset   = TensorDataset(X_train_tensor[val_idx],  y_train_tensor[val_idx])
        
        # data loaders
        train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True, generator=torch.Generator().manual_seed(seed))
        val_loader   = DataLoader(val_dataset,   batch_size=len(val_dataset), shuffle=False)
        test_loader = DataLoader(TensorDataset(X_test_tensor, y_test_tensor),
                                 batch_size=len(X_test_tensor), shuffle=False)
        
        m = BottleneckMLP(emb_list, size, start_from, 8, (128, 64, 2), p=0.2, use_cat=use_cat, binary=True).to(device)
        
        m.load_state_dict(v[0][fold])
        
        
        # for cohort in en.keys():
        for x_batch, y_batch in test_loader:
            pass
        
        if clf == 'mlp':
            with torch.no_grad():
                m.eval()
                logits, two_d = m(x_batch.to(device))
        else:
            with torch.no_grad():
                x_for_clf = []
                y_for_clf = []
                for x_batch_train, y_batch_train in train_loader:
                    with torch.no_grad():
                        m.eval()
                        _, two_d = m(x_batch_train.to(device))
                    x_for_clf.extend(two_d)
                    y_for_clf.extend(y_batch_train)
            x_for_clf = torch.stack(x_for_clf).cpu().numpy()
            y_for_clf = torch.stack(y_for_clf).cpu().numpy()
            if clf == 'svm':
                one = (y_for_clf == 1).astype(float).sum()
                zero = (y_for_clf == 0).astype(float).sum()
                w = (one/zero)
                delta = 1
    
                classifier = SVC(kernel="rbf", class_weight={0: w, 1: 1}, probability=True)
                classifier.fit(x_for_clf, y_for_clf)
                with torch.no_grad():
                    m.eval()
                    x_in = m(x_batch.to(device))[-1].cpu().numpy()
                logits = classifier.decision_function(x_in)
            elif clf == 'nc':
                classifier = NearestCentroid(metric="euclidean")
                classifier.fit(x_for_clf, y_for_clf)
                with torch.no_grad():
                    m.eval()
                    x_in = m(x_batch.to(device))[-1].cpu().numpy()
                logits = classifier.predict(x_in)

        if 1:
            if clf == 'mlp':
                y_pred = (logits.reshape(-1) > 0.5).int().cpu()
            elif clf == 'svm':
                y_pred = torch.from_numpy(logits >= delta).int()
            elif clf == 'nc':
                
                y_pred = torch.from_numpy(logits).int()
            a_test = (y_pred[y_batch==0] == y_batch[y_batch==0]).float().mean().item() + 1e-20
            b_test = (y_pred[y_batch==1] == y_batch[y_batch==1]).float().mean().item() + 1e-20
            c_test = 2 / ((1/a_test) + 1/b_test)
            if clf == 'svm':
                logits = torch.sigmoid(torch.from_numpy(logits))
            elif clf == 'nc':
                D = pairwise_distances(x_in, classifier.centroids_, metric=classifier.metric)
                pos_idx = np.where(classifier.classes_ == 1)[0][0]
                logits = torch.from_numpy(-D[:, pos_idx])
                # logits = torch.sigmoid(torch.from_numpy(logits))

            try:
                auc_test = roc_auc_score(y_batch.cpu().numpy(), logits.cpu().numpy())
            except:
                auc_test = None
            sen_test = []
            spec_test = []
            for t in [0.75, 0.8, 0.85, 0.9, 0.95]:
                sen_, spec_ = at_95(logits.cpu().numpy(), y_batch, t=t)
                sen_test.append(sen_)
                spec_test.append(spec_)
                
        
        metrics[f'{seed}, {cohort}, {fold}, {clf}, {setting}, {usage}'] = ([a_test, b_test, c_test, auc_test, *spec_test, *sen_test])                        
        

not enough samples for Cancer: Other 4 0
not enough samples for Cancer: Other 4 0
not enough samples for Cancer: Other 4 0
not enough samples for Cancer: Other 4 0
not enough samples for Cancer: Other 4 0
not enough samples for Cancer: Other 4 0
not enough samples for Cancer: Other 4 0
not enough samples for Cancer: Other 4 0
not enough samples for Cancer: Other 4 0
not enough samples for Cancer: Other 4 0
not enough samples for Cancer: Other 4 0
not enough samples for Cancer: Other 4 0
not enough samples for Cancer: Other 1 1
not enough samples for Cancer: Other 1 1
not enough samples for Cancer: Other 1 1
not enough samples for Cancer: Other 1 1
not enough samples for Cancer: Other 1 1
not enough samples for Cancer: Other 1 1
not enough samples for Cancer: Other 1 1
not enough samples for Cancer: Other 1 1
not enough samples for Cancer: Other 1 1
not enough samples for Cancer: Other 1 1
not enough samples for Cancer: Other 1 1
not enough samples for Cancer: Other 1 1
not enough sampl

In [106]:
columns = []
for t in [0.75, 0.8, 0.85, 0.9, 0.95]:
    columns.append(f'Sensitivity at {t} Specificity')
for t in [0.75, 0.8, 0.85, 0.9, 0.95]:
    columns.append(f'Specificity at {t} Sensitivity')
columns = [
    'Low Responders Accuracy',
    'High Responders Accuracy',
    'Harmonic Mean',
    'AUC'
] + columns
for k, v in results_6_excel.items():
    seed, _, setting, clf, __, usage = re.split(r',\s*(?![^()]*\))', k)


records = []
for k, v in metrics.items():
    
    seed,  cohort, fold, clf, setting, usage = k.split(', ')
    
    records.append({
        'seed': int(seed),
        'fold': int(fold),
        'cohort': cohort,
        
        'setting': setting,
        'clf': clf,
        
        'usage': usage,
        **{name: val for name, val in zip(columns, v)}
    })

df = pd.DataFrame(records)

group_cols = ['cohort', 'setting', 'clf', 'usage']
metric_cols = columns

agg_df = df.groupby(group_cols)[metric_cols].agg(['mean', 'std']).reset_index()
agg_df.columns = ['_'.join(col).strip('_') for col in agg_df.columns.values]
agg_df

Unnamed: 0,cohort,setting,clf,usage,Low Responders Accuracy_mean,Low Responders Accuracy_std,High Responders Accuracy_mean,High Responders Accuracy_std,Harmonic Mean_mean,Harmonic Mean_std,...,Specificity at 0.75 Sensitivity_mean,Specificity at 0.75 Sensitivity_std,Specificity at 0.8 Sensitivity_mean,Specificity at 0.8 Sensitivity_std,Specificity at 0.85 Sensitivity_mean,Specificity at 0.85 Sensitivity_std,Specificity at 0.9 Sensitivity_mean,Specificity at 0.9 Sensitivity_std,Specificity at 0.95 Sensitivity_mean,Specificity at 0.95 Sensitivity_std
0,Autoimmune: Other,binary,mlp,Ab raw,0.490417,0.352998,0.512023,0.296374,0.304688,0.210677,...,0.199056,0.186359,0.153683,0.174722,0.115943,0.159368,0.045194,0.098619,0.000401,0.008953
1,Autoimmune: Other,binary,mlp,Demographic,0.536604,0.350811,0.489099,0.304677,0.315351,0.220781,...,0.211838,0.192582,0.161954,0.177472,0.121038,0.162287,0.050412,0.109071,0.001317,0.016690
2,Autoimmune: Other,binary,nc,Ab raw,0.421147,0.238259,0.586698,0.153931,0.432855,0.192064,...,0.171063,0.182876,0.128583,0.164165,0.098082,0.145560,0.046692,0.104346,0.000200,0.004477
3,Autoimmune: Other,binary,nc,Demographic,0.440633,0.236083,0.585957,0.166261,0.445324,0.175251,...,0.184069,0.180349,0.141173,0.166050,0.105621,0.150476,0.045678,0.099951,0.000655,0.011703
4,Autoimmune: Other,binary,svm,Ab raw,0.714856,0.207185,0.323402,0.151795,0.406923,0.150938,...,0.204997,0.184616,0.154670,0.175532,0.112330,0.154032,0.047457,0.097365,0.000755,0.013858
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
67,Transplant,no-cutoff,mlp,Demographic,0.495041,0.291569,0.536702,0.280428,0.367923,0.185840,...,0.214903,0.175159,0.167972,0.163703,0.109242,0.142502,0.058584,0.111620,0.023340,0.067919
68,Transplant,no-cutoff,nc,Ab raw,0.534237,0.170507,0.477528,0.168274,0.468434,0.123173,...,0.182874,0.160073,0.145437,0.144917,0.102378,0.124783,0.052470,0.095564,0.024376,0.066651
69,Transplant,no-cutoff,nc,Demographic,0.474986,0.179102,0.540976,0.166879,0.467028,0.124548,...,0.177809,0.155639,0.141918,0.141325,0.102538,0.125637,0.063104,0.108609,0.028437,0.075262
70,Transplant,no-cutoff,svm,Ab raw,0.835873,0.141411,0.172387,0.120410,0.262430,0.155338,...,0.187280,0.154676,0.145490,0.144336,0.103895,0.126659,0.054836,0.098986,0.020968,0.059021


In [112]:
agg_df['Train_Timepoint'] = 'Post-B'
agg_df['Test_Timepoint'] = 'Pre-B'

agg_df['Test_Cohort'] = agg_df['cohort']
agg_df['Train_Cohort'] = agg_df['cohort']


    
# counts

def update_usage(row, what):
    cohort = row['cohort']
    return counts[f'{cohort},{what}']

# Apply transformation
for what in ['Train_N_Low', 'Train_N_High', 'Test_N_High', 'Test_N_Low']:
    agg_df[what] = agg_df.apply(lambda z:update_usage(z, what=what), axis=1)
    
    
agg_df = agg_df.rename(columns={
    # 'cohort': 'Train_Cohort', 
    'clf': 'Classifier', 
    'AUC_mean': 'CV_AUC_Mean',
    'AUC_std': 'CV_AUC_SD',
    'Sensitivity at 0.8 Specificity_mean': 'CV_Sen_at_80_Spe_Mean',
    'Sensitivity at 0.8 Specificity_std': 'CV_Sen_at_80_Spe_SD',
    'Low Responders Accuracy_mean': 'CV_Acc_Low_Mean',
    'High Responders Accuracy_mean': 'CV_Acc_High_Mean',
    'Low Responders Accuracy_std': 'CV_Acc_Low_SD',
    'High Responders Accuracy_std': 'CV_Acc_High_SD',
})

def update_usage(row):
    # Start with base usage name
    usage = row['usage']
    setting = row['setting']
    
    
    # Case 1: if usage is "Ab raw"
    
    out = 'Ab_raw'
    
    
    # Case 2: if usage is "Demographic"
    if setting.strip().lower() == 'binary':
        out += '+ Ab_binary'
        
    if usage.strip().lower() == 'demographic':
        out += '+ demographic'
    
    # Case 3: modify further based on 'setting'
    
    if setting.strip().lower() == 'no-cutoff':
        # don't add anything
        pass
    
    return out

# Apply transformation
agg_df['Feature_Set'] = agg_df.apply(update_usage, axis=1)
agg_df[['Train_Timepoint', 'Train_Cohort', 'Train_N_Low', 'Train_N_High', 'Feature_Set', 'Test_Timepoint', 'Test_Cohort', 'Test_N_Low' ,'Test_N_High', 'Classifier', 'CV_AUC_Mean', 'CV_AUC_SD',	'CV_Sen_at_80_Spe_Mean', 'CV_Sen_at_80_Spe_SD', 'CV_Acc_Low_Mean', 'CV_Acc_Low_SD',	'CV_Acc_High_Mean',	'CV_Acc_High_SD']]



Unnamed: 0,Train_Timepoint,Train_Cohort,Train_N_Low,Train_N_High,Feature_Set,Test_Timepoint,Test_Cohort,Test_N_Low,Test_N_High,Classifier,CV_AUC_Mean,CV_AUC_SD,CV_Sen_at_80_Spe_Mean,CV_Sen_at_80_Spe_SD,CV_Acc_Low_Mean,CV_Acc_Low_SD,CV_Acc_High_Mean,CV_Acc_High_SD
0,Post-B,Autoimmune: Other,72.0,163.0,Ab_raw+ Ab_binary,Pre-B,Autoimmune: Other,29.0,68.0,mlp,0.502159,0.142798,0.283873,0.218998,0.490417,0.352998,0.512023,0.296374
1,Post-B,Autoimmune: Other,72.0,163.0,Ab_raw+ Ab_binary+ demographic,Pre-B,Autoimmune: Other,29.0,68.0,mlp,0.514993,0.147105,0.302103,0.232099,0.536604,0.350811,0.489099,0.304677
2,Post-B,Autoimmune: Other,72.0,163.0,Ab_raw+ Ab_binary,Pre-B,Autoimmune: Other,29.0,68.0,nc,0.504874,0.149896,0.328394,0.211834,0.421147,0.238259,0.586698,0.153931
3,Post-B,Autoimmune: Other,72.0,163.0,Ab_raw+ Ab_binary+ demographic,Pre-B,Autoimmune: Other,29.0,68.0,nc,0.513308,0.154704,0.346185,0.214905,0.440633,0.236083,0.585957,0.166261
4,Post-B,Autoimmune: Other,72.0,163.0,Ab_raw+ Ab_binary,Pre-B,Autoimmune: Other,29.0,68.0,svm,0.513409,0.140734,0.321438,0.204253,0.714856,0.207185,0.323402,0.151795
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
67,Post-B,Transplant,81.0,73.0,Ab_raw+ demographic,Pre-B,Transplant,50.0,56.0,mlp,0.522568,0.122775,0.320528,0.187141,0.495041,0.291569,0.536702,0.280428
68,Post-B,Transplant,81.0,73.0,Ab_raw,Pre-B,Transplant,50.0,56.0,nc,0.525922,0.122243,0.341092,0.177569,0.534237,0.170507,0.477528,0.168274
69,Post-B,Transplant,81.0,73.0,Ab_raw+ demographic,Pre-B,Transplant,50.0,56.0,nc,0.513054,0.128198,0.325303,0.168747,0.474986,0.179102,0.540976,0.166879
70,Post-B,Transplant,81.0,73.0,Ab_raw,Pre-B,Transplant,50.0,56.0,svm,0.522990,0.122754,0.343476,0.163240,0.835873,0.141411,0.172387,0.120410


In [113]:
out = agg_df[['Train_Timepoint', 'Train_Cohort', 'Train_N_Low', 'Train_N_High', 'Feature_Set', 'Test_Timepoint', 'Test_Cohort', 'Test_N_Low' ,'Test_N_High', 'Classifier',	'CV_AUC_Mean', 'CV_AUC_SD',	'CV_Sen_at_80_Spe_Mean', 'CV_Sen_at_80_Spe_SD', 'CV_Acc_Low_Mean', 'CV_Acc_Low_SD',	'CV_Acc_High_Mean',	'CV_Acc_High_SD']]
out.to_csv('final_results/model_performance_summary_v6_revised.csv')

#### All

In [909]:
pd.concat(
    [pd.read_csv('final_results/model_performance_summary_v1_revised.csv'),
     pd.read_csv('final_results/model_performance_summary_v2_revised.csv'),
     pd.read_csv('final_results/model_performance_summary_v3_revised.csv'),
     pd.read_csv('final_results/model_performance_summary_v4_revised.csv'),
     pd.read_csv('final_results/model_performance_summary_v5_revised.csv'),
     pd.read_csv('final_results/model_performance_summary_v6_revised.csv'),
    ],
     ignore_index=True
).to_csv('final_results/all_final_revised.csv', index=False)

## Anomaly

In [None]:
import torch
import numpy as np
from scipy.spatial.distance import mahalanobis
from scipy.linalg import inv

# Assume:
# mu: torch.Tensor of shape (n_samples, latent_dim)
# y: labels as torch.Tensor or NumPy array

# Convert if needed
if isinstance(y, torch.Tensor):
    y = y.cpu().numpy()

mu_np = mu.cpu().numpy()
mu_pos = mu_np[y == 1]   # Positive class
mu_neg = mu_np[y == 0]   # Negative class

# Compute mean and covariance from positives
mu_mean = mu_pos.mean(axis=0)
mu_cov = np.cov(mu_pos.T)
mu_cov_inv = inv(mu_cov)  # For Mahalanobis distance


In [None]:
# Function to compute Mahalanobis distance from the positive class distribution
def mahalanobis_batch(X, mean, cov_inv):
    return np.array([mahalanobis(x, mean, cov_inv) for x in X])

dists_pos = mahalanobis_batch(mu_pos, mu_mean, mu_cov_inv)
dists_neg = mahalanobis_batch(mu_neg, mu_mean, mu_cov_inv)


In [None]:
import matplotlib.pyplot as plt

plt.hist(dists_pos, bins=30, alpha=0.6, label="Positive", color="green")
plt.hist(dists_neg, bins=30, alpha=0.6, label="Negative", color="red")
plt.axvline(np.percentile(dists_pos, 95), color='black', linestyle='--', label="95% cutoff")
plt.title("Mahalanobis Distances from Positive Distribution")
plt.xlabel("Distance")
plt.ylabel("Frequency")
plt.legend()
plt.show()


In [None]:
threshold = np.percentile(dists_pos, 95)
anomaly_flags_neg = dists_neg > threshold
print(f"Out of {len(mu_neg)} negative samples, {anomaly_flags_neg.sum()} were flagged as anomalies.")
