In [1]:
import time
import glob
import torch
import torchvision
import numpy as np
from tqdm import tqdm
from PIL import Image
#from basicsr.data.flare7k_dataset import Flare_Image_Loader,RandomGammaCorrection
from basicsr.archs.uformer_arch import Uformer
#from basicsr.archs.unet_arch import U_Net
from basicsr.utils.flare_util import blend_light_source,get_args_from_json,save_args_to_json,mkdir,predict_flare_from_6_channel,predict_flare_from_3_channel
from torch.distributions import Normal
import torchvision.transforms as transforms
import os

from models import *
from collections import OrderedDict

def mkdir(path):
	folder = os.path.exists(path)
	if not folder:
		os.makedirs(path)


def load_checkpoint(model, weights):
    checkpoint = torch.load(
        weights, map_location=lambda storage, loc: storage.cuda(0))
    new_state_dict = OrderedDict()
    for key, value in checkpoint['state_dict'].items():
        if key.startswith('module'):
            name = key[7:]
        else:
            name = key
        new_state_dict[name] = value
    model.load_state_dict(new_state_dict)

def demo_real(images_path,output_path,model_type,output_ch,pretrain_dir):
    if not os.path.exists(output_path):
        os.makedirs(output_path)
    test_path=sorted(glob.glob(images_path))
    result_path=output_path
    torch.cuda.empty_cache()
    if model_type=='Uformer':
        model=Uformer(img_size=512,img_ch=3,output_ch=output_ch).cuda()
        model.load_state_dict(torch.load(pretrain_dir))
    elif model_type=='U_Net' or model_type=='U-Net':
        model=U_Net(img_ch=3,output_ch=output_ch).cuda()
        model.load_state_dict(torch.load(pretrain_dir))
    elif model_type=='ours':
        model=Model().cuda()
        load_checkpoint(model,pretrain_dir)
    else:
        assert False, "This model is not supported!!"
    to_tensor=transforms.ToTensor()
    resize=transforms.Resize((512,512)) #The output should in the shape of 128X
    for i,image_path in tqdm(enumerate(test_path)):
        mkdir(result_path+"deflare/")
        mkdir(result_path+"flare/")
        mkdir(result_path+"input/")
        mkdir(result_path+"blend/")

        deflare_path = result_path+"deflare/"+str(i).zfill(5)+"_deflare.png"
        flare_path = result_path+"flare/"+str(i).zfill(5)+"_flare.png"
        merge_path = result_path+"input/"+str(i).zfill(5)+"_input.png"
        blend_path = result_path+"blend/"+str(i).zfill(5)+"_blend.png"

        merge_img = Image.open(image_path).convert("RGB")
        #merge_img = resize(to_tensor(merge_img))
        merge_img = (to_tensor(merge_img)).cuda().unsqueeze(0)

        model.eval()
        with torch.no_grad():
            output_img=model(merge_img)
            print(output_img.shape)
            #if ch is 6, first three channels are deflare image, others are flare image
            #if ch is 3, unsaturated region of output is the deflare image.
            gamma=torch.Tensor([2.2])
            if output_ch==6:
                deflare_img,flare_img_predicted,merge_img_predicted=predict_flare_from_6_channel(output_img,gamma)
            elif output_ch==3:
                flare_mask=torch.zeros_like(merge_img)
                deflare_img,flare_img_predicted=predict_flare_from_3_channel(output_img,flare_mask,output_img,merge_img,merge_img,gamma)
            else:
                assert False, "This output_ch is not supported!!"

            blend_img= blend_light_source(merge_img, deflare_img, 0.945)

            torchvision.utils.save_image(merge_img, merge_path)
            torchvision.utils.save_image(flare_img_predicted, flare_path)
            torchvision.utils.save_image(deflare_img, deflare_path)
            torchvision.utils.save_image(blend_img, blend_path)


