## Create Train Dataset to train the Model

In [1]:
import os
os.chdir('../..')
from trips.config import *
from trips.weibull_trips import main
import subprocess

In [None]:
number_cars = 400

if not os.path.exists("train/emergency"):
    os.makedirs("train/emergency")
    
    for i in range(300):
        main(
            src_nodes,
            dst_nodes,
            src_prob,
            turn_prob,
            emergency_probs,
            shape,
            number_cars,
            total_timesteps,
        )
        subprocess.run(
            [
                "duarouter",
                "-n",
                "./network_details/intersection.net.xml",  # Input network file
                "-t",
                "trips.trips.xml",  # Input trips file
                "-o",
                f"train/emergency/intersection_{i}.rou.xml",  # Output routes file
            ]
        )
    os.remove("trips.trips.xml")
    
route_files = [f"train/emergency/intersection_{i}.rou.xml" for i in range(300)]

## Create the environment

In [3]:
from environment.environment import EmergencySumoEnvironment
from environment.observation import EmergencyObservationFunction
from environment.reward import emergency_reward

In [4]:
env = EmergencySumoEnvironment(
    net_file="network_details/intersection.net.xml",
    route_files=route_files,
    out_csv_name=f'outputs/train/Emergency/DQN',
    min_green=5,
    yellow_time=5,
    delta_time=10,
    use_gui=False,
    num_seconds=5400,
    observation_class=EmergencyObservationFunction,
    reward_fn=emergency_reward
)

 Retrying in 1 seconds
Step #0.00 (0ms ?*RT. ?UPS, TraCI: 9ms, vehicles TOT 0 ACT 0 BUF 0)                      


## Train the policy

In [None]:
from stable_baselines3 import DQN
from stable_baselines3.common.logger import configure

out_path = "outputs/train/Emergency/"
new_logger = configure(out_path, ["stdout", "csv"])
model = DQN(
    env=env,
    policy="MlpPolicy",
    learning_rate=0.0001,     # Conservative learning rate
    train_freq=4,             # Less frequent updates
    learning_starts=2000,
    target_update_interval=200,
    exploration_fraction=0.35,
    exploration_initial_eps=1.0,
    exploration_final_eps=0.05,
    tau=0.01,
    buffer_size=50000,
    batch_size=32,
    gamma=0.95,              # Shorter-term focus
    verbose=1,
    gradient_steps=1,        # Conservative gradient updates
    max_grad_norm=0.5        # Added gradient clipping
)



model.set_logger(new_logger)
model.learn(540 * 300, log_interval=1)
model.save('agents/dqn_emergency')