In [None]:
### Imports ###
import os
from typing import List
from classes.client import Client
from classes.coordinator_utils import select_common_features_variables, \
    compute_beta

import numpy as np
import pandas as pd

In [None]:
### Helper Functions ###
### Just run this, these functions are needed in various places ###

### Define the client class ###
class ClientWrapper:
    """
    Holds all information necessary for the simulation run to work.
    Defines the input data of the client, it's name and if it should be
    considered as the coordinator
    """
    def __init__(self, id: str, input_folder: str, coordinator: bool = False):
        self.id = id
        self.input_folder = input_folder
        self.is_coordinator = coordinator
        self.client_class = None

def _check_consistency_clientwrappers(clientWrappers: List[ClientWrapper]) -> None:
    """
    Checks for a list of clients if they were created correctly
    Raises a ValueError in case of inconsistencies
    Checks:
        1. If exactly one coordinator exists
    """
    coord = False
    for clientWrapper in clientWrappers:
        if coord and clientWrapper.is_coordinator:
            raise ValueError("More than one coordinator was defined, please check "+\
                            "the code defining the clients")
        if clientWrapper.is_coordinator:
            coord = True
    if not coord:
        raise ValueError("No client instance is a coordinator, please designate "+\
                        "any client as a coordinator")
def _compare_central_federated_dfs(name:str,
                                   central_df: pd.DataFrame,
                                   federated_df: pd.DataFrame) -> None:
    """
    Compares two dataframes for equality. First checks that index and columns
    are the same, then analyses the element wise differences.
    See the analyse_diff_df function for more details on the difference analysis.
    If both dataframes contain a NaN value at the same position, this is considered
    as equal (0 as difference).
    """
    central_df = central_df.sort_index(axis=0).sort_index(axis=1)
    federated_df = federated_df.sort_index(axis=0).sort_index(axis=1)
    print(f"_________________________Analysing: {name}_________________________")
    ### compare columns and index ###
    failed = False
    if not central_df.columns.equals(federated_df.columns):
        print(f"Columns do not match for central_df and federated_df")
        union_cols = central_df.columns.union(federated_df.columns)
        intercept_cols = central_df.columns.intersection(federated_df.columns)
        print(f"Union-Intercept of columns: {union_cols.difference(intercept_cols)}")
        failed = True
    if not central_df.index.equals(federated_df.index):
        print(f"Rows do not match for central_df and federated_df")
        union_rows = central_df.index.union(federated_df.index)
        intercept_rows = central_df.index.intersection(federated_df.index)
        print(f"Union-Intercept of rows: {union_rows.difference(intercept_rows)}")
        failed = True
    if failed:
        print(f"_________________________FAILED: {name}_________________________")

    df_diff = (central_df - federated_df).abs()
    print(f"Max difference: {df_diff.max().max()}")
    print(f"Mean difference: {df_diff.mean().mean()}")
    print(f"Max diff at position: {df_diff.idxmax().idxmax()}")

def _concat_federated_results(clientWrappers: List[ClientWrapper],
                              samples_in_columns=True) -> pd.DataFrame:
    """
    Concatenates the results of the federated clients into one dataframe
    Args:
        clientWrappers: List of ClientWrapper instances, containing the data
            in the data_corrected attribute
        samples_in_columns: If True, the samples are in the columns, if False
            they are in the rows. For expression files this is true.
            This decides how to aggregate the dataframes
    """
    merged_df = None
    for clientWrapper in clientWrappers:

        # get the data in the correct format
        if not hasattr(clientWrapper, "data_corrected") or \
            clientWrapper.data_corrected is None:
            raise ValueError("No data was found in the clientWrappers")
        corrected_data = clientWrapper.data_corrected
        if not samples_in_columns:
            corrected_data = corrected_data.T

        # initialize the merged_df
        if merged_df is None:
            merged_df = corrected_data
            continue

        # merge the data
        merged_df = pd.concat([merged_df, corrected_data], axis=1)

    # final check
    if merged_df is None:
        raise ValueError("No data was found in the clientWrappers")
    # reverse the Transpose if necessary
    if not samples_in_columns:
        merged_df = merged_df.T
    return merged_df


In [None]:
###                                  INFO                                  ###
### this part defines the data used. A ClientWrapper class is used to      ###
### describe all cohorts. If other data should be tested, this part should ###
### be changed                                                             ###
### Define the different clients ###
clientWrappers: List[ClientWrapper] = list()
    # we use a helper class for each client, see the helper function
    # code block or the later definitions here for more info
# First define the basefolder where all files are located
base_dir = os.path.join("..")
    # go back to the git repos root dir
base_dir = os.path.join(base_dir, "evaluation_data", "microarray", "before")
    # location of the microarray data
# Client 1
cohortname = 'GSE38666'
clientWrappers.append(ClientWrapper(id=cohortname,
                                    input_folder=os.path.join(base_dir, cohortname),
                                    coordinator=True))
# Client 2
cohortname = 'GSE14407'
clientWrappers.append(ClientWrapper(id=cohortname,
                                    input_folder=os.path.join(base_dir, cohortname)))
# Client 3
cohortname = 'GSE6008'
clientWrappers.append(ClientWrapper(id=cohortname,
                                    input_folder=os.path.join(base_dir, cohortname)))
