forked from zzr-idam/4KDehazing
/
test_model.py
70 lines (50 loc) · 1.61 KB
/
test_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import torch
import torch.nn as nn
import torchvision
import torch.backends.cudnn as cudnn
import torch.optim
import os
import sys
import argparse
import time
import network
import numpy as np
from torchvision import transforms
from tqdm import tqdm
from PIL import Image
import torch
import numpy as np
from skimage.measure import compare_ssim
from skimage.measure import compare_psnr
from tqdm import tqdm
import kornia
import dataset
from torch.nn import functional as F
from torchvision.utils import save_image
import network
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
my_model = network.B_transformer().to(device)
my_model.eval()
my_model.to(device)
my_model.load_state_dict(torch.load("/home/dell/4Kdehaze/model/4K_ohaze.pth"))
#GAN.load_state_dict(torch.load("/home/dell/IJCAI/JBL/JBPSC/model/model_g_epoch69.pth"))
to_pil_image = transforms.ToPILImage()
tfs_full = transforms.Compose([
#transforms.Resize(1080),
transforms.ToTensor()
])
def load_simple_list(src_path):
name_list = list()
for name in os.listdir(src_path):
path = os.path.join(src_path, name)
name_list.append(path)
name_list = [name for name in name_list if '.jpg' in name]
name_list.sort()
return name_list
list_s = load_simple_list('/home/dell/4Kdehaze/OHAZE_test')
i = 0
for idx in range(1):
image_in = Image.open('/home/dell/4Kdehaze/OHAZE_test/27_outdoor_hazy.jpg').convert('RGB')
full = tfs_full(image_in).unsqueeze(0).to(device)
output = my_model(full)
save_image(output[0], 'test_result/{}.jpg'.format('27_outdoor_hazy'))