# Visualizations

In [1]:
import numpy as np
import matplotlib.pyplot as plt
import pickle

path_base = "results/reproducible_results"

with open(path_base + "/results.pickle", "rb") as f:
    smaller_results = pickle.load(f)


data_names = ["credit", "adult", "gmsc"]

results = {}
for data_name in data_names:
    results[data_name] = {}
    results[data_name] = {
        "LiCE_optimize": smaller_results[data_name]["LiCE_optimize"],
        "LiCE_median": smaller_results[data_name]["LiCE_median"],
        "LiCE_quartile": smaller_results[data_name]["LiCE_quartile"],
        "LiCE_sample": smaller_results[data_name]["LiCE_sample"],
        "MIO + SPN": smaller_results[data_name]["MIO_no_spn"],

        "VAE + SPN": smaller_results[data_name]["VAE"],
        "DiCE + SPN": smaller_results[data_name]["DiCE"],
        "CH-CVAE ": smaller_results[data_name]["CVAE"],
        "FACE_knn": smaller_results[data_name]["FACE_knn"],
        "FACE_eps": smaller_results[data_name]["FACE_eps"],
        "PROPLACE": smaller_results[data_name]["PROPLACE"],
    }

In [2]:
res_type = "valid"
# res_type = "actionable"
method = "LiCE_optimize"

data_names = ["gmsc", "adult", "credit"]

print("LL value difference")
print("\t" + "\t\t".join(data_names))
print("true ", end="")
for data_name in data_names:
    curr_res = results[data_name][method]
    mean_ll = np.mean([res[res_type]["ll"] for res in curr_res.values() if res_type in res])
    print(f"& \\scinum{{{mean_ll}}} $\\pm$ \\scinum{{{np.std([res[res_type]['ll'] for res in curr_res.values() if res_type in res])}}} ", end="")
print()
print("mio ", end="")
for data_name in data_names:
    curr_res = results[data_name][method]
    mean_ll_comp = np.mean([res[res_type]["comp_ll"] for res in curr_res.values() if res_type in res])
    print(f"& \\scinum{{{mean_ll_comp}}} $\\pm$ \\scinum{{{np.std([res[res_type]['comp_ll'] for res in curr_res.values() if res_type in res])}}} ", end="")
print()
print("diff ", end="")
for data_name in data_names:
    curr_res = results[data_name][method]
    ll_diff = np.array([res[res_type]["ll"] for res in curr_res.values() if res_type in res]) - np.array([res[res_type]["comp_ll"] for res in curr_res.values() if res_type in res])
    print(f"& \\scinum{{{np.mean(ll_diff)}}} $\\pm$ \\scinum{{{np.std(ll_diff)}}} ", end="")
print()

LL value difference
	gmsc		adult		credit
true & \scinum{-25.6180484331279} $\pm$ \scinum{4.642900031468152} & \scinum{-18.14528819185729} $\pm$ \scinum{3.8878962458968647} & \scinum{-28.785663175658833} $\pm$ \scinum{3.2769836465496667} 
mio & \scinum{-25.709163152902526} $\pm$ \scinum{4.706936630938196} & \scinum{-18.67192853775819} $\pm$ \scinum{4.10411216765172} & \scinum{-29.013268255882007} $\pm$ \scinum{3.3415703166261443} 
diff & \scinum{0.09111471977462675} $\pm$ \scinum{0.5700239585170246} & \scinum{0.5266403459009045} $\pm$ \scinum{0.4698927020652466} & \scinum{0.22760508022317458} $\pm$ \scinum{0.24096795119989892} 


In [3]:
method_order = [
    "DiCE + SPN",
    "VAE + SPN",
    "CH-CVAE ",
    "FACE_eps",
    "FACE_knn",
    "PROPLACE",
    'MIO + SPN',
    'LiCE_optimize',
    'LiCE_quartile',
    'LiCE_sample',
    'LiCE_median',
]
method_map = {
    "DiCE + SPN" : "DiCE",
    "VAE + SPN": "VAE",
    "CH-CVAE ": "C-CHVAE",
    "FACE_eps": "FACE ($\\epsilon$)",
    "FACE_knn": "FACE (knn)",
    "PROPLACE": "PROPLACE",
    'LiCE_optimize': "LiCE (optimize)",
    'LiCE_quartile': "LiCE (quartile)",
    'LiCE_sample': "LiCE (sample)",
    'LiCE_median': "LiCE (median)",
    'MIO + SPN': "MIO",
}
method_map_short = {k:v for k,v in method_map.items()}
method_map_short['LiCE_optimize'] = "LiCE (optim.)"
method_map_short['LiCE_quartile'] = "LiCE (quart.)"
method_map_short['LiCE_median'] = "LiCE (median)"
method_map_short['MIO + SPN'] = "MIO (\\uline{+spn})"
method_map_short["DiCE + SPN"] = "DiCE (\\uline{+spn})"
method_map_short['VAE + SPN'] = "VAE (\\uline{+spn})"

res_type = "valid"

