In [None]:
import glob
from itertools import product
import os
import random
import warnings

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns

%matplotlib inline
%load_ext autoreload
%autoreload 2
plt.rcParams['text.usetex'] = True
plt.rcParams['font.family'] = 'sans-serif'

sns.__version__

In [None]:
MASTER_PATH = "/data4/username/disparate_censorship_mitigation/sweep_20230704_v6/"
results = pd.read_csv(os.path.join(MASTER_PATH, "results.csv"), index_col=np.arange(6))
results.head()

In [None]:
sepsis_results = []
#sepsis_bs = []
for f in sorted(glob.glob("/data4/username/disparate_censorship_mitigation_sepsis/sepsis_20230724_*alpha*/")):
    df = pd.read_csv(os.path.join(f, "results.csv"), index_col=np.arange(4))
    if "testedonly" not in f:
        df = df[df.index.get_level_values("model") != "DCEMModel"]
    sepsis_results.append(df)
    
sepsis_results = pd.concat(sepsis_results, keys=np.tile(np.linspace(0, 1, 11), 2))
sepsis_results

In [None]:
phase_results = pd.read_csv("/data4/username/disparate_censorship_mitigation/sweep_20230704_v6/results.csv", index_col=np.arange(6))

all_results = [phase_results.reorder_levels(["k", "prevalence_disparity", "target_prevalence", "testing_disparity", "model", "metric"])]
for p in [30, 60, 90, 120, 150, 180, 210, 240, 270, 300, 330]:
    pt = f"/data4/username/disparate_censorship_mitigation/sweep_20230704_v6_phase{int(p)}"
    phase_result = pd.read_csv(os.path.join(pt, "results.csv"), index_col=np.arange(6))
    all_results.append(phase_result)

all_results = pd.concat(all_results, keys=[0, 30, 60, 90, 120, 150, 180, 210, 240, 270, 300, 330])
all_results = all_results.reorder_levels(["target_prevalence", "testing_disparity", "prevalence_disparity", "k", "model", None, "metric"])
all_results

In [None]:
tarreg_phase_results = pd.read_csv("/data4/username/disparate_censorship_mitigation/sweep_20230704_v6_tarreg/results.csv", index_col=np.arange(6))

tarreg_all_results = [phase_results.reorder_levels(["k", "prevalence_disparity", "target_prevalence", "testing_disparity", "model", "metric"])]
for p in [30, 60, 90, 120, 150, 180, 210, 240, 270, 300, 330]:
    pt = f"/data4/username/disparate_censorship_mitigation/sweep_20230704_v6_tarreg_phase{int(p)}"
    res = pd.read_csv(os.path.join(pt, "results.csv"), index_col=np.arange(6))
    tarreg_all_results.append(res)

tarreg_all_results = pd.concat(tarreg_all_results, keys=[0, 30, 60, 90, 120, 150, 180, 210, 240, 270, 300, 330])
tarreg_all_results = tarreg_all_results.reorder_levels(["target_prevalence", "testing_disparity", "prevalence_disparity", "k", "model", None, "metric"])
tarreg_all_results

In [None]:
# Collate all the final results
all_results.to_csv("../all_fully_synthetic_results.csv")
sepsis_results.to_csv("../all_sepsis_results.csv")

# "Main result" template

Creates a seaborn pointplot (as in Fig. 2, 4).

In [None]:
METRIC = "AUC"
COL = "overall"

xname = "Testing disparity"
yname = "overall AUC"

MODEL_KWARGS = {
    "YModel": {"color": "green", "linestyle": "dashed"},
    "YObsModel": {"color": "red", "linestyle": "dotted"},
    "TestedOnlyModel": {"color": "darkseagreen"},
    "RecensoringModel": {"color": "purple"},
    "Group0BaselineModel": {"color": "tab:blue"},
    "Group1BaselineModel": {"color": "tab:orange"},
    "PeerLossModel": {"color": "goldenrod"},
    "GroupPeerLossModel": {"color": "gold"},
    "JSModel": {"color": "steelblue"},
    "TruncatedLQModel": {"color": "slategray"},
    "DivideMixBasedModel": {"color": "skyblue"},
    "ITECorrectedModel": {"color": "chocolate"},
    "SELFModel": {"color": "midnightblue"},
    "SAREMModel": {"color": "purple"},
    "DCEMModel": {"color": "magenta"},
}

NAME_REPLACEMENT = {
    "YModel": "$y$-model (oracle)",
    "YObsModel": "$y$-obs model",
    "TestedOnlyModel": "Tested-only",
    "Group0BaselineModel": "Group 0 only",
    "Group1BaselineModel": "Group 1 only",
    "PeerLossModel": "Peer loss",
    "GroupPeerLossModel": "Group peer loss",
    "ITECorrectedModel": "DragonNet",
    "JSModel": "Generalized JS",
    "TruncatedLQModel": "Truncated LQ",
    "SELFModel": "SELF",
    "SAREMModel": "SAREM",
    "DCEMModel": "DCEM (ours)"
}


def get_ci(vals):
    return pd.DataFrame([np.quantile(vals, 0.025), np.quantile(vals, 0.975)]).T

