-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
239 lines (186 loc) · 8.29 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
import torch
import numpy as np
import os
import random
import torch.nn.functional as F
from kernel_encoding import reconstruct_from_cov, decode_to_cov
from models import preprocess, postprocess
def multiple_downsample(hr, kernels, scale):
with torch.no_grad():
hr = preprocess(hr)
# downsample with each kernels
b, c, h, w = hr.shape
hr = hr.view((b * c, 1, h, w))
fake_lr = batch_forward(hr, kernels, scale)
n, _, _, h, w = fake_lr.shape
fake_lr = fake_lr.view((n, b, c, h, w))
fake_lr = postprocess(fake_lr)
fake_lr = fake_lr.detach().clone()
return fake_lr
def conv_downsample(img, kernel, scale):
return F.conv2d(img, kernel, bias=None, stride=scale)
def batch_forward(batch_img, batch_kernel, scale):
batch = []
batch_size = batch_kernel.shape[0]
for b in range(batch_size):
fake_lr = conv_downsample(batch_img, batch_kernel[b], scale)
batch.append(fake_lr)
batch = torch.stack(batch)
return batch
def kernel_collage(lr_imgs, kernels=None, ratio=(0.33, 0.8)):
with torch.no_grad():
n = len(lr_imgs)
base = lr_imgs[0]
batch_size = base.shape[0]
n_pixels = base.shape[-2:]
_min, _max = int(ratio[0] * 1000), int(ratio[1] * 1000)
ratios_h = [random.choice(range(_min, _max)) / 1000 for _ in range(n - 1)]
ratios_w = [random.choice(range(_min, _max)) / 1000 for _ in range(n - 1)]
# priority in descending orders
ratios_h = sorted(ratios_h)
ratios_w = sorted(ratios_w)
# if kernels is not None:
kernel_map = kernels[0].expand((batch_size, n_pixels[0], n_pixels[1], kernels.shape[1]))
kernels = kernels.view(n, 1, 1, 1, kernels.shape[1]).expand((n, batch_size, n_pixels[0], n_pixels[1], kernels.shape[1]))
# compose locations to cut & mix
# crop & paste to base image
# from lower priority.
# higher priority patches can overwrite on lower priority patches
for i in range(n - 1, 0, -1):
cur_crop_h = int(ratios_h[i - 1] * min(n_pixels))
cur_crop_w = int(ratios_w[i - 1] * min(n_pixels))
_pos_h = [random.randint(0, n_pixels[0] - cur_crop_h) for _ in range(batch_size)]
_pos_w = [random.randint(0, n_pixels[1] - cur_crop_w) for _ in range(batch_size)]
crop_mask = torch.zeros_like(base)
kmap_mask = torch.zeros_like(kernel_map)
for b in range(batch_size):
crop_mask[b, :, _pos_h[b]: _pos_h[b] + cur_crop_h, _pos_w[b]: _pos_w[b] + cur_crop_w] = 1
kmap_mask[b, _pos_h[b]: _pos_h[b] + cur_crop_h, _pos_w[b]: _pos_w[b] + cur_crop_w, :] = 1
crop = crop_mask * lr_imgs[i]
base_mask = torch.abs(crop_mask - 1)
base = base_mask * base
base = base + crop
if kernels is not None:
kmap_crop = kmap_mask * kernels[i]
kmap_mask = torch.abs(kmap_mask - 1)
kernel_map = kernel_map * kmap_mask
kernel_map = kernel_map + kmap_crop
if kernels is not None:
return base, kernel_map.permute(0, 3, 1, 2)
else:
return base
def pixel_mix(lr_imgs, kernels):
with torch.no_grad():
fake_lr = lr_imgs
n, b, c, h, w = fake_lr.shape
kernels = kernels.view(n, 1, 1, 1, kernels.shape[1]).expand((n, b, h, w, kernels.shape[1])) # n, b, h, w, c
fake_lr = fake_lr.permute((1, 3, 4, 0, 2)) # b, h, w, n, c
kernels = kernels.permute((1, 2, 3, 0, 4))
_b = np.repeat(range(0, b), h * w)
i = np.tile(np.repeat(range(0, h), w), b)
j = np.tile(np.tile(range(0, w), h), b)
k = np.random.randint(n, size=b * h * w)
fake_lr = fake_lr[_b, i, j, k, :].view(b, h, w, c).permute(0, 3, 1, 2).detach().clone()
kernels = kernels[_b, i, j, k, :].view(b, h, w, c).permute(0, 3, 1, 2).detach().clone()
return fake_lr, kernels
def mask_mix(lr_imgs, kernels, mask):
mask = mask.permute(1, 0, 2, 3)
# mask = mask.permute(1, 0, 3, 2)
n_mask = mask.shape[0]
final_lr = torch.zeros([1, 3, mask.shape[2], mask.shape[3]]).cuda()
final_kernel_map = torch.zeros([1, 3, mask.shape[2], mask.shape[3]]).cuda()
with torch.no_grad():
for i in range(n_mask):
cur_mask = mask[i]
cur_lr = lr_imgs[i][0]
cur_kernel = kernels[i]
cur_kernel = cur_kernel.view(-1, 3, 1, 1).expand_as(final_kernel_map)[0]
masked_lr = cur_mask * cur_lr
masked_kmap = cur_mask * cur_kernel
masked_lr = masked_lr.unsqueeze(0)
masked_kmap = masked_kmap.unsqueeze(0)
final_lr = final_lr + masked_lr
final_kernel_map = final_kernel_map + masked_kmap
return final_lr, final_kernel_map
from torch.nn.functional import interpolate
_max_bound = torch.FloatTensor([[[[50]], [[10]], [[1]]]])
_min_bound = torch.FloatTensor([[[[2.5]], [[0.1]], [[1e-4]]]])
def make_kmap(k_code, size=(64, 64)):
# interpolate to make kernel map
# size is in (H, W)
n = int(np.sqrt(len(k_code)))
k_code = k_code.permute(1, 0)
k_code = k_code.view(1, 3, n, n)
k_map = interpolate(k_code, size, mode='bicubic', align_corners=True)
k_map = torch.max(torch.min(k_map, _max_bound), _min_bound)
return k_code, k_map[0]
def downsample_via_kcode(hr, kmap, scale=4):
ksize = (49, 49)
kmap = kmap.permute(2, 1, 0)
# kmap = kmap.permute(1, 2, 0)
cols = []
for h in range(0, hr.shape[1] - ksize[0] + 1, scale):
row = []
for w in range(0, hr.shape[2] - ksize[1] + 1, scale):
kernel = reconstruct_from_cov(decode_to_cov(kmap[h // scale, w // scale]), mean=(24, 24), size=ksize)
patch = hr[:, h: h + ksize[0], w: w + ksize[1]]
pixel = patch * kernel
pixel = pixel.sum(dim=(1, 2))
row.append(pixel)
row = torch.stack(row)
cols.append(row)
cols = torch.stack(cols)
cols = torch.clamp(cols.permute(2, 0, 1), 0, 1)
return cols
def visualize_kmap(kmap, s=1, out_dir='./', tag=None, ksize=49, each=False):
import cv2
out_dir = os.path.join(out_dir, 'visualized_kmap')
os.makedirs(out_dir, exist_ok=True)
if each:
out_dir = os.path.join(out_dir, tag)
os.makedirs(out_dir, exist_ok=True)
ksize = (ksize, ksize)
mean = (ksize[0] // 2, ksize[1] // 2)
cols = []
for i in range(kmap.shape[0]):
# parameter s corresponds to the stride of kernel visualization.
# if s = 1, kernels are visualized for every pixel.
if not i % s == 0:
continue
row = []
for j in range(kmap.shape[1]):
if not j % s == 0:
continue
kernel = reconstruct_from_cov(decode_to_cov(kmap[i, j]), mean=mean, size=ksize).astype(np.float32)
if kernel.max() > 0:
kernel /= kernel.max()
kernel *= 255
if each:
cv2.imwrite(os.path.join(out_dir, '{}_{}.png'.format(i, j)), kernel)
else:
row.append(kernel)
if not each:
row = cv2.hconcat(row)
cols.append(row)
if not each:
cols = cv2.vconcat(cols)
if tag is None:
cv2.imwrite(os.path.join(out_dir, '{}.png'.format(s)), cols)
else:
cv2.imwrite(os.path.join(out_dir, '{}_{}.png'.format(tag, s)), cols)
def init_Kmap(img, random=False):
_size = img[:, 0, :, :].shape
if random:
rand_n = torch.clamp(torch.randn(1).normal_(25, 8), 2.5, 47.5)[0]
rand_w = torch.clamp(torch.randn(1).normal_(0, 0.5), -1, 1)[0]
rand_w = 10 ** rand_w
rand_v = torch.clamp(torch.rand(1), 1e-3, 1)[0]
norm = torch.zeros(_size).fill_(rand_n)
w = torch.zeros(_size).fill_(rand_w)
v = torch.zeros(_size).fill_(rand_v)
else:
norm = torch.zeros(_size).fill_(45.0)
w = torch.zeros(_size).fill_(1.0)
v = torch.zeros(_size).fill_(0.5)
init_kmap = torch.stack([norm, w, v], dim=1)
return init_kmap