In [41]:
import os
import sys
import numpy as np
import pandas as pd
import pickle
from typing import Union, List, Tuple

import matplotlib.pyplot as plt
import seaborn as sns

In [42]:
from iterstrat.ml_stratifiers import MultilabelStratifiedShuffleSplit

In [43]:
# CONSTANTS

PD_FILES = "%s_t1mri_participants.csv"
STUDIES = ("abide1", "abide2", "biobd", "bsnip1", 
           "cnp", "candi", "schizconnect-vip-prague")
UNIQUE_KEYS = ["participant_id", "session", "run", "study"]
ID_TYPES = {"participant_id": str,
            "session": int,
            "acq": int,
            "run": int}
EXTERNAL_SITES = (('UM', 'ABIDE2'), ('GU', 'ABIDE2'), 
                  ('mannheim', 'BIOBD'), ('geneve', 'BIOBD'), 
                  ('Hartford', 'BSNIP1'), ('Detroit', 'BSNIP1'))

In [44]:
# Functions

def discretize_continous_label(labels: str, bins: Union[str, int] = "sturges"):
    # Get an estimation of the best bin edges. 'Sturges' is conservative for pretty large datasets (N>1000).
    bin_edges = np.histogram_bin_edges(labels, bins=bins)
    # Discretizes the values according to these bins
    discretization = np.digitize(labels, bin_edges[1:], right=True)
    return discretization
    
def get_mask_from_df(source_df: pd.DataFrame, target_df: pd.DataFrame, keys: List):
    source_keys = source_df[keys].apply(lambda row: '_'.join(row.values.astype(str)), axis=1)
    assert len(set(source_keys)) == len(source_keys), f"Multiple identique identifiers found"
    target_keys = target_df[keys].apply(lambda row: '_'.join(row.values.astype(str)), axis=1)
    mask = source_keys.isin(target_keys).values.astype(bool)
    return mask

In [45]:
# Parameters 
path_to_analyse = "/home/pa267054/neurospin/psy_sbox/analyses/2023_pauriau_sepmod"
raw = os.path.join(path_to_analyse, "data", "raw")
processed = os.path.join(path_to_analyse, "data", "processed")
root = os.path.join(path_to_analyse, "data", "root")
val_size = 0.1
stratify = ["age", "sex", "diagnosis", "site"]
random_state = 0
nb_folds = 1

In [46]:
# Get subject metadata for all the studies
metadata = pd.concat([pd.read_csv(os.path.join(processed, PD_FILES % db), dtype=ID_TYPES) for db in STUDIES], ignore_index=True, sort=False)
print(f"Nb of sbj with metadata: {len(metadata)} | {len(metadata.drop_duplicates(subset=UNIQUE_KEYS))}")
# Create new scheme
scheme = metadata[UNIQUE_KEYS].copy(deep=True)
print(scheme.head())

Nb of sbj with metadata: 4410 | 4410
  participant_id  session  run   study
0          51051        1    1  ABIDE1
1          50794        1    1  ABIDE1
2          50697        1    1  ABIDE1
3          50628        1    1  ABIDE1
4          51463        1    1  ABIDE1


In [47]:
# Load previous schemes to keep internal and external test sets

mask_train = False
mask_val = False
mask_test_intra = False
mask_test = False

for target in ("asd", "bd", "scz"):
    
    pck = pickle.load(open(os.path.join(root, f"train_val_test_test-intra_{target}_stratified.pkl"), "rb"))
    df_train = pck["train"]
    df_val = pck["validation"]
    df_test = pck["test"]
    df_test_intra = pck["test_intra"]
    if "run" not in df_test.columns:
        df_train["run"] = 1
        df_val["run"] = 1
        df_test_intra["run"] = 1
        df_test["run"] = 1
    
    mask = get_mask_from_df(source_df=scheme, target_df=df_train, keys=UNIQUE_KEYS)
    mask_train |= mask
    scheme.loc[mask, target] = "train"
    
    mask = get_mask_from_df(source_df=scheme, target_df=df_val, keys=UNIQUE_KEYS)
    mask_val |= mask
    scheme.loc[mask, target] = "validation"
    
    mask = get_mask_from_df(source_df=scheme, target_df=df_test_intra, keys=UNIQUE_KEYS)
    mask_test_intra |= mask
    scheme.loc[mask, target] = "test_intra"
    
    mask = get_mask_from_df(source_df=scheme, target_df=df_test, keys=UNIQUE_KEYS)
    mask_test |= mask
    scheme.loc[mask, target] = "test"
    
print(f"Nb of subjects for training: {mask_train.sum()}")
print(f"Nb of subjects for validation: {mask_val.sum()}")
print(f"Nb of subjects for internal tests: {mask_test_intra.sum()}")
print(f"Nb of subjects for external tests: {mask_test.sum()}")

