In [1]:
!pip install stable-baselines3[extra]



In [29]:
import gym 
import os
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3.common.evaluation import evaluate_policy

# Load Environment

In [30]:
environment_name = "CartPole-v0"

In [31]:
env = gym.make(environment_name)

In [32]:
episodes = 5
for episode in range(1, episodes+1):
    state = env.reset()
    done = False
    score = 0
    
    while not done:
        env.render()
        action = env.action_space.sample()
        n_state, reward, done, info = env.step(action)
        score += reward
    print('Episode:{} Score:{}'.format(episode, score))
env.close() # To close down the render frame

Episode:1 Score:23.0
Episode:2 Score:21.0
Episode:3 Score:22.0
Episode:4 Score:48.0
Episode:5 Score:38.0


In [33]:
env.reset() # reset the environment and obtain intial observations

array([-0.01236899, -0.037256  ,  0.0390576 ,  0.0159925 ])

In [34]:
env.render() # Used to visualize the environment

True

In [35]:
env.action_space 
# output : Discrete(2) means either 0 or 1

Discrete(2)

In [36]:
env.action_space.sample()  # It will randomly choose either 0 or 1

1

In [37]:
env.observation_space

Box(-3.4028234663852886e+38, 3.4028234663852886e+38, (4,), float32)

In [38]:
env.observation_space.sample()

array([-8.4233397e-01,  2.9248454e+38, -3.8003543e-01, -1.0028530e+37],
      dtype=float32)

In [39]:
env.step(1) # apply an action to the environment
# env.step returns four parameters, namely observation, reward, done and info.

(array([-0.01311411,  0.15728468,  0.03937745, -0.26411597]), 1.0, False, {})

# Train an RL model

In [40]:
log_path = os.path.join('Training', 'Logs')

In [41]:
log_path

'Training\\Logs'

In [43]:
env = gym.make(environment_name)
env = DummyVecEnv([lambda: env])
model = PPO('MlpPolicy', env, verbose=1, tensorboard_log=log_path)

Using cuda device


In [44]:
model.learn(total_timesteps=20000)

Exception ignored in: <function Viewer.__del__ at 0x00000163F1379820>
Traceback (most recent call last):
  File "c:\users\ankit\appdata\local\programs\python\python39\lib\site-packages\gym\envs\classic_control\rendering.py", line 165, in __del__
    self.close()
  File "c:\users\ankit\appdata\local\programs\python\python39\lib\site-packages\gym\envs\classic_control\rendering.py", line 83, in close
    self.window.close()
  File "c:\users\ankit\appdata\local\programs\python\python39\lib\site-packages\pyglet\window\win32\__init__.py", line 319, in close
    super(Win32Window, self).close()
  File "c:\users\ankit\appdata\local\programs\python\python39\lib\site-packages\pyglet\window\__init__.py", line 838, in close
    app.windows.remove(self)
  File "c:\users\ankit\appdata\local\programs\python\python39\lib\_weakrefset.py", line 110, in remove
    self.data.remove(ref(item))
KeyError: <weakref at 0x00000163F74D1D10; to 'Win32Window' at 0x00000163F63D7E80>


INFO:tensorflow:Enabling eager execution
INFO:tensorflow:Enabling v2 tensorshape
INFO:tensorflow:Enabling resource variables
INFO:tensorflow:Enabling tensor equality
INFO:tensorflow:Enabling control flow v2
Logging to Training\Logs\PPO_1
-----------------------------
| time/              |      |
|    fps             | 144  |
|    iterations      | 1    |
|    time_elapsed    | 14   |
|    total_timesteps | 2048 |
-----------------------------
-----------------------------------------
| time/                   |             |
|    fps                  | 189         |
|    iterations           | 2           |
|    time_elapsed         | 21          |
|    total_timesteps      | 4096        |
| train/                  |             |
|    approx_kl            | 0.009104459 |
|    clip_fraction        | 0.125       |
|    clip_range           | 0.2         |
|    entropy_loss         | -0.686      |
|    explained_variance   | -0.00483    |
|    learning_rate        | 0.0003      |
|    l

