In [1]:
import os
import gym
from IPython import display
import matplotlib
import matplotlib.pyplot as plt
%matplotlib inline

from stable_baselines.common.policies import MlpPolicy
from stable_baselines.common.vec_env import SubprocVecEnv
from stable_baselines.common.vec_env import DummyVecEnv
from utils.ppo import PPO
from utils.models import Policy, CNNPolicy

In [2]:
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="0"

In [3]:
LOGS = os.getcwd()

In [4]:
def makedirs(path):
    if not os.path.exists(path):
        os.makedirs(path)

In [5]:
env_name = 'BipedalWalker-v2'
run_id = 1
n_steps = 250
total_timesteps = 10000000
cnn_policy = False

LOGS = os.path.join(LOGS, env_name, 'run{}'.format(run_id))
makedirs(LOGS)
tb_log = os.path.join(LOGS, 'tb')
makedirs(tb_log)
model_dir = os.path.join(LOGS, 'models')
makedirs(model_dir)
final_model_dir = os.path.join(LOGS, 'model')
n_cpu = 4

In [None]:
env = SubprocVecEnv([lambda: gym.make(env_name) for i in range(n_cpu)])

if cnn_policy:
    print('Using CNN policy network')
    model = PPO(CNNPolicy, env, n_steps=n_steps, tensorboard_log=tb_log, verbose=1, full_tensorboard_log=True)
else:
    print('Using MLP policy network')
    model = PPO(Policy, env, n_steps=n_steps, tensorboard_log=tb_log, verbose=1, full_tensorboard_log=True)
model.learn(total_timesteps, env, save_file=os.path.join(model_dir, 'model'))
model.save(final_model_dir)
del model # remove to demonstrate saving and loading

Using MLP policy network
INFO:tensorflow:Summary name model/pi_fc0/w:0 is illegal; using model/pi_fc0/w_0 instead.
INFO:tensorflow:Summary name model/pi_fc0/b:0 is illegal; using model/pi_fc0/b_0 instead.
INFO:tensorflow:Summary name model/vf_fc0/w:0 is illegal; using model/vf_fc0/w_0 instead.
INFO:tensorflow:Summary name model/vf_fc0/b:0 is illegal; using model/vf_fc0/b_0 instead.
INFO:tensorflow:Summary name model/pi_fc1/w:0 is illegal; using model/pi_fc1/w_0 instead.
INFO:tensorflow:Summary name model/pi_fc1/b:0 is illegal; using model/pi_fc1/b_0 instead.
INFO:tensorflow:Summary name model/vf_fc1/w:0 is illegal; using model/vf_fc1/w_0 instead.
INFO:tensorflow:Summary name model/vf_fc1/b:0 is illegal; using model/vf_fc1/b_0 instead.
INFO:tensorflow:Summary name model/vf/w:0 is illegal; using model/vf/w_0 instead.
INFO:tensorflow:Summary name model/vf/b:0 is illegal; using model/vf/b_0 instead.
INFO:tensorflow:Summary name model/pi/w:0 is illegal; using model/pi/w_0 instead.
INFO:tens

---------------------------------------
| approxkl           | 0.00016729426  |
| clipfrac           | 0.0            |
| explained_variance | -0.0184        |
| fps                | 745            |
| nupdates           | 12             |
| policy_entropy     | 5.6810856      |
| policy_loss        | -0.00082313636 |
| serial_timesteps   | 3000           |
| time_elapsed       | 14.9           |
| total_timesteps    | 12096          |
| value_loss         | 0.08053862     |
---------------------------------------
--------------------------------------
| approxkl           | 0.00068820996 |
| clipfrac           | 0.00075       |
| explained_variance | 0.000791      |
| fps                | 746           |
| nupdates           | 13            |
| policy_entropy     | 5.6729407     |
| policy_loss        | -0.0021678142 |
| serial_timesteps   | 3250          |
| time_elapsed       | 16.3          |
| total_timesteps    | 13104         |
| value_loss         | 0.08885931    |
------------

---------------------------------------
| approxkl           | 0.0014348213   |
| clipfrac           | 0.0022500001   |
| explained_variance | 0.0662         |
| fps                | 775            |
| nupdates           | 29             |
| policy_entropy     | 5.659297       |
| policy_loss        | -0.00046511763 |
| serial_timesteps   | 7250           |
| time_elapsed       | 36.9           |
| total_timesteps    | 29232          |
| value_loss         | 38.23678       |
---------------------------------------
-------------------------------------
| approxkl           | 0.0001404727 |
| clipfrac           | 0.0          |
| explained_variance | 0.0516       |
| fps                | 774          |
| nupdates           | 30           |
| policy_entropy     | 5.659305     |
| policy_loss        | 6.357266e-05 |
| serial_timesteps   | 7500         |
| time_elapsed       | 38.2         |
| total_timesteps    | 30240        |
| value_loss         | 38.58839     |
------------------------

