In [14]:
import logging
import sys
# for json files
import json

import pandas as pd
import numpy as np
import yaml

# CHANGE THIS TO THE DIRECTORY WHERE YOU HAVE THE REPO CLONED
sys.path.append('/home/yuliya/repos/cosybio/FedComBat/fedcombat')


In [15]:
from classes.client import Client
import classes.coordinator_utils as c_utils

logging.basicConfig(
    level=logging.DEBUG, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", datefmt="%d-%b-%y %H:%M:%S"
)

# set params

In [16]:
# data_dir = "before"
# data_dir = "/home/yuliya/repos/cosybio/FedComBat/datasets/small_test/before"
# output_path = "/home/yuliya/repos/cosybio/FedComBat/datasets/small_test/after"
data_dir = "/home/yuliya/repos/cosybio/FedComBat/datasets/Breast_cancer_RNASeq/before"
output_path = "/home/yuliya/repos/cosybio/FedComBat/datasets/Breast_cancer_RNASeq/after"

cohorts = ["GSE129508", "GSE58135", "GSE149276"]

data_dir = f"{data_dir}"  # path to data folder
output_path = f"{output_path}/"  # path to output folder

In [17]:
global_variables = set()
global_batch_labels = set()

clients = {}

for cohort_name in cohorts:

    client = Client()

    client.cohort_name = cohort_name
    logging.info(f"Processing cohort {cohort_name}")
    
    client.config_based_init(
        client_name = cohort_name, 
        input_folder = f"{data_dir}/{cohort_name}",
    )

    if global_variables:
        global_variables = global_variables.intersection(client.variables)
    else:
        global_variables = set(client.variables)
    global_batch_labels.update(client.batch_labels)

    clients[cohort_name] = client

if len(global_batch_labels) != len(set(global_batch_labels)):
    raise ValueError("Batch labels are not unique across clients, please adjust them")
            

INFO:root:Processing cohort GSE129508
INFO:classes.client:Got the following config:
{'FedComBat': {'data_filename': 'expr_for_correction.tsv', 'data_separator': '\t', 'min_samples': 5, 'covariates': ['lum'], 'smpc': True, 'design_filename': 'design.tsv', 'design_separator': '\t', 'rows_as_features': True, 'index_col': 0, 'position': 0, 'batch_col_name': 'batch'}}
INFO:classes.client:min_samples set to 5
INFO:classes.client:Opening dataset /home/yuliya/repos/cosybio/FedComBat/datasets/Breast_cancer_RNASeq/before/GSE129508/expr_for_correction.tsv
INFO:classes.client:Shape of rawdata: (30174, 25)
INFO:classes.client:Cleaning up data, removing all-NaN rows and columns, removing all-zero rows
INFO:classes.client:Shape of data before cleanup: (30174, 25)
INFO:classes.client:Shape of data after cleanup: (30174, 25)
INFO:classes.client:Finished loading data, shape: (30174, 25), num_features: 30174, num_samples: 25
INFO:root:Processing cohort GSE58135
INFO:classes.client:Got the following confi

In [18]:
feature_information = []
for cohort_name in cohorts:
    client = clients[cohort_name]
    # min_samples = max(len(global_batch_labels) + len(global_variables) + 1, client.min_samples)    
    min_samples = 0
    batch_feature_presence = client.get_batch_feature_presence_info(min_samples=client.min_samples)

    feature_information.append(
        [client.cohort_name,
        client.position,
        client.reference_batch,
        batch_feature_presence]
    )


INFO:classes.client:Feature count: 30174
INFO:classes.client:Dropped 0 features completely empty after processing batch 'GSE129508|0'.
INFO:classes.client:Checking feature presence in batch 'GSE129508|0' with 25 samples.
INFO:classes.client:Feature count: 34675
INFO:classes.client:Dropped 0 features completely empty after processing batch 'GSE58135|2'.
INFO:classes.client:Checking feature presence in batch 'GSE58135|2' with 75 samples.
INFO:classes.client:Feature count: 31377
INFO:classes.client:Dropped 0 features completely empty after processing batch 'GSE149276|1'.
INFO:classes.client:Checking feature presence in batch 'GSE149276|1' with 31 samples.