def plot_fuzzy_barplot(xname, ynames, metrics, cols,
                       fixed_vars=[0.5, 1], fixed_levels=["prevalence_disparity", "k"],
                        latex_vars=["q_y", "k"], extra_latex="",
                        df=all_results, dodge=0.72, x='testing_disparity',
                 include_models=None, logy=False, strip_dodge=0.1, figsize=(9, 2.1),
                      bbox_to_anchor=(0.5, -0.2), title_y=1.):
    fig, axs = plt.subplots(1, len(ynames), figsize=figsize)
    latex_expr = ", ".join([f"{l}={f}" for l, f in zip(latex_vars, fixed_vars)])
    var_strs = ' and '.join(ynames)
    suptitle = fig.suptitle(f"Comparison of {var_strs} across models, ${latex_expr + extra_latex}$",  y=title_y)
    for i in range(len(ynames)):
        ax = axs[i]
        ax.set_title(r"$\downarrow$ Bias mitigation" if "ROC" in ynames[i] else r"$\uparrow$ Discriminative performance")
        if logy:
            ax.set_yscale('log')

        # reset index so that testing disparities and model are named columns
        result_slice = df.xs(tuple(fixed_vars + [metrics[i]]), level=tuple(fixed_levels + ["metric"])).droplevel(0).reset_index()
        result_slice = result_slice[result_slice["model"].isin(include_models)]

        sns.stripplot(data=result_slice.sort_values(by="model", key=lambda column: column.map(lambda e: include_models.index(e))),
                      x=x, y=cols[i], hue="model", ax=ax,
                      jitter=False, dodge=strip_dodge, alpha=0.3, palette={k: v["color"] for k, v in MODEL_KWARGS.items()}, legend=False)
        g = sns.pointplot(data=result_slice.groupby([x, "model"]).median().reset_index().sort_values(by="model", key=lambda column: column.map(lambda x: include_models.index(x))),
                      x=x, y=cols[i], hue="model", ax=ax, dodge=dodge,
                      linestyles="none", markers="_",scale=1.1,
                    palette={k: v["color"] for k, v in MODEL_KWARGS.items()})

        
        if metrics[i] == "ROCGap":
            extrema = result_slice.groupby([x, "model"]).max().reset_index()
            extrema_marker = "^"
        else:
            extrema = result_slice.groupby([x, "model"]).min().reset_index()
            extrema_marker = "v"
        g_ax = sns.pointplot(data=extrema.sort_values(by="model", key=lambda column: column.map(lambda x: include_models.index(x))),
                  x=x, y=cols[i], hue="model", ax=ax, dodge=dodge,
                  linestyles="none", markers=extrema_marker,
                    alpha=0.7, palette={k: v["color"] for k, v in MODEL_KWARGS.items()},
                     markeredgecolor="#444444", markerfacecolor='none', markersize=7.,
                    markeredgewidth=1., normalize_markers=False)

        for artist in list(g_ax.get_children()):
            artist.set_zorder(100)
            
        g.legend([],[], frameon=False)
        
        ax.set_ylabel(ynames[i])
        ax.grid(visible=True, axis='y')

        if len(result_slice[x].unique()) <= 1:
            ax.set_xlabel("")
            ax.set_xticklabels([])
        else:
            ax.set_xticklabels(["1/4", "1/3", "1/2", "1", "2", "3", "4", "5", "(all)"])
    fig.tight_layout()
    fig.subplots_adjust(bottom=0.25)
    lgd = fig.legend(labels=list(map(lambda x: NAME_REPLACEMENT.get(x, x), include_models)), loc="lower center", ncols=5, bbox_to_anchor=bbox_to_anchor)
    return fig, lgd, suptitle



ORDERING = ["YModel", "YObsModel", "Group0BaselineModel", "Group1BaselineModel",
            "PeerLossModel", "GroupPeerLossModel", "TruncatedLQModel", "JSModel",
            "ITECorrectedModel",
            "DivideMixBasedModel", "SELFModel",
            "TestedOnlyModel", "SAREMModel", "XGBoostModel", "DCEMModel (ours)"]

## Figure 2 (results on synthetic data)

In [None]:
# View summary statistics
all_results.xs(
    (2, 0.5, 1), level=("testing_disparity", "prevalence_disparity", "k")
).droplevel(0).groupby(level=(0, 2)).describe()[["overall", "diff"]]

In [None]:
# FIGURE 2
# You can add to include_models to change the models shown, though you may have to adjust the dodge parameter
q_t = 2

with warnings.catch_warnings():
    warnings.filterwarnings('ignore')
    fig, lgd, title = plot_fuzzy_barplot("Testing disparity ($q_t$)", 
                   ["ROC Gap", "AUC"], # metric names on plot
                   ["ROCGap", "AUC"], # metric names in DF index
                  ["diff", "overall"], # columns in DF
                 include_models=["YModel", "YObsModel", "Group0BaselineModel", "Group1BaselineModel", "TestedOnlyModel", "SELFModel",  "ITECorrectedModel", "GroupPeerLossModel", "SAREMModel", "DCEMModel"],
                  df=all_results.xs(q_t, level="testing_disparity", drop_level=False), figsize=(9, 1.75), extra_latex=", q_t={}".format(q_t),
                      bbox_to_anchor=(0.5, -0.22), title_y=0.9)
fig.tight_layout()
fig.savefig("performance_v4.pdf", bbox_extra_artists=(lgd, title), bbox_inches='tight',)

## Figure 3 (DCEM has a better tradeoff than tested-only)

In [None]:
# FIGURE 3
# You can change the models analyzed to compare different models as well

model_1 = "DCEMModel"
model_2 = "TestedOnlyModel"
metric_1 = "AUC"
metric_2 = "ROCGap"

overall_aucs_dcem = all_results.xs((model_1, metric_1), level=("model", "metric"))["overall"]
overall_aucs_tested = all_results.xs((model_2, metric_1), level=("model", "metric"))["overall"]
overall_gaps_dcem = all_results.xs((model_1, metric_2), level=("model", "metric"))["diff"]
overall_gaps_tested = all_results.xs((model_2, metric_2), level=("model", "metric"))["diff"]

overall_aucs_dcem = overall_aucs_dcem[np.in1d(overall_aucs_dcem.index.get_level_values("k"), [1/4, 1/3, 0.5, 1, 2, 3])]
overall_aucs_tested = overall_aucs_tested[np.in1d(overall_gaps_tested.index.get_level_values("k"), [1/4, 1/3, 0.5, 1, 2, 3])]
overall_gaps_dcem = overall_gaps_dcem[np.in1d(overall_gaps_dcem.index.get_level_values("k"), [1/4, 1/3, 0.5, 1, 2, 3])]
overall_gaps_tested = overall_gaps_tested[np.in1d(overall_gaps_tested.index.get_level_values("k"), [1/4, 1/3, 0.5, 1, 2, 3])]


targets = [0.7, 0.75, 0.8, 0.85, 0.9]
limit = 0.025
fig, ax = plt.subplots(1, len(targets), figsize=(12, 3))
suptitle = fig.suptitle("Proportion of DCEM and tested-only models by ROC gap achieved, controlling for AUC")

