In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision.utils as vutils
from datasets.pretrain_datasets import TrainData, ValData, TestData, TestData2, TestData_GCA, TestData_FFA, ValData_AECR
from models.GCA import GCANet
from models.FFA import FFANet
from models.MSBDN import MSBDNNet
from models.UNet import UNet
from models.Vanilla_AECRNet import Dehaze as AECRNet
from models.AECRNet import Dehaze as PSD_AECRNet
from utils import to_psnr, print_log, validation, adjust_learning_rate
import numpy as np
import os
from PIL import Image
import matplotlib.pyplot as plt
from skimage.transform import resize
  
def tensor2numpy(tensor):
    # gpu tensor (N x C x H x W) => cpu numpy array (N x H x W x C)
    return tensor.transpose(1,2).transpose(2,3).detach().to("cpu").numpy()

# load paramerters
def load_params(model, filename):
    params = torch.load(filename)
    model.load_state_dict(params)

if __name__ == '__main__':

    # Get devices
    device_ids = [Id for Id in range(torch.cuda.device_count())]
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Load data
    test_data_dir = 'images/SOTS/outdoor/'
    test_dataset = ValData_AECR(test_data_dir)
    test_data_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=8) # For FFA and MSBDN
        
    with torch.no_grad():
        for batch_id, val_data in enumerate(test_data_loader):
            if batch_id > 50:
                break
            
            # Get Images for PSD Input
            haze, haze_A, gt, image_name, gca_haze, input_img = val_data
            haze.to(device)
            
            # Get Image for non-PSD Input
            img = np.array(Image.open(test_dataset.haze_dir + image_name[0]))
            img = img.astype(np.float32) / 255
            img = torch.from_numpy(img)
            img = img.unsqueeze(0).transpose(2,3).transpose(1,2).to(device)
            
            # Resize image for testing
            if haze.size()[2] % 16 != 0 or haze.size()[3] % 16 != 0:
                haze = F.upsample(haze, [haze.size()[2] + 16 - haze.size()[2] % 16,
                                haze.size()[3] + 16 - haze.size()[3] % 16], mode='bilinear')
            if haze_A.size()[2] % 16 != 0 or haze_A.size()[3] % 16 != 0:
                haze_A = F.upsample(haze_A, [haze_A.size()[2] + 16 - haze_A.size()[2] % 16,
                                haze_A.size()[3] + 16 - haze_A.size()[3] % 16], mode='bilinear')
            if gca_haze.size()[2] % 16 != 0 or gca_haze.size()[3] % 16 != 0:
                gca_haze = F.upsample(gca_haze, [gca_haze.size()[2] + 16 - gca_haze.size()[2] % 16,
                                gca_haze.size()[3] + 16 - gca_haze.size()[3] % 16], mode='bilinear')
            if img.size()[2] % 16 != 0 or img.size()[3] % 16 != 0:
                img = F.upsample(img, [img.size()[2] + 16 - img.size()[2] % 16,
                                img.size()[3] + 16 - img.size()[3] % 16], mode='bilinear')
            if gt.size()[2] % 16 != 0 or gt.size()[3] % 16 != 0:
                    gt = F.upsample(gt, [gt.size()[2] + 16 - gt.size()[2] % 16, 
                                    gt.size()[3] + 16 - gt.size()[3] % 16], mode='bilinear')
            
            ### AECRNet Output #######
            net = AECRNet(3, 3)
            net.to(device)
            load_params(net, "pre-trained/AECRNET.pth".format(5))
            net.eval()
            pred_aecr = net(img)
            ##########################
            
            ### PSD-AECRNet Output ###
            net = PSD_AECRNet(3, 3)
            net = nn.DataParallel(net, device_ids=device_ids)
            net.load_state_dict(torch.load('pre-trained/PSD-AECRNET-13'))
            net.eval()
            _, pred_psd_aecr, T, A, I = net(haze, haze_A, True)
            ##########################
            
            ### UNet Output ##########
            net = UNet()
            net.to(device)
            load_params(net, "pre-trained/UNET.pth".format(5))
            net.eval()
            pred_unet = net(img)
            ##########################
            
            ### PSD-FFANet Output ####
            net = FFANet(3, 19)
            net = nn.DataParallel(net, device_ids=device_ids)
            net.load_state_dict(torch.load('pre-trained/PSD-FFANET'))
            net.eval()
            _, pred_psd_ffa, T, A, I = net(haze, haze_A, True)
            ##########################
            
            ### PSD-GCANet Output ####
            net = GCANet(in_c=4, out_c=3, only_residual=True).to(device)
            net = nn.DataParallel(net, device_ids=device_ids)
            net.load_state_dict(torch.load('pre-trained/PSD-GCANET'))
            net.eval()
            pred_psd_gca = net(gca_haze, 0, True, False)
            ##########################
            
            ### PSD-MSBDNNet Output ##
            net = MSBDNNet()
            net = nn.DataParallel(net, device_ids=device_ids)
            net.load_state_dict(torch.load('pre-trained/PSB-MSBDN'))
            net.eval()
            _, pred_psd_msbdn, T, A, I = net(haze, haze_A, True)
            ##########################
            
            # Draw figures
            fig = plt.figure(figsize=(64, 8))
            axes = fig.subplots(1, 8)
            for axis in axes.flatten():
              axis.set_axis_off()

            axes[0].imshow(tensor2numpy(input_img)[0])
            axes[0].set_title("input")
            
            axes[1].imshow(tensor2numpy(gt)[0])
            axes[1].set_title("ground truth")
            
            axes[2].imshow(tensor2numpy(pred_aecr)[0])
            axes[2].set_title("AECRNET")

            axes[3].imshow(tensor2numpy(pred_psd_aecr)[0])
            axes[3].set_title("PSD-AECRNET")
            
            axes[4].imshow(tensor2numpy(pred_unet)[0])
            axes[4].set_title("UNET")
            
            axes[5].imshow(tensor2numpy(pred_psd_ffa)[0])
            axes[5].set_title("PSD-FFANET")
            
            gca_dehaze = pred_psd_gca.float().round().clamp(0, 255)
            axes[6].imshow(Image.fromarray(gca_dehaze[0].cpu().numpy().astype(np.uint8).transpose(1, 2, 0)))
            axes[6].set_title("PSD-GCANET")
            
            axes[7].imshow(tensor2numpy(pred_psd_msbdn)[0])
            axes[7].set_title("PSD-MSBDNNET")
            plt.show()
            
            torch.cuda.empty_cache()
