# Simulations #1

Regular Lasso:
-  ✓ Finds optimal balance between fit and regularization
-  ✓ Works even when true β far from 0 (just needs small α)
-  ✓ Fast, deterministic, always converges

Bayesian Lasso:
-  ✗ Prior assumption: β should be near 0 (generative model)
-  ✗ When data contradicts prior, MCMC has convergence issues
-  ✗ Slow, stochastic, may not converge
-  ✗ Fundamentally assumes sparse β near 0 (wrong for this problem)

## Read CSVs

In [None]:
rashomon_csv = "../Results/worst_case/worst_case_rashomon.csv"
# rashomon_csv = "../Results/worst_case/worst_case_rashomon_test.csv"
rashomon_raw_df = pd.read_csv(rashomon_csv)
rashomon_raw_df = rashomon_raw_df.drop("Unnamed: 0", axis=1)
rashomon_raw_df.head()

In [None]:
lasso_csv = "../Results/worst_case/worst_case_lasso.csv"
# lasso_csv = "../Results/worst_case/worst_case_lasso_test.csv"
lasso_raw_df = pd.read_csv(lasso_csv)
lasso_raw_df = lasso_raw_df.drop("Unnamed: 0", axis=1)
lasso_raw_df.head()

In [None]:
tva_csv = "../Results/worst_case/worst_case_tva.csv"
# tva_csv = "../Results/worst_case/worst_case_tva_test.csv"
tva_raw_df = pd.read_csv(tva_csv)
tva_raw_df = tva_raw_df.drop("Unnamed: 0", axis=1)
tva_raw_df.head()

In [None]:
# ct_csv = "../Results/worst_case_causal_trees.csv"
# ct_raw_df = pd.read_csv(ct_csv)
# ct_raw_df = ct_raw_df.drop("Unnamed: 0", axis=1)
# ct_raw_df.head()

## Summarize rashomon sets results

In [None]:
rashomon_df = rashomon_raw_df.copy()

rashomon_df["best_pol_MSE"] = rashomon_df["best_pol_diff"]**2

group_by_cols = ["n_per_pol", "sim_num"]
result_cols = ["num_pools", "MSE", "IOU", "min_dosage", "best_pol_MSE"]
for result_col in result_cols:
    result_min_col = result_col + "_min"
    result_avg_col = result_col + "_mean"
    result_max_col = result_col + "_max"
    result_var_col = result_col + "_var"
    
    rashomon_df[result_min_col] = rashomon_df.groupby(group_by_cols)[result_col].transform("min")
    rashomon_df[result_avg_col] = rashomon_df.groupby(group_by_cols)[result_col].transform("mean")
    rashomon_df[result_max_col] = rashomon_df.groupby(group_by_cols)[result_col].transform("max")
    rashomon_df[result_var_col] = rashomon_df.groupby(group_by_cols)[result_col].transform("var")

rashomon_df = rashomon_df.drop_duplicates(group_by_cols)
rashomon_df = rashomon_df.drop(result_cols, axis=1)
rashomon_df = rashomon_df.drop(["best_pol_diff"], axis=1)
rashomon_df.head(n=10)

## Average over simulations

In [None]:
sum_cols = [
    "num_pools_min", "num_pools_mean", "num_pools_max", "num_pools_var",
    "MSE_min", "MSE_mean", "MSE_max", "MSE_var",
    "IOU_min", "IOU_mean", "IOU_max", "IOU_var",
    "min_dosage_min", "min_dosage_mean", "min_dosage_max", "min_dosage_var",
    "best_pol_MSE_min", "best_pol_MSE_mean", "best_pol_MSE_max", "best_pol_MSE_var"
]

for sum_col in sum_cols:
    rashomon_df[sum_col] = rashomon_df.groupby("n_per_pol")[sum_col].transform("mean")

rashomon_df = rashomon_df.drop_duplicates("n_per_pol")
rashomon_df.head()

In [None]:
lasso_df = lasso_raw_df.copy()

lasso_df["best_pol_MSE"] = lasso_df["best_pol_diff"]**2
sum_cols_lasso = ["MSE", "L1_loss", "IOU", "min_dosage", "best_pol_MSE"]

for sum_col in sum_cols_lasso:
    result_min_col = sum_col + "_min"
    result_avg_col = sum_col + "_mean"
    result_max_col = sum_col + "_max"
    result_var_col = sum_col + "_var"
    
    lasso_df[result_min_col] = lasso_df.groupby("n_per_pol")[sum_col].transform("min")
    lasso_df[result_avg_col] = lasso_df.groupby("n_per_pol")[sum_col].transform("mean")
    lasso_df[result_max_col] = lasso_df.groupby("n_per_pol")[sum_col].transform("max")
    lasso_df[result_var_col] = lasso_df.groupby("n_per_pol")[sum_col].transform("var")