for i, target in enumerate(targets):


    ax[i].set_title(f"All models with\nAUC={target} +/- {limit}")
    
    if i == 0: ax[i].set_ylabel("Frequency")
    ax[i].set_xlabel("ROC gap")

    def filter_val(a, target, within):
        return (a > target - within) & (a < target + within)

    dcem_mask = filter_val(overall_aucs_dcem, target, limit)
    tested_mask = filter_val(overall_aucs_tested, target, limit)
    #ax.scatter(overall_aucs_dcem[dcem_mask], overall_gaps_dcem[dcem_mask], color="magenta")
    #ax.scatter(overall_aucs_tested[tested_mask], overall_gaps_tested[tested_mask], color="aquamarine")

    bins = np.linspace(0, 0.23, 20)

    ax[i].hist(overall_gaps_dcem[dcem_mask], color="magenta", alpha=0.6, bins=bins, weights=np.ones_like(overall_gaps_dcem[dcem_mask]) / dcem_mask.sum(), label=f"DCEM (n={dcem_mask.sum()})")
    ax[i].hist(overall_gaps_tested[tested_mask], color="darkseagreen", alpha=0.6, bins=bins, weights=np.ones_like(overall_gaps_tested[tested_mask]) / tested_mask.sum(), label=f"tested-only (n={tested_mask.sum()})")
    
    print(f"Mean(DCEM) @ {target} +/- {limit}:", overall_gaps_dcem[dcem_mask].mean())
    print(f"Mean(tested-only) @ {target} +/- {limit}:", overall_gaps_tested[tested_mask].mean())

    ax[i].vlines([overall_gaps_dcem[dcem_mask].mean()], ymin=0, ymax=0.45, color="magenta", linestyle="dashed", alpha=0.6)
    ax[i].vlines([overall_gaps_tested[tested_mask].mean()], ymin=0, ymax=0.45, color="darkseagreen", linestyle="dashed", alpha=0.6)

    ax[i].set_ylim((0, 0.45))
    lgd = ax[i].legend(loc="lower center", bbox_to_anchor=(0.5, -1.6))
   
fig.tight_layout()
fig.subplots_adjust(top=0.7, bottom=0.5, wspace=0.3)

fig.savefig("no_tradeoff.pdf", bbox_extra_artists=(lgd, suptitle), bbox_inches="tight", pad_inches=0.1)


## Figure 4 (Results on sepsis classification)

In [None]:
# Preview the summary statistics
sepsis_results.xs((4., 1.5), level=("k", "testing_disparity")).groupby(level=(1, 2)).describe()[["overall", "diff"]]

In [None]:
# FIGURE 4
with warnings.catch_warnings():
    warnings.filterwarnings('ignore')
    fig, lgd, title = plot_fuzzy_barplot("Testing disparity ($q_t$)", 
                       ["ROC Gap", "AUC"],# "AUC Gap"],
                       ["ROCGap", "AUC"],# "AUC"],
                      ["diff",
                       "overall"],# "diff"], 
                     include_models=["YModel", "YObsModel",
                                     "TestedOnlyModel", 'Group0BaselineModel',
                                     'Group1BaselineModel', "GroupPeerLossModel",
                                     "SELFModel","ITECorrectedModel", 
                                     "SAREMModel", "DCEMModel"],
                        fixed_levels=["testing_disparity"],
                        fixed_vars=[1.5],
                        latex_vars=["q_t"], extra_latex=", k=4",
                        x="k", dodge=0.72, strip_dodge=0.01,
                                     figsize=(9, 1.6), bbox_to_anchor=(0.5, -0.25),
                      df=sepsis_results.xs(4., level="k", drop_level=False),
                                    title_y=0.9)

fig.show()
fig.tight_layout()
fig.savefig("sepsis_performance.pdf", bbox_extra_artists=(lgd, title), bbox_inches='tight',)

# Appendix

Below follows boilerplate for the remaining figures and some extra analyses.

In [None]:
fig, lgd, title = plot_fuzzy_barplot("Testing disparity ($q_t$)", 
                   ["ROC Gap", "AUC"],
                   ["ROCGap", "AUC"],
                  ["diff", "overall"],
                 include_models=["YModel", "YObsModel", "Group0BaselineModel",
                                 "Group1BaselineModel", "TestedOnlyModel", "SELFModel", "DivideMixBasedModel",
                                 "JSModel", "TruncatedLQModel", 
                                 "ITECorrectedModel", "PeerLossModel", "GroupPeerLossModel", "SAREMModel", "DCEMModel"],
                  df=all_results.xs(0.5, level="testing_disparity", drop_level=False,
                                   ), figsize=(9, 2.), extra_latex=", q_t=0.5", dodge=0.74,
                      bbox_to_anchor=(0.5, -0.25))
fig.savefig("performance_qy_0.5_k_1_qt_0.5.pdf", bbox_extra_artists=(lgd, title), bbox_inches='tight',)

In [None]:
fig, lgd, title = plot_fuzzy_barplot("Testing disparity ($q_t$)", 
                   ["ROC Gap", "AUC"],# "AUC Gap"],
                   ["ROCGap", "AUC"],# "AUC"],
                  ["diff", "overall"],# "diff"], 
                    include_models=["YModel", "YObsModel", "Group0BaselineModel",
                     "Group1BaselineModel", "TestedOnlyModel", "SELFModel", "DivideMixBasedModel",
                     "JSModel", "TruncatedLQModel", 
                     "ITECorrectedModel", "PeerLossModel", "GroupPeerLossModel", "SAREMModel", "DCEMModel"],
                  df=all_results.xs(2, level="testing_disparity", drop_level=False,
                                   ), figsize=(9, 2.), extra_latex=", q_t=2",dodge=0.74,
                      bbox_to_anchor=(0.5, -0.25))
fig.savefig("performance_qy_0.5_k_1_qt_2.pdf", bbox_extra_artists=(lgd, title), bbox_inches='tight',)

In [None]:
fig, lgd, title = plot_fuzzy_barplot("Testing disparity ($q_t$)", 
                   ["ROC Gap", "AUC"],# "AUC Gap"],
                   ["ROCGap", "AUC"],# "AUC"],
                  ["diff", "overall"],# "diff"], 
                include_models=["YModel", "YObsModel", "Group0BaselineModel",
                     "Group1BaselineModel", "TestedOnlyModel", "SELFModel", "DivideMixBasedModel",
                     "JSModel", "TruncatedLQModel", 
                     "ITECorrectedModel", "PeerLossModel", "GroupPeerLossModel", "SAREMModel", "DCEMModel"],                  
                    df=all_results.xs(1, level="testing_disparity", drop_level=False,
                                       ), figsize=(9, 2.), extra_latex=", q_t=1",dodge=0.74,
                      bbox_to_anchor=(0.5, -0.25))
                                                                                                                                                
fig.savefig("performance_qy_0.5_k_1_qt_1.pdf", bbox_extra_artists=(lgd, title), bbox_inches='tight',)

In [None]:
fig, lgd, title = plot_fuzzy_barplot("Testing disparity ($q_t$)", 
                   ["ROC Gap", "AUC"],# "AUC Gap"],
                   ["ROCGap", "AUC"],# "AUC"],
                  ["diff", "overall"],# "diff"], 
                    fixed_vars=[0.5, 1/3],
                    include_models=["YModel", "YObsModel", "Group0BaselineModel",
                     "Group1BaselineModel", "TestedOnlyModel", "SELFModel", "DivideMixBasedModel",
                     "JSModel", "TruncatedLQModel", 
                     "ITECorrectedModel", "PeerLossModel", "GroupPeerLossModel", "SAREMModel", "DCEMModel"],
                  df=all_results.xs(2, level="testing_disparity", drop_level=False,
                                   ), figsize=(9, 2.), extra_latex=", q_t=2",dodge=0.74,
                      bbox_to_anchor=(0.5, -0.25))