In [19]:
global_feature_names, feature_presence_matrix, cohorts_order = \
    c_utils.select_common_features_variables(feature_information,
                                            default_order=[0, 1, 2],
                                            min_clients=3)


INFO:classes.coordinator_utils:Found 28823 features present in (at least) 3 clients
INFO:classes.coordinator_utils:Total number of unique features: 34818
INFO:classes.coordinator_utils:Using specified client order: ['GSE129508', 'GSE149276', 'GSE58135']


INFO:classes.coordinator_utils:Cohorts order: ['GSE129508|0', 'GSE149276|1', 'GSE58135|2']


In [20]:
for cohort_name in cohorts:
    logging.info(f"\nProcessing cohort {cohort_name}")
    client = clients[cohort_name]
    logging.info("[validate] waiting for common features and covariates")

    client.validate_inputs(global_variables)
    logging.info("[validate] Inputs have been validated")
    client.set_data(global_feature_names)
    logging.info("[validate] Data has been set to contain all global features")

    client.TEST_MODE = True
    # get all client names to generate design matrix
    client.create_design(cohorts_order)
    logging.info(f"[validate] Design matrix has been created with shape {client.design.shape}")
    logging.info("[validate] design has been created")

INFO:root:
Processing cohort GSE129508
INFO:root:[validate] waiting for common features and covariates
INFO:classes.client:Client GSE129508: Data validated
INFO:classes.client:Client GSE129508: Inputs validated.
INFO:root:[validate] Inputs have been validated
INFO:classes.client:Local features: 30174; Global features: 28823
INFO:classes.client:Dropping 1351 extra local features/rows.
INFO:root:[validate] Data has been set to contain all global features
INFO:classes.client:Client GSE129508 has only one batch
INFO:root:[validate] Design matrix has been created with shape (25, 4)
INFO:root:[validate] design has been created
INFO:root:
Processing cohort GSE58135
INFO:root:[validate] waiting for common features and covariates
INFO:classes.client:Client GSE58135: Data validated
INFO:classes.client:Client GSE58135: Inputs validated.
INFO:root:[validate] Inputs have been validated
INFO:classes.client:Local features: 34675; Global features: 28823
INFO:classes.client:Dropping 5852 extra local fe

In [21]:
XtX_global = None
XtY_global = None
ref_size_global = None

for cohort_name in cohorts:
    logging.info(f"\nProcessing cohort {cohort_name}")
    client = clients[cohort_name]

    logging.info("[ComBat-first_step:] Starting the first step of ComBat")
    logging.info(f"[ComBat-first_step:] Adjusting for {len(client.variables)} covariate(s) or covariate level(s)")
    if client.mean_only:
        logging.info("[ComBat-first_step:] Performing ComBat with mean only.")
    
    # getting XtX and Xty
    XtX, XtY = client.compute_XtX_XtY()
    # XtX, XtX = XtX[0], XtX[1]
    design_cols = client.design.shape[1]
    ref_size = [sum(client.design.iloc[:, i]) for i in range(design_cols - len(client.variables))]

    if XtY_global is None:
        XtX_global = XtX.copy()
        XtY_global = XtY.copy()
        ref_size_global = np.array(ref_size)
    else:
        XtX_global += XtX.copy()
        XtY_global += XtY.copy()
        ref_size_global += ref_size.copy()
    
    logging.info("[ComBat-first_step:] Computation done, sending data to coordinator")
    logging.info(f"[ComBat-first_step:] XtX of shape {XtX.shape}, X of shape {client.design.shape}, XtY of shape {XtY.shape}")


