In [201]:
### Imports ###
import os
from typing import List, Tuple
from classes.client import Client
from classes.coordinator_utils import select_common_features_variables, \
    compute_beta, reorder_matrix, create_beta_mask

import numpy as np
import pandas as pd
import time

In [202]:
### Helper Functions ###
### Just run this, these functions are needed in various places ###

### Define the client class ###
class ClientWrapper:
    """
    Holds all information necessary for the simulation run to work.
    Defines the input data of the client, it's name and if it should be
    considered as the coordinator
    """
    def __init__(self, id: str, input_folder: str, coordinator: bool = False):
        self.id = id
        self.input_folder = input_folder
        self.is_coordinator = coordinator
        self.client_class = None

def _check_consistency_clientwrappers(clientWrappers: List[ClientWrapper]) -> None:
    """
    Checks for a list of clients if they were created correctly
    Raises a ValueError in case of inconsistencies
    Checks:
        1. If exactly one coordinator exists
    """
    coord = False
    for clientWrapper in clientWrappers:
        if coord and clientWrapper.is_coordinator:
            raise ValueError("More than one coordinator was defined, please check "+\
                            "the code defining the clients")
        if clientWrapper.is_coordinator:
            coord = True
    if not coord:
        raise ValueError("No client instance is a coordinator, please designate "+\
                        "any client as a coordinator")

def _compare_central_federated_dfs(name:str,
                                   central_df: pd.DataFrame,
                                   federated_df: pd.DataFrame,
                                   intersection_features: List[str]) -> None:
    """
    Compares two dataframes for equality. First checks that index and columns
    are the same, then analyses the element wise differences.
    See the analyse_diff_df function for more details on the difference analysis.
    If both dataframes contain a NaN value at the same position, this is considered
    as equal (0 as difference).
    Args:
        name: Name used for printing
        central_df: The central dataframe
        federated_df: The federated dataframe
        intersection_features: The features that are common to both dataframes
    """
    central_df = central_df.sort_index(axis=0).sort_index(axis=1)
    federated_df = federated_df.sort_index(axis=0).sort_index(axis=1)
    print(f"_________________________Analysing: {name}_________________________")
    ### compare columns and index ###
    failed = False
    if not central_df.columns.equals(federated_df.columns):
        print(f"Columns do not match for central_df and federated_df")
        union_cols = central_df.columns.union(federated_df.columns)
        intercept_cols = central_df.columns.intersection(federated_df.columns)
        print(f"Union-Intercept of columns: {union_cols.difference(intercept_cols)}")
        failed = True
    if not central_df.index.equals(federated_df.index):
        print(f"Rows do not match for central_df and federated_df")
        union_rows = central_df.index.union(federated_df.index)
        intercept_rows = central_df.index.intersection(federated_df.index)
        print(f"Union-Intercept of rows: {union_rows.difference(intercept_rows)}")
        failed = True
    if failed:
        print(f"_________________________FAILED: {name}_________________________")

    df_diff = (central_df - federated_df).abs()
    print(f"Max difference: {df_diff.max().max()}")
    print(f"Mean difference: {df_diff.mean().mean()}")
    print(f"Max diff at position: {df_diff.idxmax().idxmax()}")

    df_diff_intersect = df_diff.loc[intersection_features]
    print(f"Max difference in intersect: {df_diff_intersect.max().max()}")
    print(f"Mean difference in intersect: {df_diff_intersect.mean().mean()}")
    print(f"Max diff at position in intersect: {df_diff_intersect.idxmax().idxmax()}")

def _concat_federated_results(clientWrappers: List[ClientWrapper],
                              samples_in_columns=True) -> Tuple[pd.DataFrame, List[str]]:
    """
    Concatenates the results of the federated clients into one dataframe
    Also checks which features are common to all clients
    and returns them
    Args:
        clientWrappers: List of ClientWrapper instances, containing the data
            in the data_corrected attribute
        samples_in_columns: If True, the samples are in the columns, if False
            they are in the rows. For expression files this is true.
            This decides how to aggregate the dataframes
    Returns:
        merged_df: The merged dataframe containing the data of all clients
        intersection_features: The features that are common to all clients
    """
    merged_df = None
    intersection_features = set()
    for clientWrapper in clientWrappers:
        # get the data in the correct format
        if not hasattr(clientWrapper, "data_corrected") or \
            clientWrapper.data_corrected is None:
            raise ValueError("No data was found in the clientWrappers")
        corrected_data = clientWrapper.data_corrected
        if not samples_in_columns:
            corrected_data = corrected_data.T

        cleaned_corrected_features = set(corrected_data.dropna().index)
        # initialize the merged_df
        if merged_df is None:
            merged_df = corrected_data
            intersection_features = cleaned_corrected_features
            continue

        # merge the data
        merged_df = pd.concat([merged_df, corrected_data], axis=1)
        intersection_features = intersection_features.intersection(cleaned_corrected_features)

    # final check
    if merged_df is None:
        raise ValueError("No data was found in the clientWrappers")
    # reverse the Transpose if necessary
    if not samples_in_columns:
        merged_df = merged_df.T
    return merged_df, list(intersection_features)


