# 1. Import Dependencies

In [1]:
import gym 
from stable_baselines3 import A2C
from stable_baselines3.common.vec_env import VecFrameStack
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.env_util import make_atari_env
import os

# 2. Test Environment

In [2]:
environment_name = "Breakout-v0"

In [3]:
env = gym.make(environment_name)

In [4]:
episodes = 5
for episode in range(1, episodes+1):
    state = env.reset()
    done = False
    score = 0 
    
    while not done:
        env.render()
        action = env.action_space.sample()
        n_state, reward, done, info = env.step(action)
        score+=reward
    print('Episode:{} Score:{}'.format(episode, score))
env.close()

  logger.warn(


Episode:1 Score:1.0
Episode:2 Score:1.0
Episode:3 Score:3.0
Episode:4 Score:2.0
Episode:5 Score:1.0


In [5]:
env.action_space.sample()

2

In [6]:
env.observation_space.sample()

array([[[ 42, 188,  59],
        [ 14, 144, 232],
        [132,  54,  62],
        ...,
        [154, 165, 136],
        [236, 174,  87],
        [228,   5, 128]],

       [[ 16, 217,  38],
        [  0, 237,  19],
        [189,  41, 223],
        ...,
        [236,  39,  56],
        [ 49,   7, 250],
        [ 86, 232,  95]],

       [[ 45, 157, 195],
        [ 56,   6, 160],
        [  4,  48, 109],
        ...,
        [ 79, 124,  40],
        [  4, 201, 195],
        [129, 169,  52]],

       ...,

       [[ 99, 225, 234],
        [ 77,  74, 197],
        [149, 166, 146],
        ...,
        [ 19, 123,  33],
        [ 29,   8,  74],
        [ 17,   5, 212]],

       [[ 55, 251,   3],
        [141, 214, 102],
        [214,  52, 136],
        ...,
        [126,  93, 125],
        [ 34, 194, 233],
        [ 93,  49,  53]],

       [[211,   3,  81],
        [159,  89,   5],
        [184, 251, 139],
        ...,
        [ 67, 255, 166],
        [248, 187, 218],
        [170, 122, 154]]

# 3. Vectorise Environment and Train Model

In [7]:
env = make_atari_env('Breakout-v0', n_envs=4, seed=0)

In [8]:
env = VecFrameStack(env, n_stack=4)

In [9]:
log_path = os.path.join('Training', 'Logs')

In [10]:
model = A2C("CnnPolicy", env, verbose=1, tensorboard_log=log_path)

Using cpu device
Wrapping the env in a VecTransposeImage.


In [11]:
model.learn(total_timesteps=400000)

Logging to Training\Logs\A2C_1
------------------------------------
| rollout/              |          |
|    ep_len_mean        | 299      |
|    ep_rew_mean        | 1.94     |
| time/                 |          |
|    fps                | 217      |
|    iterations         | 100      |
|    time_elapsed       | 9        |
|    total_timesteps    | 2000     |
| train/                |          |
|    entropy_loss       | -1.37    |
|    explained_variance | 0.0273   |
|    learning_rate      | 0.0007   |
|    n_updates          | 99       |
|    policy_loss        | 0.238    |
|    value_loss         | 0.256    |
------------------------------------
------------------------------------
| rollout/              |          |
|    ep_len_mean        | 291      |
|    ep_rew_mean        | 1.77     |
| time/                 |          |
|    fps                | 215      |
|    iterations         | 200      |
|    time_elapsed       | 18       |
|    total_timesteps    | 4000     |
| train

------------------------------------
| rollout/              |          |
|    ep_len_mean        | 329      |
|    ep_rew_mean        | 2.61     |
| time/                 |          |
|    fps                | 213      |
|    iterations         | 1400     |
|    time_elapsed       | 131      |
|    total_timesteps    | 28000    |
| train/                |          |
|    entropy_loss       | -0.424   |
|    explained_variance | 0.717    |
|    learning_rate      | 0.0007   |
|    n_updates          | 1399     |
|    policy_loss        | 0.0219   |
|    value_loss         | 0.0463   |
------------------------------------
------------------------------------
| rollout/              |          |
|    ep_len_mean        | 358      |
|    ep_rew_mean        | 3.17     |
| time/                 |          |
|    fps                | 213      |
|    iterations         | 1500     |
|    time_elapsed       | 140      |
|    total_timesteps    | 30000    |
| train/                |          |
|

------------------------------------
| rollout/              |          |
|    ep_len_mean        | 484      |
|    ep_rew_mean        | 5.74     |
| time/                 |          |
|    fps                | 221      |
|    iterations         | 2800     |
|    time_elapsed       | 252      |
|    total_timesteps    | 56000    |
| train/                |          |
|    entropy_loss       | -0.106   |
|    explained_variance | 0.883    |
|    learning_rate      | 0.0007   |
|    n_updates          | 2799     |
|    policy_loss        | 0.0113   |
|    value_loss         | 0.0637   |
------------------------------------
------------------------------------
| rollout/              |          |
|    ep_len_mean        | 475      |
|    ep_rew_mean        | 5.6      |
| time/                 |          |
|    fps                | 221      |
|    iterations         | 2900     |
|    time_elapsed       | 261      |
|    total_timesteps    | 58000    |
| train/                |          |
|

------------------------------------
| rollout/              |          |
|    ep_len_mean        | 458      |
|    ep_rew_mean        | 5.41     |
| time/                 |          |
|    fps                | 224      |
|    iterations         | 4100     |
|    time_elapsed       | 364      |
|    total_timesteps    | 82000    |
| train/                |          |
|    entropy_loss       | -0.0756  |
|    explained_variance | 0.901    |
|    learning_rate      | 0.0007   |
|    n_updates          | 4099     |
|    policy_loss        | 0.00834  |
|    value_loss         | 0.135    |
------------------------------------
------------------------------------
| rollout/              |          |
|    ep_len_mean        | 473      |
|    ep_rew_mean        | 5.57     |
| time/                 |          |
|    fps                | 224      |
|    iterations         | 4200     |
|    time_elapsed       | 373      |
|    total_timesteps    | 84000    |
| train/                |          |
|

------------------------------------
| rollout/              |          |
|    ep_len_mean        | 537      |
|    ep_rew_mean        | 6.86     |
| time/                 |          |
|    fps                | 223      |
|    iterations         | 5500     |
|    time_elapsed       | 491      |
|    total_timesteps    | 110000   |
| train/                |          |
|    entropy_loss       | -0.206   |
|    explained_variance | 0.249    |
|    learning_rate      | 0.0007   |
|    n_updates          | 5499     |
|    policy_loss        | 0.0741   |
|    value_loss         | 0.247    |
------------------------------------
------------------------------------
| rollout/              |          |
|    ep_len_mean        | 535      |
|    ep_rew_mean        | 6.86     |
| time/                 |          |
|    fps                | 224      |
|    iterations         | 5600     |
|    time_elapsed       | 499      |
|    total_timesteps    | 112000   |
| train/                |          |
|

------------------------------------
| rollout/              |          |
|    ep_len_mean        | 552      |
|    ep_rew_mean        | 7.15     |
| time/                 |          |
|    fps                | 226      |
|    iterations         | 6800     |
|    time_elapsed       | 600      |
|    total_timesteps    | 136000   |
| train/                |          |
|    entropy_loss       | -0.16    |
|    explained_variance | 0.936    |
|    learning_rate      | 0.0007   |
|    n_updates          | 6799     |
|    policy_loss        | -0.035   |
|    value_loss         | 0.0957   |
------------------------------------
------------------------------------
| rollout/              |          |
|    ep_len_mean        | 555      |
|    ep_rew_mean        | 7.18     |
| time/                 |          |
|    fps                | 226      |
|    iterations         | 6900     |
|    time_elapsed       | 608      |
|    total_timesteps    | 138000   |
| train/                |          |
|

------------------------------------
| rollout/              |          |
|    ep_len_mean        | 582      |
|    ep_rew_mean        | 7.83     |
| time/                 |          |
|    fps                | 227      |
|    iterations         | 8100     |
|    time_elapsed       | 711      |
|    total_timesteps    | 162000   |
| train/                |          |
|    entropy_loss       | -0.00552 |
|    explained_variance | 0.86     |
|    learning_rate      | 0.0007   |
|    n_updates          | 8099     |
|    policy_loss        | 9.88e-05 |
|    value_loss         | 0.0554   |
------------------------------------
------------------------------------
| rollout/              |          |
|    ep_len_mean        | 583      |
|    ep_rew_mean        | 7.8      |
| time/                 |          |
|    fps                | 227      |
|    iterations         | 8200     |
|    time_elapsed       | 720      |
|    total_timesteps    | 164000   |
| train/                |          |
|

------------------------------------
| rollout/              |          |
|    ep_len_mean        | 622      |
|    ep_rew_mean        | 8.65     |
| time/                 |          |
|    fps                | 230      |
|    iterations         | 9500     |
|    time_elapsed       | 825      |
|    total_timesteps    | 190000   |
| train/                |          |
|    entropy_loss       | -0.259   |
|    explained_variance | 0.708    |
|    learning_rate      | 0.0007   |
|    n_updates          | 9499     |
|    policy_loss        | -0.0366  |
|    value_loss         | 0.126    |
------------------------------------
-------------------------------------
| rollout/              |           |
|    ep_len_mean        | 581       |
|    ep_rew_mean        | 7.74      |
| time/                 |           |
|    fps                | 230       |
|    iterations         | 9600      |
|    time_elapsed       | 833       |
|    total_timesteps    | 192000    |
| train/                |    

------------------------------------
| rollout/              |          |
|    ep_len_mean        | 602      |
|    ep_rew_mean        | 8.47     |
| time/                 |          |
|    fps                | 231      |
|    iterations         | 10800    |
|    time_elapsed       | 931      |
|    total_timesteps    | 216000   |
| train/                |          |
|    entropy_loss       | -0.0787  |
|    explained_variance | 0.86     |
|    learning_rate      | 0.0007   |
|    n_updates          | 10799    |
|    policy_loss        | -0.0341  |
|    value_loss         | 0.117    |
------------------------------------
------------------------------------
| rollout/              |          |
|    ep_len_mean        | 607      |
|    ep_rew_mean        | 8.51     |
| time/                 |          |
|    fps                | 232      |
|    iterations         | 10900    |
|    time_elapsed       | 939      |
|    total_timesteps    | 218000   |
| train/                |          |
|

------------------------------------
| rollout/              |          |
|    ep_len_mean        | 647      |
|    ep_rew_mean        | 9.15     |
| time/                 |          |
|    fps                | 233      |
|    iterations         | 12200    |
|    time_elapsed       | 1044     |
|    total_timesteps    | 244000   |
| train/                |          |
|    entropy_loss       | -0.109   |
|    explained_variance | 0.819    |
|    learning_rate      | 0.0007   |
|    n_updates          | 12199    |
|    policy_loss        | -0.0116  |
|    value_loss         | 0.0644   |
------------------------------------
------------------------------------
| rollout/              |          |
|    ep_len_mean        | 636      |
|    ep_rew_mean        | 8.9      |
| time/                 |          |
|    fps                | 233      |
|    iterations         | 12300    |
|    time_elapsed       | 1054     |
|    total_timesteps    | 246000   |
| train/                |          |
|

------------------------------------
| rollout/              |          |
|    ep_len_mean        | 630      |
|    ep_rew_mean        | 9.18     |
| time/                 |          |
|    fps                | 232      |
|    iterations         | 13500    |
|    time_elapsed       | 1161     |
|    total_timesteps    | 270000   |
| train/                |          |
|    entropy_loss       | -0.0667  |
|    explained_variance | 0.773    |
|    learning_rate      | 0.0007   |
|    n_updates          | 13499    |
|    policy_loss        | -0.00404 |
|    value_loss         | 0.14     |
------------------------------------
------------------------------------
| rollout/              |          |
|    ep_len_mean        | 629      |
|    ep_rew_mean        | 9.25     |
| time/                 |          |
|    fps                | 232      |
|    iterations         | 13600    |
|    time_elapsed       | 1170     |
|    total_timesteps    | 272000   |
| train/                |          |
|

------------------------------------
| rollout/              |          |
|    ep_len_mean        | 679      |
|    ep_rew_mean        | 10.4     |
| time/                 |          |
|    fps                | 231      |
|    iterations         | 14800    |
|    time_elapsed       | 1278     |
|    total_timesteps    | 296000   |
| train/                |          |
|    entropy_loss       | -0.0933  |
|    explained_variance | 0.941    |
|    learning_rate      | 0.0007   |
|    n_updates          | 14799    |
|    policy_loss        | -0.0462  |
|    value_loss         | 0.0377   |
------------------------------------
------------------------------------
| rollout/              |          |
|    ep_len_mean        | 689      |
|    ep_rew_mean        | 10.8     |
| time/                 |          |
|    fps                | 231      |
|    iterations         | 14900    |
|    time_elapsed       | 1287     |
|    total_timesteps    | 298000   |
| train/                |          |
|

------------------------------------
| rollout/              |          |
|    ep_len_mean        | 684      |
|    ep_rew_mean        | 10.8     |
| time/                 |          |
|    fps                | 231      |
|    iterations         | 16100    |
|    time_elapsed       | 1393     |
|    total_timesteps    | 322000   |
| train/                |          |
|    entropy_loss       | -0.0799  |
|    explained_variance | 0.602    |
|    learning_rate      | 0.0007   |
|    n_updates          | 16099    |
|    policy_loss        | -0.128   |
|    value_loss         | 0.198    |
------------------------------------
------------------------------------
| rollout/              |          |
|    ep_len_mean        | 690      |
|    ep_rew_mean        | 11       |
| time/                 |          |
|    fps                | 231      |
|    iterations         | 16200    |
|    time_elapsed       | 1402     |
|    total_timesteps    | 324000   |
| train/                |          |
|

------------------------------------
| rollout/              |          |
|    ep_len_mean        | 691      |
|    ep_rew_mean        | 10.9     |
| time/                 |          |
|    fps                | 230      |
|    iterations         | 17400    |
|    time_elapsed       | 1509     |
|    total_timesteps    | 348000   |
| train/                |          |
|    entropy_loss       | -0.127   |
|    explained_variance | 0.814    |
|    learning_rate      | 0.0007   |
|    n_updates          | 17399    |
|    policy_loss        | 0.00712  |
|    value_loss         | 0.0254   |
------------------------------------
------------------------------------
| rollout/              |          |
|    ep_len_mean        | 681      |
|    ep_rew_mean        | 10.6     |
| time/                 |          |
|    fps                | 230      |
|    iterations         | 17500    |
|    time_elapsed       | 1519     |
|    total_timesteps    | 350000   |
| train/                |          |
|

------------------------------------
| rollout/              |          |
|    ep_len_mean        | 689      |
|    ep_rew_mean        | 10.8     |
| time/                 |          |
|    fps                | 230      |
|    iterations         | 18800    |
|    time_elapsed       | 1629     |
|    total_timesteps    | 376000   |
| train/                |          |
|    entropy_loss       | -0.0444  |
|    explained_variance | 0.922    |
|    learning_rate      | 0.0007   |
|    n_updates          | 18799    |
|    policy_loss        | -0.00658 |
|    value_loss         | 0.13     |
------------------------------------
------------------------------------
| rollout/              |          |
|    ep_len_mean        | 697      |
|    ep_rew_mean        | 11.1     |
| time/                 |          |
|    fps                | 230      |
|    iterations         | 18900    |
|    time_elapsed       | 1637     |
|    total_timesteps    | 378000   |
| train/                |          |
|

<stable_baselines3.a2c.a2c.A2C at 0x22dba26b5b0>

# 4. Save and Reload Model

In [12]:
a2c_path = os.path.join('Training', 'Saved Models', 'A2C_model')

In [13]:
model.save(a2c_path)



In [14]:
del model

In [15]:
env = make_atari_env('Breakout-v0', n_envs=1, seed=0)
env = VecFrameStack(env, n_stack=4)

In [16]:
model = A2C.load(a2c_path, env)

Wrapping the env in a VecTransposeImage.


# 5. Evaluate and Test

In [17]:
evaluate_policy(model, env, n_eval_episodes=10, render=True)

(9.9, 3.176476034853718)

In [18]:
obs = env.reset()
while True:
    action, _states = model.predict(obs)
    obs, rewards, dones, info = env.step(action)
    env.render()

KeyboardInterrupt: 

In [None]:
env.close()