In [1]:
# To plot pretty figures
%matplotlib inline
import matplotlib as mpl
import matplotlib.pyplot as plt
mpl.rc('axes', labelsize=14)
mpl.rc('xtick', labelsize=12)
mpl.rc('ytick', labelsize=12)

# To get smooth animations
import matplotlib.animation as animation
mpl.rc('animation', html='jshtml')

In [2]:
from tqdm import tqdm

In [3]:
import tensorflow as tf
from tensorflow import keras

# The TF-Agents Library

In [4]:
from tf_agents.environments import suite_gym

env = suite_gym.load("Breakout-v4")

In [5]:
env

<tf_agents.environments.wrappers.TimeLimit at 0x137f3526ee0>

In [6]:
env.gym

<TimeLimit<AtariEnv<Breakout-v4>>>

In [7]:
# Reset
env.reset()

TimeStep(
{'discount': array(1., dtype=float32),
 'observation': array([[[0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        ...,
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0]],

       [[0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        ...,
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0]],

       [[0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        ...,
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0]],

       ...,

       [[0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        ...,
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0]],

       [[0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        ...,
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0]],

       [[0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        ...,
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0]]], dtype=uint8),
 'reward': array(0., dtype=float32),
 'step_type': array(0)})

In [8]:
env.step(1) # Fire

TimeStep(
{'discount': array(1., dtype=float32),
 'observation': array([[[0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        ...,
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0]],

       [[0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        ...,
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0]],

       [[0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        ...,
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0]],

       ...,

       [[0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        ...,
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0]],

       [[0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        ...,
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0]],

       [[0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        ...,
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0]]], dtype=uint8),
 'reward': array(0., dtype=float32),
 'step_type': array(1)})

## Environment Specification

In [9]:
env.observation_spec()

BoundedArraySpec(shape=(210, 160, 3), dtype=dtype('uint8'), name='observation', minimum=0, maximum=255)

In [10]:
env.action_spec()

BoundedArraySpec(shape=(), dtype=dtype('int64'), name='action', minimum=0, maximum=3)

In [11]:
env.time_step_spec()

TimeStep(
{'discount': BoundedArraySpec(shape=(), dtype=dtype('float32'), name='discount', minimum=0.0, maximum=1.0),
 'observation': BoundedArraySpec(shape=(210, 160, 3), dtype=dtype('uint8'), name='observation', minimum=0, maximum=255),
 'reward': ArraySpec(shape=(), dtype=dtype('float32'), name='reward'),
 'step_type': ArraySpec(shape=(), dtype=dtype('int32'), name='step_type')})

In [12]:
env.gym.get_action_meanings()

['NOOP', 'FIRE', 'RIGHT', 'LEFT']

## Environment Wrappers and Preprocessing

In [13]:
from tf_agents.environments.wrappers import ActionRepeat

repeating_env = ActionRepeat(env, times=4)

In [14]:
from gym.wrappers import TimeLimit

# Note: The following funtion, suite_gym.load(), is the same function we used 
# at the begining of this notebook but with extra arguments
limited_repeating_env = suite_gym.load(
    "Breakout-v4",
    gym_env_wrappers=[lambda env: TimeLimit(env, max_episode_steps=10000)],
    env_wrappers=[lambda env: ActionRepeat(env, times=4)]) 

In [15]:
# Atari wrappers
from tf_agents.environments import suite_atari
from tf_agents.environments.atari_preprocessing import AtariPreprocessing
from tf_agents.environments.atari_wrappers import FrameStack4

max_episode_steps = 27000 # <=> 108k frames since 1 step = 4 frames
environment_name = "BreakoutNoFrameskip-v4"

env = suite_atari.load(
    environment_name,
    max_episode_steps=max_episode_steps,
    gym_env_wrappers=[AtariPreprocessing, FrameStack4])

In [16]:
from tf_agents.environments.tf_py_environment import TFPyEnvironment

