Import Dependencies

In [13]:
import gym
from stable_baselines3 import A2C
from stable_baselines3.common.vec_env import VecFrameStack
from stable_baselines3.common.env_util import make_atari_env
from stable_baselines3.common.evaluation import evaluate_policy
import os

Test Environment

In [2]:
!python3 -m atari_py.import_roms /Roms/ROMS

In [3]:
environment_name = 'Breakout-v4'
env = gym.make(environment_name, render_mode='human')
# env.unwrapped.get_action_meanings()

A.L.E: Arcade Learning Environment (version 0.8.1+53f58b7)
[Powered by Stella]


In [4]:
env.metadata['render_fps'] = 150  # Or whichever fps value you prefer

In [5]:
env.reset()

(array([[[0, 0, 0],
         [0, 0, 0],
         [0, 0, 0],
         ...,
         [0, 0, 0],
         [0, 0, 0],
         [0, 0, 0]],
 
        [[0, 0, 0],
         [0, 0, 0],
         [0, 0, 0],
         ...,
         [0, 0, 0],
         [0, 0, 0],
         [0, 0, 0]],
 
        [[0, 0, 0],
         [0, 0, 0],
         [0, 0, 0],
         ...,
         [0, 0, 0],
         [0, 0, 0],
         [0, 0, 0]],
 
        ...,
 
        [[0, 0, 0],
         [0, 0, 0],
         [0, 0, 0],
         ...,
         [0, 0, 0],
         [0, 0, 0],
         [0, 0, 0]],
 
        [[0, 0, 0],
         [0, 0, 0],
         [0, 0, 0],
         ...,
         [0, 0, 0],
         [0, 0, 0],
         [0, 0, 0]],
 
        [[0, 0, 0],
         [0, 0, 0],
         [0, 0, 0],
         ...,
         [0, 0, 0],
         [0, 0, 0],
         [0, 0, 0]]], dtype=uint8),
 {'lives': 5, 'episode_frame_number': 0, 'frame_number': 0})

In [6]:
env.action_space

Discrete(4)

In [7]:
env.observation_space

Box(0, 255, (210, 160, 3), uint8)

In [8]:
episodes = 1
for episode in range (1, episodes+1):
    obs = env.reset()
    done = False
    score = 0

    while not done:
        env.render()
        
        action = env.action_space.sample()
        n_state, reward, done, truncated, info = env.step(action)
        score+=reward
    print('Episode:{} Score:{}'.format(episode, score))
env.close()


Episode:1 Score:2.0


In [9]:
env.close()

Vectorise Environment and Train Model

In [9]:
env = make_atari_env(environment_name, n_envs=4, seed=0)
env = VecFrameStack(env, n_stack=4)

In [14]:
log_path = os.path.join('Training', 'Logs') # Save logs in Training/Logs folder
model = A2C('CnnPolicy', env, verbose=1, tensorboard_log=log_path) # Create the model

Using cuda device
Wrapping the env in a VecTransposeImage.


In [15]:
model.learn(total_timesteps=100000) # Train the model

Logging to Training/Logs/A2C_1
------------------------------------
| rollout/              |          |
|    ep_len_mean        | 300      |
|    ep_rew_mean        | 2        |
| time/                 |          |
|    fps                | 227      |
|    iterations         | 100      |
|    time_elapsed       | 8        |
|    total_timesteps    | 2000     |
| train/                |          |
|    entropy_loss       | -1.38    |
|    explained_variance | -0.0463  |
|    learning_rate      | 0.0007   |
|    n_updates          | 99       |
|    policy_loss        | 0.257    |
|    value_loss         | 0.202    |
------------------------------------
------------------------------------
| rollout/              |          |
|    ep_len_mean        | 323      |
|    ep_rew_mean        | 2.39     |
| time/                 |          |
|    fps                | 251      |
|    iterations         | 200      |
|    time_elapsed       | 15       |
|    total_timesteps    | 4000     |
| train

<stable_baselines3.a2c.a2c.A2C at 0x7f26ffaca080>

Save and Reload Model

In [16]:
A2C_path = os.path.join('Trainig', 'Saved Models', 'A2C_Breakout_Model')
model.save(A2C_path) # Save the model



In [17]:
del model # Delete the model

In [18]:
model = A2C.load(A2C_path, env) # Load the saved model

Wrapping the env in a VecTransposeImage.


Evaluate and Test

In [19]:
env = make_atari_env('Breakout-v4', n_envs=1, seed=0)
env = VecFrameStack(env, n_stack=4)

In [22]:
evaluate_policy(model, env, n_eval_episodes=10, render=True) # Evaluate the model

(4.0, 1.7888543819998317)