logging.info("[ComBat-first_step:] All clients have finished the first step of ComBat")
logging.info(f"[ComBat-first_step:] Ref size: {ref_size_global}")
logging.info(f"[ComBat-first_step:] XtX shape: {XtX_global.shape}, XtY shape: {XtY_global.shape}")

ref_size = ref_size_global.copy()


INFO:root:
Processing cohort GSE129508
INFO:root:[ComBat-first_step:] Starting the first step of ComBat
INFO:root:[ComBat-first_step:] Adjusting for 1 covariate(s) or covariate level(s)
INFO:root:[ComBat-first_step:] Computation done, sending data to coordinator
INFO:root:[ComBat-first_step:] XtX of shape (28823, 4, 4), X of shape (25, 4), XtY of shape (28823, 4)
INFO:root:
Processing cohort GSE58135
INFO:root:[ComBat-first_step:] Starting the first step of ComBat
INFO:root:[ComBat-first_step:] Adjusting for 1 covariate(s) or covariate level(s)
INFO:root:[ComBat-first_step:] Computation done, sending data to coordinator
INFO:root:[ComBat-first_step:] XtX of shape (28823, 4, 4), X of shape (75, 4), XtY of shape (28823, 4)
INFO:root:
Processing cohort GSE149276
INFO:root:[ComBat-first_step:] Starting the first step of ComBat
INFO:root:[ComBat-first_step:] Adjusting for 1 covariate(s) or covariate level(s)
INFO:root:[ComBat-first_step:] Computation done, sending data to coordinator
INFO:r

In [22]:
n = client.data.values.shape[0]
k = client.design.shape[1]

B_hat = c_utils.compute_B_hat(XtX_global, XtY_global)
logging.info("[Compute_b_hat:] B_hat has been computed.")
grand_mean, stand_mean = c_utils.compute_mean(XtX_global, XtY_global, B_hat, ref_size)
logging.info("[Compute_b_hat:] Grand mean and stand mean have been computed.")

INFO:classes.coordinator_utils:Computing B_hat
INFO:classes.coordinator_utils:B_hat has been computed
INFO:root:[Compute_b_hat:] B_hat has been computed.
INFO:classes.coordinator_utils:Grand mean and stand mean have been computed
INFO:classes.coordinator_utils:Grand mean shape: (28823,), stand mean shape: (28823, 131)
INFO:classes.coordinator_utils:XtX_global shape: (28823, 4, 4), XtY_global shape: (28823, 4)
INFO:root:[Compute_b_hat:] Grand mean and stand mean have been computed.


In [23]:
pd.DataFrame(stand_mean)

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,121,122,123,124,125,126,127,128,129,130
0,8.883213,8.883213,8.883213,8.883213,8.883213,8.883213,8.883213,8.883213,8.883213,8.883213,...,8.883213,8.883213,8.883213,8.883213,8.883213,8.883213,8.883213,8.883213,8.883213,8.883213
1,4.229735,4.229735,4.229735,4.229735,4.229735,4.229735,4.229735,4.229735,4.229735,4.229735,...,4.229735,4.229735,4.229735,4.229735,4.229735,4.229735,4.229735,4.229735,4.229735,4.229735
2,13.534036,13.534036,13.534036,13.534036,13.534036,13.534036,13.534036,13.534036,13.534036,13.534036,...,13.534036,13.534036,13.534036,13.534036,13.534036,13.534036,13.534036,13.534036,13.534036,13.534036
3,9.716086,9.716086,9.716086,9.716086,9.716086,9.716086,9.716086,9.716086,9.716086,9.716086,...,9.716086,9.716086,9.716086,9.716086,9.716086,9.716086,9.716086,9.716086,9.716086,9.716086
4,4.190896,4.190896,4.190896,4.190896,4.190896,4.190896,4.190896,4.190896,4.190896,4.190896,...,4.190896,4.190896,4.190896,4.190896,4.190896,4.190896,4.190896,4.190896,4.190896,4.190896
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
28818,8.066962,8.066962,8.066962,8.066962,8.066962,8.066962,8.066962,8.066962,8.066962,8.066962,...,8.066962,8.066962,8.066962,8.066962,8.066962,8.066962,8.066962,8.066962,8.066962,8.066962
28819,10.392183,10.392183,10.392183,10.392183,10.392183,10.392183,10.392183,10.392183,10.392183,10.392183,...,10.392183,10.392183,10.392183,10.392183,10.392183,10.392183,10.392183,10.392183,10.392183,10.392183
28820,12.564155,12.564155,12.564155,12.564155,12.564155,12.564155,12.564155,12.564155,12.564155,12.564155,...,12.564155,12.564155,12.564155,12.564155,12.564155,12.564155,12.564155,12.564155,12.564155,12.564155
28821,10.985053,10.985053,10.985053,10.985053,10.985053,10.985053,10.985053,10.985053,10.985053,10.985053,...,10.985053,10.985053,10.985053,10.985053,10.985053,10.985053,10.985053,10.985053,10.985053,10.985053


