# 1. Import Dependencies

In [1]:
import gym 
from stable_baselines3 import A2C
from stable_baselines3.common.vec_env import VecFrameStack
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.env_util import make_atari_env
import os

# 2. Test Environment

In [2]:
environment_name = "Breakout-v0"

In [3]:
env = gym.make(environment_name)

In [4]:
episodes = 5
for episode in range(1, episodes+1):
    state = env.reset()
    done = False
    score = 0 
    
    while not done:
        env.render()
        action = env.action_space.sample()
        n_state, reward, done, info = env.step(action)
        score+=reward
    print('Episode:{} Score:{}'.format(episode, score))
env.close()



Episode:1 Score:1.0
Episode:2 Score:1.0
Episode:3 Score:1.0
Episode:4 Score:2.0
Episode:5 Score:1.0


In [5]:
env.action_space.sample()

0

In [6]:
env.observation_space.sample()

array([[[  7,  79,   6],
        [ 16, 151, 185],
        [ 99, 194, 105],
        ...,
        [ 50,  23, 191],
        [140, 130,  27],
        [115,  39, 191]],

       [[136,  75, 186],
        [ 85, 206, 146],
        [240,  13,  89],
        ...,
        [191, 210,  32],
        [126, 130,  10],
        [ 32, 154, 135]],

       [[215, 157,  45],
        [223,  15, 249],
        [239,   2,  87],
        ...,
        [ 23, 132, 187],
        [ 41, 145,  23],
        [143, 212, 251]],

       ...,

       [[ 23,  92, 187],
        [ 11, 254, 116],
        [ 89,  79,   1],
        ...,
        [243,  24,  45],
        [  0, 134, 223],
        [ 25, 229, 224]],

       [[ 88,  48, 152],
        [237, 188,  29],
        [144, 226, 229],
        ...,
        [ 77, 178, 120],
        [ 43,  62, 162],
        [194, 212, 130]],

       [[135, 251, 110],
        [116,  85,  44],
        [214, 221, 129],
        ...,
        [255,  46, 171],
        [244,   7,  51],
        [ 37,  83, 106]]

# 3. Vectorise Environment and Train Model

In [7]:
env = make_atari_env('Breakout-v0', n_envs=4, seed=0)

In [8]:
env = VecFrameStack(env, n_stack=4)

In [9]:
log_path = os.path.join('Training', 'Logs')

In [10]:
model = A2C("CnnPolicy", env, verbose=1, tensorboard_log=log_path)

Using cpu device
Wrapping the env in a VecTransposeImage.


In [11]:
!which tensorboard


/usr/local/anaconda3/envs/bmrl/bin/tensorboard


In [12]:
model.learn(total_timesteps=400000)

Logging to Training/Logs/A2C_1
------------------------------------
| rollout/              |          |
|    ep_len_mean        | 266      |
|    ep_rew_mean        | 1.19     |
| time/                 |          |
|    fps                | 267      |
|    iterations         | 100      |
|    time_elapsed       | 7        |
|    total_timesteps    | 2000     |
| train/                |          |
|    entropy_loss       | -1.39    |
|    explained_variance | 0.106    |
|    learning_rate      | 0.0007   |
|    n_updates          | 99       |
|    policy_loss        | 0.195    |
|    value_loss         | 0.19     |
------------------------------------
------------------------------------
| rollout/              |          |
|    ep_len_mean        | 277      |
|    ep_rew_mean        | 1.42     |
| time/                 |          |
|    fps                | 265      |
|    iterations         | 200      |
|    time_elapsed       | 15       |
|    total_timesteps    | 4000     |
| train

------------------------------------
| rollout/              |          |
|    ep_len_mean        | 316      |
|    ep_rew_mean        | 2.25     |
| time/                 |          |
|    fps                | 248      |
|    iterations         | 1400     |
|    time_elapsed       | 112      |
|    total_timesteps    | 28000    |
| train/                |          |
|    entropy_loss       | -0.806   |
|    explained_variance | 0.88     |
|    learning_rate      | 0.0007   |
|    n_updates          | 1399     |
|    policy_loss        | -0.0139  |
|    value_loss         | 0.027    |
------------------------------------
------------------------------------
| rollout/              |          |
|    ep_len_mean        | 313      |
|    ep_rew_mean        | 2.16     |
| time/                 |          |
|    fps                | 248      |
|    iterations         | 1500     |
|    time_elapsed       | 120      |
|    total_timesteps    | 30000    |
| train/                |          |
|

------------------------------------
| rollout/              |          |
|    ep_len_mean        | 413      |
|    ep_rew_mean        | 4.55     |
| time/                 |          |
|    fps                | 238      |
|    iterations         | 2800     |
|    time_elapsed       | 234      |
|    total_timesteps    | 56000    |
| train/                |          |
|    entropy_loss       | -0.587   |
|    explained_variance | 0.795    |
|    learning_rate      | 0.0007   |
|    n_updates          | 2799     |
|    policy_loss        | 0.102    |
|    value_loss         | 0.166    |
------------------------------------
------------------------------------
| rollout/              |          |
|    ep_len_mean        | 410      |
|    ep_rew_mean        | 4.55     |
| time/                 |          |
|    fps                | 238      |
|    iterations         | 2900     |
|    time_elapsed       | 243      |
|    total_timesteps    | 58000    |
| train/                |          |
|

------------------------------------
| rollout/              |          |
|    ep_len_mean        | 451      |
|    ep_rew_mean        | 5.23     |
| time/                 |          |
|    fps                | 243      |
|    iterations         | 4200     |
|    time_elapsed       | 345      |
|    total_timesteps    | 84000    |
| train/                |          |
|    entropy_loss       | -0.44    |
|    explained_variance | 0.963    |
|    learning_rate      | 0.0007   |
|    n_updates          | 4199     |
|    policy_loss        | -0.0283  |
|    value_loss         | 0.045    |
------------------------------------
------------------------------------
| rollout/              |          |
|    ep_len_mean        | 455      |
|    ep_rew_mean        | 5.31     |
| time/                 |          |
|    fps                | 243      |
|    iterations         | 4300     |
|    time_elapsed       | 352      |
|    total_timesteps    | 86000    |
| train/                |          |
|

------------------------------------
| rollout/              |          |
|    ep_len_mean        | 524      |
|    ep_rew_mean        | 6.78     |
| time/                 |          |
|    fps                | 244      |
|    iterations         | 5600     |
|    time_elapsed       | 458      |
|    total_timesteps    | 112000   |
| train/                |          |
|    entropy_loss       | -0.369   |
|    explained_variance | 0.881    |
|    learning_rate      | 0.0007   |
|    n_updates          | 5599     |
|    policy_loss        | -0.0614  |
|    value_loss         | 0.0255   |
------------------------------------
------------------------------------
| rollout/              |          |
|    ep_len_mean        | 527      |
|    ep_rew_mean        | 6.8      |
| time/                 |          |
|    fps                | 244      |
|    iterations         | 5700     |
|    time_elapsed       | 466      |
|    total_timesteps    | 114000   |
| train/                |          |
|

------------------------------------
| rollout/              |          |
|    ep_len_mean        | 576      |
|    ep_rew_mean        | 7.76     |
| time/                 |          |
|    fps                | 247      |
|    iterations         | 6900     |
|    time_elapsed       | 557      |
|    total_timesteps    | 138000   |
| train/                |          |
|    entropy_loss       | -0.188   |
|    explained_variance | 0.882    |
|    learning_rate      | 0.0007   |
|    n_updates          | 6899     |
|    policy_loss        | 0.0463   |
|    value_loss         | 0.118    |
------------------------------------
------------------------------------
| rollout/              |          |
|    ep_len_mean        | 579      |
|    ep_rew_mean        | 7.75     |
| time/                 |          |
|    fps                | 247      |
|    iterations         | 7000     |
|    time_elapsed       | 565      |
|    total_timesteps    | 140000   |
| train/                |          |
|

------------------------------------
| rollout/              |          |
|    ep_len_mean        | 563      |
|    ep_rew_mean        | 7.61     |
| time/                 |          |
|    fps                | 248      |
|    iterations         | 8200     |
|    time_elapsed       | 659      |
|    total_timesteps    | 164000   |
| train/                |          |
|    entropy_loss       | -0.106   |
|    explained_variance | 0.932    |
|    learning_rate      | 0.0007   |
|    n_updates          | 8199     |
|    policy_loss        | 0.149    |
|    value_loss         | 0.0536   |
------------------------------------
------------------------------------
| rollout/              |          |
|    ep_len_mean        | 560      |
|    ep_rew_mean        | 7.46     |
| time/                 |          |
|    fps                | 248      |
|    iterations         | 8300     |
|    time_elapsed       | 667      |
|    total_timesteps    | 166000   |
| train/                |          |
|

------------------------------------
| rollout/              |          |
|    ep_len_mean        | 583      |
|    ep_rew_mean        | 8.04     |
| time/                 |          |
|    fps                | 246      |
|    iterations         | 9600     |
|    time_elapsed       | 780      |
|    total_timesteps    | 192000   |
| train/                |          |
|    entropy_loss       | -0.0941  |
|    explained_variance | 0.738    |
|    learning_rate      | 0.0007   |
|    n_updates          | 9599     |
|    policy_loss        | -0.272   |
|    value_loss         | 0.208    |
------------------------------------
------------------------------------
| rollout/              |          |
|    ep_len_mean        | 577      |
|    ep_rew_mean        | 7.95     |
| time/                 |          |
|    fps                | 244      |
|    iterations         | 9700     |
|    time_elapsed       | 791      |
|    total_timesteps    | 194000   |
| train/                |          |
|

------------------------------------
| rollout/              |          |
|    ep_len_mean        | 631      |
|    ep_rew_mean        | 9.26     |
| time/                 |          |
|    fps                | 239      |
|    iterations         | 10900    |
|    time_elapsed       | 908      |
|    total_timesteps    | 218000   |
| train/                |          |
|    entropy_loss       | -0.115   |
|    explained_variance | 0.468    |
|    learning_rate      | 0.0007   |
|    n_updates          | 10899    |
|    policy_loss        | 0.0119   |
|    value_loss         | 0.108    |
------------------------------------
------------------------------------
| rollout/              |          |
|    ep_len_mean        | 631      |
|    ep_rew_mean        | 9.15     |
| time/                 |          |
|    fps                | 239      |
|    iterations         | 11000    |
|    time_elapsed       | 918      |
|    total_timesteps    | 220000   |
| train/                |          |
|

------------------------------------
| rollout/              |          |
|    ep_len_mean        | 641      |
|    ep_rew_mean        | 9.23     |
| time/                 |          |
|    fps                | 237      |
|    iterations         | 12300    |
|    time_elapsed       | 1035     |
|    total_timesteps    | 246000   |
| train/                |          |
|    entropy_loss       | -0.00049 |
|    explained_variance | 0.835    |
|    learning_rate      | 0.0007   |
|    n_updates          | 12299    |
|    policy_loss        | 1e-05    |
|    value_loss         | 0.17     |
------------------------------------
------------------------------------
| rollout/              |          |
|    ep_len_mean        | 630      |
|    ep_rew_mean        | 9.02     |
| time/                 |          |
|    fps                | 237      |
|    iterations         | 12400    |
|    time_elapsed       | 1043     |
|    total_timesteps    | 248000   |
| train/                |          |
|

------------------------------------
| rollout/              |          |
|    ep_len_mean        | 604      |
|    ep_rew_mean        | 8.63     |
| time/                 |          |
|    fps                | 236      |
|    iterations         | 13700    |
|    time_elapsed       | 1160     |
|    total_timesteps    | 274000   |
| train/                |          |
|    entropy_loss       | -0.0645  |
|    explained_variance | 0.703    |
|    learning_rate      | 0.0007   |
|    n_updates          | 13699    |
|    policy_loss        | 0.00036  |
|    value_loss         | 0.155    |
------------------------------------
------------------------------------
| rollout/              |          |
|    ep_len_mean        | 608      |
|    ep_rew_mean        | 8.63     |
| time/                 |          |
|    fps                | 236      |
|    iterations         | 13800    |
|    time_elapsed       | 1169     |
|    total_timesteps    | 276000   |
| train/                |          |
|

------------------------------------
| rollout/              |          |
|    ep_len_mean        | 697      |
|    ep_rew_mean        | 10.7     |
| time/                 |          |
|    fps                | 235      |
|    iterations         | 15100    |
|    time_elapsed       | 1281     |
|    total_timesteps    | 302000   |
| train/                |          |
|    entropy_loss       | -0.131   |
|    explained_variance | 0.826    |
|    learning_rate      | 0.0007   |
|    n_updates          | 15099    |
|    policy_loss        | -0.0718  |
|    value_loss         | 0.0829   |
------------------------------------
------------------------------------
| rollout/              |          |
|    ep_len_mean        | 702      |
|    ep_rew_mean        | 11.1     |
| time/                 |          |
|    fps                | 235      |
|    iterations         | 15200    |
|    time_elapsed       | 1289     |
|    total_timesteps    | 304000   |
| train/                |          |
|

------------------------------------
| rollout/              |          |
|    ep_len_mean        | 657      |
|    ep_rew_mean        | 10       |
| time/                 |          |
|    fps                | 236      |
|    iterations         | 16400    |
|    time_elapsed       | 1389     |
|    total_timesteps    | 328000   |
| train/                |          |
|    entropy_loss       | -0.047   |
|    explained_variance | 0.642    |
|    learning_rate      | 0.0007   |
|    n_updates          | 16399    |
|    policy_loss        | 0.0014   |
|    value_loss         | 0.0888   |
------------------------------------
------------------------------------
| rollout/              |          |
|    ep_len_mean        | 651      |
|    ep_rew_mean        | 9.91     |
| time/                 |          |
|    fps                | 236      |
|    iterations         | 16500    |
|    time_elapsed       | 1397     |
|    total_timesteps    | 330000   |
| train/                |          |
|

------------------------------------
| rollout/              |          |
|    ep_len_mean        | 694      |
|    ep_rew_mean        | 10.7     |
| time/                 |          |
|    fps                | 236      |
|    iterations         | 17700    |
|    time_elapsed       | 1497     |
|    total_timesteps    | 354000   |
| train/                |          |
|    entropy_loss       | -0.006   |
|    explained_variance | 0.923    |
|    learning_rate      | 0.0007   |
|    n_updates          | 17699    |
|    policy_loss        | 0.000261 |
|    value_loss         | 0.0594   |
------------------------------------
------------------------------------
| rollout/              |          |
|    ep_len_mean        | 688      |
|    ep_rew_mean        | 10.6     |
| time/                 |          |
|    fps                | 236      |
|    iterations         | 17800    |
|    time_elapsed       | 1505     |
|    total_timesteps    | 356000   |
| train/                |          |
|

------------------------------------
| rollout/              |          |
|    ep_len_mean        | 685      |
|    ep_rew_mean        | 10.7     |
| time/                 |          |
|    fps                | 236      |
|    iterations         | 19100    |
|    time_elapsed       | 1612     |
|    total_timesteps    | 382000   |
| train/                |          |
|    entropy_loss       | -0.178   |
|    explained_variance | 0.95     |
|    learning_rate      | 0.0007   |
|    n_updates          | 19099    |
|    policy_loss        | 0.000358 |
|    value_loss         | 0.0526   |
------------------------------------
------------------------------------
| rollout/              |          |
|    ep_len_mean        | 701      |
|    ep_rew_mean        | 11       |
| time/                 |          |
|    fps                | 236      |
|    iterations         | 19200    |
|    time_elapsed       | 1621     |
|    total_timesteps    | 384000   |
| train/                |          |
|

<stable_baselines3.a2c.a2c.A2C at 0x7fb80b43edd0>

# 4. Save and Reload Model

In [13]:
a2c_path = os.path.join('Training', 'Saved Models', 'A2C_model')

In [14]:
model.save(a2c_path)

In [15]:
del model

In [16]:
env = make_atari_env('Breakout-v0', n_envs=1, seed=0)
env = VecFrameStack(env, n_stack=4)

In [17]:
model = A2C.load(a2c_path, env)

Wrapping the env in a VecTransposeImage.


# 5. Evaluate and Test

In [18]:
evaluate_policy(model, env, n_eval_episodes=10, render=True)



(12.9, 5.127377497317708)

In [None]:
obs = env.reset()
while True:
    action, _states = model.predict(obs)
    obs, rewards, dones, info = env.step(action)
    env.render()

In [None]:
env.close()