-
Notifications
You must be signed in to change notification settings - Fork 13
/
bilateral.py
36 lines (27 loc) · 892 Bytes
/
bilateral.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import sys
import numpy as np
import torch as th
import cv2
from permutohedral.gfilt import gfilt
def gaussian_filter(ref, val, kstd):
return gfilt(ref / kstd[None, :, None, None], val)
def usage():
print("Usage: python bilateral.py input output sxy srgb")
exit(1)
if len(sys.argv) != 5:
usage()
try:
sxy = float(sys.argv[3])
srgb = float(sys.argv[4])
except:
usage()
img = cv2.imread(sys.argv[1]).astype(np.float32)[..., :3] / 255.
img = img.transpose(2, 0, 1)
yx = np.mgrid[:img.shape[1], :img.shape[2]].astype(np.float32)
stacked = np.vstack([yx, img])
img = th.from_numpy(img).cuda()
stacked = th.from_numpy(stacked).cuda()
kstd = th.FloatTensor([sxy, sxy, srgb, srgb, srgb]).cuda()
filtered = gaussian_filter(stacked[None], img[None], kstd)[0]
filtered = (255 * filtered).permute(1, 2, 0).byte().data.cpu().numpy()
cv2.imwrite(sys.argv[2], filtered)