In [23]:
### Imports ###
import os
from typing import List, Tuple
from classes.client import Client
from classes.coordinator_utils import select_common_features_variables, \
    compute_beta, reorder_matrix, create_beta_mask

import numpy as np
import pandas as pd

In [24]:
### Helper Functions ###
### Just run this, these functions are needed in various places ###

### Define the client class ###
class ClientWrapper:
    """
    Holds all information necessary for the simulation run to work.
    Defines the input data of the client, it's name and if it should be
    considered as the coordinator
    """
    def __init__(self, id: str, input_folder: str, coordinator: bool = False):
        self.id = id
        self.input_folder = input_folder
        self.is_coordinator = coordinator
        self.client_class = None

def _check_consistency_clientwrappers(clientWrappers: List[ClientWrapper]) -> None:
    """
    Checks for a list of clients if they were created correctly
    Raises a ValueError in case of inconsistencies
    Checks:
        1. If exactly one coordinator exists
    """
    coord = False
    for clientWrapper in clientWrappers:
        if coord and clientWrapper.is_coordinator:
            raise ValueError("More than one coordinator was defined, please check "+\
                            "the code defining the clients")
        if clientWrapper.is_coordinator:
            coord = True
    if not coord:
        raise ValueError("No client instance is a coordinator, please designate "+\
                        "any client as a coordinator")
    
def _compare_central_federated_dfs(name:str,
                                   central_df: pd.DataFrame,
                                   federated_df: pd.DataFrame,
                                   intersection_features: List[str]) -> None:
    """
    Compares two dataframes for equality. First checks that index and columns
    are the same, then analyses the element wise differences.
    See the analyse_diff_df function for more details on the difference analysis.
    If both dataframes contain a NaN value at the same position, this is considered
    as equal (0 as difference).
    Args:
        name: Name used for printing
        central_df: The central dataframe
        federated_df: The federated dataframe
        intersection_features: The features that are common to both dataframes
    """
    central_df = central_df.sort_index(axis=0).sort_index(axis=1)
    federated_df = federated_df.sort_index(axis=0).sort_index(axis=1)
    print(f"_________________________Analysing: {name}_________________________")
    ### compare columns and index ###
    failed = False
    if not central_df.columns.equals(federated_df.columns):
        print(f"Columns do not match for central_df and federated_df")
        union_cols = central_df.columns.union(federated_df.columns)
        intercept_cols = central_df.columns.intersection(federated_df.columns)
        print(f"Union-Intercept of columns: {union_cols.difference(intercept_cols)}")
        failed = True
    if not central_df.index.equals(federated_df.index):
        print(f"Rows do not match for central_df and federated_df")
        union_rows = central_df.index.union(federated_df.index)
        intercept_rows = central_df.index.intersection(federated_df.index)
        print(f"Union-Intercept of rows: {union_rows.difference(intercept_rows)}")
        failed = True
    if failed:
        print(f"_________________________FAILED: {name}_________________________")

    df_diff = (central_df - federated_df).abs()
    print(f"Max difference: {df_diff.max().max()}")
    print(f"Mean difference: {df_diff.mean().mean()}")
    print(f"Max diff at position: {df_diff.idxmax().idxmax()}")

    df_diff_intersect = df_diff.loc[intersection_features]
    print(f"Max difference in intersect: {df_diff_intersect.max().max()}")
    print(f"Mean difference in intersect: {df_diff_intersect.mean().mean()}")
    print(f"Max diff at position in intersect: {df_diff_intersect.idxmax().idxmax()}")