## DATA

In [203]:
### This part defines the data used. A ClientWrapper class is used to      ###
### describe all cohorts. If other data should be tested, this part should ###
### be changed                                                             ###
### Define the different clients ###
    # we use a helper class for each client, see the helper function
    # code block or the later definitions here for more info

In [204]:
# #################### Microarray Data ####################
# # First define the basefolder where all files are located
# base_dir = os.path.join("..")
# # Go back to the git repos root dir
# base_dir = os.path.join(base_dir, "evaluation_data", "microarray", "before")

# # List of cohort names
# cohort_names = [
#     'GSE38666',  # Client 1 (Coordinator)
#     'GSE14407',  # Client 2
#     'GSE6008',   # Client 3
#     'GSE40595',  # Client 4
#     'GSE26712',  # Client 5
#     'GSE69428',  # Client 6
# ]

In [205]:
#################### Proteomics Data ####################

# First define the basefolder where all files are located
base_dir = os.path.join("..")
# Go back to the git repos root dir
base_dir = os.path.join(base_dir, "evaluation_data", "proteomics", "before")

# List of cohort names
cohort_names = [
    'lab_A',  # Client 1 (Coordinator)
    'lab_B',  # Client 2
    'lab_C',  # Client 3
    'lab_D',  # Client 4
    'lab_E',  # Client 5
]

In [206]:
# ################# Microbiome Data ####################
# # First define the basefolder where all files are located
# base_dir = os.path.join("..")
# # Go back to the git repos root dir
# base_dir = os.path.join(base_dir, "evaluation_data", "microbiome", "before")

# # List of cohort names
# cohort_names = [
#     'PRJEB27928',  # Client 1 (Coordinator)
#     'PRJEB6070',   # Client 2
#     'PRJNA429097', # Client 3
#     'PRJEB10878',  # Client 4
#     'PRJNA731589', # Client 5
# ]

In [207]:
# ################### Simulation Data ###################

# # First define the basefolder where all files are located
# base_dir = os.path.join("..")
# # Go back to the git repos root dir
# # base_dir = os.path.join(base_dir, "evaluation_data", "simulated", "balanced", "before")
# # base_dir = os.path.join(base_dir, "evaluation_data", "simulated", "mild_imbalanced", "before")
# base_dir = os.path.join(base_dir, "evaluation_data", "simulated", "strong_imbalanced", "before")

# # List of cohort names
# cohort_names = [
#     'lab1',  # Client 1 (Coordinator)
#     'lab2',  # Client 2
#     'lab3',  # Client 3
# ]

In [208]:

# Initialize clientWrappers list
clientWrappers: List[ClientWrapper] = []

# Iterate over cohort names and create ClientWrapper instances
for i, cohortname in enumerate(cohort_names):
    clientWrappers.append(ClientWrapper(
        id=cohortname,
        input_folder=os.path.join(base_dir, cohortname),
        coordinator=(i == 0)  # Set the first client as coordinator
    ))

# Double check that we only have one coordinator
_check_consistency_clientwrappers(clientWrappers)


## Analysis

In [209]:
# measure time for all clients
time_tracker = {}

In [210]:
###                                  INFO                                  ###
### The following code blocks run the simulation. They are divided into    ###
### multiple logical blocks to ease the use                                ###

In [211]:
### SIMULATION: all: initial ###
### Initial reading of the input folder

send_features_variables = list()
for clientWrapper in clientWrappers:
    # define the client class
    cohort_name = clientWrapper.id
    client = Client()
    client.config_based_init(clientname = cohort_name,
                             input_folder = clientWrapper.input_folder,
                             use_hashing = False)
    clientWrapper.client_class = client
    send_features_variables.append((cohort_name,   # for mask creation - to track the cohort
                                    list(client.hash2feature.keys()),
                                    list(client.hash2variable.keys())))

