## Create Train Dataset to train the Model

In [1]:
import os
os.chdir('../..')
from trips.config import *
from trips.weibull_trips import main
import subprocess

In [2]:
scene_src_probabilities = {
    "High Traffic Scenerio": [0.25, 0.25, 0.25, 0.25],
    "Low Traffic Scenerio": [0.25, 0.25, 0.25, 0.25],
    "NS-Traffic Scenerio": [0.45, 0.05, 0.45, 0.05],
    "EW-Traffic Scenerio": [0.05, 0.45, 0.05, 0.45],
}

scene_number_cars = {
    "High Traffic Scenerio": 1000,
    "Low Traffic Scenerio": 150,
    "NS-Traffic Scenerio": 500,
    "EW-Traffic Scenerio": 500,
}


if not os.path.exists("train/A2C/"):
    os.makedirs("train/A2C/")

    counter = 0
    for scene in scene_src_probabilities:
        for i in range(75):
            main(
                src_nodes,
                dst_nodes,
                scene_src_probabilities[scene],
                turn_prob,
                emergency_probs,
                shape,
                scene_number_cars[scene],
                total_timesteps,
            )
            subprocess.run(
                [
                    "duarouter",
                    "-n",
                    "./network_details/intersection.net.xml",  # Input network file
                    "-t",
                    "trips.trips.xml",  # Input trips file
                    "-o",
                    f"train/A2C/intersection_{counter}.rou.xml",  # Output routes file
                ]
            )
            counter += 1
            
    os.remove("trips.trips.xml")


route_files = [f"train/A2C/intersection_{i}.rou.xml" for i in range(300)]
        

## Create the environment

In [3]:
from environment.environment import MultiRouteSumoEnvironment
def make_env():
    return MultiRouteSumoEnvironment(
        net_file="network_details/intersection.net.xml",
        route_files=route_files,
        out_csv_name='outputs/train/A2C/A2C',
        min_green=5,
        yellow_time=5,
        delta_time=10,
        use_gui=False,
        num_seconds=5400
    )

## Train the policy

In [None]:
from stable_baselines3 import A2C
from stable_baselines3.common.logger import configure
from stable_baselines3.common.vec_env import DummyVecEnv, VecMonitor

num_envs = 8
env = DummyVecEnv([make_env for _ in range(num_envs)])
env = VecMonitor(env)

out_path = "outputs/train/A2C/"
new_logger = configure(out_path, ["stdout", "csv"])
model = A2C("MlpPolicy", env, verbose=1)

# Set logger and start training
model.set_logger(new_logger)
model.learn(total_timesteps=540 * 450, log_interval=108)
model.save('a2c_multi_scene')