In [None]:
# Install environment and agent
!pip install highway-env

!pip install git+https://github.com/DLR-RM/stable-baselines3

# Environment
import gymnasium as gym
import highway_env

# Agent
from stable_baselines3 import PPO


import sys
from tqdm.notebook import trange
!pip install tensorboardx gym pyvirtualdisplay
!apt-get install -y xvfb ffmpeg
!git clone https://github.com/Farama-Foundation/HighwayEnv.git 2> /dev/null
sys.path.insert(0, '/content/HighwayEnv/scripts/')
from utils import record_videos, show_videos 

In [None]:
from stable_baselines3 import DQN
import pprint
from matplotlib import pyplot as plt
import numpy as np 

In [None]:
class MyHighwayEnv(gym.Env):
    def __init__(self, vehicleCount=10):
        super(MyHighwayEnv, self).__init__()
        # base setting
        self.vehicleCount = vehicleCount

        # environment setting
        self.config = {
            "observation": {
                "type": "Kinematics",
                "features": ["presence", "x", "y", "vx", "vy"],
                "absolute": True,
                "normalize": True,
                "vehicles_count": vehicleCount,
                "see_behind": True,
            },
            "action": {
                "type": "DiscreteMetaAction",
                "target_speeds": np.linspace(0, 32, 9),
            },
            "duration": 40,
            "vehicles_density": 2,
            "show_trajectories": True,
            "render_agent": True,
        }
        self.env = gym.make("highway-fast-v0")
        self.env.configure(self.config)
        self.action_space = self.env.action_space
        self.observation_space = gym.spaces.Box(
            low=-np.inf,high=np.inf,shape=(10,5),dtype=np.float32
        )

    def step(self, action):
        # Step the wrapped environment and capture all returned values
        obs, ppo_reward, done, truncated, info = self.env.step(action)

        return obs, ppo_reward, done, truncated, info

    def reset(self, **kwargs):
        obs = self.env.reset(**kwargs)
        return obs  # Make sure to return the observation

In [None]:
from stable_baselines3.common.vec_env import DummyVecEnv
env = MyHighwayEnv()

In [None]:
# Visualization utils
%load_ext tensorboard
%tensorboard --logdir "highway_ppo"

In [None]:
batch_size = 64
model = PPO("MlpPolicy",
            env,
            policy_kwargs=dict(net_arch=[dict(pi=[256, 256], vf=[256, 256])]),
            n_steps=batch_size * 2,
            batch_size=batch_size,
            n_epochs=10,
            learning_rate=5e-4,
            gamma=0.8,
            verbose=2,
            tensorboard_log="highway_ppo/")
model.learn(int(1e4))

In [None]:
model.save('models/ppo_model')

In [None]:
from gymnasium.wrappers import RecordVideo
# base setting
vehicleCount = 10

# environment setting
config = {
    "observation": {
        "type": "Kinematics",
        "features": ["presence", "x", "y", "vx", "vy"],
        "absolute": True,
        "normalize": True,
        "vehicles_count": vehicleCount,
        "see_behind": True,
    },
    "action": {
        "type": "DiscreteMetaAction",
        "target_speeds": np.linspace(0, 32, 9),
    },
    "duration": 40,
    "vehicles_density": 2,
    "show_trajectories": True,
    "render_agent": True,
}


env = gym.make('highway-v0', render_mode='rgb_array')
env.configure(config)
env = record_videos(env)
for episode in trange(3, desc='Test episodes'):
    (obs, info), done, truncated = env.reset(), False, False
    while not (done or truncated):
        action, _ = model.predict(obs, deterministic=True)
        obs, reward, done, truncated, info = env.step(int(action))
env.close()
show_videos()