In [None]:
import config as CONFIG
import os
import sys
sys.path.append(CONFIG.GLOBAL_MODEL_PATH)
from models.mwunet import mwunet

from csbdeep.io import load_training_data
from csbdeep.models import IsotropicCARE
from csbdeep.utils import plot_some

import tensorflow as tf
from matplotlib import pyplot as plt
import numpy as np

In [None]:
(X,Y), (X_val,Y_val), axes = load_training_data(CONFIG.TRAIN_DATASET_PATH, validation_split=0.1, axes='SCYX', verbose=True)

In [None]:
FX = mwunet(input_shape=X.shape[1:], output_shape=Y.shape[1:],
            conv_kernel_size=5, n_filters_per_scale=[32, 64, 128])
FX.model.load_model(os.path.join(CONFIG.CHECKPOINT_PATH, CONFIG.FX_MODEL_PATH))

In [None]:
DX = IsotropicCARE(config=None, name='DX', basedir=CONFIG.CHECKPOINT_PATH)

In [None]:
x = X_val[0:1]
y = Y_val[0:1]

print(x.shape, y.shape)

In [None]:
y_est = FX.keras_model.predict(x)
x_back = DX.keras_model.predict(y)
residual = x - x_back

In [None]:
plt.figure(figsize=(16,4))
plt.suptitle('x,  y,  y_est, x_back, residual')
plot_some(np.squeeze([x, y, y_est, x_back, residual], axis=1))
plt.show()

In [None]:
nbatch, height, width, channel = residual.shape
x_i = []
N = 1000
for i in range(N):
    noise = np.transpose(residual, (1, 2, 0, 3)).reshape(height*width, nbatch, channel)
    noise = np.random.shuffle(noise).reshape(height, width, nbatch, channel)
    noise=np.transpose(noise, (2, 0, 1, 3))
    res_x = x_back + noise
    x_i.append(np.squeeze(res_x, axis=0))
x_i = np.asarray(x_i)

In [None]:
y_i = FX.predict(x_i)

plt.figure(figsize=(20, 13))
plt.suptitle('row1: x_i;  row2: y_i;  row3: y_est')
plot_some(x_i[0:5], y_i[0:5], [y_est for _ in range(5)])
plt.show()

In [None]:
def SSIM(y_true, y_pred):
    return tf.reduce_mean(tf.image.ssim(y_true, y_pred, 1.0))

def PSNR(y_true, y_pred):
    return tf.reduce_mean(tf.image.psnr(y_true, y_pred, 1.0))

In [None]:
SSIM_ygt_yest = SSIM(y, y_est)
PSNR_ygt_yest = PSNR(y, y_est)

SSIM_yest_yi = 0
PSNR_yest_yi = 0
for i in range(N):
    SSIM_yest_yi += SSIM(y_est, y_i[i:i+1])
    PSNR_yest_yi += PSNR(y_est, y_i[i:i+1])
SSIM_yest_yi /= N
PSNR_yest_yi /= N