def demo_syn(images_path, output_path, model_type, output_ch, pretrain_dir):
    if not os.path.exists(output_path):
        os.makedirs(output_path)
    test_path = sorted(glob.glob(images_path))
    result_path = output_path
    torch.cuda.empty_cache()
    if model_type == 'Uformer':
        model = Uformer(img_size=512, img_ch=3, output_ch=output_ch).cuda()
        model.load_state_dict(torch.load(pretrain_dir))
    elif model_type == 'U_Net' or model_type == 'U-Net':
        model = U_Net(img_ch=3, output_ch=output_ch).cuda()
        model.load_state_dict(torch.load(pretrain_dir))
    elif model_type == 'ours':
        model = Model().cuda()
        load_checkpoint(model, pretrain_dir)
    else:
        assert False, "This model is not supported!!"
    to_tensor = transforms.ToTensor()
    # The output should in the shape of 128X
    resize = transforms.Resize((512, 512))
    for i, image_path in tqdm(enumerate(test_path)):
        mkdir(result_path+"deflare/")
        mkdir(result_path+"flare/")
        mkdir(result_path+"input/")
        mkdir(result_path+"blend/")

        deflare_path = result_path+"deflare/"+str(i).zfill(5)+"_deflare.png"
        flare_path = result_path+"flare/"+str(i).zfill(5)+"_flare.png"
        merge_path = result_path+"input/"+str(i).zfill(5)+"_input.png"
        blend_path = result_path+"blend/"+str(i).zfill(5)+"_blend.png"

        merge_img = Image.open(image_path).convert("RGB")
        #merge_img = resize(to_tensor(merge_img))
        merge_img = (to_tensor(merge_img)).cuda().unsqueeze(0)

        model.eval()
        with torch.no_grad():
            output_img = model(merge_img)
            print(output_img.shape)
            #if ch is 6, first three channels are deflare image, others are flare image
            #if ch is 3, unsaturated region of output is the deflare image.
            gamma = torch.Tensor([2.2])
            if output_ch == 6:
                deflare_img, flare_img_predicted, merge_img_predicted = predict_flare_from_6_channel(
                    output_img, gamma)
            elif output_ch == 3:
                flare_mask = torch.zeros_like(merge_img)
                deflare_img, flare_img_predicted = predict_flare_from_3_channel(
                    output_img, flare_mask, output_img, merge_img, merge_img, gamma)
            else:
                assert False, "This output_ch is not supported!!"

            blend_img = blend_light_source(merge_img, deflare_img, 0.999)

            torchvision.utils.save_image(merge_img, merge_path)
            torchvision.utils.save_image(flare_img_predicted, flare_path)
            torchvision.utils.save_image(deflare_img, deflare_path)
            torchvision.utils.save_image(blend_img, blend_path)


#model_type="Uformer"
model_type="ours"
images_path_real="datasets/Flare7k/test_data/real/input/*.*"
images_path_syn = "datasets/Flare7k/test_data/synthetic/input/*.*"
#images_path = "datasets/4ks/*.*"
#images_path="/home/ipprlab/Projects/flare_removal/test/real-20220929T124032Z-001/real/input/*.*"

#result_path="result/real/Uformer_noreflection/"
#result_path = "result/synthetic/Uformer/"
result_path_syn = "result/synthetic/our/noanew_115"
result_path_real = "result/real/our/noanew_115"
#result_path="result/4ks/"
#pretrain_dir='experiments/pretrained_models/uformer/net_g_last.pth'
pretrain_dir = 'experiments/pretrained_models/our/all_16_abnoa_epoch_115_iter_1_psnr_26.42.pth'
output_ch=3
mask_flag=False


time_start = time.time()
#在这里运行模型
demo_real(images_path_real, result_path_real, model_type, output_ch, pretrain_dir)
demo_syn(images_path_syn, result_path_syn,
          model_type, output_ch, pretrain_dir)
time_end = time.time()
print('totally cost', time_end-time_start)




0it [00:00, ?it/s]

torch.Size([1, 3, 512, 512])


1it [00:01,  1.27s/it]

torch.Size([1, 3, 512, 512])


2it [00:01,  1.38it/s]

torch.Size([1, 3, 512, 512])


3it [00:01,  1.76it/s]

torch.Size([1, 3, 512, 512])


4it [00:02,  2.09it/s]

torch.Size([1, 3, 512, 512])


5it [00:02,  2.32it/s]

torch.Size([1, 3, 512, 512])


6it [00:03,  2.46it/s]

torch.Size([1, 3, 512, 512])


7it [00:03,  2.52it/s]

torch.Size([1, 3, 512, 512])


8it [00:03,  2.69it/s]

torch.Size([1, 3, 512, 512])


