In [None]:
import numpy as np 
import pandas as pd 
from einops import rearrange, pack

import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns

In [None]:
plt.style.use("default")
matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42

In [None]:
result_dir = "../results/preliminary-nondp/"
fig_dir = "../figures/synthetic-data-algo-evaluation/"
df = pd.read_csv(result_dir + "results.csv", index_col=False)
model_order = [
    "1-NN", "5-NN", "Decision Tree", "Random Forest", "MLP", "Gradient Boosting", 
    "SVM", "Ridge Regression", "Linear Regression"
]
full_model_order = model_order.copy()
model_order.remove("Linear Regression")

method_names = {
    "ddpm": "DDPM",
    "ddpm-kl": "DDPM-KL",
    "tvae": "TVAE",
    "ctgan": "CTGAN",
    "synthpop-proper": "SP-P",
    "synthpop-improper": "SP-IP"
}
df.method = df.method.apply(lambda m: method_names[m])

real_data_df = pd.read_csv(result_dir + "real-data-results.csv", index_col=False)
real_data_mses = real_data_df.groupby(["model"]).mean()
real_data_mses = real_data_mses.reindex(model_order)
real_data_mses

In [None]:
min_real_data_mse = real_data_mses.mse.iloc[real_data_mses.mse.argmin()]

In [None]:
group_df = df.groupby(["model", "method", "n_syn_datasets"], as_index=False).mean()
group_df.iloc[group_df.mse.argmin()]

In [None]:
g = sns.FacetGrid(df, col="model", col_order=model_order, col_wrap=4, aspect=1.2, height=2.5)
g.map_dataframe(sns.barplot, x="method", y="mse", hue="n_syn_datasets", palette="flare")
for i, mse in enumerate(real_data_mses.mse):
    g.axes[i].axhline(mse, linestyle="dashed", color="black", label="Real Data")
    g.axes[i].axhline(min_real_data_mse, color="black", label="Best Real Data")
for ax in g.axes:
    ax.set_axisbelow(True)
    ax.grid()
g.set_xlabels("")
g.set_ylabels("MSE")
g.set_titles("{col_name}", fontweight="bold")
g.set(ylim=(min_real_data_mse * 0.9, 0.2))
g.tick_params("x", labelrotation=45)
full_legend_data = {"m = {}".format(label) if label in ["1", "2", "5", "10"] else label: handle for handle, label in zip(*g.axes[0].get_legend_handles_labels())}
g.add_legend(full_legend_data, label_order=["m = 1", "m = 2", "m = 5", "m = 10", "Real Data", "Best Real Data"], loc="upper center", bbox_to_anchor=(0.32,-0.02), ncol=6)
plt.savefig(fig_dir + "generator-comparison.pdf", bbox_inches="tight")
plt.show()

In [None]:
g = sns.FacetGrid(df, col="method")
g.map_dataframe(sns.barplot, x="model", y="mse", hue="n_syn_datasets")
g.tick_params("x", labelrotation=90)
g.add_legend()
plt.show()

In [None]:
df.groupby(["model", "method", "n_syn_datasets"])["mse"].mean().reset_index("n_syn_datasets")

In [None]:
table = df.groupby(["model", "method", "n_syn_datasets"])["mse"].aggregate(["mean", "std"])
table["formatted"] = table.apply(lambda row: "{:.4f} $\pm$ {:.4f}".format(row["mean"], row["std"]), axis=1)
table = table.reset_index("n_syn_datasets").pivot(columns="n_syn_datasets", values="formatted")
table.index.rename(["Downstream", "Generator"], inplace=True)
table.columns.rename("m", inplace=True)
table = table.reindex(full_model_order, level="Downstream", axis="index")
table

In [None]:
table.style.to_latex(fig_dir + "generator-comparison-table.tex", hrules=True, clines="skip-last;data")