def _concat_federated_results(clientWrappers: List[ClientWrapper],
                              samples_in_columns=True) -> Tuple[pd.DataFrame, List[str]]:
    """
    Concatenates the results of the federated clients into one dataframe
    Also checks which features are common to all clients
    and returns them
    Args:
        clientWrappers: List of ClientWrapper instances, containing the data
            in the data_corrected attribute
        samples_in_columns: If True, the samples are in the columns, if False
            they are in the rows. For expression files this is true.
            This decides how to aggregate the dataframes
    Returns:
        merged_df: The merged dataframe containing the data of all clients
        intersection_features: The features that are common to all clients
    """
    merged_df = None
    intersection_features = set()
    for clientWrapper in clientWrappers:
        # get the data in the correct format
        if not hasattr(clientWrapper, "data_corrected") or \
            clientWrapper.data_corrected is None:
            raise ValueError("No data was found in the clientWrappers")
        corrected_data = clientWrapper.data_corrected
        if not samples_in_columns:
            corrected_data = corrected_data.T

        cleaned_corrected_features = set(corrected_data.dropna().index)
        # initialize the merged_df
        if merged_df is None:
            merged_df = corrected_data
            intersection_features = cleaned_corrected_features
            continue

        # merge the data
        merged_df = pd.concat([merged_df, corrected_data], axis=1)
        intersection_features = intersection_features.intersection(cleaned_corrected_features)

    # final check
    if merged_df is None:
        raise ValueError("No data was found in the clientWrappers")
    # reverse the Transpose if necessary
    if not samples_in_columns:
        merged_df = merged_df.T
    return merged_df, list(intersection_features)


In [25]:
### This part defines the data used. A ClientWrapper class is used to      ###
### describe all cohorts. If other data should be tested, this part should ###
### be changed                                                             ###
### Define the different clients ###
clientWrappers: List[ClientWrapper] = list()
    # we use a helper class for each client, see the helper function
    # code block or the later definitions here for more info
# First define the basefolder where all files are located
base_dir = os.path.join("..")
    # go back to the git repos root dir
base_dir = os.path.join(base_dir, "evaluation_data", "microarray", "before")
    
# location of the microarray data
# Client 1
cohortname = 'GSE38666'
clientWrappers.append(ClientWrapper(id=cohortname,
                                    input_folder=os.path.join(base_dir, cohortname),
                                    coordinator=True))
# Client 2
cohortname = 'GSE14407'
clientWrappers.append(ClientWrapper(id=cohortname,
                                    input_folder=os.path.join(base_dir, cohortname)))
# Client 3
cohortname = 'GSE6008'
clientWrappers.append(ClientWrapper(id=cohortname,
                                    input_folder=os.path.join(base_dir, cohortname)))
# Client 4
cohortname = 'GSE40595'
clientWrappers.append(ClientWrapper(id=cohortname,
                                    input_folder=os.path.join(base_dir, cohortname)))
# Client 5
cohortname = 'GSE26712'
clientWrappers.append(ClientWrapper(id=cohortname,
                                    input_folder=os.path.join(base_dir, cohortname)))
# Client 6
cohortname = 'GSE69428'
clientWrappers.append(ClientWrapper(id=cohortname,
                                    input_folder=os.path.join(base_dir, cohortname)))
# Double check that we only have one coordinator
_check_consistency_clientwrappers(clientWrappers)

In [26]:
# ### This part defines the data used. A ClientWrapper class is used to      ###
# ### describe all cohorts. If other data should be tested, this part should ###
# ### be changed                                                             ###
# ### Define the different clients ###
# clientWrappers: List[ClientWrapper] = list()
# # we use a helper class for each client, see the helper function
# # code block or the later definitions here for more info
# # First define the basefolder where all files are located
# base_dir = os.path.join("..")

# # go back to the git repos root dir
# base_dir = os.path.join(base_dir, "evaluation_data", "proteomics", "before", "balanced")

