In [3]:
import numpy as np
import tensorflow as tf
from tf_agents.networks.q_network import QNetwork
import tf_agents.networks.network as network
from tf_agents.specs import tensor_spec
from snake_game import SnakeGame
from scene import Scene
from tensorflow.keras import Sequential
from tensorflow.keras.layers import Dense, Flatten, Conv2D, Input

scene = Scene(init_randomly=True)
episodes_count = 50000
random_episodes = 20000


In [4]:
from tf_agents.environments.tf_py_environment import TFPyEnvironment

env = SnakeGame(scene)
env = TFPyEnvironment(env)


In [5]:
class MyQNetwork(network.Network):
    def __init__(self,
                 input_tensor_spec,
                 action_spec,
                 name='MyQNetwork'):
        super(MyQNetwork, self).__init__(
            input_tensor_spec=input_tensor_spec,
            state_spec=(),
            name=name)
        
        self._action_spec = action_spec

        input_shape = input_tensor_spec.shape
        
        self._model = Sequential([
            Input(shape=input_shape),
            Conv2D(32, (2, 2), 1, activation='relu', kernel_initializer='he_normal'),
            Conv2D(64, (2, 2), 1, activation='relu', kernel_initializer='he_normal'),
            Flatten(),
            Dense(128, activation='relu', kernel_initializer='he_normal'),
            Dense(16, activation='relu', kernel_initializer='he_normal'),
            Dense(4, activation='linear')
        ])

    def call(self, observations, step_type=None, network_state=(), training=False):
        output = self._model(observations)
        return output, network_state

q_net = MyQNetwork(
    env.observation_spec(),
    env.action_spec())


In [6]:
# Create the agent
from tf_agents.agents.dqn.dqn_agent import DdqnAgent
from tensorflow import keras
from keras.optimizers import Adam
from keras.losses import Huber
from tf_agents.trajectories import TimeStep

train_step_counter = tf.Variable(0)
optimizer = Adam(learning_rate=0.003)

# epsilon = lambda train_step: (1 / (tf.cast(tf.linspace(1, 8, episodes_count)**3, tf.float32)))[train_step]
epsilon = keras.optimizers.schedules.PolynomialDecay(
    initial_learning_rate=1.0,
    decay_steps=episodes_count,
    end_learning_rate=0.01
)

agent = DdqnAgent(
    time_step_spec=env.time_step_spec(),
    action_spec=env.action_spec(),
    q_network=q_net,
    optimizer=optimizer,
    td_errors_loss_fn=Huber(reduction="none"),
    gamma=tf.constant(0.95, dtype=tf.float32),
    train_step_counter=train_step_counter,
    epsilon_greedy=lambda: epsilon(train_step_counter)
)




In [7]:
# Create the replay buffer
from tf_agents.replay_buffers import tf_uniform_replay_buffer

replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
    data_spec=agent.collect_data_spec,
    batch_size=env.batch_size,
    max_length=100000
)


In [8]:
# Create the training metrics
from tf_agents.metrics import tf_metrics

train_metrics = [
    tf_metrics.NumberOfEpisodes(),
    tf_metrics.EnvironmentSteps(),
    tf_metrics.AverageReturnMetric(),
    tf_metrics.AverageEpisodeLengthMetric()
]


In [9]:
# Create the driver
from tf_agents.drivers import dynamic_step_driver

collect_driver = dynamic_step_driver.DynamicStepDriver(
    env,
    policy=agent.collect_policy,
    observers=[replay_buffer.add_batch] + train_metrics,
    num_steps=4
)

# Run a random policy to fill the replay buffer
from tf_agents.policies.random_tf_policy import RandomTFPolicy

initial_collect_policy = RandomTFPolicy(env.time_step_spec(), env.action_spec())
init_driver = dynamic_step_driver.DynamicStepDriver(
    env,
    policy=initial_collect_policy,
    observers=[replay_buffer.add_batch],
    num_steps=random_episodes
)

final_time_step, final_policy_state = init_driver.run()


In [10]:
# Create the dataset
dataset = replay_buffer.as_dataset(
    sample_batch_size=64,
    num_steps=2, # Capaz cambiar esto
    num_parallel_calls=3
).prefetch(3)


