In [None]:
import os
import sys
from pathlib import Path

from bm3d import bm3d
from skimage.restoration import denoise_nl_means
from tqdm import tqdm

path_to_nb = Path(os.path.abspath(""))
sys.path.append(str(path_to_nb.parent.absolute()))
import matplotlib.pyplot as plt
import numpy as np

from utils.gridplot import GridPlot
from utils.plotting import _select_matrix_from_box, _trace_region_box, find_clip_indices
from utils.statistics import (
    T1_VFA_NLLS_estimator_parallel,
    get_experiment_noise,
    masked_nrmse,
    masked_ssim,
)

In [None]:
path_data_storage = path_to_nb.parent / 'data' / "data_parsed"/ "T1_VFA"
dataset_name = "T1_VFA__Pat_04_shape_222_185_48_fa_2_4_11_13_15_noise_0.02"
idx_p = 1

y = np.load(path_data_storage / dataset_name / "source" / "y.npy")
y_ref = np.load(path_data_storage / dataset_name / "source" / "y_ref.npy")
p_ref = np.load(path_data_storage / dataset_name / "source" / "p_ref.npy")
p_NONE_nlls = np.load(path_data_storage / dataset_name / "source" / "p_nlls.npy")
mask = np.load(path_data_storage / dataset_name / "source" / "mask.npy")
nz = mask.shape[0]
data = [
    np.load(path_data_storage / dataset_name / f"dataset_idx_s_{i:03d}.npz")
    for i in range(48)
]

ssim_NONE_nlls = [
    masked_ssim(
        ref=p_ref[idx_p, i, ...],
        est=p_NONE_nlls[idx_p, i, ...],
        mask=mask[i, ...],
    ).item()
    for i in range(nz)
]

nrmse_NONE_nlls = [
    masked_nrmse(
        ref=p_ref[idx_p, i, ...],
        est=p_NONE_nlls[idx_p, i, ...],
        mask=mask[i, ...],
    )
    for i in range(nz)
]


In [None]:
# create NLM reference
y_est = []

for idx_z in tqdm(range(nz), desc='denoising with NLM'):
    tmp = []
    for idx_fa in range(y.shape[0]):
        tmp.append(
            denoise_nl_means(
                image=y[idx_fa, idx_z, ...], patch_size=7, patch_distance=11, h=0.02
            )
        )
    y_est.append(tmp)

y_est = np.asarray(y_est)
y_est = np.swapaxes(y_est, 1, 0)

p_NLM_nlls = [
    T1_VFA_NLLS_estimator_parallel(
        y=y_est[:, i, ...],
        FA_values=data[i]["fa"],
        TR=data[i]["tr"],
        mask=mask[i, ...],
        B1_corr=data[i]["b1_corr"],
        bounds=data[i]["bounds"],
    )
    for i in tqdm(range(nz), desc='estimating parameter maps with NLLS')
]
p_NLM_nlls = np.asarray(p_NLM_nlls).swapaxes(0,1)

ssim_NLM_nlls = [
    masked_ssim(
        ref=p_ref[idx_p, i, ...],
        est=p_NLM_nlls[idx_p, i, ...],
        mask=mask[i, ...],
    ).item()
    for i in range(nz)
]

nrmse_NLM_nlls = [
    masked_nrmse(
        ref=p_ref[idx_p, i, ...],
        est=p_NLM_nlls[idx_p, i, ...],
        mask=mask[i, ...],
    )
    for i in range(nz)
]


In [None]:
# create BM3D reference
def main(y, noise_var):
    noise, psd, _ = get_experiment_noise("gw", noise_var, y.shape)
    z = np.atleast_3d(y) + np.atleast_3d(noise)
    y_est = bm3d(z, psd)
    y_est = np.minimum(np.maximum(y_est, 0), 1)
    return y_est


if __name__ == "__main__":
    y_est = []
    for idx_z in tqdm(range(nz), desc='denoising with BM3D'):
        var = np.square(data[idx_z]["noise_std"])
        est = np.stack(
            [main(y=y_ref[idx_fa, idx_z, :, :], noise_var=var) for idx_fa in range(y_ref.shape[0])]
        )
        y_est.append(est)

    y_est = np.asarray(y_est)
    y_est = np.swapaxes(y_est, 1, 0)

    p_BM3D_nlls = [
        T1_VFA_NLLS_estimator_parallel(
            y=y_est[:, idx_z, ...],
            FA_values=data[idx_z]["fa"],
            TR=data[idx_z]["tr"],
            mask=mask[idx_z, ...],
            B1_corr=data[idx_z]["b1_corr"],
            bounds=data[idx_z]["bounds"],
        )
        for idx_z in tqdm(range(nz), desc='estimating parameter maps with NLLS')
    ]
    p_BM3D_nlls = np.asarray(p_BM3D_nlls).swapaxes(0,1)

