Skip to content

Commit

Permalink
Update PyBullet example
Browse files Browse the repository at this point in the history
  • Loading branch information
araffin committed May 9, 2020
1 parent b1f5db1 commit a06c4a7
Showing 1 changed file with 20 additions and 4 deletions.
24 changes: 20 additions & 4 deletions docs/guide/examples.rst
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@ PyBullet: Normalizing input features

Normalizing input features may be essential to successful training of an RL agent
(by default, images are scaled but not other types of input),
for instance when training on `PyBullet <https://github.com/bulletphysics/bullet3/>`_. For that, a wrapper exists and
for instance when training on `PyBullet <https://github.com/bulletphysics/bullet3/>`_ environments. For that, a wrapper exists and
will compute a running average and standard deviation of input features (it can do the same for rewards).


Expand All @@ -311,12 +311,13 @@ will compute a running average and standard deviation of input features (it can
.. code-block:: python
import gym
import pybullet_envs
from stable_baselines3.common.vec_env import DummyVecEnv, VecNormalize
from stable_baselines3 import PPO
env = DummyVecEnv([lambda: gym.make("HalfCheetahBulletEnv-v0")])
# Automatically normalize the input features
# Automatically normalize the input features and reward
env = VecNormalize(env, norm_obs=True, norm_reward=True,
clip_obs=10.)
Expand All @@ -325,8 +326,23 @@ will compute a running average and standard deviation of input features (it can
# Don't forget to save the VecNormalize statistics when saving the agent
log_dir = "/tmp/"
model.save(log_dir + "ppo_reacher")
env.save(os.path.join(log_dir, "vec_normalize.pkl"))
model.save(log_dir + "ppo_halfcheetah")
stats_path = os.path.join(log_dir, "vec_normalize.pkl")
env.save(stats_path)
# To demonstrate loading
del model, env
# Load the agent
model = PPO.load(log_dir + "ppo_halfcheetah")
# Load the saved statistics
env = DummyVecEnv([lambda: gym.make("HalfCheetahBulletEnv-v0")])
env = VecNormalize.load(stats_path, env)
# do not update them at test time
env.training = False
# reward normalization is not needed at test time
env.norm_reward = False
Record a Video
Expand Down

0 comments on commit a06c4a7

Please sign in to comment.