In [24]:
var_list = []

for cohort_name in cohorts:
    client = clients[cohort_name]
    client.B_hat = B_hat
    sigma_site = client.get_sigma_summary(B_hat, ref_size)
    logging.info(f"[Compute_sigma_site:] Sigma site has been computed for {cohort_name}")
    logging.info(f"[Compute_sigma_site:] Sigma site shape: {sigma_site.shape}")

    var_list.append(sigma_site * client.data.shape[1])


INFO:root:[Compute_sigma_site:] Sigma site has been computed for GSE129508
INFO:root:[Compute_sigma_site:] Sigma site shape: (28823,)
INFO:root:[Compute_sigma_site:] Sigma site has been computed for GSE58135
INFO:root:[Compute_sigma_site:] Sigma site shape: (28823,)
INFO:root:[Compute_sigma_site:] Sigma site has been computed for GSE149276
INFO:root:[Compute_sigma_site:] Sigma site shape: (28823,)


In [25]:
pooled_variance = c_utils.get_pooled_variance(var_list, ref_size)
logging.info("[Compute_pooled_variance:] Pooled variance has been computed.")
logging.info(f"[Compute_pooled_variance:] Pooled variance shape: {pooled_variance.shape}")


INFO:root:[Compute_pooled_variance:] Pooled variance has been computed.
INFO:root:[Compute_pooled_variance:] Pooled variance shape: (28823,)


In [26]:
# # pooled_variance - reshare and add columns - features
# pooled_variance = pd.DataFrame(pooled_variance.reshape(-1, 1), index=global_feature_names, columns=["pooled_variance"])
# pooled_variance

In [27]:
for cohort_name in cohorts:
    client = clients[cohort_name]
    logging.info("[calculate_estimates] Getting standardized data...")
    client.get_standardized_data(
        B_hat,
        stand_mean,
        pooled_variance,
        ref_size
    )
    logging.info("[calculate_estimates] Standardized data has been computed.")

    # Get naive estimators
    client.get_naive_estimates()

    if client.eb_param:
        if client.parametric:
            logging.info("[calculate_estimates] Getting parametric Empirical Bayes estimates...")
        else:
            logging.info("[calculate_estimates] Getting non-parametric Empirical Bayes estimates...")
        client.get_eb_estimators()
        logging.info("[calculate_estimates] Empirical Bayes estimates have been computed.")
    else:
        client.gamma_star = client.gamma_hat.copy()
        client.delta_star = client.delta_hat.copy()
        logging.info("[calculate_estimates] Non-Empirical Bayes estimates have been computed.")

    client.corrected_data = client.get_corrected_data(pooled_variance)
    logging.info("[calculate_estimates] Corrected data has been computed.")



