In [None]:
import sys
import os
import numpy as np
import torch
import matplotlib.pyplot as plt
from monai.data import DataLoader
from torchvision.utils import make_grid

root_dir = "../"
sys.path.append(root_dir)
from configs import get_config
from datasets import get_dataset

config_name = "flare"
config = get_config(config_name)

with_fbp = config.data.task == "reconstruction"
dataset = get_dataset(config, with_fbp=with_fbp, with_prediction_results=True)
dataloader = DataLoader(dataset, batch_size=1, shuffle=True)

In [None]:
for i, data in enumerate(dataloader):
    scan_name = data["name_img"][0]

    image = data["image"]
    mmse = data["mmse"]
    q_lo = data["q_lo"]
    q_hi = data["q_hi"]

    if config.data.task == "denoising":
        measurement = image + config.data.sigma * torch.randn_like(image)
    elif config.data.task == "reconstruction":
        measurement = data["fbp"]

    data = torch.cat([image, measurement, mmse], dim=0)
    data = torch.rot90(data, k=1, dims=(2, 3))

    total_slices, n_slices = image.size(-1), 8
    slice_idx = torch.linspace(0, image.size(-1) - 1, steps=n_slices).long()
    image = data[..., slice_idx]
    image = torch.permute(image, (0, 4, 1, 2, 3))
    image = image.flatten(0, 1)
    image = make_grid(image, nrow=n_slices)

    _, ax = plt.subplots(figsize=(16, 9))
    ax.imshow(image.permute(1, 2, 0), cmap="gray")
    ax.axis("off")
    plt.show()

    if (i + 1) % 10 == 0:
        break