Instructions for updating:
Use `as_dataset(..., single_deterministic_pass=False) instead.


Instructions for updating:
Use `as_dataset(..., single_deterministic_pass=False) instead.


In [11]:
# Create the training loop

from tf_agents.utils.common import function
from tf_agents.trajectories import trajectory
from tf_agents.eval.metric_utils import log_metrics
import logging

logging.getLogger().setLevel(logging.INFO)

collect_driver.run = function(collect_driver.run)
agent.train = function(agent.train)

def train_agent(n_iterations):
    time_step = None
    policy_state = agent.collect_policy.get_initial_state(env.batch_size)
    iterator = iter(dataset)
    for iteration in range(n_iterations):
        time_step, policy_state = collect_driver.run(time_step, policy_state)
        trajectories, buffer_info = next(iterator)
        train_loss = agent.train(trajectories)
        if iteration % 1000 == 0:
            print("Iteration: ", iteration)
            log_metrics(train_metrics)


In [12]:
train_agent(episodes_count)


Instructions for updating:
back_prop=False is deprecated. Consider using tf.stop_gradient instead.
Instead of:
results = tf.foldr(fn, elems, back_prop=False)
Use:
results = tf.nest.map_structure(tf.stop_gradient, tf.foldr(fn, elems))


Instructions for updating:
back_prop=False is deprecated. Consider using tf.stop_gradient instead.
Instead of:
results = tf.foldr(fn, elems, back_prop=False)
Use:
results = tf.nest.map_structure(tf.stop_gradient, tf.foldr(fn, elems))
INFO:absl: 
		 NumberOfEpisodes = 3
		 EnvironmentSteps = 4
		 AverageReturn = -31.66666603088379
		 AverageEpisodeLength = 1.0


Iteration:  0


INFO:absl: 
		 NumberOfEpisodes = 2752
		 EnvironmentSteps = 4004
		 AverageReturn = -40109.1015625
		 AverageEpisodeLength = 1.0


Iteration:  1000


INFO:absl: 
		 NumberOfEpisodes = 5533
		 EnvironmentSteps = 8004
		 AverageReturn = -80904.3984375
		 AverageEpisodeLength = 2.0999999046325684


Iteration:  2000


INFO:absl: 
		 NumberOfEpisodes = 8275
		 EnvironmentSteps = 12004
		 AverageReturn = -121046.1015625
		 AverageEpisodeLength = 1.100000023841858


Iteration:  3000


INFO:absl: 
		 NumberOfEpisodes = 10878
		 EnvironmentSteps = 16004
		 AverageReturn = -158638.90625
		 AverageEpisodeLength = 0.800000011920929


Iteration:  4000


INFO:absl: 
		 NumberOfEpisodes = 13364
		 EnvironmentSteps = 20004
		 AverageReturn = -194800.09375
		 AverageEpisodeLength = 3.0


Iteration:  5000


INFO:absl: 
		 NumberOfEpisodes = 15758
		 EnvironmentSteps = 24004
		 AverageReturn = -228973.296875
		 AverageEpisodeLength = 0.5


Iteration:  6000


INFO:absl: 
		 NumberOfEpisodes = 17934
		 EnvironmentSteps = 28004
		 AverageReturn = -259571.5
		 AverageEpisodeLength = 4.5


Iteration:  7000


INFO:absl: 
		 NumberOfEpisodes = 20059
		 EnvironmentSteps = 32004
		 AverageReturn = -288391.90625
		 AverageEpisodeLength = 2.700000047683716


Iteration:  8000


INFO:absl: 
		 NumberOfEpisodes = 22130
		 EnvironmentSteps = 36004
		 AverageReturn = -316464.09375
		 AverageEpisodeLength = 2.4000000953674316


Iteration:  9000


INFO:absl: 
		 NumberOfEpisodes = 24019
		 EnvironmentSteps = 40004
		 AverageReturn = -341380.0
		 AverageEpisodeLength = 5.5


Iteration:  10000


INFO:absl: 
		 NumberOfEpisodes = 25966
		 EnvironmentSteps = 44004
		 AverageReturn = -366688.59375
		 AverageEpisodeLength = 2.4000000953674316


Iteration:  11000