tf_env = TFPyEnvironment(env)

## Creating the Deep Q-Network

In [17]:
from tf_agents.networks.q_network import QNetwork

preprocessing_layer = keras.layers.Lambda(
                                    lambda obs: tf.cast(obs, np.float32) / 255.)
conv_layer_params = [(32, (8,8), 4), (64, (4,4), 2), (64, (3,3), 1)]
fc_layer_params = [512]

q_net = QNetwork(
            tf_env.observation_spec(),
            tf_env.action_spec(),
            preprocessing_layers=preprocessing_layer,
            conv_layer_params=conv_layer_params,
            fc_layer_params=fc_layer_params)

## Creating the DQN Agent

In [18]:
from tf_agents.agents.dqn.dqn_agent import DqnAgent

train_step = tf.Variable(0)
update_period = 4 # train the model every 4 steps
optimizer = keras.optimizers.RMSprop(lr=2.5e-4, rho=0.95, momentum=0.0,
                                     epsilon=0.00001, centered=True)
epsilon_fn = keras.optimizers.schedules.PolynomialDecay(
                initial_learning_rate=1.0, # initial epsilon
                decay_steps=250000 // update_period, # <=> 1,000,000 ALE frames
                end_learning_rate=0.01) # final epsilon
agent = DqnAgent(tf_env.time_step_spec(),
                 tf_env.action_spec(),
                 q_network=q_net,
                 optimizer=optimizer,
                 target_update_period=2000, # <=> 32.000 ALE frames
                 td_errors_loss_fn=keras.losses.Huber(reduction="none"),
                 gamma=0.99, # discount factor
                 train_step_counter=train_step,
                 epsilon_greedy=lambda: epsilon_fn(train_step))
agent.initialize()

  super(RMSprop, self).__init__(name, **kwargs)


## Creating the Replay Buffer and the Corresponding Observer

In [19]:
from tf_agents.replay_buffers import tf_uniform_replay_buffer

replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
                    data_spec=agent.collect_data_spec,
                    batch_size=tf_env.batch_size,
                    max_length=1000000)

In [20]:
# An observer
replay_buffer_observer = replay_buffer.add_batch

In [21]:
# Custom observer with internal state
class ShowProgress:
    def __init__(self, total):
        self.counter = 0
        self.total = total
    def __call__(self, trajectory):
        if not trajectory.is_boundary():
            self.counter += 1
        if self.counter % 100 == 0:
            print("\r{}/{}".format(self.counter, self.total), end="")

## Creating Training Metrics

In [22]:
from tf_agents.metrics import tf_metrics

train_metrics = [
        tf_metrics.NumberOfEpisodes(),
        tf_metrics.EnvironmentSteps(),
        tf_metrics.AverageReturnMetric(),
        tf_metrics.AverageEpisodeLengthMetric(),
]

In [23]:
# logging the results of metrics
from tf_agents.eval.metric_utils import log_metrics
import logging

# logging.get_logger().set_level(logging.INFO) # Old version
logging.getLogger().setLevel(logging.INFO) # New version
log_metrics(train_metrics)

INFO:absl: 
		 NumberOfEpisodes = 0
		 EnvironmentSteps = 0
		 AverageReturn = 0.0
		 AverageEpisodeLength = 0.0


## Creating the Collect Driver

In [24]:
from tf_agents.drivers.dynamic_step_driver import DynamicStepDriver

collect_driver = DynamicStepDriver(
                        tf_env,
                        agent.collect_policy,
                        observers=[replay_buffer_observer] + train_metrics,
                        num_steps=update_period) # collect 4 steps for each training iteration

Pre-fill in the replay buffer using random policy

In [25]:
from tf_agents.policies.random_tf_policy import RandomTFPolicy

initial_collect_policy = RandomTFPolicy(tf_env.time_step_spec(),
                                        tf_env.action_spec())