fig.savefig("performance_qy_0.5_k_0.3333333333333333_qt_2.pdf", bbox_extra_artists=(lgd, title), bbox_inches='tight',)
fig, lgd, title = plot_fuzzy_barplot("Testing disparity ($q_t$)", 
                   ["ROC Gap", "AUC"],# "AUC Gap"],
                   ["ROCGap", "AUC"],# "AUC"],
                  ["diff", "overall"],# "diff"], 
                  fixed_vars=[0.5, 1/3],
                include_models=["YModel", "YObsModel", "Group0BaselineModel",
                     "Group1BaselineModel", "TestedOnlyModel", "SELFModel", "DivideMixBasedModel",
                     "JSModel", "TruncatedLQModel", 
                     "ITECorrectedModel", "PeerLossModel", "GroupPeerLossModel", "SAREMModel", "DCEMModel"],             
                               df=all_results.xs(1, level="testing_disparity", drop_level=False,
                                   ), figsize=(9, 2.), extra_latex=", q_t=1",dodge=0.74,
                      bbox_to_anchor=(0.5, -0.25), 
                    )

fig.savefig("performance_qy_0.5_k_0.3333333333333333_qt_1.pdf", bbox_extra_artists=(lgd, title), bbox_inches='tight',)
fig, lgd, title = plot_fuzzy_barplot("Testing disparity ($q_t$)", 
                   ["ROC Gap", "AUC"],# "AUC Gap"],
                   ["ROCGap", "AUC"],# "AUC"],
                  ["diff", "overall"],# "diff"], 
                fixed_vars=[0.5, 1/3],
                 include_models=["YModel", "YObsModel", "Group0BaselineModel",
                                 "Group1BaselineModel", "TestedOnlyModel", "SELFModel", "DivideMixBasedModel",
                                 "JSModel", "TruncatedLQModel", 
                                 "ITECorrectedModel", "PeerLossModel", "GroupPeerLossModel", "SAREMModel", "DCEMModel"],
                  df=all_results.xs(0.5, level="testing_disparity", drop_level=False,
                                   ), figsize=(9, 2.), extra_latex=", q_t=0.5", dodge=0.74,
                      bbox_to_anchor=(0.5, -0.25))
fig.savefig("performance_qy_0.5_k_0.3333333333333333_qt_0.5.pdf", bbox_extra_artists=(lgd, title), bbox_inches='tight',)

In [None]:
fig, lgd, title = plot_fuzzy_barplot("Testing disparity ($q_t$)", 
                   ["ROC Gap", "AUC"],# "AUC Gap"],
                   ["ROCGap", "AUC"],# "AUC"],
                  ["diff", "overall"],# "diff"], 
                    fixed_vars=[0.5, 1/2],
                    include_models=["YModel", "YObsModel", "Group0BaselineModel",
                     "Group1BaselineModel", "TestedOnlyModel", "SELFModel", "DivideMixBasedModel",
                     "JSModel", "TruncatedLQModel", 
                     "ITECorrectedModel", "PeerLossModel", "GroupPeerLossModel", "SAREMModel", "DCEMModel"],
                  df=all_results.xs(2, level="testing_disparity", drop_level=False,
                                   ), figsize=(9, 2.), extra_latex=", q_t=2",dodge=0.74,
                      bbox_to_anchor=(0.5, -0.25))
fig.savefig("performance_qy_0.5_k_0.5_qt_2.pdf", bbox_extra_artists=(lgd, title), bbox_inches='tight',)
fig, lgd, title = plot_fuzzy_barplot("Testing disparity ($q_t$)", 
                   ["ROC Gap", "AUC"],# "AUC Gap"],
                   ["ROCGap", "AUC"],# "AUC"],
                  ["diff", "overall"],# "diff"], 
                  fixed_vars=[0.5, 1/2],
                include_models=["YModel", "YObsModel", "Group0BaselineModel",
                     "Group1BaselineModel", "TestedOnlyModel", "SELFModel", "DivideMixBasedModel",
                     "JSModel", "TruncatedLQModel", 
                     "ITECorrectedModel", "PeerLossModel", "GroupPeerLossModel", "SAREMModel", "DCEMModel"],             
                               df=all_results.xs(1, level="testing_disparity", drop_level=False,
                                   ), figsize=(9, 2.), extra_latex=", q_t=1",dodge=0.74,
                      bbox_to_anchor=(0.5, -0.25), 
                    )

fig.savefig("performance_qy_0.5_k_0.5_qt_1.pdf", bbox_extra_artists=(lgd, title), bbox_inches='tight',)
fig, lgd, title = plot_fuzzy_barplot("Testing disparity ($q_t$)", 
                   ["ROC Gap", "AUC"],# "AUC Gap"],
                   ["ROCGap", "AUC"],# "AUC"],
                  ["diff", "overall"],# "diff"], 
                fixed_vars=[0.5, 1/2],
                 include_models=["YModel", "YObsModel", "Group0BaselineModel",
                                 "Group1BaselineModel", "TestedOnlyModel", "SELFModel", "DivideMixBasedModel",
                                 "JSModel", "TruncatedLQModel", 
                                 "ITECorrectedModel", "PeerLossModel", "GroupPeerLossModel", "SAREMModel", "DCEMModel"],
                  df=all_results.xs(0.5, level="testing_disparity", drop_level=False,
                                   ), figsize=(9, 2.), extra_latex=", q_t=0.5", dodge=0.74,
                      bbox_to_anchor=(0.5, -0.25))
fig.savefig("performance_qy_0.5_k_0.5_qt_0.5.pdf", bbox_extra_artists=(lgd, title), bbox_inches='tight',)

In [None]:
fig, lgd, title = plot_fuzzy_barplot("Testing disparity ($q_t$)", 
                   ["ROC Gap", "AUC"],# "AUC Gap"],
                   ["ROCGap", "AUC"],# "AUC"],
                  ["diff", "overall"],# "diff"], 
                    fixed_vars=[0.5, 2],
                    include_models=["YModel", "YObsModel", "Group0BaselineModel",
                     "Group1BaselineModel", "TestedOnlyModel", "SELFModel", "DivideMixBasedModel",
                     "JSModel", "TruncatedLQModel", 
                     "ITECorrectedModel", "PeerLossModel", "GroupPeerLossModel", "SAREMModel", "DCEMModel"],
                  df=all_results.xs(2, level="testing_disparity", drop_level=False,
                                   ), figsize=(9, 2.), extra_latex=", q_t=2",dodge=0.74,
                      bbox_to_anchor=(0.5, -0.25))
