In [None]:

import matplotlib.pyplot as plt
import numpy as np
import skimage.data as skd
import skimage.transform as skt
from copy import deepcopy
from numpy.typing import NDArray
from tqdm.auto import tqdm
import autoden as ad


%load_ext autoreload
%autoreload 2

%matplotlib widget

In [None]:
NUM_IMGS_TRN = 4
NUM_IMGS_TST = 2
NUM_IMGS_TOT = NUM_IMGS_TRN + NUM_IMGS_TST

EPOCHS = 1024
REG_TV_VAL = 1e-7

vol_orig = skd.cells3d()[:, 1, ...]
vol_orig = skt.downscale_local_mean(vol_orig, (2, 4, 4))
vol_orig = (vol_orig - vol_orig.min()) / (vol_orig.max() - vol_orig.min())

vols_noisy: NDArray = np.stack(
    [(vol_orig + 0.2 * np.random.randn(*vol_orig.shape)) for _ in tqdm(range(NUM_IMGS_TOT), desc="Create noisy images")],
    axis=0,
)
tst_inds = np.arange(NUM_IMGS_TRN, NUM_IMGS_TOT)

print(f"Img orig -> [{vol_orig.min()}, {vol_orig.max()}], Img noisy -> [{vols_noisy[0].min()}, {vols_noisy[0].max()}]")
print(f"Img shape: {vol_orig.shape}")

central_slice = vol_orig.shape[0] // 2
fig, axs = plt.subplots(1, 2, sharex=True, sharey=True, figsize=(8, 4))
axs[0].imshow(vol_orig[central_slice])
axs[1].imshow(vols_noisy[0][central_slice])
fig.tight_layout()
plt.show(block=False)

## Performing training and prediction

### Creating the model

In [None]:
net_params = ad.NetworkParamsUNet(n_features=16, n_dims=3)
model = net_params.get_model()

### Training the same initial model with different algorithms

In [None]:
denoiser_sup = ad.Supervised(model=deepcopy(model), reg_val=REG_TV_VAL)
denoiser_sup.train(vols_noisy, vol_orig, epochs=EPOCHS, tst_inds=tst_inds)

In [None]:
denoiser_n2v = ad.N2V(model=deepcopy(model), reg_val=REG_TV_VAL)
denoiser_n2v.train(vols_noisy, epochs=EPOCHS, tst_inds=tst_inds)

In [None]:
denoiser_n2n = ad.N2N(model=deepcopy(model), reg_val=REG_TV_VAL)
n2n_data = denoiser_n2n.prepare_data(vols_noisy)
denoiser_n2n.train(*n2n_data, epochs=EPOCHS)

In [None]:
denoiser_dip = ad.DIP(model=deepcopy(model), reg_val=REG_TV_VAL * 5)
dip_data = denoiser_dip.prepare_data(vols_noisy)
denoiser_dip.train(*dip_data, epochs=EPOCHS)

In [None]:
from concurrent.futures import Executor
from collections.abc import Sequence
from typing import Callable

import corrct as cct

NUM_ITERS = 5_000


def fit_variational_denoising(
    volume: NDArray,
    reg: (
        Callable[[float], cct.regularizers.BaseRegularizer] | Callable[[float], Sequence[cct.regularizers.BaseRegularizer]]
    ) = cct.regularizers.Regularizer_TV2D,
    lambda_range: tuple[float, float] = (1e-3, 1e2),
    iterations: int = 2_000,
    lower_limit: float | None = None,
    num_averages: int = 1,
    parallel_eval: bool | int | Executor = False,
) -> tuple[float, NDArray]:

    def solver_spawn(lam: float):
        return cct.solvers.PDHG(regularizer=reg(lam), verbose=True, leave_progress=False)

    def solver_call(solver: cct.solvers.Solver, vol_mask: NDArray | None = None):
        op = cct.operators.TransformIdentity(volume.shape)
        rec, info = solver(op, volume, iterations=iterations, lower_limit=lower_limit, b_test_mask=vol_mask)
        # a, b = cct.processing.post.fit_scale_bias(rec, volume)
        # rec = rec * a + b
        return rec, info

    cv = cct.param_tuning.CrossValidation(
        volume.shape, num_averages=num_averages, verbose=True, plot_result=True, parallel_eval=parallel_eval
    )
    cv.solver_spawning_function = solver_spawn
    cv.solver_calling_function = solver_call

    lams = cv.get_lambda_range(*lambda_range, num_per_order=2)
    f_avgs, _, _, recs = cv.compute_loss_values(lams, return_recs=True)
    lam_min, _ = cv.fit_loss_min(lams, f_avgs)

    recs_avg = np.mean(recs, axis=0)
    fig, axs = plt.subplots(1, len(lams), sharex=True, sharey=True, figsize=(len(lams) * 1, 2))
    slice_ind = recs_avg[0].shape[0] // 2
    fig.suptitle(f"Slice = {slice_ind}", fontsize=13)
    for ii in range(len(lams)):
        axs[ii].imshow(recs_avg[ii][slice_ind])
        axs[ii].set_title(f"l = {lams[ii]:.3e}", fontsize=13)
    fig.tight_layout()
    print(lam_min)

    solver = solver_spawn(lam_min)
    rec_reg, _ = solver_call(solver)

    return lam_min, rec_reg