lasso_df = lasso_df.drop_duplicates("n_per_pol")
lasso_df = lasso_df.drop(sum_cols_lasso, axis=1)
lasso_df = lasso_df.drop(["best_pol_diff"], axis=1)
lasso_df.head(n=10)

In [None]:
tva_df = tva_raw_df.copy()

tva_df["best_pol_MSE"] = tva_df["best_pol_diff"]**2
sum_cols_tva = ["MSE", "TVA_loss", "IOU", "min_dosage", "best_pol_MSE"]

for sum_col in sum_cols_tva:
    result_min_col = sum_col + "_min"
    result_avg_col = sum_col + "_mean"
    result_max_col = sum_col + "_max"
    result_var_col = sum_col + "_var"

    tva_df[result_min_col] = tva_df.groupby("n_per_pol")[sum_col].transform("min")
    tva_df[result_avg_col] = tva_df.groupby("n_per_pol")[sum_col].transform("mean")
    tva_df[result_max_col] = tva_df.groupby("n_per_pol")[sum_col].transform("max")
    tva_df[result_var_col] = tva_df.groupby("n_per_pol")[sum_col].transform("var")

tva_df = tva_df.drop_duplicates("n_per_pol")
tva_df = tva_df.drop(sum_cols_tva, axis=1)
tva_df = tva_df.drop(["best_pol_diff"], axis=1)
tva_df.head(n=10)

In [None]:
# ct_df = ct_raw_df.copy()

# ct_df["best_pol_MSE"] = ct_df["best_pol_diff"]**2
# sum_cols_ct = ["MSE", "IOU", "min_dosage", "best_pol_MSE"]

# for sum_col in sum_cols_ct:
#     result_min_col = sum_col + "_min"
#     result_avg_col = sum_col + "_mean"
#     result_max_col = sum_col + "_max"
#     result_var_col = sum_col + "_var"
    
#     ct_df[result_min_col] = ct_df.groupby("n_per_pol")[sum_col].transform("min")
#     ct_df[result_avg_col] = ct_df.groupby("n_per_pol")[sum_col].transform("mean")
#     ct_df[result_max_col] = ct_df.groupby("n_per_pol")[sum_col].transform("max")
#     ct_df[result_var_col] = ct_df.groupby("n_per_pol")[sum_col].transform("var")

# ct_df = ct_df.drop_duplicates("n_per_pol")
# ct_df = ct_df.drop(sum_cols_ct, axis=1)
# ct_df = ct_df.drop(["best_pol_diff"], axis=1)
# ct_df.head(n=10)

## Plots

In [None]:
fig, ax = plt.subplots(figsize=(5,5))

ax.spines[['right', 'top']].set_visible(False)

ax.plot(rashomon_df["n_per_pol"], rashomon_df["MSE_mean"],
        color="dodgerblue", zorder=1,
        marker="o", markeredgecolor="black", markersize=7,
        label="Rashomon Set")

ax.plot(lasso_df["n_per_pol"], lasso_df["MSE_mean"],
        color="indianred", zorder=3, clip_on=False,
        marker="o", markeredgecolor="black", markersize=7,
        label="LASSO")

# ax.plot(ct_df["n_per_pol"], ct_df["MSE_mean"],
#         color="seagreen", zorder=3.5, clip_on=False,
#         marker="o", markeredgecolor="black", markersize=7,
#         label="Causal Tree")

ax.plot(tva_df["n_per_pol"], tva_df["MSE_mean"],
        color="seagreen", zorder=3.5, clip_on=False,
        marker="o", markeredgecolor="black", markersize=7,
        label="TVA")

ax.set_xscale("log")
ax.set_xlabel("Samples per feature", fontsize=12)
ax.set_ylim(0.8, 1.1)
ax.set_ylabel("MSE", fontsize=12)

ax.legend(loc="center left", bbox_to_anchor=(1, 0.5))

plt.savefig("../Figures/worst_case/MSE.png", dpi=300, bbox_inches="tight")
plt.show()

In [None]:
fig, ax = plt.subplots(figsize=(5,5))

ax.spines[['right', 'top']].set_visible(False)

ax.plot(rashomon_df["n_per_pol"], rashomon_df["IOU_mean"],
        color="dodgerblue", zorder=1,
        marker="o", markeredgecolor="black", markersize=7,
        label="Rashomon Set")