INFO:absl: 
		 NumberOfEpisodes = 27900
		 EnvironmentSteps = 48004
		 AverageReturn = -391426.1875
		 AverageEpisodeLength = 2.5999999046325684


Iteration:  12000


INFO:absl: 
		 NumberOfEpisodes = 29678
		 EnvironmentSteps = 52004
		 AverageReturn = -412885.59375
		 AverageEpisodeLength = 1.399999976158142


Iteration:  13000


INFO:absl: 
		 NumberOfEpisodes = 31417
		 EnvironmentSteps = 56004
		 AverageReturn = -434034.5
		 AverageEpisodeLength = 1.2999999523162842


Iteration:  14000


INFO:absl: 
		 NumberOfEpisodes = 33133
		 EnvironmentSteps = 60004
		 AverageReturn = -453831.1875
		 AverageEpisodeLength = 2.4000000953674316


Iteration:  15000


INFO:absl: 
		 NumberOfEpisodes = 34675
		 EnvironmentSteps = 64004
		 AverageReturn = -470226.0
		 AverageEpisodeLength = 1.399999976158142


Iteration:  16000


INFO:absl: 
		 NumberOfEpisodes = 36348
		 EnvironmentSteps = 68004
		 AverageReturn = -488173.3125
		 AverageEpisodeLength = 2.0999999046325684


Iteration:  17000


INFO:absl: 
		 NumberOfEpisodes = 37781
		 EnvironmentSteps = 72004
		 AverageReturn = -501976.40625
		 AverageEpisodeLength = 3.9000000953674316


Iteration:  18000


INFO:absl: 
		 NumberOfEpisodes = 39137
		 EnvironmentSteps = 76004
		 AverageReturn = -513631.3125
		 AverageEpisodeLength = 3.799999952316284


Iteration:  19000


INFO:absl: 
		 NumberOfEpisodes = 40461
		 EnvironmentSteps = 80004
		 AverageReturn = -525122.0
		 AverageEpisodeLength = 2.200000047683716


Iteration:  20000


INFO:absl: 
		 NumberOfEpisodes = 41723
		 EnvironmentSteps = 84004
		 AverageReturn = -535672.8125
		 AverageEpisodeLength = 4.900000095367432


Iteration:  21000


INFO:absl: 
		 NumberOfEpisodes = 42929
		 EnvironmentSteps = 88004
		 AverageReturn = -544278.0
		 AverageEpisodeLength = 2.5


Iteration:  22000


INFO:absl: 
		 NumberOfEpisodes = 44116
		 EnvironmentSteps = 92004
		 AverageReturn = -553169.8125
		 AverageEpisodeLength = 3.200000047683716


Iteration:  23000


INFO:absl: 
		 NumberOfEpisodes = 45191
		 EnvironmentSteps = 96004
		 AverageReturn = -558899.125
		 AverageEpisodeLength = 2.700000047683716


Iteration:  24000


INFO:absl: 
		 NumberOfEpisodes = 46294
		 EnvironmentSteps = 100004
		 AverageReturn = -563664.125
		 AverageEpisodeLength = 3.4000000953674316


Iteration:  25000


INFO:absl: 
		 NumberOfEpisodes = 47304
		 EnvironmentSteps = 104004
		 AverageReturn = -567380.0
		 AverageEpisodeLength = 4.599999904632568


Iteration:  26000


INFO:absl: 
		 NumberOfEpisodes = 48228
		 EnvironmentSteps = 108004
		 AverageReturn = -569308.375
		 AverageEpisodeLength = 5.099999904632568


Iteration:  27000


INFO:absl: 
		 NumberOfEpisodes = 49168
		 EnvironmentSteps = 112004
		 AverageReturn = -569971.625
		 AverageEpisodeLength = 3.0999999046325684


Iteration:  28000


INFO:absl: 
		 NumberOfEpisodes = 50083
		 EnvironmentSteps = 116004
		 AverageReturn = -570232.0
		 AverageEpisodeLength = 6.400000095367432


Iteration:  29000


INFO:absl: 
		 NumberOfEpisodes = 51026
		 EnvironmentSteps = 120004
		 AverageReturn = -570091.375
		 AverageEpisodeLength = 6.099999904632568


Iteration:  30000


