<a href="https://colab.research.google.com/github/PeterWANGHK/semi_trailer_1/blob/main/scripts/sb3_highway_dqn.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Highway with SB3's DQN

##  Warming up
We start with a few useful installs and imports:

In [None]:
# Install environment and agent
!pip install highway-env
# TODO: we use the bleeding edge version because the current stable version does not support the latest gym>=0.21 versions. Revert back to stable at the next SB3 release.
!pip install git+https://github.com/DLR-RM/stable-baselines3

# Environment
import gymnasium as gym
import highway_env

gym.register_envs(highway_env)

# Agent
from stable_baselines3 import DQN

# Visualization utils
%load_ext tensorboard
import sys
from tqdm.notebook import trange
!pip install tensorboardx gym pyvirtualdisplay
!apt-get install -y xvfb ffmpeg
!git clone https://github.com/Farama-Foundation/HighwayEnv.git 2> /dev/null
sys.path.insert(0, '/content/HighwayEnv/scripts/')
from utils import record_videos, show_videos

Collecting highway-env
  Downloading highway_env-1.10.1-py3-none-any.whl.metadata (16 kB)
Downloading highway_env-1.10.1-py3-none-any.whl (104 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m105.0/105.0 kB[0m [31m3.0 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: highway-env
Successfully installed highway-env-1.10.1
Collecting git+https://github.com/DLR-RM/stable-baselines3
  Cloning https://github.com/DLR-RM/stable-baselines3 to /tmp/pip-req-build-uxk2b0o4
  Running command git clone --filter=blob:none --quiet https://github.com/DLR-RM/stable-baselines3 /tmp/pip-req-build-uxk2b0o4
  Resolved https://github.com/DLR-RM/stable-baselines3 to commit d487f2d2355a6cf81ea26a0bbbdf1a727ca2a886
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Building wheels for collected packages: stable_baselines3
  Building wheel for stable_baseline

Gym has been unmaintained since 2022 and does not support NumPy 2.0 amongst other critical functionality.
Please upgrade to Gymnasium, the maintained drop-in replacement of Gym, or contact the authors of your software and request that they upgrade.
See the migration guide at https://gymnasium.farama.org/introduction/migration_guide/ for additional information.
  return datetime.utcnow().replace(tzinfo=utc)


Collecting tensorboardx
  Downloading tensorboardx-2.6.4-py3-none-any.whl.metadata (6.2 kB)
Collecting pyvirtualdisplay
  Downloading PyVirtualDisplay-3.0-py3-none-any.whl.metadata (943 bytes)
Downloading tensorboardx-2.6.4-py3-none-any.whl (87 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m87.2/87.2 kB[0m [31m4.3 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading PyVirtualDisplay-3.0-py3-none-any.whl (15 kB)
Installing collected packages: pyvirtualdisplay, tensorboardx
Successfully installed pyvirtualdisplay-3.0 tensorboardx-2.6.4
Reading package lists... Done
Building dependency tree... Done
Reading state information... Done
ffmpeg is already the newest version (7:4.4.2-0ubuntu0.22.04.1).
xvfb is already the newest version (2:21.1.4-2ubuntu1.7~22.04.15).
0 upgraded, 0 newly installed, 0 to remove and 38 not upgraded.


## Training
Run tensorboard locally to visualize training.

In [None]:
%tensorboard --logdir "highway_dqn"

In [None]:
model = DQN('MlpPolicy', 'highway-fast-v0',
                policy_kwargs=dict(net_arch=[256, 256]),
                learning_rate=5e-4,
                buffer_size=15000,
                learning_starts=200,
                batch_size=32,
                gamma=0.8,
                train_freq=1,
                gradient_steps=1,
                target_update_interval=50,
                exploration_fraction=0.7,
                verbose=1,
                tensorboard_log='highway_dqn/')
model.learn(int(2e4))


## Testing

Visualize a few episodes

In [None]:
env = gym.make('highway-fast-v0', render_mode='rgb_array')
env = record_videos(env)
for episode in trange(3, desc='Test episodes'):
    (obs, info), done = env.reset(), False
    while not done:
        action, _ = model.predict(obs, deterministic=True)
        obs, reward, done, truncated, info = env.step(int(action))
env.close()
show_videos()