-
Notifications
You must be signed in to change notification settings - Fork 2
/
utils.py
executable file
·274 lines (218 loc) · 10.2 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from matplotlib import colors
import gym
from gym import spaces
from gym.utils import seeding
def simulate(env, init_state = None, f_name='value_iter', max_iter=1000):
env.reset()
if init_state != None:
env.init_state = init_state
state = env.init_state
indx = 0
while state not in env.goal_states:
env.step(env.policy[tuple(state)])
fig = env.render()
state = env.state
if indx > max_iter:
print("Max Iteration Limit Hit!")
return
indx += 1
return env._get_video(interval=200, gif_path=f_name+'.mp4').to_html5_video()
def draw_policy(env, p, dynamic=False):
if dynamic:
action_mapping = {
0: '\u2191', # up
1: '\u2193', # down
2: '\u2190', # left
3: '\u2192', # right
4: 'W' # Wait or Do Nothing
}
else:
action_mapping = {
0: '\u2191', # up
1: '\u2193', # down
2: '\u2190', # left
3: '\u2192' # right
}
arrow_policy = np.empty(p.shape, dtype=str)
for state, action in np.ndenumerate(p):
#print(state, " : ", p)
arrow_policy[state] = action_mapping[int(action)]
if env.maze[state] in (7,8):
arrow_policy[state] = '\u00B7'
if list(state) in env.goal_states:
arrow_policy[state] = '\u2B24' #
if env.maze[state] == 4:
arrow_policy[state] = 'X'#'\u274C' #'\u24CD'
if env.maze[state] == 5:
arrow_policy[state] = 'U'#'\u24CD'
return arrow_policy
class MazeEnv(gym.Env):
"""Configurable environment for maze. """
metadata = {'render.modes': ['human', 'rgb_array']}
def __init__(self,
maze_generator,
pob_size=1,
action_type='VonNeumann',
obs_type='full',
live_display=False,
render_trace=False,
dynamic=False):
"""Initialize the maze. DType: list"""
# Maze: 0: free space, 1: wall
self.maze_generator = maze_generator
self.maze = np.array(self.maze_generator.get_maze())
self.maze_size = self.maze.shape
self.init_state, self.goal_states = self.maze_generator.init_end_states()
self.valid_states = [state for state, _ in np.ndenumerate(self.maze)]
self.render_trace = render_trace
self.traces = []
self.action_type = action_type
self.obs_type = obs_type
# If True, show the updated display each time render is called rather
# than storing the frames and creating an animation at the end
self.live_display = live_display
self.state = None
# Action space: 0: Up, 1: Down, 2: Left, 3: Right
if self.action_type == 'VonNeumann': # Von Neumann neighborhood
self.num_actions = 4
elif action_type == 'Moore': # Moore neighborhood
self.num_actions = 8
else:
raise TypeError('Action type must be either \'VonNeumann\' or \'Moore\'')
self.action_space = spaces.Discrete(self.num_actions)
self.all_actions = list(range(self.action_space.n))
# Size of the partial observable window
self.pob_size = pob_size
# Observation space
low_obs = 0 # Lowest integer in observation
high_obs = 10 # Highest integer in observation
if self.obs_type == 'full':
self.observation_space = spaces.Box(low=low_obs, high=high_obs,
shape=self.maze_size)
elif self.obs_type == 'partial':
self.observation_space = spaces.Box(low=low_obs, high=high_obs,
shape=(self.pob_size*2+1, self.pob_size*2+1))
else:
raise TypeError('Observation type must be either \'full\' or \'partial\'')
# Colormap: order of color is, free space, wall, agent, food, poison
self.cmap = colors.ListedColormap(['white', 'black', 'blue', 'green', 'red', 'gray',
'lightblue','steelblue','powderblue', 'lightgray','yellow'])
self.bounds = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10] # values for each color
self.norm = colors.BoundaryNorm(self.bounds, self.cmap.N)
self.ax_imgs = [] # For generating videos
self.policy = None # Set a policy for this environment
self.dynamic = dynamic # Activate dynamic obstacles
def _step(self, action):
old_state = self.state
if self.dynamic:
if np.random.random() > 0.9:
self.maze[(2,1)] = 5 if self.maze[(2,1)] == 6 else 6
self.maze[(1,1)] = 6 if self.maze[(2,1)] == 5 else 5
# Update current state
self.state = self._next_state(self.state, action)
# Footprint: Record agent trajectory
self.traces.append(self.state)
return self._get_obs()
def _reset(self):
# Reset maze
self.maze = np.array(self.maze_generator.get_maze())
# Set current state be initial state
self.state = self.init_state
# Clean the list of ax_imgs, the buffer for generating videos
self.ax_imgs = []
# Clean the traces of the trajectory
self.traces = [self.init_state]
return self._get_obs()
def _render(self, mode='human', close=False):
if close:
plt.close()
return
obs = self._get_full_obs()
partial_obs = self._get_partial_obs(self.pob_size)
# For rendering traces: Only for visualization, does not affect the observation data
if self.render_trace:
obs[list(zip(*self.traces[:-1]))] = 10
# Create Figure for rendering
if not hasattr(self, 'fig'): # initialize figure and plotting axes
self.fig, (self.ax_full, self.ax_partial) = plt.subplots(nrows=1, ncols=2)
self.ax_full.axis('off')
self.ax_partial.axis('off')
self.fig.show()
if self.live_display:
# Only create the image the first time
if not hasattr(self, 'ax_full_img'):
self.ax_full_img = self.ax_full.imshow(obs, cmap=self.cmap, norm=self.norm, animated=True)
if not hasattr(self, 'ax_partial_img'):
self.ax_partial_img = self.ax_partial.imshow(partial_obs, cmap=self.cmap, norm=self.norm, animated=True)
# Update the image data for efficient live video
self.ax_full_img.set_data(obs)
self.ax_partial_img.set_data(partial_obs)
else:
# Create a new image each time to allow an animation to be created
self.ax_full_img = self.ax_full.imshow(obs, cmap=self.cmap, norm=self.norm, animated=True)
self.ax_partial_img = self.ax_partial.imshow(partial_obs, cmap=self.cmap, norm=self.norm, animated=True)
plt.draw()
if self.live_display:
# Update the figure display immediately
self.fig.canvas.draw()
else:
# Put in AxesImage buffer for video generation
self.ax_imgs.append([self.ax_full_img, self.ax_partial_img]) # List of axes to update figure frame
self.fig.set_dpi(100)
return self.fig
def _goal_test(self, state):
"""Return True if current state is a goal state."""
if type(self.goal_states[0]) == list:
return list(state) in self.goal_states
elif type(self.goal_states[0]) == tuple:
return tuple(state) in self.goal_states
def _next_state(self, state, action):
"""Return the next state from a given state by taking a given action."""
# Transition table to define movement for each action
if self.action_type == 'VonNeumann':
if self.dynamic and action == 4:
return state
transitions = {0: [-1, 0], 1: [+1, 0], 2: [0, -1], 3: [0, +1]}
elif self.action_type == 'Moore':
transitions = {0: [-1, 0], 1: [+1, 0], 2: [0, -1], 3: [0, +1],
4: [-1, +1], 5: [+1, +1], 6: [-1, -1], 7: [+1, -1]}
new_state = [state[0] + transitions[action][0], state[1] + transitions[action][1]]
return new_state
def _get_obs(self):
if self.obs_type == 'full':
return self._get_full_obs()
elif self.obs_type == 'partial':
return self._get_partial_obs(self.pob_size)
def _get_full_obs(self):
"""Return a 2D array representation of maze."""
obs = np.array(self.maze)
# Set goal positions
for goal in self.goal_states:
obs[goal[0]][goal[1]] = 3 # 3: goal
# Set current position
# Come after painting goal positions, avoid invisible within multi-goal regions
obs[self.state[0]][self.state[1]] = 2 # 2: agent
return obs
def _get_partial_obs(self, size=1):
"""Get partial observable window according to Moore neighborhood"""
# Get maze with indicated location of current position and goal positions
maze = self._get_full_obs()
pos = np.array(self.state)
under_offset = np.min(pos - size)
over_offset = np.min(len(maze) - (pos + size + 1))
offset = np.min([under_offset, over_offset])
if offset < 0: # Need padding
maze = np.pad(maze, np.abs(offset), 'constant', constant_values=7)
pos += np.abs(offset)
return maze[pos[0]-size : pos[0]+size+1, pos[1]-size : pos[1]+size+1]
def _get_video(self, interval=200, gif_path=None):
if self.live_display:
# TODO: Find a way to create animations without slowing down the live display
print('Warning: Generating an Animation when live_display=True not yet supported.')
anim = animation.ArtistAnimation(self.fig, self.ax_imgs, interval=interval)
if gif_path is not None:
anim.save(gif_path, writer='ffmpeg', fps=10)
return anim