In [1]:
import numpy as np
import tensorflow as tf
from tf_agents.networks.q_network import QNetwork
import tf_agents.networks.network as network
from tf_agents.specs import tensor_spec
from snake_game import SnakeGame
from scene import Scene
from scene_longer_snake import SceneLongerSnake
from tensorflow.keras import Sequential
from tensorflow.keras.layers import Dense, Flatten, Conv2D, Input

scene = SceneLongerSnake(init_randomly=True, snake_longer_prob=0.95, length_mean=10, length_std=2)
episodes_count = 100000
random_episodes = 20000


In [2]:
from tf_agents.environments.tf_py_environment import TFPyEnvironment

env = SnakeGame(scene)
env = TFPyEnvironment(env)


In [3]:
class MyQNetwork(network.Network):
    def __init__(self,
                 input_tensor_spec,
                 action_spec,
                 name='MyQNetwork'):
        super(MyQNetwork, self).__init__(
            input_tensor_spec=input_tensor_spec,
            state_spec=(),
            name=name)
        
        self._action_spec = action_spec
        matrix_shape = input_tensor_spec[0].shape
        obstacles_shape = input_tensor_spec[1].shape
        no_body_blocks_shape = input_tensor_spec[2].shape
        
        input0 = Input(shape=matrix_shape)
        input1 = Input(shape=obstacles_shape) # Each value represents if there is a wall or the snake body in the direction, 0 - no wall, 1 - wall
        input2 = Input(shape=no_body_blocks_shape) # Each value represents how many no body blocks are in the direction
        
        conv = Conv2D(32, (2, 2), 1, activation='relu', kernel_initializer='he_normal')(input0)
        flat = Flatten()(conv)
        dense1 = Dense(128, activation='relu', kernel_initializer='he_normal')(flat)
        dense2 = Dense(32, activation='relu', kernel_initializer='he_normal')(dense1)
        
        concat = tf.keras.layers.concatenate([dense2, input1, input2])
        output = Dense(4, activation='linear')(concat)
        
        self._model = tf.keras.Model(inputs=[input0, input1, input2], outputs=output)
        
        self._model.summary()

    def call(self, observations, step_type=None, network_state=(), training=False):
        output = self._model([observations[0], observations[1], observations[2]])
        return output, network_state

q_net = MyQNetwork(
    env.observation_spec(),
    env.action_spec())


Model: "model"
__________________________________________________________________________________________________
 Layer (type)                Output Shape                 Param #   Connected to                  
 input_1 (InputLayer)        [(None, 5, 5, 4)]            0         []                            
                                                                                                  
 conv2d (Conv2D)             (None, 4, 4, 32)             544       ['input_1[0][0]']             
                                                                                                  
 flatten (Flatten)           (None, 512)                  0         ['conv2d[0][0]']              
                                                                                                  
 dense (Dense)               (None, 128)                  65664     ['flatten[0][0]']             
                                                                                              

In [4]:
# Create the agent
from tf_agents.agents.dqn.dqn_agent import DdqnAgent
from tensorflow import keras
from keras.optimizers.legacy import Adam
from keras.losses import Huber
from tf_agents.trajectories import TimeStep

train_step_counter = tf.Variable(0)
optimizer = Adam(learning_rate=0.002)

epsilon = keras.optimizers.schedules.PolynomialDecay(
    initial_learning_rate=0.1,
    decay_steps=episodes_count,
    end_learning_rate=0.01
)

agent = DdqnAgent(
    time_step_spec=env.time_step_spec(),
    action_spec=env.action_spec(),
    q_network=q_net,
    optimizer=optimizer,
    td_errors_loss_fn=Huber(reduction="none"),
    gamma=tf.constant(0.95, dtype=tf.float32),
    train_step_counter=train_step_counter,
    epsilon_greedy=lambda: epsilon(train_step_counter)
)


Model: "model_1"
__________________________________________________________________________________________________
 Layer (type)                Output Shape                 Param #   Connected to                  
 input_4 (InputLayer)        [(None, 5, 5, 4)]            0         []                            
                                                                                                  
 conv2d_1 (Conv2D)           (None, 4, 4, 32)             544       ['input_4[0][0]']             
                                                                                                  
 flatten_1 (Flatten)         (None, 512)                  0         ['conv2d_1[0][0]']            
                                                                                                  
 dense_3 (Dense)             (None, 128)                  65664     ['flatten_1[0][0]']           
                                                                                            

In [5]:
# Create the replay buffer
from tf_agents.replay_buffers import tf_uniform_replay_buffer

replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
    data_spec=agent.collect_data_spec,
    batch_size=env.batch_size,
    max_length=100000
)


