In [None]:
import os
import sys
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from functools import reduce
from torchvision.utils import make_grid

root_dir = "../"
sys.path.append(root_dir)
from configs import get_config
from datasets import get_dataset
from calibrate import gather_data
from krcps import get_uq
from utils import organ_idx, organ_names

config_name = "ts"
config = get_config(config_name)
dataset = get_dataset(config)
results = config.get_calibration_results()

idx = [
    reduce(np.intersect1d, [_results[r].val_idx for _results in results])
    for r in range(config.calibration.r)
]
print(idx)

sns.set_style("white")
sns.set_context("paper")

In [None]:
print("Risk control:")
for _results in results:
    loss = np.array(_results.get("loss"))
    i_mean = np.array(_results.get("i_mean"))
    print(
        f"\t{_results.name():<68}:"
        f"\tloss = {loss.mean():.3f} +- {loss.std():.3f},"
        f"\tmean interval length = {100* i_mean.mean():.2f} +- {100*i_mean.std():.2f}"
    )

loss_df = {"procedure": [], "loss": [], "i_mean": []}
for _results in results:
    loss_df["procedure"].extend([_results.procedure_name()] * len(_results))
    loss_df["loss"].extend(_results.get("loss"))
    loss_df["i_mean"].extend(_results.get("i_mean"))
loss_df = pd.DataFrame(loss_df)

_, axes = plt.subplots(1, 2, figsize=(16 / 2, 9 / 4), gridspec_kw={"wspace": 0.3})
ax = axes[0]
sns.boxplot(data=loss_df, x="procedure", y="loss", ax=ax)
ax.axhline(
    results[0].config.calibration.epsilon, color="red", linestyle="--", label="epsilon"
)
ax.set_ylabel("Risk")

ax = axes[1]
sns.boxplot(data=loss_df, x="procedure", y="i_mean", ax=ax)
ax.set_ylabel("Mean interval length")
plt.show()

In [None]:
# figure_dir = os.path.join(
#     root_dir, "figures", config.data.dataset, config.data.task, "calibration"
# )
# os.makedirs(figure_dir, exist_ok=True)

# for r, r_idx in enumerate(idx):
#     _, ground_truth, reconstruction, segmentation = gather_data(config, r_idx)

#     for i in range(ground_truth.size(0)):
#         _reconstruction, _segmentation = (reconstruction[[i]], segmentation[[i]])

#         for _results in results:
#             _config = _results.config
#             _lambda = _results[r]._lambda

#             uq = get_uq(_config.calibration.uq)
#             if _config.calibration.procedure == "semrcps":
#                 uq_fn = uq(_reconstruction, _segmentation)
#             else:
#                 uq_fn = uq(_reconstruction)

#             l, u = uq_fn(_lambda)
#             center = (u + l) / 2
#             interval = u - l

#             center = center.permute(1, 0, 2, 3)
#             interval = interval.permute(1, 0, 2, 3)

#             center = torch.rot90(center, 1, [2, 3])
#             interval = torch.rot90(interval, 1, [2, 3])

#             if _config.calibration.procedure == "krcps":
#                 _lambda = torch.rot90(_lambda, 1, (0, 1))

#                 _, ax = plt.subplots(figsize=(3, 3))
#                 img = ax.imshow(_lambda, cmap="bone", vmin=0, vmax=0.03)
#                 plt.colorbar(img, ax=ax)
#                 ax.axis("off")
#                 ax.set_title(r"$\hat{\lambda}_{K}$")
#                 plt.savefig(
#                     os.path.join(figure_dir, f"{r}_{i}_krcps_lambda.pdf"),
#                     bbox_inches="tight",
#                 )
#                 plt.savefig(
#                     os.path.join(figure_dir, f"{r}_{i}_krcps_lambda.png"),
#                     bbox_inches="tight",
#                 )
#                 plt.close()

#             if _config.calibration.procedure == "semrcps":
#                 s, pad = 54, 2
#                 midpoints = s / 2 + pad + (s + pad) * np.arange(len(organ_idx))

#                 _lambda = _lambda[organ_idx]
#                 sorted_idx = torch.argsort(_lambda, descending=True)
#                 sorted_lambda = _lambda[sorted_idx]
#                 sorted_organ_names = [organ_names[i] for i in sorted_idx]

#                 _, ax = plt.subplots(figsize=(3, 3))
#                 image_lambda = sorted_lambda[:, None, None, None].expand(-1, 3, s, s)
#                 image_lambda = make_grid(image_lambda, nrow=1, padding=pad)
#                 img = ax.imshow(image_lambda[0], cmap="bone", vmin=0, vmax=0.03)
#                 plt.colorbar(img, ax=ax)
#                 ax.set_xticks([])
#                 ax.set_yticks(midpoints)
#                 ax.set_yticklabels(sorted_organ_names)
#                 ax.set_title(r"$\hat{\lambda}_{\text{sem}}$")
#                 plt.savefig(
#                     os.path.join(figure_dir, f"{r}_{i}_semrcps_lambda.pdf"),
#                     bbox_inches="tight",
#                 )
#                 plt.savefig(
#                     os.path.join(figure_dir, f"{r}_{i}_semrcps_lambda.png"),
#                     bbox_inches="tight",
#                 )
#                 plt.close()

#             _, ax = plt.subplots(figsize=(16, 9 / 4))
#             title = _results.procedure_name()
#             if _config.calibration.procedure == "rcps":
#                 title += r" $(\lambda = %.2f)$" % _lambda
#             uncertainty = interval
#             # uncertainty[uncertainty <= 0.06] = 0
#             # uncertainty = 2 * uncertainty / (center + 1)
#             grid = make_grid(uncertainty, nrow=_config.calibration.window_slices)
#             img = ax.imshow(grid[0], cmap="bone", vmin=0, vmax=0.5)
#             plt.colorbar(img, ax=ax)
#             ax.axis("off")
#             ax.set_title(title)
#             plt.savefig(
#                 os.path.join(
#                     figure_dir, f"{r}_{i}_{_config.calibration.procedure}.pdf"
#                 ),
#                 bbox_inches="tight",
#             )
#             plt.savefig(
#                 os.path.join(
#                     figure_dir, f"{r}_{i}_{_config.calibration.procedure}.png"
#                 ),
#                 bbox_inches="tight",
#             )
#             plt.close()