In [None]:

import matplotlib.pyplot as plt
import numpy as np
import skimage.color as skc
import skimage.data as skd
import skimage.transform as skt
from copy import deepcopy
from numpy.typing import NDArray
from tqdm.auto import tqdm
import autoden as ad


%load_ext autoreload
%autoreload 2

%matplotlib widget

In [None]:
USE_CAMERA_MAN = True
NUM_IMGS_TRN = 4
NUM_IMGS_TST = 2
NUM_IMGS_TOT = NUM_IMGS_TRN + NUM_IMGS_TST

EPOCHS = 1024 * 3
REG_TV_VAL = 1e-6

if USE_CAMERA_MAN:
    img_orig = skd.camera()
    img_orig = skt.downscale_local_mean(img_orig, 4)
else:
    img_orig = skd.cat()
    img_orig = skc.rgb2gray(img_orig)
    img_orig *= 255 / img_orig.max()

imgs_noisy: NDArray = np.stack(
    [(img_orig + 20 * np.random.randn(*img_orig.shape)) for _ in tqdm(range(NUM_IMGS_TOT), desc="Create noisy images")],
    axis=0,
)
tst_inds = np.arange(NUM_IMGS_TRN, NUM_IMGS_TOT)

print(f"Img orig -> [{img_orig.min()}, {img_orig.max()}], Img noisy -> [{imgs_noisy[0].min()}, {imgs_noisy[0].max()}]")
print(f"Img shape: {img_orig.shape}")

# fig, axs = plt.subplots(1, 2, sharex=True, sharey=True)
# axs[0].imshow(img_orig)
# axs[1].imshow(imgs_noisy[0])
# fig.tight_layout()
# plt.show(block=False)

## Performing training and prediction

### Creating the model

In [None]:
net_params = ad.NetworkParamsUNet(n_features=16)
model = net_params.get_model()

### Training the same initial model with different regularization losses

In [None]:
denoiser_dip_tv = ad.DIP(model=deepcopy(model), reg_val=REG_TV_VAL * 5)
dip_data = denoiser_dip_tv.prepare_data(imgs_noisy)
denoiser_dip_tv.train(*dip_data, epochs=EPOCHS)

In [None]:
reg_tgv = ad.losses.LossTGV(REG_TV_VAL * 3.5)

denoiser_dip_tgv = ad.DIP(model=deepcopy(model), reg_val=reg_tgv)
denoiser_dip_tgv.train(*dip_data, epochs=EPOCHS)

In [None]:
reg_swt = ad.losses.LossSWTN(lambda_val=REG_TV_VAL * 1.5, min_approx=False)

denoiser_dip_swt = ad.DIP(model=deepcopy(model), reg_val=reg_swt)
denoiser_dip_swt.train(*dip_data, epochs=EPOCHS)

### Getting the predictions

In [None]:
den_dip_tv = denoiser_dip_tv.infer(dip_data[0])
den_dip_tgv = denoiser_dip_tgv.infer(dip_data[0])
den_dip_swt = denoiser_dip_swt.infer(dip_data[0])

In [None]:
fig, axs = plt.subplots(2, 3, sharex=True, sharey=True, figsize=(3.5 * 3, 3.5 * 2 + 0.5))
axs[0, 0].imshow(img_orig)
axs[0, 0].set_title("Original image", fontsize=13)
axs[0, 1].imshow(imgs_noisy[0])
axs[0, 1].set_title("Noisy image", fontsize=13)
axs[0, 2].imshow(np.mean(imgs_noisy, axis=0))
axs[0, 2].set_title("Averaged image", fontsize=13)
axs[1, 0].imshow(den_dip_tv)
axs[1, 0].set_title("Denoised DIP-TV", fontsize=13)
axs[1, 1].imshow(den_dip_tgv)
axs[1, 1].set_title("Denoised DIP-TGV", fontsize=13)
axs[1, 2].imshow(den_dip_swt)
axs[1, 2].set_title("Denoised DIP-SWT", fontsize=13)
for ax in axs.flatten():
    ax.tick_params(labelsize=12)
fig.tight_layout()
plt.show(block=False)

In [None]:
from corrct.processing.post import plot_frcs
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim

all_recs = [den_dip_tv, den_dip_tgv, den_dip_swt]
all_labs = ["Deep Image Prior - TV", "Deep Image Prior - TGV", "Deep Image Prior - SWT"]

data_range = img_orig.max() - img_orig.min()
print("PSNR:")
for rec, lab in zip(all_recs, all_labs):
    print(f"- {lab}: {psnr(img_orig, rec, data_range=data_range):.3}")
print("SSIM:")
for rec, lab in zip(all_recs, all_labs):
    print(f"- {lab}: {ssim(img_orig, rec, data_range=data_range):.3}")

plot_frcs([(img_orig.astype(np.float32), rec) for rec in all_recs], all_labs)