In [None]:
import os
import sys
import torch
import matplotlib.pyplot as plt
import seaborn as sns
from torchvision.utils import make_grid
from torch.utils.data import DataLoader

experiment_dir = "../"
sys.path.append(experiment_dir)
from configs import default_celeba, default_abdomen
from configs.utils import get_config
from dataset import get_dataset

sns.set_theme()
sns.set_context("paper")

fig_dir = os.path.join(experiment_dir, "figures", "dataset")
os.makedirs(fig_dir, exist_ok=True)

In [None]:
configs = ["default_celeba", "default_abdomen"]
for config_name in configs:
    config = get_config(name=config_name)
    config.data.return_img_id = False

    _, dataset = get_dataset(config)

    n = 15
    loader = DataLoader(dataset, batch_size=n, shuffle=True, num_workers=1)
    data = next(iter(loader))

    if config.data.dataset == "CelebA":
        mu = std = torch.tensor(3 * [0.5])
        denorm = lambda x: (x * std[:, None, None]) + mu[:, None, None]
        data = denorm(data)
    im = make_grid(data, nrow=5)

    _, ax = plt.subplots(figsize=(16, 9))
    ax.imshow(
        im.permute(1, 2, 0),
        cmap="gray" if config.data.dataset == "AbdomenCT-1K" else None,
    )
    ax.axis("off")
    plt.savefig(os.path.join(fig_dir, f"{config.data.name}.png"), bbox_inches="tight")
    plt.savefig(os.path.join(fig_dir, f"{config.data.name}.pdf"), bbox_inches="tight")
    plt.show()
    plt.close()