lam_tv, den_tv = fit_variational_denoising(vols_noisy.mean(axis=0).astype(np.float32), lambda_range=(1e-2, 1e0))

### Getting the predictions

In [None]:
den_sup = denoiser_sup.infer(vols_noisy).mean(0)
den_n2v = denoiser_n2v.infer(vols_noisy).mean(0)
den_n2n = denoiser_n2n.infer(n2n_data[0])
den_dip = denoiser_dip.infer(dip_data[0])

In [None]:
fig, axs = plt.subplots(2, 3, sharex=True, sharey=True)
axs[0, 0].imshow(vol_orig[central_slice])
axs[0, 0].set_title("Original image")
axs[0, 1].imshow(vols_noisy[0][central_slice])
axs[0, 1].set_title("Noisy image")
axs[0, 2].imshow(den_sup[central_slice])
axs[0, 2].set_title("Denoised supervised")
axs[1, 0].imshow(den_n2v[central_slice])
axs[1, 0].set_title("Denoised N2V")
axs[1, 1].imshow(den_n2n[central_slice])
axs[1, 1].set_title("Denoised N2N")
axs[1, 2].imshow(den_dip[central_slice])
axs[1, 2].set_title("Denoised DIP")
fig.tight_layout()
plt.show(block=False)

In [None]:
fontsize = 14

fig, axs = plt.subplots(2, 3, sharex=True, sharey=True, figsize=(8.5, 6))
axs[0, 0].imshow(vol_orig[central_slice])
axs[0, 0].set_title("Original image", fontsize=fontsize)
axs[0, 1].imshow(vols_noisy[0][central_slice])
axs[0, 1].set_title("Noisy image", fontsize=fontsize)
axs[0, 2].imshow(den_sup[central_slice])
axs[0, 2].set_title("Denoised supervised", fontsize=fontsize)
axs[1, 0].imshow(den_n2v[central_slice])
axs[1, 0].set_title("Denoised N2V", fontsize=fontsize)
axs[1, 1].imshow(den_n2n[central_slice])
axs[1, 1].set_title("Denoised N2N", fontsize=fontsize)
axs[1, 2].imshow(den_tv[central_slice])
axs[1, 2].set_title("Denoised TV", fontsize=fontsize)
for ax in axs.flatten():
    ax.tick_params(labelsize=fontsize)
fig.tight_layout()
plt.show(block=False)

In [None]:
from corrct.processing.post import plot_frcs
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim

all_recs = [den_sup, den_n2v, den_n2n, den_dip]
all_labs = ["Supervised", "Noise2Void", "Noise2Noise", "Deep Image Prior"]

data_range = vol_orig.max() - vol_orig.min()
print("PSNR:")
for rec, lab in zip(all_recs, all_labs):
    print(f"- {lab}: {psnr(vol_orig, rec, data_range=data_range):.3}")
print("SSIM:")
for rec, lab in zip(all_recs, all_labs):
    print(f"- {lab}: {ssim(vol_orig, rec, data_range=data_range):.3}")

plot_frcs([(vol_orig.astype(np.float32), rec) for rec in all_recs], all_labs)

In [None]:
from corrct.processing.post import plot_frcs
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim

all_recs = [den_sup, den_n2v, den_n2n, den_tv]
all_labs = ["Supervised", "Noise2Void", "Noise2Noise", "TV-min"]

data_range = vol_orig.max() - vol_orig.min()
print("PSNR:")
for rec, lab in zip(all_recs, all_labs):
    print(f"- {lab}: {psnr(vol_orig, rec, data_range=data_range):.3}")
print("SSIM:")
for rec, lab in zip(all_recs, all_labs):
    print(f"- {lab}: {ssim(vol_orig, rec, data_range=data_range):.3}")

plot_frcs([(vol_orig.astype(np.float32), rec) for rec in all_recs], all_labs)