In [None]:
import numpy as np 
import pandas as pd 
from einops import rearrange, pack

import matplotlib.pyplot as plt
import matplotlib
import seaborn as sns

In [None]:
plt.style.use("default")
matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42

In [None]:
result_dir = "../results/classification-datasets/"
fig_dir = "../figures/classification/"
datasets = [
    "german-credit", 
    "adult", 
    "breast-cancer"
]

metric_names = {
    "brier": "Brier Score",
    "log_loss": "Cross Entropy",
    # "accuracy": "Accuracy",
    # "auc": "AUC",
    "accuracy_comp": "1 - Accuracy",
    "auc_comp": "1 - AUC",
}
model_names = {
    "Logistic Regression": "LogR",
    "1-NN": "1-NN",
    "5-NN": "5-NN",
    "Decision Tree": "DT",
    "Random Forest": "RF",
    "Gradient Boosting": "GB",
    "MLP": "MLP",
    "SVM": "SVM",
}
method_names = {
    "ddpm": "DDPM",
    "synthpop-proper": "SP-P",
}
metrics = metric_names.keys()

dataset_names = {
    "adult": "Adult",
    "breast-cancer": "Breast Cancer",
    "german-credit": "German Credit",
}
inv_dataset_names = {val: key for key, val in dataset_names.items()}
dataset_order = list(dataset_names.values())

dfs = {}
real_data_dfs = {}
for dataset in datasets:
    dfs[dataset] = pd.read_csv("{}{}/results.csv".format(result_dir, dataset), index_col=False)
    dfs[dataset]["model_short"] = dfs[dataset].model.apply(lambda m: model_names[m])
    dfs[dataset].method = dfs[dataset].method.apply(lambda m: method_names[m])
    dfs[dataset]["dataset"] = dataset_names[dataset]
    real_data_dfs[dataset] = pd.read_csv("{}{}/real-data-results.csv".format(result_dir, dataset), index_col=False)
    
for df in dfs.values():
    df.primal = df.primal.apply(lambda m: "Prob. Avg." if m== "Primal" else "Log Prob. Avg.")
    df["method_primal"] = df.apply(lambda row: "{} - {}".format(row["method"], row["primal"]), axis=1)
    df["accuracy_comp"] = df.accuracy.apply(lambda acc: 1 - acc)
    df["auc_comp"] = df.auc.apply(lambda auc: 1 - auc)

for df in real_data_dfs.values():
    df["accuracy_comp"] = df.accuracy.apply(lambda acc: 1 - acc)
    df["auc_comp"] = df.auc.apply(lambda auc: 1 - auc)

df_all_datasets = pd.concat(dfs, ignore_index=True)

long_dfs = {
    name: df.melt(
        id_vars=["repeat_ind", "model", "model_short", "n_syn_datasets", "method", "dataset", "primal", "method_primal"], 
        value_vars=["brier", "log_loss", "accuracy_comp", "auc_comp"]
    )
    for name, df in dfs.items()
}

# model_order_short = list(dfs["german-credit"].model_short.unique())
model_order_short = [
    "1-NN", "5-NN", "DT", "RF", "MLP", "GB", "SVM", "LogR"
]
model_order = [
    "1-NN", "5-NN", "Decision Tree", "Random Forest", "MLP", "Gradient Boosting", 
    "SVM", "Logistic Regression"
]
# model_order = list(dfs["german-credit"].model.unique())
metric_order = list(long_dfs["german-credit"].variable.unique())
n_repeats = len(dfs["german-credit"].repeat_ind.unique())

real_data_metrics = {metric: {} for metric in metrics}
min_real_data_metrics = {metric: {} for metric in metrics}
for dataset in datasets:
    for metric in metrics:
        metric_df = real_data_dfs[dataset][["model", metric]].groupby(["model"]).mean()
        real_data_metrics[metric][dataset] = metric_df.reindex(model_order)
        min_real_data_metrics[metric][dataset] = metric_df[metric].iloc[metric_df[metric].argmin()]

