In [63]:
import logging
import sys
# for json files
import json

import pandas as pd
import numpy as np
import yaml

# CHANGE THIS TO THE DIRECTORY WHERE YOU HAVE THE REPO CLONED
sys.path.append('/home/yuliya/repos/cosybio/FedComBat/fedcombat')


In [64]:
from classes.client import Client
import classes.coordinator_utils as c_utils

logging.basicConfig(
    level=logging.DEBUG, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", datefmt="%d-%b-%y %H:%M:%S"
)

# set params

In [65]:
data_dir = "before"
output_path = "after"

cohorts = ["GSE129508", "GSE58135", "GSE149276"]

data_dir = f"{data_dir}"  # path to data folder
output_path = f"{output_path}/"  # path to output folder

In [66]:
global_variables = set()
global_batch_labels = set()

clients = {}

for cohort_name in cohorts:

    client = Client()

    client.cohort_name = cohort_name
    logging.info(f"Processing cohort {cohort_name}")
    
    client.config_based_init(
        client_name = cohort_name, 
        input_folder = f"{data_dir}/{cohort_name}",
    )

    if global_variables:
        global_variables = global_variables.intersection(client.variables)
    else:
        global_variables = set(client.variables)
    global_batch_labels.update(client.batch_labels)

    clients[cohort_name] = client

if len(global_batch_labels) != len(set(global_batch_labels)):
    raise ValueError("Batch labels are not unique across clients, please adjust them")
    

            