Got the following config:
{'flimmaBatchCorrection': {'min_samples': 0, 'data_filename': 'intensities.tsv', 'separator': '\t', 'covariates': ['A'], 'design_filename': 'design.tsv', 'index_col': 'rowname', 'expression_file_flag': True}}
Opening dataset ../evaluation_data/simulated/strong_imbalanced/before/lab1/intensities.tsv
Shape of rawdata(expr_file): (6000, 40)
finished loading data, shape of data: (6000, 40), num_features: 6000, num_samples: 40
Got the following config:
{'flimmaBatchCorrection': {'min_samples': 0, 'data_filename': 'intensities.tsv', 'separator': '\t', 'covariates': ['A'], 'design_filename': 'design.tsv', 'index_col': 'rowname', 'expression_file_flag': True}}
Opening dataset ../evaluation_data/simulated/strong_imbalanced/before/lab2/intensities.tsv


Shape of rawdata(expr_file): (6000, 80)
finished loading data, shape of data: (6000, 80), num_features: 6000, num_samples: 80
Got the following config:
{'flimmaBatchCorrection': {'min_samples': 0, 'data_filename': 'intensities.tsv', 'separator': '\t', 'covariates': ['A'], 'design_filename': 'design.tsv', 'index_col': 'rowname', 'expression_file_flag': True}}
Opening dataset ../evaluation_data/simulated/strong_imbalanced/before/lab3/intensities.tsv
Shape of rawdata(expr_file): (6000, 480)
finished loading data, shape of data: (6000, 480), num_features: 6000, num_samples: 480


In [212]:
### SIMULATION: Coordinator: global_feature_selection ###
### Aggregate the features and variables

# obtain and safe common genes and indices of design matrix
# wait for each client to send the list of genes they have
# also memo the feature presence matrix and feature_to_cohorts

broadcast_features_variables = tuple()
for clientWrapper in clientWrappers:
    if clientWrapper.is_coordinator:

        time_tracker["Coordinator"] = time.time()

        global_feature_names, global_variables, feature_presence_matrix, cohorts_order = \
            select_common_features_variables(
                lists_of_features_and_variables=send_features_variables,
                min_clients=1      # minimum number of clients that need to have the feature
            )
        # memo the feature presence matrix and feature_to_cohorts
        broadcast_features_variables = global_feature_names, global_variables

        end_time = time.time()
        time_tracker["Coordinator"] = end_time - time_tracker["Coordinator"]


In [213]:
### SIMULATION: Coordinator: feature presence matrix ###
### Compute the feature presence matrix that will be used for the mask creation

for clientWrapper in clientWrappers:
    if clientWrapper.is_coordinator:
        start_time = time.time()
        all_client_names = [cw.id for cw in clientWrappers]
        feature_presence_matrix = reorder_matrix(feature_presence_matrix,
                                          all_client_names,
                                          cohorts_order)
        # memo the feature presence matrix
        end_time = time.time()
        time_tracker["Coordinator"] += end_time - start_time

In [214]:
### SIMULATION: All: validate ###
### Expand data to fullfill the global format. Also performs consistency checks

for clientWrapper in clientWrappers:

    time_tracker[clientWrapper.id] = time.time()

    global_feauture_names_hashed, global_variables_hashed = \
        broadcast_features_variables
    client = clientWrapper.client_class
    client.validate_inputs(global_variables_hashed)
    client.set_data(global_feauture_names_hashed)
    # get all client names to generate design matrix
    all_client_names = [cw.id for cw in clientWrappers]
    err = client.create_design(all_client_names[:-1])
    if err:
        raise ValueError(err)

    end_time = time.time()
    time_tracker[clientWrapper.id] = end_time - time_tracker[clientWrapper.id]

Client lab1: Inputs validated.
feature names: 6000
global features: 6000
Extra local features: 0
Extra global features: 0
Adding 0 extra global features
Got 6000 global features and 6000 features in the data matrix
Before reindexing got this data: (6000, 40)
After reindexing got this data: (6000, 40)
design was finally created:        intercept  A  lab1  lab2
file                           
s.6          1.0  0     1     0
s.29         1.0  0     1     0
s.39         1.0  0     1     0
s.56         1.0  0     1     0
s.57         1.0  0     1     0
s.58         1.0  0     1     0
s.73         1.0  0     1     0
s.97         1.0  0     1     0
s.99         1.0  0     1     0
s.107        1.0  0     1     0
s.125        1.0  0     1     0
s.146        1.0  0     1     0
s.158        1.0  0     1     0
s.167        1.0  0     1     0
s.204        1.0  0     1     0
s.205        1.0  0     1     0
s.212        1.0  0     1     0
s.223        1.0  0     1     0
s.225        1.0  0     1     