In [None]:
for dataset in datasets:
    group_df = dfs[dataset].groupby(["model", "method", "n_syn_datasets", "primal"], as_index=False).mean(numeric_only=True)
    for metric in metrics:
        print("Lowest {} for {}".format(metric_names[metric], dataset_names[dataset]))
        print(group_df.iloc[group_df[metric].argmin()][["model", "method", "primal", "n_syn_datasets", "brier", "log_loss", "accuracy", "auc"]])
        print()

In [None]:
estimation_dfs = {}
for dataset in datasets:
    records = []
    df = dfs[dataset]
    for model in model_names.keys():
        for method in method_names.values():
            for primal in ["Prob. Avg.", "Log Prob. Avg."]:
                for repeat_ind in range(n_repeats):
                    sdf = df[df.model == model]
                    sdf = sdf[sdf.method == method]
                    sdf = sdf[sdf.repeat_ind == repeat_ind]
                    sdf = sdf[sdf.primal == primal]

                    mse_series_m1 = sdf[sdf.n_syn_datasets == 1]["brier"]
                    mse_series_m2 = sdf[sdf.n_syn_datasets == 2]["brier"]
                    mse_1 = mse_series_m1.iloc[0] if len(mse_series_m1) > 0 else np.nan
                    mse_2 = mse_series_m2.iloc[0] if len(mse_series_m2) > 0 else np.nan
                    records.append({
                        "model": model,
                        "method": method,
                        "primal": primal,
                        "repeat_ind": repeat_ind,
                        "mse_1": mse_1,
                        "red_estimate": 2 * (mse_1 - mse_2)
                    })
    estimation_dfs[dataset] = pd.DataFrame.from_records(records)

def estimate_mse(row, dataset):
    estimation_df = estimation_dfs[dataset]
    sel_estimate = estimation_df[
        (estimation_df.model == row.model)
        & (estimation_df.method == row.method)
        & (estimation_df.primal == row.primal)
        & (estimation_df.repeat_ind == row.repeat_ind)
    ]
    mse1 = sel_estimate.mse_1.iloc[0]
    red_estimate = sel_estimate.red_estimate.iloc[0]
    estimated_mse = mse1 - (1 - 1 / row.n_syn_datasets) * red_estimate
    return estimated_mse

for dataset in datasets:
    dfs[dataset] = dfs[dataset].assign(est_mse=dfs[dataset].apply(lambda row: estimate_mse(row, dataset), axis=1))

In [None]:
def plot_by_model(df, dataset, metric, save=False):
    g = sns.FacetGrid(df, col="model_short", col_order=model_order_short, col_wrap=5)
    g.figure.suptitle(dataset_names[dataset])
    g.map_dataframe(sns.barplot, x="method_primal", y=metric, hue="n_syn_datasets")
    for i, metric_val in enumerate(real_data_metrics[metric][dataset][metric]):
        g.axes[i].axhline(metric_val, linestyle="dashed", color="grey")
        g.axes[i].axhline(min_real_data_metrics[metric][dataset], color="black")
    for ax in g.axes:
        ax.set_axisbelow(True)
        ax.grid()
    g.set_xlabels("")
    g.set_ylabels(metric_names[metric])
    g.tick_params("x", labelrotation=90)
    g.add_legend()
    if save:
        plt.savefig("{}{}-{}-by-model.pdf".format(fig_dir, dataset, metric), bbox_inches="tight")
    plt.show()

def plot_by_method(df, dataset, metric, save=False, selected_primal=None, selected_method=None):
    if selected_primal is not None:
        df = df[df.primal.apply(lambda val: val in selected_primal)]
    if selected_method is not None:
        df = df[df.method.apply(lambda val: val in selected_method)]

    g = sns.FacetGrid(df, col="method_primal", height=2.2, aspect=1.2)
    # g.figure.suptitle(dataset_names[dataset])
    g.map_dataframe(sns.barplot, x="model_short", y=metric, order=model_order_short, hue="n_syn_datasets", palette="flare", errwidth=0.7)
    for ax in g.axes.flatten():
        ax.axhline(min_real_data_metrics[metric][dataset], color="black")
    for ax in g.axes.flatten():
        ax.set_axisbelow(True)
        ax.grid()
    g.set(ylim=(min_real_data_metrics[metric][dataset] * 0.9, None))
    g.set_xlabels("")
    g.set_ylabels(metric_names[metric])
    g.set_titles("{col_name}", fontweight="bold")
    g.tick_params("x", labelrotation=45)
    g.add_legend(title="m")
    if save:
        plt.savefig("{}{}-{}-by-method.pdf".format(fig_dir, dataset, metric), bbox_inches="tight")
    plt.show()

