# Training a DQN with social attention on `intersection-v0`

## Import requirements

In [None]:
# Environment
import gymnasium as gym
!pip install git+https://github.com/eleurent/highway-env#egg=highway-env
import highway_env
highway_env.register_highway_envs()

# Agent
!pip install git+https://github.com/eleurent/rl-agents#egg=rl-agents

# Visualisation utils
!pip install moviepy
!pip install imageio_ffmpeg
import sys
%load_ext tensorboard
!pip install tensorboardx gym pyvirtualdisplay
!apt-get install -y xvfb python-opengl ffmpeg
!git clone https://github.com/eleurent/highway-env.git 2> /dev/null
sys.path.insert(0, '/content/highway-env/scripts/')
from utils import show_videos

## Training

Prepare environment, agent, and evaluation process.

We use a policy architecture based on social attention, see [[Leurent and Mercat, 2019]](https://arxiv.org/abs/1911.12250).


In [None]:
from rl_agents.trainer.evaluation import Evaluation
from rl_agents.agents.common.factory import load_agent, load_environment

# Get the environment and agent configurations from the rl-agents repository
!git clone https://github.com/eleurent/rl-agents.git 2> /dev/null
%cd /content/rl-agents/scripts/
env_config = 'configs/IntersectionEnv/env.json'
agent_config = 'configs/IntersectionEnv/agents/DQNAgent/ego_attention_2h.json'

env = load_environment(env_config)
agent = load_agent(agent_config, env)
evaluation = Evaluation(env, agent, num_episodes=1, display_env=False, display_agent=False)
print(f"Ready to train {agent} on {env}")

Run tensorboard locally to visualize training.

In [None]:
%tensorboard --logdir "{evaluation.directory}"

Start training. This should take about an hour.

In [None]:
evaluation.train()

Progress can be visualised in the tensorboard cell above, which should update every 30s (or manually). You may need to click the *Fit domain to data* buttons below each graph.

## Testing

Run the learned policy for a few episodes.

In [None]:
env = load_environment(env_config)
env.configure({"offscreen_rendering": True})
agent = load_agent(agent_config, env)
evaluation = Evaluation(env, agent, num_episodes=1)
evaluation.train()
show_videos(evaluation.run_directory)