In [1]:
import os
import gym
from IPython import display
import matplotlib
import matplotlib.pyplot as plt
%matplotlib inline

from stable_baselines.common.policies import MlpPolicy
from stable_baselines.common.vec_env import SubprocVecEnv
from stable_baselines.common.vec_env import DummyVecEnv
from utils.ppo import PPO
from utils.models import Policy, CNNPolicy

In [2]:
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="0"

In [3]:
LOGS = os.getcwd()

In [4]:
def makedirs(path):
    if not os.path.exists(path):
        os.makedirs(path)

In [5]:
env_name = 'BipedalWalkerHardcore-v2'
run_id = 1
n_steps = 250
total_timesteps = 10000000
cnn_policy = False

LOGS = os.path.join(LOGS, env_name, 'run{}'.format(run_id))
makedirs(LOGS)
tb_log = os.path.join(LOGS, 'tb')
makedirs(tb_log)
model_dir = os.path.join(LOGS, 'models')
makedirs(model_dir)
final_model_dir = os.path.join(LOGS, 'model')
n_cpu = 4

In [None]:
env = SubprocVecEnv([lambda: gym.make(env_name) for i in range(n_cpu)])

if cnn_policy:
    print('Using CNN policy network')
    model = PPO(CNNPolicy, env, n_steps=n_steps, tensorboard_log=tb_log, verbose=1, full_tensorboard_log=True)
else:
    print('Using MLP policy network')
    model = PPO(Policy, env, n_steps=n_steps, tensorboard_log=tb_log, verbose=1, full_tensorboard_log=True)
model.learn(total_timesteps, env, save_file=os.path.join(model_dir, 'model'))
model.save(final_model_dir)
del model # remove to demonstrate saving and loading

Using MLP policy network
INFO:tensorflow:Summary name model/pi_fc0/w:0 is illegal; using model/pi_fc0/w_0 instead.
INFO:tensorflow:Summary name model/pi_fc0/b:0 is illegal; using model/pi_fc0/b_0 instead.
INFO:tensorflow:Summary name model/vf_fc0/w:0 is illegal; using model/vf_fc0/w_0 instead.
INFO:tensorflow:Summary name model/vf_fc0/b:0 is illegal; using model/vf_fc0/b_0 instead.
INFO:tensorflow:Summary name model/pi_fc1/w:0 is illegal; using model/pi_fc1/w_0 instead.
INFO:tensorflow:Summary name model/pi_fc1/b:0 is illegal; using model/pi_fc1/b_0 instead.
INFO:tensorflow:Summary name model/vf_fc1/w:0 is illegal; using model/vf_fc1/w_0 instead.
INFO:tensorflow:Summary name model/vf_fc1/b:0 is illegal; using model/vf_fc1/b_0 instead.
INFO:tensorflow:Summary name model/vf/w:0 is illegal; using model/vf/w_0 instead.
INFO:tensorflow:Summary name model/vf/b:0 is illegal; using model/vf/b_0 instead.
INFO:tensorflow:Summary name model/pi/w:0 is illegal; using model/pi/w_0 instead.
INFO:tens

---------------------------------------
| approxkl           | 0.0002024498   |
| clipfrac           | 0.0            |
| explained_variance | 0.0195         |
| fps                | 884            |
| nupdates           | 12             |
| policy_entropy     | 5.6747303      |
| policy_loss        | -0.00074379315 |
| serial_timesteps   | 3000           |
| time_elapsed       | 14.2           |
| total_timesteps    | 12096          |
| value_loss         | 43.431786      |
---------------------------------------
--------------------------------------
| approxkl           | 0.0010845546  |
| clipfrac           | 0.007         |
| explained_variance | -0.0109       |
| fps                | 839           |
| nupdates           | 13            |
| policy_entropy     | 5.6725097     |
| policy_loss        | -0.0031518368 |
| serial_timesteps   | 3250          |
| time_elapsed       | 15.3          |
| total_timesteps    | 13104         |
| value_loss         | 0.17734638    |
------------

---------------------------------------
| approxkl           | 0.00019148899  |
| clipfrac           | 0.0            |
| explained_variance | -0.00583       |
| fps                | 782            |
| nupdates           | 29             |
| policy_entropy     | 5.6696377      |
| policy_loss        | -0.00073264155 |
| serial_timesteps   | 7250           |
| time_elapsed       | 34.6           |
| total_timesteps    | 29232          |
| value_loss         | 78.4402        |
---------------------------------------
---------------------------------------
| approxkl           | 0.00015706167  |
| clipfrac           | 0.0            |
| explained_variance | 0.0145         |
| fps                | 779            |
| nupdates           | 30             |
| policy_entropy     | 5.6693597      |
| policy_loss        | -0.00039197446 |
| serial_timesteps   | 7500           |
| time_elapsed       | 35.9           |
| total_timesteps    | 30240          |
| value_loss         | 158.92212      |


---------------------------------------
| approxkl           | 0.00027153007  |
| clipfrac           | 0.0            |
| explained_variance | 0.366          |
| fps                | 786            |
| nupdates           | 46             |
| policy_entropy     | 5.6813197      |
| policy_loss        | -0.00042371312 |
| serial_timesteps   | 11500          |
| time_elapsed       | 56.7           |
| total_timesteps    | 46368          |
| value_loss         | 0.14923888     |
---------------------------------------
--------------------------------------
| approxkl           | 0.00053265644 |
| clipfrac           | 0.0           |
| explained_variance | 0.0778        |
| fps                | 760           |
| nupdates           | 47            |
| policy_entropy     | 5.6836615     |
| policy_loss        | -4.312117e-05 |
| serial_timesteps   | 11750         |
| time_elapsed       | 57.9          |
| total_timesteps    | 47376         |
| value_loss         | 73.99603      |
------------