def plot_by_dataset(df, selected_method, selected_primal, metric, save=False, file_suffix=""):
    df = df[df.method.apply(lambda val: val in selected_method)]
    df = df[df.primal.apply(lambda val: val in selected_primal)]

    g = sns.FacetGrid(df, col="dataset", height=2.2, aspect=1.2, sharey=False, col_order=dataset_order)
    g.map_dataframe(sns.barplot, x="model_short", y=metric, order=model_order_short, hue="n_syn_datasets", palette="flare", errwidth=1.2)

    for dataset, ax in g.axes_dict.items():
        dataset_key = inv_dataset_names[dataset]
        ax.axhline(min_real_data_metrics[metric][dataset_key], color="black")
        ax.set_ylim((min_real_data_metrics[metric][dataset_key] * 0.9, None))

    for ax in g.axes.flatten():
        ax.set_axisbelow(True)
        ax.grid()

    g.tick_params("x", labelrotation=45)
    g.set_ylabels(metric_names[metric])
    g.set_xlabels("")
    g.add_legend(title="m")
    g.set_titles("{col_name}", fontweight="bold")

    if save:
        plt.savefig("{}{}-by-dataset{}.pdf".format(fig_dir, metric, file_suffix), bbox_inches="tight")

    plt.show()

def plot_all_metrics(long_df, dataset, save=False):
    g = sns.FacetGrid(long_df, col="method_primal", row="variable", sharey="row", aspect=1.2)
    # g.figure.suptitle(dataset_names[dataset])
    g.map_dataframe(sns.barplot, x="model_short", y="value", order=model_order_short, hue="n_syn_datasets", palette="flare", errwidth=1.9)
    for row_i in range(g.axes.shape[0]):
        for col_i in range(g.axes.shape[1]):
            g.axes[row_i, col_i].axhline(min_real_data_metrics[metric_order[row_i]][dataset], color="black")
            g.axes[row_i, col_i].set_ylim((min_real_data_metrics[metric_order[row_i]][dataset] * 0.9, None))
    for ax in g.axes.flatten():
        ax.set_axisbelow(True)
        ax.grid()
    g.set_xlabels("")
    for (metric, method), ax in g.axes_dict.items():
        if metric == "log_loss":
            ax.set_yscale("log")
    for i, ax in enumerate(g.axes[:, 0]):
        ax.set_ylabel(metric_names[metric_order[i]])
    g.set_titles("{col_name}", fontweight="bold")
    g.tick_params("x", labelrotation=45)
    g.add_legend(title="m")
    if save:
        plt.savefig("{}{}-all-metrics.pdf".format(fig_dir, dataset), bbox_inches="tight")
    plt.show()

def plot_mse_est(df, dataset, plot_order=model_order, save=False, file_suffix=""):
    df = df[df.primal == "Prob. Avg."]
    num_cols = 4 if len(plot_order) > 5 else len(plot_order)
    g = sns.FacetGrid(df, col="model", col_wrap=num_cols, col_order=plot_order, height=2.2, aspect=1.4)
    # g.figure.suptitle(dataset_names[dataset])

    g.map_dataframe(
        sns.lineplot, x="n_syn_datasets", y="brier", hue="method", style="method",
        markers=True,
    )
    legend_data = {"{} Measured".format(name): line for name, line in g._legend_data.items()}

    g.map_dataframe(
        sns.lineplot, x="n_syn_datasets", y="est_mse", hue="method", style="method",
        linestyle="dashed", palette=["C2", "C3"], err_style="band",
        markers=["^", "v"]
    )
    legend_data.update({"{} Predicted".format(name): line for name, line in g._legend_data.items()})
    legend_data["DDPM Predicted"].set_linestyle("dashed")
    legend_data["SP-P Predicted"].set_linestyle("dashed")

    g.add_legend(legend_data, label_order=["DDPM Measured", "DDPM Predicted", "SP-P Measured", "SP-P Predicted"], ncol=6, loc="upper right", bbox_to_anchor=(0.52, 0))
    for ax in g.axes.flatten():
        ax.grid()

    g.set_ylabels("Brier Score")
    g.set_xlabels("m (# Synthetic Datasets)")
    g.set_titles("{col_name}", fontweight="bold")
    if save:
        plt.savefig("{}{}-mse-est{}.pdf".format(fig_dir, dataset, file_suffix), bbox_inches="tight")
    plt.show()