for method in method_order:
    print("& " + method_map_short[method], end="")
    for data_name in ["gmsc", "adult", "credit"]:
        curr_res = results[data_name][method]
        spar = [res[res_type]["sparsity"] for res in curr_res.values() if res_type in res]
        nll = [-res[res_type]["ll"] for res in curr_res.values() if res_type in res]
        dist = [res[res_type]["distance"] for res in curr_res.values() if res_type in res]
        for d in [nll, dist, spar]:
            print(f" & \\scinumone{{{np.mean(d)}}} $\\pm$ \\scinumone{{{np.std(d)}}}", end="")
    print(" \\\\")

print("\n\nvalid actionable")
for method in method_order:
    print("&" + method_map_short[method], end="")
    for data_name in ["gmsc", "adult", "credit"]:
        curr_res = results[data_name][method]
        val = ["valid" in res for res in curr_res.values()]
        act = ["actionable" in res for res in curr_res.values()]
        for d in [val, act]:
            print(f" & \\scinumone{{{np.mean(d)*100}}}\\%", end="")
    print(" \\\\")


print("\n\nTIME")
for method in method_order:
    print(method_map[method], end="")
    for data_name in ["gmsc", "adult", "credit"]:
        curr_res = results[data_name][method]
        time = [res["time"] for res in curr_res.values()]
        for d in [time]:
            print(f" & \\scinum{{{np.median(d)}}}s", end="")
    print(" \\\\")


& DiCE (\uline{+spn}) & \scinumone{29.10625689186585} $\pm$ \scinumone{5.182592246765204} & \scinumone{27.34425644482405} $\pm$ \scinumone{6.665083118856582} & \scinumone{6.496} $\pm$ \scinumone{1.0816579866112948} & \scinumone{21.024686408227474} $\pm$ \scinumone{2.992416842221361} & \scinumone{26.695823501370704} $\pm$ \scinumone{13.122152911830844} & \scinumone{4.484969939879759} $\pm$ \scinumone{1.782028563379273} & \scinumone{50.95000214113616} $\pm$ \scinumone{17.90368498481968} & \scinumone{27.747628768160773} $\pm$ \scinumone{7.203097131076084} & \scinumone{8.699596774193548} $\pm$ \scinumone{2.128076482372827} \\
& VAE (\uline{+spn}) & \scinumone{17.95752152874899} $\pm$ \scinumone{2.205400482961012} & \scinumone{17.15942423141453} $\pm$ \scinumone{3.7009726214613} & \scinumone{8.0} $\pm$ \scinumone{1.1547005383792515} & \scinumone{18.362415899918965} $\pm$ \scinumone{3.6099620553232543} & \scinumone{37.05019368637985} $\pm$ \scinumone{13.158852411535614} & \scinumone{5.390977

In [4]:
# ONLY SAME 
res_type = "valid"

method_order = [m for m in method_order if "sample" not in m and "quart" not in m]  

same_is = {}
for data_name in data_names:
    same_is[data_name] = []
    methods = method_order
    if data_name == "gmsc":
        methods = [m for m in methods if m != "VAE + SPN"]
    for i in results[data_name][methods[0]].keys():
        if all(res_type in results[data_name][method][i] for method in methods):
            same_is[data_name].append(i)
print({k: len(v) for k, v in same_is.items()})

for method in method_order:
    print(method_map_short[method], end="")
    for data_name in ["gmsc", "adult", "credit"]:
        if data_name == "gmsc" and method == "VAE + SPN":
            print(" & - & - & -", end="")
            continue
        curr_res = results[data_name][method]
        spar = [curr_res[i][res_type]["sparsity"] for i in same_is[data_name]]
        nll = [-curr_res[i][res_type]["ll"] for i in same_is[data_name]]
        dist = [curr_res[i][res_type]["distance"] for i in same_is[data_name]]
        for d in [nll, dist, spar]:
            print(f" & \\scinumone{{{np.mean(d)}}} $\\pm$ \\scinumone{{{np.std(d)}}}", end="")
    print(" \\\\")

{'gmsc': 254, 'adult': 55, 'credit': 56}
DiCE (\uline{+spn}) & \scinumone{28.103874532656263} $\pm$ \scinumone{5.502861442487442} & \scinumone{28.451628992662673} $\pm$ \scinumone{6.525698069047968} & \scinumone{6.582677165354331} $\pm$ \scinumone{1.097175650289786} & \scinumone{19.824268128084924} $\pm$ \scinumone{2.5968090149044296} & \scinumone{22.861493088500712} $\pm$ \scinumone{6.148794719362975} & \scinumone{4.090909090909091} $\pm$ \scinumone{1.443137078762504} & \scinumone{35.10076510249688} $\pm$ \scinumone{2.962354972612703} & \scinumone{22.093721763150263} $\pm$ \scinumone{4.619429460202573} & \scinumone{7.625} $\pm$ \scinumone{1.8761900985012914} \\
VAE (\uline{+spn}) & - & - & - & \scinumone{17.85207693716387} $\pm$ \scinumone{2.785874577263402} & \scinumone{31.91471250642849} $\pm$ \scinumone{9.734174357728024} & \scinumone{5.0} $\pm$ \scinumone{1.2358287613066494} & \scinumone{46.18310370969824} $\pm$ \scinumone{16.99755806547317} & \scinumone{27.794408407952353} $\pm$ 