In [1]:
import numpy as np
from scene import Scene
import tensorflow as tf
from tf_agents.networks.q_network import QNetwork
from tf_agents.specs import tensor_spec
from snake_game import SnakeGame

scene = Scene(using_cnn=True, init_randomly=True)
episodes_count = 100000

pygame 2.1.3 (SDL 2.0.22, Python 3.11.4)
Hello from the pygame community. https://www.pygame.org/contribute.html


In [2]:
from tf_agents.environments.tf_py_environment import TFPyEnvironment

env = SnakeGame(scene)
env = TFPyEnvironment(env)

In [3]:
conv_layer_params = [(32, (2, 2), 1), (32, (2, 2), 1)]
fc_layer_params = [64]

q_net = QNetwork(
    env.observation_spec(),
    env.action_spec(),
    conv_layer_params=conv_layer_params,
    fc_layer_params=fc_layer_params)

In [4]:
# Create the agent
from tf_agents.agents.dqn.dqn_agent import DqnAgent
from tensorflow import keras
from keras.optimizers.legacy import Adam
from keras.losses import Huber
from tf_agents.trajectories import TimeStep

train_step_counter = tf.Variable(0)
optimizer = Adam(learning_rate=0.001)

epsilon = lambda train_step: (1 / (tf.cast(tf.linspace(1, 8, episodes_count)**3, tf.float32)))[train_step]
# epsilon = keras.optimizers.schedules.PolynomialDecay(
#     initial_learning_rate=1.0,
#     decay_steps=episodes_count,
#     end_learning_rate=0.01
# )

agent = DqnAgent(
    time_step_spec=env.time_step_spec(),
    action_spec=env.action_spec(),
    q_network=q_net,
    optimizer=optimizer,
    td_errors_loss_fn=Huber(reduction="none"),
    gamma=tf.constant(0.95, dtype=tf.float32),
    train_step_counter=train_step_counter,
    epsilon_greedy=lambda: epsilon(train_step_counter)
)

In [5]:
# Create the replay buffer
from tf_agents.replay_buffers import tf_uniform_replay_buffer

replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
    data_spec=agent.collect_data_spec,
    batch_size=env.batch_size,
    max_length=100000
)

In [6]:
# Create the training metrics
from tf_agents.metrics import tf_metrics

train_metrics = [
    tf_metrics.NumberOfEpisodes(),
    tf_metrics.EnvironmentSteps(),
    tf_metrics.AverageReturnMetric(),
    tf_metrics.AverageEpisodeLengthMetric()
]

In [7]:
# Create the driver
from tf_agents.drivers import dynamic_step_driver

collect_driver = dynamic_step_driver.DynamicStepDriver(
    env,
    policy=agent.collect_policy,
    observers=[replay_buffer.add_batch] + train_metrics,
    num_steps=1
)

# Run a random policy to fill the replay buffer
from tf_agents.policies.random_tf_policy import RandomTFPolicy

initial_collect_policy = RandomTFPolicy(env.time_step_spec(), env.action_spec())
init_driver = dynamic_step_driver.DynamicStepDriver(
    env,
    policy=initial_collect_policy,
    observers=[replay_buffer.add_batch] + train_metrics,
    num_steps=1
)

final_time_step, final_policy_state = init_driver.run()

In [8]:
# Create the dataset
dataset = replay_buffer.as_dataset(
    sample_batch_size=64,
    num_steps=2, # Capaz cambiar esto
    num_parallel_calls=3
).prefetch(3)

