sb3_contrib.qrdqn
Quantile Regression DQN (QR-DQN) builds on Deep Q-Network (DQN) and make use of quantile regression to explicitly model the distribution over returns, instead of predicting the mean return (DQN).
Available Policies
MlpPolicy CnnPolicy MultiInputPolicy
- Original paper: https://arxiv.org/abs/1710.100442
- Distributional RL (C51): https://arxiv.org/abs/1707.06887
- Further reference: https://github.com/amy12xx/ml_notes_and_reports/blob/master/distributional_rl/QRDQN.pdf
- Recurrent policies: ❌
- Multi processing: ✔️
- Gym spaces:
Space | Action | Observation |
---|---|---|
Discrete | ✔️ | ✔️ |
Box | ❌ ✔ | ️ |
MultiDiscrete | ❌ ✔ | ️ |
MultiBinary | ❌ ✔ | ️ |
Dict | ❌ ✔ | ️ |
import gym
from sb3_contrib import QRDQN
env = gym.make("CartPole-v1")
policy_kwargs = dict(n_quantiles=50)
model = QRDQN("MlpPolicy", env, policy_kwargs=policy_kwargs, verbose=1)
model.learn(total_timesteps=10000, log_interval=4)
model.save("qrdqn_cartpole")
del model # remove to demonstrate saving and loading
model = QRDQN.load("qrdqn_cartpole")
obs = env.reset()
while True:
action, _states = model.predict(obs, deterministic=True)
obs, reward, done, info = env.step(action)
env.render()
if done:
obs = env.reset()
Result on Atari environments (10M steps, Pong and Breakout) and classic control tasks using 3 and 5 seeds.
The complete learning curves are available in the associated PR.
Note
QR-DQN implementation was validated against Intel Coach one which roughly compare to the original paper results (we trained the agent with a smaller budget).
Environments | QR-DQN | DQN |
---|---|---|
Breakout | 413 +/- 21 | ~300 |
Pong | 20 +/- 0 | ~20 |
CartPole | 386 +/- 64 | 500 +/- 0 |
MountainCar | -111 +/- 4 | -107 +/- 4 |
LunarLander | 168 +/- 39 | 195 +/- 28 |
Acrobot | -73 +/- 2 | -74 +/- 2 |
Clone RL-Zoo fork and checkout the branch feat/qrdqn
:
git clone https://github.com/ku2482/rl-baselines3-zoo/
cd rl-baselines3-zoo/
git checkout feat/qrdqn
Run the benchmark (replace $ENV_ID
by the envs mentioned above):
python train.py --algo qrdqn --env $ENV_ID --eval-episodes 10 --eval-freq 10000
Plot the results:
python scripts/all_plots.py -a qrdqn -e Breakout Pong -f logs/ -o logs/qrdqn_results
python scripts/plot_from_file.py -i logs/qrdqn_results.pkl -latex -l QR-DQN
QRDQN
MlpPolicy
sb3_contrib.qrdqn.policies.QRDQNPolicy
CnnPolicy
MultiInputPolicy