In [None]:
import time
import argparse
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "1"
import torch
from skimage.metrics import structural_similarity

import matplotlib.pyplot as plt
import utils
import dataset

def str2bool(v):
    #print(v)
    if v.lower() in ('yes', 'true', 't', 'y', '1'):
        return True
    elif v.lower() in ('no', 'false', 'f', 'n', '0'):
        return False
    else:
        raise argparse.ArgumentTypeError('Unsupported value encountered.')

if __name__ == "__main__":
    # ----------------------------------------
    #        Initialize the parameters
    # ----------------------------------------
    parser = argparse.ArgumentParser()
    #GPU parameters
    parser.add_argument('--no_gpu', default = False, help = 'True for CPU')
    # Saving, and loading parameters
    
    #models_TinyDerainGAN_raindrop
    parser.add_argument('--save_name', type = str, default = '../results/models_SAUnet_SASkipnet_SAPerceptuals8_raindrop', help = 'save the generated with certain epoch')
    parser.add_argument('--load_gname', type = str, default = '../models/models_SAUnet_SASkipnet_SAPerceptuals8_raindrop/KPN_rainy_image_epoch200_bs8_generator.pth', help = 'load the pre-trained model with certain epoch')
    
    parser.add_argument('--baseroot', type = str, default = '../rainy_image_dataset/raindrop/test', help = 'images baseroot')
    
    parser.add_argument('--test_batch_size', type = int, default = 1, help = 'size of the batches')
    parser.add_argument('--num_workers', type = int, default = 1, help = 'number of workers')
    # Initialization parameters
    parser.add_argument('--init_type', type = str, default = 'xavier', help = 'initialization type of generator')
    parser.add_argument('--init_gain', type = float, default = 0.02, help = 'initialization gain of generator')
    # Dataset parameters
    parser.add_argument('--crop', type = str2bool, default = False, help = 'whether to crop input images')
    parser.add_argument('--crop_size', type = int, default = 512, help = 'single patch size')
    parser.add_argument('--resize', type = str2bool, default = True, help = 'whether to resize input images')
    parser.add_argument('--scale_size', type = int, default = 1024, help = 'single patch size')
    parser.add_argument('--geometry_aug', type = str2bool, default = False, help = 'geometry augmentation (scaling)')
    parser.add_argument('--angle_aug', type = str2bool, default = False, help = 'geometry augmentation (rotation, flipping)')
    parser.add_argument('--scale_min', type = float, default = 1, help = 'min scaling factor')
    parser.add_argument('--scale_max', type = float, default = 1, help = 'max scaling factor')
    parser.add_argument('--add_noise', type = str2bool, default = False, help = 'whether to add noise to input images')
    parser.add_argument('--mu', type = int, default = 0, help = 'Gaussian noise mean')
    parser.add_argument('--sigma', type = int, default = 30, help = 'Gaussian noise variance: 30 | 50 | 70')
    opt = parser.parse_args(args=[])
    print(opt)

    # ----------------------------------------
    #                   Test
    # ----------------------------------------
    # Initialize
    if opt.no_gpu:
        generator = utils.create_generator(opt)
    else:
        generator = utils.create_generator(opt).cuda()

    test_dataset = dataset.DenoisingValDataset(opt)
    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size = opt.test_batch_size, shuffle = False, num_workers = opt.num_workers, pin_memory = True)
    sample_folder = opt.save_name
    utils.check_path(sample_folder)

    psnr_sum, psnr_ave, ssim_sum, ssim_ave, time_sum, time_ave, eval_cnt = 0, 0, 0, 0, 0, 0, 0
    
    # Count start time
    prev_time = time.time()
    
    # forward
    for i, (true_input, true_target, height_origin, width_origin) in enumerate(test_loader):
        if i == 100:
            break

        prev_time = time.time()

        # To device
        if opt.no_gpu:
            true_input = true_input
            true_target = true_target
        else:
            true_input = true_input.cuda()
            true_target = true_target.cuda()            

        # Forward propagation
        with torch.no_grad():
            #print(true_input.size()) 
            fake_target, feature_map = generator(true_input, true_input)

        time_iter = time.time() - prev_time
        time_sum = time_sum + time_iter
        #prev_time = time.time()

        # Save
        img_list = [true_input, fake_target, true_target]
        name_list = ['in', 'pred', 'gt']
        sample_name = '%d' % (i+1)
        height_origin = height_origin.item()
        width_origin = width_origin.item()

        utils.save_sample_png(sample_folder = sample_folder, sample_name = '%d' % (i + 1), img_list = img_list, name_list = name_list, pixel_max_cnt = 255, height = height_origin, width = width_origin)
        
        # Evaluation
        img_pred_recover = utils.recover_process(fake_target, height = height_origin, width = width_origin)
        img_gt_recover = utils.recover_process(true_target, height = height_origin, width = width_origin)
        psnr_sum = psnr_sum + utils.psnr(img_pred_recover, img_gt_recover)
        ssim_sum = ssim_sum + structural_similarity(img_gt_recover, img_pred_recover, channel_axis = 2, data_range = 255)
        
        eval_cnt += 1
        
    psnr_ave = psnr_sum / eval_cnt
    ssim_ave = ssim_sum / eval_cnt
    time_ave = time_sum / eval_cnt
    fps = 1/time_ave
    
    print(f"psnr: {psnr_ave}, ssim: {ssim_ave}, time_ave: {time_ave}")


Namespace(no_gpu=False, save_name='../results/models_SAUnet_SASkipnet_SAPerceptuals8_raindrop', load_gname='../models/models_SAUnet_SASkipnet_SAPerceptuals8_raindrop/KPN_rainy_image_epoch200_bs8_generator.pth', baseroot='../rainy_image_dataset/raindrop/test', test_batch_size=1, num_workers=1, init_type='xavier', init_gain=0.02, crop=False, crop_size=512, resize=True, scale_size=1024, geometry_aug=False, angle_aug=False, scale_min=1, scale_max=1, add_noise=False, mu=0, sigma=30)
Generator are loaded!
psnr: 24.798089227143954, ssim: 0.9010393023490906, time_ave: 0.029586889743804932
