Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for custom observations (both observation function and observation space) #133

Merged
merged 10 commits into from
Feb 24, 2023
2 changes: 1 addition & 1 deletion sumo_rl/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from sumo_rl.environment.env import SumoEnvironment, TrafficSignal
from sumo_rl.environment.env import SumoEnvironment, TrafficSignal, ObservationFunction
from sumo_rl.environment.env import env, parallel_env
from sumo_rl.environment.resco_envs import grid4x4, arterial4x4, ingolstadt1, ingolstadt7, ingolstadt21, cologne1, cologne3, cologne8
7 changes: 4 additions & 3 deletions sumo_rl/environment/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from pettingzoo.utils.conversions import parallel_wrapper_fn

from .traffic_signal import TrafficSignal
from .observations import ObservationFunction, DefaultObservationFunction

LIBSUMO = 'LIBSUMO_AS_TRACI' in os.environ

Expand Down Expand Up @@ -50,7 +51,7 @@ class SumoEnvironment(gym.Env):
:param max_green: (int) Max green time in a phase
:single_agent: (bool) If true, it behaves like a regular gym.Env. Else, it behaves like a MultiagentEnv (https://github.com/ray-project/ray/blob/master/python/ray/rllib/env/multi_agent_env.py)
:reward_fn: (str/function/dict) String with the name of the reward function used by the agents, a reward function, or dictionary with reward functions assigned to individual traffic lights by their keys
:observation_fn: (str/function) String with the name of the observation function or a callable observation function itself
:observation_class: (ObservationFunction) Inherited class which has both the observation function and observation space
:add_system_info: (bool) If true, it computes system metrics (total queue, total waiting time, average speed) in the info dictionary
:add_per_agent_info: (bool) If true, it computes per-agent (per-traffic signal) metrics (average accumulated waiting time, average queue) in the info dictionary
:sumo_seed: (int/string) Random seed for sumo. If 'random' it uses a randomly chosen seed.
Expand Down Expand Up @@ -83,7 +84,7 @@ def __init__(
max_green: int = 50,
single_agent: bool = False,
reward_fn: Union[str,Callable,dict] = 'diff-waiting-time',
observation_fn: Union[str,Callable] = 'default',
observation_class: ObservationFunction = DefaultObservationFunction,
add_system_info: bool = True,
add_per_agent_info: bool = True,
sumo_seed: Union[str,int] = 'random',
Expand Down Expand Up @@ -135,7 +136,7 @@ def __init__(
traci.start([sumolib.checkBinary('sumo'), '-n', self._net], label='init_connection'+self.label)
conn = traci.getConnection('init_connection'+self.label)
self.ts_ids = list(conn.trafficlight.getIDList())
self.observation_fn = observation_fn
self.observation_class = observation_class

if isinstance(self.reward_fn, dict):
self.traffic_signals = {ts: TrafficSignal(self,
Expand Down
45 changes: 45 additions & 0 deletions sumo_rl/environment/observations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from .traffic_signal import TrafficSignal
from abc import abstractmethod
from gymnasium import spaces
import numpy as np

class ObservationFunction:
"""
Abstract base class for observation functions.
"""
def __init__(self, ts: TrafficSignal):
self.ts = ts

@abstractmethod
def __call__(self):
"""
Subclasses must override this method.
"""
pass

@abstractmethod
def observation_space(self):
"""
Subclasses must override this method.
"""
pass


class DefaultObservationFunction(ObservationFunction):
def __init__(self, ts: TrafficSignal):
super().__init__(ts)

def __call__(self):
phase_id = [1 if self.ts.green_phase == i else 0 for i in range(self.ts.num_green_phases)] # one-hot encoding
min_green = [0 if self.ts.time_since_last_phase_change < self.ts.min_green + self.ts.yellow_time else 1]
density = self.ts.get_lanes_density()
queue = self.ts.get_lanes_queue()
observation = np.array(phase_id + min_green + density + queue, dtype=np.float32)
return observation

def observation_space(self):
return spaces.Box(
low=np.zeros(self.ts.num_green_phases+1+2*len(self.ts.lanes), dtype=np.float32),
high=np.ones(self.ts.num_green_phases+1+2*len(self.ts.lanes), dtype=np.float32)
)

23 changes: 3 additions & 20 deletions sumo_rl/environment/traffic_signal.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,13 +60,7 @@ def __init__(self,
else:
raise NotImplementedError(f'Reward function {self.reward_fn} not implemented')

if isinstance(self.env.observation_fn, Callable):
self.observation_fn = self.env.observation_fn
else:
if self.env.observation_fn in TrafficSignal.observation_fns.keys():
self.observation_fn = TrafficSignal.observation_fns[self.env.observation_fn]
else:
raise NotImplementedError(f'Observation function {self.env.observation_fn} not implemented')
self.observation_fn = self.env.observation_class(self)

self.build_phases()

Expand All @@ -75,7 +69,7 @@ def __init__(self,
self.out_lanes = list(set(self.out_lanes))
self.lanes_lenght = {lane: self.sumo.lane.getLength(lane) for lane in self.lanes + self.out_lanes}

self.observation_space = spaces.Box(low=np.zeros(self.num_green_phases+1+2*len(self.lanes), dtype=np.float32), high=np.ones(self.num_green_phases+1+2*len(self.lanes), dtype=np.float32))
self.observation_space = self.observation_fn.observation_space()
self.discrete_observation_space = spaces.Tuple((
spaces.Discrete(self.num_green_phases), # Green Phase
spaces.Discrete(2), # Binary variable active if min_green seconds already elapsed
Expand Down Expand Up @@ -148,7 +142,7 @@ def set_next_phase(self, new_phase):
self.time_since_last_phase_change = 0

def compute_observation(self):
return self.observation_fn(self)
return self.observation_fn()

def compute_reward(self):
self.last_reward = self.reward_fn(self)
Expand Down Expand Up @@ -233,20 +227,9 @@ def register_reward_fn(cls, fn):

cls.reward_fns[fn.__name__] = fn

@classmethod
def register_observation_fn(cls, fn):
if fn.__name__ in cls.observation_fns.keys():
raise KeyError(f'Observation function {fn.__name__} already exists')

cls.observation_fns[fn.__name__] = fn

reward_fns = {
'diff-waiting-time': _diff_waiting_time_reward,
'average-speed': _average_speed_reward,
'queue': _queue_reward,
'pressure': _pressure_reward
}

observation_fns = {
'default': _observation_fn_default
}