In [1]:
import numpy as np
import tensorflow as tf
from tf_agents.networks.q_network import QNetwork
import tf_agents.networks.network as network
from tf_agents.specs import tensor_spec
from snake_game import SnakeGame
from scene import Scene
from scene_longer_snake import SceneLongerSnake
from tensorflow.keras import Sequential
from tensorflow.keras.layers import Dense, Flatten, Conv2D, Input

scene = SceneLongerSnake(init_randomly=True, snake_longer_prob=0.92, initial_length=10)
episodes_count = 80000
random_episodes = 20000


In [2]:
from tf_agents.environments.tf_py_environment import TFPyEnvironment

env = SnakeGame(scene)
env = TFPyEnvironment(env)


In [3]:
class MyQNetwork(network.Network):
    def __init__(self,
                 input_tensor_spec,
                 action_spec,
                 name='MyQNetwork'):
        super(MyQNetwork, self).__init__(
            input_tensor_spec=input_tensor_spec,
            state_spec=(),
            name=name)
        
        self._action_spec = action_spec
        matrix_shape = input_tensor_spec[0].shape
        array_shape = input_tensor_spec[1].shape
        
        input1 = Input(shape=matrix_shape)
        input2 = Input(shape=array_shape) # Each value represents if there is a wall or the snake body in the direction, 0 - no wall, 1 - wall
        
        conv = Conv2D(32, (2, 2), 1, activation='relu', kernel_initializer='he_normal')(input1)
        flat = Flatten()(conv)
        dense1 = Dense(128, activation='relu', kernel_initializer='he_normal')(flat)
        dense2 = Dense(32, activation='relu', kernel_initializer='he_normal')(dense1)
        
        concat = tf.keras.layers.concatenate([dense2, input2])
        output = Dense(4, activation='linear')(concat)
        
        self._model = tf.keras.Model(inputs=[input1, input2], outputs=output)
        
        self._model.summary()

    def call(self, observations, step_type=None, network_state=(), training=False):
        output = self._model([observations[0], observations[1]])
        return output, network_state

q_net = MyQNetwork(
    env.observation_spec(),
    env.action_spec())


Model: "model"
__________________________________________________________________________________________________
 Layer (type)                Output Shape                 Param #   Connected to                  
 input_1 (InputLayer)        [(None, 5, 5, 4)]            0         []                            
                                                                                                  
 conv2d (Conv2D)             (None, 4, 4, 32)             544       ['input_1[0][0]']             
                                                                                                  
 flatten (Flatten)           (None, 512)                  0         ['conv2d[0][0]']              
                                                                                                  
 dense (Dense)               (None, 128)                  65664     ['flatten[0][0]']             
                                                                                              

In [4]:
# Create the agent
from tf_agents.agents.dqn.dqn_agent import DdqnAgent
from tensorflow import keras
from keras.optimizers.legacy import Adam
from keras.losses import Huber
from tf_agents.trajectories import TimeStep

train_step_counter = tf.Variable(0)
optimizer = Adam(learning_rate=0.003)

epsilon = keras.optimizers.schedules.PolynomialDecay(
    initial_learning_rate=0.1,
    decay_steps=episodes_count,
    end_learning_rate=0.01
)

agent = DdqnAgent(
    time_step_spec=env.time_step_spec(),
    action_spec=env.action_spec(),
    q_network=q_net,
    optimizer=optimizer,
    td_errors_loss_fn=Huber(reduction="none"),
    gamma=tf.constant(0.95, dtype=tf.float32),
    train_step_counter=train_step_counter,
    epsilon_greedy=lambda: epsilon(train_step_counter)
)


Model: "model_1"
__________________________________________________________________________________________________
 Layer (type)                Output Shape                 Param #   Connected to                  
 input_3 (InputLayer)        [(None, 5, 5, 4)]            0         []                            
                                                                                                  
 conv2d_1 (Conv2D)           (None, 4, 4, 32)             544       ['input_3[0][0]']             
                                                                                                  
 flatten_1 (Flatten)         (None, 512)                  0         ['conv2d_1[0][0]']            
                                                                                                  
 dense_3 (Dense)             (None, 128)                  65664     ['flatten_1[0][0]']           
                                                                                            

