In [None]:
import numpy as np
import cv2
import skimage
import scipy
from tensorflow import keras
from PIL import Image
from matplotlib import pyplot as plt

import noise_bae

In [None]:
def evaluate(original_image, denoised_image):
    assert original_image.shape == denoised_image.shape
    l2_norm = np.sqrt(np.sum(np.power(original_image - denoised_image, 2)))
    cosine_similarity = scipy.spatial.distance.cosine(
        original_image.flatten(), 
        denoised_image.flatten()
    )
    ssim_similarity = skimage.measure.compare_ssim(
        original_image,
        denoised_image,
        multichannel=True,
        data_range=np.max(denoised_image) - np.min(denoised_image)
    )
    print('l2 norm:', round(l2_norm, 5))
    print('cosine similarity:', round(cosine_similarity, 5))
    print('ssim similarity:', round(ssim_similarity, 5))

In [None]:
denoising_model = keras.models.load_model('/Users/Frost/Desktop/denoising_model.hdf5')

In [None]:
image = np.array(Image.open('/Users/Frost/Desktop/ai_sample_data/xihu.png')).astype(np.float32)[:, :, :3]
image /= 255.0
evaluate(image, image)
plt.imshow(image)

In [None]:
noisy_image = noise_bae.bae.add_pepper(noise_bae.bae.add_gaussian(image, 0.4), 0.4)
# noisy_image = np.array(Image.open('/Users/Frost/Desktop/ai_sample_data/xihu_random_noise.png')).astype(np.float32)
# noisy_image /= 255.0
plt.imshow(noisy_image)
Image.fromarray((noisy_image * 255).astype('uint8')).save(open('/Users/Frost/Desktop/noisy.jpg', 'wb'))

In [None]:
denoised_image = noise_bae.cleaner.clean_image_median(noisy_image, 15)
evaluate(image, denoised_image)
plt.imshow(denoised_image)

In [None]:
denoised_image = noise_bae.cleaner.clean_image_nlmeans(noisy_image)
evaluate(image, denoised_image)
plt.imshow(denoised_image)

In [None]:
denoised_image = noise_bae.cleaner.clean_image_tv(noisy_image, weight=0.4, addition=0.1)
evaluate(image, denoised_image)
plt.imshow(denoised_image)

In [None]:
denoised_image = noise_bae.cleaner.clean_image_resnet(noisy_image, denoising_model)
evaluate(image, denoised_image)
plt.imshow(denoised_image)