In [None]:
import os 
import pathlib

import traci
import sumo_rl
from pettingzoo.utils.conversions import parallel_wrapper_fn
from environment.reward_functions import combined_reward

from environment.observation import Grid2x2ObservationFunction

os.environ['SUMO_HOME'] = '/opt/homebrew/opt/sumo/share/sumo'

## Create parallel environment API using SUMO-RL + rllib

In [None]:
from sumo_rl.environment.env import env, parallel_env
from ray.tune import register_env
from ray.rllib.env.wrappers.pettingzoo_env import ParallelPettingZooEnv
from ray.rllib.env.wrappers.multi_agent_env_compatibility import MultiAgentEnvCompatibility

from environment.envs import multi_agent_env

env_folder = "data/2x2grid"

multi_agent_env = parallel_env(    
        net_file = os.path.join(env_folder, "2x2.net.xml"),
        route_file = os.path.join(env_folder, "2x2.rou.xml"),
        reward_fn = combined_reward,
        observation_class = Grid2x2ObservationFunction, 
        out_csv_name="outputs/2x2grid/ppo", 
        num_seconds=1000,
        add_per_agent_info=True,
        add_system_info=True)

parallel_petting_env = ParallelPettingZooEnv(multi_agent_env)   # ParallelPettingZoo is a wrapper from rrlib, 
                                                                # that wraps an env into rrlib compatible one, it simplifies the API 

env_name = "Multi-agent-2x2"
register_env(
    env_name,
    lambda _: parallel_petting_env
)

In [None]:
print(parallel_petting_env.action_space)
print(parallel_petting_env.observation_space)
print(parallel_petting_env.get_agent_ids())
print(parallel_petting_env.action_space_sample())
print(parallel_petting_env.get_sub_environments)

parallel_petting_env.get_sub_environments


## Create algorithm config

In [None]:

from ray.rllib.algorithms.ppo import PPOConfig
from ray.tune.logger import pretty_print

ppo_config = (
    PPOConfig() # configuration class, initialising it returns an object self (config object)
    .rollouts(num_rollout_workers=1)
    .resources(num_gpus=0)
    .environment(env_name, disable_env_checking=False)
    .training(train_batch_size=4000)
)

In [None]:
algo = ppo_config.build()

In [None]:
algo.get_policy().get_weights()

In [None]:
from ray import tune
tune.run(run_or_experiment='run', name=env_name, config=trainer)

In [None]:
pretty_print(result)

In [None]:
result = algo.train()
pretty_print(result)

checkpoint_dir = algo.save().checkpoint.path
print(f"Checkpoint saved in directory {checkpoint_dir}")

In [None]:
algo.evaluate()

In [None]:
import ray
ray.shutdown()
