In [1]:
import ray
ray.shutdown()

In [None]:
# --- Credit: https://github.com/LucasAlegre/sumo-rl ---

import os
import sys
# * --- SETUP SUMO VARIABLES ---
if "SUMO_HOME" in os.environ:
    tools = os.path.join(os.environ["SUMO_HOME"], "tools")
    sys.path.append(tools)
else:
    sys.exit("Please declare the environment variable 'SUMO_HOME'")

# * --- IMPORT DEPENDENCIES ---
from ray import tune
from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.env.wrappers.pettingzoo_env import ParallelPettingZooEnv
from ray.tune.registry import register_env
import numpy as np
import pandas as pd
import ray
import sumo_rl


if __name__ == "__main__":
    ray.init(ignore_reinit_error=True)

    env_name = "miami-grid"

    # * --- REGISTER SUMO ENVIRONMENT ---
    register_env(
        env_name,
        lambda _: ParallelPettingZooEnv(
            sumo_rl.parallel_env(
                # * ADJUST NETWORK AND ROUTE FILES
                # * --------------------------------
                net_file="path/to/your/network/osm.net.xml",
                route_file="path/to/your/routes/osm.rou.xml",
                out_csv_name="outputs/sample-grid/ppo",
                use_gui=False,
                num_seconds=80000,
            )
        ),
    )

    # * --- PPO CONFIGURATION ---
    config = (
        PPOConfig()
        .environment(env=env_name, disable_env_checking=True)
        .rollouts(num_rollout_workers=4, rollout_fragment_length=128)
        .training(
            train_batch_size=512,
            lr=2e-5,
            gamma=0.95,
            lambda_=0.9,
            use_gae=True,
            clip_param=0.4,
            grad_clip=None,
            entropy_coeff=0.1,
            vf_loss_coeff=0.25,
            sgd_minibatch_size=64,
            num_sgd_iter=10,
        )
        .debugging(log_level="ERROR")
        .framework(framework="torch")
        .resources(num_gpus=0)
    )

    # * --- TRAINING ---
    tune.run(
        "PPO",
        name="PPO",
        stop={"timesteps_total": 100000},
        checkpoint_freq=10,
        storage_path="/Users/jamesb/Documents/init-build/Traffic-RL/ray_results/" + env_name,
        config=config.to_dict(),
    )