Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
LucasAlegre committed May 7, 2024
2 parents a45f2fc + c8ec95d commit ae9c293
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 8 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
<img src="docs/_static/logo.png" align="right" width="30%"/>

[![DOI](https://zenodo.org/badge/161216111.svg)](https://zenodo.org/doi/10.5281/zenodo.10869789)
[![tests](https://github.com/LucasAlegre/sumo-rl/actions/workflows/linux-test.yml/badge.svg)](https://github.com/LucasAlegre/sumo-rl/actions/workflows/linux-test.yml)
[![PyPI version](https://badge.fury.io/py/sumo-rl.svg)](https://badge.fury.io/py/sumo-rl)
[![pre-commit](https://img.shields.io/badge/pre--commit-enabled-brightgreen?logo=pre-commit&logoColor=white)](https://pre-commit.com/)
Expand Down
2 changes: 1 addition & 1 deletion sumo_rl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,4 @@
)


__version__ = "1.4.3"
__version__ = "1.4.4"
32 changes: 25 additions & 7 deletions sumo_rl/environment/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ def step(self, action: Union[dict, int]):
If single_agent is True, action is an int, otherwise it expects a dict with keys corresponding to traffic signal ids.
"""
# No action, follow fixed TL defined in self.phases
if action is None or action == {}:
if self.fixed_ts or action is None or action == {}:
for _ in range(self.delta_time):
self._sumo_step()
else:
Expand Down Expand Up @@ -364,15 +364,27 @@ def _compute_info(self):

def _compute_observations(self):
self.observations.update(
{ts: self.traffic_signals[ts].compute_observation() for ts in self.ts_ids if self.traffic_signals[ts].time_to_act}
{
ts: self.traffic_signals[ts].compute_observation()
for ts in self.ts_ids
if self.traffic_signals[ts].time_to_act or self.fixed_ts
}
)
return {ts: self.observations[ts].copy() for ts in self.observations.keys() if self.traffic_signals[ts].time_to_act}
return {
ts: self.observations[ts].copy()
for ts in self.observations.keys()
if self.traffic_signals[ts].time_to_act or self.fixed_ts
}

def _compute_rewards(self):
self.rewards.update(
{ts: self.traffic_signals[ts].compute_reward() for ts in self.ts_ids if self.traffic_signals[ts].time_to_act}
{
ts: self.traffic_signals[ts].compute_reward()
for ts in self.ts_ids
if self.traffic_signals[ts].time_to_act or self.fixed_ts
}
)
return {ts: self.rewards[ts] for ts in self.rewards.keys() if self.traffic_signals[ts].time_to_act}
return {ts: self.rewards[ts] for ts in self.rewards.keys() if self.traffic_signals[ts].time_to_act or self.fixed_ts}

@property
def observation_space(self):
Expand Down Expand Up @@ -580,10 +592,16 @@ def step(self, action):
"It is currently {}".format(agent, self.action_spaces[agent].n, action)
)

self.env._apply_actions({agent: action})
if not self.env.fixed_ts:
self.env._apply_actions({agent: action})

if self._agent_selector.is_last():
self.env._run_steps()
if not self.env.fixed_ts:
self.env._run_steps()
else:
for _ in range(self.env.delta_time):
self.env._sumo_step()

self.env._compute_observations()
self.rewards = self.env._compute_rewards()
self.compute_info()
Expand Down

0 comments on commit ae9c293

Please sign in to comment.