In [None]:
import os
import sys
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import torch
from torch.utils.data import DataLoader
from torchvision.utils import make_grid
from torchvision.transforms.functional import center_crop

root_dir = "../"
sys.path.append(root_dir)
from configs import get_config
from datasets import get_dataset

config_name = "flare"
config = get_config(config_name)

dataset = get_dataset(config)
dataloader = DataLoader(dataset, batch_size=1, shuffle=False)

sns.set_theme()
sns.set_context("paper")

In [None]:
figure_dir = os.path.join(root_dir, "figures", "scans", config.data.dataset)
os.makedirs(figure_dir, exist_ok=True)

image_size = config.calibration.image_size

n, m = 2, 8
images = torch.zeros(n, m, 1, image_size, image_size)
for i, data in enumerate(dataloader):
    image = data["image"]
    print(data["name_img"])
    print(image.size())
    image = image.squeeze()

    slice_idx = np.linspace(0, image.shape[-1] - 1, m).astype(int)
    image_slices = image[..., slice_idx]
    image_slices = image_slices.permute(2, 0, 1)
    image_slices = image_slices.unsqueeze(1)
    image_slices = center_crop(image_slices, image_size)
    images[i] = torch.rot90(image_slices, dims=(2, 3))

    if (i + 1) == n:
        break

_, ax = plt.subplots(figsize=(16, 9))
image_grid = make_grid(images.flatten(0, 1), nrow=m)
ax.imshow(image_grid.permute(1, 2, 0), cmap="gray")
ax.axis("off")
plt.savefig(os.path.join(figure_dir, "image_slices.pdf"), bbox_inches="tight")
plt.savefig(os.path.join(figure_dir, "image_slices.png"), bbox_inches="tight")
plt.show()