Nb of subjects for training: 2769
Nb of subjects for validation: 362
Nb of subjects for internal tests: 372
Nb of subjects for external tests: 377


In [48]:
mask_ext_sites = (metadata["site"] + metadata["study"]).isin([s[0] + s[1] for s in EXTERNAL_SITES])
print(f"Nb of sbj in external sites : {mask_ext_sites.sum()}")
# Sanity checks
print(f"Sanity check : {((mask_ext_sites & mask_test) == mask_test).all()}")

Nb of sbj in external sites : 636
Sanity check : True


In [49]:
# Keep only subjects of training sets
mask_train_only = mask_train.astype(int)  - mask_val.astype(int) - mask_test_intra.astype(int) - mask_test.astype(int) - mask_ext_sites.astype(int)
mask_train_only = (mask_train_only == 1)
print(f"Nb of subjects for training: {mask_train_only.sum()}")

Nb of subjects for training: 2590


In [50]:
# Get all subjects
# 1. Remove subjects from external sites (keep the same as previous studies)
# 2. Remove subjects from internal test sets and validation sets
mask_train_full = ~(mask_ext_sites | mask_test_intra | mask_val)
print(f"Number of subjects for the full training {mask_train_full.sum()}")

Number of subjects for the full training 3059


In [51]:
# Save in scheme
scheme.loc[mask_train_only, "unsupervised"] = "train"
scheme.loc[mask_train_full, "unsupervised_full"] = "train"

In [52]:
mask_control = metadata["diagnosis"] == "control"
print(f"Number of controls : {mask.sum()}")
print(f"Number of controls in training sets : {(mask_control & mask_train_only).sum()}")
print(f"Number of controls in validation sets : {(mask_control & mask_val).sum()}")
print(f"Number of controls in internal test sets : {(mask_control & mask_test_intra).sum()}")
print(f"Number of controls in external test sets : {(mask_control & mask_test).sum()}")
control_data = metadata[UNIQUE_KEYS + stratify].copy(deep=True)
control_data.loc[(mask_control & mask_train_only), "set"] = "train"
control_data.loc[(mask_control & mask_val), "set"] = "validation"
control_data.loc[(mask_control & mask_test_intra), "set"] = "test_intra"
control_data.loc[(mask_control & mask_test), "set"] = "test"
control_data = control_data[mask_control]

Number of controls : 130
Number of controls in training sets : 1397
Number of controls in validation sets : 205
Number of controls in internal test sets : 198
Number of controls in external test sets : 201


In [53]:
# save in scheme
scheme.loc[(mask_control & mask_train_only), ["age", "sex"]] = "train"
scheme.loc[(mask_control & mask_val), ["age", "sex"]] = "validation"
scheme.loc[(mask_control & mask_test_intra), ["age", "sex"]] = "test_intra"
scheme.loc[(mask_control & mask_test), ["age", "sex"]] = "test"

In [54]:
scheme

Unnamed: 0,participant_id,session,run,study,asd,bd,scz,unsupervised,unsupervised_full,age,sex
0,51051,1,1,ABIDE1,train,,,train,train,train,train
1,50794,1,1,ABIDE1,train,,,train,train,,
2,50697,1,1,ABIDE1,train,,,train,train,,
3,50628,1,1,ABIDE1,test_intra,,,,,,
4,51463,1,1,ABIDE1,train,,,train,train,,
...,...,...,...,...,...,...,...,...,...,...,...
4405,CH7546a,1,1,SCHIZCONNECT-VIP,,,train,train,train,train,train
4406,ESOC10076,1,1,PRAGUE,,,train,train,train,train,train
4407,NM2082,1,1,SCHIZCONNECT-VIP,,,test_intra,,,test_intra,test_intra
4408,fg130137,1,1,SCHIZCONNECT-VIP,,,train,train,train,,


In [60]:
scheme.to_csv(os.path.join(path_to_analyse, "data", "root", "train_val_test_test-intra.csv"), sep=",", index=False)

In [37]:
control_data

Unnamed: 0,participant_id,session,run,study,age,sex,diagnosis,site,set
0,51051,1,1,ABIDE1,14.0600,F,control,NYU,train
6,51139,1,1,ABIDE1,19.5000,M,control,TRINITY,train
8,51343,1,1,ABIDE1,30.0000,M,control,MAX_MUN,train
9,50467,1,1,ABIDE1,19.7591,M,control,USM,train
10,51062,1,1,ABIDE1,27.7600,F,control,NYU,train
...,...,...,...,...,...,...,...,...,...
4401,A00007409,1,1,SCHIZCONNECT-VIP,48.0000,M,control,MRN,train
4403,CC4094,1,1,SCHIZCONNECT-VIP,24.0000,F,control,WUSTL,train
4405,CH7546a,1,1,SCHIZCONNECT-VIP,23.0000,F,control,NU,train
4406,ESOC10076,1,1,PRAGUE,40.0000,F,control,PRAGUE,train


