In [None]:
import os
import torch
import numpy as np
import torchvision.transforms.functional as TF
import matplotlib.pyplot as plt
from matplotlib import animation, rc
import seaborn as sns
from PIL import Image

from krcps import Config
from krcps import get_uq
from krcps import get_procedure

rng = np.random.default_rng()

sns.set_theme(style="white")
sns.set_context("paper")

In [None]:
n_cal, n_val = 256, 128
config_dict = {
    "uq": "calibrated_quantile",
    "loss": "01",
    "bound": "hoeffding_bentkus",
    "epsilon": 0.1,
    "delta": 0.1,
    "lambda_max": 0.5,
    "stepsize": 2e-03,
}

rcps_config = Config(config_dict)
rcps_config.procedure = "rcps"

krcps_config = Config(config_dict)
krcps_config.procedure = "krcps"
krcps_config.n_opt = 128
krcps_config.gamma = np.linspace(0.25, 0.75, 16)
krcps_config.membership = "01_loss_otsu"
krcps_config.k = 2
krcps_config.prob_size = 50

In [None]:
n = n_cal + n_val
m = 128

gt = Image.open(os.path.join("assets", "ground_truth.jpg"))
gt = TF.to_tensor(gt)
x = torch.rand(n, *gt.size())

M = (torch.mean(gt, dim=0) >= 0.5).long()
mu = x + 0.2 * torch.randn_like(x) * (1 - M) + 0.8 * torch.randn_like(x) * M
mu = mu.unsqueeze(1)
mu = mu.repeat(1, m, 1, 1, 1)
y = torch.randn_like(mu) * 0.1 + mu

perm_idx = np.random.permutation(n)
cal_idx = perm_idx[:n_cal]
val_idx = perm_idx[n_cal:]

cal_x, cal_y = x[cal_idx], y[cal_idx]
val_x, val_y = x[val_idx], y[val_idx]

In [None]:
m_cal_x, m_cal_y = torch.mean(cal_x, dim=1), torch.mean(cal_y, dim=2)
m_val_x, m_val_y = torch.mean(val_x, dim=1), torch.mean(val_y, dim=2)

rcps = get_procedure(rcps_config)
krcps = get_procedure(krcps_config)

alpha = 0.10
_lambda = rcps(m_cal_x, m_cal_y, uq_dict={"alpha": alpha, "dim": 1})
_lambda_k = krcps(m_cal_x, m_cal_y, uq_dict={"alpha": alpha, "dim": 1})

In [None]:
val_uq = get_uq("calibrated_quantile")(m_val_y, alpha=alpha, dim=1)

_lambda_l, _lambda_u = val_uq(_lambda)
rcps_mu_i = torch.mean(_lambda_u - _lambda_l)
print(f"RCPS, mean interval length: {rcps_mu_i:.4f}")

_lambda_k_l, _lambda_k_u = val_uq(_lambda_k)
k_rcps_mu_i = torch.mean(_lambda_k_u - _lambda_k_l)
print(f"K-RCPS, mean interval length: {k_rcps_mu_i:.4f}")
print(
    f"K-RCPS reduces the mean interval length by {100 * (rcps_mu_i - k_rcps_mu_i) / rcps_mu_i:.2f}%"
)

In [None]:
rc("animation", html="html5")

fig, axes = plt.subplots(1, 2, figsize=(16 / 2, 9 / 4))
ax = axes[1]
ax.axis("off")
ax.set_title(r"$K$-RCPS calibration results ($\lambda_K,~K=2$)")
im = ax.imshow(_lambda_k, cmap="jet")

ax = axes[0]
samples = val_y[0]
vmin, vmax = torch.quantile(samples, torch.tensor([0.01, 0.99]))
samples = (samples - vmin) / (vmax - vmin)
samples = torch.clamp(samples, 0, 1)

ax.axis("off")
ax.set_title("Samples from a diffusion model")
im = ax.imshow(torch.zeros_like(gt).permute(1, 2, 0), cmap="gray", vmin=vmin, vmax=vmax)


def _init():
    im.set_data(torch.zeros_like(gt).permute(1, 2, 0))
    return (im,)


def _animate(i):
    im.set_data(samples[i].permute(1, 2, 0))
    return (im,)


anim = animation.FuncAnimation(fig, _animate, frames=m, init_func=_init)
anim.save(os.path.join("assets", "results.gif"), writer=animation.PillowWriter(fps=60))
plt.show()