In [2]:
import time
import torch
import argparse
import torch.nn as nn
import torch.nn.functional as F
#import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from datasets.pretrain_datasets import TrainData, ValData, ValData_AECR, edge_compute
from models.GCA import GCANet
from models.FFA import FFANet
from models.MSBDN import MSBDNNet
from models.UNet import UNet
from models.Vanilla_AECRNet import Dehaze as AECRNet
from models.AECRNet import Dehaze as PSD_AECRNet
from utils import to_psnr, print_log, validation, adjust_learning_rate, ssim
from torchvision.models import vgg16
import math
from pdb import set_trace as bp
import numpy as np
import os
from PIL import Image
import matplotlib.pyplot as plt
from skimage.transform import resize

#from perceptual import LossNetwork
def lr_schedule_cosdecay(t,T,init_lr=1e-4):
    lr=0.5*(1+math.cos(t*math.pi/T))*init_lr
    return lr

def tensor2numpy(tensor):
    # gpu tensor (N x C x H x W) => cpu numpy array (N x H x W x C)
    return tensor.transpose(1,2).transpose(2,3).detach().to("cpu").numpy()

# load paramerters
def load_params(model, filename):
    params = torch.load(filename)
    model.load_state_dict(params)

def get_validation(net, net_name, val_data_loader, device, dataset):

    psnr_list = []
    ssim_list = []
    
    haze_dir = dataset.haze_dir
    gt_dir = dataset.gt_dir
    for batch_id, val_data in enumerate(val_data_loader):
        #if batch_id > 1:
        #    break
        with torch.no_grad():
            haze, haze_A, gt, image_name, gca_haze, input_img = val_data
            haze = haze.to(device)
            gt = gt.to(device)
            B, _, H, W = haze.shape
            
            # Get Image for non-PSD Input
            img = np.array(Image.open(haze_dir + image_name[0]))
            img = img.astype(np.float32) / 255
            img = torch.from_numpy(img)
            img = img.unsqueeze(0).transpose(2,3).transpose(1,2).to(device)
            
            # Resize image for testing
            if haze.size()[2] % 16 != 0 or haze.size()[3] % 16 != 0:
                haze = F.upsample(haze, [haze.size()[2] + 16 - haze.size()[2] % 16,
                                haze.size()[3] + 16 - haze.size()[3] % 16], mode='bilinear')
            if haze_A.size()[2] % 16 != 0 or haze_A.size()[3] % 16 != 0:
                haze_A = F.upsample(haze_A, [haze_A.size()[2] + 16 - haze_A.size()[2] % 16,
                                haze_A.size()[3] + 16 - haze_A.size()[3] % 16], mode='bilinear')
            if gca_haze.size()[2] % 16 != 0 or gca_haze.size()[3] % 16 != 0:
                gca_haze = F.upsample(gca_haze, [gca_haze.size()[2] + 16 - gca_haze.size()[2] % 16,
                                gca_haze.size()[3] + 16 - gca_haze.size()[3] % 16], mode='bilinear')
            if img.size()[2] % 16 != 0 or img.size()[3] % 16 != 0:
                img = F.upsample(img, [img.size()[2] + 16 - img.size()[2] % 16,
                                img.size()[3] + 16 - img.size()[3] % 16], mode='bilinear')
            if gt.size()[2] % 16 != 0 or gt.size()[3] % 16 != 0:
                    gt = F.upsample(gt, [gt.size()[2] + 16 - gt.size()[2] % 16, 
                                    gt.size()[3] + 16 - gt.size()[3] % 16], mode='bilinear')
            
            if net_name == 'UNet':
                dehaze = net(img)
            elif net_name == 'AECRNet':
                _, dehaze, T, A, I = net(haze, haze_A, True)
            else:
                dehaze = net(gca_haze, 0, True, False)/255

        # --- Calculate the average PSNR --- #
        psnr_list.extend(to_psnr(dehaze, gt))

        # --- Calculate the average SSIM --- #
        ssim_list.extend(ssim(dehaze, gt))
    
    avr_psnr = sum(psnr_list) / len(psnr_list)
    avr_ssim = sum(ssim_list) / len(ssim_list)
    return avr_psnr, avr_ssim