In [215]:
### Simulatuion: Coordinator: create design mask based on feature presence matrix ###
### Create the mask for the design matrix based on the feature presence matrix
### that will be used for the beta computation
for clientWrapper in clientWrappers:
    if clientWrapper.is_coordinator:
        start_time = time.time()
        client = clientWrapper.client_class

        n=len(client.feature_names)
        k=client.design.shape[1]

        global_mask = create_beta_mask(feature_presence_matrix, n, k)
        # memo the global mask

        end_time = time.time()
        time_tracker["Coordinator"] += end_time - start_time


In [216]:
### SIMULATION: All: prepare for compute_XtX_XtY ###

for clientWrapper in clientWrappers:
    start_time = time.time()

    client = clientWrapper.client_class
    client.sample_names = client.design.index.values

    # Error check if the design index and the data index are the same
    # we check by comparing the sorted indexes
    client._check_consistency_designfile()

    # Extract only relevant (the global) features and samples
    client.data = client.data.loc[client.feature_names, client.sample_names]
    client.n_samples = len(client.sample_names)

    end_time = time.time()
    time_tracker[clientWrapper.id] += end_time - start_time

In [217]:
### SIMULATION: All: compute_XtX_XtY ###
### Compute XtX and XtY and share it
send_XtX_XtY_list: List[List[np.ndarray]] = list()
for clientWrapper in clientWrappers:
    start_time = time.time()

    client = clientWrapper.client_class

    # compute XtX and XtY
    XtX, XtY, err = client.compute_XtX_XtY()
    if err != None:
        raise ValueError(err)

    # send XtX and XtY
    send_XtX_XtY_list.append([XtX, XtY])

    end_time = time.time()
    time_tracker[clientWrapper.id] += end_time - start_time

final vectors to be sent: XtX shape: (6000, 4, 4), XtY shape: (6000, 4)
final vectors to be sent: XtX shape: (6000, 4, 4), XtY shape: (6000, 4)
final vectors to be sent: XtX shape: (6000, 4, 4), XtY shape: (6000, 4)


In [218]:
### SIMULATION: Coordinator: compute_beta
### Compute the beta values and broadcast them to the others
broadcast_betas = None # np.ndarray of shape num_features x design_columns

for clientWrapper in clientWrappers:
    if clientWrapper.is_coordinator:

        start_time = time.time()

        client = clientWrapper.client_class
        beta = compute_beta(XtX_XtY_list=send_XtX_XtY_list,
                            n=len(client.feature_names),
                            k=client.design.shape[1],
                            global_mask=global_mask)

        # send beta to clients so they can correct their data
        broadcast_betas = beta

        end_time = time.time()
        time_tracker["Coordinator"] += end_time - start_time

INFO: Shape of beta: (6000, 4)
INFO: Number of pseudo inverses: 0


In [219]:
### SIMULATION: All: include_correction
### Corrects the individual data
for clientWrapper in clientWrappers:

    start_time = time.time()

    client = clientWrapper.client_class

    # remove the batch effects in own data and safe the results
    client.remove_batch_effects(beta)
    print(f"DEBUG: Shape of corrected data: {client.data_corrected.shape}")

    end_time = time.time()
    time_tracker[clientWrapper.id] += end_time - start_time

    # As this is a simulation we don't save the corrected data to csv, instead
    # we save it as a variable to the clientwrapper
    clientWrapper.data_corrected = client.data_corrected
    # client.data_corrected.to_csv(os.path.join(os.getcwd(), "mnt", "output", "only_batch_corrected_data.csv"),
    #                                 sep=self.load("separator"))
    # client.data_corrected_and_raw.to_csv(os.path.join(os.getcwd(), "mnt", "output", "all_data.csv"),
    #                              sep=self.load("separator"))
    # with open(os.path.join(os.getcwd(), "mnt", "output", "report.txt"), "w") as f:
    #     f.write(client.report)