9it [00:04,  2.75it/s]

torch.Size([1, 3, 512, 512])


10it [00:04,  2.83it/s]

torch.Size([1, 3, 512, 512])


11it [00:04,  2.83it/s]

torch.Size([1, 3, 512, 512])


12it [00:05,  2.82it/s]

torch.Size([1, 3, 512, 512])


13it [00:05,  2.88it/s]

torch.Size([1, 3, 512, 512])


14it [00:05,  2.88it/s]

torch.Size([1, 3, 512, 512])


15it [00:06,  2.84it/s]

torch.Size([1, 3, 512, 512])


16it [00:06,  2.90it/s]

torch.Size([1, 3, 512, 512])


17it [00:06,  2.88it/s]

torch.Size([1, 3, 512, 512])


18it [00:07,  2.89it/s]

torch.Size([1, 3, 512, 512])


19it [00:07,  2.95it/s]

torch.Size([1, 3, 512, 512])


20it [00:07,  3.02it/s]

torch.Size([1, 3, 512, 512])


21it [00:08,  3.02it/s]

torch.Size([1, 3, 512, 512])


22it [00:08,  3.04it/s]

torch.Size([1, 3, 512, 512])


23it [00:08,  3.05it/s]

torch.Size([1, 3, 512, 512])


24it [00:09,  3.04it/s]

torch.Size([1, 3, 512, 512])


25it [00:09,  2.95it/s]

torch.Size([1, 3, 512, 512])


26it [00:09,  2.92it/s]

torch.Size([1, 3, 512, 512])


27it [00:10,  2.93it/s]

torch.Size([1, 3, 512, 512])


28it [00:10,  2.94it/s]

torch.Size([1, 3, 512, 512])


29it [00:10,  2.96it/s]

torch.Size([1, 3, 512, 512])


30it [00:11,  2.96it/s]

torch.Size([1, 3, 512, 512])


31it [00:11,  2.95it/s]

torch.Size([1, 3, 512, 512])


32it [00:11,  2.89it/s]

torch.Size([1, 3, 512, 512])


33it [00:12,  2.89it/s]

torch.Size([1, 3, 512, 512])


34it [00:12,  2.91it/s]

torch.Size([1, 3, 512, 512])


35it [00:12,  2.99it/s]

torch.Size([1, 3, 512, 512])


36it [00:13,  2.99it/s]

torch.Size([1, 3, 512, 512])


37it [00:13,  3.02it/s]

torch.Size([1, 3, 512, 512])


38it [00:13,  2.97it/s]

torch.Size([1, 3, 512, 512])


39it [00:14,  2.86it/s]

torch.Size([1, 3, 512, 512])


40it [00:14,  2.82it/s]

torch.Size([1, 3, 512, 512])


41it [00:15,  2.75it/s]

torch.Size([1, 3, 512, 512])


42it [00:15,  2.87it/s]

torch.Size([1, 3, 512, 512])


43it [00:15,  2.99it/s]

torch.Size([1, 3, 512, 512])


44it [00:15,  3.05it/s]

torch.Size([1, 3, 512, 512])


45it [00:16,  3.03it/s]

torch.Size([1, 3, 512, 512])


46it [00:16,  3.03it/s]

torch.Size([1, 3, 512, 512])


47it [00:16,  3.01it/s]

torch.Size([1, 3, 512, 512])


48it [00:17,  3.01it/s]

torch.Size([1, 3, 512, 512])


49it [00:17,  3.06it/s]

torch.Size([1, 3, 512, 512])


50it [00:17,  3.09it/s]

torch.Size([1, 3, 512, 512])


51it [00:18,  3.12it/s]

torch.Size([1, 3, 512, 512])


52it [00:18,  3.00it/s]

torch.Size([1, 3, 512, 512])


53it [00:18,  2.95it/s]

torch.Size([1, 3, 512, 512])


54it [00:19,  2.85it/s]

torch.Size([1, 3, 512, 512])


55it [00:19,  2.78it/s]

torch.Size([1, 3, 512, 512])


56it [00:20,  2.73it/s]

torch.Size([1, 3, 512, 512])


57it [00:20,  2.68it/s]

torch.Size([1, 3, 512, 512])


58it [00:20,  2.66it/s]

torch.Size([1, 3, 512, 512])