fig.savefig("performance_qy_0.5_k_2_qt_2.pdf", bbox_extra_artists=(lgd, title), bbox_inches='tight',)
fig, lgd, title = plot_fuzzy_barplot("Testing disparity ($q_t$)", 
                   ["ROC Gap", "AUC"],# "AUC Gap"],
                   ["ROCGap", "AUC"],# "AUC"],
                  ["diff", "overall"],# "diff"], 
                  fixed_vars=[0.5, 2],
                include_models=["YModel", "YObsModel", "Group0BaselineModel",
                     "Group1BaselineModel", "TestedOnlyModel", "SELFModel", "DivideMixBasedModel",
                     "JSModel", "TruncatedLQModel", 
                     "ITECorrectedModel", "PeerLossModel", "GroupPeerLossModel", "SAREMModel", "DCEMModel"],             
                               df=all_results.xs(1, level="testing_disparity", drop_level=False,
                                   ), figsize=(9, 2.), extra_latex=", q_t=1",dodge=0.74,
                      bbox_to_anchor=(0.5, -0.25), 
                    )

fig.savefig("performance_qy_0.5_k_2_qt_1.pdf", bbox_extra_artists=(lgd, title), bbox_inches='tight',)
fig, lgd, title = plot_fuzzy_barplot("Testing disparity ($q_t$)", 
                   ["ROC Gap", "AUC"],# "AUC Gap"],
                   ["ROCGap", "AUC"],# "AUC"],
                  ["diff", "overall"],# "diff"], 
                fixed_vars=[0.5, 2],
                 include_models=["YModel", "YObsModel", "Group0BaselineModel",
                                 "Group1BaselineModel", "TestedOnlyModel", "SELFModel", "DivideMixBasedModel",
                                 "JSModel", "TruncatedLQModel", 
                                 "ITECorrectedModel", "PeerLossModel", "GroupPeerLossModel", "SAREMModel", "DCEMModel"],
                  df=all_results.xs(0.5, level="testing_disparity", drop_level=False,
                                   ), figsize=(9, 2.), extra_latex=", q_t=0.5", dodge=0.74,
                      bbox_to_anchor=(0.5, -0.25))
fig.savefig("performance_qy_0.5_k_2_qt_0.5.pdf", bbox_extra_artists=(lgd, title), bbox_inches='tight',)

In [None]:
fig, lgd, title = plot_fuzzy_barplot("Testing disparity ($q_t$)", 
                   ["ROC Gap", "AUC"],# "AUC Gap"],
                   ["ROCGap", "AUC"],# "AUC"],
                  ["diff", "overall"],# "diff"], 
                    fixed_vars=[0.5, 3],
                    include_models=["YModel", "YObsModel", "Group0BaselineModel",
                     "Group1BaselineModel", "TestedOnlyModel", "SELFModel", "DivideMixBasedModel",
                     "JSModel", "TruncatedLQModel", 
                     "ITECorrectedModel", "PeerLossModel", "GroupPeerLossModel", "SAREMModel", "DCEMModel"],
                  df=all_results.xs(2, level="testing_disparity", drop_level=False,
                                   ), figsize=(9, 2.), extra_latex=", q_t=2",dodge=0.74,
                      bbox_to_anchor=(0.5, -0.25))
fig.savefig("performance_qy_0.5_k_3_qt_2.pdf", bbox_extra_artists=(lgd, title), bbox_inches='tight',)
fig, lgd, title = plot_fuzzy_barplot("Testing disparity ($q_t$)", 
                   ["ROC Gap", "AUC"],# "AUC Gap"],
                   ["ROCGap", "AUC"],# "AUC"],
                  ["diff", "overall"],# "diff"], 
                  fixed_vars=[0.5, 3],
                include_models=["YModel", "YObsModel", "Group0BaselineModel",
                     "Group1BaselineModel", "TestedOnlyModel", "SELFModel", "DivideMixBasedModel",
                     "JSModel", "TruncatedLQModel", 
                     "ITECorrectedModel", "PeerLossModel", "GroupPeerLossModel", "SAREMModel", "DCEMModel"],             
                               df=all_results.xs(1, level="testing_disparity", drop_level=False,
                                   ), figsize=(9, 2.), extra_latex=", q_t=1",dodge=0.74,
                      bbox_to_anchor=(0.5, -0.25), 
                    )

fig.savefig("performance_qy_0.5_k_3_qt_1.pdf", bbox_extra_artists=(lgd, title), bbox_inches='tight',)
fig, lgd, title = plot_fuzzy_barplot("Testing disparity ($q_t$)", 
                   ["ROC Gap", "AUC"],# "AUC Gap"],
                   ["ROCGap", "AUC"],# "AUC"],
                  ["diff", "overall"],# "diff"], 
                fixed_vars=[0.5, 3],
                 include_models=["YModel", "YObsModel", "Group0BaselineModel",
                                 "Group1BaselineModel", "TestedOnlyModel", "SELFModel", "DivideMixBasedModel",
                                 "JSModel", "TruncatedLQModel", 
                                 "ITECorrectedModel", "PeerLossModel", "GroupPeerLossModel", "SAREMModel", "DCEMModel"],
                  df=all_results.xs(0.5, level="testing_disparity", drop_level=False,
                                   ), figsize=(9, 2.), extra_latex=", q_t=0.5", dodge=0.74,
                      bbox_to_anchor=(0.5, -0.25))
fig.savefig("performance_qy_0.5_k_3_qt_0.5.pdf", bbox_extra_artists=(lgd, title), bbox_inches='tight',)

In [None]:

fig, lgd, title = plot_fuzzy_barplot("Testing disparity ($q_t$)", 
                       ["ROC Gap", "AUC"],# "AUC Gap"],
                       ["ROCGap", "AUC"],# "AUC"],
                      ["diff",
                       "overall"],# "diff"], 
                     include_models=["YModel", "YObsModel",
                                     "TestedOnlyModel", 'Group0BaselineModel',
                                     'Group1BaselineModel', "GroupPeerLossModel",
                                     "SELFModel","ITECorrectedModel", 
                                     "SAREMModel", "DCEMModel"],
                        fixed_levels=["testing_disparity"],
                        fixed_vars=[1.5],
                        latex_vars=["q_t"], extra_latex=", k=0.25",
                        x="k", dodge=0.72, strip_dodge=0.01,
                                     figsize=(9, 1.7), bbox_to_anchor=(0.5, -0.35),
                      df=sepsis_results.xs(0.25, level="k", drop_level=False))