ssim_BM3D_nlls = [
    masked_ssim(
        ref=p_ref[idx_p, i, ...],
        est=p_BM3D_nlls[idx_p, i, ...],
        mask=mask[i, ...],
    ).item()
    for i in range(nz)
]

nrmse_BM3D_nlls = [
    masked_nrmse(
        ref=p_ref[idx_p, i, ...],
        est=p_BM3D_nlls[idx_p, i, ...],
        mask=mask[i, ...],
    )
    for i in range(nz)
]


In [None]:
### AD PATH TO YOUR ESTIMATION HERE RESULTS
path_to_est = Path(r'E:\paper3_local_storage\TMP\T1_VFA\T1_VFA__Pat_04_shape_222_185_48_fa_2_4_11_13_15_noise_0.02\250514154937')


p_DIP_ours = [np.load(path_to_est / f'{idx_z:03d}' / "p_est.npy") for idx_z in range(nz)]
p_DIP_ours = np.asarray(p_DIP_ours).swapaxes(0,1)

ssim_DIP_ours = [
    masked_ssim(
        ref=p_ref[idx_p, i, ...],
        est=p_DIP_ours[idx_p, i, ...],
        mask=mask[i, ...],
    ).item()
    for i in range(nz)
]


nrmse_DIP_ours = [
    masked_nrmse(
        ref=p_ref[idx_p, i, ...],
        est=p_DIP_ours[idx_p, i, ...],
        mask=mask[i, ...],
    )
    for i in range(nz)
]



In [None]:
plot = GridPlot(ncols=2)

ax = plot.axs[0, 0]

ax.plot(ssim_NONE_nlls, '.', color='C0', label='NONE')
ax.plot(ssim_NLM_nlls, '.', color='C1', label='NLM')
ax.plot(ssim_BM3D_nlls, '.', color='C2', label='BM3D')
ax.plot(ssim_DIP_ours, '+', color='C7', label='OURS')
ax.set(ylabel='SSIM')
ax.legend()

ax = plot.axs[0, 1]

ax.plot(nrmse_NONE_nlls, '.', color='C0', label='NONE')
ax.plot(nrmse_NLM_nlls, '.', color='C1', label='NLM')
ax.plot(nrmse_BM3D_nlls, '.', color='C2', label='BM3D')
ax.plot(nrmse_DIP_ours, '+', color='C7', label='OURS')
ax.set(ylabel='NRMSE')
ax.legend()

plot.set_size(x=10, y=3)
plot.set_spacing(wspace=0.25)

plt.show()

In [None]:
idx_z = 18

x1, x2, y1, y2 = find_clip_indices(p_ref[idx_p, idx_z, ...], slack=0)

image_center = p_ref[idx_p, idx_z, x1:-x2, y1:-y2]

image_1_out = p_ref[idx_p, idx_z, x1:-x2, y1:-y2]
image_2_out = p_NONE_nlls[idx_p, idx_z, x1:-x2, y1:-y2]
image_3_out = p_NLM_nlls[idx_p, idx_z, x1:-x2, y1:-y2]
image_4_out = p_BM3D_nlls[idx_p, idx_z, x1:-x2, y1:-y2]
image_5_out = p_DIP_ours[idx_p, idx_z, x1:-x2, y1:-y2]
image_6_out = p_ref[idx_p, idx_z, x1:-x2, y1:-y2]
image_7_out = p_ref[idx_p, idx_z, x1:-x2, y1:-y2]
image_8_out = p_ref[idx_p, idx_z, x1:-x2, y1:-y2]
image_9_out = p_ref[idx_p, idx_z, x1:-x2, y1:-y2]


image_center_title = ""

image_1_out_title = "GT"
image_2_out_title = "NONE"
image_3_out_title = "NLM"
image_4_out_title = "BM3D"
image_5_out_title = "OURS"
image_6_out_title = "GT"
image_7_out_title = "GT"
image_8_out_title = "GT"
image_9_out_title = "GT"


sx, sy = 35, 30

edge_color = "red"
connection_color = "red"
parent_box_color = "red"
edge_linewidth = 1
connection_linewidth = 1
parent_box_linewidth = 1
parent_ax_ticks = True
ax_ticks = False
daughter_ax_ticks = False
lims = [0, 5]