59it [00:21,  2.64it/s]

torch.Size([1, 3, 512, 512])


60it [00:21,  2.68it/s]

torch.Size([1, 3, 512, 512])


61it [00:21,  2.71it/s]

torch.Size([1, 3, 512, 512])


62it [00:22,  2.73it/s]

torch.Size([1, 3, 512, 512])


63it [00:22,  2.79it/s]

torch.Size([1, 3, 512, 512])


64it [00:22,  2.86it/s]

torch.Size([1, 3, 512, 512])


65it [00:23,  2.90it/s]

torch.Size([1, 3, 512, 512])


66it [00:23,  2.90it/s]

torch.Size([1, 3, 512, 512])


67it [00:24,  2.88it/s]

torch.Size([1, 3, 512, 512])


68it [00:24,  2.85it/s]

torch.Size([1, 3, 512, 512])


69it [00:24,  2.83it/s]

torch.Size([1, 3, 512, 512])


70it [00:25,  2.87it/s]

torch.Size([1, 3, 512, 512])


71it [00:25,  2.96it/s]

torch.Size([1, 3, 512, 512])


72it [00:25,  3.02it/s]

torch.Size([1, 3, 512, 512])


73it [00:26,  3.09it/s]

torch.Size([1, 3, 512, 512])


74it [00:26,  3.08it/s]

torch.Size([1, 3, 512, 512])


75it [00:26,  3.06it/s]

torch.Size([1, 3, 512, 512])


76it [00:26,  3.09it/s]

torch.Size([1, 3, 512, 512])


77it [00:27,  2.96it/s]

torch.Size([1, 3, 512, 512])


78it [00:27,  2.85it/s]

torch.Size([1, 3, 512, 512])


79it [00:28,  2.79it/s]

torch.Size([1, 3, 512, 512])


80it [00:28,  2.78it/s]

torch.Size([1, 3, 512, 512])


81it [00:28,  2.74it/s]

torch.Size([1, 3, 512, 512])


82it [00:29,  2.73it/s]

torch.Size([1, 3, 512, 512])


83it [00:29,  2.72it/s]

torch.Size([1, 3, 512, 512])


84it [00:29,  2.77it/s]

torch.Size([1, 3, 512, 512])


85it [00:30,  2.81it/s]

torch.Size([1, 3, 512, 512])


86it [00:30,  2.85it/s]

torch.Size([1, 3, 512, 512])


87it [00:30,  2.89it/s]

torch.Size([1, 3, 512, 512])


88it [00:31,  2.92it/s]

torch.Size([1, 3, 512, 512])


89it [00:31,  2.94it/s]

torch.Size([1, 3, 512, 512])


90it [00:31,  2.95it/s]

torch.Size([1, 3, 512, 512])


91it [00:32,  2.93it/s]

torch.Size([1, 3, 512, 512])


92it [00:32,  2.95it/s]

torch.Size([1, 3, 512, 512])


93it [00:32,  3.01it/s]

torch.Size([1, 3, 512, 512])


94it [00:33,  3.05it/s]

torch.Size([1, 3, 512, 512])


95it [00:33,  3.02it/s]

torch.Size([1, 3, 512, 512])


96it [00:33,  3.07it/s]

torch.Size([1, 3, 512, 512])


97it [00:34,  3.01it/s]

torch.Size([1, 3, 512, 512])


98it [00:34,  2.91it/s]

torch.Size([1, 3, 512, 512])


99it [00:35,  2.83it/s]

torch.Size([1, 3, 512, 512])


100it [00:35,  2.83it/s]
0it [00:00, ?it/s]

torch.Size([1, 3, 512, 512])


1it [00:00,  2.93it/s]

torch.Size([1, 3, 512, 512])


2it [00:00,  3.10it/s]

torch.Size([1, 3, 512, 512])


3it [00:00,  3.13it/s]

torch.Size([1, 3, 512, 512])


4it [00:01,  3.11it/s]

torch.Size([1, 3, 512, 512])


5it [00:01,  3.07it/s]

torch.Size([1, 3, 512, 512])


6it [00:01,  3.10it/s]

torch.Size([1, 3, 512, 512])


7it [00:02,  3.02it/s]

torch.Size([1, 3, 512, 512])


8it [00:02,  3.02it/s]

