-
Notifications
You must be signed in to change notification settings - Fork 20
/
visualize.py
136 lines (116 loc) · 4.83 KB
/
visualize.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import os
import numpy as np
from PIL import Image
__all__ = ['get_color_pallete', 'set_img_color', 'show_prediction',
'show_colorful_images', 'save_colorful_images']
def set_img_color(img, label, colors, background=0, show255=False):
for i in range(len(colors)):
if i != background:
img[np.where(label == i)] = colors[i]
if show255:
img[np.where(label == 255)] = 255
return img
def show_prediction(img, pred, colors, background=0):
im = np.array(img, np.uint8)
set_img_color(im, pred, colors, background)
out = np.array(im)
return out
def show_colorful_images(prediction, palettes):
im = Image.fromarray(palettes[prediction.astype('uint8').squeeze()])
im.show()
def save_colorful_images(prediction, filename, output_dir, palettes):
'''
:param prediction: [B, H, W, C]
'''
im = Image.fromarray(palettes[prediction.astype('uint8').squeeze()])
fn = os.path.join(output_dir, filename)
out_dir = os.path.split(fn)[0]
if not os.path.exists(out_dir):
os.mkdir(out_dir)
im.save(fn)
def get_color_pallete(npimg, dataset='pascal_voc'):
"""Visualize image.
Parameters
----------
npimg : numpy.ndarray
Single channel image with shape `H, W, 1`.
dataset : str, default: 'pascal_voc'
The dataset that model pretrained on. ('pascal_voc', 'ade20k')
Returns
-------
out_img : PIL.Image
Image with color pallete
"""
# recovery boundary
if dataset in ('pascal_voc', 'pascal_aug'):
npimg[npimg == -1] = 255
# put colormap
if dataset == 'ade20k':
npimg = npimg + 1
out_img = Image.fromarray(npimg.astype('uint8'))
out_img.putpalette(adepallete)
return out_img
elif dataset == 'citys':
out_img = Image.fromarray(npimg.astype('uint8'))
out_img.putpalette(cityspallete)
return out_img
out_img = Image.fromarray(npimg.astype('uint8'))
out_img.putpalette(vocpallete)
return out_img
def _getvocpallete(num_cls):
n = num_cls
pallete = [0] * (n * 3)
for j in range(0, n):
lab = j
pallete[j * 3 + 0] = 0
pallete[j * 3 + 1] = 0
pallete[j * 3 + 2] = 0
i = 0
while (lab > 0):
pallete[j * 3 + 0] |= (((lab >> 0) & 1) << (7 - i))
pallete[j * 3 + 1] |= (((lab >> 1) & 1) << (7 - i))
pallete[j * 3 + 2] |= (((lab >> 2) & 1) << (7 - i))
i = i + 1
lab >>= 3
return pallete
vocpallete = _getvocpallete(256)
adepallete = [
0, 0, 0, 120, 120, 120, 180, 120, 120, 6, 230, 230, 80, 50, 50, 4, 200, 3, 120, 120, 80, 140, 140, 140, 204,
5, 255, 230, 230, 230, 4, 250, 7, 224, 5, 255, 235, 255, 7, 150, 5, 61, 120, 120, 70, 8, 255, 51, 255, 6, 82,
143, 255, 140, 204, 255, 4, 255, 51, 7, 204, 70, 3, 0, 102, 200, 61, 230, 250, 255, 6, 51, 11, 102, 255, 255,
7, 71, 255, 9, 224, 9, 7, 230, 220, 220, 220, 255, 9, 92, 112, 9, 255, 8, 255, 214, 7, 255, 224, 255, 184, 6,
10, 255, 71, 255, 41, 10, 7, 255, 255, 224, 255, 8, 102, 8, 255, 255, 61, 6, 255, 194, 7, 255, 122, 8, 0, 255,
20, 255, 8, 41, 255, 5, 153, 6, 51, 255, 235, 12, 255, 160, 150, 20, 0, 163, 255, 140, 140, 140, 250, 10, 15,
20, 255, 0, 31, 255, 0, 255, 31, 0, 255, 224, 0, 153, 255, 0, 0, 0, 255, 255, 71, 0, 0, 235, 255, 0, 173, 255,
31, 0, 255, 11, 200, 200, 255, 82, 0, 0, 255, 245, 0, 61, 255, 0, 255, 112, 0, 255, 133, 255, 0, 0, 255, 163,
0, 255, 102, 0, 194, 255, 0, 0, 143, 255, 51, 255, 0, 0, 82, 255, 0, 255, 41, 0, 255, 173, 10, 0, 255, 173, 255,
0, 0, 255, 153, 255, 92, 0, 255, 0, 255, 255, 0, 245, 255, 0, 102, 255, 173, 0, 255, 0, 20, 255, 184, 184, 0,
31, 255, 0, 255, 61, 0, 71, 255, 255, 0, 204, 0, 255, 194, 0, 255, 82, 0, 10, 255, 0, 112, 255, 51, 0, 255, 0,
194, 255, 0, 122, 255, 0, 255, 163, 255, 153, 0, 0, 255, 10, 255, 112, 0, 143, 255, 0, 82, 0, 255, 163, 255,
0, 255, 235, 0, 8, 184, 170, 133, 0, 255, 0, 255, 92, 184, 0, 255, 255, 0, 31, 0, 184, 255, 0, 214, 255, 255,
0, 112, 92, 255, 0, 0, 224, 255, 112, 224, 255, 70, 184, 160, 163, 0, 255, 153, 0, 255, 71, 255, 0, 255, 0,
163, 255, 204, 0, 255, 0, 143, 0, 255, 235, 133, 255, 0, 255, 0, 235, 245, 0, 255, 255, 0, 122, 255, 245, 0,
10, 190, 212, 214, 255, 0, 0, 204, 255, 20, 0, 255, 255, 255, 0, 0, 153, 255, 0, 41, 255, 0, 255, 204, 41, 0,
255, 41, 255, 0, 173, 0, 255, 0, 245, 255, 71, 0, 255, 122, 0, 255, 0, 255, 184, 0, 92, 255, 184, 255, 0, 0,
133, 255, 255, 214, 0, 25, 194, 194, 102, 255, 0, 92, 0, 255]
cityspallete = [
128, 64, 128,
244, 35, 232,
70, 70, 70,
102, 102, 156,
190, 153, 153,
153, 153, 153,
250, 170, 30,
220, 220, 0,
107, 142, 35,
152, 251, 152,
0, 130, 180,
220, 20, 60,
255, 0, 0,
0, 0, 142,
0, 0, 70,
0, 60, 100,
0, 80, 100,
0, 0, 230,
119, 11, 32,
]