In [None]:
dataset = "german-credit"
plot_mse_est(dfs[dataset], dataset)

In [None]:
plot_mse_est(dfs["german-credit"], "german-credit", plot_order=["5-NN", "Decision Tree", "Gradient Boosting", "MLP"], save=True, file_suffix="-small")

In [None]:
dataset = "german-credit"
plot_all_metrics(long_dfs[dataset], dataset)

In [None]:
dataset = "german-credit"
plot_by_model(dfs[dataset], dataset, "brier")

In [None]:
dataset = "german-credit"
plot_by_method(dfs[dataset], dataset, "brier", selected_primal=["Prob. Avg."])

In [None]:
plot_by_dataset(df_all_datasets, ["SP-P"], ["Prob. Avg."], "brier", save=True)

In [None]:
for dataset in datasets:
    plot_by_method(dfs[dataset], dataset, "brier", save=True, selected_primal=["Prob. Avg."], selected_method=["SP-P"])

In [None]:
for dataset in datasets:
    plot_all_metrics(long_dfs[dataset], dataset, save=True)
    plot_mse_est(dfs[dataset], dataset, save=True)
    # for metric in metrics:
    #     plot_by_model(dfs[dataset], dataset, metric, save=False)
    #     plot_by_method(dfs[dataset], dataset, metric, save=False)

In [None]:
for dataset, df in dfs.items():
    table = df.groupby(["model", "method_primal", "n_syn_datasets"])["brier"].aggregate(["mean", "std"])
    table["formatted"] = table.apply(lambda row: "{:.2f} $\pm$ {:.3f}".format(row["mean"], row["std"]), axis=1)
    table = table.reset_index("n_syn_datasets").pivot(columns="n_syn_datasets", values="formatted")
    table.index.rename(["Downstream", "Generator"], inplace=True)
    table.columns.rename("m", inplace=True)
    table = table.reindex(model_order, level="Downstream", axis="index")
    table.style.to_latex(fig_dir + "{}-brier-table.tex".format(dataset), hrules=True, clines="skip-last;data")

In [None]:
pred_measured_col_name = "Predicted / Measured"
for dataset, df in dfs.items():
    df = df.melt(id_vars=["repeat_ind", "model", "n_syn_datasets", "method"], value_vars=["brier", "est_mse"], var_name=pred_measured_col_name)
    table = df.groupby(["model", "n_syn_datasets", "method", pred_measured_col_name])["value"].aggregate(["mean", "std"])
    table["formatted"] = table.apply(lambda row: "{:.2f} $\pm$ {:.3f}".format(row["mean"], row["std"]), axis=1)
    table = table.reset_index("n_syn_datasets").pivot(columns="n_syn_datasets", values="formatted")
    table.index.rename(["Downstream", "Generator", pred_measured_col_name], inplace=True)
    table.columns.rename("m", inplace=True)
    table = table.reindex(model_order, level="Downstream", axis="index")
    table.rename(index={"est_mse": "Predicted", "brier": "Measured"}, inplace=True)
    table.style.to_latex(fig_dir + "{}-mse-est-table.tex".format(dataset), hrules=True, clines="skip-last;data")

In [None]:
df = dfs["german-credit"]
df = df[df.primal == "Prob. Avg."]
df = df.melt(id_vars=["repeat_ind", "model", "n_syn_datasets", "method"], value_vars=["brier", "est_mse"], var_name=pred_measured_col_name)
df