# Testing PDHM-Net

This notebook is meant to test the performances of the different PDHM-Net Trainings.
It allows one to evaluate the SSIM and PSNR score of simulated images and to denoise experimental images using different methods.

In [None]:
import os
import sys
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

from tensorflow.keras.layers import Input
from tensorflow.keras.optimizers import Adam

### Enter the location of your aquaDenoising folder containing the general_functions folder
sys.path.append("path/to/aquaDenoising/")

from general_functions.neural_networks.norm_patch import robustnorm, evenpatch, evenreconstruct, randompatch, randomreconstruct
from general_functions.neural_networks.architectures import model_pdhmnet
from general_functions.neural_networks.metrics import ssimsmape_loss, ssim, psnr

## Load Neural Network Model

In [None]:
### Choose Training

TRAINING_NB = "01"
PATCH_SIZE = 128

### - - - - - - - - 

input_img = Input((PATCH_SIZE, PATCH_SIZE, 1), name='img')

model = model_pdhmnet(input_img, 1, [64,64,64,64,64], dropout=0.05) # Build PDHM-Net

loss = {"output_denoised":ssimsmape_loss} # Choose the loss function

model.compile(optimizer=Adam(), loss=loss, metrics= ["accuracy", ssim, psnr])
model.load_weights(f"DATA/PDHMNet/saved_models/Training{TRAINING_NB}/model_vloss_min.h5")

In [None]:
# ## To generate PLugIM files
# import tf2onnx
# import onnx

# onnx_model, _ = tf2onnx.convert.from_keras(model)
# onnx.save(onnx_model, f"DATA/PDHMNet/saved_models/Training{TRAINING_NB}/model_vloss_min.onnx")

# file = open(f"DATA/PDHMNet/saved_models/Training{TRAINING_NB}/model_vloss_min.onnx", "wb")
# file.write(onnx_model.SerializeToString())
# file.close()

## Denoising images

In [None]:
%%time

### Denoising Parameters

BATCH_SIZE = 8

# Evenly Distributed Patches 
STEP = 16
BORDER = 30

# Randomly Distributed Patches
NB_PATCH = 1024*4

### - - - - - - - - - - -

for IMG_NB in range(1,100)
    
    ### Simulated images

    # loc_img = f"location/simulated/images_{IMG_NB}"
    
    ### Experimental images
    
    loc_img = f"location/experimental/images_{IMG_NB}"


    try:
        img_in = Image.open(loc_img)
    except FileNotFoundError:
        break        
    img_in = np.array(img_in, dtype="float32")

    img_norm = robustnorm(img_in, 0.01)
    
    # patches, x_patches, y_patches = randompatch(img_norm, NB_PATCH, patch_size=PATCH_SIZE, seed=0)
    patches = evenpatch(img_norm, patch_size=PATCH_SIZE, step=STEP)
    
    pred_patch = model.predict(patches, batch_size=BATCH_SIZE)
    
    # img_out = randomreconstruct(img_in, pred_patch, x_patches, y_patches, border=BORDER)
    img_out = evenreconstruct(img_in, pred_patch, step=STEP, border=BORDER)


    tif_img = Image.fromarray(img_out)
    tif_img.save(f"location/denoised/images_{IMG_NB}.tif")

    tif_crop = Image.fromarray(img_out[BORDER:-BORDER, BORDER:-BORDER])
    tif_crop.save(f"location/denoised/cropped/images_{IMG_NB}.tif")

In [None]:
IMG_ID = 0

### Experimental Images
loc_img = f"location/experimental/images_{IMG_ID}"

try:
    img_in = Image.open(loc_img)
except FileNotFoundError:
    break        
img_in = np.array(img_in, dtype="float32")

### Denoised Images

tif_img = np.load(f"location/denoised/images_{IMG_ID}.tif")

### - - - - - - - - 

fig, ax = plt.subplots(1, 2, dpi=300)

ax[0].imshow(img_in[IMG_ID])
ax[1].imshow(tif_img[IMG_ID])

ax[0].set_title("Simulated Noisy Image")
ax[1].set_title("PDHM-Net Denoised Image")

## Evaluate the PSNR and SSIM of the denoisied simulated images

In [None]:
### Choose Training

TRAINING_NB = "01"
BORDER = 30

### - - - - - - - - 

img_list = []
ref_list = []

for IMG_NB in range(3):
    img = Image.open(f"location/denoised/images_{IMG_NB}")
    img_list.append(np.array(img, dtype="float32")[BORDER:-BORDER,BORDER:-BORDER])

    ref = Image.open(f"location/simulated/noiseless/images_{IMG_NB}")
    ref_list.append(robustnorm(np.array(ref, dtype="float32"), 0)[BORDER:-BORDER,BORDER:-BORDER])

img_arr = np.expand_dims(np.array(img_list), axis=-1)
ref_arr = np.expand_dims(np.array(ref_list), axis=-1)

In [None]:
IMG_ID = 0

fig, ax = plt.subplots(1, 2, dpi=300)

ax[0].imshow(img_list[IMG_ID])
ax[1].imshow(ref_list[IMG_ID])

ax[0].set_title("PDHM-Net Denoised Image")
ax[1].set_title("Simulated Noiseless Image")