fig.show()
fig.savefig("sepsis_performance_k0.25.pdf", bbox_extra_artists=(lgd, title), bbox_inches='tight',)

fig, lgd, title = plot_fuzzy_barplot("Testing disparity ($q_t$)", 
                       ["ROC Gap", "AUC"],# "AUC Gap"],
                       ["ROCGap", "AUC"],# "AUC"],
                      ["diff",
                       "overall"],# "diff"], 
                     include_models=["YModel", "YObsModel",
                                     "TestedOnlyModel", 'Group0BaselineModel',
                                     'Group1BaselineModel', "GroupPeerLossModel",
                                     "SELFModel","ITECorrectedModel", 
                                     "SAREMModel", "DCEMModel"],
                        fixed_levels=["testing_disparity"],
                        fixed_vars=[1.5],
                        latex_vars=["q_t"], extra_latex=", k=1/3",
                        x="k", dodge=0.72, strip_dodge=0.01,
                                     figsize=(9, 1.7), bbox_to_anchor=(0.5, -0.35),
                      df=sepsis_results.xs(1/3, level="k", drop_level=False))

fig.show()
fig.savefig("sepsis_performance_k0.3333333333333333.pdf", bbox_extra_artists=(lgd, title), bbox_inches='tight',)


fig, lgd, title = plot_fuzzy_barplot("Testing disparity ($q_t$)", 
                       ["ROC Gap", "AUC"],# "AUC Gap"],
                       ["ROCGap", "AUC"],# "AUC"],
                      ["diff",
                       "overall"],# "diff"], 
                     include_models=["YModel", "YObsModel",
                                     "TestedOnlyModel", 'Group0BaselineModel',
                                     'Group1BaselineModel', "GroupPeerLossModel",
                                     "SELFModel","ITECorrectedModel", 
                                     "SAREMModel", "DCEMModel"],
                        fixed_levels=["testing_disparity"],
                        fixed_vars=[1.5],
                        latex_vars=["q_t"], extra_latex=", k=0.5",
                        x="k", dodge=0.72, strip_dodge=0.01,
                                     figsize=(9, 1.7), bbox_to_anchor=(0.5, -0.35),
                      df=sepsis_results.xs(0.5, level="k", drop_level=False))

fig.show()
fig.savefig("sepsis_performance_k0.5.pdf", bbox_extra_artists=(lgd, title), bbox_inches='tight',)


fig, lgd, title = plot_fuzzy_barplot("Testing disparity ($q_t$)", 
                       ["ROC Gap", "AUC"],# "AUC Gap"],
                       ["ROCGap", "AUC"],# "AUC"],
                      ["diff",
                       "overall"],# "diff"], 
                     include_models=["YModel", "YObsModel",
                                     "TestedOnlyModel", 'Group0BaselineModel',
                                     'Group1BaselineModel', "GroupPeerLossModel",
                                     "SELFModel","ITECorrectedModel", 
                                     "SAREMModel", "DCEMModel"],
                        fixed_levels=["testing_disparity"],
                        fixed_vars=[1.5],
                        latex_vars=["q_t"], extra_latex=", k=1",
                        x="k", dodge=0.72, strip_dodge=0.01,
                                     figsize=(9, 1.7), bbox_to_anchor=(0.5, -0.35),
                      df=sepsis_results.xs(1, level="k", drop_level=False))

fig.show()
fig.savefig("sepsis_performance_k1.pdf", bbox_extra_artists=(lgd, title), bbox_inches='tight',)

fig, lgd, title = plot_fuzzy_barplot("Testing disparity ($q_t$)", 
                       ["ROC Gap", "AUC"],# "AUC Gap"],
                       ["ROCGap", "AUC"],# "AUC"],
                      ["diff",
                       "overall"],# "diff"], 
                     include_models=["YModel", "YObsModel",
                                     "TestedOnlyModel", 'Group0BaselineModel',
                                     'Group1BaselineModel', "GroupPeerLossModel",
                                     "SELFModel","ITECorrectedModel", 
                                     "SAREMModel", "DCEMModel"],
                        fixed_levels=["testing_disparity"],
                        fixed_vars=[1.5],
                        latex_vars=["q_t"], extra_latex=", k=2",
                        x="k", dodge=0.72, strip_dodge=0.01,
                                     figsize=(9, 1.7), bbox_to_anchor=(0.5, -0.35),
                      df=sepsis_results.xs(2, level="k", drop_level=False))

fig.show()
fig.savefig("sepsis_performance_k2.pdf", bbox_extra_artists=(lgd, title), bbox_inches='tight',)

fig, lgd, title = plot_fuzzy_barplot("Testing disparity ($q_t$)", 
                       ["ROC Gap", "AUC"],# "AUC Gap"],
                       ["ROCGap", "AUC"],# "AUC"],
                      ["diff",
                       "overall"],# "diff"], 
                     include_models=["YModel", "YObsModel",
                                     "TestedOnlyModel", 'Group0BaselineModel',
                                     'Group1BaselineModel', "GroupPeerLossModel",
                                     "SELFModel","ITECorrectedModel", 
                                     "SAREMModel", "DCEMModel"],
                        fixed_levels=["testing_disparity"],
                        fixed_vars=[1.5],
                        latex_vars=["q_t"], extra_latex=", k=3",
                        x="k", dodge=0.72, strip_dodge=0.01,
                                     figsize=(9, 1.7), bbox_to_anchor=(0.5, -0.35),
                      df=sepsis_results.xs(3, level="k", drop_level=False))

fig.show()
fig.savefig("sepsis_performance_k3.pdf", bbox_extra_artists=(lgd, title), bbox_inches='tight',)

fig, lgd, title = plot_fuzzy_barplot("Testing disparity ($q_t$)", 
                       ["ROC Gap", "AUC"],# "AUC Gap"],
                       ["ROCGap", "AUC"],# "AUC"],
                      ["diff",
                       "overall"],# "diff"], 
                     include_models=["YModel", "YObsModel",
                                     "TestedOnlyModel", 'Group0BaselineModel',
                                     'Group1BaselineModel', "GroupPeerLossModel",
                                     "SELFModel","ITECorrectedModel", 
                                     "SAREMModel", "DCEMModel"],
                        fixed_levels=["testing_disparity"],
                        fixed_vars=[1.5],
                        latex_vars=["q_t"], extra_latex=", k=4",
                        x="k", dodge=0.72, strip_dodge=0.01,
                                     figsize=(9, 1.7), bbox_to_anchor=(0.5, -0.35),
                      df=sepsis_results.xs(4, level="k", drop_level=False))