INFO:root:[calculate_estimates] Getting standardized data...
INFO:classes.client:mod_mean shape: (28823, 25)
INFO:classes.client:Standardizing data...
INFO:classes.client:Data standardized, shape: (28823, 25)
INFO:root:[calculate_estimates] Standardized data has been computed.
INFO:classes.client:Computed gamma_hat, shape: (1, 28823)
INFO:classes.client:Computed delta_hat, shape: (1, 28823)
INFO:classes.client:Computed naive estimates.
INFO:root:[calculate_estimates] Getting parametric Empirical Bayes estimates...
INFO:classes.client:Computed a_prior and b_prior.
INFO:classes.client:Computed EB estimators.
INFO:root:[calculate_estimates] Empirical Bayes estimates have been computed.
INFO:classes.client:Correcting data...
INFO:classes.client:Rescaling and adding back mean adjustments...
INFO:classes.client:Data shape: (28823, 25)
INFO:classes.client:var_pooled shape: (28823,)
INFO:classes.client:stand_mean shape: (28823, 25)
INFO:classes.client:mod_mean shape: (28823, 25)
INFO:classes.c

In [32]:
client.mod_mean

sample_id,GSM4495391,GSM4495392,GSM4495393,GSM4495394,GSM4495395,GSM4495396,GSM4495397,GSM4495398,GSM4495399,GSM4495400,...,GSM4495414,GSM4495415,GSM4495416,GSM4495417,GSM4495418,GSM4495419,GSM4495420,GSM4495422,GSM4495423,GSM4495424
gene_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
A1BG,0.406062,0.406062,0.406062,0.406062,0.406062,0.406062,0.0,0.406062,0.0,0.0,...,0.406062,0.0,0.0,0.0,0.0,0.0,0.0,0.406062,0.0,0.0
A1CF,-0.222285,-0.222285,-0.222285,-0.222285,-0.222285,-0.222285,0.0,-0.222285,0.0,0.0,...,-0.222285,0.0,0.0,0.0,0.0,0.0,0.0,-0.222285,0.0,0.0
A2M,-0.012262,-0.012262,-0.012262,-0.012262,-0.012262,-0.012262,0.0,-0.012262,0.0,0.0,...,-0.012262,0.0,0.0,0.0,0.0,0.0,0.0,-0.012262,0.0,0.0
A2ML1,-3.718805,-3.718805,-3.718805,-3.718805,-3.718805,-3.718805,0.0,-3.718805,0.0,0.0,...,-3.718805,0.0,0.0,0.0,0.0,0.0,0.0,-3.718805,0.0,0.0
A2MP1,-0.233740,-0.233740,-0.233740,-0.233740,-0.233740,-0.233740,0.0,-0.233740,0.0,0.0,...,-0.233740,0.0,0.0,0.0,0.0,0.0,0.0,-0.233740,0.0,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
ZYG11A,-0.614374,-0.614374,-0.614374,-0.614374,-0.614374,-0.614374,0.0,-0.614374,0.0,0.0,...,-0.614374,0.0,0.0,0.0,0.0,0.0,0.0,-0.614374,0.0,0.0
ZYG11B,-0.181551,-0.181551,-0.181551,-0.181551,-0.181551,-0.181551,0.0,-0.181551,0.0,0.0,...,-0.181551,0.0,0.0,0.0,0.0,0.0,0.0,-0.181551,0.0,0.0
ZYX,-0.167689,-0.167689,-0.167689,-0.167689,-0.167689,-0.167689,0.0,-0.167689,0.0,0.0,...,-0.167689,0.0,0.0,0.0,0.0,0.0,0.0,-0.167689,0.0,0.0
ZZEF1,0.359137,0.359137,0.359137,0.359137,0.359137,0.359137,0.0,0.359137,0.0,0.0,...,0.359137,0.0,0.0,0.0,0.0,0.0,0.0,0.359137,0.0,0.0