ax.plot(lasso_df["n_per_pol"], lasso_df["IOU_mean"],
        color="indianred", zorder=3, clip_on=False,
        marker="o", markeredgecolor="black", markersize=7,
        label="LASSO")

# ax.plot(ct_df["n_per_pol"], ct_df["IOU_mean"],
#         color="seagreen", zorder=3.5, clip_on=False,
#         marker="o", markeredgecolor="black", markersize=7,
#         label="Causal Tree")

ax.plot(tva_df["n_per_pol"], tva_df["IOU_mean"],
        color="seagreen", zorder=3.5, clip_on=False,
        marker="o", markeredgecolor="black", markersize=7,
        label="TVA")

ax.set_xscale("log")
ax.set_xlabel("Samples per feature", fontsize=12)
ax.set_ylim(0, 1.01)
ax.set_ylabel("Best feature set coverage", fontsize=12)

ax.legend(loc="center left", bbox_to_anchor=(1, 0.5))

plt.savefig("../Figures/worst_case/feature_coverage.png", dpi=300, bbox_inches="tight")
plt.show()

In [None]:
fig, ax = plt.subplots(figsize=(5,5))

ax.spines[['right', 'top']].set_visible(False)

ax.plot(rashomon_df["n_per_pol"], rashomon_df["min_dosage_mean"],
        color="dodgerblue", zorder=1,
        marker="o", markeredgecolor="black", markersize=7,
        label="Rashomon Set")

ax.plot(lasso_df["n_per_pol"], lasso_df["min_dosage_mean"],
        color="indianred", zorder=3, clip_on=False,
        marker="o", markeredgecolor="black", markersize=7,
        label="LASSO")

# ax.plot(ct_df["n_per_pol"], ct_df["min_dosage_mean"],
#         color="seagreen", zorder=3.5, clip_on=False,
#         marker="o", markeredgecolor="black", markersize=7,
#         label="Causal Tree")

ax.plot(tva_df["n_per_pol"], tva_df["min_dosage_mean"],
        color="seagreen", zorder=3.5, clip_on=False,
        marker="o", markeredgecolor="black", markersize=7,
        label="TVA")

ax.set_xscale("log")
ax.set_xlabel("Samples per feature", fontsize=12)
ax.set_ylim(0, 1.01)
ax.set_ylabel("Minimum dosage inclusion", fontsize=12)

ax.legend(loc="center left", bbox_to_anchor=(1, 0.5))

plt.savefig("../Figures/worst_case/min_dosage_inclusion_ct.png", dpi=300, bbox_inches="tight")
plt.show()

In [None]:
fig, ax = plt.subplots(figsize=(5,5))

ax.spines[['right', 'top']].set_visible(False)

ax.plot(rashomon_df["n_per_pol"], rashomon_df["best_pol_MSE_mean"],
        color="dodgerblue", zorder=1,
        marker="o", markeredgecolor="black", markersize=7,
        label="Rashomon Set")

ax.plot(lasso_df["n_per_pol"], lasso_df["best_pol_MSE_mean"],
        color="indianred", zorder=3, clip_on=False,
        marker="o", markeredgecolor="black", markersize=7,
        label="LASSO")

# ax.plot(ct_df["n_per_pol"], ct_df["best_pol_MSE_mean"],
#         color="seagreen", zorder=3.5, clip_on=False,
#         marker="o", markeredgecolor="black", markersize=7,
#         label="Causal Tree")

ax.plot(tva_df["n_per_pol"], tva_df["best_pol_MSE_mean"],
        color="seagreen", zorder=3.5, clip_on=False,
        marker="o", markeredgecolor="black", markersize=7,
        label="TVA")

ax.set_xscale("log")
ax.set_xlabel("Samples per feature", fontsize=12)
# ax.set_ylim(0, 1)
ax.set_ylabel("MSE for best feature outcome", fontsize=12)

ax.legend(loc="center left", bbox_to_anchor=(1, 0.5))

plt.savefig("../Figures/worst_case/best_feature_MSE.png", dpi=300, bbox_inches="tight")
plt.show()

In [None]:
fig, ax = plt.subplots(nrows=1, ncols=3, figsize=(18, 5))

for i in range(3):
    ax[i].spines[['right', 'top']].set_visible(False)
    ax[i].set_xscale("log")
    ax[i].set_xlabel("Samples per feature", fontsize=12)


ax[0].plot(rashomon_df["n_per_pol"], rashomon_df["MSE_mean"],
        color="dodgerblue", zorder=1,
        marker="o", markeredgecolor="black", markersize=7,
        label="Rashomon Set")