<stable_baselines3.ppo.ppo.PPO at 0x163f1328cd0>

# Save and Reload model

In [45]:
PPO_Path = os.path.join('Training', 'Saved Models', 'PP_Model_Cartpole')

In [46]:
model.save(PPO_Path)

In [47]:
del model

In [48]:
model = PPO.load(PPO_Path, env=env)

In [49]:
model.learn(total_timesteps=1000)

Logging to Training\Logs\PPO_2
-----------------------------
| time/              |      |
|    fps             | 508  |
|    iterations      | 1    |
|    time_elapsed    | 4    |
|    total_timesteps | 2048 |
-----------------------------


<stable_baselines3.ppo.ppo.PPO at 0x163f1328d90>

# Evaluation

In [50]:
evaluate_policy(model, env, n_eval_episodes=10, render=True)



(200.0, 0.0)

In [51]:
env.close()

# Test model

In [54]:
model.predict(obs)

(array([0], dtype=int64), None)

In [55]:
action , _ = model.predict(obs)

In [56]:
action

array([1], dtype=int64)

In [57]:
episodes = 5
for episode in range(1, episodes+1):
    obs = env.reset()
    done = False
    score = 0
    
    while not done:
        env.render()
        action, _ = model.predict(obs)
        obs, reward, done, info = env.step(action)
        score += reward
    print('Episode:{} Score:{}'.format(episode, score))
env.close() # To close down the render frame

Episode:1 Score:[200.]
Episode:2 Score:[200.]
Episode:3 Score:[76.]
Episode:4 Score:[200.]
Episode:5 Score:[178.]


In [58]:
obs = env.reset()

In [59]:
obs

array([[-0.00392499,  0.03335889, -0.00611641,  0.03700055]],
      dtype=float32)

In [60]:
action, _ = model.predict(obs)

In [67]:
env.action_space.sample()

1

In [68]:
env.step(action) # As step returns four parameteres i.s observations, reward, done, info

(array([[-0.00325781, -0.16167483, -0.00537639,  0.32774743]],
       dtype=float32),
 array([1.], dtype=float32),
 array([False]),
 [{}])

# Viewing Logs in Tensorboard

In [69]:
training_log_path = os.path.join(log_path, 'PPO_2')

In [70]:
training_log_path

'Training\\Logs\\PPO_2'

In [76]:
!tensorboard --logdir= {training_log_path}

2021-10-30 23:23:56.339638: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library cudart64_110.dll
usage: tensorboard [-h] [--helpfull] [--logdir PATH] [--logdir_spec PATH_SPEC]
                   [--host ADDR] [--bind_all] [--port PORT]
                   [--reuse_port BOOL] [--load_fast {false,auto,true}]
                   [--extra_data_server_flags EXTRA_DATA_SERVER_FLAGS]
                   [--grpc_creds_type {local,ssl,ssl_dev}]
                   [--grpc_data_provider PORT] [--purge_orphaned_data BOOL]
                   [--db URI] [--db_import] [--inspect] [--version_tb]
                   [--tag TAG] [--event_file PATH] [--path_prefix PATH]
                   [--window_title TEXT] [--max_reload_threads COUNT]
                   [--reload_interval SECONDS] [--reload_task TYPE]
                   [--reload_multifile BOOL]
                   [--reload_multifile_inactive_secs SECONDS]
                   [--generic_data TYPE]
          

# Adding a callback to the training stage

In [78]:
from stable_baselines3.common.callbacks import EvalCallback, StopTrainingOnRewardThreshold

In [79]:
save_path = os.path.join('Training', 'Saved Models')

In [80]:
stop_callback = StopTrainingOnRewardThreshold(reward_threshold=200, verbose=1)
eval_callback = EvalCallback(env, callback_on_new_best=stop_callback, eval_freq=10000, best_model_save_path=save_path, verbose=1)