In [30]:
pd.DataFrame(B_hat)

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,28813,28814,28815,28816,28817,28818,28819,28820,28821,28822
0,7.9549,2.167818,14.042447,9.238513,0.536794,2.092622,8.323541,2.259626,10.484661,9.755008,...,9.484056,8.522177,8.983341,9.291917,9.474096,7.812072,10.050047,10.900205,11.525582,9.845318
1,7.328667,2.146903,13.772212,8.476767,5.41232,1.479288,7.768384,0.942619,8.995859,9.634247,...,9.759435,10.697336,6.983636,8.756266,9.354254,7.276928,10.303749,10.967788,9.749539,10.818301
2,9.835197,5.777945,13.26612,10.387528,4.904075,4.067713,9.052282,2.076185,10.830289,11.467162,...,9.411593,9.97818,6.265033,7.909051,10.402089,8.478472,10.54278,13.778637,11.315555,10.333807
3,0.406062,-0.222285,-0.012262,-3.718805,-0.23374,-0.887525,0.055727,0.211751,0.681317,0.054578,...,-0.504963,-0.527216,0.422732,0.514321,0.087984,-0.614374,-0.181551,-0.167689,0.359137,-0.026562


In [26]:
merged_data = pd.DataFrame()

for cohort_name in cohorts:
    client = clients[cohort_name]
    logging.info(f"[save_data] Saving data for {cohort_name}")
    if merged_data.empty:
        merged_data = client.corrected_data.copy()
    else:
        merged_data = pd.concat([merged_data, client.corrected_data], axis=1)

print(merged_data.shape)
merged_data.to_csv(f"{output_path}script_fed_data.csv", index=True, sep="\t")

INFO:root:[save_data] Saving data for GSE129508
INFO:root:[save_data] Saving data for GSE58135
INFO:root:[save_data] Saving data for GSE149276


(28823, 131)


# JSON CHECK

In [27]:
import numpy as np
import pandas as pd
import json

def convert(obj):
    if isinstance(obj, np.ndarray):
        return obj.tolist()
    elif isinstance(obj, pd.DataFrame):
        # add index as column
        obj.reset_index(inplace=True)
        return obj.to_dict(orient="records")
    elif isinstance(obj, pd.Series):
        return obj.tolist()
    elif isinstance(obj, (np.integer, np.floating)):
        return obj.item()
    else:
        raise TypeError(f"Object of type {type(obj).__name__} is not JSON serializable")


In [28]:
for cohort_name in cohorts:
    client = clients[cohort_name]

    check_dict = {
        "xtx": client.XtX[0],
        "xty": pd.DataFrame(client.XtY, index=client.data.index).T,
        "corrected_data": client.corrected_data,
        "gamma_hat": pd.DataFrame(client.gamma_hat, columns=client.data.index),
        "gamma_bar": float(client.gamma_bar[0]),
        "gamma_star": pd.DataFrame(client.gamma_star, columns=client.data.index),
        "delta_hat": pd.DataFrame(client.delta_hat, columns=client.data.index),
        "delta_star": pd.DataFrame(client.delta_star, columns=client.data.index),
        "t2": float(client.t2[0]),
        "a_prior": float(client.a_prior[0]),
        "b_prior": float(client.b_prior[0]),
        "stand_mean": pd.DataFrame(client.stand_mean, columns=client.data.columns, index=client.data.index),
        "mod_mean": client.mod_mean,
        "pooled_variance": pd.DataFrame(client.pooled_var.reshape(1,-1), columns=client.data.index),
        "sigma": pd.DataFrame(client.sigma.reshape(1,-1), columns=client.data.index),
        "B_hat": pd.DataFrame(client.B_hat, columns=client.data.index)
    }

    with open(f"/home/yuliya/repos/cosybio/FedComBat/evaluation/d_combat/json/{cohort_name}_Py_out.json", "w", encoding="utf-8") as f:
        json.dump(check_dict, f, indent=4, ensure_ascii=False, default=convert)


