In [None]:
import numpy as np
import json
import os
from sim_intervals import METHODS
from simulate_data import NS, TAUS, RHOS
from tqdm import tqdm
import matplotlib.pyplot as plt

plt.rcParams["text.usetex"] = True
plt.rc("text.latex", preamble=r"\usepackage{cmbright}")
# plt.rcParams['mathtext.fontset'] = 'stixsans'

In [None]:
RESULT_DIR = "results/simulation"
fnames = os.listdir(RESULT_DIR)
results = {}
for fname in tqdm(fnames):
    key = fname[:-5]
    with open(os.path.join(RESULT_DIR, fname), "r") as f:
        result = json.load(f)
        for r in result.values():
            del r["intervals"]
            # r['width'] = np.mean(r['widths'])
            del r["widths"]
        results[key] = result

In [None]:
results["fe_unif_n=3"]["re_dl"]["nominal_cov"]

In [None]:
nice_names = {
    "binom": "ST",
    "binom-corr": r"ST$_{\rho=0.1}$",
    "signrank": "SRT",
    "fe": "FE",
    "birge": "Birge",
    "birge-t": r"Birge$_t$",
    "birge-mle": r"Birge$_{\rm{MLE}}$",
    "pdg": r"Birge$_{\rm{PDG}}$",
    "codata": r"Birge$_{\rm{CODATA}}$",
    "re_hksj": r"RE$_{\rm{HKSJ}}$",
    "re_mhksj": r"RE$_{\rm{mHKSJ}}$",
    "re_mmhksj": r"RE$_{\rm{mmHKSJ}}$",
    "re_dl": r"RE$_{\rm{DL}}$",
    "re_mle": r"RE$_{\rm{MLE}}$",
    "re_pm": r"RE$_{\rm{PM}}$",
}

setting = "re-outliers_unif"
# NS = [3, 6, 10, 15]
# TAUS = [0.1, 1.0, 10.0]

ncol = 6
fig, axs = plt.subplots(
    3,
    ncol,
    figsize=(8, 4),
    sharex=True,
    sharey=True,
    gridspec_kw={"wspace": 0, "hspace": 0.25},
)
for i, method in enumerate(METHODS):
    ax = axs.flatten()[i]
    ax.tick_params(direction="in", top=True, right=True)
    ax.set_xticks(np.arange(len(TAUS)))
    ax.set_yticks(np.arange(len(NS)))
    ax.set_title(nice_names[method], pad=2)
    ax.set_yticklabels([str(n) for n in NS])
    ax.set_xticklabels([str(t) for t in TAUS])

    # set xlabel if on bottom row, ylabel if in first column
    if i >= len(METHODS) - ncol:
        ax.set_xlabel(r"$\tau$")
    if i % ncol == 0:
        ax.set_ylabel(r"$n$")
    #
    img = np.full((len(NS), len(TAUS)), np.nan)
    for j, n in enumerate(NS):
        for k, tau in enumerate(TAUS):
            key = f"{setting}_n={n}_tau={tau}"
            if key in results and method in results[key]:
                # print(results[key][method]['coverage'])
                img[j, k] = (
                    results[key][method]["coverage"]
                    - results[key][method]["nominal_cov"]
                )
    sc = ax.imshow(img, cmap="bwr_r", vmin=-0.4, vmax=0.4)
    print(method, np.mean(img))

    ax.set_aspect(0.6)
    # ax.set_xlabel('tau')
for ax in axs.flatten()[len(METHODS) :]:
    ax.set_visible(False)
cbar_ax = fig.add_axes([0.92, 0.15, 0.02, 0.7])
fig.colorbar(sc, cax=cbar_ax, label=r"Achieved coverage $-$ Target coverage")
# plt.tight_layout()
plt.show()

In [None]:
results["re-outliers"]