In [81]:
model = PPO('MlpPolicy', env, verbose=1, tensorboard_log=log_path)

Using cuda device


In [82]:
model.learn(total_timesteps=20000, callback=eval_callback)

Logging to Training\Logs\PPO_3
-----------------------------
| time/              |      |
|    fps             | 536  |
|    iterations      | 1    |
|    time_elapsed    | 3    |
|    total_timesteps | 2048 |
-----------------------------
-----------------------------------------
| time/                   |             |
|    fps                  | 408         |
|    iterations           | 2           |
|    time_elapsed         | 10          |
|    total_timesteps      | 4096        |
| train/                  |             |
|    approx_kl            | 0.009245167 |
|    clip_fraction        | 0.106       |
|    clip_range           | 0.2         |
|    entropy_loss         | -0.686      |
|    explained_variance   | 0.00204     |
|    learning_rate        | 0.0003      |
|    loss                 | 6.44        |
|    n_updates            | 10          |
|    policy_gradient_loss | -0.0159     |
|    value_loss           | 53.3        |
-----------------------------------------
---

<stable_baselines3.ppo.ppo.PPO at 0x163f6042940>

# Changing Policies

In [84]:
net_arch = [dict(pi=[128,128,128,128,128], vf=[128,128,128,128])]

In [86]:
model = PPO('MlpPolicy', env, verbose=1, tensorboard_log=log_path, policy_kwargs={'net_arch':net_arch})

Using cuda device


In [87]:
model.learn(total_timesteps=20000, callback=eval_callback)

Logging to Training\Logs\PPO_4
-----------------------------
| time/              |      |
|    fps             | 415  |
|    iterations      | 1    |
|    time_elapsed    | 4    |
|    total_timesteps | 2048 |
-----------------------------
----------------------------------------
| time/                   |            |
|    fps                  | 308        |
|    iterations           | 2          |
|    time_elapsed         | 13         |
|    total_timesteps      | 4096       |
| train/                  |            |
|    approx_kl            | 0.01600798 |
|    clip_fraction        | 0.248      |
|    clip_range           | 0.2        |
|    entropy_loss         | -0.681     |
|    explained_variance   | 0.0037     |
|    learning_rate        | 0.0003     |
|    loss                 | 2.96       |
|    n_updates            | 10         |
|    policy_gradient_loss | -0.0299    |
|    value_loss           | 20.4       |
----------------------------------------
---------------------

<stable_baselines3.ppo.ppo.PPO at 0x163f60782b0>

# Using an alternate Algorithm

In [88]:
from stable_baselines3 import DQN

In [89]:
model = DQN('MlpPolicy', env, verbose=1, tensorboard_log=log_path)

Using cuda device


In [90]:
model.learn(total_timesteps=20000)

Logging to Training\Logs\DQN_1
----------------------------------
| rollout/            |          |
|    exploration rate | 0.97     |
| time/               |          |
|    episodes         | 4        |
|    fps              | 2919     |
|    time_elapsed     | 0        |
|    total timesteps  | 64       |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration rate | 0.933    |
| time/               |          |
|    episodes         | 8        |
|    fps              | 4072     |
|    time_elapsed     | 0        |
|    total timesteps  | 142      |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration rate | 0.896    |
| time/               |          |
|    episodes         | 12       |
|    fps              | 4884     |
|    time_elapsed     | 0        |
|    total timesteps  | 219      |
----------------------------------
------------------------

----------------------------------
| rollout/            |          |
|    exploration rate | 0.05     |
| time/               |          |
|    episodes         | 108      |
|    fps              | 8361     |
|    time_elapsed     | 0        |
|    total timesteps  | 2393     |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration rate | 0.05     |
| time/               |          |
|    episodes         | 112      |
|    fps              | 8386     |
|    time_elapsed     | 0        |
|    total timesteps  | 2459     |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration rate | 0.05     |
| time/               |          |
|    episodes         | 116      |
|    fps              | 8361     |
|    time_elapsed     | 0        |
|    total timesteps  | 2543     |
----------------------------------
----------------------------------
| rollout/          