--------------------------------------
| approxkl           | 0.00029951174 |
| clipfrac           | 0.0           |
| explained_variance | 0.0576        |
| fps                | 782           |
| nupdates           | 46            |
| policy_entropy     | 5.6507587     |
| policy_loss        | -0.00101497   |
| serial_timesteps   | 11500         |
| time_elapsed       | 58            |
| total_timesteps    | 46368         |
| value_loss         | 76.03463      |
--------------------------------------
--------------------------------------
| approxkl           | 0.00012220336 |
| clipfrac           | 0.0           |
| explained_variance | 0.0972        |
| fps                | 793           |
| nupdates           | 47            |
| policy_entropy     | 5.651762      |
| policy_loss        | 0.00023262878 |
| serial_timesteps   | 11750         |
| time_elapsed       | 59.3          |
| total_timesteps    | 47376         |
| value_loss         | 40.462116     |
-------------------------

--------------------------------------
| approxkl           | 0.00010679554 |
| clipfrac           | 0.0           |
| explained_variance | 0.218         |
| fps                | 812           |
| nupdates           | 63            |
| policy_entropy     | 5.6615148     |
| policy_loss        | -0.0004710285 |
| serial_timesteps   | 15750         |
| time_elapsed       | 79.8          |
| total_timesteps    | 63504         |
| value_loss         | 33.90655      |
--------------------------------------
--------------------------------------
| approxkl           | 5.731637e-05  |
| clipfrac           | 0.0           |
| explained_variance | 0.0549        |
| fps                | 804           |
| nupdates           | 64            |
| policy_entropy     | 5.6630874     |
| policy_loss        | -0.0013361012 |
| serial_timesteps   | 16000         |
| time_elapsed       | 81            |
| total_timesteps    | 64512         |
| value_loss         | 39.330708     |
-------------------------

--------------------------------------
| approxkl           | 0.0030506265  |
| clipfrac           | 0.03125       |
| explained_variance | -0.385        |
| fps                | 797           |
| nupdates           | 80            |
| policy_entropy     | 5.641277      |
| policy_loss        | -0.0035208256 |
| serial_timesteps   | 20000         |
| time_elapsed       | 101           |
| total_timesteps    | 80640         |
| value_loss         | 0.27508724    |
--------------------------------------
--------------------------------------
| approxkl           | 0.0021580895  |
| clipfrac           | 0.0165        |
| explained_variance | 0.429         |
| fps                | 793           |
| nupdates           | 81            |
| policy_entropy     | 5.6350284     |
| policy_loss        | -0.0040757367 |
| serial_timesteps   | 20250         |
| time_elapsed       | 102           |
| total_timesteps    | 81648         |
| value_loss         | 0.18411785    |
-------------------------

-------------------------------------
| approxkl           | 0.0052765436 |
| clipfrac           | 0.070250005  |
| explained_variance | 0.204        |
| fps                | 814          |
| nupdates           | 97           |
| policy_entropy     | 5.5493717    |
| policy_loss        | -0.0033391   |
| serial_timesteps   | 24250        |
| time_elapsed       | 122          |
| total_timesteps    | 97776        |
| value_loss         | 0.0315604    |
-------------------------------------
---------------------------------------
| approxkl           | 0.0017424407   |
| clipfrac           | 0.00275        |
| explained_variance | 0.0199         |
| fps                | 803            |
| nupdates           | 98             |
| policy_entropy     | 5.540776       |
| policy_loss        | -6.0184204e-05 |
| serial_timesteps   | 24500          |
| time_elapsed       | 124            |
| total_timesteps    | 98784          |
| value_loss         | 0.018127859    |
--------------------------

--------------------------------------
| approxkl           | 0.0024882406  |
| clipfrac           | 0.02425       |
| explained_variance | 0.446         |
| fps                | 796           |
| nupdates           | 114           |
| policy_entropy     | 5.459044      |
| policy_loss        | -0.0033968776 |
| serial_timesteps   | 28500         |
| time_elapsed       | 144           |
| total_timesteps    | 114912        |
| value_loss         | 0.025531346   |
--------------------------------------
--------------------------------------
| approxkl           | 0.001844948   |
| clipfrac           | 0.01          |
| explained_variance | 0.164         |
| fps                | 847           |
| nupdates           | 115           |
| policy_entropy     | 5.450992      |
| policy_loss        | -0.0023449783 |
| serial_timesteps   | 28750         |
| time_elapsed       | 145           |
| total_timesteps    | 115920        |
| value_loss         | 0.20830977    |
-------------------------

--------------------------------------
| approxkl           | 0.001988983   |
| clipfrac           | 0.01675       |
| explained_variance | 0.319         |
| fps                | 812           |
| nupdates           | 131           |
| policy_entropy     | 5.376182      |
| policy_loss        | -0.0017768061 |
| serial_timesteps   | 32750         |
| time_elapsed       | 165           |
| total_timesteps    | 132048        |
| value_loss         | 0.039794628   |
--------------------------------------
--------------------------------------
| approxkl           | 0.0020914925  |
| clipfrac           | 0.0105        |
| explained_variance | 0.178         |
| fps                | 817           |
| nupdates           | 132           |
| policy_entropy     | 5.377905      |
| policy_loss        | -0.0028885778 |
| serial_timesteps   | 33000         |
| time_elapsed       | 166           |
| total_timesteps    | 133056        |
| value_loss         | 0.27457625    |
-------------------------

In [None]:
from stable_baselines.common.policies import FeedForwardPolicy

In [None]:
model = PPO2.load(final_model_dir)

# Enjoy trained agent
env = gym.make(env_name)
obs = env.reset()
while True:
    action, _states = model.predict(obs)
    obs, rewards, dones, info = env.step(action)
    env.render()