In [1]:
### Imports ###
import os
from typing import List, Tuple, Union, Dict
from classes.client import Client
from classes.coordinator_utils import select_common_features_variables, \
    compute_beta, create_beta_mask

import numpy as np
import pandas as pd
import time

In [2]:
### Helper Functions ###
### Just run this, these functions are needed in various places ###

### Define the client class ###
class ClientWrapper:
    """
    Holds all information necessary for the simulation run to work.
    Defines the input data of the client, it's name and if it should be
    considered as the coordinator
    """
    def __init__(self, id: str, input_folder: str, coordinator: bool = False):
        self.id = id
        self.input_folder = input_folder
        self.is_coordinator = coordinator
        self.client_class: Client # initiated later
        self.data_corrected: pd.DataFrame # initiated later

def _check_consistency_clientwrappers(clientWrappers: List[ClientWrapper]) -> None:
    """
    Checks for a list of clients if they were created correctly
    Raises a ValueError in case of inconsistencies
    Checks:
        1. If exactly one coordinator exists
    """
    coord = False
    for clientWrapper in clientWrappers:
        if coord and clientWrapper.is_coordinator:
            raise ValueError("More than one coordinator was defined, please check "+\
                            "the code defining the clients")
        if clientWrapper.is_coordinator:
            coord = True
    if not coord:
        raise ValueError("No client instance is a coordinator, please designate "+\
                        "any client as a coordinator")

def _compare_central_federated_dfs(name:str,
                                   central_df: pd.DataFrame,
                                   federated_df: pd.DataFrame,
                                   intersection_features: List[str]) -> None:
    """
    Compares two dataframes for equality. First checks that index and columns
    are the same, then analyses the element wise differences.
    See the analyse_diff_df function for more details on the difference analysis.
    If both dataframes contain a NaN value at the same position, this is considered
    as equal (0 as difference).
    Args:
        name: Name used for printing
        central_df: The central dataframe
        federated_df: The federated dataframe
        intersection_features: The features that are common to both dataframes
    """
    central_df = central_df.sort_index(axis=0).sort_index(axis=1)
    federated_df = federated_df.sort_index(axis=0).sort_index(axis=1)
    print(f"_________________________Analysing: {name}_________________________")
    ### compare columns and index ###
    failed = False
    if not central_df.columns.equals(federated_df.columns):
        print(f"Columns do not match for central_df and federated_df")
        union_cols = central_df.columns.union(federated_df.columns)
        intercept_cols = central_df.columns.intersection(federated_df.columns)
        print(f"Union-Intercept of columns: {union_cols.difference(intercept_cols)}")
        print(f"Central corrected columns: {len(central_df.columns)}")
        print(f"Federated corrected columns: {len(federated_df.columns)}")
        print(f"Columns only corrected by central: {len(central_df.columns.difference(federated_df.columns))}")
        print(f"Columns only corrected by federated: {len(federated_df.columns.difference(central_df.columns))}")
        failed = True
    if not central_df.index.equals(federated_df.index):
        print(f"Rows do not match for central_df and federated_df")
        union_rows = central_df.index.union(federated_df.index)
        intercept_rows = central_df.index.intersection(federated_df.index)
        print(f"Union-Intercept of rows: {union_rows.difference(intercept_rows)}")
        print(f"Central corrected rows: {len(central_df.index)}")
        print(f"Federated corrected rows: {len(federated_df.index)}")
        print(f"Rows only corrected by central: {len(central_df.index.difference(federated_df.index))}")
        print(f"Rows only corrected by federated: {len(federated_df.index.difference(central_df.index))}")
        failed = True
    if failed:
        print(f"_________________________FAILED: {name}_________________________")

    df_diff = (central_df - federated_df).abs()
    print(f"Max difference: {df_diff.max().max()}")
    print(f"Mean difference: {df_diff.mean().mean()}")
    print(f"Max diff at position: {df_diff.idxmax().idxmax()}")

    df_diff_intersect = df_diff.loc[intersection_features]
    print(f"Max difference in intersect: {df_diff_intersect.max().max()}")
    print(f"Mean difference in intersect: {df_diff_intersect.mean().mean()}")
    print(f"Max diff at position in intersect: {df_diff_intersect.idxmax().idxmax()}")

