In [1]:
import numpy as np
import tensorflow as tf
from tf_agents.networks.q_network import QNetwork
import tf_agents.networks.network as network
from tf_agents.specs import tensor_spec
from snake_game import SnakeGame
from scene import Scene
from tensorflow.keras import Sequential
from tensorflow.keras.layers import Dense, Flatten, Conv2D

scene = Scene(init_randomly=True)
episodes_count = 5000


In [2]:
from tf_agents.environments.tf_py_environment import TFPyEnvironment

env = SnakeGame(scene)
env = TFPyEnvironment(env)


In [3]:
conv_layer_params = [(32, (2, 2), 1), (64, (2, 2), 1)]
fc_layer_params = [128, 16]

q_net = QNetwork(
    env.observation_spec(),
    env.action_spec(),
    conv_layer_params=conv_layer_params,
    fc_layer_params=fc_layer_params)


In [3]:
class MyQNetwork(network.Network):
    def __init__(self,
                 input_tensor_spec,
                 action_spec,
                 name='MyQNetwork'):
        super(MyQNetwork, self).__init__(
            input_tensor_spec=input_tensor_spec,
            state_spec=(),
            name=name)
        
        self._action_spec = action_spec

        input_shape = tf.expand_dims(input_tensor_spec, axis=0)
        # Create your custom model
        self._model = Sequential([
            Conv2D(32, (2, 2), 1, input_shape=input_shape, activation='relu', kernel_initializer='he_normal'),
            Conv2D(64, (2, 2), 1, activation='relu', kernel_initializer='he_normal'),
            Flatten(),
            Dense(128, activation='relu', kernel_initializer='he_normal'),
            Dense(16, activation='relu', kernel_initializer='he_normal'),
            Dense(4)
        ])

    def call(self, observations, step_type=None, network_state=(), training=False):
        output = self._model(observations)
        return output, network_state

    def create_variables(self, input_tensor_spec, **kwargs):
        self(tf.zeros(input_tensor_spec.shape, input_tensor_spec.dtype))

    def copy(self, name=''):
        pass

q_net = MyQNetwork(
    env.observation_spec(),
    env.action_spec())


ValueError: Attempt to convert a value (BoundedTensorSpec(shape=(5, 5, 4), dtype=tf.float32, name='observation', minimum=array([0., 0., 0., 0.], dtype=float32), maximum=array([1., 1., 1., 1.], dtype=float32))) with an unsupported type (<class 'tensorflow.python.framework.tensor.BoundedTensorSpec'>) to a Tensor.

In [32]:
# Create the agent
from tf_agents.agents.dqn.dqn_agent import DqnAgent
from tensorflow import keras
from keras.optimizers.legacy import Adam
from keras.losses import Huber
from tf_agents.trajectories import TimeStep

train_step_counter = tf.Variable(0)
optimizer = Adam(learning_rate=0.003)

# epsilon = lambda train_step: (1 / (tf.cast(tf.linspace(1, 8, episodes_count)**3, tf.float32)))[train_step]
epsilon = keras.optimizers.schedules.PolynomialDecay(
    initial_learning_rate=1.0,
    decay_steps=episodes_count,
    end_learning_rate=0.01
)

agent = DqnAgent(
    time_step_spec=env.time_step_spec(),
    action_spec=env.action_spec(),
    q_network=q_net,
    optimizer=optimizer,
    td_errors_loss_fn=Huber(reduction="none"),
    gamma=tf.constant(0.95, dtype=tf.float32),
    train_step_counter=train_step_counter,
    epsilon_greedy=lambda: epsilon(train_step_counter)
)


ValueError: Exception encountered when calling layer 'MyQNetwork' (type MyQNetwork).

Input 0 of layer "sequential" is incompatible with the layer: expected shape=(None, 5, 5, 4), found shape=(5, 5, 4)

Call arguments received by layer 'MyQNetwork' (type MyQNetwork):
  • observations=tf.Tensor(shape=(5, 5, 4), dtype=float32)
  • step_type=None
  • network_state=()
  • training=False
  In call to configurable 'DqnAgent' (<class 'tf_agents.agents.dqn.dqn_agent.DqnAgent'>)

In [5]:
# Create the replay buffer
from tf_agents.replay_buffers import tf_uniform_replay_buffer

replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
    data_spec=agent.collect_data_spec,
    batch_size=env.batch_size,
    max_length=100000
)


In [6]:
# Create the training metrics
from tf_agents.metrics import tf_metrics

train_metrics = [
    tf_metrics.NumberOfEpisodes(),
    tf_metrics.EnvironmentSteps(),
    tf_metrics.AverageReturnMetric(),
    tf_metrics.AverageEpisodeLengthMetric()
]


In [7]:
# Create the driver
from tf_agents.drivers import dynamic_step_driver

collect_driver = dynamic_step_driver.DynamicStepDriver(
    env,
    policy=agent.collect_policy,
    observers=[replay_buffer.add_batch] + train_metrics,
    num_steps=4
)

# Run a random policy to fill the replay buffer
from tf_agents.policies.random_tf_policy import RandomTFPolicy

initial_collect_policy = RandomTFPolicy(env.time_step_spec(), env.action_spec())
init_driver = dynamic_step_driver.DynamicStepDriver(
    env,
    policy=initial_collect_policy,
    observers=[replay_buffer.add_batch],
    num_steps=20000
)

final_time_step, final_policy_state = init_driver.run()


In [8]:
# Create the dataset
dataset = replay_buffer.as_dataset(
    sample_batch_size=64,
    num_steps=2, # Capaz cambiar esto
    num_parallel_calls=3
).prefetch(3)