torch.Size([1, 3, 512, 512])


9it [00:02,  3.02it/s]

torch.Size([1, 3, 512, 512])


10it [00:03,  3.06it/s]

torch.Size([1, 3, 512, 512])


11it [00:03,  3.07it/s]

torch.Size([1, 3, 512, 512])


12it [00:03,  3.08it/s]

torch.Size([1, 3, 512, 512])


13it [00:04,  3.06it/s]

torch.Size([1, 3, 512, 512])


14it [00:04,  3.03it/s]

torch.Size([1, 3, 512, 512])


15it [00:04,  2.97it/s]

torch.Size([1, 3, 512, 512])


16it [00:05,  2.99it/s]

torch.Size([1, 3, 512, 512])


17it [00:05,  2.98it/s]

torch.Size([1, 3, 512, 512])


18it [00:05,  3.04it/s]

torch.Size([1, 3, 512, 512])


19it [00:06,  3.00it/s]

torch.Size([1, 3, 512, 512])


20it [00:06,  3.11it/s]

torch.Size([1, 3, 512, 512])


21it [00:06,  3.12it/s]

torch.Size([1, 3, 512, 512])


22it [00:07,  3.06it/s]

torch.Size([1, 3, 512, 512])


23it [00:07,  3.03it/s]

torch.Size([1, 3, 512, 512])


24it [00:07,  2.95it/s]

torch.Size([1, 3, 512, 512])


25it [00:08,  2.94it/s]

torch.Size([1, 3, 512, 512])


26it [00:08,  2.95it/s]

torch.Size([1, 3, 512, 512])


27it [00:08,  2.93it/s]

torch.Size([1, 3, 512, 512])


28it [00:09,  2.95it/s]

torch.Size([1, 3, 512, 512])


29it [00:09,  2.92it/s]

torch.Size([1, 3, 512, 512])


30it [00:09,  2.91it/s]

torch.Size([1, 3, 512, 512])


31it [00:10,  2.87it/s]

torch.Size([1, 3, 512, 512])


32it [00:10,  2.93it/s]

torch.Size([1, 3, 512, 512])


33it [00:10,  3.00it/s]

torch.Size([1, 3, 512, 512])


34it [00:11,  3.20it/s]

torch.Size([1, 3, 512, 512])


35it [00:11,  3.09it/s]

torch.Size([1, 3, 512, 512])


36it [00:11,  3.00it/s]

torch.Size([1, 3, 512, 512])


37it [00:12,  2.92it/s]

torch.Size([1, 3, 512, 512])


38it [00:12,  2.93it/s]

torch.Size([1, 3, 512, 512])


39it [00:12,  2.98it/s]

torch.Size([1, 3, 512, 512])


40it [00:13,  2.96it/s]

torch.Size([1, 3, 512, 512])


41it [00:13,  2.92it/s]

torch.Size([1, 3, 512, 512])


42it [00:13,  3.02it/s]

torch.Size([1, 3, 512, 512])


43it [00:14,  3.08it/s]

torch.Size([1, 3, 512, 512])


44it [00:14,  3.13it/s]

torch.Size([1, 3, 512, 512])


45it [00:14,  3.13it/s]

torch.Size([1, 3, 512, 512])


46it [00:15,  3.04it/s]

torch.Size([1, 3, 512, 512])


47it [00:15,  3.00it/s]

torch.Size([1, 3, 512, 512])


48it [00:15,  3.11it/s]

torch.Size([1, 3, 512, 512])


49it [00:16,  3.15it/s]

torch.Size([1, 3, 512, 512])


50it [00:16,  3.04it/s]

torch.Size([1, 3, 512, 512])


51it [00:16,  3.12it/s]

torch.Size([1, 3, 512, 512])


52it [00:17,  3.10it/s]

torch.Size([1, 3, 512, 512])


53it [00:17,  3.10it/s]

torch.Size([1, 3, 512, 512])


54it [00:17,  3.01it/s]

torch.Size([1, 3, 512, 512])


55it [00:18,  2.99it/s]

torch.Size([1, 3, 512, 512])


56it [00:18,  2.93it/s]

torch.Size([1, 3, 512, 512])


57it [00:18,  2.86it/s]

torch.Size([1, 3, 512, 512])


58it [00:19,  2.82it/s]

