In [9]:
import gymnasium
import highway_env
from stable_baselines3 import DQN

from config.config_sb import config_dict

In [10]:
env = gymnasium.make("highway-fast-v0", render_mode="rgb_array")
env.unwrapped.configure(config_dict)


obs_space = env.observation_space
print(f"Observation space: {obs_space}")


obs, info = env.reset()
print(f"Observation shape: {obs.shape}")

Observation space: Box(-inf, inf, (5, 5), float32)
Observation shape: (3, 3, 20)


In [7]:
model = DQN('MlpPolicy', env,
          policy_kwargs=dict(net_arch=[256, 256]),
          learning_rate=5e-4,
          buffer_size=15000,
          learning_starts=200,
          batch_size=32,
          gamma=0.8,
          train_freq=1,
          gradient_steps=1,
          target_update_interval=50,
          verbose=1,
          tensorboard_log="highway_dqn/")

model.learn(total_timesteps=2e4)
model.save("highway_dqn/model")

print("Model saved to highway_dqn/model")

Using cpu device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
Logging to highway_dqn/DQN_1
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 22       |
|    ep_rew_mean      | 18.4     |
|    exploration_rate | 0.958    |
| time/               |          |
|    episodes         | 4        |
|    fps              | 146      |
|    time_elapsed     | 0        |
|    total_timesteps  | 88       |
----------------------------------
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 17.6     |
|    ep_rew_mean      | 14.8     |
|    exploration_rate | 0.933    |
| time/               |          |
|    episodes         | 8        |
|    fps              | 146      |
|    time_elapsed     | 0        |
|    total_timesteps  | 141      |
----------------------------------
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 19.2     |
| 

In [8]:
model = DQN.load("highway_dqn/model")

episodes = 100
for episode in range(episodes):
    done = truncated = False
    obs, info = env.reset()
    episode_reward = 0
    
    while not (done or truncated):
        action, _states = model.predict(obs, deterministic=True)
        obs, reward, done, truncated, info = env.step(action)
        episode_reward += reward
        env.render()
    
    print(f"Episode {episode+1}: Reward = {episode_reward}")

Episode 1: Reward = 95.12864105740587
Episode 2: Reward = 16.649436275087915
Episode 3: Reward = 14.934616967000284
Episode 4: Reward = 3.079821521900746
Episode 5: Reward = 99.07750561487074
Episode 6: Reward = 97.18092789170863
Episode 7: Reward = 99.44901352761181
Episode 8: Reward = 6.004236555965546
Episode 9: Reward = 96.46436048654517
Episode 10: Reward = 5.059160714323233
Episode 11: Reward = 30.24890606698977
Episode 12: Reward = 96.2437311132718


AttributeError: 'NoneType' object has no attribute 'get_image'