## Benchmark of method on Synehtic Data

In [44]:
import wandb
import pandas as pd
import seaborn as sns
import conceptlab as clab

import matplotlib.pyplot as plt
import yaml
import numpy as np
import os.path as osp

In [45]:
with open('wandb_info.txt') as f:
    wandb_user_name = f.readlines()[0]

In [46]:
OUT_DIR = '../../../results/synthetic/'
raw_filename = 'benchmark_synthetic.csv'

DOWNLOAD = False

In [47]:
raw_path = osp.join(OUT_DIR, raw_filename)
if osp.isfile(raw_path) and not DOWNLOAD:

    _df = pd.read_csv(raw_path, index_col=0, header=0)
else:
    entity, project = wandb_user_name, "clab_benchmark_2"
    _df = clab.utils.wandb.download_wandb_project(project, entity)
    _df.to_csv(
        raw_path,
    )

In [48]:
df = _df.copy()

In [49]:
df = df.iloc[
    np.isin(df["model"].values, ["cbm", "biolord", "cinemaot", "cbmfm", "cbmfm_raw"])
]

In [50]:
# Update cbmfm_syn
df.loc[
    (df["model"] == "cbmfm") & (df["model.edit"] == True), "model"
] = "scCBGM-FM (edit)"
df.loc[
    (df["model"] == "cbmfm") & (df["model.edit"] == False), "model"
] = "scCBGM-FM (decode)"

# Update cbmfm_raw_syn
df.loc[
    (df["model"] == "cbmfm_raw") & (df["model.edit"] == True), "model"
] = "Vanilla-FM (edit)"
df.loc[
    (df["model"] == "cbmfm_raw") & (df["model.edit"] == False), "model"
] = "Vanilla-FM (decode)"

In [51]:
df["model"] = df["model"].map(
    lambda x: {
        "cbm": "scCBGM",
        "biolord": "biolord",
        "cinemaot": "CINEMA-OT",
    }.get(x, x)
)

In [52]:
df

Unnamed: 0,data,seed,model,model.edit,data/intervention_labels,_runtime,_step,_timestamp,_wandb,cell_level_mse
0,synthetic_1,69,scCBGM-FM (decode),False,synthetic/intervention_0,,,,,
1,synthetic_1,13,scCBGM,,synthetic/intervention_0,279.266422,0.0,1.758585e+09,{'runtime': 278},0.049712
2,synthetic_1,42,scCBGM,,synthetic/intervention_0,277.137271,0.0,1.758585e+09,{'runtime': 276},0.042684
3,synthetic_1,1337,scCBGM,,synthetic/intervention_0,277.107152,0.0,1.758585e+09,{'runtime': 276},0.050928
4,synthetic_1,69,scCBGM,,synthetic/intervention_0,278.937991,0.0,1.758585e+09,{'runtime': 278},0.043233
...,...,...,...,...,...,...,...,...,...,...
416,synthetic_3,69,scCBGM-FM (edit),True,synthetic/intervention_4,476.275695,0.0,1.758653e+09,{'runtime': 475},0.004026
417,synthetic_3,13,scCBGM-FM (decode),False,synthetic/intervention_4,471.325262,0.0,1.758653e+09,{'runtime': 471},0.051331
418,synthetic_3,42,scCBGM-FM (decode),False,synthetic/intervention_4,464.643063,0.0,1.758653e+09,{'runtime': 464},0.054233
419,synthetic_3,1337,scCBGM-FM (decode),False,synthetic/intervention_4,466.602726,0.0,1.758653e+09,{'runtime': 466},0.053578


In [53]:
df.rename(columns = {'cell_level_mse':'MSE_intervened'}, inplace = True)

In [54]:
grp_columns = ['model','data']
metric_columns = ['MSE_intervened']


In [55]:
df_mean = df[grp_columns + metric_columns].groupby(grp_columns).agg('mean').reset_index()
df_mean = df_mean.pivot_table(index = 'model', columns = 'data', values = 'MSE_intervened')
df_mean = df_mean.map(lambda x: f"{x:.4f}" if isinstance(x, float) else str(x))


df_std = df[grp_columns + metric_columns].groupby(grp_columns).agg('std').reset_index()
df_std = df_std.pivot_table(index = 'model', columns = 'data', values = 'MSE_intervened')
df_std = df_std.map(lambda x: f"{x:.4f}" if isinstance(x, float) else str(x))


