In [None]:
import time
import glob
import torch
import torchvision
import numpy as np
from tqdm import tqdm
from PIL import Image
#from basicsr.data.flare7k_dataset import Flare_Image_Loader,RandomGammaCorrection
from basicsr.archs.uformer_arch import Uformer
#from basicsr.archs.unet_arch import U_Net
from basicsr.utils.flare_util import blend_light_source,get_args_from_json,save_args_to_json,mkdir,predict_flare_from_6_channel,predict_flare_from_3_channel
from torch.distributions import Normal
import torchvision.transforms as transforms
import os

from models import *
from collections import OrderedDict

def mkdir(path):
	folder = os.path.exists(path)
	if not folder:
		os.makedirs(path)


def load_checkpoint(model, weights):
    checkpoint = torch.load(
        weights, map_location=lambda storage, loc: storage.cuda(0))
    new_state_dict = OrderedDict()
    for key, value in checkpoint['state_dict'].items():
        if key.startswith('module'):
            name = key[7:]
        else:
            name = key
        new_state_dict[name] = value
    model.load_state_dict(new_state_dict)

def demo_real(images_path,output_path,model_type,output_ch,pretrain_dir):
    if not os.path.exists(output_path):
        os.makedirs(output_path)
    test_path=sorted(glob.glob(images_path))
    result_path=output_path
    torch.cuda.empty_cache()
    if model_type=='Uformer':
        model=Uformer(img_size=512,img_ch=3,output_ch=output_ch).cuda()
        model.load_state_dict(torch.load(pretrain_dir))
    elif model_type=='U_Net' or model_type=='U-Net':
        model=U_Net(img_ch=3,output_ch=output_ch).cuda()
        model.load_state_dict(torch.load(pretrain_dir))
    elif model_type=='ours':
        model=Model().cuda()
        load_checkpoint(model,pretrain_dir)
    else:
        assert False, "This model is not supported!!"
    to_tensor=transforms.ToTensor()
    resize=transforms.Resize((512,512)) #The output should in the shape of 128X
    for i,image_path in tqdm(enumerate(test_path)):
        mkdir(result_path+"deflare/")
        mkdir(result_path+"flare/")
        mkdir(result_path+"input/")
        mkdir(result_path+"blend/")

        deflare_path = result_path+"deflare/"+str(i).zfill(5)+"_deflare.png"
        flare_path = result_path+"flare/"+str(i).zfill(5)+"_flare.png"
        merge_path = result_path+"input/"+str(i).zfill(5)+"_input.png"
        blend_path = result_path+"blend/"+str(i).zfill(5)+"_blend.png"

        merge_img = Image.open(image_path).convert("RGB")
        #merge_img = resize(to_tensor(merge_img))
        merge_img = (to_tensor(merge_img)).cuda().unsqueeze(0)

        model.eval()
        with torch.no_grad():
            output_img=model(merge_img)
            print(output_img.shape)
            #if ch is 6, first three channels are deflare image, others are flare image
            #if ch is 3, unsaturated region of output is the deflare image.
            gamma=torch.Tensor([2.2])
            if output_ch==6:
                deflare_img,flare_img_predicted,merge_img_predicted=predict_flare_from_6_channel(output_img,gamma)
            elif output_ch==3:
                flare_mask=torch.zeros_like(merge_img)
                deflare_img,flare_img_predicted=predict_flare_from_3_channel(output_img,flare_mask,output_img,merge_img,merge_img,gamma)
            else:
                assert False, "This output_ch is not supported!!"

            blend_img= blend_light_source(merge_img, deflare_img, 0.945)

            torchvision.utils.save_image(merge_img, merge_path)
            torchvision.utils.save_image(flare_img_predicted, flare_path)
            torchvision.utils.save_image(deflare_img, deflare_path)
            torchvision.utils.save_image(blend_img, blend_path)


def demo_syn(images_path, output_path, model_type, output_ch, pretrain_dir):
    if not os.path.exists(output_path):
        os.makedirs(output_path)
    test_path = sorted(glob.glob(images_path))
    result_path = output_path
    torch.cuda.empty_cache()
    if model_type == 'Uformer':
        model = Uformer(img_size=512, img_ch=3, output_ch=output_ch).cuda()
        model.load_state_dict(torch.load(pretrain_dir))
    elif model_type == 'U_Net' or model_type == 'U-Net':
        model = U_Net(img_ch=3, output_ch=output_ch).cuda()
        model.load_state_dict(torch.load(pretrain_dir))
    elif model_type == 'ours':
        model = Model().cuda()
        load_checkpoint(model, pretrain_dir)
    else:
        assert False, "This model is not supported!!"
    to_tensor = transforms.ToTensor()
    # The output should in the shape of 128X
    resize = transforms.Resize((512, 512))
    for i, image_path in tqdm(enumerate(test_path)):
        mkdir(result_path+"deflare/")
        mkdir(result_path+"flare/")
        mkdir(result_path+"input/")
        mkdir(result_path+"blend/")

        deflare_path = result_path+"deflare/"+str(i).zfill(5)+"_deflare.png"
        flare_path = result_path+"flare/"+str(i).zfill(5)+"_flare.png"
        merge_path = result_path+"input/"+str(i).zfill(5)+"_input.png"
        blend_path = result_path+"blend/"+str(i).zfill(5)+"_blend.png"

        merge_img = Image.open(image_path).convert("RGB")
        #merge_img = resize(to_tensor(merge_img))
        merge_img = (to_tensor(merge_img)).cuda().unsqueeze(0)

        model.eval()
        with torch.no_grad():
            output_img = model(merge_img)
            print(output_img.shape)
            #if ch is 6, first three channels are deflare image, others are flare image
            #if ch is 3, unsaturated region of output is the deflare image.
            gamma = torch.Tensor([2.2])
            if output_ch == 6:
                deflare_img, flare_img_predicted, merge_img_predicted = predict_flare_from_6_channel(
                    output_img, gamma)
            elif output_ch == 3:
                flare_mask = torch.zeros_like(merge_img)
                deflare_img, flare_img_predicted = predict_flare_from_3_channel(
                    output_img, flare_mask, output_img, merge_img, merge_img, gamma)
            else:
                assert False, "This output_ch is not supported!!"

            blend_img = blend_light_source(merge_img, deflare_img, 0.999)

            torchvision.utils.save_image(merge_img, merge_path)
            torchvision.utils.save_image(flare_img_predicted, flare_path)
            torchvision.utils.save_image(deflare_img, deflare_path)
            torchvision.utils.save_image(blend_img, blend_path)


#model_type="Uformer"
model_type="ours"
images_path_real="datasets/Flare7k/test_data/real/input/*.*"
images_path_syn = "datasets/Flare7k/test_data/synthetic/input/*.*"
#images_path = "datasets/4ks/*.*"


#result_path="result/real/Uformer_noreflection/"
#result_path = "result/synthetic/Uformer/"
result_path_syn = ""
result_path_real = ""
#result_path="result/4ks/"
#pretrain_dir='experiments/pretrained_models/uformer/net_g_last.pth'
pretrain_dir = ''
output_ch=3
mask_flag=False



demo_real(images_path_real, result_path_real, model_type, output_ch, pretrain_dir)
demo_syn(images_path_syn, result_path_syn,
          model_type, output_ch, pretrain_dir)