In [5]:
# Create the replay buffer
from tf_agents.replay_buffers import tf_uniform_replay_buffer

replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
    data_spec=agent.collect_data_spec,
    batch_size=env.batch_size,
    max_length=100000
)


In [6]:
# Create the training metrics
from tf_agents.metrics import tf_metrics

train_metrics = [
    tf_metrics.NumberOfEpisodes(),
    tf_metrics.EnvironmentSteps(),
    tf_metrics.AverageReturnMetric(),
    tf_metrics.AverageEpisodeLengthMetric()
]


In [7]:
# Create the driver
from tf_agents.drivers import dynamic_step_driver

collect_driver = dynamic_step_driver.DynamicStepDriver(
    env,
    policy=agent.collect_policy,
    observers=[replay_buffer.add_batch] + train_metrics,
    num_steps=4
)

# Run a random policy to fill the replay buffer
from tf_agents.policies.random_tf_policy import RandomTFPolicy

initial_collect_policy = RandomTFPolicy(env.time_step_spec(), env.action_spec())
init_driver = dynamic_step_driver.DynamicStepDriver(
    env,
    policy=initial_collect_policy,
    observers=[replay_buffer.add_batch],
    num_steps=random_episodes
)

final_time_step, final_policy_state = init_driver.run()


In [8]:
# Create the dataset
dataset = replay_buffer.as_dataset(
    sample_batch_size=64,
    num_steps=2, # Capaz cambiar esto
    num_parallel_calls=3
).prefetch(3)


