In [None]:
%cd ..
%config Completer.use_jedi = False
%load_ext autoreload
%autoreload 2

In [None]:
import os
import tensorflow as tf
import tensorflow_datasets as tfds

import tensorflow_datasets_bw as datasets
from tensorflow_datasets_bw import visualize
import dppp

# Load an example dataset and example kernels

In [None]:
images = tfds.load(name="bsds500", split="validation") \
                .map(datasets.get_image) \
                .map(datasets.to_float32) \
                .map(datasets.from_255_to_1_range)
kernels = tfds.load(name='schelten_kernels/dmsp', split='test') \
                .map(datasets.crop_kernel_to_size) \
                .map(datasets.get_value('kernel')) \
                .map(datasets.to_float32) \
                .map(dppp.conv2D_filter_rgb)

In [None]:
mode = 'wrap'  # Border mode ('constant' or 'wrap')
noise_stddev = 0.04
image = datasets.get_one_example(images, index=67)[None,...]
kernel = datasets.get_one_example(kernels, index=0)

# Degrade the image

In [None]:
degraded = dppp.blur(image, kernel, noise_stddev, clip_final=False, mode=mode)

# Reconstruct the image

In [None]:
# Load the denoiser
denoiser, (denoiser_min_stddev, denoiser_max_stddev) = dppp.load_denoiser('models/drugan+_0.0-0.2.h5')

# Define the noise stddev used for the stochastic evaluation of the prior
if denoiser_min_stddev <= 0.1 <= denoiser_max_stddev:
    denoiser_stddev = 0.1
else:
    denoiser_stddev = denoiser_max_stddev

# Callbacks
log_dir = os.path.join('logs', 'dmsp_nb')
callbacks = [
    # Print the PSNR every 20th step
    dppp.callback_print_psnr('psnr', 20, image),
    # Log the SSIM to TensorBoard
    dppp.callback_tb_ssim(log_dir, 'ssim', image)
]

# Reconstruct the image
reconstructed = dppp.dmsp_deblur_nb(degraded, kernel, noise_stddev, denoiser, denoiser_stddev,
                                    num_steps=300, mode=mode, callbacks=callbacks)

In [None]:
# Print PSNR and FSIM
psnr = dppp.psnr(image, reconstructed).numpy()[0]
fsim = dppp.fsim(image, reconstructed).numpy()[0]
print(f"Reconstructed PSNR: {psnr:0.2f}, FSIM: {fsim:0.4f}")

# Visualize
visualize.draw_images([image[0], degraded[0], reconstructed[0]])