In [None]:
import os
import sys
from pathlib import Path

path_working_dir = Path(os.path.abspath(""))
path_repo = path_working_dir.parents[1]
sys.path.insert(0, str(path_working_dir.parents[1].absolute()))

path_data_raw = path_working_dir.parent / "data_raw"
path_data_parsed = path_working_dir.parent / "data_parsed"

In [None]:
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm

from utils.gridplot import GridPlot
from utils.physics import simulate_complex_noise, spgr_signal
from utils.plotting import (
    plot_estimate_b1_corr,
    plot_estimate_mask,
    plot_estimate_param,
    plot_estimate_signal,
)
from utils.statistics import (
    T1_VFA_NLLS_estimator_parallel,
    estimate_sigma2_from_residuals,
)

plt.rcParams["image.interpolation"] = "None"
plt.rcParams["font.size"] = 9.0


def get_dataset_name(str_application, str_shape, str_setting, str_noise, str_pat_name):
    return f"{str_application}__{str_pat_name}_{str_shape}_{str_setting}_{str_noise}"

In [None]:
pat_names = ["Pat_04"]

for pat_name in pat_names:
    params = dict(
        name_application="T1_VFA",
        setting_name="fa",
        pat_name=pat_name,
        param_names=["S0", "T1"],
        param_units=["a.u.", "seconds"],
        fa_deg=[2, 4, 11, 13, 15],
        fa=np.deg2rad([2, 4, 11, 13, 15]),
        tr=6.8 * 1e-3,
        p_scales=[9.8, 1],
        lims_p=((0, 15), (0, 6)),
        lims_diff_p=((-1.5, +1.5), (-0.6, +0.6)),
        lims_y=(0, 1),
        lims_diff_y=(-0.05, +0.05),
        lims_b1=(0, 2),
        noise_std=0.02,
        init_guess=[8, 4],
        bounds=((0, 0), (100, 6)),
        mask_bound=0.3,
    )
    dp = path_data_raw / "brainweb" / params["pat_name"]

    p_ref = np.stack(
        [
            params["p_scales"][idx_param]
            * np.load(dp / f"{params['param_names'][idx_param]}.npy")
            for idx_param in range(len(params["param_names"]))
        ],
        axis=0,
    ).swapaxes(1, 3)

    print(f"loaded p_ref with shape: {p_ref.shape}")

    mask_bound = p_ref[1, ...] > params["mask_bound"]

    path_to_corrected_mask = dp / "mask_corrected.npy"

    if path_to_corrected_mask.exists():
        mask_tissue = np.load(path_to_corrected_mask).swapaxes(0, 2)

    else:
        mask_tissue = np.load(dp / "mask.npy").swapaxes(0, 2)

    mask = np.logical_and.reduce([mask_bound, mask_tissue])
    b1_corr = mask.astype(float)

    print(f"loaded mask with shape: {mask.shape}")

    export_plots = True

    dataset_name = get_dataset_name(
        str_application=params["name_application"],
        str_shape=f"shape_{mask.shape[1]}_{mask.shape[2]}_{mask.shape[0]}",
        str_setting=params["setting_name"]
        + str(params["fa_deg"]).replace(", ", "_").replace("[", "_").replace("]", ""),
        str_noise="noise_" + str(params["noise_std"]),
        str_pat_name=params["pat_name"],
    )

    path_dataset = path_data_parsed / dataset_name.split("__")[0] / dataset_name
    for item in ("source", "plots"):
        tmp = path_dataset / item
        tmp.mkdir(parents=True, exist_ok=True)

    y_ref = np.asarray(
        [
            spgr_signal(
                S0=p_ref[0, idx_z, :, :],
                T1=p_ref[1, idx_z, :, :],
                FA=params["fa"],
                TR=params["tr"],
                B1_corr=b1_corr[idx_z, :, :],
                mask=mask[idx_z, :, :],
            )
            for idx_z in range(mask.shape[0])
        ]
    ).swapaxes(0, 1)

    print(f"calculated y_ref with shape: {y_ref.shape}")

    y_scales = np.zeros(y_ref.shape[0:2])
    for idx_z in range(y_ref.shape[1]):
        for idx_s in range(y_ref.shape[0]):
            y_scales[idx_s, idx_z] = y_ref[idx_s, idx_z, ...].max()

    y = np.asarray(
        [
            simulate_complex_noise(
                image=y_ref[:, idx_z, :, :], noise_std=params["noise_std"]
            )
            for idx_z in range(y_ref.shape[1])
        ]
    ).swapaxes(0, 1)

    print(f"calculated y with shape: {y_ref.shape}")

    for idx_z in tqdm(range(mask.shape[0]), desc="plotting dataset"):
        plot_estimate_b1_corr(
            b1_corr=b1_corr[idx_z, :, :],
            mask=mask[idx_z, :, :],
            path_export=path_dataset / "plots" / f"b1_corr_z_{idx_z}.png",
        )
        plot_estimate_mask(
            mask=mask[idx_z, :, :],
            path_export=path_dataset / "plots" / f"mask_z_{idx_z}.png",
        )

        plot_estimate_signal(
            y=y[:, idx_z, :, :],
            y_ref=y_ref[:, idx_z, :, :],
            mask=mask[idx_z, ...],
            lims=params["lims_y"],
            lims_diff=params["lims_diff_y"],
            path_export=path_dataset / "plots" / f"y_versus_y_ref_z_{idx_z}.png",
            name_est="y",
            name_ref="y_ref",
        )

    p_nlls = []
    y_nlls = []

    for idx_z in tqdm(
        range(mask.shape[0]), desc="calculating and plotting NLLS estimates"
    ):
        nlls = T1_VFA_NLLS_estimator_parallel(
            y=y[:, idx_z, :, :],
            FA_values=params["fa"],
            TR=params["tr"],
            B1_corr=b1_corr[idx_z, :, :],
            mask=mask[idx_z, :, :],
            bounds=params["bounds"],
        )

        p_nlls.append(nlls)

        plot_estimate_param(
            p_est=nlls,
            p_ref=p_ref[:, idx_z, :, :],
            lims_p=params["lims_p"],
            lims_diff=params["lims_diff_p"],
            lims_hist=params["lims_p"],
            name_est="nlls",
            name_ref="ref",
            param_names=params["param_names"],
            param_units=params["param_units"],
            path_export=path_dataset / "plots" / f"p_nlls_versus_ref_z_{idx_z}.png",
            mask=mask[idx_z],
        )

        nlls = spgr_signal(
            S0=nlls[0, :, :],
            T1=nlls[1, :, :],
            FA=params["fa"],
            TR=params["tr"],
            mask=mask[idx_z, :, :],
            B1_corr=b1_corr[idx_z, :, :],
        )

        y_nlls.append(nlls)

        plot_estimate_signal(
            y=nlls,
            y_ref=y[:, idx_z, :, :],
            mask=mask[idx_z, ...],
            lims=params["lims_y"],
            lims_diff=params["lims_diff_y"],
            path_export=path_dataset / "plots" / f"y_nlls_versus_y_z_{idx_z}.png",
            name_est="nlls",
            name_ref="y",
        )

    p_nlls = np.asarray(p_nlls).swapaxes(0, 1)
    y_nlls = np.asarray(y_nlls).swapaxes(0, 1)
    r_nlls = y_nlls - y
    np.save(path_dataset / "source" / "p_nlls.npy", p_nlls)
    np.save(path_dataset / "source" / "y_nlls.npy", y_nlls)
    np.save(path_dataset / "source" / "r_nlls.npy", r_nlls)
    np.save(path_dataset / "source" / "p_ref.npy", p_ref)
    np.save(path_dataset / "source" / "mask.npy", mask)
    np.save(path_dataset / "source" / "y.npy", y)
    np.save(path_dataset / "source" / "y_ref.npy", y_ref)
    np.save(path_dataset / "source" / "b1_corr.npy", b1_corr)
    np.savez(path_dataset / "source" / "params.npz", **params)

    print("estimating noise level...")
    mask_expanded = np.stack(y.shape[0] * [mask])
    sigma_2_whole_volume = estimate_sigma2_from_residuals(r=r_nlls[:, mask])

    sigma_2_per_slice = []
    for idx_z in range(mask.shape[0]):
        sub_r = r_nlls[:, idx_z, :, :][:, mask[idx_z, :, :]]

        sigma_2_per_slice.append(estimate_sigma2_from_residuals(r=sub_r))

    sigma_per_slice_mean = [np.sqrt(np.mean(item)) for item in sigma_2_per_slice]

    plot = GridPlot(ncols=3, nrows=1, size=(12, 4))
    plot.axs[0, 0].hist(
        r_nlls[mask_expanded], bins=100, label="residuals (whole volume)"
    )
    plot.axs[0, 0].set(title=dataset_name)

    plot.axs[0, 1].hist(sigma_2_whole_volume, bins=100, label="sigma_2 (whole volume)")

    plot.axs[0, 2].plot(sigma_per_slice_mean, "*", label="sigma_per_slice_mean")
    plot.axs[0, 2].plot(
        [0, mask.shape[0]], 2 * [params["noise_std"]], "-k", label="noise std"
    )

    for idx_row in range(plot.nrows):
        for idx_col in range(plot.ncols):
            if (idx_row, idx_col) != (1, 1):
                plot.axs[idx_row, idx_col].legend()

    plot.export(path_dataset / "plots" / "_noise_analysis.png")
    plt.close()

    for idx_z in tqdm(range(mask.shape[0]), desc="saving dataset"):
        np.savez(
            path_dataset / f"dataset_idx_s_{idx_z:03d}.npz",
            y=y[:, idx_z, :, :],
            y_ref=y_ref[:, idx_z, :, :],
            mask=mask[idx_z, :, :],
            p_nlls=p_nlls[:, idx_z, :, :],
            p_ref=p_ref[:, idx_z, :, :],
            fa=params["fa"],
            tr=params["tr"],
            b1_corr=b1_corr[idx_z, :, :],
            init_guess=params["init_guess"],
            bounds=params["bounds"],
            param_names=params["param_names"],
            param_units=params["param_units"],
            name_application=params["name_application"],
            num_slices=mask.shape[0],
            image_height=mask.shape[1],
            image_width=mask.shape[2],
            noise_std=params["noise_std"],
            lims_p=params["lims_p"],
            lims_y=params["lims_y"],
            estimated_noise_std=sigma_per_slice_mean[idx_z],  # used for training (\ell)
        )