### Once you have installed Conda, you will also want to run the following line in your Anaconda terminal:

#### conda install -c conda-forge pyglet

### Import dependencies

In [None]:
!pip install stable-baselines3[extra]

In [1]:
import os
import gym
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3.common.evaluation import evaluate_policy

### Load Environment

In [2]:
environment_name = "CartPole-v0"
env = gym.make(environment_name)

In [None]:
episodes = 5
for episode in range(1, episodes + 1):
    state = env.reset()
    done = False
    score = 0
    
    while not done:
        env.render() # renders environment
        action = env.action_space.sample() # Discrete, picks an action, Move Left / Right
        n_state, reward, done, info = env.step(action)
        score += reward
    print('Episode:{} Score:{}'.format(episode, score))
env.close() # Close environment, can run outside of cell to also close environment

### Train RL Model

In [3]:
# Make Directories To Save Logs (and later Models)
log_path = os.path.join('training', 'logs')

In [4]:
env = gym.make(environment_name)
env = DummyVecEnv([lambda: env])
# Multilayer Perceptron Policy
model = PPO('MlpPolicy', env, verbose=1, tensorboard_log=log_path)

Using cpu device


In [None]:
model.learn(total_timesteps=20000)

### Save and Reload Model

In [5]:
PPO_Path = os.path.join('training', 'saved_models', 'PPO_Model_Cartpole')

In [None]:
model.save(PPO_Path)

In [7]:
del model

In [8]:
model = PPO.load(PPO_Path, env=env)

In [None]:
model.learn(total_timesteps=1000)

### Evaluation

In [None]:
evaluate_policy(model, env, n_eval_episodes=10, render=True)

In [None]:
env.close()

### Test Model

In [None]:
episodes = 5
for episode in range(1, episodes + 1):
    obs = env.reset()
    done = False
    score = 0
    
    while not done:
        env.render() # renders environment
        action, _ = model.predict(obs) # predict returns two, but we only require action
        obs, reward, done, info = env.step(action)
        score += reward
    print('Episode:{} Score:{}'.format(episode, score))
env.close() # Close environment, can run outside of cell to also close environment

### Viewing Logs with Tensorboard
#### Note: This should not be run within Jupyter, as it will freeze the notebook while it runs.
#### However, it can be run within the Notebook purely for demonstration purposes, so it will be included.

In [9]:
training_log_path = os.path.join(log_path, 'PPO_2')

In [None]:
#!tensorboard --logdir={training_log_path}
# if you are running this from jupyter, go to http://localhost:6006 to access

### Callback to Training Stage

In [10]:
from stable_baselines3.common.callbacks import EvalCallback, StopTrainingOnRewardThreshold

In [11]:
save_path = os.path.join('training', 'saved_models')

In [14]:
stop_callback = StopTrainingOnRewardThreshold(reward_threshold=200, verbose=1)
eval_callback = EvalCallback(env,
                             callback_on_new_best=stop_callback,
                             eval_freq=10000,
                             best_model_save_path= save_path,
                             verbose=1)

In [16]:
model = PPO('MlpPolicy', env, verbose=1, tensorboard_log=log_path)

Using cpu device


In [17]:
model.learn(total_timesteps=20000, callback=eval_callback)

Logging to training\logs\PPO_4
-----------------------------
| time/              |      |
|    fps             | 3051 |
|    iterations      | 1    |
|    time_elapsed    | 0    |
|    total_timesteps | 2048 |
-----------------------------
-----------------------------------------
| time/                   |             |
|    fps                  | 1992        |
|    iterations           | 2           |
|    time_elapsed         | 2           |
|    total_timesteps      | 4096        |
| train/                  |             |
|    approx_kl            | 0.009683084 |
|    clip_fraction        | 0.104       |
|    clip_range           | 0.2         |
|    entropy_loss         | -0.686      |
|    explained_variance   | -0.000819   |
|    learning_rate        | 0.0003      |
|    loss                 | 6.81        |
|    n_updates            | 10          |
|    policy_gradient_loss | -0.016      |
|    value_loss           | 51.1        |
-----------------------------------------
---



Eval num_timesteps=10000, episode_reward=182.80 +/- 21.82
Episode length: 182.80 +/- 21.82
------------------------------------------
| eval/                   |              |
|    mean_ep_length       | 183          |
|    mean_reward          | 183          |
| time/                   |              |
|    total_timesteps      | 10000        |
| train/                  |              |
|    approx_kl            | 0.0052961265 |
|    clip_fraction        | 0.0364       |
|    clip_range           | 0.2          |
|    entropy_loss         | -0.621       |
|    explained_variance   | 0.241        |
|    learning_rate        | 0.0003       |
|    loss                 | 27.8         |
|    n_updates            | 40           |
|    policy_gradient_loss | -0.0103      |
|    value_loss           | 60.3         |
------------------------------------------
New best mean reward!
------------------------------
| time/              |       |
|    fps             | 1592  |
|    iterations     

<stable_baselines3.ppo.ppo.PPO at 0x2577fb3b4c0>

### Changing Policies

In [19]:
net_arch = [dict(pi=[128,128,128,128], vf=[128,128,128,128])]

In [22]:
model = PPO('MlpPolicy', env, verbose=1, tensorboard_log=log_path, policy_kwargs={'net_arch':net_arch})

Using cpu device


In [23]:
model.learn(total_timesteps=20000, callback=eval_callback)

