In [1]:
# Import dependencies
!pip install stable-baselines3[extra]

Collecting stable-baselines3[extra]
  Downloading stable_baselines3-1.6.0-py3-none-any.whl (177 kB)
Collecting gym==0.21
  Downloading gym-0.21.0.tar.gz (1.5 MB)
Collecting torch>=1.11
  Downloading torch-1.12.1-cp39-cp39-win_amd64.whl (161.8 MB)
Collecting ale-py==0.7.4
  Downloading ale_py-0.7.4-cp39-cp39-win_amd64.whl (904 kB)
Collecting autorom[accept-rom-license]~=0.4.2
  Downloading AutoROM-0.4.2-py3-none-any.whl (16 kB)
Collecting importlib-metadata>=4.10.0
  Downloading importlib_metadata-4.12.0-py3-none-any.whl (21 kB)
Collecting importlib-resources
  Downloading importlib_resources-5.9.0-py3-none-any.whl (33 kB)
Collecting AutoROM.accept-rom-license
  Downloading AutoROM.accept-rom-license-0.4.2.tar.gz (9.8 kB)
  Installing build dependencies: started
  Installing build dependencies: finished with status 'done'
  Getting requirements to build wheel: started
  Getting requirements to build wheel: finished with status 'done'
    Preparing wheel metadata: started
    Preparing w

In [2]:
import gym 
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3.common.evaluation import evaluate_policy

In [3]:
# Load Environment
environment_name = "CartPole-v0"

env = gym.make(environment_name)

In [6]:
episodes = 5
for episode in range(1, episodes+1):
    state = env.reset()
    done = False
    score = 0 
    
    while not done:
        env.render()
        action = env.action_space.sample()
        n_state, reward, done, info = env.step(action)
        score+=reward
    print('Episode:{} Score:{}'.format(episode, score))
env.close()

Episode:1 Score:65.0
Episode:2 Score:46.0
Episode:3 Score:45.0
Episode:4 Score:27.0
Episode:5 Score:48.0


In [7]:
# Understanding The Environment
env.action_space.sample()
# 0-push cart to left, 1-push cart to the right

1

In [8]:
env.observation_space.sample()
# [cart position, cart velocity, pole angle, pole angular velocity]

array([ 3.9290626e+00, -9.1089731e+37, -3.1479254e-02, -1.1330018e+38],
      dtype=float32)

In [9]:
# Train an RL Model
env = gym.make(environment_name)
env = DummyVecEnv([lambda: env])
model = PPO('MlpPolicy', env, verbose = 1)

Using cpu device


In [10]:
model.learn(total_timesteps=20000)

-----------------------------
| time/              |      |
|    fps             | 1217 |
|    iterations      | 1    |
|    time_elapsed    | 1    |
|    total_timesteps | 2048 |
-----------------------------
-----------------------------------------
| time/                   |             |
|    fps                  | 805         |
|    iterations           | 2           |
|    time_elapsed         | 5           |
|    total_timesteps      | 4096        |
| train/                  |             |
|    approx_kl            | 0.008730016 |
|    clip_fraction        | 0.0964      |
|    clip_range           | 0.2         |
|    entropy_loss         | -0.686      |
|    explained_variance   | -0.00319    |
|    learning_rate        | 0.0003      |
|    loss                 | 5.79        |
|    n_updates            | 10          |
|    policy_gradient_loss | -0.0154     |
|    value_loss           | 48.3        |
-----------------------------------------
----------------------------------

<stable_baselines3.ppo.ppo.PPO at 0x1e08c45fd60>

In [11]:
# Save and Reload Model
import os
PPO_path = os.path.join('Training', 'Saved Models', 'PPO_model')

In [12]:
model.save(PPO_path)



In [13]:
del model

In [15]:
model = PPO.load('Training/Saved Models/PPO_model', env=env)

In [16]:
# Evaluation
evaluate_policy(model, env, n_eval_episodes=10, render=True)



(200.0, 0.0)

In [17]:
env.close()

In [18]:
# Test Model
obs = env.reset()
while True:
    action, _states = model.predict(obs)
    obs, rewards, done, info = env.step(action)
    env.render()
    if done: 
        print('info', info)
        break

info [{'TimeLimit.truncated': True, 'terminal_observation': array([-0.03061121, -0.41353926,  0.05644975,  0.06191196], dtype=float32)}]


In [19]:
env.close()

In [22]:
# Viewing Logs in Tensorboard
training_log_path = os.path.join(log_path, 'PPO_3')

In [24]:
# Adding a callback to the training Stage
from stable_baselines3.common.callbacks import EvalCallback, StopTrainingOnRewardThreshold