df_table = df_mean.astype(str) + '$\\pm$' + df_std.astype(str)

In [56]:
df_table

data,synthetic_1,synthetic_2,synthetic_3
model,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
CINEMA-OT,0.0390$\pm$0.0025,0.0393$\pm$0.0012,0.0400$\pm$0.0034
Vanilla-FM (decode),0.0756$\pm$0.0014,0.0769$\pm$0.0014,0.0773$\pm$0.0041
Vanilla-FM (edit),0.0031$\pm$0.0014,0.0049$\pm$0.0016,0.0049$\pm$0.0027
biolord,0.0402$\pm$0.0014,0.0411$\pm$0.0021,0.0428$\pm$0.0021
scCBGM,0.0388$\pm$0.0053,0.0318$\pm$0.0071,0.0407$\pm$0.0027
scCBGM-FM (decode),0.0539$\pm$0.0039,0.0444$\pm$0.0082,0.0594$\pm$0.0092
scCBGM-FM (edit),0.0025$\pm$0.0016,0.0028$\pm$0.0014,0.0039$\pm$0.0023


In [57]:
method_order = ['scCBGM','scCBGM-FM (edit)', 'scCBGM-FM (decode)','Vanilla-FM (edit)', 'Vanilla-FM (decode)', 'CINEMA-OT','biolord']
df_table = df_table.loc[method_order]

In [58]:
def bolden_by_table(
    table_to_bolden: pd.DataFrame, values_to_bolden_by: pd.DataFrame
) -> pd.DataFrame:
    """
    Bold (LaTeX) the entry in each column of `table_to_bolden` whose corresponding value
    in `values_to_bolden_by` is the minimum for that column.

    Both dataframes should have the same shape and aligned index/columns.
    """
    # Align shapes/labels to avoid misalignment surprises
    vals = values_to_bolden_by.reindex_like(table_to_bolden)
    new_table = table_to_bolden.copy()

    # For each column, find the row index of the minimum
    min_row_idx_per_col = np.argmin(vals.to_numpy(), axis=0)

    # Bolden the matching cell in each column
    for col_i, col_name in enumerate(new_table.columns):
        r = min_row_idx_per_col[col_i]
        current_val = new_table.iat[r, col_i]

        # Avoid double-wrapping if already bold
        s = str(current_val)
        if not s.startswith("\\textbf{") and not s.startswith(r"\textbf{"):
            new_table.iat[r, col_i] = f"\\textbf{{{s}}}"

    return new_table

In [59]:
df_table = bolden_by_table(df_table, df_mean)

In [60]:
df_table.columns = ['Synthetic 1', 'Synthetic 2', 'Synthetic 3']
df_table.index.name = None

In [61]:
latex_str = df_table.to_latex( column_format= 'l' + df_table.shape[1] * 'c')

In [62]:
# Replace first \hline with \toprule
latex_str = latex_str.replace('\\hline', '\\toprule', 1)

# Replace last \hline with \bottomrule
latex_str = latex_str[::-1].replace('\\hline'[::-1], '\\bottomrule'[::-1], 1)[::-1]

In [63]:
print(latex_str)

\begin{tabular}{lccc}
\toprule
 & Synthetic 1 & Synthetic 2 & Synthetic 3 \\
\midrule
scCBGM & 0.0388$\pm$0.0053 & 0.0318$\pm$0.0071 & 0.0407$\pm$0.0027 \\
scCBGM-FM (edit) & \textbf{0.0025$\pm$0.0016} & \textbf{0.0028$\pm$0.0014} & \textbf{0.0039$\pm$0.0023} \\
scCBGM-FM (decode) & 0.0539$\pm$0.0039 & 0.0444$\pm$0.0082 & 0.0594$\pm$0.0092 \\
Vanilla-FM (edit) & 0.0031$\pm$0.0014 & 0.0049$\pm$0.0016 & 0.0049$\pm$0.0027 \\
Vanilla-FM (decode) & 0.0756$\pm$0.0014 & 0.0769$\pm$0.0014 & 0.0773$\pm$0.0041 \\
CINEMA-OT & 0.0390$\pm$0.0025 & 0.0393$\pm$0.0012 & 0.0400$\pm$0.0034 \\
biolord & 0.0402$\pm$0.0014 & 0.0411$\pm$0.0021 & 0.0428$\pm$0.0021 \\
\bottomrule
\end{tabular}