start remove_batch_effects
Shape of data: (6000, 40)
Shape of beta:  (6000, 4)
Beta_reduced contains 0 Nan values
Shape of corrected data after correction: (6000, 40)
index is Index(['prt1', 'prt10', 'prt100', 'prt1000', 'prt1001', 'prt1002', 'prt1003',
       'prt1004', 'prt1005', 'prt1006',
       ...
       'prt990', 'prt991', 'prt992', 'prt993', 'prt994', 'prt995', 'prt996',
       'prt997', 'prt998', 'prt999'],
      dtype='object', name='rowname', length=6000)
Amount of index found in hash2feature: 6000/6000
After renaming got this data_corrected:               s.6      s.29      s.39      s.56      s.57      s.58      s.73  \
rowname                                                                         
prt1     2.037042 -1.561229  1.319323  0.394239 -0.981225 -1.527206 -0.715851   
prt10    1.660499  1.627477  0.998509  1.562899  1.025412  2.403214  0.838157   
prt100  -0.165470  5.848844  4.550318 -1.451182 -4.480440 -0.779068  4.551928   
prt1000  0.806859  0.710102  3.5220

In [220]:
# print the time tracker for the coordinator
print(f"Time tracker for coordinator, ms: {round(time_tracker['Coordinator']*1000, 2)}")

# print the time tracker for the clients
for clientWrapper in clientWrappers:
    print(f"Time tracker for {clientWrapper.id}, ms: {round(time_tracker[clientWrapper.id]*1000, 2)}")

Time tracker for coordinator, ms: 323.4
Time tracker for lab1, ms: 410.29
Time tracker for lab2, ms: 364.62
Time tracker for lab3, ms: 639.19


In [61]:
###                                  INFO                                  ###
###                            SIMULATION IS DONE                          ###
### The simulation is done. The corrected data is saved in the             ###
### clientWrapper instances. Now we analyse the data by comparing to the   ###
### calculated centralized corrected data.                                 ###


In [62]:
federated_df, intersect_features = _concat_federated_results(clientWrappers, samples_in_columns=True)

In [63]:
### SAVE THE RESULTS ###
# Microarray data
# federated_df.to_csv(os.path.join("..", "evaluation_data", "microarray", "after", "FedSim_corrected_data.tsv"), sep="\t")

# Proteomics data
federated_df.to_csv(os.path.join("..", "evaluation_data", "proteomics", "after", "FedSim_corrected_data.tsv"), sep="\t")

# Microbiome data
# federated_df.to_csv(os.path.join("..", "evaluation_data", "microbiome", "after", "FedSim_corrected_data.tsv"), sep="\t")

# Simulation data
# federated_df.to_csv(os.path.join("..", "evaluation_data", "simulated", "balanced", "after", "FedSim_corrected_data.tsv"), sep="\t")
# federated_df.to_csv(os.path.join("..", "evaluation_data", "simulated", "mild_imbalanced", "after", "FedSim_corrected_data.tsv"), sep="\t")
# federated_df.to_csv(os.path.join("..", "evaluation_data", "simulated", "strong_imbalanced", "after", "FedSim_corrected_data.tsv"), sep="\t")


In [39]:
# ### Concat the federated data and read in the centralized data ###
# central_df_path = os.path.join(os.path.dirname(base_dir), "after", "central_corrected_UNION.tsv")
# central_df = pd.read_csv(central_df_path, sep="\t", index_col=0)
# _compare_central_federated_dfs("microarray", central_df, federated_df, intersect_features)

_________________________Analysing: microarray_________________________
Max difference: 8.171241461241152e-14
Mean difference: 4.357000735613189e-15
Max diff at position: GSM1701030
Max difference in intersect: 8.171241461241152e-14
Mean difference in intersect: 4.515621419262236e-15
Max diff at position in intersect: GSM1701030


In [88]:
### Concat the federated data and read in the centralized data ###
central_df_path = os.path.join(os.path.dirname(os.path.dirname(base_dir)), "proteomics", "after", "intensities_log_Rcorrected_UNION.tsv")
central_df = pd.read_csv(central_df_path, sep="\t", index_col=0)
_compare_central_federated_dfs("proteomic data", central_df, federated_df, intersect_features)

_________________________Analysing: proteomic data_________________________
Max difference: 1.9184653865522705e-13
Mean difference: 3.1758692034102934e-14
Max diff at position: Clinspect_E_coli_A_S42_Slot1-21_1_8670
Max difference in intersect: 1.1368683772161603e-13
Mean difference in intersect: 3.338152465379177e-14
Max diff at position in intersect: Clinspect_E_coli_A_S42_Slot1-21_1_8670
