Skip to content

Commit

Permalink
Documentation update + Nan wrapper (#358)
Browse files Browse the repository at this point in the history
* improved venenv doc

* updated dummy_vec_env doc + improved vec_env doc

* added VecCheckNan

* added checking nan guide

* added test

* added hyperparam warning to doc

* clean up and typos

* codacy fixes + cleanup + changelog

* hotfix

* fix test

* Update docs/guide/checking_nan.rst

Co-Authored-By: Antonin RAFFIN <antonin.raffin@ensta.org>

* Update docs/guide/checking_nan.rst

Co-Authored-By: Antonin RAFFIN <antonin.raffin@ensta.org>

* Update docs/guide/checking_nan.rst

Co-Authored-By: Antonin RAFFIN <antonin.raffin@ensta.org>

* Update docs/guide/checking_nan.rst

Co-Authored-By: Antonin RAFFIN <antonin.raffin@ensta.org>

* Update docs/guide/checking_nan.rst

Co-Authored-By: Antonin RAFFIN <antonin.raffin@ensta.org>

* Update docs/guide/checking_nan.rst

Co-Authored-By: Antonin RAFFIN <antonin.raffin@ensta.org>

* Update docs/guide/checking_nan.rst

Co-Authored-By: Antonin RAFFIN <antonin.raffin@ensta.org>

* fixed VecCheckNan exception only called once

* add tf NaN debugging options to the NaN guide
  • Loading branch information
hill-a authored and araffin committed Jun 11, 2019
1 parent 72dab6a commit 4db0868
Show file tree
Hide file tree
Showing 9 changed files with 433 additions and 5 deletions.
253 changes: 253 additions & 0 deletions docs/guide/checking_nan.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,253 @@
Dealing with NaNs and infs
==========================

During the training of a model on a given environment, it is possible that the RL model becomes completely
corrupted when a NaN or an inf is given or returned from the RL model.

How and why?
------------

The issue arises then NaNs or infs do not crash, but simply get propagated through the training,
until all the floating point number converge to NaN or inf. This is in line with the
`IEEE Standard for Floating-Point Arithmetic (IEEE 754) <https://ieeexplore.ieee.org/document/4610935>`_ standard, as it says:

.. note::
Five possible exceptions can occur:
- Invalid operation (:math:`\sqrt{-1}`, :math:`\inf \times 1`, :math:`\text{NaN}\ \mathrm{mod}\ 1`, ...) return NaN
- Division by zero:
- if the operand is not zero (:math:`1/0`, :math:`-2/0`, ...) returns :math:`\pm\inf`
- if the operand is zero (:math:`0/0`) returns signaling NaN
- Overflow (exponent too high to represent) returns :math:`\pm\inf`
- Underflow (exponent too low to represent) returns :math:`0`
- Inexact (not representable exactly in base 2, eg: :math:`1/5`) returns the rounded value (ex: :code:`assert (1/5) * 3 == 0.6000000000000001`)

And of these, only ``Division by zero`` will signal an exception, the rest will propagate invalid values quietly.

In python, dividing by zero will indeed raise the exception: ``ZeroDivisionError: float division by zero``,
but ignores the rest.

The default in numpy, will warn: ``RuntimeWarning: invalid value encountered``
but will not halt the code.

And the worst of all, Tensorflow will not signal anything

.. code-block:: python
import tensorflow as tf
import numpy as np
print("tensorflow test:")
a = tf.constant(1.0)
b = tf.constant(0.0)
c = a / b
sess = tf.Session()
val = sess.run(c) # this will be quiet
print(val)
sess.close()
print("\r\nnumpy test:")
a = np.float64(1.0)
b = np.float64(0.0)
val = a / b # this will warn
print(val)
print("\r\npure python test:")
a = 1.0
b = 0.0
val = a / b # this will raise an exception and halt.
print(val)
Unfortunately, most of the floating point operations are handled by Tensorflow and numpy,
meaning you might get little to no warning when a invalid value occurs.

Numpy parameters
----------------

Numpy has a convenient way of dealing with invalid value: `numpy.seterr <https://docs.scipy.org/doc/numpy/reference/generated/numpy.seterr.html>`_,
which defines for the python process, how it should handle floating point error.

.. code-block:: python
import numpy as np
np.seterr(all='raise') # define before your code.
print("numpy test:")
a = np.float64(1.0)
b = np.float64(0.0)
val = a / b # this will now raise an exception instead of a warning.
print(val)
but this will also avoid overflow issues on floating point numbers:

.. code-block:: python
import numpy as np
np.seterr(all='raise') # define before your code.
print("numpy overflow test:")
a = np.float64(10)
b = np.float64(1000)
val = a ** b # this will now raise an exception
print(val)
but will not avoid the propagation issues:

.. code-block:: python
import numpy as np
np.seterr(all='raise') # define before your code.
print("numpy propagation test:")
a = np.float64('NaN')
b = np.float64(1.0)
val = a + b # this will neither warn nor raise anything
print(val)
Tensorflow parameters
---------------------

Tensorflow can add checks for detecting and dealing with invalid value: `tf.add_check_numerics_ops <https://www.tensorflow.org/api_docs/python/tf/add_check_numerics_ops>`_ and `tf.check_numerics <https://www.tensorflow.org/api_docs/python/tf/debugging/check_numerics>`_,
however they will add operations to the Tensorflow graph and raise the computation time.

.. code-block:: python
import tensorflow as tf
print("tensorflow test:")
a = tf.constant(1.0)
b = tf.constant(0.0)
c = a / b
check_nan = tf.add_check_numerics_ops() # add after your graph definition.
sess = tf.Session()
val, _ = sess.run([c, check_nan]) # this will now raise an exception
print(val)
sess.close()
but this will also avoid overflow issues on floating point numbers:

.. code-block:: python
import tensorflow as tf
print("tensorflow overflow test:")
check_nan = [] # the list of check_numerics operations
a = tf.constant(10)
b = tf.constant(1000)
c = a ** b
check_nan.append(tf.check_numerics(c, "")) # check the 'c' operations
sess = tf.Session()
val, _ = sess.run([c] + check_nan) # this will now raise an exception
print(val)
sess.close()
and catch propagation issues:

.. code-block:: python
import tensorflow as tf
print("tensorflow propagation test:")
check_nan = [] # the list of check_numerics operations
a = tf.constant('NaN')
b = tf.constant(1.0)
c = a + b
check_nan.append(tf.check_numerics(c, "")) # check the 'c' operations
sess = tf.Session()
val, _ = sess.run([c] + check_nan) # this will now raise an exception
print(val)
sess.close()
VecCheckNan Wrapper
-------------------

In order to find when and from where the invalid value originated from, stable-baselines comes with a ``VecCheckNan`` wrapper.

It will monitor the actions, observations, and rewards, indicating what action or observation caused it and from what.

.. code-block:: python
import gym
from gym import spaces
import numpy as np
from stable_baselines import PPO2
from stable_baselines.common.vec_env import DummyVecEnv, VecCheckNan
class NanAndInfEnv(gym.Env):
"""Custom Environment that raised NaNs and Infs"""
metadata = {'render.modes': ['human']}
def __init__(self):
super(NanAndInfEnv, self).__init__()
self.action_space = spaces.Box(low=-np.inf, high=np.inf, shape=(1,), dtype=np.float64)
self.observation_space = spaces.Box(low=-np.inf, high=np.inf, shape=(1,), dtype=np.float64)
def step(self, _action):
randf = np.random.rand()
if randf > 0.99:
obs = float('NaN')
elif randf > 0.98:
obs = float('inf')
else:
obs = randf
return [obs], 0.0, False, {}
def reset(self):
return [0.0]
def render(self, mode='human', close=False):
pass
# Create environment
env = DummyVecEnv([lambda: NanAndInfEnv()])
env = VecCheckNan(env, raise_exception=True)
# Instantiate the agent
model = PPO2('MlpPolicy', env)
# Train the agent
model.learn(total_timesteps=int(2e5)) # this will crash explaining that the invalid value originated from the environment.
RL Model hyperparameters
------------------------

Depending on your hyperparameters, NaN can occurs much more often.
A great example of this: https://github.com/hill-a/stable-baselines/issues/340

Be aware, the hyperparameters given by default seem to work in most cases,
however your environment might not play nice with them.
If this is the case, try to read up on the effect each hyperparameters has on the model,
so that you can try and tune them to get a stable model. Alternatively, you can try automatic hyperparameter tuning (included in the rl zoo).

Missing values from datasets
----------------------------

If your environment is generated from an external dataset, do not forget to make sure your dataset does not contain NaNs.
As some datasets will sometimes fill missing values with NaNs as a surrogate value.

Here is some reading material about finding NaNs: https://pandas.pydata.org/pandas-docs/stable/user_guide/missing_data.html

And filling the missing values with something else (imputation): https://towardsdatascience.com/how-to-handle-missing-data-8646b18db0d4

11 changes: 9 additions & 2 deletions docs/guide/vec_envs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
Vectorized Environments
=======================

Vectorized Environments are a method for multiprocess training. Instead of training an RL agent
on 1 environment, it allows us to train it on `n` environments using `n` processes.
Vectorized Environments are a method for stacking multiple independent environments into a single environment.
Instead of training an RL agent on 1 environment per step, it allows us to train it on `n` environments per step.
Because of this, `actions` passed to the environment are now a vector (of dimension `n`).
It is the same for `observations`, `rewards` and end of episode signals (`dones`).
In the case of non-array observation spaces such as `Dict` or `Tuple`, where different sub-spaces
Expand Down Expand Up @@ -69,3 +69,10 @@ VecVideoRecorder

.. autoclass:: VecVideoRecorder
:members:


VecCheckNan
~~~~~~~~~~~~~~~~

.. autoclass:: VecCheckNan
:members:
1 change: 1 addition & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ This toolset is a fork of OpenAI Baselines, with a major structural refactoring,
guide/tensorboard
guide/rl_zoo
guide/pretrain
guide/checking_nan


.. toctree::
Expand Down
5 changes: 4 additions & 1 deletion docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ Changelog
For download links, please look at `Github release page <https://github.com/hill-a/stable-baselines/releases>`_.


Pre-Release 2.6.0a0 (WIP)
Pre-Release 2.6.0a1 (WIP)
-------------------------

**Hindsight Experience Replay (HER) - Reloaded | get/load parameters**
Expand Down Expand Up @@ -35,6 +35,9 @@ Pre-Release 2.6.0a0 (WIP)
``find_trainable_params`` was returning all trainable variables, discarding the scope argument.
This bug was causing the model to save duplicated parameters (for DDPG and SAC)
but did not affect the performance.
- added guide for managing ``NaN`` and ``inf``
- added ``VecCheckNan`` wrapper
- updated ven_env doc

**Breaking Change:** DDPG replay buffer was unified with DQN/SAC replay buffer. As a result,
when loading a DDPG model trained with stable_baselines<2.6.0, it throws an import error.
Expand Down
1 change: 1 addition & 0 deletions stable_baselines/common/vec_env/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@
from stable_baselines.common.vec_env.vec_frame_stack import VecFrameStack
from stable_baselines.common.vec_env.vec_normalize import VecNormalize
from stable_baselines.common.vec_env.vec_video_recorder import VecVideoRecorder
from stable_baselines.common.vec_env.vec_check_nan import VecCheckNan
5 changes: 4 additions & 1 deletion stable_baselines/common/vec_env/dummy_vec_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@

class DummyVecEnv(VecEnv):
"""
Creates a simple vectorized wrapper for multiple environments
Creates a simple vectorized wrapper for multiple environments, calling each environment in sequence on the current
Python process. This is useful for computationally simple environment such as ``cartpole-v1``, as the overhead of
multiprocess or multithread outweighs the environment computation time. This can also be used for RL methods that
require a vectorized environment, but that you want a single environments to train with.
:param env_fns: ([Gym Environment]) the list of environments to vectorize
"""
Expand Down
6 changes: 5 additions & 1 deletion stable_baselines/common/vec_env/subproc_vec_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,11 @@ def _worker(remote, parent_remote, env_fn_wrapper):

class SubprocVecEnv(VecEnv):
"""
Creates a multiprocess vectorized wrapper for multiple environments
Creates a multiprocess vectorized wrapper for multiple environments, distributing each environment to its own
process, allowing significant speed up when the environment is computationally complex.
For performance reasons, if your environment is not IO bound, the number of environments should not exceed the
number of logical cores on your CPU.
.. warning::
Expand Down

0 comments on commit 4db0868

Please sign in to comment.