# # location of the proteomic data
# # Client 1
# cohortname = 'lab_A'
# clientWrappers.append(ClientWrapper(id=cohortname,
#                                     input_folder=os.path.join(base_dir, cohortname),
#                                     coordinator=True))
# # Client 2
# cohortname = 'lab_B'
# clientWrappers.append(ClientWrapper(id=cohortname,
#                                     input_folder=os.path.join(base_dir, cohortname)))
# # Client 3
# cohortname = 'lab_C'
# clientWrappers.append(ClientWrapper(id=cohortname,
#                                     input_folder=os.path.join(base_dir, cohortname)))
# # Client 4
# cohortname = 'lab_D'
# clientWrappers.append(ClientWrapper(id=cohortname,
#                                     input_folder=os.path.join(base_dir, cohortname)))
# # Client 5
# cohortname = 'lab_E'
# clientWrappers.append(ClientWrapper(id=cohortname,
#                                     input_folder=os.path.join(base_dir, cohortname)))

# # Double check that we only have one coordinator
# _check_consistency_clientwrappers(clientWrappers)

In [27]:
###                                  INFO                                  ###
### The following code blocks run the simulation. They are divided into    ###
### multiple logical blocks to ease the use                                ###

In [28]:
### SIMULATION: all: initial ###
### Initial reading of the input folder

send_features_variables = list()
for clientWrapper in clientWrappers:
    # define the client class
    cohort_name = clientWrapper.id
    client = Client()
    client.config_based_init(clientname = cohort_name,
                             input_folder = clientWrapper.input_folder,
                             use_hashing = False)
    clientWrapper.client_class = client
    send_features_variables.append((cohort_name,   # for mask creation - to track the cohort
                                    list(client.hash2feature.keys()), 
                                    list(client.hash2variable.keys())))