----------------------------------
| rollout/            |          |
|    exploration rate | 0.05     |
| time/               |          |
|    episodes         | 216      |
|    fps              | 9093     |
|    time_elapsed     | 0        |
|    total timesteps  | 5124     |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration rate | 0.05     |
| time/               |          |
|    episodes         | 220      |
|    fps              | 9094     |
|    time_elapsed     | 0        |
|    total timesteps  | 5188     |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration rate | 0.05     |
| time/               |          |
|    episodes         | 224      |
|    fps              | 9091     |
|    time_elapsed     | 0        |
|    total timesteps  | 5259     |
----------------------------------
----------------------------------
| rollout/          

----------------------------------
| rollout/            |          |
|    exploration rate | 0.05     |
| time/               |          |
|    episodes         | 324      |
|    fps              | 9038     |
|    time_elapsed     | 0        |
|    total timesteps  | 7626     |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration rate | 0.05     |
| time/               |          |
|    episodes         | 328      |
|    fps              | 9027     |
|    time_elapsed     | 0        |
|    total timesteps  | 7698     |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration rate | 0.05     |
| time/               |          |
|    episodes         | 332      |
|    fps              | 9122     |
|    time_elapsed     | 0        |
|    total timesteps  | 7943     |
----------------------------------
----------------------------------
| rollout/          

----------------------------------
| rollout/            |          |
|    exploration rate | 0.05     |
| time/               |          |
|    episodes         | 432      |
|    fps              | 9055     |
|    time_elapsed     | 1        |
|    total timesteps  | 9970     |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration rate | 0.05     |
| time/               |          |
|    episodes         | 436      |
|    fps              | 7013     |
|    time_elapsed     | 1        |
|    total timesteps  | 10038    |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration rate | 0.05     |
| time/               |          |
|    episodes         | 440      |
|    fps              | 7015     |
|    time_elapsed     | 1        |
|    total timesteps  | 10118    |
----------------------------------
----------------------------------
| rollout/          

----------------------------------
| rollout/            |          |
|    exploration rate | 0.05     |
| time/               |          |
|    episodes         | 540      |
|    fps              | 7330     |
|    time_elapsed     | 1        |
|    total timesteps  | 12253    |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration rate | 0.05     |
| time/               |          |
|    episodes         | 544      |
|    fps              | 7343     |
|    time_elapsed     | 1        |
|    total timesteps  | 12348    |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration rate | 0.05     |
| time/               |          |
|    episodes         | 548      |
|    fps              | 7352     |
|    time_elapsed     | 1        |
|    total timesteps  | 12408    |
----------------------------------
----------------------------------
| rollout/          

----------------------------------
| rollout/            |          |
|    exploration rate | 0.05     |
| time/               |          |
|    episodes         | 648      |
|    fps              | 7645     |
|    time_elapsed     | 1        |
|    total timesteps  | 14543    |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration rate | 0.05     |
| time/               |          |
|    episodes         | 652      |
|    fps              | 7655     |
|    time_elapsed     | 1        |
|    total timesteps  | 14623    |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration rate | 0.05     |
| time/               |          |
|    episodes         | 656      |
|    fps              | 7649     |
|    time_elapsed     | 1        |
|    total timesteps  | 14695    |
----------------------------------
----------------------------------
| rollout/          

----------------------------------
| rollout/            |          |
|    exploration rate | 0.05     |
| time/               |          |
|    episodes         | 756      |
|    fps              | 7777     |
|    time_elapsed     | 2        |
|    total timesteps  | 16957    |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration rate | 0.05     |
| time/               |          |
|    episodes         | 760      |
|    fps              | 7785     |
|    time_elapsed     | 2        |
|    total timesteps  | 17029    |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration rate | 0.05     |
| time/               |          |
|    episodes         | 764      |
|    fps              | 7810     |
|    time_elapsed     | 2        |
|    total timesteps  | 17200    |
----------------------------------
----------------------------------
| rollout/          