def _concat_federated_results(clientWrappers: List[ClientWrapper],
                              samples_in_columns=True) -> Tuple[pd.DataFrame, List[str]]:
    """
    Concatenates the results of the federated clients into one dataframe
    Also checks which features are common to all clients
    and returns them
    Args:
        clientWrappers: List of ClientWrapper instances, containing the data
            in the data_corrected attribute
        samples_in_columns: If True, the samples are in the columns, if False
            they are in the rows. For expression files this is true.
            This decides how to aggregate the dataframes
    Returns:
        merged_df: The merged dataframe containing the data of all clients
        intersection_features: The features that are common to all clients
    """
    merged_df = None
    intersection_features = set()
    for clientWrapper in clientWrappers:
        # get the data in the correct format
        if not hasattr(clientWrapper, "data_corrected") or \
            clientWrapper.data_corrected is None:
            raise ValueError("No data was found in the clientWrappers")
        corrected_data = clientWrapper.data_corrected
        if not samples_in_columns:
            corrected_data = corrected_data.T

        cleaned_corrected_features = set(corrected_data.dropna().index)
        # initialize the merged_df
        if merged_df is None:
            merged_df = corrected_data
            intersection_features = cleaned_corrected_features
            continue

        # merge the data
        merged_df = pd.concat([merged_df, corrected_data], axis=1)
        intersection_features = intersection_features.intersection(cleaned_corrected_features)

    # final check
    if merged_df is None:
        raise ValueError("No data was found in the clientWrappers")
    # reverse the Transpose if necessary
    if not samples_in_columns:
        merged_df = merged_df.T
    return merged_df, list(intersection_features)


## DATA

In [3]:
### This part defines the data used. A ClientWrapper class is used to      ###
### describe all cohorts.                                                  ###
### Comment in the wanted data or add a new block for new data. No other   ###
### changes are necessary.                                                 ###

In [4]:
# #################### Microarray Data ####################
# # First define the basefolder where all files are located
# base_dir = os.path.join("..")
# # Go back to the git repos root dir
# datafolder = "microarray"
# base_dir = os.path.join(base_dir, "evaluation_data", "microarray", "before")

# # List of cohort names
# cohort_names = [
#     'GSE38666',  # Client 1 (Coordinator)
#     'GSE14407',  # Client 2
#     'GSE6008',   # Client 3
#     'GSE40595',  # Client 4
#     'GSE26712',  # Client 5
#     'GSE69428',  # Client 6
# ]
# output_name = "microarray"
# central_filename = "central_corrected_UNION.tsv"

In [5]:
#################### Proteomics Data ####################

# First define the basefolder where all files are located
base_dir = os.path.join("..")
# Go back to the git repos root dir
datafolder = "proteomics"
base_dir = os.path.join(base_dir, "evaluation_data", datafolder, "before")

# List of cohort names
cohort_names = [
    'lab_A',  # Client 1 (Coordinator)
    'lab_B',  # Client 2
    'lab_C',  # Client 3
    'lab_D',  # Client 4
    'lab_E',  # Client 5
]
output_name = "proteomics"
central_filename = "intensities_log_Rcorrected_UNION.tsv"

In [6]:
# #################### Proteomics Multibatch Data ####################

# # First define the basefolder where all files are located
# base_dir = os.path.join("..")
# # Go back to the git repos root dir
# datafolder = "proteomics_multibatch"
# base_dir = os.path.join(base_dir, "evaluation_data", datafolder, "before")

# # List of cohort names
# cohort_names = [
#     'center1', # Client 1 (Coordinator)
#     'center2', # Client 2
#     'center3' # Client 3
# ]
# output_name = "proteomics_multibatch"
# central_filename = "intensities_log_Rcorrected_UNION.tsv"

In [7]:
# ################# Microbiome Data ####################
# # First define the basefolder where all files are located
# base_dir = os.path.join("..")
# # Go back to the git repos root dir
# datafolder = "microbiome"
# base_dir = os.path.join(base_dir, "evaluation_data", datafolder, "before")

# # List of cohort names
# cohort_names = [
#     'PRJEB27928',  # Client 1 (Coordinator)
#     'PRJEB6070',   # Client 2
#     'PRJNA429097', # Client 3
#     'PRJEB10878',  # Client 4
#     'PRJNA731589', # Client 5
# ]
# output_name = "microbiome"
# central_filename = "normalized_logmin_counts_5centers_Rcorrected.tsv"

