In [None]:
# file: envs/traffic_env.py

import gym
from gym import spaces
import numpy as np

from integration_pems_ems_sumo import SUTrafficEnv  # the class we wrote before


class GymTrafficEnv(gym.Env):
    """
    Gym-style wrapper around SUTrafficEnv.

    Observations: continuous vector (time of day, hours_to_next_EV, avg_flow, avg_speed, avg_occ, phases...)
    Actions: discrete phase index per TLS (Discrete for single TLS, MultiDiscrete for multiple).
    """

    metadata = {"render_modes": ["human"], "render_fps": 10}

    def __init__(
        self,
        sumo_cfg: str,
        ems_day,
        pems_day_rl,
        meta_rl,
        tls_ids,
        use_gui: bool = False,
        sim_duration_s: int = 3600,
    ):
        super().__init__()

        self.tls_ids = tls_ids
        self.num_tls = len(tls_ids)

        # underlying SUMO env
        self._env = SUTrafficEnv(
            sumo_cfg=sumo_cfg,
            ems_day=ems_day,
            pems_day_rl=pems_day_rl,
            meta_rl=meta_rl,
            tls_ids=tls_ids,
            use_gui=use_gui,
            sim_duration_s=sim_duration_s,
        )

        # ---------- ACTION SPACE ----------
        # assume each TLS has at most N phases; start with N=4 (you can adjust later)
        self.max_phases = 4
        if self.num_tls == 1:
            self.action_space = spaces.Discrete(self.max_phases)
        else:
            self.action_space = spaces.MultiDiscrete([self.max_phases] * self.num_tls)

        # ---------- OBSERVATION SPACE ----------
        # obs = [time_of_day, hours_to_next_EV, avg_flow, avg_speed, avg_occ, phases...]
        # Rough bounds:
        # time_of_day ∈ [0, 1]
        # hours_to_next_EV ∈ [0, 24] (clip)
        # avg_flow ∈ [0, 5000] (vehicles / 5 min)
        # avg_speed ∈ [0, 120] (mph)
        # avg_occ ∈ [0, 1]
        # phases ∈ [0, max_phases]
        low = np.array(
            [0.0, 0.0, 0.0, 0.0, 0.0] + [0.0] * self.num_tls,
            dtype=np.float32,
        )
        high = np.array(
            [1.0, 24.0, 5000.0, 120.0, 1.0] + [float(self.max_phases)] * self.num_tls,
            dtype=np.float32,
        )

        self.observation_space = spaces.Box(low=low, high=high, dtype=np.float32)

    def reset(self, *, seed=None, options=None):
        super().reset(seed=seed)
        obs = self._env.reset()
        obs = self._clip_obs(obs)
        return obs.astype(np.float32), {}

    def step(self, action):
        # convert action to a list if multiple TLS
        if self.num_tls == 1:
            a = int(action)
        else:
            a = np.array(action, dtype=int).tolist()

        obs, reward, done, info = self._env.step(a)
        obs = self._clip_obs(obs)

        # gym step signature: obs, reward, terminated, truncated, info
        terminated = bool(done)
        truncated = False  # you can implement time-based truncation if you want

        return obs.astype(np.float32), float(reward), terminated, truncated, info

    def _clip_obs(self, obs):
        obs = np.array(obs, dtype=np.float32)

        # sanity clipping
        obs[0] = np.clip(obs[0], 0.0, 1.0)       # time_of_day
        if len(obs) > 1:
            obs[1] = np.clip(obs[1], 0.0, 24.0)  # hours_to_next_EV

        if len(obs) > 2:
            obs[2] = np.clip(obs[2], 0.0, 5000.0)  # avg_flow
        if len(obs) > 3:
            obs[3] = np.clip(obs[3], 0.0, 120.0)   # avg_speed
        if len(obs) > 4:
            obs[4] = np.clip(obs[4], 0.0, 1.0)     # avg_occ

        return obs

    def render(self):
        # if you started SUMO with GUI, you already "see" it; no extra render here
        pass

    def close(self):
        self._env.close()


In [None]:
from envs.traffic_env import GymTrafficEnv

env = GymTrafficEnv(
    sumo_cfg="config.sumo.cfg",
    ems_day=ems_day_df,          # filtered EMS for that day with nearest_station, etc.
    pems_day_rl=pems_day_rl_df,  # transformed PeMS day
    meta_rl=meta_rl_df,          # station metadata
    tls_ids=["TL_1"],            # actual tls ID from SUMO
    use_gui=False,
    sim_duration_s=3600
)

obs, _ = env.reset()
done = False
while not done:
    action = env.action_space.sample()
    obs, reward, terminated, truncated, info = env.step(action)
    done = terminated or truncated

env.close()