fig.show()
fig.savefig("sepsis_performance_k4.pdf", bbox_extra_artists=(lgd, title), bbox_inches='tight',)

fig, lgd, title = plot_fuzzy_barplot("Testing disparity ($q_t$)", 
                       ["ROC Gap", "AUC"],# "AUC Gap"],
                       ["ROCGap", "AUC"],# "AUC"],
                      ["diff",
                       "overall"],# "diff"], 
                     include_models=["YModel", "YObsModel",
                                     "TestedOnlyModel", 'Group0BaselineModel',
                                     'Group1BaselineModel', "GroupPeerLossModel",
                                     "SELFModel","ITECorrectedModel", 
                                     "SAREMModel", "DCEMModel"],
                        fixed_levels=["testing_disparity"],
                        fixed_vars=[1.5],
                        latex_vars=["q_t"], extra_latex=", k=5",
                        x="k", dodge=0.72, strip_dodge=0.01,
                                     figsize=(9, 1.7), bbox_to_anchor=(0.5, -0.35),
                      df=sepsis_results.xs(5, level="k", drop_level=False))

fig.show()
fig.savefig("sepsis_performance_k5.pdf", bbox_extra_artists=(lgd, title), bbox_inches='tight',)

In [None]:
overall_aucs_self = all_results.xs(("SELFModel", "AUC"), level=("model", "metric"))["overall"]
overall_gaps_self = all_results.xs(("SELFModel", "ROCGap"), level=("model", "metric"))["diff"]

overall_aucs_self = overall_aucs_self[np.in1d(overall_gaps_self.index.get_level_values("k"), [1/4, 1/3, 0.5, 1, 2, 3])]
overall_gaps_self = overall_gaps_self[np.in1d(overall_gaps_self.index.get_level_values("k"), [1/4, 1/3, 0.5, 1, 2, 3])]



targets = [0.01, 0.03, 0.05, 0.07]
limit = 0.01
fig, ax = plt.subplots(1, len(targets), figsize=(12, 3))
suptitle = fig.suptitle("Proportion of DCEM and SELF models by AUC achieved, controlling for ROC gap")

for i, target in enumerate(targets):


    ax[i].set_title(f"Models with ROC gap\n{target} +/- {limit}")
    
    if i == 0: ax[i].set_ylabel("Frequency")
    ax[i].set_xlabel("AUC")

    def filter_val(a, target, within):
        return (a > target - within) & (a < target + within)

    dcem_mask = filter_val(overall_gaps_dcem, target, limit)
    self_mask = filter_val(overall_gaps_self, target, limit)
    #ax.scatter(overall_aucs_dcem[dcem_mask], overall_gaps_dcem[dcem_mask], color="magenta")
    #ax.scatter(overall_aucs_tested[tested_mask], overall_gaps_tested[tested_mask], color="aquamarine")

    bins = np.linspace(0.5, 1., 20)

    ax[i].hist(overall_aucs_dcem[dcem_mask], color="magenta", alpha=0.6, bins=bins, weights=np.ones_like(overall_aucs_dcem[dcem_mask]) / dcem_mask.sum(), label=f"DCEM (n={dcem_mask.sum()})")
    ax[i].hist(overall_aucs_self[self_mask], color="blue", alpha=0.6, bins=bins, weights=np.ones_like(overall_aucs_self[self_mask]) / self_mask.sum(), label=f"SELF (n={self_mask.sum()})")
    
    print(f"Mean(DCEM) @ {target} +/- {limit}:", overall_aucs_dcem[dcem_mask].mean())
    print(f"Mean(SELF) @ {target} +/- {limit}:", overall_aucs_self[self_mask].mean())

    ax[i].vlines([overall_aucs_dcem[dcem_mask].mean()], ymin=0, ymax=0.45, color="magenta", linestyle="dashed", alpha=0.6)
    ax[i].vlines([overall_aucs_self[self_mask].mean()], ymin=0, ymax=0.45, color="blue", linestyle="dashed", alpha=0.6)

    ax[i].set_ylim((0, 0.45))
    lgd = ax[i].legend(loc="lower center", bbox_to_anchor=(0.5, -1.6))
   
fig.tight_layout()
fig.subplots_adjust(top=0.7, bottom=0.5, wspace=0.3)

fig.savefig("no_tradeoff_self.pdf", bbox_extra_artists=(lgd, suptitle), bbox_inches="tight", pad_inches=0.1)


# Sensitivity analyses

In [None]:
import re

t_ablations = glob.glob("/data4/username/disparate_censorship_mitigation_ablations/20231215_t*_testedonly*/results.csv")
ablation_dfs = []
keys = []
for f in t_ablations:
    try:
        tokens = os.path.dirname(f).split("_")
        if "phase" not in f:
            phase = 0
            temp_idx = -2
        else:
            phase = re.search("\d+", tokens[-1]).group()
            temp_idx = -3
        temp = re.search("[\d]+(\.)?[\d]+", tokens[temp_idx]).group()
        df = pd.read_csv(f, index_col=np.arange(6))
        keys.append((phase, temp))
        ablation_dfs.append(df)
        print(phase, temp, f)
    except Exception as e:
        print("Failed on", f)
        raise e
ablation_results = pd.concat(ablation_dfs, keys=keys, names=["phase", "t"])
ablation_summaries = ablation_results.groupby(level=(1, 2, 3, 4, 5, 6, 7)).describe()


In [None]:
ablation_summaries["overall"].xs("AUC", level="metric").xs((1.0, 0.5, 2.0), level=(1, 2, 4))


In [None]:
ablation_summaries["diff"].xs("ROCGap", level="metric").xs((1.0, 0.5, 2.0), level=(1, 2, 4))


In [None]:
from sklearn.calibration import CalibrationDisplay
import torch.nn.functional as F

import pickle
import io
from sklearn.metrics import roc_auc_score

import sys
if sys.path[0] != "..":
    sys.path.insert(0, "..")
from nn_modules import DisparateCensorshipEstimator

class CPU_Unpickler(pickle.Unpickler):
    def find_class(self, module, name):
        if module == 'torch.storage' and name == '_load_from_bytes':
            return lambda b: torch.load(io.BytesIO(b), map_location='cpu')
        else:
            return super().find_class(module, name)
        
