In [None]:
# rl_training.py

import gym
from gym import spaces
import numpy as np

from integration_pems_ems_sumo import SUTrafficEnv


class SB3TrafficEnv(gym.Env):
    """
    Gym-compatible wrapper for Stable-Baselines3 (old Gym API).
    Wraps SUTrafficEnv.
    """

    def __init__(
        self,
        sumo_cfg: str,
        ems_day,
        pems_day_rl,
        meta_rl,
        tls_ids,
        use_gui: bool = False,
        sim_duration_s: int = 3600,
    ):
        super().__init__()

        self._env = SUTrafficEnv(
            sumo_cfg=sumo_cfg,
            ems_day=ems_day,
            pems_day_rl=pems_day_rl,
            meta_rl=meta_rl,
            tls_ids=tls_ids,
            use_gui=use_gui,
            sim_duration_s=sim_duration_s,
        )
        self.tls_ids = tls_ids
        self.num_tls = len(tls_ids)
        self.max_phases = 4  # adjust if your TLS has more phases

        # Action space: discrete phase selection per TLS
        if self.num_tls == 1:
            self.action_space = spaces.Discrete(self.max_phases)
        else:
            self.action_space = spaces.MultiDiscrete([self.max_phases] * self.num_tls)

        # Observation space: [time_of_day, hours_to_next_EV, avg_flow, avg_speed, avg_occ, phases...]
        low = np.array(
            [0.0, 0.0, 0.0, 0.0, 0.0] + [0.0] * self.num_tls,
            dtype=np.float32,
        )
        high = np.array(
            [1.0, 24.0, 5000.0, 120.0, 1.0] + [float(self.max_phases)] * self.num_tls,
            dtype=np.float32,
        )
        self.observation_space = spaces.Box(low=low, high=high, dtype=np.float32)

    def _clip_obs(self, obs):
        obs = np.array(obs, dtype=np.float32)
        if len(obs) > 0:
            obs[0] = np.clip(obs[0], 0.0, 1.0)
        if len(obs) > 1:
            obs[1] = np.clip(obs[1], 0.0, 24.0)
        if len(obs) > 2:
            obs[2] = np.clip(obs[2], 0.0, 5000.0)
        if len(obs) > 3:
            obs[3] = np.clip(obs[3], 0.0, 120.0)
        if len(obs) > 4:
            obs[4] = np.clip(obs[4], 0.0, 1.0)
        return obs

    def reset(self):
        obs = self._env.reset()
        return self._clip_obs(obs).astype(np.float32)

    def step(self, action):
        if self.num_tls == 1:
            a = int(action)
        else:
            a = np.array(action, dtype=int).tolist()

        obs, reward, done, info = self._env.step(a)
        obs = self._clip_obs(obs).astype(np.float32)
        return obs, float(reward), bool(done), info

    def close(self):
        self._env.close()


In [None]:
!pip install "gym==0.21.0" stable-baselines3


In [None]:
# rl_training.py (continued)

from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import DummyVecEnv

def make_env(sumo_cfg, ems_day, pems_day_rl, meta_rl, tls_ids):
    def _init():
        env = SB3TrafficEnv(
            sumo_cfg=sumo_cfg,
            ems_day=ems_day,
            pems_day_rl=pems_day_rl,
            meta_rl=meta_rl,
            tls_ids=tls_ids,
            use_gui=False,
            sim_duration_s=3600,  # 1 simulated hour per episode
        )
        return env
    return _init


def train_rl_agent(
    sumo_cfg,
    ems_day,
    pems_day_rl,
    meta_rl,
    tls_ids,
    total_timesteps=200_000,
    model_path="models/ppo_traffic.zip",
):
    """
    Train PPO agent on one (or a set of) EMS+PeMS day(s).
    For multiple days, you could randomize which day/env each episode uses.
    """

    os.makedirs(os.path.dirname(model_path), exist_ok=True)

    env_fn = make_env(sumo_cfg, ems_day, pems_day_rl, meta_rl, tls_ids)
    vec_env = DummyVecEnv([env_fn])

    model = PPO(
        "MlpPolicy",
        vec_env,
        verbose=1,
        tensorboard_log="./tb_logs/",
        n_steps=1024,
        batch_size=256,
        gamma=0.99,
        learning_rate=3e-4,
    )

    model.learn(total_timesteps=total_timesteps)
    model.save(model_path)

    vec_env.close()
    print("Model saved to:", model_path)
    return model_path


In [None]:
# rl_eval.py

import numpy as np
from stable_baselines3 import PPO

from rl_training import SB3TrafficEnv
from baselines.controllers import FixedTimeController, GreedyEVPreemptionController


def evaluate_policy(env, policy, episodes=5, is_sb3=False):
    """
    Evaluate a policy on an env:
    - policy: either 'sb3_model' or an object with .select_action(obs)
    - is_sb3: True if using Stable-Baselines3 model
    Returns list of episode rewards.
    """
    ep_rewards = []

    for ep in range(episodes):
        obs = env.reset()
        done = False
        total_r = 0.0

        while not done:
            if is_sb3:
                action, _ = policy.predict(obs, deterministic=True)
            else:
                action = policy.select_action(obs)

            obs, reward, done, info = env.step(action)
            total_r += reward

        ep_rewards.append(total_r)
        env.close()

    return ep_rewards


def run_evaluation(
    sumo_cfg,
    ems_day,
    pems_day_rl,
    meta_rl,
    tls_ids,
    model_path="models/ppo_traffic_day1.zip",
):
    # Build one SB3-compatible env
    env = SB3TrafficEnv(
        sumo_cfg=sumo_cfg,
        ems_day=ems_day,
        pems_day_rl=pems_day_rl,
        meta_rl=meta_rl,
        tls_ids=tls_ids,
        use_gui=False,
        sim_duration_s=3600,
    )

    # Baseline 1: fixed-time
    fixed_agent = FixedTimeController(tls_ids=tls_ids, phase_duration_steps=20, max_phases=4)
    fixed_rewards = evaluate_policy(env, fixed_agent, episodes=3, is_sb3=False)
    print("Fixed-time episode rewards:", fixed_rewards, "mean:", np.mean(fixed_rewards))

    # Baseline 2: greedy preemption
    greedy_agent = GreedyEVPreemptionController(
        tls_ids=tls_ids,
        phase_duration_steps=20,
        max_phases=4,
        tls_phase_map=None  # you can pass real mappings
    )
    env = SB3TrafficEnv(sumo_cfg, ems_day, pems_day_rl, meta_rl, tls_ids, use_gui=False, sim_duration_s=3600)
    greedy_rewards = evaluate_policy(env, greedy_agent, episodes=3, is_sb3=False)
    print("Greedy preemption episode rewards:", greedy_rewards, "mean:", np.mean(greedy_rewards))

    # RL agent
    model = PPO.load(model_path)
    env = SB3TrafficEnv(sumo_cfg, ems_day, pems_day_rl, meta_rl, tls_ids, use_gui=False, sim_duration_s=3600)
    rl_rewards = evaluate_policy(env, model, episodes=3, is_sb3=True)
    print("RL (PPO) episode rewards:", rl_rewards, "mean:", np.mean(rl_rewards))