INFO:root:Processing cohort GSE129508
INFO:classes.client:Got the following config:
{'FedComBat': {'data_filename': 'expr_for_correction.tsv', 'data_separator': '\t', 'min_samples': 5, 'covariates': ['lum'], 'smpc': True, 'design_filename': 'design.tsv', 'design_separator': '\t', 'rows_as_features': True, 'index_col': 0, 'position': 0, 'batch_col_name': 'batch'}}
INFO:classes.client:min_samples set to 5
INFO:classes.client:Opening dataset before/GSE129508/expr_for_correction.tsv
INFO:classes.client:Shape of rawdata: (5, 6)
INFO:classes.client:Cleaning up data, removing all-NaN rows and columns, removing all-zero rows
INFO:classes.client:Shape of data before cleanup: (5, 6)
INFO:classes.client:Shape of data after cleanup: (5, 6)
INFO:classes.client:Finished loading data, shape: (5, 6), num_features: 5, num_samples: 6
INFO:root:Processing cohort GSE58135
INFO:classes.client:Got the following config:
{'FedComBat': {'data_filename': 'expr_for_correction.tsv', 'data_separator': '\t', 'min_s

In [67]:
feature_information = []
for cohort_name in cohorts:
    client = clients[cohort_name]
    min_samples = max(len(global_batch_labels) + len(global_variables) + 1, client.min_samples)    
    batch_feature_presence = client.get_batch_feature_presence_info(min_samples=min_samples)

    feature_information.append(
        [client.cohort_name,
        client.position,
        client.reference_batch,
        batch_feature_presence]
    )


INFO:classes.client:Feature count: 5
INFO:classes.client:Dropped 0 features completely empty after processing batch 'GSE129508|0'.
INFO:classes.client:Checking feature presence in batch 'GSE129508|0' with 6 samples.
INFO:classes.client:Feature count: 5
INFO:classes.client:Dropped 0 features completely empty after processing batch 'GSE58135|2'.
INFO:classes.client:Checking feature presence in batch 'GSE58135|2' with 6 samples.
INFO:classes.client:Feature count: 5
INFO:classes.client:Dropped 0 features completely empty after processing batch 'GSE149276|1'.
INFO:classes.client:Checking feature presence in batch 'GSE149276|1' with 6 samples.


In [68]:
global_feature_names, feature_presence_matrix, cohorts_order = \
    c_utils.select_common_features_variables(feature_information,
                                            default_order=[0, 1, 2],
                                            min_clients=3)


INFO:classes.coordinator_utils:Found 5 features present in (at least) 3 clients
INFO:classes.coordinator_utils:Total number of unique features: 5
INFO:classes.coordinator_utils:Using specified client order: ['GSE129508', 'GSE149276', 'GSE58135']
INFO:classes.coordinator_utils:Cohorts order: ['GSE129508|0', 'GSE149276|1', 'GSE58135|2']


In [69]:
for cohort_name in cohorts:
    logging.info(f"\nProcessing cohort {cohort_name}")
    client = clients[cohort_name]
    logging.info("[validate] waiting for common features and covariates")

    client.validate_inputs(global_variables)
    logging.info("[validate] Inputs have been validated")
    client.set_data(global_feature_names)
    logging.info("[validate] Data has been set to contain all global features")

    # get all client names to generate design matrix
    client.create_design(cohorts_order)
    logging.info(f"[validate] Design matrix has been created with shape {client.design.shape}")
    logging.info("[validate] design has been created")

INFO:root:
Processing cohort GSE129508
INFO:root:[validate] waiting for common features and covariates
INFO:classes.client:Client GSE129508: Data validated
INFO:classes.client:Client GSE129508: Inputs validated.
INFO:root:[validate] Inputs have been validated
INFO:classes.client:Local features: 5; Global features: 5
INFO:classes.client:Dropping 0 extra local features/rows.
INFO:root:[validate] Data has been set to contain all global features
INFO:classes.client:Client GSE129508 has only one batch
INFO:root:[validate] Design matrix has been created with shape (6, 4)
INFO:root:[validate] design has been created
INFO:root:
Processing cohort GSE58135
INFO:root:[validate] waiting for common features and covariates
INFO:classes.client:Client GSE58135: Data validated
INFO:classes.client:Client GSE58135: Inputs validated.
INFO:root:[validate] Inputs have been validated
INFO:classes.client:Local features: 5; Global features: 5
INFO:classes.client:Dropping 0 extra local features/rows.
INFO:root:

In [70]:
XtX_global = None
XtY_global = None
ref_size_global = None

for cohort_name in cohorts:
    logging.info(f"\nProcessing cohort {cohort_name}")
    client = clients[cohort_name]

    logging.info("[ComBat-first_step:] Starting the first step of ComBat")
    logging.info(f"[ComBat-first_step:] Adjusting for {len(client.variables)} covariate(s) or covariate level(s)")
    if client.mean_only:
        logging.info("[ComBat-first_step:] Performing ComBat with mean only.")
    
    # getting XtX and Xty
    XtX, XtY = client.compute_XtX_XtY()
    design_cols = client.design.shape[1]
    ref_size = [sum(client.design.iloc[:, i]) for i in range(design_cols - len(client.variables))]

    if XtY_global is None:
        XtX_global = XtX
        XtY_global = XtY
        ref_size_global = np.array(ref_size)
    else:
        XtX_global += XtX
        XtY_global += XtY
        ref_size_global += ref_size
    
    logging.info("[ComBat-first_step:] Computation done, sending data to coordinator")
    logging.info(f"[ComBat-first_step:] XtX of shape {XtX.shape}, X of shape {client.design.shape}, XtY of shape {XtY.shape}")


logging.info("[ComBat-first_step:] All clients have finished the first step of ComBat")
logging.info(f"[ComBat-first_step:] Ref size: {ref_size_global}")
logging.info(f"[ComBat-first_step:] XtX shape: {XtX_global.shape}, XtY shape: {XtY_global.shape}")

ref_size = ref_size_global


INFO:root:
Processing cohort GSE129508
INFO:root:[ComBat-first_step:] Starting the first step of ComBat
INFO:root:[ComBat-first_step:] Adjusting for 1 covariate(s) or covariate level(s)
INFO:root:[ComBat-first_step:] Computation done, sending data to coordinator
INFO:root:[ComBat-first_step:] XtX of shape (5, 4, 4), X of shape (6, 4), XtY of shape (5, 4)
INFO:root:
Processing cohort GSE58135
INFO:root:[ComBat-first_step:] Starting the first step of ComBat
INFO:root:[ComBat-first_step:] Adjusting for 1 covariate(s) or covariate level(s)
INFO:root:[ComBat-first_step:] Computation done, sending data to coordinator
INFO:root:[ComBat-first_step:] XtX of shape (5, 4, 4), X of shape (6, 4), XtY of shape (5, 4)
INFO:root:
Processing cohort GSE149276
INFO:root:[ComBat-first_step:] Starting the first step of ComBat
INFO:root:[ComBat-first_step:] Adjusting for 1 covariate(s) or covariate level(s)
INFO:root:[ComBat-first_step:] Computation done, sending data to coordinator
INFO:root:[ComBat-first_

In [71]:
n = client.data.values.shape[0]
k = client.design.shape[1]

B_hat = c_utils.compute_B_hat(XtX_global, XtY_global)
logging.info("[Compute_b_hat:] B_hat has been computed.")
grand_mean, stand_mean = c_utils.compute_mean(XtX_global, XtY_global, B_hat, ref_size)
logging.info("[Compute_b_hat:] Grand mean and stand mean have been computed.")

INFO:classes.coordinator_utils:Computing B_hat
INFO:classes.coordinator_utils:B_hat has been computed
INFO:root:[Compute_b_hat:] B_hat has been computed.
INFO:classes.coordinator_utils:Grand mean and stand mean have been computed
INFO:classes.coordinator_utils:Grand mean shape: (5,), stand mean shape: (5, 18)
INFO:classes.coordinator_utils:XtX_global shape: (5, 4, 4), XtY_global shape: (5, 4)
INFO:root:[Compute_b_hat:] Grand mean and stand mean have been computed.


# CORRECT  ^^

In [72]:
var_list = []

for cohort_name in cohorts:
    client = clients[cohort_name]
    sigma_site = client.get_sigma_summary(B_hat, stand_mean)
    logging.info(f"[Compute_sigma_site:] Sigma site has been computed for {cohort_name}")
    logging.info(f"[Compute_sigma_site:] Sigma site shape: {sigma_site.shape}")

    var_list.append(sigma_site)

pooled_variance = c_utils.get_pooled_variance(var_list, ref_size)
logging.info("[Compute_pooled_variance:] Pooled variance has been computed.")
logging.info(f"[Compute_pooled_variance:] Pooled variance shape: {pooled_variance.shape}")


INFO:root:[Compute_sigma_site:] Sigma site has been computed for GSE129508
INFO:root:[Compute_sigma_site:] Sigma site shape: (5,)
INFO:root:[Compute_sigma_site:] Sigma site has been computed for GSE58135
INFO:root:[Compute_sigma_site:] Sigma site shape: (5,)
INFO:root:[Compute_sigma_site:] Sigma site has been computed for GSE149276
INFO:root:[Compute_sigma_site:] Sigma site shape: (5,)
INFO:root:[Compute_pooled_variance:] Pooled variance has been computed.
INFO:root:[Compute_pooled_variance:] Pooled variance shape: (5,)


In [73]:
# pooled_variance - reshare and add columns - features
pooled_variance = pd.DataFrame(pooled_variance.reshape(-1, 1), index=global_feature_names, columns=["pooled_variance"])
pooled_variance

Unnamed: 0,pooled_variance
BTLA,2.661572
RP1-209B5.2,0.893536
RP1-40G4P.1,0.922773
TAL2,0.797096
TTC7A,0.57092


# INCORRECT ^^^

In [11]:

for cohort_name in cohorts:
    client = clients[cohort_name]
    logging.info("[calculate_estimates] Getting standardized data...")
    client.get_standardized_data(
        B_hat,
        stand_mean,
        pooled_variance,
        ref_size
    )
    logging.info("[calculate_estimates] Standardized data has been computed.")

    # Get naive estimators
    client.get_naive_estimates()

    if client.eb_param:
        if client.parametric:
            logging.info("[calculate_estimates] Getting parametric Empirical Bayes estimates...")
        else:
            logging.info("[calculate_estimates] Getting non-parametric Empirical Bayes estimates...")
        client.get_eb_estimators()
        logging.info("[calculate_estimates] Empirical Bayes estimates have been computed.")
    else:
        client.gamma_star = client.gamma_hat.copy()
        client.delta_star = client.delta_hat.copy()
        logging.info("[calculate_estimates] Non-Empirical Bayes estimates have been computed.")

    corrected_data = client.get_corrected_data(pooled_variance)
    logging.info("[calculate_estimates] Corrected data has been computed.")




INFO:root:[calculate_estimates] Getting standardized data...
INFO:classes.client:mod_mean shape: (5, 6)
INFO:classes.client:Standardizing data...
INFO:classes.client:Data standardized, shape: (5, 6)
INFO:root:[calculate_estimates] Standardized data has been computed.
INFO:classes.client:Computed gamma_hat, shape: (1, 5)
INFO:classes.client:Computed delta_hat, shape: (1, 5)
INFO:classes.client:Computed naive estimates.
INFO:root:[calculate_estimates] Getting parametric Empirical Bayes estimates...


indices [0 1 2 3 4 5]


IndexError: index 5 is out of bounds for axis 1 with size 5

In [None]:
client.delta_hat

array([[0.65821301, 0.29199562, 0.56371845, 0.17261438, 2.04382362]])