# 1. Import Dependencies

In [1]:
import gym 
from stable_baselines3 import A2C
from stable_baselines3.common.vec_env import VecFrameStack
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.env_util import make_atari_env
import os

# 2. Test Environment

In [2]:
environment_name = "Breakout-v0"

In [3]:
env = gym.make(environment_name)

In [7]:
episodes = 5
for episode in range(1, episodes+1):
    state = env.reset()
    done = False
    score = 0 
    
    while not done:
        env.render(mode='human')
        action = env.action_space.sample()
        n_state, reward, done, info = env.step(action)
        score+=reward
    print('Episode:{} Score:{}'.format(episode, score))
env.close()

Episode:1 Score:3.0
Episode:2 Score:1.0
Episode:3 Score:0.0
Episode:4 Score:1.0
Episode:5 Score:0.0


In [5]:
env.action_space.sample()

2

In [6]:
env.observation_space.sample()

array([[[182, 173,  79],
        [ 20,   0, 242],
        [221, 219,  12],
        ...,
        [253, 238, 213],
        [245, 112, 176],
        [240, 127,  38]],

       [[137, 125, 238],
        [106,  59,  94],
        [ 39,  48,  98],
        ...,
        [233, 131, 221],
        [183, 126, 219],
        [237, 150, 242]],

       [[165, 253,  43],
        [  4,  58, 220],
        [ 99,  42, 216],
        ...,
        [ 38, 198, 246],
        [140, 133, 206],
        [107,  88, 138]],

       ...,

       [[126,  66, 212],
        [ 26, 184,  94],
        [ 44, 169,  48],
        ...,
        [124,  24, 204],
        [103,  24, 210],
        [163, 100, 178]],

       [[ 13, 251, 198],
        [169, 178, 150],
        [ 77, 208,  41],
        ...,
        [ 98,  46, 144],
        [ 13,  77,   2],
        [ 71, 255, 246]],

       [[ 21,  56, 158],
        [224, 209, 105],
        [117,  58, 221],
        ...,
        [213, 184,  96],
        [209, 162, 128],
        [137,  69, 128]]

# 3. Vectorise Environment and Train Model

In [8]:
env = make_atari_env('Breakout-v0', n_envs=4, seed=0)

In [9]:
env = VecFrameStack(env, n_stack=4)

In [10]:
log_path = os.path.join('Training', 'Logs')

In [14]:
print(log_path)

Training\Logs


In [11]:
model = A2C("CnnPolicy", env, verbose=1, tensorboard_log=log_path)

Using cpu device
Wrapping the env in a VecTransposeImage.


In [12]:
model.learn(total_timesteps=400000)

Logging to Training\Logs\A2C_1
------------------------------------
| rollout/              |          |
|    ep_len_mean        | 282      |
|    ep_rew_mean        | 1.62     |
| time/                 |          |
|    fps                | 155      |
|    iterations         | 100      |
|    time_elapsed       | 12       |
|    total_timesteps    | 2000     |
| train/                |          |
|    entropy_loss       | -1.38    |
|    explained_variance | 0.0315   |
|    learning_rate      | 0.0007   |
|    n_updates          | 99       |
|    policy_loss        | -0.0635  |
|    value_loss         | 0.0081   |
------------------------------------
------------------------------------
| rollout/              |          |
|    ep_len_mean        | 293      |
|    ep_rew_mean        | 1.86     |
| time/                 |          |
|    fps                | 157      |
|    iterations         | 200      |
|    time_elapsed       | 25       |
|    total_timesteps    | 4000     |
| train

------------------------------------
| rollout/              |          |
|    ep_len_mean        | 389      |
|    ep_rew_mean        | 3.88     |
| time/                 |          |
|    fps                | 161      |
|    iterations         | 1400     |
|    time_elapsed       | 173      |
|    total_timesteps    | 28000    |
| train/                |          |
|    entropy_loss       | -0.616   |
|    explained_variance | 0.936    |
|    learning_rate      | 0.0007   |
|    n_updates          | 1399     |
|    policy_loss        | 0.0945   |
|    value_loss         | 0.0413   |
------------------------------------
------------------------------------
| rollout/              |          |
|    ep_len_mean        | 400      |
|    ep_rew_mean        | 4.2      |
| time/                 |          |
|    fps                | 161      |
|    iterations         | 1500     |
|    time_elapsed       | 185      |
|    total_timesteps    | 30000    |
| train/                |          |
|

------------------------------------
| rollout/              |          |
|    ep_len_mean        | 466      |
|    ep_rew_mean        | 5.6      |
| time/                 |          |
|    fps                | 164      |
|    iterations         | 2800     |
|    time_elapsed       | 339      |
|    total_timesteps    | 56000    |
| train/                |          |
|    entropy_loss       | -0.468   |
|    explained_variance | 0.843    |
|    learning_rate      | 0.0007   |
|    n_updates          | 2799     |
|    policy_loss        | -0.0104  |
|    value_loss         | 0.14     |
------------------------------------
------------------------------------
| rollout/              |          |
|    ep_len_mean        | 464      |
|    ep_rew_mean        | 5.54     |
| time/                 |          |
|    fps                | 165      |
|    iterations         | 2900     |
|    time_elapsed       | 351      |
|    total_timesteps    | 58000    |
| train/                |          |
|

------------------------------------
| rollout/              |          |
|    ep_len_mean        | 482      |
|    ep_rew_mean        | 5.99     |
| time/                 |          |
|    fps                | 165      |
|    iterations         | 4200     |
|    time_elapsed       | 506      |
|    total_timesteps    | 84000    |
| train/                |          |
|    entropy_loss       | -1.03    |
|    explained_variance | 0.416    |
|    learning_rate      | 0.0007   |
|    n_updates          | 4199     |
|    policy_loss        | 0.329    |
|    value_loss         | 0.361    |
------------------------------------
------------------------------------
| rollout/              |          |
|    ep_len_mean        | 490      |
|    ep_rew_mean        | 6.04     |
| time/                 |          |
|    fps                | 165      |
|    iterations         | 4300     |
|    time_elapsed       | 518      |
|    total_timesteps    | 86000    |
| train/                |          |
|

------------------------------------
| rollout/              |          |
|    ep_len_mean        | 524      |
|    ep_rew_mean        | 6.79     |
| time/                 |          |
|    fps                | 166      |
|    iterations         | 5600     |
|    time_elapsed       | 672      |
|    total_timesteps    | 112000   |
| train/                |          |
|    entropy_loss       | -0.514   |
|    explained_variance | 0.59     |
|    learning_rate      | 0.0007   |
|    n_updates          | 5599     |
|    policy_loss        | -0.0379  |
|    value_loss         | 0.199    |
------------------------------------
------------------------------------
| rollout/              |          |
|    ep_len_mean        | 526      |
|    ep_rew_mean        | 6.84     |
| time/                 |          |
|    fps                | 166      |
|    iterations         | 5700     |
|    time_elapsed       | 684      |
|    total_timesteps    | 114000   |
| train/                |          |
|

------------------------------------
| rollout/              |          |
|    ep_len_mean        | 544      |
|    ep_rew_mean        | 7.13     |
| time/                 |          |
|    fps                | 165      |
|    iterations         | 7000     |
|    time_elapsed       | 844      |
|    total_timesteps    | 140000   |
| train/                |          |
|    entropy_loss       | -0.358   |
|    explained_variance | 0.973    |
|    learning_rate      | 0.0007   |
|    n_updates          | 6999     |
|    policy_loss        | 0.0211   |
|    value_loss         | 0.0382   |
------------------------------------
------------------------------------
| rollout/              |          |
|    ep_len_mean        | 547      |
|    ep_rew_mean        | 7.16     |
| time/                 |          |
|    fps                | 165      |
|    iterations         | 7100     |
|    time_elapsed       | 856      |
|    total_timesteps    | 142000   |
| train/                |          |
|

------------------------------------
| rollout/              |          |
|    ep_len_mean        | 548      |
|    ep_rew_mean        | 7.14     |
| time/                 |          |
|    fps                | 165      |
|    iterations         | 8400     |
|    time_elapsed       | 1012     |
|    total_timesteps    | 168000   |
| train/                |          |
|    entropy_loss       | -0.339   |
|    explained_variance | 0.526    |
|    learning_rate      | 0.0007   |
|    n_updates          | 8399     |
|    policy_loss        | -0.05    |
|    value_loss         | 0.0808   |
------------------------------------
------------------------------------
| rollout/              |          |
|    ep_len_mean        | 561      |
|    ep_rew_mean        | 7.41     |
| time/                 |          |
|    fps                | 165      |
|    iterations         | 8500     |
|    time_elapsed       | 1024     |
|    total_timesteps    | 170000   |
| train/                |          |
|

------------------------------------
| rollout/              |          |
|    ep_len_mean        | 603      |
|    ep_rew_mean        | 8.48     |
| time/                 |          |
|    fps                | 166      |
|    iterations         | 9800     |
|    time_elapsed       | 1178     |
|    total_timesteps    | 196000   |
| train/                |          |
|    entropy_loss       | -0.336   |
|    explained_variance | 0.687    |
|    learning_rate      | 0.0007   |
|    n_updates          | 9799     |
|    policy_loss        | -0.0417  |
|    value_loss         | 0.191    |
------------------------------------
------------------------------------
| rollout/              |          |
|    ep_len_mean        | 598      |
|    ep_rew_mean        | 8.28     |
| time/                 |          |
|    fps                | 166      |
|    iterations         | 9900     |
|    time_elapsed       | 1191     |
|    total_timesteps    | 198000   |
| train/                |          |
|

------------------------------------
| rollout/              |          |
|    ep_len_mean        | 632      |
|    ep_rew_mean        | 8.96     |
| time/                 |          |
|    fps                | 166      |
|    iterations         | 11200    |
|    time_elapsed       | 1346     |
|    total_timesteps    | 224000   |
| train/                |          |
|    entropy_loss       | -0.34    |
|    explained_variance | 0.352    |
|    learning_rate      | 0.0007   |
|    n_updates          | 11199    |
|    policy_loss        | 0.0326   |
|    value_loss         | 0.203    |
------------------------------------
------------------------------------
| rollout/              |          |
|    ep_len_mean        | 629      |
|    ep_rew_mean        | 8.87     |
| time/                 |          |
|    fps                | 166      |
|    iterations         | 11300    |
|    time_elapsed       | 1359     |
|    total_timesteps    | 226000   |
| train/                |          |
|

------------------------------------
| rollout/              |          |
|    ep_len_mean        | 611      |
|    ep_rew_mean        | 8.72     |
| time/                 |          |
|    fps                | 166      |
|    iterations         | 12500    |
|    time_elapsed       | 1503     |
|    total_timesteps    | 250000   |
| train/                |          |
|    entropy_loss       | -0.637   |
|    explained_variance | 0.499    |
|    learning_rate      | 0.0007   |
|    n_updates          | 12499    |
|    policy_loss        | -0.0611  |
|    value_loss         | 0.0404   |
------------------------------------
------------------------------------
| rollout/              |          |
|    ep_len_mean        | 622      |
|    ep_rew_mean        | 9.07     |
| time/                 |          |
|    fps                | 166      |
|    iterations         | 12600    |
|    time_elapsed       | 1515     |
|    total_timesteps    | 252000   |
| train/                |          |
|

------------------------------------
| rollout/              |          |
|    ep_len_mean        | 667      |
|    ep_rew_mean        | 9.86     |
| time/                 |          |
|    fps                | 166      |
|    iterations         | 13900    |
|    time_elapsed       | 1669     |
|    total_timesteps    | 278000   |
| train/                |          |
|    entropy_loss       | -0.341   |
|    explained_variance | 0.169    |
|    learning_rate      | 0.0007   |
|    n_updates          | 13899    |
|    policy_loss        | 0.00358  |
|    value_loss         | 0.366    |
------------------------------------
------------------------------------
| rollout/              |          |
|    ep_len_mean        | 663      |
|    ep_rew_mean        | 9.9      |
| time/                 |          |
|    fps                | 166      |
|    iterations         | 14000    |
|    time_elapsed       | 1681     |
|    total_timesteps    | 280000   |
| train/                |          |
|

------------------------------------
| rollout/              |          |
|    ep_len_mean        | 659      |
|    ep_rew_mean        | 9.6      |
| time/                 |          |
|    fps                | 166      |
|    iterations         | 15300    |
|    time_elapsed       | 1840     |
|    total_timesteps    | 306000   |
| train/                |          |
|    entropy_loss       | -0.164   |
|    explained_variance | 0.823    |
|    learning_rate      | 0.0007   |
|    n_updates          | 15299    |
|    policy_loss        | 0.00451  |
|    value_loss         | 0.177    |
------------------------------------
------------------------------------
| rollout/              |          |
|    ep_len_mean        | 673      |
|    ep_rew_mean        | 10       |
| time/                 |          |
|    fps                | 166      |
|    iterations         | 15400    |
|    time_elapsed       | 1852     |
|    total_timesteps    | 308000   |
| train/                |          |
|

------------------------------------
| rollout/              |          |
|    ep_len_mean        | 702      |
|    ep_rew_mean        | 11       |
| time/                 |          |
|    fps                | 165      |
|    iterations         | 16700    |
|    time_elapsed       | 2012     |
|    total_timesteps    | 334000   |
| train/                |          |
|    entropy_loss       | -0.165   |
|    explained_variance | 0.795    |
|    learning_rate      | 0.0007   |
|    n_updates          | 16699    |
|    policy_loss        | -0.0858  |
|    value_loss         | 0.521    |
------------------------------------
------------------------------------
| rollout/              |          |
|    ep_len_mean        | 677      |
|    ep_rew_mean        | 10.5     |
| time/                 |          |
|    fps                | 165      |
|    iterations         | 16800    |
|    time_elapsed       | 2025     |
|    total_timesteps    | 336000   |
| train/                |          |
|

------------------------------------
| rollout/              |          |
|    ep_len_mean        | 697      |
|    ep_rew_mean        | 10.6     |
| time/                 |          |
|    fps                | 164      |
|    iterations         | 18100    |
|    time_elapsed       | 2195     |
|    total_timesteps    | 362000   |
| train/                |          |
|    entropy_loss       | -0.208   |
|    explained_variance | 0.603    |
|    learning_rate      | 0.0007   |
|    n_updates          | 18099    |
|    policy_loss        | -0.0672  |
|    value_loss         | 0.116    |
------------------------------------
------------------------------------
| rollout/              |          |
|    ep_len_mean        | 691      |
|    ep_rew_mean        | 10.4     |
| time/                 |          |
|    fps                | 164      |
|    iterations         | 18200    |
|    time_elapsed       | 2207     |
|    total_timesteps    | 364000   |
| train/                |          |
|

------------------------------------
| rollout/              |          |
|    ep_len_mean        | 709      |
|    ep_rew_mean        | 11.2     |
| time/                 |          |
|    fps                | 164      |
|    iterations         | 19500    |
|    time_elapsed       | 2370     |
|    total_timesteps    | 390000   |
| train/                |          |
|    entropy_loss       | -0.22    |
|    explained_variance | 0.897    |
|    learning_rate      | 0.0007   |
|    n_updates          | 19499    |
|    policy_loss        | 0.0215   |
|    value_loss         | 0.0531   |
------------------------------------
------------------------------------
| rollout/              |          |
|    ep_len_mean        | 713      |
|    ep_rew_mean        | 11.2     |
| time/                 |          |
|    fps                | 164      |
|    iterations         | 19600    |
|    time_elapsed       | 2382     |
|    total_timesteps    | 392000   |
| train/                |          |
|

<stable_baselines3.a2c.a2c.A2C at 0x29d5cfebe20>

# 4. Save and Reload Model

In [13]:
a2c_path = os.path.join('Training', 'Saved Models', 'A2C_model')

In [15]:
model.save(a2c_path)

In [16]:
del model

In [17]:
env = make_atari_env('Breakout-v0', n_envs=1, seed=0)
env = VecFrameStack(env, n_stack=4)

In [18]:
model = A2C.load(a2c_path, env)

Wrapping the env in a VecTransposeImage.


# 5. Evaluate and Test

In [19]:
evaluate_policy(model, env, n_eval_episodes=10, render=True)

(10.7, 2.491987158875422)

In [None]:
obs = env.reset()
while True:
    action, _states = model.predict(obs)
    obs, rewards, dones, info = env.step(action)
    env.render()

In [None]:
env.close()