### Import Dependencies

In [212]:
!pip install stable-baselines3[extra]




In [213]:
!pip install gym





In [214]:
!pip3 install torch torchvision torchaudio



In [215]:
import os
import gym
from stable_baselines3 import PPO 
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3.common.evaluation import evaluate_policy
import time



### Load Environment 

In [216]:
env = gym.make("CartPole-v1", render_mode="human")
env.action_space.seed(82)
observation, info = env.reset(seed=82)

In [287]:
env = gym.make("CartPole-v1", render_mode="human")

episodes = 5
for episode in range(1, episodes + 1):    
    state = env.reset()
    done = False
    score = 0

    while not done:
        env.render()
        action = env.action_space.sample()
        n_state, reward, done, done2, info = env.step(action)
        score += reward
        #time.sleep(0.1)
    print('Episode: {} Score {}'.format(episode, score))
#env.close()


Episode: 1 Score 20.0
Episode: 2 Score 13.0
Episode: 3 Score 18.0
Episode: 4 Score 77.0
Episode: 5 Score 19.0


In [288]:
env.close()

### Train RL Model

In [219]:
log_path = os.path.join('Training','Logs')

In [220]:
log_path

'Training\\Logs'

In [154]:
#env = gym.make("CartPole-v1", render_mode="human")
#env = DummyVecEnv([lambda: env])
#model = PPO('MlpPolicy', env, verbose=1, tensorboard_log=log_path)

In [289]:
def make_env():
    return gym.make("CartPole-v1", render_mode="human")

env = [make_env]
env = DummyVecEnv(env)  # Create the vectorized environment
model = PPO('MlpPolicy', env, verbose=1, tensorboard_log=log_path)

Using cpu device


In [93]:
PPO?

In [290]:
model.learn(total_timesteps=20000)

Logging to Training\Logs\PPO_14


KeyboardInterrupt: 

### Save and Reload

In [None]:
PPO_path = os.path.join('Training', 'Saved Models', 'PPO_Model_cartpole')

In [229]:
model.save(PPO_path)

In [136]:
del model

In [144]:
model.learn(total_timesteps=1000)

Logging to Training\Logs\PPO_11
-----------------------------
| time/              |      |
|    fps             | 46   |
|    iterations      | 1    |
|    time_elapsed    | 43   |
|    total_timesteps | 2048 |
-----------------------------


<stable_baselines3.ppo.ppo.PPO at 0x207262266e0>

In [230]:
model = PPO.load(PPO_path, env = env)

### Evaluation 

In [231]:
evaluate_policy(model, env, n_eval_episodes=3, render=True)

(500.0, 0.0)

In [179]:
env.close()

### Testing Agent

In [300]:
env = gym.make("CartPole-v1", render_mode="human")
env.action_space.seed(82)
observation, info = env.reset(seed=82)
env = DummyVecEnv([lambda: env])

episodes = 5
for episode in range(1, episodes + 1):    
    obs = env.reset()
    done = False
    score = 0

    while not done:
        env.render()
        action, _ = model.predict(obs)
        obs, reward, done, info = env.step(action)
        score += reward
        #time.sleep(0.1)
    print('Episode: {} Score {}'.format(episode, score))
#env.close()


Episode: 1 Score [285.]
Episode: 2 Score [363.]
Episode: 3 Score [500.]
Episode: 4 Score [115.]
Episode: 5 Score [110.]


In [258]:
env.close()

###  Tensorboard Visualization

In [265]:
training_log_path = os.path.join(log_path, 'PPO_1')

In [266]:
training_log_path

'Training\\Logs\\PPO_1'

In [273]:
!tensorboard --logdir={training_log_path}

^C


### Call back to the Training Stage

In [277]:
from stable_baselines3.common.callbacks import EvalCallback, StopTrainingOnRewardThreshold

In [279]:
save_path = os.path.join('Training', 'Saved Model')