In [6]:
# Create the training metrics
from tf_agents.metrics import tf_metrics

train_metrics = [
    tf_metrics.NumberOfEpisodes(),
    tf_metrics.EnvironmentSteps(),
    tf_metrics.AverageReturnMetric(),
    tf_metrics.AverageEpisodeLengthMetric()
]


In [7]:
# Create the driver
from tf_agents.drivers import dynamic_step_driver

collect_driver = dynamic_step_driver.DynamicStepDriver(
    env,
    policy=agent.collect_policy,
    observers=[replay_buffer.add_batch] + train_metrics,
    num_steps=2
)

# Run a random policy to fill the replay buffer
from tf_agents.policies.random_tf_policy import RandomTFPolicy

initial_collect_policy = RandomTFPolicy(env.time_step_spec(), env.action_spec())
init_driver = dynamic_step_driver.DynamicStepDriver(
    env,
    policy=initial_collect_policy,
    observers=[replay_buffer.add_batch],
    num_steps=random_episodes
)

final_time_step, final_policy_state = init_driver.run()


In [8]:
# Create the dataset
dataset = replay_buffer.as_dataset(
    sample_batch_size=64,
    num_steps=2, # Capaz cambiar esto
    num_parallel_calls=5
).prefetch(5)


Instructions for updating:
Use `as_dataset(..., single_deterministic_pass=False) instead.


In [9]:

# Create the training loop

from tf_agents.utils.common import function
from tf_agents.trajectories import trajectory
from tf_agents.eval.metric_utils import log_metrics
import logging

logging.getLogger().setLevel(logging.INFO)

collect_driver.run = function(collect_driver.run)
agent.train = function(agent.train)

def train_agent(n_iterations):
    time_step = None
    policy_state = agent.collect_policy.get_initial_state(env.batch_size)
    iterator = iter(dataset)
    for iteration in range(n_iterations):
        time_step, policy_state = collect_driver.run(time_step, policy_state)
        trajectories, buffer_info = next(iterator)
        train_loss = agent.train(trajectories)
        if iteration % 1000 == 0:
            print("Iteration: ", iteration)
            print("Replay buffer len: " + str(replay_buffer.num_frames().numpy()))
            log_metrics(train_metrics)


In [10]:
train_agent(episodes_count)


Instructions for updating:
back_prop=False is deprecated. Consider using tf.stop_gradient instead.
Instead of:
results = tf.foldr(fn, elems, back_prop=False)
Use:
results = tf.nest.map_structure(tf.stop_gradient, tf.foldr(fn, elems))


INFO:absl: 
		 NumberOfEpisodes = 11
		 EnvironmentSteps = 2
		 AverageReturn = -260.0
		 AverageEpisodeLength = 0.10000000149011612


Iteration:  0
Replay buffer len: 43665


INFO:absl: 
		 NumberOfEpisodes = 392
		 EnvironmentSteps = 2002
		 AverageReturn = -15003.5
		 AverageEpisodeLength = 18.600000381469727


Iteration:  1000
Replay buffer len: 46047


INFO:absl: 
		 NumberOfEpisodes = 534
		 EnvironmentSteps = 4002
		 AverageReturn = -20213.0
		 AverageEpisodeLength = 13.699999809265137


Iteration:  2000
Replay buffer len: 48189


INFO:absl: 
		 NumberOfEpisodes = 660
		 EnvironmentSteps = 6002
		 AverageReturn = -24396.0
		 AverageEpisodeLength = 11.800000190734863


Iteration:  3000
Replay buffer len: 50315


INFO:absl: 
		 NumberOfEpisodes = 803
		 EnvironmentSteps = 8002
		 AverageReturn = -29023.5
		 AverageEpisodeLength = 12.300000190734863


Iteration:  4000
Replay buffer len: 52458


INFO:absl: 
		 NumberOfEpisodes = 958
		 EnvironmentSteps = 10002
		 AverageReturn = -34009.0
		 AverageEpisodeLength = 5.199999809265137


Iteration:  5000
Replay buffer len: 54613


INFO:absl: 
		 NumberOfEpisodes = 1099
		 EnvironmentSteps = 12002
		 AverageReturn = -38204.5
		 AverageEpisodeLength = 21.299999237060547


Iteration:  6000
Replay buffer len: 56754


INFO:absl: 
		 NumberOfEpisodes = 1234
		 EnvironmentSteps = 14002
		 AverageReturn = -42086.5
		 AverageEpisodeLength = 10.600000381469727


Iteration:  7000
Replay buffer len: 58889


INFO:absl: 
		 NumberOfEpisodes = 1362
		 EnvironmentSteps = 16002
		 AverageReturn = -45634.5
		 AverageEpisodeLength = 13.399999618530273


Iteration:  8000
Replay buffer len: 61017


INFO:absl: 
		 NumberOfEpisodes = 1532
		 EnvironmentSteps = 18002
		 AverageReturn = -50843.5
		 AverageEpisodeLength = 8.399999618530273


Iteration:  9000
Replay buffer len: 63187


INFO:absl: 
		 NumberOfEpisodes = 1654
		 EnvironmentSteps = 20002
		 AverageReturn = -54166.0
		 AverageEpisodeLength = 11.100000381469727


Iteration:  10000
Replay buffer len: 65309


INFO:absl: 
		 NumberOfEpisodes = 1799
		 EnvironmentSteps = 22002
		 AverageReturn = -58298.0
		 AverageEpisodeLength = 14.300000190734863


Iteration:  11000
Replay buffer len: 67454


INFO:absl: 
		 NumberOfEpisodes = 1933
		 EnvironmentSteps = 24002
		 AverageReturn = -61877.0
		 AverageEpisodeLength = 7.5


Iteration:  12000
Replay buffer len: 69588


INFO:absl: 
		 NumberOfEpisodes = 2081
		 EnvironmentSteps = 26002
		 AverageReturn = -66119.5
		 AverageEpisodeLength = 10.0


Iteration:  13000
Replay buffer len: 71736


INFO:absl: 
		 NumberOfEpisodes = 2215
		 EnvironmentSteps = 28002
		 AverageReturn = -69745.0
		 AverageEpisodeLength = 15.600000381469727


Iteration:  14000
Replay buffer len: 73870


INFO:absl: 
		 NumberOfEpisodes = 2375
		 EnvironmentSteps = 30002
		 AverageReturn = -74096.5
		 AverageEpisodeLength = 10.300000190734863


Iteration:  15000
Replay buffer len: 76030


INFO:absl: 
		 NumberOfEpisodes = 2506
		 EnvironmentSteps = 32002
		 AverageReturn = -77470.0
		 AverageEpisodeLength = 20.200000762939453


Iteration:  16000
Replay buffer len: 78161


INFO:absl: 
		 NumberOfEpisodes = 2650
		 EnvironmentSteps = 34002
		 AverageReturn = -81242.0
		 AverageEpisodeLength = 10.399999618530273


Iteration:  17000
Replay buffer len: 80305


INFO:absl: 
		 NumberOfEpisodes = 2787
		 EnvironmentSteps = 36002
		 AverageReturn = -85010.5
		 AverageEpisodeLength = 14.5


Iteration:  18000
Replay buffer len: 82442


INFO:absl: 
		 NumberOfEpisodes = 2917
		 EnvironmentSteps = 38002
		 AverageReturn = -88339.0
		 AverageEpisodeLength = 14.699999809265137


Iteration:  19000
Replay buffer len: 84572


INFO:absl: 
		 NumberOfEpisodes = 3030
		 EnvironmentSteps = 40002
		 AverageReturn = -90928.0
		 AverageEpisodeLength = 14.100000381469727


Iteration:  20000
Replay buffer len: 86685


INFO:absl: 
		 NumberOfEpisodes = 3161
		 EnvironmentSteps = 42002
		 AverageReturn = -94322.5
		 AverageEpisodeLength = 10.199999809265137


Iteration:  21000
Replay buffer len: 88816


INFO:absl: 
		 NumberOfEpisodes = 3272
		 EnvironmentSteps = 44002
		 AverageReturn = -96942.5
		 AverageEpisodeLength = 19.399999618530273


Iteration:  22000
Replay buffer len: 90927


INFO:absl: 
		 NumberOfEpisodes = 3396
		 EnvironmentSteps = 46002
		 AverageReturn = -100057.5
		 AverageEpisodeLength = 17.399999618530273


Iteration:  23000
Replay buffer len: 93051


INFO:absl: 
		 NumberOfEpisodes = 3523
		 EnvironmentSteps = 48002
		 AverageReturn = -103266.5
		 AverageEpisodeLength = 14.199999809265137


Iteration:  24000
Replay buffer len: 95177


INFO:absl: 
		 NumberOfEpisodes = 3648
		 EnvironmentSteps = 50002
		 AverageReturn = -106380.5
		 AverageEpisodeLength = 21.100000381469727


Iteration:  25000
Replay buffer len: 97303


INFO:absl: 
		 NumberOfEpisodes = 3757
		 EnvironmentSteps = 52002
		 AverageReturn = -108830.5
		 AverageEpisodeLength = 12.800000190734863


Iteration:  26000
Replay buffer len: 99412


INFO:absl: 
		 NumberOfEpisodes = 3873
		 EnvironmentSteps = 54002
		 AverageReturn = -111635.0
		 AverageEpisodeLength = 12.800000190734863


Iteration:  27000
Replay buffer len: 100000


INFO:absl: 
		 NumberOfEpisodes = 3997
		 EnvironmentSteps = 56002
		 AverageReturn = -114789.5
		 AverageEpisodeLength = 15.100000381469727


Iteration:  28000
Replay buffer len: 100000


INFO:absl: 
		 NumberOfEpisodes = 4111
		 EnvironmentSteps = 58002
		 AverageReturn = -117551.5
		 AverageEpisodeLength = 25.299999237060547


Iteration:  29000
Replay buffer len: 100000


INFO:absl: 
		 NumberOfEpisodes = 4209
		 EnvironmentSteps = 60002
		 AverageReturn = -119689.5
		 AverageEpisodeLength = 19.700000762939453


Iteration:  30000
Replay buffer len: 100000


INFO:absl: 
		 NumberOfEpisodes = 4351
		 EnvironmentSteps = 62002
		 AverageReturn = -123292.0
		 AverageEpisodeLength = 14.100000381469727


Iteration:  31000
Replay buffer len: 100000


INFO:absl: 
		 NumberOfEpisodes = 4471
		 EnvironmentSteps = 64002
		 AverageReturn = -126129.0
		 AverageEpisodeLength = 12.300000190734863


Iteration:  32000
Replay buffer len: 100000


INFO:absl: 
		 NumberOfEpisodes = 4577
		 EnvironmentSteps = 66002
		 AverageReturn = -128600.0
		 AverageEpisodeLength = 10.399999618530273


Iteration:  33000
Replay buffer len: 100000


INFO:absl: 
		 NumberOfEpisodes = 4701
		 EnvironmentSteps = 68002
		 AverageReturn = -131698.5
		 AverageEpisodeLength = 20.600000381469727


Iteration:  34000
Replay buffer len: 100000


INFO:absl: 
		 NumberOfEpisodes = 4821
		 EnvironmentSteps = 70002
		 AverageReturn = -134505.0
		 AverageEpisodeLength = 17.100000381469727


Iteration:  35000
Replay buffer len: 100000


INFO:absl: 
		 NumberOfEpisodes = 4926
		 EnvironmentSteps = 72002
		 AverageReturn = -136719.5
		 AverageEpisodeLength = 12.600000381469727


Iteration:  36000
Replay buffer len: 100000


INFO:absl: 
		 NumberOfEpisodes = 5021
		 EnvironmentSteps = 74002
		 AverageReturn = -138724.5
		 AverageEpisodeLength = 11.100000381469727


Iteration:  37000
Replay buffer len: 100000


INFO:absl: 
		 NumberOfEpisodes = 5124
		 EnvironmentSteps = 76002
		 AverageReturn = -140961.0
		 AverageEpisodeLength = 27.399999618530273


Iteration:  38000
Replay buffer len: 100000


INFO:absl: 
		 NumberOfEpisodes = 5228
		 EnvironmentSteps = 78002
		 AverageReturn = -143193.5
		 AverageEpisodeLength = 26.5


Iteration:  39000
Replay buffer len: 100000


INFO:absl: 
		 NumberOfEpisodes = 5324
		 EnvironmentSteps = 80002
		 AverageReturn = -145023.5
		 AverageEpisodeLength = 22.399999618530273


Iteration:  40000
Replay buffer len: 100000


INFO:absl: 
		 NumberOfEpisodes = 5435
		 EnvironmentSteps = 82002
		 AverageReturn = -147573.0
		 AverageEpisodeLength = 20.0


Iteration:  41000
Replay buffer len: 100000


INFO:absl: 
		 NumberOfEpisodes = 5527
		 EnvironmentSteps = 84002
		 AverageReturn = -149309.5
		 AverageEpisodeLength = 13.699999809265137


Iteration:  42000
Replay buffer len: 100000


INFO:absl: 
		 NumberOfEpisodes = 5627
		 EnvironmentSteps = 86002
		 AverageReturn = -151470.5
		 AverageEpisodeLength = 25.399999618530273


Iteration:  43000
Replay buffer len: 100000


INFO:absl: 
		 NumberOfEpisodes = 5735
		 EnvironmentSteps = 88002
		 AverageReturn = -153906.5
		 AverageEpisodeLength = 21.700000762939453


Iteration:  44000
Replay buffer len: 100000


INFO:absl: 
		 NumberOfEpisodes = 5820
		 EnvironmentSteps = 90002
		 AverageReturn = -155324.0
		 AverageEpisodeLength = 28.100000381469727


Iteration:  45000
Replay buffer len: 100000


INFO:absl: 
		 NumberOfEpisodes = 5943
		 EnvironmentSteps = 92002
		 AverageReturn = -158318.5
		 AverageEpisodeLength = 16.299999237060547


Iteration:  46000
Replay buffer len: 100000


INFO:absl: 
		 NumberOfEpisodes = 6046
		 EnvironmentSteps = 94002
		 AverageReturn = -160570.0
		 AverageEpisodeLength = 19.5


Iteration:  47000
Replay buffer len: 100000


INFO:absl: 
		 NumberOfEpisodes = 6141
		 EnvironmentSteps = 96002
		 AverageReturn = -162446.5
		 AverageEpisodeLength = 19.5


Iteration:  48000
Replay buffer len: 100000


INFO:absl: 
		 NumberOfEpisodes = 6231
		 EnvironmentSteps = 98002
		 AverageReturn = -164260.5
		 AverageEpisodeLength = 25.700000762939453


Iteration:  49000
Replay buffer len: 100000


INFO:absl: 
		 NumberOfEpisodes = 6319
		 EnvironmentSteps = 100002
		 AverageReturn = -165775.0
		 AverageEpisodeLength = 23.100000381469727


Iteration:  50000
Replay buffer len: 100000


INFO:absl: 
		 NumberOfEpisodes = 6395
		 EnvironmentSteps = 102002
		 AverageReturn = -166808.0
		 AverageEpisodeLength = 22.600000381469727


Iteration:  51000
Replay buffer len: 100000


INFO:absl: 
		 NumberOfEpisodes = 6480
		 EnvironmentSteps = 104002
		 AverageReturn = -168301.5
		 AverageEpisodeLength = 12.699999809265137


Iteration:  52000
Replay buffer len: 100000


INFO:absl: 
		 NumberOfEpisodes = 6579
		 EnvironmentSteps = 106002
		 AverageReturn = -170365.5
		 AverageEpisodeLength = 27.100000381469727


Iteration:  53000
Replay buffer len: 100000


INFO:absl: 
		 NumberOfEpisodes = 6667
		 EnvironmentSteps = 108002
		 AverageReturn = -171885.5
		 AverageEpisodeLength = 15.699999809265137


Iteration:  54000
Replay buffer len: 100000


INFO:absl: 
		 NumberOfEpisodes = 6761
		 EnvironmentSteps = 110002
		 AverageReturn = -173625.5
		 AverageEpisodeLength = 24.700000762939453


Iteration:  55000
Replay buffer len: 100000


INFO:absl: 
		 NumberOfEpisodes = 6851
		 EnvironmentSteps = 112002
		 AverageReturn = -175308.5
		 AverageEpisodeLength = 11.199999809265137


Iteration:  56000
Replay buffer len: 100000


INFO:absl: 
		 NumberOfEpisodes = 6937
		 EnvironmentSteps = 114002
		 AverageReturn = -176868.0
		 AverageEpisodeLength = 22.299999237060547


Iteration:  57000
Replay buffer len: 100000


INFO:absl: 
		 NumberOfEpisodes = 7027
		 EnvironmentSteps = 116002
		 AverageReturn = -178337.5
		 AverageEpisodeLength = 16.5


Iteration:  58000
Replay buffer len: 100000


INFO:absl: 
		 NumberOfEpisodes = 7105
		 EnvironmentSteps = 118002
		 AverageReturn = -179691.0
		 AverageEpisodeLength = 21.5


Iteration:  59000
Replay buffer len: 100000


INFO:absl: 
		 NumberOfEpisodes = 7181
		 EnvironmentSteps = 120002
		 AverageReturn = -180782.5
		 AverageEpisodeLength = 27.100000381469727


Iteration:  60000
Replay buffer len: 100000


INFO:absl: 
		 NumberOfEpisodes = 7262
		 EnvironmentSteps = 122002
		 AverageReturn = -182009.0
		 AverageEpisodeLength = 14.800000190734863


Iteration:  61000
Replay buffer len: 100000


INFO:absl: 
		 NumberOfEpisodes = 7347
		 EnvironmentSteps = 124002
		 AverageReturn = -183339.5
		 AverageEpisodeLength = 25.200000762939453


Iteration:  62000
Replay buffer len: 100000


INFO:absl: 
		 NumberOfEpisodes = 7434
		 EnvironmentSteps = 126002
		 AverageReturn = -184689.5
		 AverageEpisodeLength = 15.899999618530273


Iteration:  63000
Replay buffer len: 100000


INFO:absl: 
		 NumberOfEpisodes = 7509
		 EnvironmentSteps = 128002
		 AverageReturn = -185849.5
		 AverageEpisodeLength = 34.400001525878906


Iteration:  64000
Replay buffer len: 100000


INFO:absl: 
		 NumberOfEpisodes = 7587
		 EnvironmentSteps = 130002
		 AverageReturn = -186969.0
		 AverageEpisodeLength = 23.899999618530273


Iteration:  65000
Replay buffer len: 100000


INFO:absl: 
		 NumberOfEpisodes = 7660
		 EnvironmentSteps = 132002
		 AverageReturn = -187716.0
		 AverageEpisodeLength = 16.100000381469727


Iteration:  66000
Replay buffer len: 100000


INFO:absl: 
		 NumberOfEpisodes = 7739
		 EnvironmentSteps = 134002
		 AverageReturn = -188908.0
		 AverageEpisodeLength = 15.5


Iteration:  67000
Replay buffer len: 100000


INFO:absl: 
		 NumberOfEpisodes = 7816
		 EnvironmentSteps = 136002
		 AverageReturn = -190001.5
		 AverageEpisodeLength = 20.0


Iteration:  68000
Replay buffer len: 100000


INFO:absl: 
		 NumberOfEpisodes = 7891
		 EnvironmentSteps = 138002
		 AverageReturn = -190941.0
		 AverageEpisodeLength = 30.0


Iteration:  69000
Replay buffer len: 100000


INFO:absl: 
		 NumberOfEpisodes = 7982
		 EnvironmentSteps = 140002
		 AverageReturn = -192425.0
		 AverageEpisodeLength = 17.5


Iteration:  70000
Replay buffer len: 100000


INFO:absl: 
		 NumberOfEpisodes = 8053
		 EnvironmentSteps = 142002
		 AverageReturn = -193319.5
		 AverageEpisodeLength = 30.299999237060547


Iteration:  71000
Replay buffer len: 100000


INFO:absl: 
		 NumberOfEpisodes = 8124
		 EnvironmentSteps = 144002
		 AverageReturn = -194067.0
		 AverageEpisodeLength = 23.0


Iteration:  72000
Replay buffer len: 100000


INFO:absl: 
		 NumberOfEpisodes = 8200
		 EnvironmentSteps = 146002
		 AverageReturn = -195064.0
		 AverageEpisodeLength = 45.599998474121094


Iteration:  73000
Replay buffer len: 100000


INFO:absl: 
		 NumberOfEpisodes = 8277
		 EnvironmentSteps = 148002
		 AverageReturn = -195939.0
		 AverageEpisodeLength = 35.79999923706055


Iteration:  74000
Replay buffer len: 100000


INFO:absl: 
		 NumberOfEpisodes = 8347
		 EnvironmentSteps = 150002
		 AverageReturn = -196549.0
		 AverageEpisodeLength = 30.700000762939453


Iteration:  75000
Replay buffer len: 100000


INFO:absl: 
		 NumberOfEpisodes = 8408
		 EnvironmentSteps = 152002
		 AverageReturn = -196833.0
		 AverageEpisodeLength = 23.100000381469727


Iteration:  76000
Replay buffer len: 100000


INFO:absl: 
		 NumberOfEpisodes = 8489
		 EnvironmentSteps = 154002
		 AverageReturn = -198073.0
		 AverageEpisodeLength = 26.299999237060547


Iteration:  77000
Replay buffer len: 100000


INFO:absl: 
		 NumberOfEpisodes = 8565
		 EnvironmentSteps = 156002
		 AverageReturn = -199257.0
		 AverageEpisodeLength = 31.399999618530273


Iteration:  78000
Replay buffer len: 100000


INFO:absl: 
		 NumberOfEpisodes = 8633
		 EnvironmentSteps = 158002
		 AverageReturn = -200065.0
		 AverageEpisodeLength = 18.799999237060547


Iteration:  79000
Replay buffer len: 100000


INFO:absl: 
		 NumberOfEpisodes = 8697
		 EnvironmentSteps = 160002
		 AverageReturn = -200473.5
		 AverageEpisodeLength = 33.099998474121094


Iteration:  80000
Replay buffer len: 100000


INFO:absl: 
		 NumberOfEpisodes = 8772
		 EnvironmentSteps = 162002
		 AverageReturn = -201445.0
		 AverageEpisodeLength = 39.900001525878906


Iteration:  81000
Replay buffer len: 100000


INFO:absl: 
		 NumberOfEpisodes = 8830
		 EnvironmentSteps = 164002
		 AverageReturn = -201611.0
		 AverageEpisodeLength = 36.400001525878906


Iteration:  82000
Replay buffer len: 100000


INFO:absl: 
		 NumberOfEpisodes = 8912
		 EnvironmentSteps = 166002
		 AverageReturn = -202686.0
		 AverageEpisodeLength = 26.600000381469727


Iteration:  83000
Replay buffer len: 100000


INFO:absl: 
		 NumberOfEpisodes = 8979
		 EnvironmentSteps = 168002
		 AverageReturn = -203526.0
		 AverageEpisodeLength = 32.29999923706055


Iteration:  84000
Replay buffer len: 100000


INFO:absl: 
		 NumberOfEpisodes = 9045
		 EnvironmentSteps = 170002
		 AverageReturn = -204182.5
		 AverageEpisodeLength = 32.5


Iteration:  85000
Replay buffer len: 100000


INFO:absl: 
		 NumberOfEpisodes = 9106
		 EnvironmentSteps = 172002
		 AverageReturn = -204293.0
		 AverageEpisodeLength = 27.0


Iteration:  86000
Replay buffer len: 100000


INFO:absl: 
		 NumberOfEpisodes = 9171
		 EnvironmentSteps = 174002
		 AverageReturn = -204870.0
		 AverageEpisodeLength = 27.799999237060547


Iteration:  87000
Replay buffer len: 100000


INFO:absl: 
		 NumberOfEpisodes = 9234
		 EnvironmentSteps = 176002
		 AverageReturn = -205332.5
		 AverageEpisodeLength = 33.400001525878906


Iteration:  88000
Replay buffer len: 100000


INFO:absl: 
		 NumberOfEpisodes = 9304
		 EnvironmentSteps = 178002
		 AverageReturn = -206125.5
		 AverageEpisodeLength = 43.5


Iteration:  89000
Replay buffer len: 100000


INFO:absl: 
		 NumberOfEpisodes = 9362
		 EnvironmentSteps = 180002
		 AverageReturn = -206227.0
		 AverageEpisodeLength = 23.0


Iteration:  90000
Replay buffer len: 100000


INFO:absl: 
		 NumberOfEpisodes = 9428
		 EnvironmentSteps = 182002
		 AverageReturn = -206526.5
		 AverageEpisodeLength = 23.899999618530273


Iteration:  91000
Replay buffer len: 100000


INFO:absl: 
		 NumberOfEpisodes = 9490
		 EnvironmentSteps = 184002
		 AverageReturn = -206850.5
		 AverageEpisodeLength = 37.79999923706055


Iteration:  92000
Replay buffer len: 100000


INFO:absl: 
		 NumberOfEpisodes = 9549
		 EnvironmentSteps = 186002
		 AverageReturn = -207084.0
		 AverageEpisodeLength = 39.900001525878906


Iteration:  93000
Replay buffer len: 100000


INFO:absl: 
		 NumberOfEpisodes = 9604
		 EnvironmentSteps = 188002
		 AverageReturn = -207096.5
		 AverageEpisodeLength = 46.0


Iteration:  94000
Replay buffer len: 100000


INFO:absl: 
		 NumberOfEpisodes = 9657
		 EnvironmentSteps = 190002
		 AverageReturn = -207052.5
		 AverageEpisodeLength = 42.0


Iteration:  95000
Replay buffer len: 100000


INFO:absl: 
		 NumberOfEpisodes = 9710
		 EnvironmentSteps = 192002
		 AverageReturn = -207019.0
		 AverageEpisodeLength = 33.20000076293945


Iteration:  96000
Replay buffer len: 100000


INFO:absl: 
		 NumberOfEpisodes = 9765
		 EnvironmentSteps = 194002
		 AverageReturn = -207141.0
		 AverageEpisodeLength = 40.70000076293945


Iteration:  97000
Replay buffer len: 100000


INFO:absl: 
		 NumberOfEpisodes = 9826
		 EnvironmentSteps = 196002
		 AverageReturn = -207325.5
		 AverageEpisodeLength = 45.20000076293945


Iteration:  98000
Replay buffer len: 100000


INFO:absl: 
		 NumberOfEpisodes = 9877
		 EnvironmentSteps = 198002
		 AverageReturn = -207287.5
		 AverageEpisodeLength = 50.79999923706055


Iteration:  99000
Replay buffer len: 100000


In [59]:
# Evaluate the agent

test_scene = Scene(init_randomly=True)
test_env = SnakeGame(test_scene)
test_env = TFPyEnvironment(test_env)

time_step = test_env.reset()
rewards = []
steps = 0

while not time_step.is_last() and steps < 200:
    steps += 1
    print("Map:\n", np.argmax(time_step.observation[0], axis=3))
    print("Direction:\n", time_step.observation[1])
    print("No body blocks:\n", time_step.observation[2])
    action_step = agent.policy.action(time_step)
    print("Action: ", action_step.action.numpy())
    time_step = test_env.step(action_step.action)
    rewards.append(time_step.reward.numpy()[0])
    
print("Total reward: ", sum(rewards))
print("Total steps: ", steps)


Map:
 [[[0 0 0 1 2]
  [0 0 0 0 0]
  [0 0 0 0 3]
  [0 0 0 0 0]
  [0 0 0 0 0]]]
Direction:
 tf.Tensor([[False  True  True False]], shape=(1, 4), dtype=bool)
No body blocks:
 tf.Tensor([[0.375 0.    0.125 0.5  ]], shape=(1, 4), dtype=float32)
Action:  [3]
Map:
 [[[0 0 0 2 0]
  [0 0 0 1 0]
  [0 0 0 0 3]
  [0 0 0 0 0]
  [0 0 0 0 0]]]
Direction:
 tf.Tensor([[False  True False False]], shape=(1, 4), dtype=bool)
No body blocks:
 tf.Tensor([[0.375 0.125 0.125 0.375]], shape=(1, 4), dtype=float32)
Action:  [3]
Map:
 [[[0 0 0 0 0]
  [0 0 0 2 0]
  [0 0 0 1 3]
  [0 0 0 0 0]
  [0 0 0 0 0]]]
Direction:
 tf.Tensor([[False  True False False]], shape=(1, 4), dtype=bool)
No body blocks:
 tf.Tensor([[0.375 0.25  0.125 0.25 ]], shape=(1, 4), dtype=float32)
Action:  [2]
Map:
 [[[0 0 0 0 0]
  [0 0 0 2 0]
  [3 0 0 2 1]
  [0 0 0 0 0]
  [0 0 0 0 0]]]
Direction:
 tf.Tensor([[ True False  True False]], shape=(1, 4), dtype=bool)
No body blocks:
 tf.Tensor([[0.4871795  0.25641027 0.         0.25641027]], shape=(1, 

In [12]:

# Count the average steps per episode

n = 500
total_steps = 0
total_reward = 0
steps = 0
total_500 = 0
wins = 0

for i in range(n):
    if i % 10 == 0 and i != 0:
        print("Episode: ", i)
        
    time_step = test_env.reset()
    steps = 0
    episode_reward = 0
    while not time_step.is_last() and steps < 500:
        steps += 1
        action_step = agent.policy.action(time_step)
        time_step = test_env.step(action_step.action)
        reward = time_step.reward.numpy()[0]
        total_reward += reward
        episode_reward += reward
        
    if episode_reward == 115:
        wins += 1
    
    if steps == 500:
        total_500 += 1
    else:
        total_steps += steps
        
    if i % 10 == 0 and i != 0:
        print("Avg reward: ", total_reward / (i + 1))
        
if steps == 500:
    print("500 episodes reached")
    
print("Average steps per episode: ", total_steps / (n - total_500))
print("Average reward per episode: ", total_reward / n)
print("Total 500 episodes: ", total_500)
print("Wins: ", wins)


Episode:  10
Avg reward:  49.54545454545455
Episode:  20
Avg reward:  43.57142857142857
Episode:  30
Avg reward:  45.645161290322584
Episode:  40
Avg reward:  48.78048780487805
Episode:  50
Avg reward:  49.411764705882355
Episode:  60
Avg reward:  49.59016393442623
Episode:  70
Avg reward:  47.46478873239437
Episode:  80
Avg reward:  46.72839506172839
Episode:  90
Avg reward:  47.637362637362635
Episode:  100
Avg reward:  49.7029702970297
Episode:  110
Avg reward:  48.153153153153156
Episode:  120
Avg reward:  48.264462809917354
Episode:  130
Avg reward:  47.70992366412214
Episode:  140
Avg reward:  47.02127659574468
Episode:  150
Avg reward:  47.913907284768214
Episode:  160
Avg reward:  47.2360248447205
Episode:  170
Avg reward:  47.046783625730995
Episode:  180
Avg reward:  47.70718232044199
Episode:  190
Avg reward:  47.40837696335078
Episode:  200
Avg reward:  47.43781094527363
Episode:  210
Avg reward:  47.48815165876777
Episode:  220
Avg reward:  46.90045248868778
Episode:  230


In [13]:
# Save the q network weights
tf.keras.models.save_model(q_net._model, "q_network_5x5.h5")



  tf.keras.models.save_model(q_net._model, "q_network_5x5.h5")
