Implémentation d'un modèle DQN sur l'environnement merge-V0

In [None]:
import gymnasium
from gymnasium.wrappers import RecordVideo
from stable_baselines3 import DQN
import tensorboard
from RL.config_merge import config_merge 
import highway_env
highway_env.register_highway_envs()  # noqa: F401

TRAIN = True

if __name__ == "__main__":
    # Create the environment
    env = gymnasium.make("merge-v0", render_mode='rgb_array',config=config_merge)
    obs, info = env.reset()

    # Create the model
    model = DQN(
        "MlpPolicy",
        env,
        policy_kwargs=dict(net_arch=[256, 256]),
        learning_rate=5e-4,
        buffer_size=15000,
        learning_starts=200,
        batch_size=32,
        gamma=0.8,
        train_freq=1,
        gradient_steps=1,
        target_update_interval=50,
        verbose=1,
        tensorboard_log="racetrack/",
    )

    # Train the model
    if TRAIN:
        model.learn(total_timesteps=int(2e4))
        model.save("racetrack/model")
        del model

    # Run the trained model and record video
    model = DQN.load("racetrack/model", env=env)
    env = RecordVideo(
        env, video_folder="racetrack/videos", episode_trigger=lambda e: True
    )
    env.unwrapped.set_record_video_wrapper(env)
    env.configure({"simulation_frequency": 15})  # Higher FPS for rendering

for videos in range(10):
    print(videos)
    done = truncated = False
    obs, info = env.reset()
    while not (done or truncated):
        # Predict
        action, _states = model.predict(obs, deterministic=True)
        # Get reward
        obs, reward, done, truncated, info = env.step(action)
        # Render
        env.render()
env.close()