save_path = os.path.join('Training', 'Saved Models')
log_path = os.path.join('Training', 'Logs')

In [26]:
env = gym.make(environment_name)
env = DummyVecEnv([lambda: env])

In [27]:
stop_callback = StopTrainingOnRewardThreshold(reward_threshold=190, verbose=1)
eval_callback = EvalCallback(env, 
                             callback_on_new_best=stop_callback, 
                             eval_freq=10000, 
                             best_model_save_path=save_path, 
                             verbose=1)

In [28]:
model = PPO('MlpPolicy', env, verbose = 1, tensorboard_log=log_path)

Using cpu device


In [29]:
model.learn(total_timesteps=20000, callback=eval_callback)

Logging to Training\Logs\PPO_1
-----------------------------
| time/              |      |
|    fps             | 386  |
|    iterations      | 1    |
|    time_elapsed    | 5    |
|    total_timesteps | 2048 |
-----------------------------
-----------------------------------------
| time/                   |             |
|    fps                  | 480         |
|    iterations           | 2           |
|    time_elapsed         | 8           |
|    total_timesteps      | 4096        |
| train/                  |             |
|    approx_kl            | 0.008157235 |
|    clip_fraction        | 0.0861      |
|    clip_range           | 0.2         |
|    entropy_loss         | -0.687      |
|    explained_variance   | -0.0104     |
|    learning_rate        | 0.0003      |
|    loss                 | 7.28        |
|    n_updates            | 10          |
|    policy_gradient_loss | -0.0136     |
|    value_loss           | 55.9        |
-----------------------------------------
---



Eval num_timesteps=10000, episode_reward=197.40 +/- 5.20
Episode length: 197.40 +/- 5.20
-----------------------------------------
| eval/                   |             |
|    mean_ep_length       | 197         |
|    mean_reward          | 197         |
| time/                   |             |
|    total_timesteps      | 10000       |
| train/                  |             |
|    approx_kl            | 0.006334926 |
|    clip_fraction        | 0.0471      |
|    clip_range           | 0.2         |
|    entropy_loss         | -0.62       |
|    explained_variance   | 0.302       |
|    learning_rate        | 0.0003      |
|    loss                 | 27.4        |
|    n_updates            | 40          |
|    policy_gradient_loss | -0.0131     |
|    value_loss           | 62.1        |
-----------------------------------------
New best mean reward!
Stopping training because the mean reward 197.40  is above the threshold 190


<stable_baselines3.ppo.ppo.PPO at 0x1e08c455df0>

In [30]:
model_path = os.path.join('Training', 'Saved Models', 'best_model')
model = PPO.load(model_path, env=env)

In [31]:
evaluate_policy(model, env, n_eval_episodes=10, render=True)

(199.1, 2.7)

In [32]:
env.close()

In [33]:
# Changing Policies
net_arch=[dict(pi=[128, 128, 128, 128], vf=[128, 128, 128, 128])]

In [34]:
model = PPO('MlpPolicy', env, verbose = 1, policy_kwargs={'net_arch': net_arch})

Using cpu device


In [35]:
model.learn(total_timesteps=20000, callback=eval_callback)

-----------------------------
| time/              |      |
|    fps             | 470  |
|    iterations      | 1    |
|    time_elapsed    | 4    |
|    total_timesteps | 2048 |
-----------------------------
-----------------------------------------
| time/                   |             |
|    fps                  | 323         |
|    iterations           | 2           |
|    time_elapsed         | 12          |
|    total_timesteps      | 4096        |
| train/                  |             |
|    approx_kl            | 0.014374632 |
|    clip_fraction        | 0.173       |
|    clip_range           | 0.2         |
|    entropy_loss         | -0.682      |
|    explained_variance   | 0.000391    |
|    learning_rate        | 0.0003      |
|    loss                 | 3.39        |
|    n_updates            | 10          |
|    policy_gradient_loss | -0.0202     |
|    value_loss           | 22.3        |
-----------------------------------------
----------------------------------

<stable_baselines3.ppo.ppo.PPO at 0x1e0a08450a0>

In [36]:
# Using an Alternate Algorithm
from stable_baselines3 import DQN

In [37]:
model = DQN('MlpPolicy', env, verbose = 1, tensorboard_log=log_path)

Using cpu device


In [38]:
model.learn(total_timesteps=20000, callback=eval_callback)