In [None]:
# file: baselines/controllers.py

import numpy as np
import traci


class FixedTimeController:
    """
    Baseline 1: fixed-time traffic signal.
    Ignores EMS, just cycles phases every K simulation steps.
    """

    def __init__(self, tls_ids, phase_duration_steps=20, max_phases=4):
        self.tls_ids = tls_ids
        self.phase_duration_steps = phase_duration_steps
        self.max_phases = max_phases
        self.current_phase_idx = {tls_id: 0 for tls_id in tls_ids}
        self.step_counter = 0

    def select_action(self, obs=None):
        """
        Returns action compatible with GymTrafficEnv:
        - scalar if one TLS
        - list of actions if multiple
        """
        # advance phase every phase_duration_steps
        if self.step_counter % self.phase_duration_steps == 0 and self.step_counter > 0:
            for tls_id in self.tls_ids:
                self.current_phase_idx[tls_id] = (
                    self.current_phase_idx[tls_id] + 1
                ) % self.max_phases

        self.step_counter += 1

        if len(self.tls_ids) == 1:
            return self.current_phase_idx[self.tls_ids[0]]
        else:
            return [self.current_phase_idx[tls_id] for tls_id in self.tls_ids]


class GreedyEVPreemptionController:
    """
    Baseline 2: simple emergency preemption.
    If any EV is detected on an incoming edge for this TLS, switch/hold green toward that EV.
    Otherwise behaves like a fixed-time controller.
    """

    def __init__(
        self,
        tls_ids,
        ev_prefix="EV_",
        phase_duration_steps=20,
        max_phases=4,
        tls_phase_map=None,
    ):
        """
        tls_phase_map: optional dict mapping
            tls_id -> {phase_index: [incoming_edge_ids_for_that_phase]}
        If not provided, this example just uses phase index 0 as "EV direction".
        """
        self.tls_ids = tls_ids
        self.ev_prefix = ev_prefix
        self.phase_duration_steps = phase_duration_steps
        self.max_phases = max_phases
        self.tls_phase_map = tls_phase_map or {}

        self.current_phase_idx = {tls_id: 0 for tls_id in tls_ids}
        self.step_counter = 0

    def _ev_present_for_phase(self, tls_id, phase_idx):
        """
        Check if an EV exists on any incoming edge mapped to this phase.
        """
        phase_map = self.tls_phase_map.get(tls_id, {})
        incoming_edges = phase_map.get(phase_idx, [])

        if not incoming_edges:
            return False

        for vid in traci.vehicle.getIDList():
            if not vid.startswith(self.ev_prefix):
                continue
            # get the edge where vehicle is now
            edge_id = traci.vehicle.getRoadID(vid)
            if edge_id in incoming_edges:
                return True

        return False

    def select_action(self, obs=None):
        """
        Returns action compatible with GymTrafficEnv.
        """
        actions = []

        for tls_id in self.tls_ids:
            # 1) Check if any phase has an EV present; if so, preempt to that phase
            preempt_phase = None
            phase_map = self.tls_phase_map.get(tls_id, {})

            if phase_map:
                for phase_idx in phase_map.keys():
                    if self._ev_present_for_phase(tls_id, phase_idx):
                        preempt_phase = phase_idx
                        break

            if preempt_phase is not None:
                self.current_phase_idx[tls_id] = preempt_phase
            else:
                # 2) No EV: fixed-time cycling
                if self.step_counter % self.phase_duration_steps == 0 and self.step_counter > 0:
                    self.current_phase_idx[tls_id] = (
                        self.current_phase_idx[tls_id] + 1
                    ) % self.max_phases

            actions.append(self.current_phase_idx[tls_id])

        self.step_counter += 1

        if len(self.tls_ids) == 1:
            return actions[0]
        return actions


In [None]:
from envs.traffic_env import GymTrafficEnv
from baselines.controllers import FixedTimeController, GreedyEVPreemptionController

# build env (same as before)
env = GymTrafficEnv(
    sumo_cfg="config.sumo.cfg",
    ems_day=ems_day_df,
    pems_day_rl=pems_day_rl_df,
    meta_rl=meta_rl_df,
    tls_ids=["TL_1"],
    use_gui=False,
    sim_duration_s=3600
)

# Baseline 1: fixed-time
fixed_agent = FixedTimeController(tls_ids=["TL_1"], phase_duration_steps=20, max_phases=4)

obs, _ = env.reset()
done = False
total_reward_fixed = 0.0

while not done:
    action = fixed_agent.select_action(obs)
    obs, reward, terminated, truncated, info = env.step(action)
    done = terminated or truncated
    total_reward_fixed += reward

env.close()
print("Fixed-time total reward:", total_reward_fixed)