In [27]:
# Sex
control_data[["set", "sex"]].groupby("set").describe()

Unnamed: 0_level_0,sex,sex,sex,sex
Unnamed: 0_level_1,count,unique,top,freq
set,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2
test,201,2,F,104
test_intra,192,2,M,127
train,1397,2,M,852
validation,196,2,M,115


In [21]:
# Age
control_data[["set", "age"]].groupby("set").describe()

Unnamed: 0_level_0,age,age,age,age,age,age,age,age
Unnamed: 0_level_1,count,mean,std,min,25%,50%,75%,max
set,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2
test,201.0,27.53697,15.644043,7.6,12.265753,25.0,39.0,66.0
test_intra,192.0,27.61158,15.338253,7.126626,12.975,24.580822,39.0,78.77
train,1397.0,26.089367,14.749832,5.2,13.25,23.0,35.0,79.22
validation,196.0,26.05056,12.939682,8.0,15.0,24.0,33.004795,66.88


In [22]:
# Site
control_data[["set", "site"]].groupby("set").describe()

Unnamed: 0_level_0,site,site,site,site
Unnamed: 0_level_1,count,unique,top,freq
set,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2
test,201,6,GU,52
test_intra,192,36,CNP,26
train,1397,43,KKI,132
validation,196,37,WUSTL,19


In [15]:
# Stratification for age and sex predictions
# Discretize continuous labels
y = sbj_to_strat[stratify].copy(deep=True).values
if "age" in stratify:
    i_age = stratify.index("age")
    y[:, i_age] = discretize_continous_label(y[:, i_age].astype(np.float32))

In [17]:
sbj_to_strat["diagnosis"].unique()

array(['control', 'asd', 'bipolar disorder',
       'relative of schizoaffective disorder',
       'relative of bipolar disorder', 'schizoaffective disorder', 'scz',
       'psychotic bd', 'relative of schizophrenia', 'adhd', 'bd', 'fep'],
      dtype=object)

In [None]:
# Create arrays for splitting
dummy_x = np.zeros((len(sbj_to_strat), 1, 128, 128, 128))

In [None]:
# Stratification
print("Train - validation sets")
splitter = MultilabelStratifiedShuffleSplit(n_splits=nb_folds, test_size=val_size, random_state=random_state)
gen = splitter.split(dummy_x, y)
for f in range(nb_folds):
    train_index, val_index = next(gen)
    df_train = sbj_to_strat.iloc[train_index]
    mask_train = get_mask_from_df(source_df=scheme, target_df=df_train, keys=UNIQUE_KEYS)
    df_val = sbj_to_strat.iloc[val_index]
    mask_val = get_mask_from_df(source_df=scheme, target_df=df_val, keys=UNIQUE_KEYS)

    scheme.loc[mask_train, f"fold{f}"] = "train"
    scheme.loc[mask_val, f"fold{f}"] = "validation"
    scheme.loc[mask_test_intra, f"fold{f}"] = "test_intra"
    scheme.loc[mask_test, f"fold{f}"] = "test"

In [None]:



    
    print(scheme.head())
    # Sanity checks
    for fold in [f"fold{f}" for f in range(nb_folds)]:
        for split in ("train", "validation", "test", "test_intra"):
            print(f"Scheme: {fold} | Split {split}")
            print(f"Number of subjects {(scheme[fold] == split).sum()}")


        mask = (metadata["site"] + metadata["study"]).isin([s[0] + s[1] for s in EXTERNAL_SITES])
        print((scheme.loc[mask_ext_sites, fold].unique()))
        print(((scheme.loc[mask, fold] == "test") | (scheme.loc[mask, fold].isnull())).all())
        print("External acquisition sites are in train, validation or test_intra set !")




        fig, ax = plt.subplots(len(stratify), 1, figsize=(12, 6))
        fig.suptitle(f"{fold}")
        for i, s in enumerate(stratify):
            df = pd.concat((scheme, metadata[stratify]), axis=1)
            sns.histplot(data=df, x=s, hue=fold, stat="percent", kde=True, common_norm=False, ax=ax[i])
        fig.savefig(os.path.join(path_to_analyse, "figures", "hist_{fold}.png"))

    # Saving
    path_to_scheme = os.path.join(processed, f"train_val_test_test-intra_stratified.csv")
    if os.path.exists(path_to_scheme):
        answer = input(f"There is already a scheme at {path_to_scheme}. Do you want to replace it ? (y/n)")
        if answer == "y":
            scheme.to_csv(path_to_scheme, sep=",", index=False)