Instructions for updating:
Use `tf.data.Dataset.counter(...)` instead.
Instructions for updating:
Use `as_dataset(..., single_deterministic_pass=False) instead.


In [9]:
# Create the training loop

from tf_agents.utils.common import function
from tf_agents.trajectories import trajectory
from tf_agents.eval.metric_utils import log_metrics
import logging

logging.getLogger().setLevel(logging.INFO)

collect_driver.run = function(collect_driver.run)
agent.train = function(agent.train)

def train_agent(n_iterations):
    time_step = None
    policy_state = agent.collect_policy.get_initial_state(env.batch_size)
    iterator = iter(dataset)
    for iteration in range(n_iterations):
        time_step, policy_state = collect_driver.run(time_step, policy_state)
        trajectories, buffer_info = next(iterator)
        train_loss = agent.train(trajectories)
        if iteration % 1000 == 0:
            print("Iteration: ", iteration)
            log_metrics(train_metrics)


In [10]:
train_agent(episodes_count)


Instructions for updating:
back_prop=False is deprecated. Consider using tf.stop_gradient instead.
Instead of:
results = tf.foldr(fn, elems, back_prop=False)
Use:
results = tf.nest.map_structure(tf.stop_gradient, tf.foldr(fn, elems))


INFO:absl: 
		 NumberOfEpisodes = 0
		 EnvironmentSteps = 4
		 AverageReturn = 0.0
		 AverageEpisodeLength = 0.0


Iteration:  0


INFO:absl: 
		 NumberOfEpisodes = 3037
		 EnvironmentSteps = 4004
		 AverageReturn = -44988.69921875
		 AverageEpisodeLength = 0.699999988079071


Iteration:  1000


INFO:absl: 
		 NumberOfEpisodes = 5839
		 EnvironmentSteps = 8004
		 AverageReturn = -86282.5
		 AverageEpisodeLength = 1.5


Iteration:  2000


INFO:absl: 
		 NumberOfEpisodes = 8616
		 EnvironmentSteps = 12004
		 AverageReturn = -127687.1015625
		 AverageEpisodeLength = 1.399999976158142


Iteration:  3000


INFO:absl: 
		 NumberOfEpisodes = 11441
		 EnvironmentSteps = 16004
		 AverageReturn = -169235.59375
		 AverageEpisodeLength = 2.200000047683716


Iteration:  4000


INFO:absl: 
		 NumberOfEpisodes = 14158
		 EnvironmentSteps = 20004
		 AverageReturn = -209288.203125
		 AverageEpisodeLength = 0.800000011920929


Iteration:  5000


INFO:absl: 
		 NumberOfEpisodes = 16864
		 EnvironmentSteps = 24004
		 AverageReturn = -248414.59375
		 AverageEpisodeLength = 1.899999976158142


Iteration:  6000


INFO:absl: 
		 NumberOfEpisodes = 19566
		 EnvironmentSteps = 28004
		 AverageReturn = -288202.40625
		 AverageEpisodeLength = 1.600000023841858


Iteration:  7000


INFO:absl: 
		 NumberOfEpisodes = 22147
		 EnvironmentSteps = 32004
		 AverageReturn = -325780.5
		 AverageEpisodeLength = 2.0


Iteration:  8000


INFO:absl: 
		 NumberOfEpisodes = 24923
		 EnvironmentSteps = 36004
		 AverageReturn = -367001.8125
		 AverageEpisodeLength = 2.0


Iteration:  9000


INFO:absl: 
		 NumberOfEpisodes = 27579
		 EnvironmentSteps = 40004
		 AverageReturn = -404921.0
		 AverageEpisodeLength = 1.100000023841858


Iteration:  10000


INFO:absl: 
		 NumberOfEpisodes = 30111
		 EnvironmentSteps = 44004
		 AverageReturn = -441529.09375
		 AverageEpisodeLength = 2.299999952316284


Iteration:  11000


INFO:absl: 
		 NumberOfEpisodes = 32546
		 EnvironmentSteps = 48004
		 AverageReturn = -476605.59375
		 AverageEpisodeLength = 2.9000000953674316


Iteration:  12000


INFO:absl: 
		 NumberOfEpisodes = 35101
		 EnvironmentSteps = 52004
		 AverageReturn = -512360.09375
		 AverageEpisodeLength = 1.899999976158142


Iteration:  13000


INFO:absl: 
		 NumberOfEpisodes = 37502
		 EnvironmentSteps = 56004
		 AverageReturn = -546205.875
		 AverageEpisodeLength = 1.2999999523162842


Iteration:  14000


INFO:absl: 
		 NumberOfEpisodes = 39831
		 EnvironmentSteps = 60004
		 AverageReturn = -579262.3125
		 AverageEpisodeLength = 1.2999999523162842


Iteration:  15000


INFO:absl: 
		 NumberOfEpisodes = 42265
		 EnvironmentSteps = 64004
		 AverageReturn = -612963.625
		 AverageEpisodeLength = 0.699999988079071


Iteration:  16000


INFO:absl: 
		 NumberOfEpisodes = 44591
		 EnvironmentSteps = 68004
		 AverageReturn = -645632.5
		 AverageEpisodeLength = 2.299999952316284


Iteration:  17000


INFO:absl: 
		 NumberOfEpisodes = 46943
		 EnvironmentSteps = 72004
		 AverageReturn = -678299.125
		 AverageEpisodeLength = 2.9000000953674316


Iteration:  18000


INFO:absl: 
		 NumberOfEpisodes = 49277
		 EnvironmentSteps = 76004
		 AverageReturn = -710666.0
		 AverageEpisodeLength = 1.2999999523162842


Iteration:  19000


INFO:absl: 
		 NumberOfEpisodes = 51669
		 EnvironmentSteps = 80004
		 AverageReturn = -743722.875
		 AverageEpisodeLength = 1.600000023841858


Iteration:  20000


INFO:absl: 
		 NumberOfEpisodes = 53912
		 EnvironmentSteps = 84004
		 AverageReturn = -774612.6875
		 AverageEpisodeLength = 3.5


Iteration:  21000


INFO:absl: 
		 NumberOfEpisodes = 56063
		 EnvironmentSteps = 88004
		 AverageReturn = -804175.875
		 AverageEpisodeLength = 1.2000000476837158


Iteration:  22000


INFO:absl: 
		 NumberOfEpisodes = 58262
		 EnvironmentSteps = 92004
		 AverageReturn = -833970.875
		 AverageEpisodeLength = 2.9000000953674316


Iteration:  23000


INFO:absl: 
		 NumberOfEpisodes = 60392
		 EnvironmentSteps = 96004
		 AverageReturn = -862628.5
		 AverageEpisodeLength = 2.200000047683716


Iteration:  24000


INFO:absl: 
		 NumberOfEpisodes = 62469
		 EnvironmentSteps = 100004
		 AverageReturn = -889886.5
		 AverageEpisodeLength = 2.5


Iteration:  25000


INFO:absl: 
		 NumberOfEpisodes = 64514
		 EnvironmentSteps = 104004
		 AverageReturn = -917420.8125
		 AverageEpisodeLength = 2.9000000953674316


Iteration:  26000


INFO:absl: 
		 NumberOfEpisodes = 66622
		 EnvironmentSteps = 108004
		 AverageReturn = -946087.5
		 AverageEpisodeLength = 2.799999952316284


Iteration:  27000


INFO:absl: 
		 NumberOfEpisodes = 68707
		 EnvironmentSteps = 112004
		 AverageReturn = -973576.625
		 AverageEpisodeLength = 1.5


Iteration:  28000


INFO:absl: 
		 NumberOfEpisodes = 70794
		 EnvironmentSteps = 116004
		 AverageReturn = -1000757.125
		 AverageEpisodeLength = 1.399999976158142


Iteration:  29000


INFO:absl: 
		 NumberOfEpisodes = 72787
		 EnvironmentSteps = 120004
		 AverageReturn = -1026820.8125
		 AverageEpisodeLength = 0.699999988079071


Iteration:  30000


INFO:absl: 
		 NumberOfEpisodes = 74791
		 EnvironmentSteps = 124004
		 AverageReturn = -1052687.25
		 AverageEpisodeLength = 1.7999999523162842


Iteration:  31000


INFO:absl: 
		 NumberOfEpisodes = 76742
		 EnvironmentSteps = 128004
		 AverageReturn = -1077964.375
		 AverageEpisodeLength = 4.0


Iteration:  32000


INFO:absl: 
		 NumberOfEpisodes = 78661
		 EnvironmentSteps = 132004
		 AverageReturn = -1102964.875
		 AverageEpisodeLength = 2.299999952316284


Iteration:  33000


INFO:absl: 
		 NumberOfEpisodes = 80564
		 EnvironmentSteps = 136004
		 AverageReturn = -1126286.0
		 AverageEpisodeLength = 2.5999999046325684


Iteration:  34000


INFO:absl: 
		 NumberOfEpisodes = 82479
		 EnvironmentSteps = 140004
		 AverageReturn = -1150202.625
		 AverageEpisodeLength = 2.5


Iteration:  35000


INFO:absl: 
		 NumberOfEpisodes = 84321
		 EnvironmentSteps = 144004
		 AverageReturn = -1172522.125
		 AverageEpisodeLength = 1.7999999523162842


Iteration:  36000


INFO:absl: 
		 NumberOfEpisodes = 86170
		 EnvironmentSteps = 148004
		 AverageReturn = -1194815.25
		 AverageEpisodeLength = 2.4000000953674316


Iteration:  37000


INFO:absl: 
		 NumberOfEpisodes = 87871
		 EnvironmentSteps = 152004
		 AverageReturn = -1215441.375
		 AverageEpisodeLength = 1.2999999523162842


Iteration:  38000


INFO:absl: 
		 NumberOfEpisodes = 89706
		 EnvironmentSteps = 156004
		 AverageReturn = -1237552.25
		 AverageEpisodeLength = 1.899999976158142


Iteration:  39000


INFO:absl: 
		 NumberOfEpisodes = 91352
		 EnvironmentSteps = 160004
		 AverageReturn = -1256838.625
		 AverageEpisodeLength = 4.5


Iteration:  40000


INFO:absl: 
		 NumberOfEpisodes = 93076
		 EnvironmentSteps = 164004
		 AverageReturn = -1276237.75
		 AverageEpisodeLength = 2.299999952316284


Iteration:  41000


INFO:absl: 
		 NumberOfEpisodes = 94772
		 EnvironmentSteps = 168004
		 AverageReturn = -1295487.875
		 AverageEpisodeLength = 3.5999999046325684


Iteration:  42000


INFO:absl: 
		 NumberOfEpisodes = 96412
		 EnvironmentSteps = 172004
		 AverageReturn = -1314830.625
		 AverageEpisodeLength = 2.200000047683716


Iteration:  43000


INFO:absl: 
		 NumberOfEpisodes = 98069
		 EnvironmentSteps = 176004
		 AverageReturn = -1333590.0
		 AverageEpisodeLength = 2.4000000953674316


Iteration:  44000


INFO:absl: 
		 NumberOfEpisodes = 99720
		 EnvironmentSteps = 180004
		 AverageReturn = -1352698.875
		 AverageEpisodeLength = 2.799999952316284


Iteration:  45000


INFO:absl: 
		 NumberOfEpisodes = 101444
		 EnvironmentSteps = 184004
		 AverageReturn = -1373294.125
		 AverageEpisodeLength = 3.4000000953674316


Iteration:  46000


INFO:absl: 
		 NumberOfEpisodes = 103050
		 EnvironmentSteps = 188004
		 AverageReturn = -1390870.625
		 AverageEpisodeLength = 2.799999952316284


Iteration:  47000


INFO:absl: 
		 NumberOfEpisodes = 104615
		 EnvironmentSteps = 192004
		 AverageReturn = -1407203.875
		 AverageEpisodeLength = 2.200000047683716


Iteration:  48000


INFO:absl: 
		 NumberOfEpisodes = 106251
		 EnvironmentSteps = 196004
		 AverageReturn = -1425923.0
		 AverageEpisodeLength = 4.599999904632568


Iteration:  49000


INFO:absl: 
		 NumberOfEpisodes = 107855
		 EnvironmentSteps = 200004
		 AverageReturn = -1443390.125
		 AverageEpisodeLength = 2.4000000953674316


Iteration:  50000


INFO:absl: 
		 NumberOfEpisodes = 109418
		 EnvironmentSteps = 204004
		 AverageReturn = -1460138.625
		 AverageEpisodeLength = 2.0


Iteration:  51000


INFO:absl: 
		 NumberOfEpisodes = 110864
		 EnvironmentSteps = 208004
		 AverageReturn = -1473475.5
		 AverageEpisodeLength = 2.4000000953674316


Iteration:  52000


INFO:absl: 
		 NumberOfEpisodes = 112312
		 EnvironmentSteps = 212004
		 AverageReturn = -1487683.625
		 AverageEpisodeLength = 3.799999952316284


Iteration:  53000


INFO:absl: 
		 NumberOfEpisodes = 113746
		 EnvironmentSteps = 216004
		 AverageReturn = -1502065.875
		 AverageEpisodeLength = 3.9000000953674316


Iteration:  54000


INFO:absl: 
		 NumberOfEpisodes = 115112
		 EnvironmentSteps = 220004
		 AverageReturn = -1514533.5
		 AverageEpisodeLength = 2.0999999046325684


Iteration:  55000


INFO:absl: 
		 NumberOfEpisodes = 116609
		 EnvironmentSteps = 224004
		 AverageReturn = -1528838.25
		 AverageEpisodeLength = 3.0


Iteration:  56000


INFO:absl: 
		 NumberOfEpisodes = 117879
		 EnvironmentSteps = 228004
		 AverageReturn = -1539139.625
		 AverageEpisodeLength = 3.0


Iteration:  57000


INFO:absl: 
		 NumberOfEpisodes = 119184
		 EnvironmentSteps = 232004
		 AverageReturn = -1550377.125
		 AverageEpisodeLength = 3.200000047683716


Iteration:  58000


INFO:absl: 
		 NumberOfEpisodes = 120566
		 EnvironmentSteps = 236004
		 AverageReturn = -1562471.875
		 AverageEpisodeLength = 3.0999999046325684


Iteration:  59000


INFO:absl: 
		 NumberOfEpisodes = 121903
		 EnvironmentSteps = 240004
		 AverageReturn = -1574319.0
		 AverageEpisodeLength = 2.5


Iteration:  60000


INFO:absl: 
		 NumberOfEpisodes = 123219
		 EnvironmentSteps = 244004
		 AverageReturn = -1585471.25
		 AverageEpisodeLength = 1.2999999523162842


Iteration:  61000


INFO:absl: 
		 NumberOfEpisodes = 124498
		 EnvironmentSteps = 248004
		 AverageReturn = -1595160.0
		 AverageEpisodeLength = 4.400000095367432


Iteration:  62000


INFO:absl: 
		 NumberOfEpisodes = 125762
		 EnvironmentSteps = 252004
		 AverageReturn = -1604482.0
		 AverageEpisodeLength = 2.0999999046325684


Iteration:  63000


INFO:absl: 
		 NumberOfEpisodes = 126997
		 EnvironmentSteps = 256004
		 AverageReturn = -1613894.125
		 AverageEpisodeLength = 3.5999999046325684


Iteration:  64000


INFO:absl: 
		 NumberOfEpisodes = 128164
		 EnvironmentSteps = 260004
		 AverageReturn = -1621898.875
		 AverageEpisodeLength = 3.299999952316284


Iteration:  65000


INFO:absl: 
		 NumberOfEpisodes = 129354
		 EnvironmentSteps = 264004
		 AverageReturn = -1630125.75
		 AverageEpisodeLength = 3.700000047683716


Iteration:  66000


INFO:absl: 
		 NumberOfEpisodes = 130518
		 EnvironmentSteps = 268004
		 AverageReturn = -1637001.375
		 AverageEpisodeLength = 3.700000047683716


Iteration:  67000


INFO:absl: 
		 NumberOfEpisodes = 131718
		 EnvironmentSteps = 272004
		 AverageReturn = -1644326.375
		 AverageEpisodeLength = 2.799999952316284


Iteration:  68000


INFO:absl: 
		 NumberOfEpisodes = 132932
		 EnvironmentSteps = 276004
		 AverageReturn = -1652353.75
		 AverageEpisodeLength = 3.799999952316284


Iteration:  69000


INFO:absl: 
		 NumberOfEpisodes = 134143
		 EnvironmentSteps = 280004
		 AverageReturn = -1660930.75
		 AverageEpisodeLength = 1.899999976158142


Iteration:  70000


INFO:absl: 
		 NumberOfEpisodes = 135319
		 EnvironmentSteps = 284004
		 AverageReturn = -1667831.5
		 AverageEpisodeLength = 0.699999988079071


Iteration:  71000


INFO:absl: 
		 NumberOfEpisodes = 136456
		 EnvironmentSteps = 288004
		 AverageReturn = -1674437.375
		 AverageEpisodeLength = 4.699999809265137


Iteration:  72000


INFO:absl: 
		 NumberOfEpisodes = 137520
		 EnvironmentSteps = 292004
		 AverageReturn = -1678238.0
		 AverageEpisodeLength = 2.200000047683716


Iteration:  73000


INFO:absl: 
		 NumberOfEpisodes = 138652
		 EnvironmentSteps = 296004
		 AverageReturn = -1683921.625
		 AverageEpisodeLength = 1.899999976158142


Iteration:  74000


INFO:absl: 
		 NumberOfEpisodes = 139757
		 EnvironmentSteps = 300004
		 AverageReturn = -1689264.625
		 AverageEpisodeLength = 6.099999904632568


Iteration:  75000


INFO:absl: 
		 NumberOfEpisodes = 140815
		 EnvironmentSteps = 304004
		 AverageReturn = -1693901.375
		 AverageEpisodeLength = 2.9000000953674316


Iteration:  76000


INFO:absl: 
		 NumberOfEpisodes = 141904
		 EnvironmentSteps = 308004
		 AverageReturn = -1698926.25
		 AverageEpisodeLength = 3.0


Iteration:  77000


INFO:absl: 
		 NumberOfEpisodes = 142903
		 EnvironmentSteps = 312004
		 AverageReturn = -1702033.375
		 AverageEpisodeLength = 3.5


Iteration:  78000


INFO:absl: 
		 NumberOfEpisodes = 143990
		 EnvironmentSteps = 316004
		 AverageReturn = -1705870.375
		 AverageEpisodeLength = 3.200000047683716


Iteration:  79000


INFO:absl: 
		 NumberOfEpisodes = 144943
		 EnvironmentSteps = 320004
		 AverageReturn = -1707876.75
		 AverageEpisodeLength = 6.699999809265137


Iteration:  80000


INFO:absl: 
		 NumberOfEpisodes = 145929
		 EnvironmentSteps = 324004
		 AverageReturn = -1710443.625
		 AverageEpisodeLength = 5.800000190734863


Iteration:  81000


INFO:absl: 
		 NumberOfEpisodes = 146882
		 EnvironmentSteps = 328004
		 AverageReturn = -1711753.25
		 AverageEpisodeLength = 5.0


Iteration:  82000


INFO:absl: 
		 NumberOfEpisodes = 147783
		 EnvironmentSteps = 332004
		 AverageReturn = -1712159.25
		 AverageEpisodeLength = 4.0


Iteration:  83000


INFO:absl: 
		 NumberOfEpisodes = 148749
		 EnvironmentSteps = 336004
		 AverageReturn = -1714451.0
		 AverageEpisodeLength = 1.7999999523162842


Iteration:  84000


INFO:absl: 
		 NumberOfEpisodes = 149685
		 EnvironmentSteps = 340004
		 AverageReturn = -1714158.75
		 AverageEpisodeLength = 3.200000047683716


Iteration:  85000


INFO:absl: 
		 NumberOfEpisodes = 150571
		 EnvironmentSteps = 344004
		 AverageReturn = -1713880.0
		 AverageEpisodeLength = 2.700000047683716


Iteration:  86000


INFO:absl: 
		 NumberOfEpisodes = 151455
		 EnvironmentSteps = 348004
		 AverageReturn = -1713812.375
		 AverageEpisodeLength = 5.099999904632568


Iteration:  87000


INFO:absl: 
		 NumberOfEpisodes = 152319
		 EnvironmentSteps = 352004
		 AverageReturn = -1712930.0
		 AverageEpisodeLength = 4.300000190734863


Iteration:  88000


INFO:absl: 
		 NumberOfEpisodes = 153149
		 EnvironmentSteps = 356004
		 AverageReturn = -1711618.75
		 AverageEpisodeLength = 3.9000000953674316


Iteration:  89000


INFO:absl: 
		 NumberOfEpisodes = 154031
		 EnvironmentSteps = 360004
		 AverageReturn = -1711307.625
		 AverageEpisodeLength = 4.0


Iteration:  90000


INFO:absl: 
		 NumberOfEpisodes = 154868
		 EnvironmentSteps = 364004
		 AverageReturn = -1710002.75
		 AverageEpisodeLength = 3.0999999046325684


Iteration:  91000


INFO:absl: 
		 NumberOfEpisodes = 155716
		 EnvironmentSteps = 368004
		 AverageReturn = -1708198.375
		 AverageEpisodeLength = 4.599999904632568


Iteration:  92000


INFO:absl: 
		 NumberOfEpisodes = 156475
		 EnvironmentSteps = 372004
		 AverageReturn = -1705163.75
		 AverageEpisodeLength = 5.099999904632568


Iteration:  93000


INFO:absl: 
		 NumberOfEpisodes = 157282
		 EnvironmentSteps = 376004
		 AverageReturn = -1701850.25
		 AverageEpisodeLength = 2.9000000953674316


Iteration:  94000


INFO:absl: 
		 NumberOfEpisodes = 158045
		 EnvironmentSteps = 380004
		 AverageReturn = -1698748.0
		 AverageEpisodeLength = 3.200000047683716


Iteration:  95000


INFO:absl: 
		 NumberOfEpisodes = 158788
		 EnvironmentSteps = 384004
		 AverageReturn = -1694690.25
		 AverageEpisodeLength = 5.199999809265137


Iteration:  96000


INFO:absl: 
		 NumberOfEpisodes = 159568
		 EnvironmentSteps = 388004
		 AverageReturn = -1690518.625
		 AverageEpisodeLength = 4.599999904632568


Iteration:  97000


INFO:absl: 
		 NumberOfEpisodes = 160346
		 EnvironmentSteps = 392004
		 AverageReturn = -1685576.375
		 AverageEpisodeLength = 5.300000190734863


Iteration:  98000


INFO:absl: 
		 NumberOfEpisodes = 161036
		 EnvironmentSteps = 396004
		 AverageReturn = -1680241.25
		 AverageEpisodeLength = 8.699999809265137


Iteration:  99000


INFO:absl: 
		 NumberOfEpisodes = 161680
		 EnvironmentSteps = 400004
		 AverageReturn = -1673990.5
		 AverageEpisodeLength = 3.200000047683716


Iteration:  100000


INFO:absl: 
		 NumberOfEpisodes = 162382
		 EnvironmentSteps = 404004
		 AverageReturn = -1668822.75
		 AverageEpisodeLength = 5.5


Iteration:  101000


INFO:absl: 
		 NumberOfEpisodes = 163074
		 EnvironmentSteps = 408004
		 AverageReturn = -1663724.25
		 AverageEpisodeLength = 8.5


Iteration:  102000


INFO:absl: 
		 NumberOfEpisodes = 163748
		 EnvironmentSteps = 412004
		 AverageReturn = -1657042.0
		 AverageEpisodeLength = 4.099999904632568


Iteration:  103000


INFO:absl: 
		 NumberOfEpisodes = 164411
		 EnvironmentSteps = 416004
		 AverageReturn = -1650671.625
		 AverageEpisodeLength = 8.5


Iteration:  104000


INFO:absl: 
		 NumberOfEpisodes = 165066
		 EnvironmentSteps = 420004
		 AverageReturn = -1643545.25
		 AverageEpisodeLength = 6.0


Iteration:  105000


INFO:absl: 
		 NumberOfEpisodes = 165679
		 EnvironmentSteps = 424004
		 AverageReturn = -1636097.875
		 AverageEpisodeLength = 5.400000095367432


Iteration:  106000


INFO:absl: 
		 NumberOfEpisodes = 166329
		 EnvironmentSteps = 428004
		 AverageReturn = -1628713.375
		 AverageEpisodeLength = 5.099999904632568


Iteration:  107000


INFO:absl: 
		 NumberOfEpisodes = 166896
		 EnvironmentSteps = 432004
		 AverageReturn = -1620088.75
		 AverageEpisodeLength = 7.900000095367432


Iteration:  108000


INFO:absl: 
		 NumberOfEpisodes = 167492
		 EnvironmentSteps = 436004
		 AverageReturn = -1611151.375
		 AverageEpisodeLength = 7.400000095367432


Iteration:  109000


INFO:absl: 
		 NumberOfEpisodes = 168056
		 EnvironmentSteps = 440004
		 AverageReturn = -1601086.5
		 AverageEpisodeLength = 3.9000000953674316


Iteration:  110000


INFO:absl: 
		 NumberOfEpisodes = 168595
		 EnvironmentSteps = 444004
		 AverageReturn = -1592347.25
		 AverageEpisodeLength = 10.699999809265137


Iteration:  111000


INFO:absl: 
		 NumberOfEpisodes = 169125
		 EnvironmentSteps = 448004
		 AverageReturn = -1582666.25
		 AverageEpisodeLength = 7.599999904632568


Iteration:  112000


INFO:absl: 
		 NumberOfEpisodes = 169655
		 EnvironmentSteps = 452004
		 AverageReturn = -1572401.75
		 AverageEpisodeLength = 3.4000000953674316


Iteration:  113000


INFO:absl: 
		 NumberOfEpisodes = 170167
		 EnvironmentSteps = 456004
		 AverageReturn = -1561965.125
		 AverageEpisodeLength = 10.600000381469727


Iteration:  114000


INFO:absl: 
		 NumberOfEpisodes = 170666
		 EnvironmentSteps = 460004
		 AverageReturn = -1550092.125
		 AverageEpisodeLength = 7.800000190734863


Iteration:  115000


INFO:absl: 
		 NumberOfEpisodes = 171174
		 EnvironmentSteps = 464004
		 AverageReturn = -1539083.5
		 AverageEpisodeLength = 7.0


Iteration:  116000


INFO:absl: 
		 NumberOfEpisodes = 171676
		 EnvironmentSteps = 468004
		 AverageReturn = -1527729.625
		 AverageEpisodeLength = 7.900000095367432


Iteration:  117000


INFO:absl: 
		 NumberOfEpisodes = 172184
		 EnvironmentSteps = 472004
		 AverageReturn = -1516035.625
		 AverageEpisodeLength = 7.400000095367432


Iteration:  118000


INFO:absl: 
		 NumberOfEpisodes = 172617
		 EnvironmentSteps = 476004
		 AverageReturn = -1503827.75
		 AverageEpisodeLength = 13.100000381469727


Iteration:  119000


INFO:absl: 
		 NumberOfEpisodes = 173086
		 EnvironmentSteps = 480004
		 AverageReturn = -1491523.625
		 AverageEpisodeLength = 10.300000190734863


Iteration:  120000


INFO:absl: 
		 NumberOfEpisodes = 173535
		 EnvironmentSteps = 484004
		 AverageReturn = -1478417.25
		 AverageEpisodeLength = 3.5999999046325684


Iteration:  121000


INFO:absl: 
		 NumberOfEpisodes = 173944
		 EnvironmentSteps = 488004
		 AverageReturn = -1464421.0
		 AverageEpisodeLength = 10.399999618530273


Iteration:  122000


INFO:absl: 
		 NumberOfEpisodes = 174394
		 EnvironmentSteps = 492004
		 AverageReturn = -1451302.875
		 AverageEpisodeLength = 6.5


Iteration:  123000


INFO:absl: 
		 NumberOfEpisodes = 174819
		 EnvironmentSteps = 496004
		 AverageReturn = -1437237.75
		 AverageEpisodeLength = 10.600000381469727


Iteration:  124000


INFO:absl: 
		 NumberOfEpisodes = 175241
		 EnvironmentSteps = 500004
		 AverageReturn = -1423088.5
		 AverageEpisodeLength = 10.300000190734863


Iteration:  125000


INFO:absl: 
		 NumberOfEpisodes = 175642
		 EnvironmentSteps = 504004
		 AverageReturn = -1408773.625
		 AverageEpisodeLength = 7.5


Iteration:  126000


INFO:absl: 
		 NumberOfEpisodes = 176021
		 EnvironmentSteps = 508004
		 AverageReturn = -1393307.75
		 AverageEpisodeLength = 14.100000381469727


Iteration:  127000


INFO:absl: 
		 NumberOfEpisodes = 176399
		 EnvironmentSteps = 512004
		 AverageReturn = -1378383.5
		 AverageEpisodeLength = 12.800000190734863


Iteration:  128000


INFO:absl: 
		 NumberOfEpisodes = 176780
		 EnvironmentSteps = 516004
		 AverageReturn = -1363275.75
		 AverageEpisodeLength = 12.300000190734863


Iteration:  129000


INFO:absl: 
		 NumberOfEpisodes = 177146
		 EnvironmentSteps = 520004
		 AverageReturn = -1347462.0
		 AverageEpisodeLength = 11.5


Iteration:  130000


INFO:absl: 
		 NumberOfEpisodes = 177526
		 EnvironmentSteps = 524004
		 AverageReturn = -1331898.25
		 AverageEpisodeLength = 11.100000381469727


Iteration:  131000


INFO:absl: 
		 NumberOfEpisodes = 177858
		 EnvironmentSteps = 528004
		 AverageReturn = -1314383.5
		 AverageEpisodeLength = 9.0


Iteration:  132000


INFO:absl: 
		 NumberOfEpisodes = 178195
		 EnvironmentSteps = 532004
		 AverageReturn = -1296662.375
		 AverageEpisodeLength = 10.399999618530273


Iteration:  133000


INFO:absl: 
		 NumberOfEpisodes = 178526
		 EnvironmentSteps = 536004
		 AverageReturn = -1279974.75
		 AverageEpisodeLength = 9.699999809265137


Iteration:  134000


INFO:absl: 
		 NumberOfEpisodes = 178819
		 EnvironmentSteps = 540004
		 AverageReturn = -1262129.5
		 AverageEpisodeLength = 12.800000190734863


Iteration:  135000


INFO:absl: 
		 NumberOfEpisodes = 179105
		 EnvironmentSteps = 544004
		 AverageReturn = -1243828.875
		 AverageEpisodeLength = 10.600000381469727


Iteration:  136000


INFO:absl: 
		 NumberOfEpisodes = 179400
		 EnvironmentSteps = 548004
		 AverageReturn = -1225585.375
		 AverageEpisodeLength = 11.0


Iteration:  137000


INFO:absl: 
		 NumberOfEpisodes = 179665
		 EnvironmentSteps = 552004
		 AverageReturn = -1206392.0
		 AverageEpisodeLength = 13.100000381469727


Iteration:  138000


INFO:absl: 
		 NumberOfEpisodes = 179943
		 EnvironmentSteps = 556004
		 AverageReturn = -1187348.875
		 AverageEpisodeLength = 11.5


Iteration:  139000


INFO:absl: 
		 NumberOfEpisodes = 180214
		 EnvironmentSteps = 560004
		 AverageReturn = -1168566.0
		 AverageEpisodeLength = 17.600000381469727


Iteration:  140000


INFO:absl: 
		 NumberOfEpisodes = 180487
		 EnvironmentSteps = 564004
		 AverageReturn = -1148358.875
		 AverageEpisodeLength = 13.699999809265137


Iteration:  141000


INFO:absl: 
		 NumberOfEpisodes = 180719
		 EnvironmentSteps = 568004
		 AverageReturn = -1128867.75
		 AverageEpisodeLength = 16.200000762939453


Iteration:  142000


INFO:absl: 
		 NumberOfEpisodes = 180931
		 EnvironmentSteps = 572004
		 AverageReturn = -1108257.75
		 AverageEpisodeLength = 17.700000762939453


Iteration:  143000


INFO:absl: 
		 NumberOfEpisodes = 181157
		 EnvironmentSteps = 576004
		 AverageReturn = -1086800.25
		 AverageEpisodeLength = 18.899999618530273


Iteration:  144000


INFO:absl: 
		 NumberOfEpisodes = 181386
		 EnvironmentSteps = 580004
		 AverageReturn = -1066007.5
		 AverageEpisodeLength = 16.899999618530273


Iteration:  145000


INFO:absl: 
		 NumberOfEpisodes = 181607
		 EnvironmentSteps = 584004
		 AverageReturn = -1044900.125
		 AverageEpisodeLength = 16.5


Iteration:  146000


INFO:absl: 
		 NumberOfEpisodes = 181797
		 EnvironmentSteps = 588004
		 AverageReturn = -1023086.0
		 AverageEpisodeLength = 23.799999237060547


Iteration:  147000


INFO:absl: 
		 NumberOfEpisodes = 181980
		 EnvironmentSteps = 592004
		 AverageReturn = -1001019.875
		 AverageEpisodeLength = 25.5


Iteration:  148000


INFO:absl: 
		 NumberOfEpisodes = 182163
		 EnvironmentSteps = 596004
		 AverageReturn = -978691.1875
		 AverageEpisodeLength = 29.399999618530273


Iteration:  149000


In [11]:
# Evaluate the agent

time_step = env.reset()
rewards = []
steps = 0

while not time_step.is_last() and steps < 200:
    steps += 1
    print("Map:\n", np.argmax(time_step.observation.numpy(), axis=3))
    action_step = agent.policy.action(time_step)
    print("Action: ", action_step.action.numpy())
    time_step = env.step(action_step.action)
    rewards.append(time_step.reward.numpy()[0])
    
print("Total reward: ", sum(rewards))
print("Total steps: ", steps)


Map:
 [[[0 0 0 0 0]
  [0 1 0 0 0]
  [0 2 0 0 0]
  [0 0 0 3 0]
  [0 0 0 0 0]]]
Action:  [2]
Map:
 [[[0 0 0 0 0]
  [0 2 1 0 0]
  [0 0 0 0 0]
  [0 0 0 3 0]
  [0 0 0 0 0]]]
Action:  [2]
Map:
 [[[0 0 0 0 0]
  [0 0 2 1 0]
  [0 0 0 0 0]
  [0 0 0 3 0]
  [0 0 0 0 0]]]
Action:  [3]
Map:
 [[[0 0 0 0 0]
  [0 0 0 2 0]
  [0 0 0 1 0]
  [0 0 0 3 0]
  [0 0 0 0 0]]]
Action:  [3]
Map:
 [[[0 0 0 0 0]
  [0 0 0 2 3]
  [0 0 0 2 0]
  [0 0 0 1 0]
  [0 0 0 0 0]]]
Action:  [2]
Map:
 [[[0 0 0 0 0]
  [0 0 0 0 3]
  [0 0 0 2 0]
  [0 0 0 2 1]
  [0 0 0 0 0]]]
Action:  [1]
Map:
 [[[0 0 0 0 0]
  [0 0 0 0 3]
  [0 0 0 0 1]
  [0 0 0 2 2]
  [0 0 0 0 0]]]
Action:  [1]
Map:
 [[[0 0 0 0 0]
  [0 0 0 0 1]
  [0 0 0 0 2]
  [0 0 0 2 2]
  [0 0 0 0 3]]]
Action:  [0]
Map:
 [[[0 0 0 0 0]
  [0 0 0 1 2]
  [0 0 0 0 2]
  [0 0 0 0 2]
  [0 0 0 0 3]]]
Action:  [3]
Map:
 [[[0 0 0 0 0]
  [0 0 0 2 2]
  [0 0 0 1 2]
  [0 0 0 0 0]
  [0 0 0 0 3]]]
Action:  [3]
Map:
 [[[0 0 0 0 0]
  [0 0 0 2 2]
  [0 0 0 2 0]
  [0 0 0 1 0]
  [0 0 0 0 3]]]
Action:  [2]

In [12]:
# Count the average steps per episode

n = 200
total_steps = 0
total_reward = 0
steps = 0
total_500 = 0

for i in range(n):
    print("Episode: ", i)
    time_step = env.reset()
    steps = 0
    while not time_step.is_last() and steps < 500:
        steps += 1
        action_step = agent.policy.action(time_step)
        time_step = env.step(action_step.action)
        reward = time_step.reward.numpy()[0]
        total_reward += reward
    
    if steps == 500:
        total_500 += 1
    else:
        total_steps += steps
        
if steps == 500:
    print("500 episodes reached")
    
print("Average steps per episode: ", total_steps / (n - total_500))
print("Average reward per episode: ", total_reward / n)
print("Total 500 episodes: ", total_500)


Episode:  0
Episode:  1
Episode:  2
Episode:  3
Episode:  4
Episode:  5
Episode:  6
Episode:  7
Episode:  8
Episode:  9
Episode:  10
Episode:  11
Episode:  12
Episode:  13
Episode:  14
Episode:  15
Episode:  16
Episode:  17
Episode:  18
Episode:  19
Episode:  20
Episode:  21
Episode:  22
Episode:  23
Episode:  24
Episode:  25
Episode:  26
Episode:  27
Episode:  28
Episode:  29
Episode:  30
Episode:  31
Episode:  32
Episode:  33
Episode:  34
Episode:  35
Episode:  36
Episode:  37
Episode:  38
Episode:  39
Episode:  40
Episode:  41
Episode:  42
Episode:  43
Episode:  44
Episode:  45
Episode:  46
Episode:  47
Episode:  48
Episode:  49
Episode:  50
Episode:  51
Episode:  52
Episode:  53
Episode:  54
Episode:  55
Episode:  56
Episode:  57
Episode:  58
Episode:  59
Episode:  60
Episode:  61
Episode:  62
Episode:  63
Episode:  64
Episode:  65
Episode:  66
Episode:  67
Episode:  68
Episode:  69
Episode:  70
Episode:  71
Episode:  72
Episode:  73
Episode:  74
Episode:  75
Episode:  76
Episode: 

In [13]:
print("Map:\n", np.argmax(time_step.observation.numpy(), axis=3))


Map:
 [[[0 0 0 0 0]
  [0 0 0 0 0]
  [0 0 0 0 0]
  [3 0 0 0 0]
  [1 2 0 0 0]]]


In [14]:
action_step = agent.policy.action(time_step)
time_step = env.step(action_step.action)
print("Map:\n", np.argmax(time_step.observation.numpy(), axis=3))


Map:
 [[[0 0 0 0 0]
  [0 0 0 0 0]
  [0 0 0 0 3]
  [1 0 0 0 0]
  [2 2 0 0 0]]]


In [15]:
# Save the agent
import os
from tf_agents.policies.policy_saver import PolicySaver

saver = PolicySaver(agent.policy, batch_size=None)
saver.save("policy")


INFO:absl:Function `function_with_signature` contains input name(s) 0/step_type, 0/reward, 0/discount, 0/observation, 10066885, 10066887, 10066889, 10066891, 10066893, 10066895, 10066897, 10066899, 10066901, 10066903 with unsupported characters which will be renamed to step_type, reward, discount, observation, unknown, unknown_0, unknown_1, unknown_2, unknown_3, unknown_4, unknown_5, unknown_6, unknown_7, unknown_8 in the SavedModel.
INFO:absl:Function `function_with_signature` contains input name(s) 10066948 with unsupported characters which will be renamed to unknown in the SavedModel.


INFO:tensorflow:Assets written to: policy/assets


INFO:tensorflow:Assets written to: policy/assets