-------------------------------------
| approxkl           | 0.0025882586 |
| clipfrac           | 0.024        |
| explained_variance | 0.226        |
| fps                | 849          |
| nupdates           | 62           |
| policy_entropy     | 5.6921334    |
| policy_loss        | -0.004202494 |
| serial_timesteps   | 15500        |
| time_elapsed       | 77.3         |
| total_timesteps    | 62496        |
| value_loss         | 0.25780633   |
-------------------------------------
--------------------------------------
| approxkl           | 0.0026397307  |
| clipfrac           | 0.022         |
| explained_variance | 0.749         |
| fps                | 810           |
| nupdates           | 63            |
| policy_entropy     | 5.689882      |
| policy_loss        | -0.0049543804 |
| serial_timesteps   | 15750         |
| time_elapsed       | 78.4          |
| total_timesteps    | 63504         |
| value_loss         | 0.07321389    |
--------------------------------------

--------------------------------------
| approxkl           | 0.0018244595  |
| clipfrac           | 0.010749999   |
| explained_variance | 0.123         |
| fps                | 831           |
| nupdates           | 79            |
| policy_entropy     | 5.652226      |
| policy_loss        | -0.0031659324 |
| serial_timesteps   | 19750         |
| time_elapsed       | 98.4          |
| total_timesteps    | 79632         |
| value_loss         | 0.032866474   |
--------------------------------------
-------------------------------------
| approxkl           | 0.0032063571 |
| clipfrac           | 0.02975      |
| explained_variance | -0.0421      |
| fps                | 771          |
| nupdates           | 80           |
| policy_entropy     | 5.642788     |
| policy_loss        | -0.004558313 |
| serial_timesteps   | 20000        |
| time_elapsed       | 99.6         |
| total_timesteps    | 80640        |
| value_loss         | 0.03936811   |
-------------------------------------

--------------------------------------
| approxkl           | 0.00071773847 |
| clipfrac           | 0.0005        |
| explained_variance | 0.218         |
| fps                | 797           |
| nupdates           | 96            |
| policy_entropy     | 5.596282      |
| policy_loss        | -0.0006662999 |
| serial_timesteps   | 24000         |
| time_elapsed       | 119           |
| total_timesteps    | 96768         |
| value_loss         | 0.052824542   |
--------------------------------------
-------------------------------------
| approxkl           | 0.0034904624 |
| clipfrac           | 0.04125      |
| explained_variance | 0.435        |
| fps                | 810          |
| nupdates           | 97           |
| policy_entropy     | 5.59761      |
| policy_loss        | -0.003650211 |
| serial_timesteps   | 24250        |
| time_elapsed       | 121          |
| total_timesteps    | 97776        |
| value_loss         | 0.033023007  |
-------------------------------------

---------------------------------------
| approxkl           | 0.0002962334   |
| clipfrac           | 0.0            |
| explained_variance | 0.133          |
| fps                | 794            |
| nupdates           | 113            |
| policy_entropy     | 5.5791764      |
| policy_loss        | -0.00086736376 |
| serial_timesteps   | 28250          |
| time_elapsed       | 141            |
| total_timesteps    | 113904         |
| value_loss         | 34.49879       |
---------------------------------------
--------------------------------------
| approxkl           | 0.00024945647 |
| clipfrac           | 0.0           |
| explained_variance | 0.0768        |
| fps                | 787           |
| nupdates           | 114           |
| policy_entropy     | 5.5795326     |
| policy_loss        | 0.00089018117 |
| serial_timesteps   | 28500         |
| time_elapsed       | 142           |
| total_timesteps    | 114912        |
| value_loss         | 71.11712      |
------------

--------------------------------------
| approxkl           | 0.00079272594 |
| clipfrac           | 0.001         |
| explained_variance | 0.673         |
| fps                | 809           |
| nupdates           | 130           |
| policy_entropy     | 5.589503      |
| policy_loss        | -0.0016231511 |
| serial_timesteps   | 32500         |
| time_elapsed       | 162           |
| total_timesteps    | 131040        |
| value_loss         | 0.27895746    |
--------------------------------------
--------------------------------------
| approxkl           | 0.001393703   |
| clipfrac           | 0.0037500001  |
| explained_variance | 0.543         |
| fps                | 786           |
| nupdates           | 131           |
| policy_entropy     | 5.587889      |
| policy_loss        | -0.0011804511 |
| serial_timesteps   | 32750         |
| time_elapsed       | 163           |
| total_timesteps    | 132048        |
| value_loss         | 0.63443613    |
-------------------------

--------------------------------------
| approxkl           | 0.005012266   |
| clipfrac           | 0.066         |
| explained_variance | 0.75          |
| fps                | 864           |
| nupdates           | 147           |
| policy_entropy     | 5.5821033     |
| policy_loss        | -0.0046379385 |
| serial_timesteps   | 36750         |
| time_elapsed       | 183           |
| total_timesteps    | 148176        |
| value_loss         | 0.14455383    |
--------------------------------------
--------------------------------------
| approxkl           | 0.0015966641  |
| clipfrac           | 0.00625       |
| explained_variance | 0.34          |
| fps                | 799           |
| nupdates           | 148           |
| policy_entropy     | 5.578298      |
| policy_loss        | -0.0020052928 |
| serial_timesteps   | 37000         |
| time_elapsed       | 184           |
| total_timesteps    | 149184        |
| value_loss         | 0.060422286   |
-------------------------

In [None]:
from stable_baselines.common.policies import FeedForwardPolicy

In [None]:
model = PPO2.load(final_model_dir)

# Enjoy trained agent
env = gym.make(env_name)
obs = env.reset()
while True:
    action, _states = model.predict(obs)
    obs, rewards, dones, info = env.step(action)
    env.render()