Code used to generate the different splits used for training / validation and testing. Longitudinal data will be separated into training, validation and testing. On the other hand, cross-sectional data, as it is only for pre-training the models, will be divided into training and test **excluding all the subjects included in the longitudinal test data**.

In [None]:
import os
import re
import pandas as pd
from pathlib import Path
from sklearn.model_selection import train_test_split
from datetime import datetime
from sklearn.cluster import KMeans
from IPython.display import display

In [None]:
RANDOM_SEED = 9999
PATH_TO_DATA = Path(os.path.join('..', 'data', 'mounts', 'v1'))
PATH_TO_OUTPUT = Path(os.path.join('..', 'data', 'splits', 'v1'))
PATH_TO_LONG = PATH_TO_DATA / '20240428_longitudinal.parquet'
PATH_TO_CROSS_SECTIONAL = {
    'mri': PATH_TO_DATA / '20240428_mri_cross_sectional.parquet',
    'fdg': PATH_TO_DATA / '20240428_fdg_cross_sectional.parquet',
    'amy': PATH_TO_DATA / '20240428_amy_cross_sectional.parquet'
}

In [None]:
curr_date = datetime.now().strftime('%Y%m%d')

# load all the datasets
long_df = pd.read_parquet(PATH_TO_LONG).sort_index()
cross_dfs = {k: pd.read_parquet(f).sort_index() for k, f in PATH_TO_CROSS_SECTIONAL.items()}

# Splits for the longitudinal information

In [None]:
# as we have the trajectories and there can be two trajectories for the same subject, 
# we must consider the splits at the subject level
long_uids = long_df.index.get_level_values(0).unique()

# stratify based on the number of cognitive domains affected considering this number as if they were different calses
long_stratify = long_df.groupby('subject_id').nth(0)[[ 
    'lr_delta_memory_composite_binary',
    'lr_delta_exec_composite_binary',
    'lr_delta_language_composite_binary',
    'lr_delta_visuospatial_composite_binary']].sum(axis=1)

display(pd.DataFrame(long_stratify.value_counts()))

# perform the split of the longitudinal dataset for train/test
long_train, long_test, long_train_stratify, long_test_stratify = train_test_split(
    long_uids, long_stratify, 
    test_size=0.3, random_state=RANDOM_SEED, shuffle=True, stratify=long_stratify
)

print('Data shape (train / test):\n\n\ttrain: {}\n\ttest: {}\n'.format(
    long_train.shape, long_test.shape
))
print('\n')

# perform the split of the longitudinal dataset for train/validation
long_train, long_valid, long_train_stratify, long_valid_stratify = train_test_split(
    long_train, long_train_stratify, 
    test_size=0.15, random_state=RANDOM_SEED, shuffle=True, stratify=long_train_stratify
)

print('Data shape (train / valid):\n\n\ttrain: {}\n\tvalid: {}\n'.format(
    long_train.shape, long_valid.shape
))
print('\n')

# select the final indices
long_train_indices = long_df.loc[long_train].index
long_valid_indices = long_df.loc[long_valid].index
long_test_indices = long_df.loc[long_test].index

print('Final number of entries: {} (train) | {} (valid) | {} (test)'.format(
    long_train_indices.shape[0], long_valid_indices.shape[0], long_test_indices.shape[0]
))

# unitary tests
assert len(set(long_test_indices.get_level_values(0).tolist()).intersection(
    set(long_valid_indices.get_level_values(0).tolist())
)) == 0

assert len(set(long_test_indices.get_level_values(0).tolist()).intersection(
    set(long_train_indices.get_level_values(0).tolist())
)) == 0


# create the final dataframe
long_indices_df = pd.concat([
    pd.DataFrame(['train']*len(long_train_indices), index=long_train_indices, columns=['split']),
    pd.DataFrame(['valid']*len(long_valid_indices), index=long_valid_indices, columns=['split']),
    pd.DataFrame(['test']*len(long_test_indices), index=long_test_indices, columns=['split']),
], axis=0)

pd.DataFrame(long_indices_df['split'].value_counts())


In [None]:
# Explore data distributions
long_df_with_split = long_df.join(long_indices_df).copy()

print('Average values')

display(
    long_df_with_split.groupby('split')[[
        'lr_delta_memory_composite',
        'lr_delta_exec_composite',
        'lr_delta_language_composite',
        'lr_delta_visuospatial_composite'
    ]].mean().round(3)
)

print('SD values')
display(
    long_df_with_split.groupby('split')[[
        'lr_delta_memory_composite',
        'lr_delta_exec_composite',
        'lr_delta_language_composite',
        'lr_delta_visuospatial_composite'
    ]].std().round(3)
)

print('Binary version')
display(
    (long_df_with_split.groupby('split')[[
        'lr_delta_memory_composite_binary',
        'lr_delta_exec_composite_binary',
        'lr_delta_language_composite_binary',
        'lr_delta_visuospatial_composite_binary'
    ]].mean() * 100).astype(float).round(1)
)

print('Diagnosis information')
display(
    (pd.DataFrame(long_df_with_split.groupby('split')[[
            'baseline_diagnosis',
    ]].value_counts(normalize=True)) * 100).round(2)
)

display(
    (pd.DataFrame(long_df_with_split.groupby('split')[[
            'diagnosis_2Y',
    ]].value_counts(normalize=True)) * 100).round(2)
)

display(
    (pd.DataFrame(long_df_with_split.groupby('split')[[
            'diagnosis_4Y',
    ]].value_counts(normalize=True)) * 100).round(2)
)

In [None]:
# export the information
in_long_date = re.findall('(^\d{8}).*', os.path.split(PATH_TO_LONG)[1])[0]
long_indices_df.to_parquet(
    os.path.join(
        PATH_TO_OUTPUT, '%s_longitudinal_generated%s.parquet' % (in_long_date, curr_date)
    )
)

# Splits for cross-sectional data

In [None]:
cross_data_splits = {}
for key, mod_df in cross_dfs.items():

    # important!! prevent patients from the longitudinal test/validation dataset from being 
    # able to leak into the pre-training data
    mod_df = mod_df.loc[
        ~mod_df.index.get_level_values('subject_id').isin(
            long_test_indices.get_level_values(0).unique().tolist()
        )
    ].copy()


    # generate the splits
    mod_df_train_index, mod_df_test_index = train_test_split(
        mod_df.index,
        test_size=0.15, random_state=RANDOM_SEED, shuffle=True
    )
    print('Split performed for "{}", entries: {} (train) | {} (valid)'.format(
        key, mod_df_train_index.shape[0], mod_df_test_index.shape[0]))

    # export the generated splits
    in_cross_date = re.findall('(^\d{8}).*', os.path.split(PATH_TO_CROSS_SECTIONAL[key])[1])[0]

    cross_df_indices_df = pd.concat([
        pd.DataFrame(['train']*len(mod_df_train_index), index=mod_df_train_index, columns=['split']),
        pd.DataFrame(['test']*len(mod_df_test_index), index=mod_df_test_index, columns=['split'])
    ], axis=0)
    cross_df_indices_df.to_parquet(
        os.path.join(
            PATH_TO_OUTPUT, '%s_%s_cross_sectional_generated%s.parquet' % (in_cross_date, key, curr_date)
        )
    )