if __name__ == '__main__':
    torch.cuda.empty_cache()
    lr=1e-4
    device_ids = [Id for Id in range(torch.cuda.device_count())]
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    val_batch_size = 1
    category = 'outdoor'
    torch.backends.cudnn.benchmark = True
    
    val_data_dir = 'images/SOTS/outdoor/'
    val_dataset = ValData_AECR(val_data_dir)
    val_data_loader = DataLoader(val_dataset, batch_size=val_batch_size, shuffle=False, num_workers=1)
    print("DATALOADER DONE!")
    
    ### AECRNet Output #######
    net = AECRNet(3, 3)
    net.to(device)
    load_params(net, "pre-trained/AECRNET.pth".format(5))
    net.eval()
    
    val_psnr, val_ssim = get_validation(net, 'UNet', val_data_loader, device, val_dataset)
    print('[AECRNET] Val_PSNR: {0:.2f}, Val_SSIM: {1:.4f}'.format(val_psnr, val_ssim))
    ##########################
    
    ### PSD-AECRNet Output ###
    num_epochs = 20
    for epoch in range(num_epochs):
        net = PSD_AECRNet(3, 3)
        net.to(device)
        net = nn.DataParallel(net, device_ids=device_ids)
        net.load_state_dict(torch.load('pre-trained/PSD-AECRNET-{}'.format(epoch)))
        net.eval()

        val_psnr, val_ssim = get_validation(net, 'AECRNet', val_data_loader, device, val_dataset)
        print('[PSD-AECRNET-{0:}] Val_PSNR: {1:.2f}, Val_SSIM: {2:.4f}'.format(epoch, val_psnr, val_ssim))
    ##########################
    
    ### UNet Output ##########
    net = UNet()
    net.to(device)
    load_params(net, "pre-trained/UNET.pth".format(5))
    net.eval()
    
    val_psnr, val_ssim = get_validation(net, 'UNet', val_data_loader, device, val_dataset)
    print('[UNET] Val_PSNR: {0:.2f}, Val_SSIM: {1:.4f}'.format(val_psnr, val_ssim))
    ##########################

    ### PSD-FFANet Output ####
    net = FFANet(3, 19)
    net = nn.DataParallel(net, device_ids=device_ids)
    net.load_state_dict(torch.load('pre-trained/PSD-FFANET'))
    net.eval()
   
    val_psnr, val_ssim = get_validation(net, 'AECRNet', val_data_loader, device, val_dataset)
    print('[PSD-FFANET] Val_PSNR: {0:.2f}, Val_SSIM: {1:.4f}'.format(val_psnr, val_ssim))
    ##########################

    ### PSD-GCANet Output ####
    net = GCANet(in_c=4, out_c=3, only_residual=True).to(device)
    net = nn.DataParallel(net, device_ids=device_ids)
    net.load_state_dict(torch.load('pre-trained/PSD-GCANET'))
    net.eval()
    
    val_psnr, val_ssim = get_validation(net, 'GCANet', val_data_loader, device, val_dataset)
    print('[PSD-GCANET] Val_PSNR: {0:.2f}, Val_SSIM: {1:.4f}'.format(val_psnr, val_ssim))
    ##########################

    ### PSD-MSBDNNet Output ##
    net = MSBDNNet()
    net = nn.DataParallel(net, device_ids=device_ids)
    net.load_state_dict(torch.load('pre-trained/PSB-MSBDN'))
    net.eval()
    
    val_psnr, val_ssim = get_validation(net, 'AECRNet', val_data_loader, device, val_dataset)
    print('[PSD-MSBDNNET] Val_PSNR: {0:.2f}, Val_SSIM: {1:.4f}'.format(val_psnr, val_ssim))
    ##########################

DATALOADER DONE!




[AECRNET] Val_PSNR: 25.28, Val_SSIM: 0.9350
[PSD-AECRNET-0] Val_PSNR: 24.18, Val_SSIM: 0.9246
[PSD-AECRNET-1] Val_PSNR: 25.11, Val_SSIM: 0.9370
[PSD-AECRNET-2] Val_PSNR: 26.68, Val_SSIM: 0.9499
[PSD-AECRNET-3] Val_PSNR: 26.34, Val_SSIM: 0.9498
[PSD-AECRNET-4] Val_PSNR: 25.08, Val_SSIM: 0.9407
[PSD-AECRNET-5] Val_PSNR: 26.65, Val_SSIM: 0.9531
[PSD-AECRNET-6] Val_PSNR: 26.97, Val_SSIM: 0.9563
[PSD-AECRNET-7] Val_PSNR: 26.32, Val_SSIM: 0.9468
[PSD-AECRNET-8] Val_PSNR: 26.59, Val_SSIM: 0.9517
[PSD-AECRNET-9] Val_PSNR: 26.40, Val_SSIM: 0.9421
[PSD-AECRNET-10] Val_PSNR: 27.03, Val_SSIM: 0.9568
[PSD-AECRNET-11] Val_PSNR: 26.29, Val_SSIM: 0.9491
[PSD-AECRNET-12] Val_PSNR: 26.46, Val_SSIM: 0.9493
[PSD-AECRNET-13] Val_PSNR: 27.31, Val_SSIM: 0.9562
[PSD-AECRNET-14] Val_PSNR: 26.42, Val_SSIM: 0.9458
[PSD-AECRNET-15] Val_PSNR: 27.14, Val_SSIM: 0.9519
[PSD-AECRNET-16] Val_PSNR: 26.92, Val_SSIM: 0.9504
[PSD-AECRNET-17] Val_PSNR: 26.57, Val_SSIM: 0.9485
[PSD-AECRNET-18] Val_PSNR: 26.78, Val_SSIM: 0.95