-
Notifications
You must be signed in to change notification settings - Fork 0
/
renderer.py
127 lines (102 loc) · 3.5 KB
/
renderer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import imageio
import os
import numpy as np
from image_viewer import SimpleImageViewer
from heatmap import features_heatmap
_colorbits = {
0:0x00000000 ,
1:0x00D70000 ,
2:0x000000D7 ,
3:0x00D700D7 ,
4:0x0000D700 ,
5:0x00D7D700 ,
6:0x0000D7D7 ,
7:0x00D7D7D7 ,
8:0x00000000 ,
9:0x00FF0000 ,
10:0x000000FF ,
11:0x00FF00FF ,
12:0x0000FF00 ,
13:0x00FFFF00 ,
14:0x0000FFFF ,
15:0x00FFFFFF }
_palette = np.vectorize(_colorbits.get, otypes=[np.uint32])
####################################
class Renderer:
def render(self, state, frame):
pass
def reset(self):
pass
####################################
class CompositeRender(Renderer):
def __init__(self, inner):
self.inner = inner
def render(self, state, frame):
for o in self.inner: o.render(state,frame)
def reset(self):
for o in self.inner: o.reset()
####################################
class FrameRender(Renderer):
def __init__(self):
self.viewer = SimpleImageViewer()
def render(self, state, frame):
arr = self.prepare(state, frame)
if arr is not None: self.viewer.imshow(arr)
def prepare(self, state, frame):
if frame is None: return None
screen = np.frombuffer(_palette(frame).tobytes(), dtype=np.uint8)
arr = np.reshape(screen, (312,352,4))
return arr
####################################
class StateRender(Renderer):
def __init__(self):
self.viewer = SimpleImageViewer()
def render(self, state, frame):
arr = self.prepare(state, frame)
if arr is not None: self.viewer.imshow(arr)
def prepare(self, state, frame):
if state is None: return None
screen = np.asarray(state)
arr = np.swapaxes(screen, 0, 2).astype(np.uint8)
return arr
####################################
class GifRender(Renderer):
def __init__(self, inner, fname = "./frame.gif"):
self.writer = None
self.inner = inner
self.fname = fname
def render(self,state,frame):
arr = self.inner.prepare(state,frame)
if arr is not None: self.writer.append_data(arr)
def reset(self):
if self.writer is not None:
self.writer.close()
self.move_file(self.fname, "./gifs")
self.writer = imageio.get_writer(uri=self.fname, mode="I")
def move_file(self, fname, folder):
if not os.path.isfile(fname): return
if not os.path.isdir(folder):
os.makedirs(folder)
idx = 1
while os.path.isfile(os.path.join(folder, str(idx) + os.path.basename(fname))):
idx = idx + 1
os.rename(fname, os.path.join(folder, str(idx) + os.path.basename(fname)))
####################################
class HeatmapRender(Renderer):
def __init__(self, model, layer_name="conv2d_1", stack_by=8):
self.model = model
self.layer_name = layer_name
self.stack_by = stack_by
self.viewer = SimpleImageViewer()
def render(self, state, frame):
self.viewer.imshow(self.prepare(state, frame))
def prepare(self, state, frame):
data = np.expand_dims(np.asarray(state).astype(np.float64), axis=0)
hm = features_heatmap(self.model, data, self.layer_name, self.stack_by)
return hm
def get_renderer(render_mode, model):
if render_mode == 0: return Renderer()
elif render_mode == 1: return FrameRender()
elif render_mode == 2: return StateRender()
elif render_mode == 3: return CompositeRender([FrameRender(),GifRender(FrameRender()),GifRender(StateRender(),"./state.gif")])
elif render_mode == 4: return CompositeRender([HeatmapRender(model), GifRender(HeatmapRender(model), "./heatmap.gif")])