In [1]:
from multi_modality_fl.utils.data_management import GlobalExperimentsConfiguration, write_json, read_json
import os
import pandas as pd

In [2]:
current_experiment = GlobalExperimentsConfiguration(
    base_path=os.path.join(os.getcwd(), os.path.join('multi_modality_fl', 'experiments')),
    experiment_name='testing',
    random_seed=0
)

current_experiment.create_experiment(
    dataset_folder='/Users/benjamindanek/Code/federated_learning_multi_modality_ancestry/data',
    dataset=GlobalExperimentsConfiguration.MULTIMODALITY,
    split_method=GlobalExperimentsConfiguration.SKLEARN
)

internal:  (597, 714)
external:  (1116, 674)
shared columns ['ENSG00000143947', 'ENSG00000112335', 'rs11743963', 'ENSG00000135218', 'ENSG00000077549', 'ENSG00000146833', 'ENSG00000110066', 'ENSG00000116667', 'rs1871900', 'ENSG00000183150', 'ENSG00000138131', 'ENSG00000140939', 'ENSG00000149573', 'ENSG00000159164', 'ENSG00000111596', 'ENSG00000259417', 'ENSG00000136541', 'ENSG00000135164', 'ENSG00000175718', 'ENSG00000122224', 'ENSG00000121851', 'ENSG00000124782', 'ENSG00000176410', 'rs4853705', 'ENSG00000116691', 'ENSG00000157060', 'ENSG00000162889', 'ENSG00000089060', 'rs4440018', 'ENSG00000204070', 'ENSG00000232040', 'ENSG00000141127', 'ENSG00000182504', 'ENSG00000120008', 'ENSG00000165521', 'ENSG00000213390', 'ENSG00000114745', 'ENSG00000118596', 'ENSG00000159713', 'ENSG00000250510', 'ENSG00000109686', 'ENSG00000167880', 'ENSG00000136643', 'ENSG00000047188', 'ENSG00000165168', 'ENSG00000102401', 'ENSG00000109956', 'ENSG00000110871', 'ENSG00000120457', 'ENSG00000110442', 'ENSG0000016

<multi_modality_fl.utils.data_management.GlobalExperimentsConfiguration at 0x104436260>

In [3]:
from typing import Dict
import numpy as np

num_folds = 5
def _generate_stratified_k_folds(df: pd.DataFrame) -> Dict[int, pd.DataFrame]:
    k_fold_indeces = dict()

    # shuffle the dataframe
    df = df.sample(frac=1, replace=False, random_state=0)

    for _, group in df.groupby('PHENO'):

        fold_len = len(group) // num_folds
        start = 0
        for fold in range(0, num_folds):
            end = start + fold_len if fold != num_folds - 1 else len(group)
            
            fold_data = group.iloc[start:end]
            if fold not in k_fold_indeces:
                k_fold_indeces[fold] = fold_data
            else:
                k_fold_indeces[fold] = pd.concat([k_fold_indeces[fold], fold_data])
            
            start = end

    # sanity check, since this is such a crucial part of the experimental design
    for i, subset_i in k_fold_indeces.items():
        for j, subset_j in k_fold_indeces.items():
            if i == j: continue
            assert set(subset_i.index) & set(subset_j.index) == set(), "folds have overlapping indeces"

    # all partitions must have approximately similar startification
    assert np.std([fold_values['PHENO'].value_counts()[0] / fold_values['PHENO'].value_counts()[1] for fold_values in k_fold_indeces.values()]) < 0.03

    return k_fold_indeces

folds_for_experiment = _generate_stratified_k_folds(current_experiment.full_internal_dataset)

def set_fold(fold_idx: int):
    """take the base dataset and split it into k folds; """

    holdout_idx = fold_idx
    holdout_dataset = folds_for_experiment[holdout_idx]
    
    trainig_folds = []
    for fold_idx, fold in folds_for_experiment.items():
        if fold_idx != holdout_idx:
            trainig_folds.append(fold)
    
    training_dataset = pd.concat(trainig_folds)

    return holdout_dataset, training_dataset


In [4]:
folds = _generate_stratified_k_folds(current_experiment.full_internal_dataset)
[fold_values['PHENO'].value_counts()[0] / fold_values['PHENO'].value_counts()[1] for fold_values in folds.values()]

[0.4, 0.4, 0.4, 0.4, 0.4069767441860465]

In [5]:
holdout, training = set_fold(1)
display(holdout)
display(training)

Unnamed: 0,ENSG00000143947,ENSG00000112335,rs11743963,ENSG00000135218,ENSG00000077549,ENSG00000146833,ENSG00000110066,ENSG00000116667,rs1871900,ENSG00000183150,...,ENSG00000138640,rs74937936,ENSG00000114796,ENSG00000213218,ENSG00000172086,ENSG00000204634,ENSG00000180644,ENSG00000104886,rs4238361,ENSG00000062716
235,1.941928,2.541150,-1.096533,-0.173336,-0.224394,0.252583,0.095131,0.054646,0.035481,0.298046,...,-0.473754,0.027070,-1.076048,-0.036456,-0.604653,0.617455,-2.444062,1.495267,0.165449,0.527596
441,-1.444026,-0.453831,-0.057941,0.966664,1.262213,0.779189,0.683614,-1.706572,0.127744,-0.232316,...,-0.202740,-1.247137,-1.398003,-0.835565,0.340334,1.841908,-1.124382,0.439134,-1.504982,0.362247
517,-0.504293,-1.198754,1.360592,-1.029127,1.262458,1.037780,1.088592,-1.086045,0.155946,1.614636,...,-1.321213,0.097291,0.541736,3.646601,1.080827,-1.894456,0.037211,-0.232391,0.116295,-1.446085
187,-2.437781,0.993589,1.722793,-0.691267,3.267951,2.064082,1.637317,-0.427401,1.637235,0.140375,...,-1.641736,0.020081,0.377022,1.325358,0.860319,-0.581323,1.141866,-1.676477,1.508977,-1.817062
89,-0.966113,1.093264,-1.507379,-0.634860,-1.607851,0.521608,-0.834465,-0.360187,-0.045767,-0.578437,...,-0.211003,-1.178397,-0.948626,0.577559,0.331569,-0.413669,-1.038602,0.583838,0.078781,0.970103
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
155,0.863093,0.515552,0.001404,-1.423934,-1.072093,-0.364358,0.301050,1.526252,0.132565,-1.082894,...,-0.889081,0.154653,-0.183341,0.100704,1.702016,-1.194779,-0.448557,-0.530518,-0.142227,-0.398865
246,0.292720,0.333473,0.001244,-1.174569,0.346775,-0.329268,-0.001718,-0.555570,-0.166756,1.318800,...,0.191262,-1.232917,0.173628,0.056411,-0.561928,1.050621,0.422901,1.272804,1.438240,0.213223
97,0.230051,1.614090,1.256492,0.129016,0.672171,0.761801,-0.439372,0.214596,-1.359608,-0.276759,...,0.442868,-1.393661,-0.957982,-2.899066,0.268086,-1.326342,0.358763,0.240276,1.641227,-0.710741
20,-1.254920,-0.823544,1.515762,0.581803,-0.132323,0.015882,0.193642,-0.278819,-1.014840,0.926331,...,0.517389,0.106874,1.045398,0.025759,-0.304774,-0.334474,0.367090,-0.928616,-1.242909,-0.785081


Unnamed: 0,ENSG00000143947,ENSG00000112335,rs11743963,ENSG00000135218,ENSG00000077549,ENSG00000146833,ENSG00000110066,ENSG00000116667,rs1871900,ENSG00000183150,...,ENSG00000138640,rs74937936,ENSG00000114796,ENSG00000213218,ENSG00000172086,ENSG00000204634,ENSG00000180644,ENSG00000104886,rs4238361,ENSG00000062716
211,0.584381,-0.653081,-1.387847,0.492597,0.072761,-0.252337,-1.141181,1.021276,-1.320942,-1.002912,...,-0.453372,1.535075,0.326735,-0.370246,0.221064,0.501961,1.659376,-2.156769,1.266749,-0.666706
366,-0.774045,-1.040139,1.267276,1.469246,0.056192,-1.506444,1.617581,-0.233028,-0.019448,1.779834,...,0.470556,0.288685,-0.020729,-1.353670,0.883810,0.593783,-2.279908,-2.885459,1.537927,1.667364
268,-0.893757,-0.479804,-1.571986,-2.781721,-0.582138,0.472947,0.060433,0.225956,1.428102,0.163833,...,-0.348801,-1.031241,0.767742,0.698490,-0.575211,-0.692992,0.884273,-1.190333,-0.027106,-0.587897
367,-0.596641,0.811078,1.389775,0.347388,0.726279,-0.384490,-0.629164,-0.115778,-1.370668,-1.109284,...,0.140485,1.668941,-0.866272,0.117831,0.092333,-0.218097,-0.076786,-0.015228,-0.096427,-0.213659
225,1.284047,1.899512,-0.147398,1.739520,-1.167515,-0.279737,-0.005071,1.029465,-0.026980,-0.249559,...,-0.977905,0.201109,-0.706540,2.437364,0.154597,0.041285,-1.187675,-0.827964,0.117897,0.678767
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
472,0.898414,1.334805,1.266228,-0.623296,-2.489481,-1.245682,-3.285510,1.093602,0.109881,-0.742016,...,-0.155769,1.572897,1.078700,0.245530,-0.290995,1.401802,2.286711,0.992320,1.435009,1.396531
70,0.698702,0.780560,1.295275,0.202047,0.470757,-0.477733,-0.061133,1.760197,0.318812,-0.872899,...,-0.203095,0.126979,-1.180912,-1.297723,-1.228108,0.807004,-0.027940,-0.479249,0.016441,-0.354795
277,0.527510,-0.796497,1.355750,-1.043445,-1.083388,-0.287780,0.917695,1.675362,0.109195,0.263268,...,1.796977,0.101348,-0.348784,-0.774381,-0.082046,-0.913148,0.843578,-0.219970,-0.061138,-0.111780
359,-0.043172,-0.897722,-1.345545,-0.071066,-0.906294,-0.123266,0.110271,0.232275,0.204597,0.045100,...,0.307526,-1.246423,0.807929,-0.009993,0.348830,0.552139,-0.178482,0.375650,0.166049,-0.422884


In [6]:
current_experiment.set_fold(fold_idx=1)
    
# use a validation dataset
current_experiment.set_validation_dataset()

num_clients, site_prefix, split_method = 5, "site", "linear"

# 1. split the data for the experiment so each client has its own dataframe of stratified data
client_dataframes = current_experiment.get_stratified_client_subsets(
    dataset=current_experiment.training_dataset,
    num_clients=num_clients,
    method=split_method
)

# translate data frames into client splits & write the data
client_splits = []
training_data_paths = []
validaiton_data_paths = []
for i, df in enumerate(client_dataframes):
    start, end = 0, len(df)
    client_splits.append((start, end))
    
    training_subset_path = os.path.join(current_experiment.experiment_path, f'client_training_{i}_stratified.h5')
    df.to_hdf(training_subset_path, key='df', mode='w')
    training_data_paths.append(training_subset_path)

    validation_subset_path = os.path.join(current_experiment.experiment_path, f'client_validation_{i}_stratified.h5')
    df.to_hdf(validation_subset_path, key='df', mode='w')
    validaiton_data_paths.append(validation_subset_path)

0 [(0, 114), (114, 143)]
1 [(0, 284), (284, 355)]
0 [(0, 7), (7, 22), (22, 44), (44, 74), (74, 114)]
1 [(0, 18), (18, 55), (55, 111), (111, 186), (186, 284)]


In [7]:
current_experiment.internal_test_dataset['PHENO'].value_counts()[0] / current_experiment.internal_test_dataset['PHENO'].value_counts()[1]

0.39436619718309857

In [8]:
current_experiment.validation_dataset['PHENO'].value_counts()[0] / current_experiment.training_dataset['PHENO'].value_counts()[1]

0.10211267605633803

In [9]:
t, v = current_experiment.stratified_split(df=current_experiment.training_dataset, column='PHENO', ratios=[0.8, 0.2])
v['PHENO'].value_counts()

0 [(0, 91), (91, 114)]
1 [(0, 227), (227, 284)]


PHENO
1    57
0    23
Name: count, dtype: int64

In [10]:
import numpy as np

def stratified_split(df: pd.DataFrame, column: str, ratios: float):
    samples = []
    groups = df.groupby(column)
    for _, group in groups:
        shuffled_group = group.sample(frac=1, replace=False, random_state=0)
        n = len(shuffled_group)
        offset = 0

        indeces = []
        for ri, ratio in enumerate(ratios):

            start = offset
            end = start + int(n * ratio) if ri != len(ratios) - 1 else n
            offset = end

            indeces.append((start, end))
        
        print(_, indeces)

        for i, (start, end) in enumerate(indeces):
            stratified_subset = group.iloc[start: end]
            
            if len(samples) == len(ratios):
                samples[i].append(stratified_subset)
            else:
                samples.append([stratified_subset])
    
    flattened = []
    for sample in samples:
        
        flattened.append(pd.concat([group for group in sample]))

    value_proportions = [subset['PHENO'].value_counts()[0] / subset['PHENO'].value_counts()[1] for subset in flattened]
    print(value_proportions)
    assert np.std(value_proportions) < 0.03, f"Value counts of stratified dataset inconsistnet. {value_proportions}"

    return flattened

r = stratified_split(current_experiment.training_dataset, column='PHENO', ratios=[0.8, 0.2])

0 [(0, 91), (91, 114)]
1 [(0, 227), (227, 284)]
[0.4008810572687225, 0.40350877192982454]


In [11]:
import pandas as pd

pd.read_hdf('/Users/benjamindanek/Code/federated_learning_multi_modality_ancestry/multi_modality_fl/experiments/federated_random_forest_xgboost/client_training_0_stratified.h5')

Unnamed: 0,ENSG00000106853,ENSG00000164985,ENSG00000110077,ENSG00000135378,ENSG00000266094,ENSG00000188655,ENSG00000240682,ENSG00000150637,ENSG00000174937,ENSG00000181481,...,ENSG00000038427,ENSG00000144893,ENSG00000149582,ENSG00000172037,ENSG00000141127,ENSG00000157322,ENSG00000183150,ENSG00000143537,ENSG00000089060,ENSG00000093010
173,-0.531211,-0.903189,-0.890909,1.513053,-0.518242,-0.310894,0.102622,0.812651,-0.098505,0.101966,...,-0.07676,0.359953,-0.971485,-0.215898,-0.602565,-0.75789,0.712387,-0.091809,-0.12683,0.562271
218,0.74557,-0.274998,0.167433,-1.01804,1.161092,-0.052596,-1.601408,-1.057788,1.51577,-0.021842,...,0.270662,0.072317,-2.196739,-1.047918,0.823857,-0.398433,-0.6675,0.739976,-1.179362,-0.459578
0,-0.142423,0.400908,-1.193848,-0.192021,0.854485,-1.668472,-0.599969,1.348376,2.650046,-1.199505,...,-0.766175,-0.815271,0.144658,0.581931,-0.371057,-0.053879,1.101786,-0.376567,2.398742,-0.848698
293,0.229119,-0.35373,0.371532,-1.606966,0.609142,0.06161,-0.421741,-0.144672,0.18452,0.106564,...,0.707488,-0.391105,0.338611,0.269864,0.779065,0.108572,-0.270206,-0.112388,0.837289,0.093903
487,-0.346805,-4.438834,-1.254933,-0.472472,1.200714,0.042502,1.178473,-1.934166,-0.033673,-0.449893,...,-2.281724,0.122011,-0.049236,-1.08306,-0.374013,0.936211,2.110685,-3.258715,-0.1987,-0.102338
13,0.359403,1.105021,0.191982,0.267977,1.472481,-0.481043,-0.190513,-0.134845,0.681554,1.706106,...,-0.398243,0.044271,-0.426124,-0.447694,0.262716,0.694344,1.144851,-0.417672,0.623655,0.187537
224,0.316619,0.006353,-0.787599,-0.1772,-0.847227,2.416022,-0.55126,0.565413,-0.002655,-0.208785,...,-0.23916,-0.549222,-0.432838,-0.299146,-0.939355,-0.589958,1.426951,0.762336,0.368313,0.139786
192,-0.046175,-0.434034,0.615262,0.616905,-3.211917,0.151242,0.2043,-0.402382,-0.466064,1.398314,...,-0.059045,1.824425,2.841932,0.540625,-0.010288,1.278227,2.968023,0.42513,-0.613594,-0.299148
440,1.956391,0.132396,2.077022,-2.316414,0.3431,0.536442,-5.959353,0.764644,0.041144,4.030832,...,2.823267,2.311684,0.404828,1.313662,1.723946,-0.309621,0.791965,1.062309,-0.792483,-1.582662
291,-0.991765,0.011345,-0.167049,-0.312135,0.162716,-1.47002,-0.521469,-1.441038,3.763387,-1.203022,...,-0.119678,0.395342,-2.08518,0.999136,0.067032,0.018596,-0.7021,-0.552763,-0.120842,1.467044