zoom_param = dict(
    ne=dict(
        c=(88, 50),
        rc1=(0, 2),
        rc2=(0, 3),
        rc3=(0, 4),
        rc4=(0, 5),
        rc5=(0, 6),
        index=2,
    ),
    sw=dict(  # down
        c=(110, 155),
        rc1=(2, 2),
        rc2=(2, 3),
        rc3=(2, 4),
        rc4=(2, 5),
        rc5=(2, 6),
        index=3,
    ),
    se=dict(  # middle
        c=(100, 100),
        rc1=(1, 2),
        rc2=(1, 3),
        rc3=(1, 4),
        rc4=(1, 5),
        rc5=(1, 6),
        index=4,
    ),
)


plot = GridPlot(nrows=3, ncols=7)
plot.latex = False
plot.lims = lims
rc = [0, 1]
cc = [0, 2]

plot.axs[0, 1].remove()
plot.axs[1, 1].remove()
plot.axs[2, 1].remove()
plot.axs[0, 0].remove()
plot.axs[1, 0].remove()
plot.axs[2, 0].remove()


gs = plot.axs[rc[0], cc[0]].get_gridspec()

plot.axs[rc[0], cc[0]] = plot.fig.add_subplot(
    gs[rc[0] : rc[-1] + 1, cc[0] : cc[-1] + 0]
)


mat = image_center

plot.add_subplot(row=rc[0], col=cc[0], mat=mat, cbar_bool=False, lims=lims)
plot.lines.remove(row=rc[0], col=cc[0])

for key in zoom_param:
    # fetch plot params
    index = zoom_param[key]["index"]
    (
        ax_1_out,
        ax_2_out,
        ax_3_out,
        ax_4_out,
        ax_5_out,
    ) = (
        zoom_param[key]["rc1"],
        zoom_param[key]["rc2"],
        zoom_param[key]["rc3"],
        zoom_param[key]["rc4"],
        zoom_param[key]["rc5"],
    )
    trace_x, trace_y = _trace_region_box(c=zoom_param[key]["c"], sx=sx, sy=sy)

    # plot traces on image_center
    plot.axs[rc[0], cc[0]].plot(
        trace_x, trace_y, parent_box_color, linewidth=parent_box_linewidth
    )

    image_1_out_zoom = _select_matrix_from_box(image_1_out, trace_x, trace_y)
    image_2_out_zoom = _select_matrix_from_box(image_2_out, trace_x, trace_y)
    image_3_out_zoom = _select_matrix_from_box(image_3_out, trace_x, trace_y)
    image_4_out_zoom = _select_matrix_from_box(image_4_out, trace_x, trace_y)
    image_5_out_zoom = _select_matrix_from_box(image_5_out, trace_x, trace_y)

    plot.add_subplot(row=ax_1_out[0], col=ax_1_out[1], mat=image_1_out_zoom, lims=lims)
    plot.add_subplot(row=ax_2_out[0], col=ax_2_out[1], mat=image_2_out_zoom, lims=lims)
    plot.add_subplot(row=ax_3_out[0], col=ax_3_out[1], mat=image_3_out_zoom, lims=lims)
    plot.add_subplot(row=ax_4_out[0], col=ax_4_out[1], mat=image_4_out_zoom, lims=lims)
    plot.add_subplot(row=ax_5_out[0], col=ax_5_out[1], mat=image_5_out_zoom, lims=lims)
    plot.lines.set_color(row=ax_1_out[0], col=ax_1_out[1], color="red")

titles = (
    None,
    None,
    image_1_out_title,
    image_2_out_title,
    image_3_out_title,
    image_4_out_title,
    image_5_out_title,
)
for i in range(len(titles)):
    if titles[i] is not None:
        plot.axs[0, i].set_title(titles[i], fontsize=8)

plot.cbar.share_col(
    idx=3,
    length=1.45,
    width=0.15,
    loc="lower left",
    anchors=(-2.85, 0.9, 1, 1),
    borderpad=0,
    ticks=(0, 2.5, 5),
)
plot.cbar.remove_cols(cols=[2,4,5, 6])
plot.ticks.remove_all()

plot.lines.remove_cols(cols=[0, 1])
# plt.text(lims[1] * 1.01, -0.12, _tex_str("(s)"))
plt.text(0.4, -6.5, "seconds")

plot.set_size(4, 1.7)
plot.set_spacing(wspace=0.05, hspace=0.05)

plot.export('results.png', dpi=2000)
# plt.show()
