In [None]:
import os
import sys
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

root_dir = "../"
sys.path.append(root_dir)
from configs import get_config
from utils import organ_idx, organ_names

config_name = "ts"
config = get_config(config_name)

results = config.get_calibration_results(
    results_kwargs={"calibration.k": 4, "calibration.sem_control": [False, True]}
)

sns.set_style("white")
sns.set_context("paper")

In [None]:
figure_dir = os.path.join(root_dir, "figures", "organ_loss")
os.makedirs(figure_dir, exist_ok=True)

sem_loss_data = {"procedure": [], "organ": [], "loss": [], "i_mean": []}
for _results in results:
    procedure_name = _results.procedure_name()
    organ_loss = _results.get("organ_loss")
    organ_i_mean = _results.get("organ_i_mean")
    for _organ_loss, _organ_i_mean in zip(organ_loss, organ_i_mean):
        for _idx, (_loss, _i_mean) in enumerate(zip(_organ_loss, _organ_i_mean)):
            if _idx in organ_idx:
                sem_loss_data["procedure"].append(procedure_name)
                sem_loss_data["organ"].append(organ_names[organ_idx.index(_idx)])
                sem_loss_data["loss"].append(_loss.item())
                sem_loss_data["i_mean"].append(_i_mean.item())

sem_loss_data = pd.DataFrame(sem_loss_data)
organ_loss_data = sem_loss_data.groupby(["organ"])["i_mean"].mean().sort_values()

_, axes = plt.subplots(1, 2, figsize=(16 / 3, 9 / 3), gridspec_kw={"wspace": 0.6})
ax = axes[0]
sns.barplot(
    data=sem_loss_data,
    y="organ",
    x="i_mean",
    hue="procedure",
    ax=ax,
    order=organ_loss_data.index.tolist()[::-1],
)
ax.set_xlabel("")
ax.set_ylabel("Mean interval length")
ax.set_xlim(None, 0.28)
# ax.set_yticklabels(ax.get_yticklabels())
ax.legend()

ax = axes[1]
sns.barplot(
    data=sem_loss_data,
    y="organ",
    x="loss",
    hue="procedure",
    ax=ax,
    order=organ_loss_data.index.tolist()[::-1],
)
ax.axvline(config.calibration.epsilon, color="red", linestyle="--", label="tolerance")
ax.set_xlabel("")
ax.set_ylabel("Risk")
ax.set_xlim(None, 0.32)
ax.set_yticklabels([])
ax.legend()
plt.savefig(
    os.path.join(figure_dir, f"{config.data.dataset.lower()}.pdf"), bbox_inches="tight"
)
plt.savefig(
    os.path.join(figure_dir, f"{config.data.dataset.lower()}.png"), bbox_inches="tight"
)
plt.show()

In [None]:
sem_lambda_data = {"procedure": [], "organ": [], "lambda": []}
for _results in results:
    if _results.config.calibration.procedure != "semrcps":
        continue
    procedure_name = _results.procedure_name()
    for sem_lambda in _results.get("_lambda"):
        for _idx, _lambda in enumerate(sem_lambda):
            if _idx == 0:
                continue
            if _idx in organ_idx:
                sem_lambda_data["procedure"].append(procedure_name)
                sem_lambda_data["lambda"].append(_lambda.item())
                sem_lambda_data["organ"].append(organ_names[organ_idx.index(_idx)])

sem_lambda_data = pd.DataFrame(sem_lambda_data)
organ_lambda_data = sem_lambda_data.groupby(["organ"])["lambda"].mean().sort_values()

_, ax = plt.subplots(figsize=(16 / 4, 9 / 4))
sns.barplot(
    data=sem_lambda_data,
    x="organ",
    y="lambda",
    hue="procedure",
    ax=ax,
    order=organ_lambda_data.index.tolist()[::-1],
)
ax.set_xlabel("")
ax.set_ylabel(r"$\lambda_k$")
ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha="right")
ax.legend(loc="upper left", bbox_to_anchor=(1, 1))
plt.savefig(os.path.join(figure_dir, "organ_lambda.pdf"), bbox_inches="tight")
plt.savefig(os.path.join(figure_dir, "organ_lambda.png"), bbox_inches="tight")
plt.show()