In [221]:
import os
import json
import numpy as np
import pandas as pd

from pathlib import Path

In [222]:
# Path setup
base_dir = Path("/home/yuliya/repos/cosybio/FedComBat/evaluation/d_combat/json")
cohorts = ["GSE129508", "GSE149276", "GSE58135"]

# Comparison parameters
RTOL = 1e-6
ATOL = 1e-8


In [223]:
def load_json(path):
    with open(path, "r", encoding="utf-8") as f:
        return json.load(f)


def to_df(obj):
    """Convert nested lists or dicts to DataFrame, safely."""
    if isinstance(obj, list):
        if isinstance(obj[0], dict):  # list of dicts
            return pd.DataFrame(obj)
        else:
            return pd.DataFrame([obj]) if not isinstance(obj[0], list) else pd.DataFrame(obj)
    elif isinstance(obj, dict):
        return pd.DataFrame.from_dict(obj)
    return pd.DataFrame(obj)


In [224]:
def compare_arrays(name, val1, val2):
    try:
        arr1 = np.array(val1)
        arr2 = np.array(val2)
        if arr1.shape != arr2.shape:
            return f"❌ {name}: shape mismatch {arr1.shape} vs {arr2.shape}"
        if not np.allclose(arr1, arr2, rtol=RTOL, atol=ATOL, equal_nan=True):
            max_diff = np.max(np.abs(arr1 - arr2))
            return f"⚠️  {name}: values differ (max abs diff = {max_diff:.2e})"
        return f"✅ {name}: match"
    except Exception as e:
        return f"❌ {name}: error comparing - {e}"


In [225]:
# how to get how many digits max are in the xty matrix after the decimal point
def get_max_digits(arr):
    if isinstance(arr, pd.DataFrame):
        arr = arr.values
    arr = arr.flatten()
    max_digits = max(
        len(f"{x}".rstrip("0").split(".")[1]) if "." in f"{x}" else 0
        for x in arr if not np.isnan(x)
    )
    return max_digits


def compare_dataframes(name, val1, val2, return_df=False):
    try:
        df1 = to_df(val1)
        df2 = to_df(val2)

        if name == "sigma" or name == "pooled_variance":
            df2 = df2.T
            df2.columns = df2.loc["_row"]
            df2 = df2.drop("_row")
            df2 = df2.reset_index(drop=True)
            # values to float
            df2 = df2.astype(float)

        df1.index = df1.index.astype(str)
        df2.index = df2.index.astype(str)
        df1.columns = df1.columns.astype(str)
        df2.columns = df2.columns.astype(str)
        
        if "gene_id" in df1.columns:
            # if gene_id is in the columns, set it as index
            df1.set_index("gene_id", inplace=True)
            df2.set_index("_row", inplace=True)
        else:
            # if "_row" in a column name, remove this column
            df1 = df1.loc[:, ~df1.columns.str.contains("index")]
            df2 = df2.loc[:, ~df2.columns.str.contains("_row")]
                
        if set(df1.columns) != set(df2.columns) or set(df1.index) != set(df2.index):
            if return_df:
                print(f"❌ {name}: row/col names mismatch")
                return df1, df2
            return f"❌ {name}: row/col names mismatch"
        
        # Sort indices to ensure consistent order for comparison
        df1 = df1.loc[df2.index, df2.columns]

        # get minimum number of digits after the decimal point
        min_digits = min(get_max_digits(df1), get_max_digits(df2))
        a, b = get_max_digits(df1), get_max_digits(df2)
        # round to the minimum number of digits after the decimal point
        df1 = df1.round(min_digits).copy()
        df2 = df2.round(min_digits).copy()
        # compare the two dataframes

        if return_df:
            print( "here")
            return df1, df2

        if not np.allclose(df1.values, df2.values, rtol=RTOL, atol=ATOL, equal_nan=True):
            max_diff = np.max(np.abs(df1.values - df2.values))
            if max_diff > 9e-5:
                return f"⚠️  {name}: values differ (max abs diff = {max_diff:.2e}, min_digits = {min_digits}, {(a, b)})"
            else:
                return f"✅ {name}: match, (max diff = {max_diff:.2e}, min_digits = {min_digits}, {(a, b)})"
        return f"✅ {name}: match"
    except Exception as e:
        return f"❌ {name}: error comparing DataFrames - {e}"
    