ax[0].plot(lasso_df["n_per_pol"], lasso_df["MSE_mean"],
        color="indianred", zorder=3, clip_on=False,
        marker="o", markeredgecolor="black", markersize=7,
        label="LASSO")
ax[0].plot(tva_df["n_per_pol"], tva_df["MSE_mean"],
        color="seagreen", zorder=3, clip_on=False,
        marker="o", markeredgecolor="black", markersize=7,
        label="TVA")

ax[0].set_ylim(0.8, 1.1)
ax[0].set_ylabel("MSE", fontsize=12)

ax[1].plot(rashomon_df["n_per_pol"], rashomon_df["IOU_mean"],
        color="dodgerblue", zorder=1,
        marker="o", markeredgecolor="black", markersize=7,
        label="Rashomon Set")

ax[1].plot(lasso_df["n_per_pol"], lasso_df["IOU_mean"],
        color="indianred", zorder=3, clip_on=False,
        marker="o", markeredgecolor="black", markersize=7,
        label="LASSO")

ax[1].plot(tva_df["n_per_pol"], tva_df["IOU_mean"],
        color="seagreen", zorder=3.5, clip_on=False,
        marker="o", markeredgecolor="black", markersize=7,
        label="TVA")

ax[1].set_ylim(0, 1.01)
ax[1].set_ylabel("Best feature set coverage", fontsize=12)


ax[2].plot(rashomon_df["n_per_pol"], rashomon_df["best_pol_MSE_mean"],
        color="dodgerblue", zorder=1,
        marker="o", markeredgecolor="black", markersize=7,
        label="Rashomon Set")

ax[2].plot(lasso_df["n_per_pol"], lasso_df["best_pol_MSE_mean"],
        color="indianred", zorder=3, clip_on=False,
        marker="o", markeredgecolor="black", markersize=7,
        label="LASSO")

ax[2].plot(tva_df["n_per_pol"], tva_df["best_pol_MSE_mean"],
        color="seagreen", zorder=3.5, clip_on=False,
        marker="o", markeredgecolor="black", markersize=7,
        label="TVA")

ax[2].set_ylabel("MSE for best feature outcome", fontsize=12)


ax[2].legend(loc="upper right")#, bbox_to_anchor=(0.5, 0.5))

plt.savefig("../Figures/worst_case/mse_coverage_best_eff.png", dpi=300, bbox_inches="tight")

plt.show()

### Heat map

In [None]:
heatmap_df = rashomon_raw_df.copy()

reg = 1e-2
heatmap_df["loss"] = heatmap_df["MSE"] + reg * heatmap_df["num_pools"]
heatmap_df["posterior"] = np.exp(-heatmap_df["loss"])
# heatmap_df["posterior"] = (heatmap_df["posterior"] - np.max(heatmap_df["posterior"])) / np.max(heatmap_df["posterior"])
# heatmap_df["posterior"] = (heatmap_df["posterior"] - np.min(heatmap_df["posterior"])) / \
#     (np.max(heatmap_df["posterior"]) - np.min(heatmap_df["posterior"]))


heatmap_dfs = [None]*4
heatmap_dfs[0] = heatmap_df[heatmap_df["n_per_pol"] == 10].copy()
heatmap_dfs[1] = heatmap_df[heatmap_df["n_per_pol"] == 50].copy()
heatmap_dfs[2] = heatmap_df[heatmap_df["n_per_pol"] == 100].copy()
heatmap_dfs[3] = heatmap_df[heatmap_df["n_per_pol"] == 1000].copy()
for i in range(4):
    heatmap_dfs[i]["posterior"] = (heatmap_dfs[i]["posterior"] - np.max(heatmap_dfs[i]["posterior"])) / np.max(heatmap_dfs[i]["posterior"])
    # heatmap_dfs[i]["posterior"] = (heatmap_dfs[i]["posterior"] - np.min(heatmap_dfs[i]["posterior"])) / \
    #     (np.max(heatmap_dfs[i]["posterior"]) - np.min(heatmap_dfs[i]["posterior"]))

In [None]:
heatmap_df.head()

In [None]:
fig, ax = plt.subplots(2, 2, figsize=(12,10))

from matplotlib import colors

h = ax[0, 0].hist2d(heatmap_dfs[0]["num_pools"], heatmap_dfs[0]["posterior"], norm=colors.LogNorm(),
             cmap="OrRd",
             weights=[1e-2]*len(heatmap_dfs[0])
            #  weights=[1]*len(heatmap_dfs[0])
             )