Logging to Training\Logs\DQN_1
----------------------------------
| rollout/            |          |
|    exploration_rate | 0.958    |
| time/               |          |
|    episodes         | 4        |
|    fps              | 967      |
|    time_elapsed     | 0        |
|    total_timesteps  | 89       |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration_rate | 0.924    |
| time/               |          |
|    episodes         | 8        |
|    fps              | 1081     |
|    time_elapsed     | 0        |
|    total_timesteps  | 160      |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration_rate | 0.868    |
| time/               |          |
|    episodes         | 12       |
|    fps              | 1272     |
|    time_elapsed     | 0        |
|    total_timesteps  | 277      |
----------------------------------
------------------------

----------------------------------
| rollout/            |          |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 108      |
|    fps              | 2104     |
|    time_elapsed     | 1        |
|    total_timesteps  | 2240     |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 112      |
|    fps              | 2107     |
|    time_elapsed     | 1        |
|    total_timesteps  | 2300     |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 116      |
|    fps              | 2109     |
|    time_elapsed     | 1        |
|    total_timesteps  | 2364     |
----------------------------------
----------------------------------
| rollout/          

----------------------------------
| rollout/            |          |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 216      |
|    fps              | 2300     |
|    time_elapsed     | 1        |
|    total_timesteps  | 4486     |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 220      |
|    fps              | 2311     |
|    time_elapsed     | 1        |
|    total_timesteps  | 4591     |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 224      |
|    fps              | 2293     |
|    time_elapsed     | 2        |
|    total_timesteps  | 4653     |
----------------------------------
----------------------------------
| rollout/          

----------------------------------
| rollout/            |          |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 324      |
|    fps              | 2224     |
|    time_elapsed     | 3        |
|    total_timesteps  | 6734     |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 328      |
|    fps              | 2228     |
|    time_elapsed     | 3        |
|    total_timesteps  | 6838     |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 332      |
|    fps              | 2239     |
|    time_elapsed     | 3        |
|    total_timesteps  | 7012     |
----------------------------------
----------------------------------
| rollout/          

----------------------------------
| rollout/            |          |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 432      |
|    fps              | 2229     |
|    time_elapsed     | 4        |
|    total_timesteps  | 9341     |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 436      |
|    fps              | 2229     |
|    time_elapsed     | 4        |
|    total_timesteps  | 9411     |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 440      |
|    fps              | 2230     |
|    time_elapsed     | 4        |
|    total_timesteps  | 9489     |
----------------------------------
----------------------------------
| rollout/          

----------------------------------
| rollout/            |          |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 532      |
|    fps              | 2268     |
|    time_elapsed     | 5        |
|    total_timesteps  | 11460    |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 536      |
|    fps              | 2273     |
|    time_elapsed     | 5        |
|    total_timesteps  | 11550    |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 540      |
|    fps              | 2275     |
|    time_elapsed     | 5        |
|    total_timesteps  | 11633    |
----------------------------------
----------------------------------
| rollout/          

----------------------------------
| rollout/            |          |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 640      |
|    fps              | 2334     |
|    time_elapsed     | 5        |
|    total_timesteps  | 13582    |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 644      |
|    fps              | 2336     |
|    time_elapsed     | 5        |
|    total_timesteps  | 13656    |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 648      |
|    fps              | 2337     |
|    time_elapsed     | 5        |
|    total_timesteps  | 13718    |
----------------------------------
----------------------------------
| rollout/          

----------------------------------
| rollout/            |          |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 748      |
|    fps              | 2338     |
|    time_elapsed     | 6        |
|    total_timesteps  | 16006    |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 752      |
|    fps              | 2348     |
|    time_elapsed     | 6        |
|    total_timesteps  | 16160    |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 756      |
|    fps              | 2345     |
|    time_elapsed     | 6        |
|    total_timesteps  | 16213    |
----------------------------------
----------------------------------
| rollout/          

----------------------------------
| rollout/            |          |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 856      |
|    fps              | 2389     |
|    time_elapsed     | 7        |
|    total_timesteps  | 18354    |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 860      |
|    fps              | 2387     |
|    time_elapsed     | 7        |
|    total_timesteps  | 18424    |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 864      |
|    fps              | 2364     |
|    time_elapsed     | 7        |
|    total_timesteps  | 18492    |
----------------------------------
----------------------------------
| rollout/          

<stable_baselines3.dqn.dqn.DQN at 0x1e0a089fdf0>

In [39]:
dqn_path = os.path.join('Training', 'Saved Models', 'DQN_model')

In [40]:
model.save(dqn_path)

In [41]:
model = DQN.load(dqn_path, env=env)

In [42]:
evaluate_policy(model, env, n_eval_episodes=10, render=True)

(9.7, 0.7810249675906654)

In [43]:
env.close()