In [226]:
def compare_dicts(dict1, dict2):
    keys1 = set(dict1.keys())
    keys2 = set(dict2.keys())
    missing_in_py = keys1 - keys2
    missing_in_r = keys2 - keys1

    results = []

    if missing_in_py:
        results.append(f"❌ Missing in Python: {missing_in_py}")
    if missing_in_r:
        results.append(f"❌ Missing in R: {missing_in_r}")

    for key in sorted(keys1 & keys2):
        v1 = dict1[key]
        v2 = dict2[key]

        if isinstance(v1, (float, int, list)) and isinstance(v2, (float, int, list)):
            results.append(compare_arrays(key, v1, v2))
        elif isinstance(v1, (dict, list)) or isinstance(v2, (dict, list)):
            results.append(compare_dataframes(key, v1, v2))
        else:
            results.append(f"❓ {key}: unrecognized types ({type(v1)}, {type(v2)})")

    return results


In [227]:
for cohort in cohorts[:1]:
    py_file = base_dir / f"{cohort}_Py_out.json"
    r_file = base_dir / f"{cohort}_D_out.json"

    print(f"\n🔍 Comparing cohort: {cohort}")
    if not py_file.exists() or not r_file.exists():
        print(f"❌ Missing file for {cohort}")
        continue

    py_data = load_json(py_file)
    r_data = load_json(r_file)

    result_lines = compare_dicts(py_data, r_data)
    for line in result_lines:
        print(line)


🔍 Comparing cohort: GSE129508
✅ B_hat: match, (max diff = 5.00e-05, min_digits = 8, (20, 8))
⚠️  a_prior: values differ (max abs diff = 8.04e-06)
⚠️  b_prior: values differ (max abs diff = 1.57e-05)
✅ corrected_data: match, (max diff = 5.00e-05, min_digits = 8, (20, 8))
✅ delta_hat: match
✅ delta_star: match
⚠️  gamma_bar: values differ (max abs diff = 3.31e-05)
✅ gamma_hat: match
✅ gamma_star: match
✅ mod_mean: match, (max diff = 5.00e-05, min_digits = 8, (20, 8))
✅ pooled_variance: match
✅ sigma: match
✅ stand_mean: match
⚠️  t2: values differ (max abs diff = 2.04e-05)
✅ xtx: match
✅ xty: match


In [218]:
xty_py, xty_r = compare_dataframes('pooled_variance', py_data['pooled_variance'], r_data['pooled_variance'], return_df=True)

here


In [219]:
xty_py

_row,A1BG,A1CF,A2M,A2ML1,A2MP1,A3GALT2,A4GALT,A4GNT,AAAS,AACS,...,ZZEF1,ZZZ3,BP-21201H5.1,BP-21264C1.1,BP-2171C21.2,BP-2171C21.4,BP-2171C21.5,BP-2171C21.6,BP-2189O9.2,YR211F11.2
0,1.351,1.0412,1.6807,3.9378,1.7469,1.146,1.9965,0.9262,0.1614,0.3173,...,0.5698,0.222,0.9341,1.553,0.901,1.8032,0.8627,1.2274,1.9905,0.6619


In [220]:
xty_r

_row,A1BG,A1CF,A2M,A2ML1,A2MP1,A3GALT2,A4GALT,A4GNT,AAAS,AACS,...,ZZEF1,ZZZ3,BP-21201H5.1,BP-21264C1.1,BP-2171C21.2,BP-2171C21.4,BP-2171C21.5,BP-2171C21.6,BP-2189O9.2,YR211F11.2
0,0.8816,1.0835,2.0873,2.6537,1.5471,1.1657,1.7701,1.1715,0.149,0.271,...,0.3227,0.1831,0.8433,1.6968,0.9011,2.1716,1.0627,1.4699,2.438,1.3389


In [215]:
print(xty_py.dtypes)
print(xty_r.dtypes)


_row
A1BG            float64
A1CF            float64
A2M             float64
A2ML1           float64
A2MP1           float64
                 ...   
BP-2171C21.4    float64
BP-2171C21.5    float64
BP-2171C21.6    float64
BP-2189O9.2     float64
YR211F11.2      float64
Length: 28823, dtype: object
_row
A1BG            float64
A1CF            float64
A2M             float64
A2ML1           float64
A2MP1           float64
                 ...   
BP-2171C21.4    float64
BP-2171C21.5    float64
BP-2171C21.6    float64
BP-2189O9.2     float64
YR211F11.2      float64
Length: 28823, dtype: object


In [216]:
np.allclose(xty_py.values, xty_r.values, rtol=RTOL, atol=ATOL, equal_nan=True)

False

In [217]:
xty_r.index

Index(['0'], dtype='object')