ax[0, 0].set_title("Samples per feature = 10", fontsize=12)
cb_00 = fig.colorbar(h[3], norm=colors.NoNorm, ax=ax[0, 0])
# ticks_00 = [0.01, 0.1, 0.5, 1]
# # ticks_00 = [0.01, 0.1, 1, 10, 20, 100]
# # ticks_00 = [1, 10, 20, 40]
# cb_00.set_ticks(ticks_00)
# cb_00.set_ticklabels(ticks_00)


h = ax[0, 1].hist2d(heatmap_dfs[1]["num_pools"], heatmap_dfs[1]["posterior"], norm=colors.LogNorm(),
             cmap="OrRd", weights=[1e-2]*len(heatmap_dfs[1]))
ax[0, 1].set_title("Samples per feature = 50", fontsize=12)
cb_01 = fig.colorbar(h[3], norm=colors.NoNorm, ax=ax[0, 1])
# ticks_01 = [0.01, 0.1, 0.5, 1]
# # ticks_01 = [0.01, 0.1, 1, 10, 40]
# # ticks_01 = [1, 10, 30, 60]
# cb_01.set_ticks(ticks_01)
# cb_01.set_ticklabels(ticks_01)

h = ax[1, 0].hist2d(heatmap_dfs[2]["num_pools"], heatmap_dfs[2]["posterior"], norm=colors.LogNorm(),
             cmap="OrRd", weights=[1e-2]*len(heatmap_dfs[2]))
ax[1, 0].set_title("Samples per feature = 100", fontsize=12)
cb_10 = fig.colorbar(h[3], norm=colors.NoNorm, ax=ax[1, 0])
# ticks_10 = [0.01, 0.1, 0.5, 1]
# # ticks_10 = [0.01, 0.1, 1, 10, 20]
# # ticks_10 = [1, 10, 30, 70]
# cb_10.set_ticks(ticks_10)
# cb_10.set_ticklabels(ticks_10)

h = ax[1, 1].hist2d(heatmap_dfs[3]["num_pools"], heatmap_dfs[3]["posterior"], norm=colors.LogNorm(),
             cmap="OrRd", weights=[1e-2]*len(heatmap_dfs[3]))
ax[1, 1].set_title("Samples per feature = 1000", fontsize=12)
cb_11 = fig.colorbar(h[3], norm=colors.NoNorm, ax=ax[1, 1])
# ticks_11 = [0.01, 0.1, 0.5, 1]
# # ticks_11 = [0.01, 1, 10, 20]
# # ticks_11 = [1, 10, 50, 100]
# cb_11.set_ticks(ticks_11)
# cb_11.set_ticklabels(ticks_11)

for i in range(2):
    for j in range(2):
        idx = i * 2 + j
        ax[i, j].set_ylim(np.min(heatmap_dfs[idx]["posterior"]), np.max(heatmap_dfs[idx]["posterior"]))
        # ax[i, j].set_ylim(0, 1)
        # ax[i, j].set_xlim(2, 9)
        # ax[i, j].set_xlim(8, 13)
        ax[i, j].set_xlim(2, 18)
#         ax[i, j].set_ylabel("Scaled Posterior probability", fontsize=12)


fig.supylabel("Relative posterior probability ratio", fontsize=14)
fig.supxlabel("Model size", fontsize=14)

# plt.savefig("../Figures/sim_" + n_str + "/rset_2d_hist.png", dpi=300, bbox_inches="tight")
plt.savefig("../Figures/worst_case/rset_2d_hist.png", dpi=300, bbox_inches="tight")
plt.show()

### Sample heatmap

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(6,5))

h = ax.hist2d(heatmap_df["num_pools"], heatmap_df["posterior"], norm=colors.LogNorm(),
             cmap="OrRd")#, weights=[1e-2]*len(heatmap_df))
cb = fig.colorbar(h[3], norm=colors.NoNorm, ax=ax)
ticks = [1, 10, 50, 100, 200, 500]
cb.set_ticks(ticks)
cb.set_ticklabels(ticks)

ax.plot([1, 11], [-0.24, -0.24], color="black", linestyle="--", linewidth=2)
ax.plot([6.2, 6.2], [-0.6, 0.1], color="black", linestyle="--", linewidth=2)



fig.supylabel("Relative posterior probability ratio", fontsize=14)
fig.supxlabel("Model size", fontsize=14)

# plt.savefig("../Figures/ex_2d_hist.png", dpi=300, bbox_inches="tight")
plt.show()