In [297]:
stop_callback = StopTrainingOnRewardThreshold(reward_threshold=500, verbose=1)
eval_callback = EvalCallback(env,
                            callback_on_new_best=stop_callback,
                            eval_freq=3000,
                            best_model_save_path=save_path,
                            verbose=1,
                            render=False
                            )

In [298]:
model = PPO('MlpPolicy', env, verbose=1, tensorboard_log=log_path)

Using cpu device


In [299]:
model.learn(total_timesteps=20000, callback=eval_callback)

Logging to Training\Logs\PPO_16
-----------------------------
| time/              |      |
|    fps             | 46   |
|    iterations      | 1    |
|    time_elapsed    | 43   |
|    total_timesteps | 2048 |
-----------------------------
Eval num_timesteps=3000, episode_reward=198.80 +/- 13.91
Episode length: 198.80 +/- 13.91
-----------------------------------------
| eval/                   |             |
|    mean_ep_length       | 199         |
|    mean_reward          | 199         |
| time/                   |             |
|    total_timesteps      | 3000        |
| train/                  |             |
|    approx_kl            | 0.008206619 |
|    clip_fraction        | 0.124       |
|    clip_range           | 0.2         |
|    entropy_loss         | -0.686      |
|    explained_variance   | 0.0147      |
|    learning_rate        | 0.0003      |
|    loss                 | 10.2        |
|    n_updates            | 10          |
|    policy_gradient_loss | -0.0194   

------------------------------------------
| time/                   |              |
|    fps                  | 30           |
|    iterations           | 10           |
|    time_elapsed         | 665          |
|    total_timesteps      | 20480        |
| train/                  |              |
|    approx_kl            | 0.0043251826 |
|    clip_fraction        | 0.0298       |
|    clip_range           | 0.2          |
|    entropy_loss         | -0.542       |
|    explained_variance   | 0.744        |
|    learning_rate        | 0.0003       |
|    loss                 | 6.49         |
|    n_updates            | 90           |
|    policy_gradient_loss | -0.00423     |
|    value_loss           | 29.7         |
------------------------------------------


<stable_baselines3.ppo.ppo.PPO at 0x20727513010>

### Changing Policies

In [302]:
net_arch = [dict(pi=[128,128,128,128], vf=[128,128,128,128])]

In [303]:
model = PPO('MlpPolicy', env, verbose=1, tensorboard_log=log_path, policy_kwargs={'net_arch':net_arch})

Using cpu device




In [304]:
model.learn(total_timesteps=20000, callback=eval_callback)

Logging to Training\Logs\PPO_17
Eval num_timesteps=520, episode_reward=9.80 +/- 0.40
Episode length: 9.80 +/- 0.40
---------------------------------
| eval/              |          |
|    mean_ep_length  | 9.8      |
|    mean_reward     | 9.8      |
| time/              |          |
|    total_timesteps | 520      |
---------------------------------
-----------------------------
| time/              |      |
|    fps             | 45   |
|    iterations      | 1    |
|    time_elapsed    | 45   |
|    total_timesteps | 2048 |
-----------------------------
Eval num_timesteps=3520, episode_reward=338.20 +/- 134.25
Episode length: 338.20 +/- 134.25
-----------------------------------------
| eval/                   |             |
|    mean_ep_length       | 338         |
|    mean_reward          | 338         |
| time/                   |             |
|    total_timesteps      | 3520        |
| train/                  |             |
|    approx_kl            | 0.014095473 |
|    clip

<stable_baselines3.ppo.ppo.PPO at 0x20727623a00>

### Using Alternate Algorithm

In [305]:
from stable_baselines3 import DQN

In [306]:
model = DQN('MlpPolicy', env, verbose=1, tensorboard_log=log_path)

Using cpu device


In [307]:
model.learn(total_timesteps=20000)

Logging to Training\Logs\DQN_1
----------------------------------
| rollout/            |          |
|    exploration_rate | 0.964    |
| time/               |          |
|    episodes         | 4        |
|    fps              | 46       |
|    time_elapsed     | 1        |
|    total_timesteps  | 75       |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration_rate | 0.911    |
| time/               |          |
|    episodes         | 8        |
|    fps              | 46       |
|    time_elapsed     | 4        |
|    total_timesteps  | 188      |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration_rate | 0.855    |
| time/               |          |
|    episodes         | 12       |
|    fps              | 46       |
|    time_elapsed     | 6        |
|    total_timesteps  | 306      |
----------------------------------
------------------------

