In [None]:
%cd ..
%config Completer.use_jedi = False
%load_ext autoreload
%autoreload 2

In [None]:
import os
import imageio
import tensorflow as tf
import tensorflow_datasets as tfds

import tensorflow_datasets_bw as datasets
from tensorflow_datasets_bw import visualize
import dppp

# Load an example dataset and example kernels

In [None]:
examples = ['kate', 'library', 'vase', 'vase2']
example_id = 2
data_path = os.path.join('notebooks', 'data', 'inpainting')
image = imageio.imread(os.path.join(data_path, f'{examples[example_id]}.png'))
mask = imageio.imread(os.path.join(data_path, f'{examples[example_id]}_mask.png'))

image = datasets.from_255_to_1_range(
            datasets.to_float32(
                tf.constant(image)))[None,...]
mask = tf.broadcast_to(
            datasets.from_255_to_1_range(
                datasets.to_float32(
                    tf.constant(mask)))[None,...,None],
               shape=tf.shape(image))

degraded = mask * image

In [None]:
import skimage.restoration

inpainted_skimage = skimage.restoration.inpaint_biharmonic(degraded[0].numpy(), tf.cast(mask[0] == False, tf.uint8).numpy())[None,...]
inpainted_border = dppp.inpaint_border(degraded, mask)

# Run different Methods and Models

In [None]:
model_methods = [
    ('drunet+_0.0-0.2', 'dmsp', 0.1),
    ('drugan+_0.0-0.2', 'dmsp', 0.1),
    ('drunet+_0.0-0.2', 'hqs', None),
    ('drugan+_0.0-0.2', 'hqs', None),
]

In [None]:
all_reconstructed = []

for model_name, method, denoiser_stddev in model_methods:
    model_path = os.path.join('models', f'{model_name}.h5')
    denoiser, (_, max_denoiser_stddev) = dppp.load_denoiser(model_path)
    if denoiser_stddev is None:
        denoiser_stddev = max_denoiser_stddev
        
    print(f"Running {method} for {model_name}...")
    if method == 'dmsp':
        rec = dppp.dmsp_inpaint(degraded, mask, denoiser, denoiser_stddev)
        
    if method == 'hqs':
        rec = dppp.hqs_inpaint(degraded, mask, denoiser, max_denoiser_stddev)
    
    all_reconstructed.append(rec)

# Visualize

In [None]:
visualize.draw_images([image[0], degraded[0], inpainted_border[0], inpainted_skimage[0],
                       *[x[0] for x in all_reconstructed]], ncol=4, figsize=(50, 40))

# Export

In [None]:
export_dir = os.path.join('visualize', 'inpaint_vase')
os.makedirs(export_dir)

def export_path(x):
    return os.path.join(export_dir, x)

def to_uint8(x):
    return tf.cast(tf.clip_by_value(x, 0, 1) * 255, tf.uint8)

def write_to(file_name, img):
    imageio.imwrite(export_path(f'{file_name}.png'), to_uint8(img[0]))

#### Images

# Original
write_to('original', image)

# Degraded
write_to('degraded', degraded)

# Border
write_to('border', inpainted_border)

# Skimage
write_to('skimage', inpainted_skimage)

# Reconstructed
for mm, rec in zip(model_methods, all_reconstructed):
    model_name, method, denoiser_stddev = mm
    write_to(f'{method}-{model_name}', rec)