Skip to content

Commit

Permalink
Refactor ACKTR continuous (#476)
Browse files Browse the repository at this point in the history
* Add ACKTR support for continuous actions

* Renamed acktr file

* Add test for ACKTR with continuous actions

* Remove unused acktr utils

* Enable GAE for ACKTR

* Add ACKTR box pretraining test

* Remove unused code

* Rename gae lambda

* Change default gae_lambda to None
  • Loading branch information
araffin committed Sep 28, 2019
1 parent 4929e54 commit 46d13b4
Show file tree
Hide file tree
Showing 24 changed files with 91 additions and 1,496 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ All the following examples can be executed online using Google colab notebooks:
| ------------------- | ---------------------------- | ------------------ | ------------------ | ------------------ | ------------------- | ------------------ | --------------------------------- |
| A2C | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
| ACER | :heavy_check_mark: | :heavy_check_mark: | :x: <sup>(5)</sup> | :heavy_check_mark: | :x: | :x: | :heavy_check_mark: |
| ACKTR | :heavy_check_mark: | :heavy_check_mark: | :x: <sup>(5)</sup> | :heavy_check_mark: | :x: | :x: | :heavy_check_mark: |
| ACKTR | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :x: | :x: | :heavy_check_mark: |
| DDPG | :heavy_check_mark: | :x: | :heavy_check_mark: | :x: | :x: | :x: | :heavy_check_mark: <sup>(4)</sup>|
| DQN | :heavy_check_mark: | :x: | :x: | :heavy_check_mark: | :x: | :x: | :x: |
| GAIL <sup>(2)</sup> | :heavy_check_mark: | :x: | :heavy_check_mark: |:heavy_check_mark:| :x: | :x: | :heavy_check_mark: <sup>(4)</sup> |
Expand Down
2 changes: 1 addition & 1 deletion docs/guide/algos.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ Name Refactored [#f1]_ Recurrent ``Box`` ``Discrete`` Multi P
============ ======================== ========= =========== ============ ================
A2C ✔️ ✔️ ✔️ ✔️ ✔️
ACER ✔️ ✔️ ❌ [#f4]_ ✔️ ✔️
ACKTR ✔️ ✔️ ❌ [#f4]_ ✔️ ✔️
ACKTR ✔️ ✔️ ✔️ ✔️ ✔️
DDPG ✔️ ❌ ✔️ ❌ ✔️ [#f3]_
DQN ✔️ ❌ ❌ ✔️ ❌
HER ✔️ ❌ ✔️ ✔️ ❌
Expand Down
11 changes: 9 additions & 2 deletions docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,16 @@ Breaking Changes:
by explicitly setting `start_method = 'fork'`. See
`PR #428 <https://github.com/hill-a/stable-baselines/pull/428>`_.
- Updated dependencies: tensorflow v1.8.0 is now required
- Remove `checkpoint_path` and `checkpoint_freq` argument from `DQN` that were not used
- Removed `checkpoint_path` and `checkpoint_freq` argument from `DQN` that were not used
- Removed `bench/benchmark.py` that was not used
- Removed several functions from `common/tf_util.py` that were not used
- Removed `ppo1/run_humanoid.py`

New Features:
^^^^^^^^^^^^^
- **important change** Switch to using zip-archived JSON and Numpy `savez` for
storing models for better support across library/Python versions. (@Miffyli)
- ACKTR now supports continuous actions
- Add `double_q` argument to `DQN` constructor

Bug Fixes:
Expand All @@ -52,7 +56,10 @@ Others:
to `stable_baselines.common.noise`. The API remains backward-compatible;
for example `from stable_baselines.ddpg.noise import NormalActionNoise` is still
okay. (@shwang)
- docker images were updated
- Docker images were updated
- Cleaned up files in `common/` folder and in `acktr/` folder that were only used by old ACKTR version
(e.g. `filter.py`)
- Renamed `acktr_disc.py` to `acktr.py`

Documentation:
^^^^^^^^^^^^^^
Expand Down
2 changes: 1 addition & 1 deletion docs/modules/acktr.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ Can I use?
Space Action Observation
============= ====== ===========
Discrete ✔️ ✔️
Box ✔️
Box ✔️ ✔️
MultiDiscrete ❌ ✔️
MultiBinary ❌ ✔️
============= ====== ===========
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines/acktr/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from stable_baselines.acktr.acktr_disc import ACKTR
from stable_baselines.acktr.acktr import ACKTR
Original file line number Diff line number Diff line change
@@ -1,20 +1,17 @@
"""
Discrete acktr
"""

import time
from collections import deque

import tensorflow as tf
import numpy as np
import tensorflow as tf
from gym.spaces import Box, Discrete

from stable_baselines import logger
from stable_baselines.common import explained_variance, ActorCriticRLModel, tf_util, SetVerbosity, TensorboardWriter
from stable_baselines.a2c.a2c import A2CRunner
from stable_baselines.a2c.utils import Scheduler, calc_entropy, mse, \
from stable_baselines.ppo2.ppo2 import Runner as PPO2Runner
from stable_baselines.a2c.utils import Scheduler, mse, \
total_episode_reward_logger
from stable_baselines.acktr import kfac
from stable_baselines.common import explained_variance, ActorCriticRLModel, tf_util, SetVerbosity, TensorboardWriter
from stable_baselines.common.policies import ActorCriticPolicy, RecurrentActorCriticPolicy
from stable_baselines.ppo2.ppo2 import safe_mean

Expand All @@ -40,15 +37,18 @@ class ACKTR(ActorCriticRLModel):
:param tensorboard_log: (str) the log location for tensorboard (if None, no logging)
:param _init_setup_model: (bool) Whether or not to build the network at the creation of the instance
:param async_eigen_decomp: (bool) Use async eigen decomposition
:param kfac_update: (int) update kfac after kfac_update steps
:param policy_kwargs: (dict) additional arguments to be passed to the policy on creation
:param gae_lambda: (float) Factor for trade-off of bias vs variance for Generalized Advantage Estimator
If None (default), then the classic advantage will be used instead of GAE
:param full_tensorboard_log: (bool) enable additional logging when using tensorboard
WARNING: this logging can take a lot of space quickly
"""

def __init__(self, policy, env, gamma=0.99, nprocs=1, n_steps=20, ent_coef=0.01, vf_coef=0.25, vf_fisher_coef=1.0,
learning_rate=0.25, max_grad_norm=0.5, kfac_clip=0.001, lr_schedule='linear', verbose=0,
tensorboard_log=None, _init_setup_model=True, async_eigen_decomp=False,
policy_kwargs=None, full_tensorboard_log=False):
tensorboard_log=None, _init_setup_model=True, async_eigen_decomp=False, kfac_update=1,
gae_lambda=None, policy_kwargs=None, full_tensorboard_log=False):

super(ACKTR, self).__init__(policy=policy, env=env, verbose=verbose, requires_vec_env=True,
_init_setup_model=_init_setup_model, policy_kwargs=policy_kwargs)
Expand All @@ -66,16 +66,17 @@ def __init__(self, policy, env, gamma=0.99, nprocs=1, n_steps=20, ent_coef=0.01,
self.tensorboard_log = tensorboard_log
self.async_eigen_decomp = async_eigen_decomp
self.full_tensorboard_log = full_tensorboard_log
self.kfac_update = kfac_update
self.gae_lambda = gae_lambda

self.graph = None
self.sess = None
self.action_ph = None
self.actions_ph = None
self.advs_ph = None
self.rewards_ph = None
self.pg_lr_ph = None
self.model = None
self.model2 = None
self.logits = None
self.learning_rate_ph = None
self.step_model = None
self.train_model = None
self.entropy = None
self.pg_loss = None
self.vf_loss = None
Expand All @@ -88,8 +89,6 @@ def __init__(self, policy, env, gamma=0.99, nprocs=1, n_steps=20, ent_coef=0.01,
self.train_op = None
self.q_runner = None
self.learning_rate_schedule = None
self.train_model = None
self.step_model = None
self.step = None
self.proba_step = None
self.value = None
Expand All @@ -98,24 +97,25 @@ def __init__(self, policy, env, gamma=0.99, nprocs=1, n_steps=20, ent_coef=0.01,
self.summary = None
self.episode_reward = None
self.trained = False
self.continuous_actions = False

if _init_setup_model:
self.setup_model()

def _get_pretrain_placeholders(self):
policy = self.train_model
if isinstance(self.action_space, Discrete):
return policy.obs_ph, self.action_ph, policy.policy
raise NotImplementedError("WIP: ACKTR does not support Continuous actions yet.")
return policy.obs_ph, self.actions_ph, policy.policy
return policy.obs_ph, self.actions_ph, policy.deterministic_action

def setup_model(self):
with SetVerbosity(self.verbose):

assert issubclass(self.policy, ActorCriticPolicy), "Error: the input policy for the ACKTR model must be " \
"an instance of common.policies.ActorCriticPolicy."

if isinstance(self.action_space, Box):
raise NotImplementedError("WIP: ACKTR does not support Continuous actions yet.")
# Enable continuous actions tricks (normalized advantage)
self.continuous_actions = isinstance(self.action_space, Box)

self.graph = tf.Graph()
with self.graph.as_default():
Expand All @@ -127,35 +127,34 @@ def setup_model(self):
n_batch_step = self.n_envs
n_batch_train = self.n_envs * self.n_steps

self.model = step_model = self.policy(self.sess, self.observation_space, self.action_space, self.n_envs,
1, n_batch_step, reuse=False, **self.policy_kwargs)
step_model = self.policy(self.sess, self.observation_space, self.action_space, self.n_envs,
1, n_batch_step, reuse=False, **self.policy_kwargs)

self.params = params = tf_util.get_trainable_vars("model")

with tf.variable_scope("train_model", reuse=True,
custom_getter=tf_util.outer_scope_getter("train_model")):
self.model2 = train_model = self.policy(self.sess, self.observation_space, self.action_space,
self.n_envs, self.n_steps, n_batch_train,
reuse=True, **self.policy_kwargs)
train_model = self.policy(self.sess, self.observation_space, self.action_space,
self.n_envs, self.n_steps, n_batch_train,
reuse=True, **self.policy_kwargs)

with tf.variable_scope("loss", reuse=False, custom_getter=tf_util.outer_scope_getter("loss")):
self.advs_ph = advs_ph = tf.placeholder(tf.float32, [None])
self.rewards_ph = rewards_ph = tf.placeholder(tf.float32, [None])
self.pg_lr_ph = pg_lr_ph = tf.placeholder(tf.float32, [])
self.action_ph = action_ph = train_model.pdtype.sample_placeholder([None])
self.learning_rate_ph = learning_rate_ph = tf.placeholder(tf.float32, [])
self.actions_ph = train_model.pdtype.sample_placeholder([None])

logpac = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=train_model.policy, labels=action_ph)
self.logits = train_model.policy
neg_log_prob = train_model.proba_distribution.neglogp(self.actions_ph)

# training loss
pg_loss = tf.reduce_mean(advs_ph * logpac)
self.entropy = entropy = tf.reduce_mean(calc_entropy(train_model.policy))
pg_loss = tf.reduce_mean(advs_ph * neg_log_prob)
self.entropy = entropy = tf.reduce_mean(train_model.proba_distribution.entropy())
self.pg_loss = pg_loss = pg_loss - self.ent_coef * entropy
self.vf_loss = vf_loss = mse(tf.squeeze(train_model.value_fn), rewards_ph)
train_loss = pg_loss + self.vf_coef * vf_loss

# Fisher loss construction
self.pg_fisher = pg_fisher_loss = -tf.reduce_mean(logpac)
self.pg_fisher = pg_fisher_loss = -tf.reduce_mean(neg_log_prob)
sample_net = train_model.value_fn + tf.random_normal(tf.shape(train_model.value_fn))
self.vf_fisher = vf_fisher_loss = - self.vf_fisher_coef * tf.reduce_mean(
tf.pow(train_model.value_fn - tf.stop_gradient(sample_net), 2))
Expand All @@ -172,12 +171,12 @@ def setup_model(self):

with tf.variable_scope("input_info", reuse=False):
tf.summary.scalar('discounted_rewards', tf.reduce_mean(self.rewards_ph))
tf.summary.scalar('learning_rate', tf.reduce_mean(self.pg_lr_ph))
tf.summary.scalar('learning_rate', tf.reduce_mean(self.learning_rate_ph))
tf.summary.scalar('advantage', tf.reduce_mean(self.advs_ph))

if self.full_tensorboard_log:
tf.summary.histogram('discounted_rewards', self.rewards_ph)
tf.summary.histogram('learning_rate', self.pg_lr_ph)
tf.summary.histogram('learning_rate', self.learning_rate_ph)
tf.summary.histogram('advantage', self.advs_ph)
if tf_util.is_image(self.observation_space):
tf.summary.image('observation', train_model.obs_ph)
Expand All @@ -186,8 +185,8 @@ def setup_model(self):

with tf.variable_scope("kfac", reuse=False, custom_getter=tf_util.outer_scope_getter("kfac")):
with tf.device('/gpu:0'):
self.optim = optim = kfac.KfacOptimizer(learning_rate=pg_lr_ph, clip_kl=self.kfac_clip,
momentum=0.9, kfac_update=1,
self.optim = optim = kfac.KfacOptimizer(learning_rate=learning_rate_ph, clip_kl=self.kfac_clip,
momentum=0.9, kfac_update=self.kfac_update,
epsilon=0.01, stats_decay=0.99,
async_eigen_decomp=self.async_eigen_decomp,
cold_iter=10,
Expand Down Expand Up @@ -220,13 +219,28 @@ def _train_step(self, obs, states, rewards, masks, actions, values, update, writ
:return: (float, float, float) policy loss, value loss, policy entropy
"""
advs = rewards - values
cur_lr = None
# Normalize advantage (used in the original continuous version)
if self.continuous_actions:
advs = (advs - advs.mean()) / (advs.std() + 1e-8)

current_lr = None

assert len(obs) > 0, "Error: the observation input array cannot be empty"

# Note: in the original continuous version,
# the stepsize was automatically tuned computing the kl div
# and comparing it to the desired one
for _ in range(len(obs)):
cur_lr = self.learning_rate_schedule.value()
assert cur_lr is not None, "Error: the observation input array cannon be empty"
current_lr = self.learning_rate_schedule.value()

td_map = {
self.train_model.obs_ph: obs,
self.actions_ph: actions,
self.advs_ph: advs,
self.rewards_ph: rewards,
self.learning_rate_ph: current_lr
}

td_map = {self.train_model.obs_ph: obs, self.action_ph: actions, self.advs_ph: advs, self.rewards_ph: rewards,
self.pg_lr_ph: cur_lr}
if states is not None:
td_map[self.train_model.states_ph] = states
td_map[self.train_model.dones_ph] = masks
Expand Down Expand Up @@ -287,7 +301,12 @@ def learn(self, total_timesteps, callback=None, seed=None, log_interval=100, tb_

self.trained = True

runner = A2CRunner(self.env, self, n_steps=self.n_steps, gamma=self.gamma)
# Use GAE
if self.gae_lambda is not None:
runner = PPO2Runner(env=self.env, model=self, n_steps=self.n_steps, gamma=self.gamma, lam=self.gae_lambda)
else:
runner = A2CRunner(self.env, self, n_steps=self.n_steps, gamma=self.gamma)

self.episode_reward = np.zeros((self.n_envs,))

t_start = time.time()
Expand All @@ -302,9 +321,14 @@ def learn(self, total_timesteps, callback=None, seed=None, log_interval=100, tb_

for update in range(1, total_timesteps // self.n_batch + 1):
# true_reward is the reward without discount
obs, states, rewards, masks, actions, values, ep_infos, true_reward = runner.run()
if isinstance(runner, PPO2Runner):
# We are using GAE
obs, returns, masks, actions, values, _, states, ep_infos, true_reward = runner.run()
else:
obs, states, returns, masks, actions, values, ep_infos, true_reward = runner.run()

ep_info_buf.extend(ep_infos)
policy_loss, value_loss, policy_entropy = self._train_step(obs, states, rewards, masks, actions, values,
policy_loss, value_loss, policy_entropy = self._train_step(obs, states, returns, masks, actions, values,
self.num_timesteps // (self.n_batch + 1),
writer)
n_seconds = time.time() - t_start
Expand All @@ -323,7 +347,7 @@ def learn(self, total_timesteps, callback=None, seed=None, log_interval=100, tb_
break

if self.verbose >= 1 and (update % log_interval == 0 or update == 1):
explained_var = explained_variance(values, rewards)
explained_var = explained_variance(values, returns)
logger.record_tabular("nupdates", update)
logger.record_tabular("total_timesteps", self.num_timesteps)
logger.record_tabular("fps", fps)
Expand All @@ -346,6 +370,7 @@ def learn(self, total_timesteps, callback=None, seed=None, log_interval=100, tb_
def save(self, save_path, cloudpickle=False):
data = {
"gamma": self.gamma,
"gae_lambda": self.gae_lambda,
"nprocs": self.nprocs,
"n_steps": self.n_steps,
"vf_coef": self.vf_coef,
Expand All @@ -360,6 +385,7 @@ def save(self, save_path, cloudpickle=False):
"observation_space": self.observation_space,
"action_space": self.action_space,
"n_envs": self.n_envs,
"kfac_update": self.kfac_update,
"_vectorize_action": self._vectorize_action,
"policy_kwargs": self.policy_kwargs
}
Expand Down

0 comments on commit 46d13b4

Please sign in to comment.