----------------------------------
| rollout/            |          |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 108      |
|    fps              | 46       |
|    time_elapsed     | 47       |
|    total_timesteps  | 2196     |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 112      |
|    fps              | 46       |
|    time_elapsed     | 48       |
|    total_timesteps  | 2258     |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 116      |
|    fps              | 46       |
|    time_elapsed     | 50       |
|    total_timesteps  | 2317     |
----------------------------------
----------------------------------
| rollout/          

----------------------------------
| rollout/            |          |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 216      |
|    fps              | 46       |
|    time_elapsed     | 102      |
|    total_timesteps  | 4775     |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 220      |
|    fps              | 46       |
|    time_elapsed     | 104      |
|    total_timesteps  | 4880     |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 224      |
|    fps              | 46       |
|    time_elapsed     | 106      |
|    total_timesteps  | 4975     |
----------------------------------
----------------------------------
| rollout/          

----------------------------------
| rollout/            |          |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 324      |
|    fps              | 46       |
|    time_elapsed     | 155      |
|    total_timesteps  | 7211     |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 328      |
|    fps              | 46       |
|    time_elapsed     | 157      |
|    total_timesteps  | 7328     |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 332      |
|    fps              | 46       |
|    time_elapsed     | 159      |
|    total_timesteps  | 7396     |
----------------------------------
----------------------------------
| rollout/          

----------------------------------
| rollout/            |          |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 432      |
|    fps              | 46       |
|    time_elapsed     | 208      |
|    total_timesteps  | 9683     |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 436      |
|    fps              | 46       |
|    time_elapsed     | 211      |
|    total_timesteps  | 9831     |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 440      |
|    fps              | 46       |
|    time_elapsed     | 212      |
|    total_timesteps  | 9910     |
----------------------------------
----------------------------------
| rollout/          

----------------------------------
| rollout/            |          |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 540      |
|    fps              | 46       |
|    time_elapsed     | 258      |
|    total_timesteps  | 12015    |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 544      |
|    fps              | 46       |
|    time_elapsed     | 259      |
|    total_timesteps  | 12080    |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 548      |
|    fps              | 46       |
|    time_elapsed     | 262      |
|    total_timesteps  | 12204    |
----------------------------------
----------------------------------
| rollout/          

----------------------------------
| rollout/            |          |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 648      |
|    fps              | 46       |
|    time_elapsed     | 313      |
|    total_timesteps  | 14593    |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 652      |
|    fps              | 46       |
|    time_elapsed     | 314      |
|    total_timesteps  | 14665    |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 656      |
|    fps              | 46       |
|    time_elapsed     | 316      |
|    total_timesteps  | 14731    |
----------------------------------
----------------------------------
| rollout/          

----------------------------------
| rollout/            |          |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 756      |
|    fps              | 46       |
|    time_elapsed     | 362      |
|    total_timesteps  | 16861    |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 760      |
|    fps              | 46       |
|    time_elapsed     | 364      |
|    total_timesteps  | 16975    |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 764      |
|    fps              | 46       |
|    time_elapsed     | 365      |
|    total_timesteps  | 17036    |
----------------------------------
----------------------------------
| rollout/          

----------------------------------
| rollout/            |          |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 864      |
|    fps              | 46       |
|    time_elapsed     | 412      |
|    total_timesteps  | 19212    |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 868      |
|    fps              | 46       |
|    time_elapsed     | 414      |
|    total_timesteps  | 19321    |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 872      |
|    fps              | 46       |
|    time_elapsed     | 416      |
|    total_timesteps  | 19395    |
----------------------------------
----------------------------------
| rollout/          

<stable_baselines3.dqn.dqn.DQN at 0x20727621330>