In [None]:
import os
import sys
import torch
import matplotlib.pyplot as plt
from torchvision.utils import make_grid
from monai.data import DataLoader

root_dir = "../"
sys.path.append(root_dir)
from configs import get_config
from datasets import AbdomenAtlas

config_name = "ts"
config = get_config(config_name)

with_fbp = True
dataset = AbdomenAtlas(config, with_fbp=with_fbp)
dataloader = DataLoader(dataset, batch_size=4, shuffle=True)

In [None]:
data = next(iter(dataloader))
image = data["image"]

if config.data.task == "denoising":
    measurement = image + config.data.sigma * torch.randn_like(image)
elif config.data.task == "reconstruction":
    measurement = data["fbp"]
else:
    raise ValueError(f"Unknown task: {config.data.task}")

total_slices, n_slices = image.size(-1), 8
slice_idx = torch.linspace(0, image.size(-1) - 1, steps=n_slices).long()

image = image[..., slice_idx]
measurement = measurement[..., slice_idx]

image = torch.permute(image, (0, 4, 1, 2, 3))
image = image.flatten(0, 1)

measurement = torch.permute(measurement, (0, 4, 1, 2, 3))
measurement = measurement.flatten(0, 1)

image = make_grid(image, nrow=n_slices)
measurement = make_grid(measurement, nrow=n_slices)

_, ax = plt.subplots(figsize=(16, 9))
ax.imshow(image.permute(1, 2, 0), cmap="gray")
ax.axis("off")
plt.show()

_, ax = plt.subplots(figsize=(16, 9))
ax.imshow(measurement.permute(1, 2, 0), cmap="gray")
ax.axis("off")
plt.show()