In [None]:
import os
import sys
import matplotlib.pyplot as plt
import seaborn as sns

experiment_dir = "../"
sys.path.append(experiment_dir)
from configs import celeba_ncsnpp_conffusion, abdomen_ncsnpp_conffusion
from configs.utils import get_config
from dataset import get_dataset
from utils import denoising_results

sns.set_theme()
sns.set_context("paper", font_scale=1.5)

fig_dir = os.path.join(experiment_dir, "figures", "denoise")
os.makedirs(fig_dir, exist_ok=True)

In [None]:
configs = ["celeba_ncsnpp_conffusion", "abdomen_ncsnpp_conffusion"]
for config_name in configs:
    config = get_config(config_name)

    n = 4
    _, dataset = get_dataset(config)
    _, original, perturbed, denoised = denoising_results(
        dataset, config, shuffle=True, n=n
    )

    _, axes = plt.subplots(n, 4, figsize=(16, 9))
    for i in range(n):
        ax = axes[i, 0]
        ax.imshow(
            original[i].permute(1, 2, 0),
            cmap="gray" if config.data.dataset == "AbdomenCT-1K" else None,
        )
        ax.axis("off")
        if i == 0:
            ax.set_title("Original")

        ax = axes[i, 1]
        ax.imshow(
            perturbed[i].permute(1, 2, 0),
            cmap="gray" if config.data.dataset == "AbdomenCT-1K" else None,
        )
        ax.axis("off")
        if i == 0:
            ax.set_title("Perturbed")

        ax = axes[i, 2]
        ax.imshow(
            denoised[i, 0].permute(1, 2, 0),
            cmap="gray" if config.data.dataset == "AbdomenCT-1K" else None,
        )
        ax.axis("off")
        if i == 0:
            ax.set_title(r"$\hat{q}_{0.05}$")

        ax = axes[i, 3]
        ax.imshow(
            denoised[i, 2].permute(1, 2, 0),
            cmap="gray" if config.data.dataset == "AbdomenCT-1K" else None,
        )
        ax.axis("off")
        if i == 0:
            ax.set_title(r"$\hat{q}_{0.95}$")

    plt.savefig(os.path.join(fig_dir, f"{config_name}.png"), bbox_inches="tight")
    plt.savefig(os.path.join(fig_dir, f"{config_name}.pdf"), bbox_inches="tight")
    plt.show()
    plt.close()