In [None]:
import numpy as np 
import pandas as pd 
from einops import rearrange, pack
import pickle

import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns

In [None]:
plt.style.use("default")
matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42

In [None]:
model_names = {
    "Logistic Regression": "LogR",
    "Linear Regression": "LR",
    "Ridge Regression": "RR",
    "1-NN": "1-NN",
    "5-NN": "5-NN",
    "Decision Tree": "DT",
    "Random Forest": "RF",
    "Gradient Boosting": "GB",
    "MLP": "MLP",
    "SVM": "SVM",
}
method_names = {
    "ddpm": "DDPM",
    "synthpop-proper": "SP-P",
}
classification_datasets = [
    "german-credit", 
    "adult", 
    "breast-cancer",
]
regression_datasets = [
    "abalone", 
    "insurance", 
    "california-housing", 
    "ACS2018"
]
datasets = regression_datasets + classification_datasets
def dataset_type(dataset):
    return "classification" if dataset in classification_datasets else "regression"

n_repeats = 3 

dfs = {}
for dataset in datasets:
    records = []
    for repeat_ind in range(n_repeats):
        for method in method_names.keys():
            for model_name in model_names.keys():
                if dataset_type(dataset) == "classification" and model_name in ["Linear Regression", "Ridge Regression"]: continue
                if dataset_type(dataset) == "regression" and model_name == "Logistic Regression": continue

                with open("../results/{}-datasets/variance-estimation/{}/{}_{}_{}.p".format(dataset_type(dataset), dataset, method, model_name, repeat_ind), "rb") as file:
                    obj = pickle.load(file)
                    model_variances = obj["model_variances"]
                    synthetic_data_variances = obj["synthetic_data_variances"]

                    for variance in model_variances:
                        records.append({
                            "dataset": dataset,
                            "method": method_names[method],
                            "repeat_ind": repeat_ind,
                            "model": model_names[model_name],
                            "variance_type": "MV",
                            "value": variance,
                        })
                    for variance in synthetic_data_variances:
                        records.append({
                            "dataset": dataset,
                            "method": method_names[method],
                            "repeat_ind": repeat_ind,
                            "model": model_names[model_name],
                            "variance_type": "SDV",
                            "value": variance,
                        })

    df = pd.DataFrame.from_records(records)
    dfs[dataset] = df

classification_model_order = [
    "DT", "1-NN", "5-NN", "RF", "MLP", "GB", "SVM", "LogR"
]
regression_model_order = [
    "DT", "1-NN", "5-NN", "RF", "MLP", "GB", "SVM", "LR", "RR"
]

In [None]:
# Remove extremely large variances for linear regression
acs = dfs["ACS2018"]
dfs["ACS2018"] = acs[(acs.model != "LR") | (acs.value < 1e6)]

In [None]:
dfs["german-credit"]

In [None]:
figdir = "../figures/variance-estimation/"
legend_names = {
    "MV": "Model Variance (MV)",
    "SDV": "Synthetic Data Variance (SDV)",
}
for dataset in datasets:
    df = dfs[dataset]
    mean_df = df.groupby(["dataset", "method", "model", "repeat_ind", "variance_type"], as_index=False).mean()
    print(dataset)

    g = sns.FacetGrid(mean_df, col="method", aspect=1.2, height=2.5)
    # g.figure.suptitle(dataset)
    g.map_dataframe(
        sns.stripplot, x="model", y="value", 
        order=classification_model_order if dataset in classification_datasets else regression_model_order, 
        hue="variance_type", palette={"MV": "C0", "SDV": "C1"}
    )
    g.tick_params("x", labelrotation=45)
    g.set_ylabels("Variance")
    g.set_xlabels("")
    g.set_titles("{col_name}")
    g._legend_data = {legend_names[name]: artist for name, artist in g._legend_data.items()}
    g.add_legend(loc="upper right", bbox_to_anchor=(0.53, 0.05), ncol=2)
    for ax in g.axes.flatten():
        ax.grid()

    plt.savefig(figdir + "{}.pdf".format(dataset), bbox_inches="tight")
    plt.show()