## CBGM vs scCBGM Ablation Study Results

In [20]:
import wandb
import pandas as pd
import seaborn as sns
import conceptlab as clab

import matplotlib.pyplot as plt
import yaml
import numpy as np
import os.path as osp

In [21]:
OUT_DIR = '../../../results/synthetic/'
raw_filename = 'sccbgm_ablation.csv'
DOWNLOAD = False

In [22]:
with open('wandb_info.txt') as f:
    wandb_user_name = f.readlines()[0]

In [23]:
raw_path = osp.join(OUT_DIR, raw_filename)
if osp.isfile(raw_path) and not DOWNLOAD:

    _df = pd.read_csv(raw_path, index_col = 0, header = 0)
else:
    entity, project = wandb_user_name, "clab_performance_ablation_2"
    _df = clab.utils.wandb.download_wandb_project(project, entity)
    _df.to_csv(raw_path,)

In [24]:
df = _df.copy()

In [25]:
df['model'] = df['model'].map({'scCBGM':'scCBGM','cem_vae':'CBGM'})

In [26]:
grp_columns = ['model.orthogonality_hp','model.decoder_type']
metric_columns = ['MSE_intervened']

In [27]:
df_mean = (
    df[grp_columns + metric_columns].groupby(grp_columns).agg("mean").reset_index()
)
df_mean

df_std = df[grp_columns + metric_columns].groupby(grp_columns).agg("std").reset_index()
df_std

df_table = df_mean.copy()
df_table["MSE"] = [
    f"{x:0.5f} $\pm$ {y:0.5f}"
    for x, y in zip(df_mean["MSE_intervened"], df_std["MSE_intervened"])
]

In [28]:
print(df_mean)

   model.orthogonality_hp model.decoder_type  MSE_intervened
0                     0.0           residual        0.198859
1                     0.0               skip        0.198786
2                     0.5           residual        0.198713
3                     0.5               skip        0.198635


In [29]:
df_table.drop(columns = ['MSE_intervened'], inplace = True)

In [30]:
df_table["model.orthogonality_hp"] = df_table["model.orthogonality_hp"].map(
    {0.0: "\\xmark", 0.5: "\\cmark"}
)
df_table["model.decoder_type"] = df_table["model.decoder_type"].map(
    {"residual": "\\xmark", "skip": "\\cmark"}
)

df_table = df_table.rename(
    columns={
        "model.orthogonality_hp": r"$\mathcal{L}_{cc}$",
        "model.decoder_type": "skip",
    }
)

In [31]:
df_table = df_table.sort_values('MSE',ascending = False)

In [32]:
print(df_table)

  $\mathcal{L}_{cc}$    skip                    MSE
0             \xmark  \xmark  0.19886 $\pm$ 0.00045
1             \xmark  \cmark  0.19879 $\pm$ 0.00051
2             \cmark  \xmark  0.19871 $\pm$ 0.00276
3             \cmark  \cmark  0.19864 $\pm$ 0.00294


In [33]:
df_table.rename(columns = {'MSE':r'MSE ($\downarrow$)'}, inplace = True)

In [34]:
latex_str = df_table.to_latex(
    index=False, index_names=False, column_format="l" + df_table.shape[1] * "c"
)

In [35]:
# Replace first \hline with \toprule
latex_str = latex_str.replace('\\hline', '\\toprule', 1)

# Replace last \hline with \bottomrule
latex_str = latex_str[::-1].replace('\\hline'[::-1], '\\bottomrule'[::-1], 1)[::-1]



In [18]:
print(latex_str)

\begin{tabular}{lccc}
\toprule
$\mathcal{L}_{cc}$ & skip & MSE ($\downarrow$) \\
\midrule
\xmark & \xmark & 0.19886 $\pm$ 0.00045 \\
\xmark & \cmark & 0.19879 $\pm$ 0.00051 \\
\cmark & \xmark & 0.19871 $\pm$ 0.00276 \\
\cmark & \cmark & 0.19864 $\pm$ 0.00294 \\
\bottomrule
\end{tabular}