Instructions for updating:
Use `as_dataset(..., single_deterministic_pass=False) instead.


In [9]:
# Create the training loop

from tf_agents.utils.common import function
from tf_agents.trajectories import trajectory
from tf_agents.eval.metric_utils import log_metrics
import logging

logging.getLogger().setLevel(logging.INFO)

collect_driver.run = function(collect_driver.run)
agent.train = function(agent.train)

def train_agent(n_iterations):
    time_step = None
    policy_state = agent.collect_policy.get_initial_state(env.batch_size)
    iterator = iter(dataset)
    for iteration in range(n_iterations):
        time_step, policy_state = collect_driver.run(time_step, policy_state)
        trajectories, buffer_info = next(iterator)
        train_loss = agent.train(trajectories)
        if iteration % 1000 == 0:
            print("Iteration: ", iteration)
            log_metrics(train_metrics)


In [10]:
train_agent(episodes_count)


Instructions for updating:
back_prop=False is deprecated. Consider using tf.stop_gradient instead.
Instead of:
results = tf.foldr(fn, elems, back_prop=False)
Use:
results = tf.nest.map_structure(tf.stop_gradient, tf.foldr(fn, elems))


INFO:absl: 
		 NumberOfEpisodes = 4
		 EnvironmentSteps = 4
		 AverageReturn = -50.0
		 AverageEpisodeLength = 1.0


Iteration:  0


INFO:absl: 
		 NumberOfEpisodes = 557
		 EnvironmentSteps = 4004
		 AverageReturn = -10085.5
		 AverageEpisodeLength = 12.600000381469727


Iteration:  1000


INFO:absl: 
		 NumberOfEpisodes = 850
		 EnvironmentSteps = 8004
		 AverageReturn = -14390.0
		 AverageEpisodeLength = 15.699999809265137


Iteration:  2000


INFO:absl: 
		 NumberOfEpisodes = 1111
		 EnvironmentSteps = 12004
		 AverageReturn = -17475.5
		 AverageEpisodeLength = 17.600000381469727


Iteration:  3000


INFO:absl: 
		 NumberOfEpisodes = 1384
		 EnvironmentSteps = 16004
		 AverageReturn = -19942.5
		 AverageEpisodeLength = 15.899999618530273


Iteration:  4000


INFO:absl: 
		 NumberOfEpisodes = 1668
		 EnvironmentSteps = 20004
		 AverageReturn = -22070.0
		 AverageEpisodeLength = 12.5


Iteration:  5000


INFO:absl: 
		 NumberOfEpisodes = 1982
		 EnvironmentSteps = 24004
		 AverageReturn = -24521.5
		 AverageEpisodeLength = 8.600000381469727


Iteration:  6000


INFO:absl: 
		 NumberOfEpisodes = 2279
		 EnvironmentSteps = 28004
		 AverageReturn = -26881.0
		 AverageEpisodeLength = 14.699999809265137


Iteration:  7000


INFO:absl: 
		 NumberOfEpisodes = 2559
		 EnvironmentSteps = 32004
		 AverageReturn = -28803.5
		 AverageEpisodeLength = 14.100000381469727


Iteration:  8000


INFO:absl: 
		 NumberOfEpisodes = 2834
		 EnvironmentSteps = 36004
		 AverageReturn = -30537.5
		 AverageEpisodeLength = 14.399999618530273


Iteration:  9000


INFO:absl: 
		 NumberOfEpisodes = 3104
		 EnvironmentSteps = 40004
		 AverageReturn = -32125.5
		 AverageEpisodeLength = 18.399999618530273


Iteration:  10000


INFO:absl: 
		 NumberOfEpisodes = 3361
		 EnvironmentSteps = 44004
		 AverageReturn = -33600.0
		 AverageEpisodeLength = 11.0


Iteration:  11000


INFO:absl: 
		 NumberOfEpisodes = 3602
		 EnvironmentSteps = 48004
		 AverageReturn = -34643.0
		 AverageEpisodeLength = 20.200000762939453


Iteration:  12000


INFO:absl: 
		 NumberOfEpisodes = 3868
		 EnvironmentSteps = 52004
		 AverageReturn = -36238.0
		 AverageEpisodeLength = 16.0


Iteration:  13000


INFO:absl: 
		 NumberOfEpisodes = 4122
		 EnvironmentSteps = 56004
		 AverageReturn = -37651.0
		 AverageEpisodeLength = 14.899999618530273


Iteration:  14000


INFO:absl: 
		 NumberOfEpisodes = 4379
		 EnvironmentSteps = 60004
		 AverageReturn = -38942.5
		 AverageEpisodeLength = 16.5


Iteration:  15000


INFO:absl: 
		 NumberOfEpisodes = 4652
		 EnvironmentSteps = 64004
		 AverageReturn = -40761.5
		 AverageEpisodeLength = 6.900000095367432


Iteration:  16000


INFO:absl: 
		 NumberOfEpisodes = 4892
		 EnvironmentSteps = 68004
		 AverageReturn = -41780.0
		 AverageEpisodeLength = 16.700000762939453


Iteration:  17000


INFO:absl: 
		 NumberOfEpisodes = 5151
		 EnvironmentSteps = 72004
		 AverageReturn = -43156.0
		 AverageEpisodeLength = 13.600000381469727


Iteration:  18000


INFO:absl: 
		 NumberOfEpisodes = 5414
		 EnvironmentSteps = 76004
		 AverageReturn = -44396.0
		 AverageEpisodeLength = 22.600000381469727


Iteration:  19000


INFO:absl: 
		 NumberOfEpisodes = 5642
		 EnvironmentSteps = 80004
		 AverageReturn = -45070.0
		 AverageEpisodeLength = 19.5


Iteration:  20000


INFO:absl: 
		 NumberOfEpisodes = 5887
		 EnvironmentSteps = 84004
		 AverageReturn = -46021.0
		 AverageEpisodeLength = 19.899999618530273


Iteration:  21000


INFO:absl: 
		 NumberOfEpisodes = 6119
		 EnvironmentSteps = 88004
		 AverageReturn = -46576.5
		 AverageEpisodeLength = 17.799999237060547


Iteration:  22000


INFO:absl: 
		 NumberOfEpisodes = 6366
		 EnvironmentSteps = 92004
		 AverageReturn = -47547.5
		 AverageEpisodeLength = 11.399999618530273


Iteration:  23000


INFO:absl: 
		 NumberOfEpisodes = 6599
		 EnvironmentSteps = 96004
		 AverageReturn = -48128.5
		 AverageEpisodeLength = 14.199999809265137


Iteration:  24000


INFO:absl: 
		 NumberOfEpisodes = 6802
		 EnvironmentSteps = 100004
		 AverageReturn = -48178.5
		 AverageEpisodeLength = 20.5


Iteration:  25000


INFO:absl: 
		 NumberOfEpisodes = 7014
		 EnvironmentSteps = 104004
		 AverageReturn = -48408.0
		 AverageEpisodeLength = 22.600000381469727


Iteration:  26000


INFO:absl: 
		 NumberOfEpisodes = 7237
		 EnvironmentSteps = 108004
		 AverageReturn = -48853.5
		 AverageEpisodeLength = 21.200000762939453


Iteration:  27000


INFO:absl: 
		 NumberOfEpisodes = 7463
		 EnvironmentSteps = 112004
		 AverageReturn = -49179.5
		 AverageEpisodeLength = 8.800000190734863


Iteration:  28000


INFO:absl: 
		 NumberOfEpisodes = 7708
		 EnvironmentSteps = 116004
		 AverageReturn = -49969.5
		 AverageEpisodeLength = 15.100000381469727


Iteration:  29000


INFO:absl: 
		 NumberOfEpisodes = 7911
		 EnvironmentSteps = 120004
		 AverageReturn = -49968.5
		 AverageEpisodeLength = 20.700000762939453


Iteration:  30000


INFO:absl: 
		 NumberOfEpisodes = 8122
		 EnvironmentSteps = 124004
		 AverageReturn = -50168.0
		 AverageEpisodeLength = 20.5


Iteration:  31000


INFO:absl: 
		 NumberOfEpisodes = 8316
		 EnvironmentSteps = 128004
		 AverageReturn = -50005.5
		 AverageEpisodeLength = 24.700000762939453


Iteration:  32000


INFO:absl: 
		 NumberOfEpisodes = 8506
		 EnvironmentSteps = 132004
		 AverageReturn = -49697.5
		 AverageEpisodeLength = 31.200000762939453


Iteration:  33000


INFO:absl: 
		 NumberOfEpisodes = 8738
		 EnvironmentSteps = 136004
		 AverageReturn = -50110.0
		 AverageEpisodeLength = 25.0


Iteration:  34000


INFO:absl: 
		 NumberOfEpisodes = 8950
		 EnvironmentSteps = 140004
		 AverageReturn = -50021.0
		 AverageEpisodeLength = 22.799999237060547


Iteration:  35000


INFO:absl: 
		 NumberOfEpisodes = 9154
		 EnvironmentSteps = 144004
		 AverageReturn = -49863.0
		 AverageEpisodeLength = 8.399999618530273


Iteration:  36000


INFO:absl: 
		 NumberOfEpisodes = 9376
		 EnvironmentSteps = 148004
		 AverageReturn = -49827.5
		 AverageEpisodeLength = 25.200000762939453


Iteration:  37000


INFO:absl: 
		 NumberOfEpisodes = 9571
		 EnvironmentSteps = 152004
		 AverageReturn = -49660.5
		 AverageEpisodeLength = 22.100000381469727


Iteration:  38000


INFO:absl: 
		 NumberOfEpisodes = 9775
		 EnvironmentSteps = 156004
		 AverageReturn = -49651.0
		 AverageEpisodeLength = 22.299999237060547


Iteration:  39000


INFO:absl: 
		 NumberOfEpisodes = 9974
		 EnvironmentSteps = 160004
		 AverageReturn = -49537.5
		 AverageEpisodeLength = 15.5


Iteration:  40000


INFO:absl: 
		 NumberOfEpisodes = 10180
		 EnvironmentSteps = 164004
		 AverageReturn = -49328.5
		 AverageEpisodeLength = 25.200000762939453


Iteration:  41000


INFO:absl: 
		 NumberOfEpisodes = 10387
		 EnvironmentSteps = 168004
		 AverageReturn = -49252.5
		 AverageEpisodeLength = 32.29999923706055


Iteration:  42000


INFO:absl: 
		 NumberOfEpisodes = 10594
		 EnvironmentSteps = 172004
		 AverageReturn = -49350.0
		 AverageEpisodeLength = 14.100000381469727


Iteration:  43000


INFO:absl: 
		 NumberOfEpisodes = 10786
		 EnvironmentSteps = 176004
		 AverageReturn = -48794.0
		 AverageEpisodeLength = 20.600000381469727


Iteration:  44000


INFO:absl: 
		 NumberOfEpisodes = 10963
		 EnvironmentSteps = 180004
		 AverageReturn = -48199.0
		 AverageEpisodeLength = 16.200000762939453


Iteration:  45000


INFO:absl: 
		 NumberOfEpisodes = 11161
		 EnvironmentSteps = 184004
		 AverageReturn = -48006.5
		 AverageEpisodeLength = 8.800000190734863


Iteration:  46000


INFO:absl: 
		 NumberOfEpisodes = 11332
		 EnvironmentSteps = 188004
		 AverageReturn = -47311.0
		 AverageEpisodeLength = 14.600000381469727


Iteration:  47000


INFO:absl: 
		 NumberOfEpisodes = 11524
		 EnvironmentSteps = 192004
		 AverageReturn = -46992.0
		 AverageEpisodeLength = 19.100000381469727


Iteration:  48000


INFO:absl: 
		 NumberOfEpisodes = 11716
		 EnvironmentSteps = 196004
		 AverageReturn = -46766.0
		 AverageEpisodeLength = 13.899999618530273


Iteration:  49000


INFO:absl: 
		 NumberOfEpisodes = 11879
		 EnvironmentSteps = 200004
		 AverageReturn = -45710.5
		 AverageEpisodeLength = 26.299999237060547


Iteration:  50000


INFO:absl: 
		 NumberOfEpisodes = 12059
		 EnvironmentSteps = 204004
		 AverageReturn = -44859.0
		 AverageEpisodeLength = 28.0


Iteration:  51000


INFO:absl: 
		 NumberOfEpisodes = 12218
		 EnvironmentSteps = 208004
		 AverageReturn = -43849.0
		 AverageEpisodeLength = 24.799999237060547


Iteration:  52000


INFO:absl: 
		 NumberOfEpisodes = 12389
		 EnvironmentSteps = 212004
		 AverageReturn = -42824.0
		 AverageEpisodeLength = 33.79999923706055


Iteration:  53000


INFO:absl: 
		 NumberOfEpisodes = 12540
		 EnvironmentSteps = 216004
		 AverageReturn = -41538.5
		 AverageEpisodeLength = 31.700000762939453


Iteration:  54000


INFO:absl: 
		 NumberOfEpisodes = 12698
		 EnvironmentSteps = 220004
		 AverageReturn = -40140.5
		 AverageEpisodeLength = 22.600000381469727


Iteration:  55000


INFO:absl: 
		 NumberOfEpisodes = 12860
		 EnvironmentSteps = 224004
		 AverageReturn = -38972.0
		 AverageEpisodeLength = 37.400001525878906


Iteration:  56000


INFO:absl: 
		 NumberOfEpisodes = 13010
		 EnvironmentSteps = 228004
		 AverageReturn = -37684.5
		 AverageEpisodeLength = 29.799999237060547


Iteration:  57000


INFO:absl: 
		 NumberOfEpisodes = 13158
		 EnvironmentSteps = 232004
		 AverageReturn = -36017.0
		 AverageEpisodeLength = 16.700000762939453


Iteration:  58000


INFO:absl: 
		 NumberOfEpisodes = 13314
		 EnvironmentSteps = 236004
		 AverageReturn = -34913.0
		 AverageEpisodeLength = 32.900001525878906


Iteration:  59000


INFO:absl: 
		 NumberOfEpisodes = 13443
		 EnvironmentSteps = 240004
		 AverageReturn = -33237.0
		 AverageEpisodeLength = 17.5


Iteration:  60000


INFO:absl: 
		 NumberOfEpisodes = 13614
		 EnvironmentSteps = 244004
		 AverageReturn = -32205.0
		 AverageEpisodeLength = 25.299999237060547


Iteration:  61000


INFO:absl: 
		 NumberOfEpisodes = 13759
		 EnvironmentSteps = 248004
		 AverageReturn = -30882.0
		 AverageEpisodeLength = 25.0


Iteration:  62000


INFO:absl: 
		 NumberOfEpisodes = 13916
		 EnvironmentSteps = 252004
		 AverageReturn = -29482.5
		 AverageEpisodeLength = 15.0


Iteration:  63000


INFO:absl: 
		 NumberOfEpisodes = 14078
		 EnvironmentSteps = 256004
		 AverageReturn = -28567.0
		 AverageEpisodeLength = 16.200000762939453


Iteration:  64000


INFO:absl: 
		 NumberOfEpisodes = 14236
		 EnvironmentSteps = 260004
		 AverageReturn = -27366.0
		 AverageEpisodeLength = 24.399999618530273


Iteration:  65000


INFO:absl: 
		 NumberOfEpisodes = 14390
		 EnvironmentSteps = 264004
		 AverageReturn = -26085.5
		 AverageEpisodeLength = 13.800000190734863


Iteration:  66000


INFO:absl: 
		 NumberOfEpisodes = 14518
		 EnvironmentSteps = 268004
		 AverageReturn = -24277.5
		 AverageEpisodeLength = 33.29999923706055


Iteration:  67000


INFO:absl: 
		 NumberOfEpisodes = 14644
		 EnvironmentSteps = 272004
		 AverageReturn = -22253.5
		 AverageEpisodeLength = 30.899999618530273


Iteration:  68000


INFO:absl: 
		 NumberOfEpisodes = 14781
		 EnvironmentSteps = 276004
		 AverageReturn = -20411.0
		 AverageEpisodeLength = 34.20000076293945


Iteration:  69000


INFO:absl: 
		 NumberOfEpisodes = 14921
		 EnvironmentSteps = 280004
		 AverageReturn = -18524.0
		 AverageEpisodeLength = 25.299999237060547


Iteration:  70000


INFO:absl: 
		 NumberOfEpisodes = 15069
		 EnvironmentSteps = 284004
		 AverageReturn = -17149.5
		 AverageEpisodeLength = 26.299999237060547


Iteration:  71000


INFO:absl: 
		 NumberOfEpisodes = 15198
		 EnvironmentSteps = 288004
		 AverageReturn = -15171.0
		 AverageEpisodeLength = 33.400001525878906


Iteration:  72000


INFO:absl: 
		 NumberOfEpisodes = 15326
		 EnvironmentSteps = 292004
		 AverageReturn = -13044.0
		 AverageEpisodeLength = 34.900001525878906


Iteration:  73000


INFO:absl: 
		 NumberOfEpisodes = 15457
		 EnvironmentSteps = 296004
		 AverageReturn = -10743.5
		 AverageEpisodeLength = 32.400001525878906


Iteration:  74000


INFO:absl: 
		 NumberOfEpisodes = 15578
		 EnvironmentSteps = 300004
		 AverageReturn = -8229.0
		 AverageEpisodeLength = 29.899999618530273


Iteration:  75000


INFO:absl: 
		 NumberOfEpisodes = 15704
		 EnvironmentSteps = 304004
		 AverageReturn = -5992.0
		 AverageEpisodeLength = 43.29999923706055


Iteration:  76000


INFO:absl: 
		 NumberOfEpisodes = 15822
		 EnvironmentSteps = 308004
		 AverageReturn = -3403.0
		 AverageEpisodeLength = 27.5


Iteration:  77000


INFO:absl: 
		 NumberOfEpisodes = 15939
		 EnvironmentSteps = 312004
		 AverageReturn = -970.5
		 AverageEpisodeLength = 28.700000762939453


Iteration:  78000


INFO:absl: 
		 NumberOfEpisodes = 16053
		 EnvironmentSteps = 316004
		 AverageReturn = 1274.5
		 AverageEpisodeLength = 41.400001525878906


Iteration:  79000


In [11]:
# Evaluate the agent

test_scene = Scene(init_randomly=True)
test_env = SnakeGame(test_scene)
test_env = TFPyEnvironment(test_env)

time_step = test_env.reset()
rewards = []
steps = 0

while not time_step.is_last() and steps < 200:
    steps += 1
    print("Map:\n", np.argmax(time_step.observation[0], axis=3))
    print("Direction:\n", time_step.observation[1])
    action_step = agent.policy.action(time_step)
    print("Action: ", action_step.action.numpy())
    time_step = test_env.step(action_step.action)
    rewards.append(time_step.reward.numpy()[0])
    
print("Total reward: ", sum(rewards))
print("Total steps: ", steps)


Map:
 [[[0 0 0 0 0]
  [0 0 0 0 0]
  [0 0 0 0 0]
  [0 0 0 0 3]
  [1 2 0 0 0]]]
Direction:
 tf.Tensor([[ True False  True  True]], shape=(1, 4), dtype=bool)
Action:  [1]
Map:
 [[[0 0 0 0 0]
  [0 0 0 0 0]
  [0 0 0 0 0]
  [1 0 0 0 3]
  [2 0 0 0 0]]]
Direction:
 tf.Tensor([[ True False False  True]], shape=(1, 4), dtype=bool)
Action:  [2]
Map:
 [[[0 0 0 0 0]
  [0 0 0 0 0]
  [0 0 0 0 0]
  [2 1 0 0 3]
  [0 0 0 0 0]]]
Direction:
 tf.Tensor([[ True False False False]], shape=(1, 4), dtype=bool)
Action:  [2]
Map:
 [[[0 0 0 0 0]
  [0 0 0 0 0]
  [0 0 0 0 0]
  [0 2 1 0 3]
  [0 0 0 0 0]]]
Direction:
 tf.Tensor([[ True False False False]], shape=(1, 4), dtype=bool)
Action:  [2]
Map:
 [[[0 0 0 0 0]
  [0 0 0 0 0]
  [0 0 0 0 0]
  [0 0 2 1 3]
  [0 0 0 0 0]]]
Direction:
 tf.Tensor([[ True False False False]], shape=(1, 4), dtype=bool)
Action:  [2]
Map:
 [[[0 0 0 0 0]
  [0 0 0 0 0]
  [3 0 0 0 0]
  [0 0 2 2 1]
  [0 0 0 0 0]]]
Direction:
 tf.Tensor([[ True False  True False]], shape=(1, 4), dtype=bool)
Actio

In [12]:
# Count the average steps per episode

n = 500
total_steps = 0
total_reward = 0
steps = 0
total_500 = 0

for i in range(n):
    if i % 10 == 0 and i != 0:
        print("Episode: ", i)
        print("Avg reward: ", total_reward / (i + 1))
        
    time_step = test_env.reset()
    steps = 0
    while not time_step.is_last() and steps < 500:
        steps += 1
        action_step = agent.policy.action(time_step)
        time_step = test_env.step(action_step.action)
        reward = time_step.reward.numpy()[0]
        total_reward += reward
    
    if steps == 500:
        total_500 += 1
    else:
        total_steps += steps
        
if steps == 500:
    print("500 episodes reached")
    
print("Average steps per episode: ", total_steps / (n - total_500))
print("Average reward per episode: ", total_reward / n)
print("Total 500 episodes: ", total_500)


Episode:  10
Avg reward:  50.0
Episode:  20
Avg reward:  51.666666666666664
Episode:  30
Avg reward:  61.12903225806452
Episode:  40
Avg reward:  60.24390243902439
Episode:  50
Avg reward:  63.13725490196079
Episode:  60
Avg reward:  65.08196721311475
Episode:  70
Avg reward:  63.87323943661972
Episode:  80
Avg reward:  63.333333333333336
Episode:  90
Avg reward:  62.582417582417584
Episode:  100
Avg reward:  61.48514851485149
Episode:  110
Avg reward:  60.67567567567568
Episode:  120
Avg reward:  61.611570247933884
Episode:  130
Avg reward:  61.412213740458014
Episode:  140
Avg reward:  63.40425531914894
Episode:  150
Avg reward:  63.311258278145694
Episode:  160
Avg reward:  62.91925465838509
Episode:  170
Avg reward:  63.421052631578945
Episode:  180
Avg reward:  63.20441988950276
Episode:  190
Avg reward:  62.853403141361255
Episode:  200
Avg reward:  63.00995024875622
Episode:  210
Avg reward:  63.056872037914694
Episode:  220
Avg reward:  62.828054298642535
Episode:  230
Avg rewa

In [13]:
# Save the q network weights
tf.keras.models.save_model(q_net._model, "q_network.h5")



  tf.keras.models.save_model(q_net._model, "q_network.h5")
