In [None]:

import matplotlib.pyplot as plt
import numpy as np
import skimage.data as skd
import skimage.transform as skt
from copy import deepcopy
from numpy.typing import NDArray
from tqdm.auto import tqdm
import autoden as ad


%load_ext autoreload
%autoreload 2

%matplotlib widget

In [None]:
USE_CAMERA_MAN = True
NUM_IMGS_TRN = 4
NUM_IMGS_TST = 2
NUM_IMGS_TOT = NUM_IMGS_TRN + NUM_IMGS_TST

EPOCHS = 1024
REG_TV_VAL = 1e-7

img_orig = skd.cat().astype(np.float32)
img_orig = skt.downscale_local_mean(img_orig, (4, 4, 1))
print(img_orig.shape)
img_orig *= 255 / img_orig.max()

imgs_noisy: NDArray = np.stack(
    [(img_orig + 20 * np.random.randn(*img_orig.shape)) for _ in tqdm(range(NUM_IMGS_TOT), desc="Create noisy images")],
    axis=0,
)

print(f"Img orig -> [{img_orig.min()}, {img_orig.max()}], Img noisy -> [{imgs_noisy[0].min()}, {imgs_noisy[0].max()}]")
print(f"Img shape: {img_orig.shape}")

# fig, axs = plt.subplots(1, 2, sharex=True, sharey=True)
# axs[0].imshow(img_orig)
# axs[1].imshow(imgs_noisy[0])
# fig.tight_layout()
# plt.show(block=False)

## Performing training and prediction

### Creating the model

In [None]:
net_params = ad.NetworkParamsUNet(n_channels_in=3, n_channels_out=3, n_features=16)
model = net_params.get_model()

### Training the same initial model with different algorithms

In [None]:
denoiser_sup = ad.Supervised(model=deepcopy(model), reg_val=REG_TV_VAL)
sup_data = denoiser_sup.prepare_data(imgs_noisy, img_orig, num_tst_ratio=NUM_IMGS_TST / NUM_IMGS_TOT, channel_axis=-1)
denoiser_sup.train(*sup_data, epochs=EPOCHS)

In [None]:
denoiser_n2v = ad.N2V(model=deepcopy(model), reg_val=REG_TV_VAL)
n2v_data = denoiser_n2v.prepare_data(imgs_noisy, num_tst_ratio=NUM_IMGS_TST / NUM_IMGS_TOT, channel_axis=-1)
denoiser_n2v.train(*n2v_data, epochs=EPOCHS)

In [None]:
denoiser_n2n = ad.N2N(model=deepcopy(model), reg_val=REG_TV_VAL)
n2n_data = denoiser_n2n.prepare_data(imgs_noisy, channel_axis=-1)
denoiser_n2n.train(*n2n_data, epochs=EPOCHS)

In [None]:
denoiser_dip = ad.DIP(model=deepcopy(model), reg_val=REG_TV_VAL * 2.5e1)
dip_data = denoiser_dip.prepare_data(imgs_noisy, channel_axis=-1)
denoiser_dip.train(*dip_data, epochs=EPOCHS * 3)

### Getting the predictions

In [None]:
den_sup = denoiser_sup.infer(sup_data[0], channel_axis_dst=-1).mean(0)
den_n2v = denoiser_n2v.infer(n2v_data[0], channel_axis_dst=-1).mean(0)
den_n2n = denoiser_n2n.infer(n2n_data[0], channel_axis_dst=-1)
den_dip = denoiser_dip.infer(dip_data[0], channel_axis_dst=-1)

In [None]:
fontsize = 14

fig, axs = plt.subplots(2, 3, sharex=True, sharey=True, figsize=(12, 6))
axs[0, 0].imshow(img_orig / 255)
axs[0, 0].set_title("Original image", fontsize=fontsize)
axs[0, 1].imshow(imgs_noisy[0] / 255)
axs[0, 1].set_title("Noisy image", fontsize=fontsize)
axs[0, 2].imshow(den_sup / 255)
axs[0, 2].set_title("Denoised supervised", fontsize=fontsize)
axs[1, 0].imshow(den_n2v / 255)
axs[1, 0].set_title("Denoised N2V", fontsize=fontsize)
axs[1, 1].imshow(den_n2n / 255)
axs[1, 1].set_title("Denoised N2N", fontsize=fontsize)
axs[1, 2].imshow(den_dip / 255)
axs[1, 2].set_title("Denoised DIP", fontsize=fontsize)
for ax in axs.flatten():
    ax.tick_params(labelsize=fontsize)
fig.tight_layout()
plt.show(block=False)

In [None]:
from corrct.processing.post import plot_frcs
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim

all_recs = [den_sup, den_n2v, den_n2n, den_dip]
all_labs = ["Supervised", "Noise2Void", "Noise2Noise", "Deep Image Prior"]

data_range = img_orig.max() - img_orig.min()
print("PSNR:")
for rec, lab in zip(all_recs, all_labs):
    print(f"- {lab}: {psnr(img_orig, rec, data_range=data_range):.3}")
print("SSIM:")
for rec, lab in zip(all_recs, all_labs):
    print(f"- {lab}: {ssim(img_orig, rec, data_range=data_range, channel_axis=-1):.3}")

plot_frcs([(img_orig.astype(np.float32), rec) for rec in all_recs], all_labs, axes=(-3, -2))