Got the following config:
{'flimmaBatchCorrection': {'min_samples': 2, 'data_filename': 'expr_for_correction_UNION.tsv', 'separator': '\t', 'covariates': ['HGSC'], 'design_filename': 'design.tsv', 'index_col': 'Gene', 'expression_file_flag': True}}
Opening dataset ../evaluation_data/microarray/before/GSE38666/expr_for_correction_UNION.tsv
Shape of rawdata(expr_file): (51276, 30)
finished loading data, shape of data: (51276, 30), num_features: 51276, num_samples: 30
Got the following config:
{'flimmaBatchCorrection': {'min_samples': 2, 'data_filename': 'expr_for_correction_UNION.tsv', 'separator': '\t', 'covariates': ['HGSC'], 'design_filename': 'design.tsv', 'index_col': 'Gene', 'expression_file_flag': True}}
Opening dataset ../evaluation_data/microarray/before/GSE14407/expr_for_correction_UNION.tsv
Shape of rawdata(expr_file): (51276, 24)
finished loading data, shape of data: (51276, 24), num_features: 51276, num_samples: 24
Got the following config:
{'flimmaBatchCorrection': {'min_sa

In [29]:
### SIMULATION: Coordinator: global_feature_selection ###
### Aggregate the features and variables

# obtain and safe common genes and indices of design matrix
# wait for each client to send the list of genes they have
# also memo the feature presence matrix and feature_to_cohorts

broadcast_features_variables = tuple()
for clientWrapper in clientWrappers:
    if clientWrapper.is_coordinator:
        global_feature_names, global_variables, feature_presence_matrix, cohorts_order = \
            select_common_features_variables(
                lists_of_features_and_variables=send_features_variables,
                min_clients=1      # minimum number of clients that need to have the feature
            )
        # memo the feature presence matrix and feature_to_cohorts
        broadcast_features_variables = global_feature_names, global_variables
        

In [30]:
### SIMULATION: Coordinator: feature presence matrix ###
### Compute the feature presence matrix that will be used for the mask creation

for clientWrapper in clientWrappers:
    if clientWrapper.is_coordinator:
        all_client_names = [cw.id for cw in clientWrappers]
        feature_presence_matrix = reorder_matrix(feature_presence_matrix, 
                                          all_client_names, 
                                          cohorts_order)
        # memo the feature presence matrix


In [31]:
### SIMULATION: All: validate ###
### Expand data to fullfill the global format. Also performs consistency checks

for clientWrapper in clientWrappers:
    global_feauture_names_hashed, global_variables_hashed = \
        broadcast_features_variables
    client = clientWrapper.client_class
    client.validate_inputs(global_variables_hashed)
    client.set_data(global_feauture_names_hashed)
    # get all client names to generate design matrix
    all_client_names = [cw.id for cw in clientWrappers]
    err = client.create_design(all_client_names[:-1])
    if err:
        raise ValueError(err)

Client GSE38666: Inputs validated.
feature names: 51276
global features: 51276
Extra local features: 0
Extra global features: 0
Adding 0 extra global features
Got 51276 global features and 51276 features in the data matrix
Before reindexing got this data: (51276, 30)
After reindexing got this data: (51276, 30)
design was finally created:            intercept  HGSC  GSE38666  GSE14407  GSE6008  GSE40595  GSE26712
file                                                                       
GSM947277        1.0     0         1         0        0         0         0
GSM947278        1.0     0         1         0        0         0         0
GSM947279        1.0     0         1         0        0         0         0
GSM947280        1.0     0         1         0        0         0         0
GSM947281        1.0     0         1         0        0         0         0
GSM947282        1.0     0         1         0        0         0         0
GSM947283        1.0     0         1         0      

In [32]:
### Simulatuion: Coordinator: create design mask based on feature presence matrix ###
### Create the mask for the design matrix based on the feature presence matrix
### that will be used for the beta computation
for clientWrapper in clientWrappers:
    if clientWrapper.is_coordinator:
        client = clientWrapper.client_class

        n=len(client.feature_names)
        k=client.design.shape[1]

        global_mask = create_beta_mask(feature_presence_matrix, n, k)
        # memo the global mask


In [33]:
### SIMULATION: All: prepare for compute_XtX_XtY ###

for clientWrapper in clientWrappers:
    client = clientWrapper.client_class
    client.sample_names = client.design.index.values

    # Error check if the design index and the data index are the same
    # we check by comparing the sorted indexes
    client._check_consistency_designfile()

    # Extract only relevant (the global) features and samples
    client.data = client.data.loc[client.feature_names, client.sample_names]
    client.n_samples = len(client.sample_names)

In [34]:
### SIMULATION: All: compute_XtX_XtY ###
### Compute XtX and XtY and share it
send_XtX_XtY_list: List[List[np.ndarray]] = list()
for clientWrapper in clientWrappers:
    client = clientWrapper.client_class

    # compute XtX and XtY
    XtX, XtY, err = client.compute_XtX_XtY()
    if err != None:
        raise ValueError(err)

    # send XtX and XtY
    send_XtX_XtY_list.append([XtX, XtY])

final vectors to be sent: XtX shape: (51276, 7, 7), XtY shape: (51276, 7)
final vectors to be sent: XtX shape: (51276, 7, 7), XtY shape: (51276, 7)
final vectors to be sent: XtX shape: (51276, 7, 7), XtY shape: (51276, 7)
final vectors to be sent: XtX shape: (51276, 7, 7), XtY shape: (51276, 7)
final vectors to be sent: XtX shape: (51276, 7, 7), XtY shape: (51276, 7)
final vectors to be sent: XtX shape: (51276, 7, 7), XtY shape: (51276, 7)


In [35]:
### SIMULATION: Coordinator: compute_beta
### Compute the beta values and broadcast them to the others
broadcast_betas = None # np.ndarray of shape num_features x design_columns

for clientWrapper in clientWrappers:
    if clientWrapper.is_coordinator:
        client = clientWrapper.client_class
        beta = compute_beta(XtX_XtY_list=send_XtX_XtY_list,
                            n=len(client.feature_names),
                            k=client.design.shape[1],
                            global_mask=global_mask)

        # send beta to clients so they can correct their data
        broadcast_betas = beta

INFO: Shape of beta: (51276, 7)
INFO: Number of pseudo inverses: 0


In [36]:
### SIMULATION: All: include_correction
### Corrects the individual data
for clientWrapper in clientWrappers:
    client = clientWrapper.client_class

    # remove the batch effects in own data and safe the results
    client.remove_batch_effects(beta)
    print(f"DEBUG: Shape of corrected data: {client.data_corrected.shape}")
    # As this is a simulation we don't save the corrected data to csv, instead
    # we save it as a variable to the clientwrapper
    clientWrapper.data_corrected = client.data_corrected
    # client.data_corrected.to_csv(os.path.join(os.getcwd(), "mnt", "output", "only_batch_corrected_data.csv"),
    #                                 sep=self.load("separator"))
    # client.data_corrected_and_raw.to_csv(os.path.join(os.getcwd(), "mnt", "output", "all_data.csv"),
    #                              sep=self.load("separator"))
    # with open(os.path.join(os.getcwd(), "mnt", "output", "report.txt"), "w") as f:
    #     f.write(client.report)

start remove_batch_effects
Shape of data: (51276, 30)
Shape of beta:  (51276, 7)
Beta_reduced contains 0 Nan values
Shape of corrected data after correction: (51276, 30)
index is Index(['AA001021', 'AA001052', 'AA001150', 'AA001203', 'AA001287', 'AA001364',
       'AA001375', 'AA001390', 'AA001400', 'AA001414',
       ...
       'Z97832', 'Z98200', 'Z98443', 'Z98745', 'Z98749', 'Z98751', 'Z98752',
       'Z98884', 'Z98950', 'Z99714'],
      dtype='object', name='Gene', length=51276)
Amount of index found in hash2feature: 51276/51276
After renaming got this data_corrected:           GSM947277  GSM947278  GSM947279  GSM947280  GSM947281  GSM947282  \
Gene                                                                         
AA001021   3.925542   4.470440   4.372540   3.962911   3.775610   4.983943   
AA001052   9.849833   8.828600   9.088667   8.774817   9.475751   9.435066   
AA001150   4.592693   4.428854   4.192204   4.764856   5.720164   3.780583   
AA001203   9.373274   9.732365 

In [37]:
###                                  INFO                                  ###
###                            SIMULATION IS DONE                          ###
### The simulation is done. The corrected data is saved in the             ###
### clientWrapper instances. Now we analyse the data by comparing to the   ###
### calculated centralized corrected data.                                 ###


In [39]:
### Concat the federated data and read in the centralized data ###
central_df_path = os.path.join(os.path.dirname(base_dir), "after", "central_corrected_UNION.tsv")
central_df = pd.read_csv(central_df_path, sep="\t", index_col=0)
federated_df, intersect_features = _concat_federated_results(clientWrappers, samples_in_columns=True)
_compare_central_federated_dfs("microarray", central_df, federated_df, intersect_features)

_________________________Analysing: microarray_________________________
Max difference: 8.171241461241152e-14
Mean difference: 4.357000735613189e-15
Max diff at position: GSM1701030
Max difference in intersect: 8.171241461241152e-14
Mean difference in intersect: 4.515621419262236e-15
Max diff at position in intersect: GSM1701030


In [22]:
### Concat the federated data and read in the centralized data ###
central_df_path = os.path.join(os.path.dirname(os.path.dirname(base_dir)), "after", "balanced", "central_intensities_log_corrected_UNION.tsv")
central_df = pd.read_csv(central_df_path, sep="\t", index_col=0)
federated_df, intersect_features = _concat_federated_results(clientWrappers, samples_in_columns=True)
_compare_central_federated_dfs("proteomic data", central_df, federated_df, intersect_features)

_________________________Analysing: proteomic data_________________________
Max difference: 1.9184653865522705e-13
Mean difference: 3.1758692034102934e-14
Max diff at position: Clinspect_E_coli_A_S42_Slot1-21_1_8670
Max difference in intersect: 1.1368683772161603e-13
Mean difference in intersect: 3.338152465379177e-14
Max diff at position in intersect: Ref8537_S11_20230414