In [8]:
# ################### Simulation Data ###################

# # First define the basefolder where all files are located
# base_dir = os.path.join("..")
# # Go back to the git repos root dir

# # There are three different setups, activate the wanted one here accordingly, deactivate the others:
# # BALANCED
# datafolder = os.path.join("simulated", "balanced")
# base_dir = os.path.join(base_dir, "evaluation_data", "simulated", "balanced", "before")
# output_name = "simulated_balanced"
# # # MILD IMBALANCED
# # datafolder = os.path.join("simulated", "mild_imbalanced")
# # base_dir = os.path.join(base_dir, "evaluation_data", "simulated", "mild_imbalanced", "before")
# # output_name = "simulated_mild_imbalanced"
# # # STRONG IMBALANCED
# # datafolder = os.path.join("simulated", "strong_imbalanced")
# # base_dir = os.path.join(base_dir, "evaluation_data", "simulated", "strong_imbalanced", "before")
# # output_name = "simulated_strong_imbalanced"

# # List of cohort names
# cohort_names = [
#     'lab1',  # Client 1 (Coordinator)
#     'lab2',  # Client 2
#     'lab3',  # Client 3
# ]
# central_filename = "intensities_R_corrected.tsv" # is the same for all simulated setups

In [9]:

# Initialize clientWrappers list
clientWrappers: List[ClientWrapper] = []

# Iterate over cohort names and create ClientWrapper instances
for i, cohortname in enumerate(cohort_names):
    clientWrappers.append(ClientWrapper(
        id=cohortname,
        input_folder=os.path.join(base_dir, cohortname),
        coordinator=(i == 0)  # Set the first client as coordinator
    ))

# Double check that we only have one coordinator
_check_consistency_clientwrappers(clientWrappers)


## Analysis

In [10]:
# measure time for all clients
time_tracker = {}

In [11]:
###                                  INFO                                  ###
### The following code blocks run the simulation. They are divided into    ###
### multiple logical blocks to ease the use                                ###

In [12]:
### SIMULATION: all: initial ###
### Initial reading of the input folder

send_batch_labels_covariates = list()
send_feature_information = list()
for clientWrapper in clientWrappers:
    # define the client class
    cohort_name = clientWrapper.id
    client = Client()
    client.config_based_init(clientname = cohort_name,
                             input_folder = clientWrapper.input_folder,
                             use_hashing = True)
    print(f"client has P39360?: {'P39360' in client.feature2hash}")
    print(f"client has 3ea5bc157a12312b21d02ab33f31b3a8bb666f6bdce412a922defef32d7fb12e?: {'3ea5bc157a12312b21d02ab33f31b3a8bb666f6bdce412a922defef32d7fb12e' in client.hash2feature}")
    assert isinstance(client.hash2feature, dict)
    assert isinstance(client.hash2variable, dict)
    clientWrapper.client_class = client
    # send the batch labels and covariates
    send_batch_labels_covariates.append((client.batch_labels, client.hash2variable.keys()))

# receive the batch labels and covariates
# intersect covariates, union labels
for clientWrapper in clientWrappers:
    if clientWrapper.is_coordinator:
        global_variables_hashed = set()
        global_batch_labels = list()
        for labels, variables in send_batch_labels_covariates:
            # intersect the variablesP39360
            if len(global_variables_hashed) == 0:
                global_variables_hashed = set(variables)
            else:
                global_variables_hashed = global_variables_hashed.intersection(set(variables))
            # extend the batch_labels
            global_batch_labels.extend(labels)
        # ensure the batch_labels are unique
        if len(global_batch_labels) != len(set(global_batch_labels)):
            raise ValueError("Batch labels are not unique")
        num_batches = len(global_batch_labels)

# get the relevant and privacy preserving features
for clientWrapper in clientWrappers:
    cohort_name = clientWrapper.id
    client = clientWrapper.client_class
    min_samples = max(num_batches+len(global_variables_hashed)+1, client.min_samples)
    batch_feature_presence_info: Dict[str, List[str]] = client.get_batch_feature_presence_info(min_samples=min_samples)
    print(f"Client has feature 624?: {['3ea5bc157a12312b21d02ab33f31b3a8bb666f6bdce412a922defef32d7fb12e' in batch_features for batch_features in batch_feature_presence_info.values()]}")
    send_feature_information.append((cohort_name,
                                client.position,
                                client.reference_batch,
                                batch_feature_presence_info))


