/
test.py
195 lines (162 loc) · 6.98 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
import argparse
from pathlib import Path
import torch
import torch.nn as nn
from PIL import Image
from torchvision import transforms
from torchvision.utils import save_image
import net
from params_position import example_all
device = torch.device('cuda:0')
def test_transform(size, crop):
transform_list = []
if size != 0:
transform_list.append(transforms.Resize(size))
if crop:
transform_list.append(transforms.CenterCrop(size))
transform_list.append(transforms.ToTensor())
transform = transforms.Compose(transform_list)
return transform
#Camoufalge function
def camouflage(vgg, decoder, PSF, fore, back, mask):
b,c,w,h = fore.size()
down_sam = nn.MaxPool2d((8, 8), (8, 8), (0, 0), ceil_mode=True)
mask = down_sam(mask)
fore_f = vgg(fore)
back_f = vgg(back)
feat = PSF(fore_f,back_f,mask)
output = decoder(feat)
output = output[:,:,:w,:h]
return output
def embed(fore,mask,back,x,y):
n_b, c_b, w_b, h_b = back.size()
n_f, c_f, w_f, h_f = fore.size()
mask_b = torch.zeros([n_b, 1, w_b, h_b]).to(device)
fore_b = torch.zeros([n_b, c_b, w_b, h_b]).to(device)
mask_b[:,:,x:w_f + x, y : h_f+y] = mask
fore_b[:,:, x:w_f+x, y : h_f+y] = fore
out = torch.mul(back, 1-mask_b)
output = torch.mul(fore_b, mask_b) + out
return output
# Output the coordinates of the upper left corner of the camouflage region,
# the default camouflage region is in the center of the background image.
def position(fore, back):
a_s, b_s, c_s, d_s = back.size()
a_c, b_c, c_c, d_c = fore.size()
x = abs((c_s - c_c) // 2)
y = abs((d_s - d_c) // 2)
return x,y
parser = argparse.ArgumentParser()
parser.add_argument('--use_examples', type=int, default=2, help='Use the input and positional parameters we provide. None means input by the users.')
# If input by users
parser.add_argument('--fore', type=str, default='input/fore/2.jpg', help='Foreground image.')
parser.add_argument('--mask', type=str, default='input/mask/2.png', help='Mask image.')
parser.add_argument('--back', type=str, default='input/back/2.jpg', help='Background image.')
parser.add_argument('--zoomSize', type=int, default=1.5, help='Zoom size.')
parser.add_argument('--Vertical', type=int, default=100, help='Move the camouflage region in the vertical direction, the larger the value, the lower the region.')
parser.add_argument('--Horizontal', type=int, default=0, help='Move the camouflage region in the horizontal direction, the larger the value, the more right the region.')
# Crop parameters
parser.add_argument('--Left', type=int, default=0)
parser.add_argument('--Right', type=int, default=-1)
parser.add_argument('--Top', type=int, default=0)
parser.add_argument('--Bottom', type=int, default=-1)
parser.add_argument('--vgg', type=str, default='models/vgg_normalised.pth')
parser.add_argument('--decoder', type=str, default='models/decoder.pth')
parser.add_argument('--PSF', type=str, default='models/PSF.pth')
# Additional options
parser.add_argument('--fore_size', type=int, default=0,
help='New (minimum) size for the fore image, \
keeping the original size if set to 0')
parser.add_argument('--back_size', type=int, default=0,
help='New (minimum) size for the back image, \
keeping the original size if set to 0')
parser.add_argument('--mask_size', type=int, default=0,
help='New (minimum) size for the mask image, \
keeping the original size if set to 0')
parser.add_argument('--crop', action='store_true',
help='do center crop to create squared image')
parser.add_argument('--save_ext', default='.jpg',
help='The extension name of the output image')
parser.add_argument('--output', type=str, default='output',
help='Directory to save the output image(s)')
args = parser.parse_args()
output_dir = Path(args.output)
output_dir.mkdir(exist_ok=True, parents=True)
if args.use_examples:
assert (args.use_examples>0 and args.use_examples<7)
example = example_all[args.use_examples-1]
fore_path = [Path(example['fore_path'])]
mask_path = [Path(example['mask_path'])]
back_path = [Path(example['back_path'])]
zoomSize = example['zoomSize']
Vertical = example['Vertical']
Horizontal = example['Horizontal']
# Crop
Left = example['Left']
Right = example['Right']
Top = example['Top']
Bottom = example['Bottom']
else:
assert (args.fore)
fore_path = [Path(args.fore)]
assert (args.mask)
mask_path = [Path(args.mask)]
assert (args.back)
back_path = [Path(args.back)]
zoomSize = args.zoomSize
Vertical = args.Vertical
Horizontal = args.Horizontal
Left = args.Left
Right = args.Right + 1
Top = args.Top
Bottom = args.Bottom + 1
decoder = net.decoder
vgg = net.vgg
PSF = net.PSF(in_planes = 512)
decoder.eval()
vgg.eval()
PSF.eval()
decoder.load_state_dict(torch.load(args.decoder))
PSF.load_state_dict(torch.load(args.PSF))
vgg.load_state_dict(torch.load(args.vgg))
vgg = nn.Sequential(*list(vgg.children())[:31])
vgg.to(device)
decoder.to(device)
PSF.to(device)
fore_tf = test_transform(args.fore_size, args.crop)
back_tf = test_transform(args.back_size, args.crop)
mask_tf = test_transform(args.mask_size, args.crop)
for (fore_path,mask_path) in zip(fore_path, mask_path):
for back_path in back_path:
fore = Image.open(str(fore_path))
back = Image.open(str(back_path))
# If the foreground is larger than the background, scale the foreground to the background size.
tempSize = [fore.size[0] * zoomSize, fore.size[1] * zoomSize]
if tempSize[0] > back.size[0]:
tempSize[0] = back.size[0]
tempSize[1] = int(tempSize[1] * back.size[0] /(fore.size[0]*zoomSize))
if tempSize[1] > back.size[1]:
temp = tempSize[1]
tempSize[1] = back.size[1]
tempSize[0] = int(tempSize[0] * back.size[1] / (temp))
fore_tf = test_transform((int(tempSize[1]), int(tempSize[0])), args.crop)
mask_tf = test_transform((int(tempSize[1]), int(tempSize[0])), args.crop)
fore = fore_tf(fore)
back = back_tf(back)
mask = mask_tf(Image.open(str(mask_path)))
back = back.to(device).unsqueeze(0)
fore = fore.to(device).unsqueeze(0)
mask = mask.to(device).unsqueeze(0)
mask = (mask>0).float()
_,_,w,h =mask.shape
x, y = position(fore, back)
Vertical = Vertical if Vertical<=x else x
Horizontal = Horizontal if Horizontal<=y else y
x = x + Vertical
y = y + Horizontal
back_use = back[:, :, x:x + w, y:y + h]
with torch.no_grad():
output_pre = camouflage(vgg, decoder, PSF, fore, back_use, mask)
output_pre = embed(output_pre, mask, back, x, y)[:,:,Top:Bottom,Left:Right]
output_name = output_dir / '{:s}_{:s}{:s}'.format(back_path.stem, fore_path.stem, args.save_ext)
save_image(output_pre, str(output_name))