def get_split_from_cache(prior, t, k, qy, qt, split="test", base_results_path="/data4/username/disparate_censorship_mitigation",):
    prior_t_dir = f"20231215_prior{prior}_t{t}"
    data_path = os.path.join(base_results_path, prior_t_dir, "data_dict.pkl")
    data_dict = None
    if (prior, t) not in DATA_CACHE:
        with open(data_path, "rb") as f:
            print("Reloading data...")
            data_dict = CPU_Unpickler(f).load()
        DATA_CACHE[(prior, t)] = data_dict
    else:
        data_dict = DATA_CACHE[(prior, t)]
    param_key = (f'k_{k}', f'prevalence_disparity_{qy}', 'target_prevalence_0.25', f'testing_disparity_{qt}')
    return data_dict[param_key][split]
    
def load_dcem_model(prior, t, k, qy, qt, base_results_path="/data4/username/disparate_censorship_mitigation", verify=None):
    prior_t_dir = f"20231215_prior{prior}_t{t}"
    prior_param_dir = f"k_{k}_prevalence_disparity_{qy}_target_prevalence_0.25_testing_disparity_{qt}"
    model_skeleton = DisparateCensorshipEstimator(2)
    state_dict_path = os.path.join(base_results_path, prior_t_dir, prior_param_dir, "DCEMModel_model_info/model_state.pth")
    state_dict = torch.load(state_dict_path, map_location=torch.device('cpu'))
    model_skeleton.load_state_dict(state_dict)
    if verify is not None:
        split = get_split_from_cache(prior, t, k, qy, qt)
        
        X_ts = split["X"]
        Y_ts = split["Y"]
        A_ts = split["A"]
        preds = model_skeleton.model.predict_proba(X_ts, A=A_ts)[:, 1]
        auc = roc_auc_score(Y_ts, preds)
        print(auc)
    return model_skeleton

def get_t_hat(prior, t, k, qy, qt, base_results_path="/data4/username/disparate_censorship_mitigation"):
    split = get_split_from_cache(prior, t, k, qy, qt)
    X_ts = split["X"]
    A_ts = split["A"]
    T_ts = split["T"]
    XA = torch.from_numpy(np.concatenate([X_ts, X_ts * A_ts[:, None], A_ts[:, None]], axis=1)).float()
    logits = model.uni_propensity_model(XA).detach()
    return logits, T_ts



In [None]:
prior, t, k, qy, qt = (0.5, 10.0, 0.25, 0.25, 0.5)
model = load_dcem_model(prior, t, k, qy, qt, verify=True)
t_logits, t_true = get_t_hat(prior, t, k, qy, qt)

print(t_logits)
tau = 0.1

fig, ax = plt.subplots(1, 3, figsize=(10, 3.))
for i, tau in enumerate([0.1, 1.0, 10.0]):
    with torch.no_grad():
        disp = CalibrationDisplay.from_predictions(t_true, F.softmax(t_logits / tau, dim=-1)[:, 1], n_bins=20, ax=ax[i])
        ax[i].set_title(r"Calibration plot of $\hat{t}$" + r", $\tau={}$".format(tau))
        ax[i].set_xlabel("Mean predicted probability")
        ax[i].set_ylabel("Fraction positive")
fig.tight_layout()
fig.savefig(f"supp_calibration.pdf", bbox_inches="tight")
fig.show()

# Letting filtering methods "cheat"

When methods filter out a true positive (known with certainty), correct it.

In [None]:
FILTER_CHEAT_PATH = "/data4/username/disparate_censorship_mitigation/sweep_20230620_cheat/"
results_cheat = pd.read_csv(os.path.join(FILTER_CHEAT_PATH, "results.csv"), index_col=np.arange(6))
bootstrap_results_cheat = pd.read_csv(os.path.join(FILTER_CHEAT_PATH, "bootstrap_results.csv"), index_col=np.arange(7))
display(results_cheat)
display(bootstrap_results_cheat)

In [None]:
results_cheat.index = results_cheat.index.set_levels(["DivideMixCheatModel", "SELFCheatModel"], level="model")
bootstrap_results_cheat.index = bootstrap_results_cheat.index.set_levels(["DivideMixCheatModel", "SELFCheatModel"], level="model")
print(results_cheat.index.get_level_values("model").unique())

In [None]:
results_subset = results.loc[pd.IndexSlice[:, :, :, :, ["DivideMixBasedModel", "SELFModel", "YModel", "YObsModel"], :], :]
bootstrap_subset = bootstrap_results.loc[pd.IndexSlice[:, :, :, :, ["DivideMixBasedModel", "SELFModel", "YModel", "YObsModel"], :, :], :]
results_cheat_final = pd.concat([results_cheat, results_subset])
bootstrap_cheat_final = pd.concat([bootstrap_results_cheat, bootstrap_subset])
display(results_cheat_final)
print(bootstrap_cheat_final.index.get_level_values("model").unique())

In [None]:
fixed = ("prevalence_disparity", "k")
all_pairs = sorted(product(results_cheat_final.index.get_level_values(fixed[0]).unique(), results_cheat_final.index.get_level_values(fixed[1]).unique()))
for vals in all_pairs:
    fig = plot_results("Testing disparity", "overall AUC", "AUC", "overall", vals, fixed, df=results_cheat_final, bs_df=bootstrap_cheat_final)
    fig.show()

# Optimal M-step contour plot

In [None]:
def solution(x, y):
    b = -(1 + 2 * x * y)
    d = b ** 2 - 4 * (x * y + x) * y
    return (-b - np.sqrt(d)) / (2 * (x * y + x))
            
N = 100
x = np.linspace(0.001, 0.999, N)
y = np.linspace(0.001, 0.999, N)

xx, yy = np.meshgrid(x, y)
z = np.vectorize(solution)(xx.ravel(), yy.ravel()).reshape(xx.shape)
fig, ax = plt.subplots(figsize=(4, 3))

ax.set_title("Optimal M-step $\hat{y}^{(i)}$ given $\hat{t}^{(i)}$, $Q(y^{(i)})$; $t^{(i)} = 0$")
ax.set_xlabel(r"$\hat{t}^{(i)}$")
ax.set_ylabel(r"$Q(y^{(i)})$")

cs = ax.contourf(xx, yy, z, vmin=0, vmax=1, levels=10)
cs.set_clim(0, 1)
cslines = ax.contour(cs, colors='white', linestyles="dashed", linewidths=1)
ax.clabel(cslines, fmt="%2.1f", use_clabeltext=True, colors="white")
cbar = fig.colorbar(cs, label="Optimal $\hat{y}^{(i)}$")
fig.tight_layout()
fig.savefig("optim_mstep_v2.pdf")