In [None]:
import os
import sys
import torch
import matplotlib.pyplot as plt
import seaborn as sns
from torchvision.utils import make_grid
from torchvision.transforms.functional import rotate
from matplotlib.gridspec import GridSpec

experiment_dir = "../"
sys.path.append(experiment_dir)
from configs import celeba_ncsnpp, abdomen_ncsnpp
from configs.utils import get_config
from dataset import get_dataset
from utils import denoising_results

root_dir = "../../"
sys.path.append(root_dir)
from krcps.utils import get_uq

sns.set_theme()
sns.set_context("paper")

fig_dir = os.path.join(experiment_dir, "figures", "denoise")
os.makedirs(fig_dir, exist_ok=True)

In [None]:
uq_fn = get_uq("calibrated_quantile", alpha=0.20, dim=1)

configs = ["abdomen_ncsnpp"]
for config_name in configs:
    config = get_config(config_name)

    n = 3
    _, dataset = get_dataset(config)
    _, original, perturbed, denoised = denoising_results(
        dataset, config, shuffle=True, n=n
    )

    if config.data.dataset == "AbdomenCT-1K":
        original = rotate(original, 90)
        perturbed = rotate(perturbed, 90)
        denoised = torch.stack([rotate(x, 90) for x in denoised])

    uq_I = uq_fn(denoised)
    l, u = uq_I(0)
    i = u - l

    _, ax = plt.subplots(figsize=(16, 9))
    data = torch.stack([original, perturbed, l, u, i], dim=1)
    data = torch.flatten(data, start_dim=0, end_dim=1)
    im = make_grid(data, nrow=5, normalize=False)
    ax.imshow(im.permute(1, 2, 0), cmap="gray", vmin=0, vmax=1)
    ax.axis("off")
    plt.savefig(os.path.join(fig_dir, f"{config_name}_uq.pdf"), bbox_inches="tight")
    plt.savefig(os.path.join(fig_dir, f"{config_name}_uq.png"), bbox_inches="tight")
    plt.show()
    raise NotImplementedError

    m = 5
    fig = plt.figure(figsize=(16, 9))
    gs = GridSpec(n, 3, width_ratios=[1, 1, m], wspace=0.05, hspace=0)
    for i in range(n):
        ax = fig.add_subplot(gs[i, 0])
        ax.imshow(
            original[i].permute(1, 2, 0),
            cmap="gray" if config.data.dataset == "AbdomenCT-1K" else None,
        )
        ax.axis("off")
        if i == 0:
            ax.set_title("Original")

        ax = fig.add_subplot(gs[i, 1])
        ax.imshow(
            perturbed[i].permute(1, 2, 0),
            cmap="gray" if config.data.dataset == "AbdomenCT-1K" else None,
        )
        ax.axis("off")
        if i == 0:
            ax.set_title("Perturbed")

        ax = fig.add_subplot(gs[i, 2])
        sample = denoised[i, :m]
        sample = make_grid(sample, nrow=m, padding=0)
        ax.imshow(
            sample.permute(1, 2, 0),
            cmap="gray" if config.data.dataset == "AbdomenCT1-K" else None,
        )
        ax.axis("off")
        if i == 0:
            ax.set_title("Samples")

        # individual_fig_dir = os.path.join(fig_dir, "individual", str(i))
        # os.makedirs(individual_fig_dir, exist_ok=True)

        # _, ax = plt.subplots(figsize=(6, 6))
        # ax.imshow(
        #     original[i].permute(1, 2, 0),
        #     cmap="gray" if config.data.dataset == "AbdomenCT-1K" else None,
        # )
        # ax.axis("off")
        # plt.savefig(
        #     os.path.join(individual_fig_dir, f"{config_name}_original.png"),
        #     bbox_inches="tight",
        #     dpi=300,
        # )

        # for j in range(m):
        #     _, ax = plt.subplots(figsize=(6, 6))
        #     ax.imshow(
        #         denoised[i, j].permute(1, 2, 0),
        #         cmap="gray" if config.data.dataset == "AbdomenCT-1K" else None,
        #     )
        #     ax.axis("off")
        #     plt.savefig(
        #         os.path.join(individual_fig_dir, f"{config_name}_sample_{j}.png"),
        #         bbox_inches="tight",
        #         dpi=300,
        #     )
        #     plt.close()

    plt.savefig(os.path.join(fig_dir, f"{config_name}.png"), bbox_inches="tight")
    plt.savefig(os.path.join(fig_dir, f"{config_name}.pdf"), bbox_inches="tight")
    plt.show()
    plt.close()