# Client 4
cohortname = 'GSE40595'
clientWrappers.append(ClientWrapper(id=cohortname,
                                    input_folder=os.path.join(base_dir, cohortname)))
# Client 5
cohortname = 'GSE26712'
clientWrappers.append(ClientWrapper(id=cohortname,
                                    input_folder=os.path.join(base_dir, cohortname)))
# Client 6
cohortname = 'GSE69428'
clientWrappers.append(ClientWrapper(id=cohortname,
                                    input_folder=os.path.join(base_dir, cohortname)))
# Double check that we only have one coordinator
_check_consistency_clientwrappers(clientWrappers)

In [None]:
###                                  INFO                                  ###
### The following code blocks run the simulation. They are divided into    ###
### multiple logical blocks to ease the use                                ###

In [None]:
### SIMULATION: all: initial ###
### Initial reading of the input folder

send_features_variables = list()
for clientWrapper in clientWrappers:
    # define the client class
    cohort_name = clientWrapper.id
    client = Client()
    client.config_based_init(clientname = cohort_name,
                             input_folder = clientWrapper.input_folder)
    clientWrapper.client_class = client
    send_features_variables.append((list(client.hash2feature.keys()), list(client.hash2variable.keys())))

In [None]:
### SIMULATION: Coordinator: global_feature_selection ###
### Aggregate the features and variables

# obtain and safe common genes and indices of design matrix
# wait for each client to send the list of genes they have
broadcast_features_variables = tuple()
for clientWrapper in clientWrappers:
    if clientWrapper.is_coordinator:
        global_feature_names, global_variables = \
            select_common_features_variables(
                lists_of_features_and_variables=send_features_variables
            )
        broadcast_features_variables = global_feature_names, global_variables

In [None]:
### SIMULATION: All: validate ###
### Expand data to fullfill the global format. Also performs consistency checks

for clientWrapper in clientWrappers:
    global_feauture_names_hashed, global_variables_hashed = \
        broadcast_features_variables
    client = clientWrapper.client_class
    client.validate_inputs(global_variables_hashed)
    client.set_data(global_feauture_names_hashed)
    # get all client names to generate design matrix
    all_client_names = [cw.id for cw in clientWrappers]
    err = client.create_design(all_client_names[:-1])
    if err:
        raise ValueError(err)

In [None]:
### SIMULATION: All: compute_XtX_XtY ###
### Compute XtX and XtY and share it
send_XtX_XtY_list: List[List[np.ndarray]] = list()
for clientWrapper in clientWrappers:
    client = clientWrapper.client_class
    client.sample_names = client.design.index.values

    # Error check if the design index and the data index are the same
    # we check by comparing the sorted indexes
    client._check_consistency_designfile()

    # Extract only relevant (the global) features and samples
    client.data = client.data.loc[client.feature_names, client.sample_names]
    client.n_samples = len(client.sample_names)

    # compute XtX and XtY
    XtX, XtY, err = client.compute_XtX_XtY()
    if err != None:
        raise ValueError(err)

    # send XtX and XtY
    send_XtX_XtY_list.append([XtX, XtY])

In [None]:
### SIMULATION: Coordinator: compute_beta
### Compute the beta values and broadcast them to the others
broadcast_betas = None # np.ndarray of shape num_features x design_columns

for clientWrapper in clientWrappers:
    if clientWrapper.is_coordinator:
        client = clientWrapper.client_class
        beta = compute_beta(XtX_XtY_list=send_XtX_XtY_list,
                            n=len(client.feature_names),
                            k=client.design.shape[1])

        # send beta to clients so they can correct their data
        broadcast_betas = beta

In [None]:
### SIMULATION: All: include_correction
### Corrects the individual data
for clientWrapper in clientWrappers:
    client = clientWrapper.client_class

    # remove the batch effects in own data and safe the results
    client.remove_batch_effects(beta)
    print(f"DEBUG: Shape of corrected data: {client.data_corrected.shape}")
    # As this is a simulation we don't save the corrected data to csv, instead
    # we save it as a variable to the clientwrapper
    clientWrapper.data_corrected = client.data_corrected
    # client.data_corrected.to_csv(os.path.join(os.getcwd(), "mnt", "output", "only_batch_corrected_data.csv"),
    #                                 sep=self.load("separator"))
    # client.data_corrected_and_raw.to_csv(os.path.join(os.getcwd(), "mnt", "output", "all_data.csv"),
    #                              sep=self.load("separator"))
    # with open(os.path.join(os.getcwd(), "mnt", "output", "report.txt"), "w") as f:
    #     f.write(client.report)

In [None]:
###                                  INFO                                  ###
###                            SIMULATION IS DONE                          ###
### The simulation is done. The corrected data is saved in the             ###
### clientWrapper instances. Now we analyse the data by comparing to the   ###
### calculated centralized corrected data.                                 ###


In [None]:
### Concat the federated data and read in the centralized data ###
central_df_path = os.path.join(os.path.dirname(base_dir), "after", "central_corrected_UNION.tsv")
central_df = pd.read_csv(central_df_path, sep="\t", index_col=0)
federated_df = _concat_federated_results(clientWrappers, samples_in_columns=True)
_compare_central_federated_dfs("microarray", central_df, federated_df)