In [None]:
import os
from multi_modality_fl.utils.data_management import GlobalExperimentsConfiguration, write_json, read_json

current_experiment = GlobalExperimentsConfiguration(
    base_path=os.path.join(os.getcwd(), os.path.join('multi_modality_fl', 'experiments')),
    experiment_name='sklearn_general',
    random_seed=0
)

current_experiment.create_experiment(
    dataset_folder='/Users/benjamindanek/Code/federated_learning_multi_modality_ancestry/data',
    dataset=GlobalExperimentsConfiguration.MULTIMODALITY,
    split_method=GlobalExperimentsConfiguration.SKLEARN
)

In [None]:
import json
import numpy as np
import pandas as pd

from sklearn.linear_model import SGDRegressor
def read_multi_modality_dataset(data_path, start: int=None, end: int=None):
    
    if start and end:
        data = pd.read_hdf(data_path, start=start, stop=end)
    else:
        data = pd.read_hdf(data_path) # validation set does not need to specify boundaries

    if ('ID' in data.columns):
        x = data.drop(columns=['ID', 'PHENO']).copy()
    else:
        x = data.drop(columns=['PHENO']).copy()

    y = pd.DataFrame(data['PHENO'].copy())    
    return x, y

def load_data() -> dict:
    client_id = "site-1"
    with open("/Users/benjamindanek/Code/federated_learning_multi_modality_ancestry/multi_modality_fl/experiments/federated_linear/uniform_sid_0_of_1.json") as file:
        data_split = json.load(file)
    print(data_split)

    data_path = data_split["data_path"]
    data_index = data_split["data_index"]

    # check if site_id and "valid" in the mapping dict
    if client_id not in data_index.keys():
        raise ValueError(
            f"Dict of data_index does not contain Client {client_id} split",
        )

    if "valid_path" not in data_split.keys():
        raise ValueError(
            "Data does not contain Validation split",
        )
    
    valid_path = data_split["valid_path"]

    site_index = data_index[client_id]

    # training
    X_train, y_train = read_multi_modality_dataset(
        data_path=data_path, start=site_index["start"], end=site_index["end"]
    )

    # validation
    X_valid, y_valid = read_multi_modality_dataset(
        data_path=valid_path
    )

    return (X_train, y_train), (X_valid, y_valid)

train, val = load_data()
df = pd.concat([train[1], train[0]], axis=1, join='inner')
df.to_csv("/tmp/data/train", sep=",", index=False, header=False)

df = pd.concat([val[1], val[0]], axis=1, join='inner')
df.to_csv("/tmp/data/val", sep=",", index=False, header=False)