INFO:absl: 
		 NumberOfEpisodes = 51868
		 EnvironmentSteps = 124004
		 AverageReturn = -568590.875
		 AverageEpisodeLength = 5.900000095367432


Iteration:  31000


INFO:absl: 
		 NumberOfEpisodes = 52622
		 EnvironmentSteps = 128004
		 AverageReturn = -565556.625
		 AverageEpisodeLength = 5.0


Iteration:  32000


INFO:absl: 
		 NumberOfEpisodes = 53333
		 EnvironmentSteps = 132004
		 AverageReturn = -560909.125
		 AverageEpisodeLength = 8.300000190734863


Iteration:  33000


INFO:absl: 
		 NumberOfEpisodes = 54056
		 EnvironmentSteps = 136004
		 AverageReturn = -555503.125
		 AverageEpisodeLength = 5.300000190734863


Iteration:  34000


INFO:absl: 
		 NumberOfEpisodes = 54726
		 EnvironmentSteps = 140004
		 AverageReturn = -548867.0
		 AverageEpisodeLength = 7.199999809265137


Iteration:  35000


INFO:absl: 
		 NumberOfEpisodes = 55330
		 EnvironmentSteps = 144004
		 AverageReturn = -540113.5
		 AverageEpisodeLength = 5.199999809265137


Iteration:  36000


INFO:absl: 
		 NumberOfEpisodes = 55888
		 EnvironmentSteps = 148004
		 AverageReturn = -531148.6875
		 AverageEpisodeLength = 6.5


Iteration:  37000


INFO:absl: 
		 NumberOfEpisodes = 56460
		 EnvironmentSteps = 152004
		 AverageReturn = -521412.09375
		 AverageEpisodeLength = 9.699999809265137


Iteration:  38000


INFO:absl: 
		 NumberOfEpisodes = 56978
		 EnvironmentSteps = 156004
		 AverageReturn = -510390.40625
		 AverageEpisodeLength = 7.099999904632568


Iteration:  39000


INFO:absl: 
		 NumberOfEpisodes = 57453
		 EnvironmentSteps = 160004
		 AverageReturn = -497757.5
		 AverageEpisodeLength = 8.100000381469727


Iteration:  40000


INFO:absl: 
		 NumberOfEpisodes = 57884
		 EnvironmentSteps = 164004
		 AverageReturn = -484058.1875
		 AverageEpisodeLength = 6.699999809265137


Iteration:  41000


INFO:absl: 
		 NumberOfEpisodes = 58253
		 EnvironmentSteps = 168004
		 AverageReturn = -468653.8125
		 AverageEpisodeLength = 15.800000190734863


Iteration:  42000


INFO:absl: 
		 NumberOfEpisodes = 58659
		 EnvironmentSteps = 172004
		 AverageReturn = -452686.3125
		 AverageEpisodeLength = 11.899999618530273


Iteration:  43000


INFO:absl: 
		 NumberOfEpisodes = 59046
		 EnvironmentSteps = 176004
		 AverageReturn = -436541.90625
		 AverageEpisodeLength = 9.100000381469727


Iteration:  44000


INFO:absl: 
		 NumberOfEpisodes = 59385
		 EnvironmentSteps = 180004
		 AverageReturn = -419169.40625
		 AverageEpisodeLength = 15.399999618530273


Iteration:  45000


INFO:absl: 
		 NumberOfEpisodes = 59708
		 EnvironmentSteps = 184004
		 AverageReturn = -400627.1875
		 AverageEpisodeLength = 6.300000190734863


Iteration:  46000


INFO:absl: 
		 NumberOfEpisodes = 59999
		 EnvironmentSteps = 188004
		 AverageReturn = -381125.0
		 AverageEpisodeLength = 11.699999809265137


Iteration:  47000


INFO:absl: 
		 NumberOfEpisodes = 60249
		 EnvironmentSteps = 192004
		 AverageReturn = -360503.40625
		 AverageEpisodeLength = 11.5


Iteration:  48000


INFO:absl: 
		 NumberOfEpisodes = 60481
		 EnvironmentSteps = 196004
		 AverageReturn = -338690.1875
		 AverageEpisodeLength = 15.899999618530273


Iteration:  49000


In [13]:
# Evaluate the agent