Got the following config:
{'flimmaBatchCorrection': {'batch_col': None, 'covariates': ['Pyr'], 'data_filename': 'intensities_log_UNION.tsv', 'design_filename': 'design.tsv', 'design_separator': '\t', 'expression_file_flag': True, 'index_col': 'rowname', 'min_samples': 2, 'normalizationMethod': None, 'position': 0, 'separator': '\t', 'smpc': True}}
Opening dataset ../evaluation_data/proteomics/before/lab_A/intensities_log_UNION.tsv
Shape of rawdata(expr_file): (2549, 24)
I got feature P39360?: False
finished loading data, shape of data: (2549, 24), num_features: 2549, num_samples: 24
client has P39360?: False
client has 3ea5bc157a12312b21d02ab33f31b3a8bb666f6bdce412a922defef32d7fb12e?: False
Got the following config:
{'flimmaBatchCorrection': {'batch_col': None, 'covariates': ['Pyr'], 'data_filename': 'intensities_log_UNION.tsv', 'design_filename': 'design.tsv', 'design_separator': '\t', 'expression_file_flag': True, 'index_col': 'rowname', 'min_samples': 2, 'normalizationMethod': None,

In [13]:
### SIMULATION: Coordinator: global_feature_selection ###
### Aggregate the features and variables

# obtain and safe common genes and indices of design matrix
# wait for each client to send the list of genes they have
# also memo the feature presence matrix and feature_to_cohorts

broadcast_features_variables = tuple()
for clientWrapper in clientWrappers:
    if clientWrapper.is_coordinator:

        time_tracker["Coordinator"] = time.time()

        global_feature_names, feature_presence_matrix, cohorts_order = \
            select_common_features_variables(
                feature_batch_info=send_feature_information,
                min_clients=3,
                default_order=cohort_names
            )
        # memo the feature presence matrix and feature_to_cohorts
        broadcast_features_variables = global_feature_names, cohorts_order
        end_time = time.time()
        time_tracker["Coordinator"] = end_time - time_tracker["Coordinator"]


INFO: Found 2702 features present in at least 3 clients
INFO: Total number of features shared: 3047
INFO: Using given specific client order: ['lab_A', 'lab_B', 'lab_C', 'lab_D', 'lab_E']
INFO: Cohorts order: ['lab_A', 'lab_B', 'lab_C', 'lab_D', 'lab_E']


In [14]:
### SIMULATION: All: validate ###
### Expand data to fullfill the global format. Also performs consistency checks
for clientWrapper in clientWrappers:

    time_tracker[clientWrapper.id] = time.time()

    global_feauture_names_hashed, cohorts_order = \
        broadcast_features_variables
    client = clientWrapper.client_class
    client.validate_inputs(global_variables_hashed)
    client.set_data(global_feauture_names_hashed)

    err = client.create_design(cohorts_order)
    if err:
        raise ValueError(err)

    end_time = time.time()
    time_tracker[clientWrapper.id] = end_time - time_tracker[clientWrapper.id]

INFO: Client lab_A: Inputs validated.
Number of features available in this client: 2549
Number of features given globally: 2702
Number of features only available on this client: 54
Number of features available in other clients but not this client: 207
Adding 207 extra global features
INFO: Client lab_A has only one batch
INFO: Client lab_A is not the reference batch
INFO: Client lab_B: Inputs validated.
Number of features available in this client: 2846
Number of features given globally: 2702
Number of features only available on this client: 181
Number of features available in other clients but not this client: 37
Adding 37 extra global features
INFO: Client lab_B has only one batch
INFO: Client lab_B is not the reference batch
INFO: Client lab_C: Inputs validated.
Number of features available in this client: 2820
Number of features given globally: 2702
Number of features only available on this client: 142
Number of features available in other clients but not this client: 24
Adding 24 e

In [15]:
### Simulatuion: Coordinator: create design mask based on feature presence matrix ###
### Create the mask for the design matrix based on the feature presence matrix
### that will be used for the beta computation
for clientWrapper in clientWrappers:
    if clientWrapper.is_coordinator:
        start_time = time.time()
        client = clientWrapper.client_class

        n=len(client.feature_names)
        k=client.design.shape[1]

        global_mask = create_beta_mask(feature_presence_matrix, n, k)
        print(f"Shape of global mask: {global_mask.shape}")
        print(f"Head of mask: {global_mask[:5]}")
        # memo the global mask

        end_time = time.time()
        time_tracker["Coordinator"] += end_time - start_time


Shape of global mask: (2702, 6)
Head of mask: [[False False False False False False]
 [False False False False False False]
 [False False False False False False]
 [False False False False False False]
 [False False False False False False]]


In [16]:
### SIMULATION: All: prepare for compute_XtX_XtY ###

for clientWrapper in clientWrappers:
    start_time = time.time()

    client = clientWrapper.client_class
    client.sample_names = client.design.index.values

    # Error check if the design index and the data index are the same
    # we check by comparing the sorted indexes
    client._check_consistency_designfile()

    # Extract only relevant (the global) features and samples
    client.data = client.data.loc[client.feature_names, client.sample_names]
    client.n_samples = len(client.sample_names)

    end_time = time.time()
    time_tracker[clientWrapper.id] += end_time - start_time

In [17]:
### SIMULATION: All: compute_XtX_XtY ###
### Compute XtX and XtY and share it
send_XtX_XtY_list: List[List[np.ndarray]] = list()
for clientWrapper in clientWrappers:
    start_time = time.time()

    client = clientWrapper.client_class

    # compute XtX and XtY
    XtX, XtY, err = client.compute_XtX_XtY()
    if err != "":
        raise ValueError(err)

    # send XtX and XtY
    send_XtX_XtY_list.append([XtX, XtY])

    end_time = time.time()
    time_tracker[clientWrapper.id] += end_time - start_time

In [18]:
### SIMULATION: Coordinator: compute_beta
### Compute the beta values and broadcast them to the others
broadcast_betas = None # np.ndarray of shape num_features x design_columns

for clientWrapper in clientWrappers:
    print(f"{clientWrapper.client_class.hash2feature.get('3ea5bc157a12312b21d02ab33f31b3a8bb666f6bdce412a922defef32d7fb12e', None)}")

for clientWrapper in clientWrappers:
    if clientWrapper.is_coordinator:

        start_time = time.time()

        client = clientWrapper.client_class
        #TODO: RMV
        print(f"{client.feature_names[654] in client.hash2feature}")
        print(f"FeaturePresenceMatrix of feature 654 ({client.feature_names[654]}):")
        print(f"{feature_presence_matrix[654, :]}")
        beta = compute_beta(XtX_XtY_list=send_XtX_XtY_list,
                            n=len(client.feature_names),
                            k=client.design.shape[1],
                            global_mask=global_mask)

        # send beta to clients so they can correct their data
        broadcast_betas = beta

        end_time = time.time()
        time_tracker["Coordinator"] += end_time - start_time

None
P39360
P39360
P39360
None
False
FeaturePresenceMatrix of feature 654 (3ea5bc157a12312b21d02ab33f31b3a8bb666f6bdce412a922defef32d7fb12e):
[0 1 1 1 0]
INFO: Error at feature 654
Mask: [False False  True False False  True]
submatrix: [[35. 35. 12. 11.]
 [35. 35. 12. 11.]
 [12. 12. 12.  0.]
 [11. 11.  0. 11.]]
full XTX: [[35. 35.  0. 12. 11. 12.]
 [35. 35.  0. 12. 11. 12.]
 [ 0.  0.  0.  0.  0.  0.]
 [12. 12.  0. 12.  0.  0.]
 [11. 11.  0.  0. 11.  0.]
 [12. 12.  0.  0.  0. 12.]]


In [19]:
### SIMULATION: All: include_correction
### Corrects the individual data
for clientWrapper in clientWrappers:

    start_time = time.time()

    client = clientWrapper.client_class

    # remove the batch effects in own data and safe the results
    client.remove_batch_effects(beta)
    print(f"DEBUG: Shape of corrected data: {client.data_corrected.shape}")

    end_time = time.time()
    time_tracker[clientWrapper.id] += end_time - start_time

    # As this is a simulation we don't save the corrected data to csv, instead
    # we save it as a variable to the clientwrapper
    clientWrapper.data_corrected = client.data_corrected
    # client.data_corrected.to_csv(os.path.join(os.getcwd(), "mnt", "output", "only_batch_corrected_data.csv"),
    #                                 sep=self.load("separator"))
    # client.data_corrected_and_raw.to_csv(os.path.join(os.getcwd(), "mnt", "output", "all_data.csv"),
    #                              sep=self.load("separator"))
    # with open(os.path.join(os.getcwd(), "mnt", "output", "report.txt"), "w") as f:
    #     f.write(client.report)


INFO: Removing the batch effects with the trained betas
INFO: Shape of corrected data after correction: (2495, 24)
DEBUG: Shape of corrected data: (2495, 24)
INFO: Removing the batch effects with the trained betas
INFO: Shape of corrected data after correction: (2665, 23)
DEBUG: Shape of corrected data: (2665, 23)
INFO: Removing the batch effects with the trained betas
INFO: Shape of corrected data after correction: (2678, 23)
DEBUG: Shape of corrected data: (2678, 23)
INFO: Removing the batch effects with the trained betas
INFO: Shape of corrected data after correction: (2686, 24)
DEBUG: Shape of corrected data: (2686, 24)
INFO: Removing the batch effects with the trained betas
INFO: Shape of corrected data after correction: (2376, 24)
DEBUG: Shape of corrected data: (2376, 24)


In [20]:
# print the time tracker for the coordinator
print(f"Time tracker for coordinator, ms: {round(time_tracker['Coordinator']*1000, 2)}")

# print the time tracker for the clients
for clientWrapper in clientWrappers:
    print(f"Time tracker for {clientWrapper.id}, ms: {round(time_tracker[clientWrapper.id]*1000, 2)}")

Time tracker for coordinator, ms: 73.34
Time tracker for lab_A, ms: 135.22
Time tracker for lab_B, ms: 82.88
Time tracker for lab_C, ms: 148.13
Time tracker for lab_D, ms: 143.03
Time tracker for lab_E, ms: 92.08


In [21]:
###                                  INFO                                  ###
###                            SIMULATION IS DONE                          ###
### The simulation is done. The corrected data is saved in the             ###
### clientWrapper instances. Now we analyse the data by comparing to the   ###
### calculated centralized corrected data.                                 ###


In [22]:
federated_df, intersect_features = _concat_federated_results(clientWrappers, samples_in_columns=True)

In [23]:
### SAVE THE RESULTS ###
#federated_df.to_csv(os.path.join("..", "evaluation_data", datafolder, "after", "FedSim_corrected_data.tsv"), sep="\t")

In [24]:
### Concat the federated data and read in the centralized data ###
base_dir = os.path.join("..", "evaluation_data")
central_df_path = os.path.join(base_dir, datafolder, "after", central_filename)
central_df = pd.read_csv(central_df_path, sep="\t", index_col=0)
_compare_central_federated_dfs(output_name, central_df, federated_df, intersect_features)

_________________________Analysing: proteomics_________________________
Rows do not match for central_df and federated_df
Union-Intercept of rows: Index(['A0A385XJE6;P0CE49;P0CE50;P0CE51;P0CE52;P0CE53;P0CE54;P0CE55;P0CE56;P0CE57;P0CE58',
       'A5A621', 'P00722', 'P02931;P02932;P21420', 'P02932', 'P03007',
       'P03030', 'P05050', 'P05052', 'P05100',
       ...
       'Q46811', 'Q46812', 'Q46814', 'Q46865', 'Q46938', 'Q47156', 'Q47157',
       'Q47537', 'Q47702', 'Q6BEX0'],
      dtype='object', name='rowname', length=430)
Central corrected rows: 2272
Federated corrected rows: 2702
Rows only corrected by central: 0
Rows only corrected by federated: 430
_________________________FAILED: proteomics_________________________
Max difference: 1.1368683772161603e-13
Mean difference: 3.327937958514573e-14
Max diff at position: Clinspect_E_coli_A_S5_Slot1-7_1_8641
Max difference in intersect: 1.1368683772161603e-13
Mean difference in intersect: 3.338782576384476e-14
Max diff at position in in