torch.Size([1, 3, 512, 512])


59it [00:19,  2.76it/s]

torch.Size([1, 3, 512, 512])


60it [00:19,  2.84it/s]

torch.Size([1, 3, 512, 512])


61it [00:20,  2.84it/s]

torch.Size([1, 3, 512, 512])


62it [00:20,  2.83it/s]

torch.Size([1, 3, 512, 512])


63it [00:21,  2.89it/s]

torch.Size([1, 3, 512, 512])


64it [00:21,  3.10it/s]

torch.Size([1, 3, 512, 512])


65it [00:21,  3.08it/s]

torch.Size([1, 3, 512, 512])


66it [00:21,  3.11it/s]

torch.Size([1, 3, 512, 512])


67it [00:22,  3.22it/s]

torch.Size([1, 3, 512, 512])


68it [00:22,  3.23it/s]

torch.Size([1, 3, 512, 512])


69it [00:22,  3.24it/s]

torch.Size([1, 3, 512, 512])


70it [00:23,  3.24it/s]

torch.Size([1, 3, 512, 512])


71it [00:23,  3.31it/s]

torch.Size([1, 3, 512, 512])


72it [00:23,  3.36it/s]

torch.Size([1, 3, 512, 512])


73it [00:24,  3.34it/s]

torch.Size([1, 3, 512, 512])


74it [00:24,  3.28it/s]

torch.Size([1, 3, 512, 512])


75it [00:24,  3.35it/s]

torch.Size([1, 3, 512, 512])


76it [00:24,  3.43it/s]

torch.Size([1, 3, 512, 512])


77it [00:25,  3.47it/s]

torch.Size([1, 3, 512, 512])


78it [00:25,  3.44it/s]

torch.Size([1, 3, 512, 512])


79it [00:25,  3.32it/s]

torch.Size([1, 3, 512, 512])


80it [00:26,  3.37it/s]

torch.Size([1, 3, 512, 512])


81it [00:26,  3.26it/s]

torch.Size([1, 3, 512, 512])


82it [00:26,  3.30it/s]

torch.Size([1, 3, 512, 512])


83it [00:27,  3.19it/s]

torch.Size([1, 3, 512, 512])


84it [00:27,  3.13it/s]

torch.Size([1, 3, 512, 512])


85it [00:27,  3.13it/s]

torch.Size([1, 3, 512, 512])


86it [00:28,  3.10it/s]

torch.Size([1, 3, 512, 512])


87it [00:28,  3.11it/s]

torch.Size([1, 3, 512, 512])


88it [00:28,  3.07it/s]

torch.Size([1, 3, 512, 512])


89it [00:29,  3.05it/s]

torch.Size([1, 3, 512, 512])


90it [00:29,  3.03it/s]

torch.Size([1, 3, 512, 512])


91it [00:29,  2.93it/s]

torch.Size([1, 3, 512, 512])


92it [00:30,  2.90it/s]

torch.Size([1, 3, 512, 512])


93it [00:30,  2.92it/s]

torch.Size([1, 3, 512, 512])


94it [00:30,  2.99it/s]

torch.Size([1, 3, 512, 512])


95it [00:31,  2.96it/s]

torch.Size([1, 3, 512, 512])


96it [00:31,  2.93it/s]

torch.Size([1, 3, 512, 512])


97it [00:31,  2.98it/s]

torch.Size([1, 3, 512, 512])


98it [00:32,  3.07it/s]

torch.Size([1, 3, 512, 512])


99it [00:32,  3.04it/s]

torch.Size([1, 3, 512, 512])


100it [00:32,  3.06it/s]

totally cost 69.70273470878601





In [18]:
import numpy as np
import cv2 as cv

img = cv.imread("/home/ipprlab/Projects/flare/Flare7K/datasets/synthetic/10.jpg")

# 缩放图像，后面的其他程序都是在这一行上改动
dst = cv.resize(img, (3840,2160), interpolation=cv.INTER_AREA)
cv.imwrite("/home/ipprlab/Projects/flare/Flare7K/datasets/synthetic/10_2k.jpg",dst)
# 显示图像
# cv.imshow("dst: %d x %d" % (dst.shape[0], dst.shape[1]), dst)
# cv.waitKey(0)
# cv.destroyAllWindows()


True