Skip to content

Commit

Permalink
Decoupled observation retrieval
Browse files Browse the repository at this point in the history
  • Loading branch information
jerryz123 committed Jul 7, 2018
1 parent f945fb5 commit 31640e4
Show file tree
Hide file tree
Showing 5 changed files with 41 additions and 20 deletions.
3 changes: 2 additions & 1 deletion examples/fluids_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,5 @@
# actions = {k:fluids.VelocityAction(1) for k in controlled_keys}
# actions = {k:fluids.SteeringAction(0, 1) for k in controlled_keys}
# actions = {k:fluids.KeyboardAction() for k in controlled_keys}
obs, rew = simulator.step(actions)
rew = simulator.step(actions)
obs = simulator.get_observations(controlled_keys)
3 changes: 2 additions & 1 deletion fluids/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,5 @@
background_control =fluids.BACKGROUND_CSP)
while True:
actions = {k: fluids.KeyboardAction() for k in simulator.get_control_keys()}
obs, rew = simulator.step(actions)
rew = simulator.step(actions)
obs = simulator.get_observations(simulator.get_control_keys())
4 changes: 0 additions & 4 deletions fluids/assets/crosswalk.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,3 @@ def __init__(self, **kwargs):
self.start_waypoints[0].nxt = [self.end_waypoints[0]]
self.start_waypoints[1].nxt = [self.end_waypoints[1]]


def render(self, surface, **kwargs):
if self.vis_level > 3:
super(CrossWalk, self).render(surface, **kwargs)
23 changes: 17 additions & 6 deletions fluids/sim.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,10 @@ def render(self):
self.clock.tick(self.fps)
self.surface.blit(pygame.transform.scale(self.state.get_static_surface(),
self.screen_dim), (0, 0))

if self.vis_level > 2:
self.surface.blit(pygame.transform.scale(self.state.get_static_debug_surface(),
self.screen_dim), (0, 0))

self.surface.blit(pygame.transform.scale(self.state.get_dynamic_surface(),
self.screen_dim), (0, 0))

Expand All @@ -105,6 +108,14 @@ def render(self):
pygame.display.flip()
pygame.event.pump()
self.last_keys_pressed = pygame.key.get_pressed()
if self.last_keys_pressed[pygame.K_PERIOD]:
self.vis_level += 1
self.state.update_vis_level(self.vis_level)
fluids_print("New visualization level: " + str(self.vis_level))
elif self.last_keys_pressed[pygame.K_COMMA] and self.vis_level > 1:
self.vis_level -= 1
self.state.update_vis_level(self.vis_level)
fluids_print("New visualization level: " + str(self.vis_level))
else:
self.clock.tick(0)
if not self.state.time % 60:
Expand Down Expand Up @@ -157,15 +168,15 @@ def step(self, actions={}):

self.state.time += 1



observations = {k:c.make_observation(self.obs_space, **self.obs_args)
for k, c in iteritems(self.state.controlled_cars)}
reward_step = self.reward_fn(self.state)
#print(reward_step)
if self.render_on:
self.render()
return observations, reward_step
return reward_step
def get_observations(self, keys={}):
observations = {k:self.state.objects[k].make_observation(self.obs_space, **self.obs_args)
for k in keys}
return observations

def get_background_actions(self):
if self.background_control == BACKGROUND_NULL or len(self.state.background_cars) == 0:
Expand Down
28 changes: 20 additions & 8 deletions fluids/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,25 +175,31 @@ def __init__(self,

fluids_print("State creation complete")
if vis_level:
self.static_surface = pygame.Surface(self.dimensions)
self.static_surface = pygame.Surface(self.dimensions)
self.static_debug_surface = pygame.Surface(self.dimensions, pygame.SRCALPHA)
for k, obj in iteritems(self.static_objects):
obj.render(self.static_surface)
if vis_level > 2:
for waypoint in self.waypoints:
waypoint.render(self.static_surface)
for waypoint in self.ped_waypoints:
waypoint.render(self.static_surface, color=(255, 255, 0))
if type(obj) != CrossWalk:
obj.render(self.static_surface)
else:
obj.render(self.static_debug_surface)
for waypoint in self.waypoints:
waypoint.render(self.static_debug_surface)
for waypoint in self.ped_waypoints:
waypoint.render(self.static_debug_surface, color=(255, 255, 0))


def get_static_surface(self):
return self.static_surface

def get_static_debug_surface(self):
return self.static_debug_surface

def get_dynamic_surface(self):
dynamic_surface = pygame.Surface(self.dimensions, pygame.SRCALPHA)
for typ in [Pedestrian, TrafficLight, CrossWalkLight]:
for k, obj in iteritems(self.type_map[typ]):
obj.render(dynamic_surface)
for k, car in iteritems(self.background_cars):
for k, car in iteritems(self.background_cars):
car.render(dynamic_surface)
for k, car in iteritems(self.controlled_cars):
car.render(dynamic_surface)
Expand Down Expand Up @@ -226,5 +232,11 @@ def min_distance_to_collision(self, obj):
mind = d
return mind

def update_vis_level(self, new_vis_level):
self.vis_level = new_vis_level
for k, obj in iteritems(self.objects):
obj.vis_level = new_vis_level


def get_controlled_collisions(self):
return

0 comments on commit 31640e4

Please sign in to comment.