Logging to training\logs\PPO_5
-----------------------------
| time/              |      |
|    fps             | 2601 |
|    iterations      | 1    |
|    time_elapsed    | 0    |
|    total_timesteps | 2048 |
-----------------------------
----------------------------------------
| time/                   |            |
|    fps                  | 1359       |
|    iterations           | 2          |
|    time_elapsed         | 3          |
|    total_timesteps      | 4096       |
| train/                  |            |
|    approx_kl            | 0.01528943 |
|    clip_fraction        | 0.212      |
|    clip_range           | 0.2        |
|    entropy_loss         | -0.681     |
|    explained_variance   | -0.00374   |
|    learning_rate        | 0.0003     |
|    loss                 | 3          |
|    n_updates            | 10         |
|    policy_gradient_loss | -0.0258    |
|    value_loss           | 19.4       |
----------------------------------------
---------------------

<stable_baselines3.ppo.ppo.PPO at 0x2577fb402b0>

### Using an Alternative Algorithm

In [24]:
from stable_baselines3 import DQN

In [25]:
model = DQN('MlpPolicy', env, verbose=1, tensorboard_log=log_path)

Using cpu device


In [26]:
model.learn(total_timesteps=20000)

Logging to training\logs\DQN_1
----------------------------------
| rollout/            |          |
|    exploration_rate | 0.959    |
| time/               |          |
|    episodes         | 4        |
|    fps              | 9663     |
|    time_elapsed     | 0        |
|    total_timesteps  | 87       |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration_rate | 0.925    |
| time/               |          |
|    episodes         | 8        |
|    fps              | 10463    |
|    time_elapsed     | 0        |
|    total_timesteps  | 157      |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration_rate | 0.871    |
| time/               |          |
|    episodes         | 12       |
|    fps              | 12314    |
|    time_elapsed     | 0        |
|    total_timesteps  | 271      |
----------------------------------
------------------------

----------------------------------
| rollout/            |          |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 108      |
|    fps              | 13880    |
|    time_elapsed     | 0        |
|    total_timesteps  | 2513     |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 112      |
|    fps              | 13853    |
|    time_elapsed     | 0        |
|    total_timesteps  | 2605     |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 116      |
|    fps              | 13867    |
|    time_elapsed     | 0        |
|    total_timesteps  | 2677     |
----------------------------------
----------------------------------
| rollout/          

----------------------------------
| rollout/            |          |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 216      |
|    fps              | 14008    |
|    time_elapsed     | 0        |
|    total_timesteps  | 4890     |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 220      |
|    fps              | 14027    |
|    time_elapsed     | 0        |
|    total_timesteps  | 4995     |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 224      |
|    fps              | 14070    |
|    time_elapsed     | 0        |
|    total_timesteps  | 5123     |
----------------------------------
----------------------------------
| rollout/          

----------------------------------
| rollout/            |          |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 324      |
|    fps              | 14143    |
|    time_elapsed     | 0        |
|    total_timesteps  | 7427     |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 328      |
|    fps              | 14145    |
|    time_elapsed     | 0        |
|    total_timesteps  | 7499     |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 332      |
|    fps              | 14119    |
|    time_elapsed     | 0        |
|    total_timesteps  | 7570     |
----------------------------------
----------------------------------
| rollout/          

----------------------------------
| rollout/            |          |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 432      |
|    fps              | 14037    |
|    time_elapsed     | 0        |
|    total_timesteps  | 9688     |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 436      |
|    fps              | 14025    |
|    time_elapsed     | 0        |
|    total_timesteps  | 9764     |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 440      |
|    fps              | 14039    |
|    time_elapsed     | 0        |
|    total_timesteps  | 9858     |
----------------------------------
----------------------------------
| rollout/          

----------------------------------
| rollout/            |          |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 540      |
|    fps              | 14025    |
|    time_elapsed     | 0        |
|    total_timesteps  | 12163    |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 544      |
|    fps              | 14034    |
|    time_elapsed     | 0        |
|    total_timesteps  | 12255    |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 548      |
|    fps              | 14028    |
|    time_elapsed     | 0        |
|    total_timesteps  | 12334    |
----------------------------------
----------------------------------
| rollout/          

----------------------------------
| rollout/            |          |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 648      |
|    fps              | 13986    |
|    time_elapsed     | 1        |
|    total_timesteps  | 14395    |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 652      |
|    fps              | 13993    |
|    time_elapsed     | 1        |
|    total_timesteps  | 14529    |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 656      |
|    fps              | 13991    |
|    time_elapsed     | 1        |
|    total_timesteps  | 14596    |
----------------------------------
----------------------------------
| rollout/          

----------------------------------
| rollout/            |          |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 756      |
|    fps              | 14014    |
|    time_elapsed     | 1        |
|    total_timesteps  | 16863    |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 760      |
|    fps              | 14024    |
|    time_elapsed     | 1        |
|    total_timesteps  | 16988    |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 764      |
|    fps              | 14023    |
|    time_elapsed     | 1        |
|    total_timesteps  | 17070    |
----------------------------------
----------------------------------
| rollout/          

----------------------------------
| rollout/            |          |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 864      |
|    fps              | 14053    |
|    time_elapsed     | 1        |
|    total_timesteps  | 19243    |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 868      |
|    fps              | 14044    |
|    time_elapsed     | 1        |
|    total_timesteps  | 19315    |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 872      |
|    fps              | 14039    |
|    time_elapsed     | 1        |
|    total_timesteps  | 19365    |
----------------------------------
----------------------------------
| rollout/          

<stable_baselines3.dqn.dqn.DQN at 0x257001fa520>

In [27]:
# To load DQN, do DQN.load instead of PPO.load