## CBGM vs. scCBGM Performance Comparison

In [7]:
import wandb
import pandas as pd
import seaborn as sns
import conceptlab as clab

import matplotlib.pyplot as plt
import yaml

import os.path as osp

In [None]:
OUT_DIR = '../../../results/synthetic/'
raw_filename = 'cbgm_vs_sccbgm_performance_sweep.csv'
DOWNLOAD = False

In [9]:
with open('wandb_info.txt') as f:
    wandb_user_name = f.readlines()[0]

In [10]:
raw_path = osp.join(OUT_DIR, raw_filename)
if osp.isfile(raw_path) and not DOWNLOAD:

    _df = pd.read_csv(raw_path, index_col = 0, header = 0)
else:
    entity, project = wandb_user_name, "clab_performance_sweep_3"
    _df = clab.utils.wandb.download_wandb_project(project, entity)
    _df.to_csv(raw_path,)

In [11]:
df = _df.copy()

In [12]:
df['model'] = df['model'].map({'scCBGM':'scCBGM','cem_vae':'CBGM'})

In [14]:
grp_columns = ['model','modify']
metric_columns = ['MSE_intervened']

In [27]:
df_mean = (
    df[grp_columns + metric_columns].groupby(grp_columns).agg("mean").reset_index()
)
df_mean

df_std = df[grp_columns + metric_columns].groupby(grp_columns).agg("std").reset_index()
df_std

df_table = df_mean.copy()
df_table["MSE"] = [
    f"{x:0.3f} $\pm$ {y:0.3f}"
    for x, y in zip(df_mean["MSE_intervened"], df_std["MSE_intervened"])
]

In [16]:
print(df_mean)

    model     modify  MSE_intervened
0    CBGM        add        0.248551
1    CBGM    default        0.206041
2    CBGM       drop        0.246806
3    CBGM  duplicate        0.227106
4    CBGM      noise        0.228912
5  scCBGM        add        0.199264
6  scCBGM    default        0.199276
7  scCBGM       drop        0.199873
8  scCBGM  duplicate        0.199299
9  scCBGM      noise        0.199500


In [17]:
df_table.drop(columns = ['MSE_intervened'], inplace = True)

In [18]:
df_table = df_table.pivot_table(index='model', columns='modify', values = 'MSE',aggfunc=lambda x: x)

In [28]:
df_table = df_table.rename(
    columns={
        "add": "Irrelevant",
        "default": "Clean",
        "drop": "Missing",
        "duplicate": "Duplicated",
        "noise": "Incorrect",
    }
)
df_table = df_table[["Clean", "Incorrect", "Irrelevant", "Missing", "Duplicated"]]

In [20]:
df_table.loc['scCBGM'] = [f'\\textbf{{{x}}}' for x in df_table.loc['scCBGM'].values]

In [21]:
df_table.columns.name = None
df_table.index.name = None

In [29]:
latex_str = df_table.to_latex(column_format="l" + "c" * (len(df_table.columns)))
# Replace first \hline with \toprule
latex_str = latex_str.replace("\\hline", "\\toprule", 1)

# Replace last \hline with \bottomrule
latex_str = latex_str[::-1].replace("\\hline"[::-1], "\\bottomrule"[::-1], 1)[::-1]

In [23]:
print(latex_str)

\begin{tabular}{lccccc}
\toprule
 & Clean & Incorrect & Irrelevant & Missing & Duplicated \\
\midrule
CBGM & 0.206 $\pm$ 0.048 & 0.229 $\pm$ 0.112 & 0.249 $\pm$ 0.271 & 0.247 $\pm$ 0.123 & 0.227 $\pm$ 0.120 \\
scCBGM & \textbf{0.199 $\pm$ 0.002} & \textbf{0.199 $\pm$ 0.002} & \textbf{0.199 $\pm$ 0.002} & \textbf{0.200 $\pm$ 0.002} & \textbf{0.199 $\pm$ 0.002} \\
\bottomrule
\end{tabular}