time_step = env.reset()
rewards = []
steps = 0

while not time_step.is_last() and steps < 200:
    steps += 1
    print("Map:\n", np.argmax(time_step.observation.numpy(), axis=3))
    action_step = agent.policy.action(time_step)
    print("Action: ", action_step.action.numpy())
    time_step = env.step(action_step.action)
    rewards.append(time_step.reward.numpy()[0])
    
print("Total reward: ", sum(rewards))
print("Total steps: ", steps)


Map:
 [[[0 0 0 0 0]
  [0 0 0 3 0]
  [0 0 0 0 0]
  [1 2 0 0 0]
  [0 0 0 0 0]]]
Action:  [1]
Map:
 [[[0 0 0 0 0]
  [0 0 0 3 0]
  [1 0 0 0 0]
  [2 0 0 0 0]
  [0 0 0 0 0]]]
Action:  [2]
Map:
 [[[0 0 0 0 0]
  [0 0 0 3 0]
  [2 1 0 0 0]
  [0 0 0 0 0]
  [0 0 0 0 0]]]
Action:  [1]
Map:
 [[[0 0 0 0 0]
  [0 1 0 3 0]
  [0 2 0 0 0]
  [0 0 0 0 0]
  [0 0 0 0 0]]]
Action:  [2]
Map:
 [[[0 0 0 0 0]
  [0 2 1 3 0]
  [0 0 0 0 0]
  [0 0 0 0 0]
  [0 0 0 0 0]]]
Action:  [2]
Map:
 [[[0 0 0 0 0]
  [0 2 2 1 0]
  [0 0 0 0 0]
  [0 0 0 0 0]
  [0 0 3 0 0]]]
Action:  [3]
Map:
 [[[0 0 0 0 0]
  [0 0 2 2 0]
  [0 0 0 1 0]
  [0 0 0 0 0]
  [0 0 3 0 0]]]
Action:  [3]
Map:
 [[[0 0 0 0 0]
  [0 0 0 2 0]
  [0 0 0 2 0]
  [0 0 0 1 0]
  [0 0 3 0 0]]]
Action:  [3]
Map:
 [[[0 0 0 0 0]
  [0 0 0 0 0]
  [0 0 0 2 0]
  [0 0 0 2 0]
  [0 0 3 1 0]]]
Action:  [0]
Map:
 [[[0 0 0 0 0]
  [0 0 0 0 0]
  [0 0 0 2 0]
  [3 0 0 2 0]
  [0 0 1 2 0]]]
Action:  [0]
Map:
 [[[0 0 0 0 0]
  [0 0 0 0 0]
  [0 0 0 0 0]
  [3 0 0 2 0]
  [0 1 2 2 0]]]
Action:  [0]

In [14]:
# Count the average steps per episode

n = 200
total_steps = 0
total_reward = 0
steps = 0
total_500 = 0

for i in range(n):
    if i % 10 == 0:
        print("Episode: ", i)
        
    time_step = env.reset()
    steps = 0
    while not time_step.is_last() and steps < 500:
        steps += 1
        action_step = agent.policy.action(time_step)
        time_step = env.step(action_step.action)
        reward = time_step.reward.numpy()[0]
        total_reward += reward
    
    if steps == 500:
        total_500 += 1
    else:
        total_steps += steps
        
if steps == 500:
    print("500 episodes reached")
    
print("Average steps per episode: ", total_steps / (n - total_500))
print("Average reward per episode: ", total_reward / n)
print("Total 500 episodes: ", total_500)


Episode:  0
Episode:  10
Episode:  20
Episode:  30
Episode:  40
Episode:  50
Episode:  60
Episode:  70
Episode:  80
Episode:  90
Episode:  100
Episode:  110
Episode:  120
Episode:  130
Episode:  140
Episode:  150
Episode:  160
Episode:  170
Episode:  180
Episode:  190
Average steps per episode:  23.31
Average reward per episode:  130.65
Total 500 episodes:  0


In [15]:
# Save the agent
import os
from tf_agents.policies.policy_saver import PolicySaver

saver = PolicySaver(agent.policy, batch_size=None)
saver.save("policy")


TypeError: this __dict__ descriptor does not support '_DictWrapper' objects