In [None]:
import os
import sys
import torch
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from torchvision.utils import make_grid
from monai.transforms import ResizeWithPadOrCrop

root_dir = "../"
sys.path.append(root_dir)
from configs import get_config
from calibrate import gather_data
from datasets import get_dataset

config_name = "ts"
config = get_config(config_name)

n = 4
config.calibration.n_cal = config.calibration.n_val = n // 2

dataset = get_dataset(
    config,
    with_fbp=config.data.task == "reconstruction",
    with_prediction_results=True,
    with_segmentation_results=True,
)
idx, ground_truth, reconstruction, segmentation = gather_data(config, n=n)

window_data = pd.read_csv(os.path.join(config.get_results_dir(), "window.csv"))
window_data.set_index("scan_name", inplace=True)

sns.set_theme()
sns.set_context("paper")

In [None]:
figure_dir = os.path.join(
    root_dir, "figures", config.data.dataset, config.data.task, "examples"
)
os.makedirs(figure_dir, exist_ok=True)

for i, _idx in enumerate(idx):
    data = dataset[_idx]
    scan_name = data["name_img"]

    _ground_truth = ground_truth[i].unsqueeze(1)
    _reconstruction = reconstruction[i, 1].unsqueeze(1)
    _segmentation = segmentation[i].unsqueeze(1)

    _reconstruction[_reconstruction <= 0.1] = 0

    _l = reconstruction[i, 0]
    _u = reconstruction[i, 2]
    print(torch.quantile(_l.flatten(), 0.05), torch.quantile(_u.flatten(), 0.05))

    # _, ax = plt.subplots(figsize=(16 / 4, 9 / 4))
    # sns.ecdfplot(_l.flatten(), label="lower", ax=ax)
    # sns.ecdfplot(_u.flatten(), label="upper", ax=ax)
    # ax.legend()
    # plt.show()
    # raise NotImplementedError

    if config.data.task == "denoising":
        _measurement = _ground_truth + config.data.sigma * torch.randn_like(
            _ground_truth
        )
    elif config.data.task == "reconstruction":
        transform = ResizeWithPadOrCrop(
            (
                config.calibration.image_size,
                config.calibration.image_size,
                config.calibration.window_slices,
            )
        )
        _measurement = data["fbp"]

        scan_window = window_data.loc[scan_name]["window"]
        scan_slices_idx = list(map(int, scan_window.split(",")))
        print(scan_name, scan_slices_idx)

        _measurement = _measurement[..., scan_slices_idx]
        _measurement = transform(_measurement)
        _measurement = _measurement.permute(3, 0, 1, 2)

    _, ax = plt.subplots(figsize=(16, 9))
    data = torch.cat([_ground_truth, _measurement, _reconstruction], dim=0)
    data = torch.rot90(data, k=1, dims=(2, 3))
    image = make_grid(data, nrow=config.calibration.window_slices)
    ax.imshow(image.permute(1, 2, 0).numpy(), cmap="gray")
    ax.axis("off")
    plt.savefig(os.path.join(figure_dir, f"{scan_name}.pdf"), bbox_inches="tight")
    plt.savefig(os.path.join(figure_dir, f"{scan_name}.png"), bbox_inches="tight")
    plt.show()

    _, ax = plt.subplots(figsize=(16, 9))
    data = torch.rot90(_segmentation, k=1, dims=(2, 3))
    image = make_grid(data, nrow=config.calibration.window_slices, normalize=True)[0]
    ax.imshow(image.numpy(), cmap="jet")
    ax.axis("off")
    plt.savefig(os.path.join(figure_dir, f"{scan_name}_seg.pdf"), bbox_inches="tight")
    plt.savefig(os.path.join(figure_dir, f"{scan_name}_seg.png"), bbox_inches="tight")
    plt.show()