Instructions for updating:
Use `tf.data.Dataset.counter(...)` instead.
Instructions for updating:
Use `as_dataset(..., single_deterministic_pass=False) instead.


In [9]:
# Create the training loop

from tf_agents.utils.common import function
from tf_agents.trajectories import trajectory
from tf_agents.eval.metric_utils import log_metrics
import logging

logging.getLogger().setLevel(logging.INFO)

collect_driver.run = function(collect_driver.run)
agent.train = function(agent.train)

def train_agent(n_iterations):
    time_step = None
    policy_state = agent.collect_policy.get_initial_state(env.batch_size)
    iterator = iter(dataset)
    for iteration in range(n_iterations):
        time_step, policy_state = collect_driver.run(time_step, policy_state)
        trajectories, buffer_info = next(iterator)
        train_loss = agent.train(trajectories)
        if iteration % 1000 == 0:
            print("Iteration: ", iteration)
            log_metrics(train_metrics)

In [10]:
train_agent(episodes_count)

Instructions for updating:
back_prop=False is deprecated. Consider using tf.stop_gradient instead.
Instead of:
results = tf.foldr(fn, elems, back_prop=False)
Use:
results = tf.nest.map_structure(tf.stop_gradient, tf.foldr(fn, elems))


INFO:absl: 
		 NumberOfEpisodes = 1
		 EnvironmentSteps = 2
		 AverageReturn = -30.200000762939453
		 AverageEpisodeLength = 1.0


Iteration:  0


INFO:absl: 
		 NumberOfEpisodes = 184
		 EnvironmentSteps = 1002
		 AverageReturn = -4686.7568359375
		 AverageEpisodeLength = 6.199999809265137


Iteration:  1000


INFO:absl: 
		 NumberOfEpisodes = 343
		 EnvironmentSteps = 2002
		 AverageReturn = -8553.1728515625
		 AverageEpisodeLength = 5.599999904632568


Iteration:  2000


INFO:absl: 
		 NumberOfEpisodes = 495
		 EnvironmentSteps = 3002
		 AverageReturn = -11915.912109375
		 AverageEpisodeLength = 7.699999809265137


Iteration:  3000


INFO:absl: 
		 NumberOfEpisodes = 631
		 EnvironmentSteps = 4002
		 AverageReturn = -14436.4814453125
		 AverageEpisodeLength = 10.899999618530273


Iteration:  4000


INFO:absl: 
		 NumberOfEpisodes = 753
		 EnvironmentSteps = 5002
		 AverageReturn = -16295.8154296875
		 AverageEpisodeLength = 9.399999618530273


Iteration:  5000


INFO:absl: 
		 NumberOfEpisodes = 867
		 EnvironmentSteps = 6002
		 AverageReturn = -17650.33203125
		 AverageEpisodeLength = 6.199999809265137


Iteration:  6000


INFO:absl: 
		 NumberOfEpisodes = 994
		 EnvironmentSteps = 7002
		 AverageReturn = -19241.796875
		 AverageEpisodeLength = 7.800000190734863


Iteration:  7000


INFO:absl: 
		 NumberOfEpisodes = 1104
		 EnvironmentSteps = 8002
		 AverageReturn = -20280.974609375
		 AverageEpisodeLength = 9.199999809265137


Iteration:  8000


INFO:absl: 
		 NumberOfEpisodes = 1191
		 EnvironmentSteps = 9002
		 AverageReturn = -20543.14453125
		 AverageEpisodeLength = 13.0


Iteration:  9000


INFO:absl: 
		 NumberOfEpisodes = 1269
		 EnvironmentSteps = 10002
		 AverageReturn = -20628.91796875
		 AverageEpisodeLength = 16.799999237060547


Iteration:  10000


INFO:absl: 
		 NumberOfEpisodes = 1331
		 EnvironmentSteps = 11002
		 AverageReturn = -20168.978515625
		 AverageEpisodeLength = 14.699999809265137


Iteration:  11000


INFO:absl: 
		 NumberOfEpisodes = 1408
		 EnvironmentSteps = 12002
		 AverageReturn = -19942.52734375
		 AverageEpisodeLength = 9.600000381469727


Iteration:  12000


INFO:absl: 
		 NumberOfEpisodes = 1478
		 EnvironmentSteps = 13002
		 AverageReturn = -19667.28515625
		 AverageEpisodeLength = 9.899999618530273


Iteration:  13000


INFO:absl: 
		 NumberOfEpisodes = 1532
		 EnvironmentSteps = 14002
		 AverageReturn = -18897.1875
		 AverageEpisodeLength = 22.0


Iteration:  14000


INFO:absl: 
		 NumberOfEpisodes = 1581
		 EnvironmentSteps = 15002
		 AverageReturn = -17762.05859375
		 AverageEpisodeLength = 25.700000762939453


Iteration:  15000


INFO:absl: 
		 NumberOfEpisodes = 1636
		 EnvironmentSteps = 16002
		 AverageReturn = -16855.458984375
		 AverageEpisodeLength = 22.899999618530273


Iteration:  16000


INFO:absl: 
		 NumberOfEpisodes = 1685
		 EnvironmentSteps = 17002
		 AverageReturn = -15813.263671875
		 AverageEpisodeLength = 16.200000762939453


Iteration:  17000


INFO:absl: 
		 NumberOfEpisodes = 1732
		 EnvironmentSteps = 18002
		 AverageReturn = -14633.025390625
		 AverageEpisodeLength = 23.700000762939453


Iteration:  18000


INFO:absl: 
		 NumberOfEpisodes = 1771
		 EnvironmentSteps = 19002
		 AverageReturn = -13383.142578125
		 AverageEpisodeLength = 29.200000762939453


Iteration:  19000


INFO:absl: 
		 NumberOfEpisodes = 1816
		 EnvironmentSteps = 20002
		 AverageReturn = -11857.025390625
		 AverageEpisodeLength = 23.5


Iteration:  20000


INFO:absl: 
		 NumberOfEpisodes = 1850
		 EnvironmentSteps = 21002
		 AverageReturn = -10371.1025390625
		 AverageEpisodeLength = 31.700000762939453


Iteration:  21000


INFO:absl: 
		 NumberOfEpisodes = 1884
		 EnvironmentSteps = 22002
		 AverageReturn = -8520.533203125
		 AverageEpisodeLength = 33.900001525878906


Iteration:  22000


INFO:absl: 
		 NumberOfEpisodes = 1920
		 EnvironmentSteps = 23002
		 AverageReturn = -6758.41259765625
		 AverageEpisodeLength = 22.200000762939453


Iteration:  23000


INFO:absl: 
		 NumberOfEpisodes = 1953
		 EnvironmentSteps = 24002
		 AverageReturn = -5102.58837890625
		 AverageEpisodeLength = 31.600000381469727


Iteration:  24000


INFO:absl: 
		 NumberOfEpisodes = 1986
		 EnvironmentSteps = 25002
		 AverageReturn = -3487.24609375
		 AverageEpisodeLength = 36.400001525878906


Iteration:  25000


INFO:absl: 
		 NumberOfEpisodes = 2013
		 EnvironmentSteps = 26002
		 AverageReturn = -1563.176513671875
		 AverageEpisodeLength = 35.29999923706055


Iteration:  26000


INFO:absl: 
		 NumberOfEpisodes = 2052
		 EnvironmentSteps = 27002
		 AverageReturn = 234.68557739257812
		 AverageEpisodeLength = 25.0


Iteration:  27000


INFO:absl: 
		 NumberOfEpisodes = 2083
		 EnvironmentSteps = 28002
		 AverageReturn = 2240.649658203125
		 AverageEpisodeLength = 22.700000762939453


Iteration:  28000


INFO:absl: 
		 NumberOfEpisodes = 2114
		 EnvironmentSteps = 29002
		 AverageReturn = 4267.9306640625
		 AverageEpisodeLength = 36.29999923706055


Iteration:  29000


INFO:absl: 
		 NumberOfEpisodes = 2144
		 EnvironmentSteps = 30002
		 AverageReturn = 6377.77978515625
		 AverageEpisodeLength = 37.0


Iteration:  30000


INFO:absl: 
		 NumberOfEpisodes = 2172
		 EnvironmentSteps = 31002
		 AverageReturn = 8620.970703125
		 AverageEpisodeLength = 31.799999237060547


Iteration:  31000


INFO:absl: 
		 NumberOfEpisodes = 2202
		 EnvironmentSteps = 32002
		 AverageReturn = 10706.6005859375
		 AverageEpisodeLength = 35.099998474121094


Iteration:  32000


INFO:absl: 
		 NumberOfEpisodes = 2228
		 EnvironmentSteps = 33002
		 AverageReturn = 12533.013671875
		 AverageEpisodeLength = 46.20000076293945


Iteration:  33000


INFO:absl: 
		 NumberOfEpisodes = 2251
		 EnvironmentSteps = 34002
		 AverageReturn = 14930.2421875
		 AverageEpisodeLength = 42.29999923706055


Iteration:  34000


INFO:absl: 
		 NumberOfEpisodes = 2279
		 EnvironmentSteps = 35002
		 AverageReturn = 17315.53125
		 AverageEpisodeLength = 32.79999923706055


Iteration:  35000


INFO:absl: 
		 NumberOfEpisodes = 2305
		 EnvironmentSteps = 36002
		 AverageReturn = 19751.904296875
		 AverageEpisodeLength = 34.099998474121094


Iteration:  36000


INFO:absl: 
		 NumberOfEpisodes = 2331
		 EnvironmentSteps = 37002
		 AverageReturn = 21905.31640625
		 AverageEpisodeLength = 40.0


Iteration:  37000


INFO:absl: 
		 NumberOfEpisodes = 2356
		 EnvironmentSteps = 38002
		 AverageReturn = 24062.05078125
		 AverageEpisodeLength = 47.599998474121094


Iteration:  38000


INFO:absl: 
		 NumberOfEpisodes = 2381
		 EnvironmentSteps = 39002
		 AverageReturn = 26481.509765625
		 AverageEpisodeLength = 41.20000076293945


Iteration:  39000


INFO:absl: 
		 NumberOfEpisodes = 2410
		 EnvironmentSteps = 40002
		 AverageReturn = 28950.146484375
		 AverageEpisodeLength = 34.79999923706055


Iteration:  40000


INFO:absl: 
		 NumberOfEpisodes = 2435
		 EnvironmentSteps = 41002
		 AverageReturn = 31223.49609375
		 AverageEpisodeLength = 39.400001525878906


Iteration:  41000


INFO:absl: 
		 NumberOfEpisodes = 2458
		 EnvironmentSteps = 42002
		 AverageReturn = 33832.70703125
		 AverageEpisodeLength = 46.70000076293945


Iteration:  42000


INFO:absl: 
		 NumberOfEpisodes = 2482
		 EnvironmentSteps = 43002
		 AverageReturn = 36160.5078125
		 AverageEpisodeLength = 38.79999923706055


Iteration:  43000


INFO:absl: 
		 NumberOfEpisodes = 2507
		 EnvironmentSteps = 44002
		 AverageReturn = 38734.7265625
		 AverageEpisodeLength = 40.900001525878906


Iteration:  44000


INFO:absl: 
		 NumberOfEpisodes = 2531
		 EnvironmentSteps = 45002
		 AverageReturn = 41269.79296875
		 AverageEpisodeLength = 39.599998474121094


Iteration:  45000


INFO:absl: 
		 NumberOfEpisodes = 2554
		 EnvironmentSteps = 46002
		 AverageReturn = 43605.03125
		 AverageEpisodeLength = 45.5


Iteration:  46000


INFO:absl: 
		 NumberOfEpisodes = 2573
		 EnvironmentSteps = 47002
		 AverageReturn = 46004.1015625
		 AverageEpisodeLength = 55.599998474121094


Iteration:  47000


INFO:absl: 
		 NumberOfEpisodes = 2595
		 EnvironmentSteps = 48002
		 AverageReturn = 48806.64453125
		 AverageEpisodeLength = 47.70000076293945


Iteration:  48000


INFO:absl: 
		 NumberOfEpisodes = 2620
		 EnvironmentSteps = 49002
		 AverageReturn = 51382.2421875
		 AverageEpisodeLength = 40.79999923706055


Iteration:  49000


INFO:absl: 
		 NumberOfEpisodes = 2639
		 EnvironmentSteps = 50002
		 AverageReturn = 53883.625
		 AverageEpisodeLength = 46.70000076293945


Iteration:  50000


INFO:absl: 
		 NumberOfEpisodes = 2658
		 EnvironmentSteps = 51002
		 AverageReturn = 56371.85546875
		 AverageEpisodeLength = 58.29999923706055


Iteration:  51000


INFO:absl: 
		 NumberOfEpisodes = 2680
		 EnvironmentSteps = 52002
		 AverageReturn = 58842.3671875
		 AverageEpisodeLength = 43.0


Iteration:  52000


INFO:absl: 
		 NumberOfEpisodes = 2705
		 EnvironmentSteps = 53002
		 AverageReturn = 61652.90625
		 AverageEpisodeLength = 38.400001525878906


Iteration:  53000


INFO:absl: 
		 NumberOfEpisodes = 2728
		 EnvironmentSteps = 54002
		 AverageReturn = 64060.26953125
		 AverageEpisodeLength = 47.70000076293945


Iteration:  54000


INFO:absl: 
		 NumberOfEpisodes = 2752
		 EnvironmentSteps = 55002
		 AverageReturn = 66762.234375
		 AverageEpisodeLength = 44.099998474121094


Iteration:  55000


INFO:absl: 
		 NumberOfEpisodes = 2774
		 EnvironmentSteps = 56002
		 AverageReturn = 69440.953125
		 AverageEpisodeLength = 45.5


Iteration:  56000


INFO:absl: 
		 NumberOfEpisodes = 2796
		 EnvironmentSteps = 57002
		 AverageReturn = 72337.1484375
		 AverageEpisodeLength = 42.70000076293945


Iteration:  57000


INFO:absl: 
		 NumberOfEpisodes = 2824
		 EnvironmentSteps = 58002
		 AverageReturn = 75018.890625
		 AverageEpisodeLength = 33.5


Iteration:  58000


INFO:absl: 
		 NumberOfEpisodes = 2847
		 EnvironmentSteps = 59002
		 AverageReturn = 77438.484375
		 AverageEpisodeLength = 38.79999923706055


Iteration:  59000


INFO:absl: 
		 NumberOfEpisodes = 2870
		 EnvironmentSteps = 60002
		 AverageReturn = 80233.921875
		 AverageEpisodeLength = 41.79999923706055


Iteration:  60000


INFO:absl: 
		 NumberOfEpisodes = 2892
		 EnvironmentSteps = 61002
		 AverageReturn = 83038.9609375
		 AverageEpisodeLength = 40.5


Iteration:  61000


INFO:absl: 
		 NumberOfEpisodes = 2913
		 EnvironmentSteps = 62002
		 AverageReturn = 85717.6953125
		 AverageEpisodeLength = 38.5


Iteration:  62000


INFO:absl: 
		 NumberOfEpisodes = 2932
		 EnvironmentSteps = 63002
		 AverageReturn = 88148.359375
		 AverageEpisodeLength = 52.0


Iteration:  63000


INFO:absl: 
		 NumberOfEpisodes = 2953
		 EnvironmentSteps = 64002
		 AverageReturn = 90996.140625
		 AverageEpisodeLength = 47.29999923706055


Iteration:  64000


INFO:absl: 
		 NumberOfEpisodes = 2975
		 EnvironmentSteps = 65002
		 AverageReturn = 93564.6484375
		 AverageEpisodeLength = 48.599998474121094


Iteration:  65000


INFO:absl: 
		 NumberOfEpisodes = 2998
		 EnvironmentSteps = 66002
		 AverageReturn = 96401.6640625
		 AverageEpisodeLength = 36.900001525878906


Iteration:  66000


INFO:absl: 
		 NumberOfEpisodes = 3022
		 EnvironmentSteps = 67002
		 AverageReturn = 99177.2734375
		 AverageEpisodeLength = 45.20000076293945


Iteration:  67000


INFO:absl: 
		 NumberOfEpisodes = 3042
		 EnvironmentSteps = 68002
		 AverageReturn = 101829.328125
		 AverageEpisodeLength = 49.20000076293945


Iteration:  68000


INFO:absl: 
		 NumberOfEpisodes = 3067
		 EnvironmentSteps = 69002
		 AverageReturn = 104566.328125
		 AverageEpisodeLength = 40.70000076293945


Iteration:  69000


INFO:absl: 
		 NumberOfEpisodes = 3089
		 EnvironmentSteps = 70002
		 AverageReturn = 107354.2890625
		 AverageEpisodeLength = 41.0


Iteration:  70000


INFO:absl: 
		 NumberOfEpisodes = 3112
		 EnvironmentSteps = 71002
		 AverageReturn = 110053.3125
		 AverageEpisodeLength = 43.0


Iteration:  71000


INFO:absl: 
		 NumberOfEpisodes = 3133
		 EnvironmentSteps = 72002
		 AverageReturn = 112680.1640625
		 AverageEpisodeLength = 51.5


Iteration:  72000


INFO:absl: 
		 NumberOfEpisodes = 3154
		 EnvironmentSteps = 73002
		 AverageReturn = 115463.703125
		 AverageEpisodeLength = 56.400001525878906


Iteration:  73000


INFO:absl: 
		 NumberOfEpisodes = 3177
		 EnvironmentSteps = 74002
		 AverageReturn = 118411.671875
		 AverageEpisodeLength = 49.599998474121094


Iteration:  74000


INFO:absl: 
		 NumberOfEpisodes = 3202
		 EnvironmentSteps = 75002
		 AverageReturn = 121392.5
		 AverageEpisodeLength = 37.599998474121094


Iteration:  75000


INFO:absl: 
		 NumberOfEpisodes = 3225
		 EnvironmentSteps = 76002
		 AverageReturn = 123994.7890625
		 AverageEpisodeLength = 41.900001525878906


Iteration:  76000


INFO:absl: 
		 NumberOfEpisodes = 3245
		 EnvironmentSteps = 77002
		 AverageReturn = 126611.046875
		 AverageEpisodeLength = 51.20000076293945


Iteration:  77000


INFO:absl: 
		 NumberOfEpisodes = 3265
		 EnvironmentSteps = 78002
		 AverageReturn = 129263.546875
		 AverageEpisodeLength = 43.0


Iteration:  78000


INFO:absl: 
		 NumberOfEpisodes = 3287
		 EnvironmentSteps = 79002
		 AverageReturn = 132093.53125
		 AverageEpisodeLength = 51.29999923706055


Iteration:  79000


INFO:absl: 
		 NumberOfEpisodes = 3310
		 EnvironmentSteps = 80002
		 AverageReturn = 134894.9375
		 AverageEpisodeLength = 48.79999923706055


Iteration:  80000


INFO:absl: 
		 NumberOfEpisodes = 3330
		 EnvironmentSteps = 81002
		 AverageReturn = 137871.296875
		 AverageEpisodeLength = 50.5


Iteration:  81000


INFO:absl: 
		 NumberOfEpisodes = 3350
		 EnvironmentSteps = 82002
		 AverageReturn = 140460.140625
		 AverageEpisodeLength = 51.900001525878906


Iteration:  82000


INFO:absl: 
		 NumberOfEpisodes = 3370
		 EnvironmentSteps = 83002
		 AverageReturn = 143446.90625
		 AverageEpisodeLength = 44.0


Iteration:  83000


INFO:absl: 
		 NumberOfEpisodes = 3391
		 EnvironmentSteps = 84002
		 AverageReturn = 146277.265625
		 AverageEpisodeLength = 55.0


Iteration:  84000


INFO:absl: 
		 NumberOfEpisodes = 3412
		 EnvironmentSteps = 85002
		 AverageReturn = 149194.484375
		 AverageEpisodeLength = 49.70000076293945


Iteration:  85000


INFO:absl: 
		 NumberOfEpisodes = 3431
		 EnvironmentSteps = 86002
		 AverageReturn = 151952.15625
		 AverageEpisodeLength = 52.70000076293945


Iteration:  86000


INFO:absl: 
		 NumberOfEpisodes = 3452
		 EnvironmentSteps = 87002
		 AverageReturn = 154928.59375
		 AverageEpisodeLength = 52.900001525878906


Iteration:  87000


INFO:absl: 
		 NumberOfEpisodes = 3472
		 EnvironmentSteps = 88002
		 AverageReturn = 157401.609375
		 AverageEpisodeLength = 51.900001525878906


Iteration:  88000


INFO:absl: 
		 NumberOfEpisodes = 3493
		 EnvironmentSteps = 89002
		 AverageReturn = 160158.046875
		 AverageEpisodeLength = 53.79999923706055


Iteration:  89000


INFO:absl: 
		 NumberOfEpisodes = 3513
		 EnvironmentSteps = 90002
		 AverageReturn = 162853.0625
		 AverageEpisodeLength = 47.599998474121094


Iteration:  90000


INFO:absl: 
		 NumberOfEpisodes = 3535
		 EnvironmentSteps = 91002
		 AverageReturn = 165777.546875
		 AverageEpisodeLength = 49.400001525878906


Iteration:  91000


INFO:absl: 
		 NumberOfEpisodes = 3556
		 EnvironmentSteps = 92002
		 AverageReturn = 168608.734375
		 AverageEpisodeLength = 41.599998474121094


Iteration:  92000


INFO:absl: 
		 NumberOfEpisodes = 3576
		 EnvironmentSteps = 93002
		 AverageReturn = 171374.875
		 AverageEpisodeLength = 45.099998474121094


Iteration:  93000


INFO:absl: 
		 NumberOfEpisodes = 3597
		 EnvironmentSteps = 94002
		 AverageReturn = 174225.296875
		 AverageEpisodeLength = 41.900001525878906


Iteration:  94000


INFO:absl: 
		 NumberOfEpisodes = 3617
		 EnvironmentSteps = 95002
		 AverageReturn = 176812.34375
		 AverageEpisodeLength = 46.70000076293945


Iteration:  95000


INFO:absl: 
		 NumberOfEpisodes = 3638
		 EnvironmentSteps = 96002
		 AverageReturn = 179692.171875
		 AverageEpisodeLength = 50.599998474121094


Iteration:  96000


INFO:absl: 
		 NumberOfEpisodes = 3660
		 EnvironmentSteps = 97002
		 AverageReturn = 182547.859375
		 AverageEpisodeLength = 42.599998474121094


Iteration:  97000


INFO:absl: 
		 NumberOfEpisodes = 3683
		 EnvironmentSteps = 98002
		 AverageReturn = 185322.71875
		 AverageEpisodeLength = 44.29999923706055


Iteration:  98000


INFO:absl: 
		 NumberOfEpisodes = 3708
		 EnvironmentSteps = 99002
		 AverageReturn = 188117.53125
		 AverageEpisodeLength = 45.20000076293945


Iteration:  99000


In [11]:
# Evaluate the agent

time_step = env.reset()
rewards = []
steps = 0

while not time_step.is_last() and steps < 200:
    steps += 1
    print("Map:\n", np.argmax(time_step.observation.numpy(), axis=3))
    action_step = agent.policy.action(time_step)
    print("Action: ", action_step.action.numpy())
    time_step = env.step(action_step.action)
    rewards.append(time_step.reward.numpy()[0])
    
print("Total reward: ", sum(rewards))
print("Total steps: ", steps)

Map:
 [[[0 0 0 0 0]
  [0 0 0 0 0]
  [0 0 0 0 0]
  [0 3 0 1 2]
  [0 0 0 0 0]]]
Action:  [0]
Map:
 [[[0 0 0 0 0]
  [0 0 0 0 0]
  [0 0 0 0 0]
  [0 3 1 2 0]
  [0 0 0 0 0]]]
Action:  [0]
Map:
 [[[3 0 0 0 0]
  [0 0 0 0 0]
  [0 0 0 0 0]
  [0 1 2 2 0]
  [0 0 0 0 0]]]
Action:  [0]
Map:
 [[[3 0 0 0 0]
  [0 0 0 0 0]
  [0 0 0 0 0]
  [1 2 2 0 0]
  [0 0 0 0 0]]]
Action:  [1]
Map:
 [[[3 0 0 0 0]
  [0 0 0 0 0]
  [1 0 0 0 0]
  [2 2 0 0 0]
  [0 0 0 0 0]]]
Action:  [1]
Map:
 [[[3 0 0 0 0]
  [1 0 0 0 0]
  [2 0 0 0 0]
  [2 0 0 0 0]
  [0 0 0 0 0]]]
Action:  [1]
Map:
 [[[1 0 0 3 0]
  [2 0 0 0 0]
  [2 0 0 0 0]
  [2 0 0 0 0]
  [0 0 0 0 0]]]
Action:  [2]
Map:
 [[[2 1 0 3 0]
  [2 0 0 0 0]
  [2 0 0 0 0]
  [0 0 0 0 0]
  [0 0 0 0 0]]]
Action:  [2]
Map:
 [[[2 2 1 3 0]
  [2 0 0 0 0]
  [0 0 0 0 0]
  [0 0 0 0 0]
  [0 0 0 0 0]]]
Action:  [2]
Map:
 [[[2 2 2 1 0]
  [2 0 0 0 0]
  [0 0 0 3 0]
  [0 0 0 0 0]
  [0 0 0 0 0]]]
Action:  [3]
Map:
 [[[2 2 2 2 0]
  [0 0 0 1 0]
  [0 0 0 3 0]
  [0 0 0 0 0]
  [0 0 0 0 0]]]
Action:  [3]

In [12]:
# Count the average steps per episode

n = 500
total_steps = 0
steps = 0
total_500 = 0

for _ in range(n):
    time_step = env.reset()
    steps = 0
    while not time_step.is_last() and steps < 500:
        steps += 1
        action_step = agent.policy.action(time_step)
        time_step = env.step(action_step.action)
    
    if steps == 500:
        total_500 += 1
    else:
        total_steps += steps
        
if steps == 500:
    print("500 episodes reached")
    
print("Average steps per episode: ", total_steps / (n - total_500))
print("Total 500 episodes: ", total_500)

Average steps per episode:  51.354
Total 500 episodes:  0


In [13]:
print("Map:\n", np.argmax(time_step.observation.numpy(), axis=3))

Map:
 [[[0 0 0 0 0]
  [2 0 3 0 0]
  [1 0 0 0 0]
  [0 0 0 0 0]
  [0 0 0 0 0]]]


In [14]:
action_step = agent.policy.action(time_step)
time_step = env.step(action_step.action)
print("Map:\n", np.argmax(time_step.observation.numpy(), axis=3))

Map:
 [[[0 0 0 0 0]
  [1 0 3 0 0]
  [2 0 0 0 0]
  [0 0 0 0 0]
  [0 0 0 0 0]]]


In [15]:
# Save the agent
import os
from tf_agents.policies.policy_saver import PolicySaver

saver = PolicySaver(agent.policy, batch_size=None)
saver.save("policy")

INFO:absl:Function `function_with_signature` contains input name(s) 0/step_type, 0/reward, 0/discount, 0/observation, 3755814, 3755816, 3755818, 3755820, 3755822, 3755824, 3755826, 3755828 with unsupported characters which will be renamed to step_type, reward, discount, observation, unknown, unknown_0, unknown_1, unknown_2, unknown_3, unknown_4, unknown_5, unknown_6 in the SavedModel.
INFO:absl:Function `function_with_signature` contains input name(s) 3755869 with unsupported characters which will be renamed to unknown in the SavedModel.


INFO:tensorflow:Assets written to: policy/assets


INFO:tensorflow:Assets written to: policy/assets