init_driver = DynamicStepDriver(
                        tf_env,
                        initial_collect_policy,
                        observers=[replay_buffer.add_batch, ShowProgress(20000)],
                        num_steps=20000) # <=> 80,000 ALE frames
final_time_step, final_policy_state = init_driver.run()

20000/20000

## Creating the Dataset

In [26]:
trajectories, buffer_info = replay_buffer.get_next(
                                sample_batch_size=2, num_steps=3)

Instructions for updating:
Use `as_dataset(..., single_deterministic_pass=False) instead.


Instructions for updating:
Use `as_dataset(..., single_deterministic_pass=False) instead.


In [27]:
trajectories._fields

('step_type',
 'observation',
 'action',
 'policy_info',
 'next_step_type',
 'reward',
 'discount')

In [28]:
trajectories.observation.shape

TensorShape([2, 3, 84, 84, 4])

In [29]:
trajectories.step_type.numpy()

array([[1, 1, 1],
       [1, 1, 1]])

In [30]:
# to transitions
from tf_agents.trajectories.trajectory import to_transition

time_steps, action_steps, next_time_steps = to_transition(trajectories)
time_steps.observation.shape

TensorShape([2, 2, 84, 84, 4])

In [31]:
# Convert to dataset
dataset = replay_buffer.as_dataset(
                            sample_batch_size=64,
                            num_steps=2,
                            num_parallel_calls=3).prefetch(3)

## Creating the Training Loop

In [32]:
from tf_agents.utils.common import function

collect_driver.run = function(collect_driver.run)
agent.train = function(agent.train)

In [50]:
def train_agent(n_iterations):
    time_step = None
    policy_state = agent.collect_policy.get_initial_state(tf_env.batch_size)
    iterator = iter(dataset)
#     pbar = tqdm(range(n_iterations))
    for iteration in (pbar := tqdm(range(n_iterations))): # Note: walrus operator ":=" defines the pbar variable as the tqdm(...) instance
        time_step, policy_state = collect_driver.run(time_step, policy_state)
        trajectories, buffer_info = next(iterator)
        train_loss = agent.train(trajectories)
#         print("\r{} loss:{:.5}".format(iteration, train_loss.loss.numpy()), end="")
        pbar.set_description( "loss:\t%.5f" % train_loss.loss.numpy() )
        if iteration % 1000 == 0:
            log_metrics(train_metrics)

Train!!!!

In [51]:
train_agent(10000000)

loss:	0.00007:   0%|          | 0/10000000 [00:00<?, ?it/s]INFO:absl: 
		 NumberOfEpisodes = 251
		 EnvironmentSteps = 44208
		 AverageReturn = 1.7000000476837158
		 AverageEpisodeLength = 206.6999969482422
loss:	0.00046:   0%|          | 1000/10000000 [01:15<208:13:08, 13.34it/s]INFO:absl: 
		 NumberOfEpisodes = 273
		 EnvironmentSteps = 48208
		 AverageReturn = 0.800000011920929
		 AverageEpisodeLength = 166.89999389648438
loss:	0.00025:   0%|          | 2000/10000000 [02:23<182:36:17, 15.21it/s]INFO:absl: 
		 NumberOfEpisodes = 294
		 EnvironmentSteps = 52208
		 AverageReturn = 1.100000023841858
		 AverageEpisodeLength = 188.0
loss:	0.00043:   0%|          | 3000/10000000 [03:30<185:04:40, 15.00it/s]INFO:absl: 
		 NumberOfEpisodes = 316
		 EnvironmentSteps = 56208
		 AverageReturn = 1.2000000476837158
		 AverageEpisodeLength = 188.5
loss:	0.00005:   0%|          | 4000/10000000 [04:45<200:51:56, 13.82it/s]INFO:absl: 
		 NumberOfEpisodes = 337
		 EnvironmentSteps = 60208
		 AverageRe

KeyboardInterrupt: 