----------------------------------
| rollout/            |          |
|    exploration rate | 0.05     |
| time/               |          |
|    episodes         | 864      |
|    fps              | 8020     |
|    time_elapsed     | 2        |
|    total timesteps  | 19339    |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration rate | 0.05     |
| time/               |          |
|    episodes         | 868      |
|    fps              | 8024     |
|    time_elapsed     | 2        |
|    total timesteps  | 19420    |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration rate | 0.05     |
| time/               |          |
|    episodes         | 872      |
|    fps              | 8030     |
|    time_elapsed     | 2        |
|    total timesteps  | 19483    |
----------------------------------
----------------------------------
| rollout/          

<stable_baselines3.dqn.dqn.DQN at 0x163f6042400>

In [91]:
from stable_baselines3 import A2C

In [92]:
model = A2C('MlpPolicy', env, verbose=1, tensorboard_log=log_path)

Using cuda device


In [93]:
model.learn(total_timesteps=20000)

Logging to Training\Logs\A2C_1
------------------------------------
| time/                 |          |
|    fps                | 275      |
|    iterations         | 100      |
|    time_elapsed       | 1        |
|    total_timesteps    | 500      |
| train/                |          |
|    entropy_loss       | -0.67    |
|    explained_variance | 0.381    |
|    learning_rate      | 0.0007   |
|    n_updates          | 99       |
|    policy_loss        | 1.05     |
|    value_loss         | 2.93     |
------------------------------------
------------------------------------
| time/                 |          |
|    fps                | 259      |
|    iterations         | 200      |
|    time_elapsed       | 3        |
|    total_timesteps    | 1000     |
| train/                |          |
|    entropy_loss       | -0.68    |
|    explained_variance | -0.0576  |
|    learning_rate      | 0.0007   |
|    n_updates          | 199      |
|    policy_loss        | 1.82     |
|    va

-------------------------------------
| time/                 |           |
|    fps                | 264       |
|    iterations         | 1700      |
|    time_elapsed       | 32        |
|    total_timesteps    | 8500      |
| train/                |           |
|    entropy_loss       | -0.677    |
|    explained_variance | -8.76e-05 |
|    learning_rate      | 0.0007    |
|    n_updates          | 1699      |
|    policy_loss        | 0.451     |
|    value_loss         | 0.575     |
-------------------------------------
------------------------------------
| time/                 |          |
|    fps                | 263      |
|    iterations         | 1800     |
|    time_elapsed       | 34       |
|    total_timesteps    | 9000     |
| train/                |          |
|    entropy_loss       | -0.634   |
|    explained_variance | 2.2e-05  |
|    learning_rate      | 0.0007   |
|    n_updates          | 1799     |
|    policy_loss        | -4.84    |
|    value_loss         

------------------------------------
| time/                 |          |
|    fps                | 171      |
|    iterations         | 3300     |
|    time_elapsed       | 96       |
|    total_timesteps    | 16500    |
| train/                |          |
|    entropy_loss       | -0.524   |
|    explained_variance | -0.0604  |
|    learning_rate      | 0.0007   |
|    n_updates          | 3299     |
|    policy_loss        | 0.000704 |
|    value_loss         | 1.77e-06 |
------------------------------------
------------------------------------
| time/                 |          |
|    fps                | 169      |
|    iterations         | 3400     |
|    time_elapsed       | 100      |
|    total_timesteps    | 17000    |
| train/                |          |
|    entropy_loss       | -0.494   |
|    explained_variance | -40.4    |
|    learning_rate      | 0.0007   |
|    n_updates          | 3399     |
|    policy_loss        | 5.32e-05 |
|    value_loss         | 4.34e-07 |
-

<stable_baselines3.a2c.a2c.A2C at 0x163f60a5460>

Similarly we can do for other RL Algorithms such as DDPG, HER, SAC, TD3 