In [None]:
import os
import sys
import matplotlib.pyplot as plt
import seaborn as sns
from torchvision.utils import make_grid
from matplotlib.gridspec import GridSpec

experiment_dir = "../"
sys.path.append(experiment_dir)
from configs import celeba_ncsnpp, abdomen_ncsnpp
from configs.utils import get_config
from dataset import get_dataset
from utils import denoising_results

sns.set_theme()
sns.set_context("paper")

fig_dir = os.path.join(experiment_dir, "figures", "denoise")
os.makedirs(fig_dir, exist_ok=True)

In [None]:
configs = ["celeba_ncsnpp", "abdomen_ncsnpp"]
for config_name in configs:
    config = get_config(config_name)

    n = 4
    _, dataset = get_dataset(config)
    _, original, perturbed, denoised = denoising_results(
        dataset, config, shuffle=True, n=n
    )

    m = 5
    fig = plt.figure(figsize=(16, 9))
    gs = GridSpec(n, 3, width_ratios=[1, 1, m], wspace=0.05, hspace=0)
    for i in range(n):
        ax = fig.add_subplot(gs[i, 0])
        ax.imshow(
            original[i].permute(1, 2, 0),
            cmap="gray" if config.data.dataset == "AbdomenCT-1K" else None,
        )
        ax.axis("off")
        if i == 0:
            ax.set_title("Original")

        ax = fig.add_subplot(gs[i, 1])
        ax.imshow(
            perturbed[i].permute(1, 2, 0),
            cmap="gray" if config.data.dataset == "AbdomenCT-1K" else None,
        )
        ax.axis("off")
        if i == 0:
            ax.set_title("Perturbed")

        ax = fig.add_subplot(gs[i, 2])
        sample = denoised[i, :m]
        sample = make_grid(sample, nrow=m, padding=0)
        ax.imshow(
            sample.permute(1, 2, 0),
            cmap="gray" if config.data.dataset == "AbdomenCT1-K" else None,
        )
        ax.axis("off")
        if i == 0:
            ax.set_title("Samples")
    plt.savefig(os.path.join(fig_dir, f"{config_name}.png"), bbox_inches="tight")
    plt.savefig(os.path.join(fig_dir, f"{config_name}.pdf"), bbox_inches="tight")
    plt.show()
    plt.close()