In [5]:
import os
import time
import tensorflow as tf
import datetime
import numpy as np
from utils.utils import *

# disable tensorflow warnings for better visualization
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'


%reload_ext autoreload
%autoreload 2

In [6]:
def test(logger, test_data, mask3d_batch, mask_s, model, verbose = False):
    psnr_list, ssim_list = [], []
    test_gt = test_data
    test_PhiTy = gen_meas_tf(test_gt, mask3d_batch, mask_s, is_training = False)
    begin = time.time()
    
    model_out = model(test_PhiTy)
    end = time.time()
    for k in range(test_gt.shape[0]):
        psnr_val = tf_psnr(model_out[k,:,:,:], test_gt[k,:,:,:])
        ssim_val = tf_ssim(model_out[k,:,:,:], test_gt[k,:,:,:])
        psnr_list.append(psnr_val)
        ssim_list.append(ssim_val)
        if verbose:
            print('psnr=', psnr_val, 'ssim=', ssim_val)
    pred = np.transpose(model_out, (0, 2, 3, 1)).astype(np.float32)
    truth = np.transpose(test_gt, (0, 2, 3, 1)).astype(np.float32)
    psnr_mean = np.mean(np.asarray(psnr_list))
    ssim_mean = np.mean(np.asarray(ssim_list))
    print('===> testing psnr = {:.2f}, ssim = {:.3f}, time: {:.2f}'.format(psnr_mean, ssim_mean, (end - begin)))
    return (pred, truth, psnr_list, ssim_list, psnr_mean, ssim_mean)

# Iterative test for all masks and all models

In [7]:
# some global setting
test_path = "./Data/testing/simu/"
mask_list = ['mask1.mat', 'mask2.mat', 'mask3.mat', 'mask4.mat']
model_list = ['v3_en_mask1', 'v3_mask1', 'v3_mix', 'v3_en_mix']
batch_size = 1
patch_size = 256
logger = None

# load test data
print("="*10, "Loading test data:", "="*10)
test_data = LoadTest(test_path, patch_size)

# iterative test
for model_name in model_list:
    # load model
    model_path = "./models/" + model_name
    model = tf.keras.models.load_model(model_path)
    print("="*10, "Processing model:"+model_name, "="*10)
    
    # define if it is energy_noramlization model
    if model_name[3:5] == 'en':
        energy = True
    else:
        energy = False
        
    for mask_name in mask_list:
        # load mask
        mask_path = "./Data/" + mask_name
        mask3d_batch, mask_s = generate_masks(mask_path, batch_size, energy)
        print("Result for mask:", mask_name)
        
        # test
        test(logger, test_data, mask3d_batch, mask_s, model)

0 (256, 256, 28) 1.0 0.0
1 (256, 256, 28) 1.0 0.0
2 (256, 256, 28) 1.0 0.0
3 (256, 256, 28) 1.0 0.0
4 (256, 256, 28) 1.0 0.0
5 (256, 256, 28) 1.0 0.0
6 (256, 256, 28) 1.0 0.0
7 (256, 256, 28) 1.0 0.0
8 (256, 256, 28) 1.0 0.0
9 (256, 256, 28) 1.0 0.0
Result for mask: mask1.mat
===> testing psnr = 30.55, ssim = 0.878, time: 0.18
Result for mask: mask2.mat
===> testing psnr = 28.50, ssim = 0.824, time: 0.17
Result for mask: mask3.mat
===> testing psnr = 28.48, ssim = 0.826, time: 0.17
Result for mask: mask4.mat
===> testing psnr = 28.59, ssim = 0.826, time: 0.10
Result for mask: mask1.mat
===> testing psnr = 31.20, ssim = 0.890, time: 0.17
Result for mask: mask2.mat
===> testing psnr = 28.36, ssim = 0.823, time: 0.17
Result for mask: mask3.mat
===> testing psnr = 28.06, ssim = 0.821, time: 0.16
Result for mask: mask4.mat
===> testing psnr = 27.77, ssim = 0.815, time: 0.10
Result for mask: mask1.mat
===> testing psnr = 30.44, ssim = 0.880, time: 0.17
Result for mask: mask2.mat
===> testing