Here we will use the TF Agents to train the CartPole environment with DQN.
TF Agents package makes the implementation of RL algo easier.

In [1328]:
import tensorflow as tf
from tf_agents.environments import suite_gym

In [1329]:
env = suite_gym.load("CartPole-v1")

In [1330]:
env

<tf_agents.environments.wrappers.TimeLimit at 0x1e6c0659c10>

In [1331]:
env.gym

<TimeLimit<OrderEnforcing<CartPoleEnv<CartPole-v1>>>>

In [1332]:
env.reset()

TimeStep(
{'step_type': array(0),
 'reward': array(0., dtype=float32),
 'discount': array(1., dtype=float32),
 'observation': array([ 0.01262178,  0.04025272, -0.01962253,  0.00919425], dtype=float32)})

In [1333]:
env.step(0)

TimeStep(
{'step_type': array(1),
 'reward': array(1., dtype=float32),
 'discount': array(1., dtype=float32),
 'observation': array([ 0.01342683, -0.15458241, -0.01943865,  0.29562202], dtype=float32)})

Explore Environment Specification

In [1334]:
env.observation_spec()

BoundedArraySpec(shape=(4,), dtype=dtype('float32'), name='observation', minimum=[-4.8000002e+00 -3.4028235e+38 -4.1887903e-01 -3.4028235e+38], maximum=[4.8000002e+00 3.4028235e+38 4.1887903e-01 3.4028235e+38])

In [1335]:
env.action_spec()

BoundedArraySpec(shape=(), dtype=dtype('int64'), name='action', minimum=0, maximum=1)

In [1336]:
env.time_step_spec()

TimeStep(
{'step_type': ArraySpec(shape=(), dtype=dtype('int32'), name='step_type'),
 'reward': ArraySpec(shape=(), dtype=dtype('float32'), name='reward'),
 'discount': BoundedArraySpec(shape=(), dtype=dtype('float32'), name='discount', minimum=0.0, maximum=1.0),
 'observation': BoundedArraySpec(shape=(4,), dtype=dtype('float32'), name='observation', minimum=[-4.8000002e+00 -3.4028235e+38 -4.1887903e-01 -3.4028235e+38], maximum=[4.8000002e+00 3.4028235e+38 4.1887903e-01 3.4028235e+38])})

In [1337]:
env.reward_spec()

ArraySpec(shape=(), dtype=dtype('float32'), name='reward')

In [1338]:
env.discount_spec()

BoundedArraySpec(shape=(), dtype=dtype('float32'), name='discount', minimum=0.0, maximum=1.0)

In [1339]:
env.current_time_step()

TimeStep(
{'step_type': array(1),
 'reward': array(1., dtype=float32),
 'discount': array(1., dtype=float32),
 'observation': array([ 0.01342683, -0.15458241, -0.01943865,  0.29562202], dtype=float32)})

Wrap the environment with TFPyEnvironment which supports both py and tf environments.

In [1340]:
from tf_agents.environments.tf_py_environment import TFPyEnvironment

In [1341]:
env = TFPyEnvironment(env)

In [1342]:
env

<tf_agents.environments.tf_py_environment.TFPyEnvironment at 0x1e6c22a89d0>

In [1343]:
env.reset()

TimeStep(
{'step_type': <tf.Tensor: shape=(1,), dtype=int32, numpy=array([0])>,
 'reward': <tf.Tensor: shape=(1,), dtype=float32, numpy=array([0.], dtype=float32)>,
 'discount': <tf.Tensor: shape=(1,), dtype=float32, numpy=array([1.], dtype=float32)>,
 'observation': <tf.Tensor: shape=(1, 4), dtype=float32, numpy=
array([[-0.04740351, -0.01980463,  0.03243716,  0.01009236]],
      dtype=float32)>})

In [1344]:
env.step(0)

TimeStep(
{'step_type': <tf.Tensor: shape=(1,), dtype=int32, numpy=array([1])>,
 'reward': <tf.Tensor: shape=(1,), dtype=float32, numpy=array([1.], dtype=float32)>,
 'discount': <tf.Tensor: shape=(1,), dtype=float32, numpy=array([1.], dtype=float32)>,
 'observation': <tf.Tensor: shape=(1, 4), dtype=float32, numpy=
array([[-0.0477996 , -0.21537639,  0.03263901,  0.31283054]],
      dtype=float32)>})

In [1345]:
env.observation_spec()

BoundedTensorSpec(shape=(4,), dtype=tf.float32, name='observation', minimum=array([-4.8000002e+00, -3.4028235e+38, -4.1887903e-01, -3.4028235e+38],
      dtype=float32), maximum=array([4.8000002e+00, 3.4028235e+38, 4.1887903e-01, 3.4028235e+38],
      dtype=float32))

In [1346]:
env.action_spec()

BoundedTensorSpec(shape=(), dtype=tf.int64, name='action', minimum=array(0, dtype=int64), maximum=array(1, dtype=int64))

In [1347]:
env.time_step_spec()

TimeStep(
{'step_type': TensorSpec(shape=(), dtype=tf.int32, name='step_type'),
 'reward': TensorSpec(shape=(), dtype=tf.float32, name='reward'),
 'discount': BoundedTensorSpec(shape=(), dtype=tf.float32, name='discount', minimum=array(0., dtype=float32), maximum=array(1., dtype=float32)),
 'observation': BoundedTensorSpec(shape=(4,), dtype=tf.float32, name='observation', minimum=array([-4.8000002e+00, -3.4028235e+38, -4.1887903e-01, -3.4028235e+38],
      dtype=float32), maximum=array([4.8000002e+00, 3.4028235e+38, 4.1887903e-01, 3.4028235e+38],
      dtype=float32))})

In [1348]:
env.reward_spec()

TensorSpec(shape=(), dtype=tf.float32, name='reward')

In [1349]:
env.discount_spec()

BoundedArraySpec(shape=(), dtype=dtype('float32'), name='discount', minimum=0.0, maximum=1.0)

In [1350]:
env.current_time_step()

TimeStep(
{'step_type': <tf.Tensor: shape=(1,), dtype=int32, numpy=array([1])>,
 'reward': <tf.Tensor: shape=(1,), dtype=float32, numpy=array([1.], dtype=float32)>,
 'discount': <tf.Tensor: shape=(1,), dtype=float32, numpy=array([1.], dtype=float32)>,
 'observation': <tf.Tensor: shape=(1, 4), dtype=float32, numpy=
array([[-0.0477996 , -0.21537639,  0.03263901,  0.31283054]],
      dtype=float32)>})

Create a Deep Q Network

In [1351]:
from tf_agents.networks.q_network import QNetwork

In [1352]:
help(QNetwork)

Help on class QNetwork in module tf_agents.networks.q_network:

class QNetwork(tf_agents.networks.network.Network)
 |  QNetwork(input_tensor_spec, action_spec, preprocessing_layers=None, preprocessing_combiner=None, conv_layer_params=None, fc_layer_params=(75, 40), dropout_layer_params=None, activation_fn=<function relu at 0x000001E6AA0D1260>, kernel_initializer=None, batch_squash=True, dtype=tf.float32, q_layer_activation_fn=None, name='QNetwork')
 |  
 |  Feed Forward network.
 |  
 |  Method resolution order:
 |      QNetwork
 |      tf_agents.networks.network.Network
 |      keras.src.engine.base_layer.Layer
 |      tensorflow.python.module.module.Module
 |      tensorflow.python.trackable.autotrackable.AutoTrackable
 |      tensorflow.python.trackable.base.Trackable
 |      keras.src.utils.version_utils.LayerVersionSelector
 |      builtins.object
 |  
 |  Methods defined here:
 |  
 |  __init__(self, input_tensor_spec, action_spec, preprocessing_layers=None, preprocessing_combine

In [1353]:
q_net = QNetwork(env.observation_spec(), env.action_spec())

Create a DQN agent

In [1354]:
from tf_agents.agents.dqn.dqn_agent import DqnAgent
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.losses import MeanSquaredError
from tensorflow.keras.optimizers.schedules import PolynomialDecay

In [1355]:
help(DqnAgent)

Help on class DqnAgent in module tf_agents.agents.dqn.dqn_agent:

class DqnAgent(tf_agents.agents.tf_agent.TFAgent)
 |  DqnAgent(time_step_spec: tf_agents.trajectories.time_step.TimeStep, action_spec: Union[tensorflow.python.framework.type_spec.TypeSpec, tensorflow.python.framework.tensor.TensorSpec, tensorflow.python.ops.ragged.ragged_tensor.RaggedTensorSpec, tensorflow.python.framework.sparse_tensor.SparseTensorSpec, ForwardRef('tf_agents.distributions.utils.DistributionSpecV2'), Iterable[ForwardRef('NestedTensorSpec')], Mapping[str, ForwardRef('NestedTensorSpec')]], q_network: tf_agents.networks.network.Network, optimizer: Union[keras.src.optimizers.optimizer.Optimizer, tensorflow.python.training.optimizer.Optimizer], observation_and_action_constraint_splitter: Optional[Callable[[Union[tensorflow.python.framework.type_spec.TypeSpec, tensorflow.python.framework.tensor.TensorSpec, tensorflow.python.ops.ragged.ragged_tensor.RaggedTensorSpec, tensorflow.python.framework.sparse_tensor.Sp

In [1356]:
optimizer = Adam(learning_rate=0.001)
loss = MeanSquaredError('none', 'mean_squared_error')
loss_fn = loss.call
discount_factor = 0.95
epsilon_fn = PolynomialDecay(
    initial_learning_rate=1.0,
    decay_steps=1600,
    end_learning_rate=0.0,
    power=1
    )
target_model_update = 50
train_step = tf.Variable(0)

agent = DqnAgent(
    env.time_step_spec(),
    env.action_spec(),
    q_network=q_net,
    optimizer=optimizer,
    gamma=discount_factor,
    td_errors_loss_fn=loss_fn,
    epsilon_greedy=lambda:epsilon_fn(train_step),
    train_step_counter=train_step,
    target_update_period=target_model_update,
)

In [1357]:
env.time_step_spec()

TimeStep(
{'step_type': TensorSpec(shape=(), dtype=tf.int32, name='step_type'),
 'reward': TensorSpec(shape=(), dtype=tf.float32, name='reward'),
 'discount': BoundedTensorSpec(shape=(), dtype=tf.float32, name='discount', minimum=array(0., dtype=float32), maximum=array(1., dtype=float32)),
 'observation': BoundedTensorSpec(shape=(4,), dtype=tf.float32, name='observation', minimum=array([-4.8000002e+00, -3.4028235e+38, -4.1887903e-01, -3.4028235e+38],
      dtype=float32), maximum=array([4.8000002e+00, 3.4028235e+38, 4.1887903e-01, 3.4028235e+38],
      dtype=float32))})

In [1358]:
agent.initialize()

Create a Replay Buffer to store experiences

In [1359]:
from tf_agents.replay_buffers.tf_uniform_replay_buffer import TFUniformReplayBuffer

In [1360]:
agent.collect_data_spec

Trajectory(
{'step_type': TensorSpec(shape=(), dtype=tf.int32, name='step_type'),
 'observation': BoundedTensorSpec(shape=(4,), dtype=tf.float32, name='observation', minimum=array([-4.8000002e+00, -3.4028235e+38, -4.1887903e-01, -3.4028235e+38],
      dtype=float32), maximum=array([4.8000002e+00, 3.4028235e+38, 4.1887903e-01, 3.4028235e+38],
      dtype=float32)),
 'action': BoundedTensorSpec(shape=(), dtype=tf.int64, name='action', minimum=array(0, dtype=int64), maximum=array(1, dtype=int64)),
 'policy_info': (),
 'next_step_type': TensorSpec(shape=(), dtype=tf.int32, name='step_type'),
 'reward': TensorSpec(shape=(), dtype=tf.float32, name='reward'),
 'discount': BoundedTensorSpec(shape=(), dtype=tf.float32, name='discount', minimum=array(0., dtype=float32), maximum=array(1., dtype=float32))})

In [1361]:
replay_buffer = TFUniformReplayBuffer(
    data_spec= agent.collect_data_spec,
    batch_size= env.batch_size,
    max_length= 10000
)

Create an observer to write into the replay buffer

In [1362]:
replay_buffer_observer = replay_buffer.add_batch

In [1363]:
replay_buffer

<tf_agents.replay_buffers.tf_uniform_replay_buffer.TFUniformReplayBuffer at 0x1e6c1c5d590>

In [1364]:
replay_buffer_observer

<bound method ReplayBuffer.add_batch of <tf_agents.replay_buffers.tf_uniform_replay_buffer.TFUniformReplayBuffer object at 0x000001E6C1C5D590>>

Create a Driver that explores environment using a given policy, collects experience and broadcast them to observer.

In [1365]:
from tf_agents.drivers.dynamic_step_driver import DynamicStepDriver

In [1366]:
agent.policy

<tf_agents.policies.greedy_policy.GreedyPolicy at 0x1e6c2280390>

In [1367]:
collect_driver = DynamicStepDriver(
    env=env,
    policy= agent.collect_policy,
    observers= [replay_buffer_observer],
    num_steps=200
)

Create a driver to just fill the replay buffer with some experiences with a random policy

In [1368]:
from tf_agents.policies.random_tf_policy import RandomTFPolicy

In [1369]:
initial_collect_policy = RandomTFPolicy(env.time_step_spec(), env.action_spec())

In [1370]:
initial_collect_policy

<tf_agents.policies.random_tf_policy.RandomTFPolicy at 0x1e6c1c73510>

In [1371]:
initial_driver = DynamicStepDriver(
    env,
    initial_collect_policy,
    [replay_buffer_observer],
    num_steps=1000
)

In [1372]:
final_time_step, final_policy_state = initial_driver.run()

In [1373]:
final_time_step

TimeStep(
{'step_type': <tf.Tensor: shape=(1,), dtype=int32, numpy=array([1])>,
 'reward': <tf.Tensor: shape=(1,), dtype=float32, numpy=array([1.], dtype=float32)>,
 'discount': <tf.Tensor: shape=(1,), dtype=float32, numpy=array([1.], dtype=float32)>,
 'observation': <tf.Tensor: shape=(1, 4), dtype=float32, numpy=
array([[ 0.05271898,  0.20753057, -0.17364648, -0.65162116]],
      dtype=float32)>})

In [1374]:
final_policy_state

()

Create a Dataset of sample a batch of trajectories for agent to train.

In [1375]:
from tf_agents.trajectories.trajectory import to_transition

In [1376]:
trajectories, buffer_info = replay_buffer.get_next(sample_batch_size=2, num_steps=3)

In [1377]:
trajectories

Trajectory(
{'step_type': <tf.Tensor: shape=(2, 3), dtype=int32, numpy=
array([[1, 1, 1],
       [1, 1, 1]])>,
 'observation': <tf.Tensor: shape=(2, 3, 4), dtype=float32, numpy=
array([[[-0.09331799, -0.4235372 ,  0.04876883,  0.3151456 ],
        [-0.10178874, -0.22914264,  0.05507175,  0.03823283],
        [-0.10637159, -0.03485189,  0.0558364 , -0.23657854]],

       [[ 0.01602175,  0.17876667,  0.00213026, -0.19927166],
        [ 0.01959709,  0.3738581 , -0.00185518, -0.49128184],
        [ 0.02707425,  0.56900615, -0.01168081, -0.7845489 ]]],
      dtype=float32)>,
 'action': <tf.Tensor: shape=(2, 3), dtype=int64, numpy=
array([[1, 1, 1],
       [1, 1, 0]], dtype=int64)>,
 'policy_info': (),
 'next_step_type': <tf.Tensor: shape=(2, 3), dtype=int32, numpy=
array([[1, 1, 1],
       [1, 1, 1]])>,
 'reward': <tf.Tensor: shape=(2, 3), dtype=float32, numpy=
array([[1., 1., 1.],
       [1., 1., 1.]], dtype=float32)>,
 'discount': <tf.Tensor: shape=(2, 3), dtype=float32, numpy=
array([[1.

In [1378]:
buffer_info

BufferInfo(ids=<tf.Tensor: shape=(2, 3), dtype=int64, numpy=
array([[148, 149, 150],
       [272, 273, 274]], dtype=int64)>, probabilities=<tf.Tensor: shape=(2,), dtype=float32, numpy=array([0.00096061, 0.00096061], dtype=float32)>)

In [1379]:
time_steps, action_steps, next_time_steps = to_transition(trajectories)

In [1380]:
time_steps

TimeStep(
{'step_type': <tf.Tensor: shape=(2, 2), dtype=int32, numpy=
array([[1, 1],
       [1, 1]])>,
 'reward': <tf.Tensor: shape=(2, 2), dtype=float32, numpy=
array([[0., 0.],
       [0., 0.]], dtype=float32)>,
 'discount': <tf.Tensor: shape=(2, 2), dtype=float32, numpy=
array([[0., 0.],
       [0., 0.]], dtype=float32)>,
 'observation': <tf.Tensor: shape=(2, 2, 4), dtype=float32, numpy=
array([[[-0.09331799, -0.4235372 ,  0.04876883,  0.3151456 ],
        [-0.10178874, -0.22914264,  0.05507175,  0.03823283]],

       [[ 0.01602175,  0.17876667,  0.00213026, -0.19927166],
        [ 0.01959709,  0.3738581 , -0.00185518, -0.49128184]]],
      dtype=float32)>})

In [1381]:
action_steps

PolicyStep(action=<tf.Tensor: shape=(2, 2), dtype=int64, numpy=
array([[1, 1],
       [1, 1]], dtype=int64)>, state=(), info=())

In [1382]:
dataset = replay_buffer.as_dataset(
    sample_batch_size=128,
    num_steps=2,
)

In [1383]:
dataset

<_MapDataset element_spec=(Trajectory(
{'step_type': TensorSpec(shape=(128, 2), dtype=tf.int32, name=None),
 'observation': TensorSpec(shape=(128, 2, 4), dtype=tf.float32, name=None),
 'action': TensorSpec(shape=(128, 2), dtype=tf.int64, name=None),
 'policy_info': (),
 'next_step_type': TensorSpec(shape=(128, 2), dtype=tf.int32, name=None),
 'reward': TensorSpec(shape=(128, 2), dtype=tf.float32, name=None),
 'discount': TensorSpec(shape=(128, 2), dtype=tf.float32, name=None)}), BufferInfo(ids=TensorSpec(shape=(128, 2), dtype=tf.int64, name=None), probabilities=TensorSpec(shape=(128,), dtype=tf.float32, name=None)))>

In [1384]:
it = iter(dataset)

In [1385]:
it

<tensorflow.python.data.ops.iterator_ops.OwnedIterator at 0x1e6b02ea410>

In [1386]:
next(it)

(Trajectory(
 {'step_type': <tf.Tensor: shape=(128, 2), dtype=int32, numpy=
 array([[1, 1],
        [1, 1],
        [1, 1],
        [1, 1],
        [1, 1],
        [1, 1],
        [1, 1],
        [1, 1],
        [1, 1],
        [1, 2],
        [1, 1],
        [1, 1],
        [1, 2],
        [1, 1],
        [1, 1],
        [1, 1],
        [1, 1],
        [1, 1],
        [1, 1],
        [1, 1],
        [1, 1],
        [1, 1],
        [1, 2],
        [1, 1],
        [1, 1],
        [1, 1],
        [1, 1],
        [1, 1],
        [1, 1],
        [1, 1],
        [1, 1],
        [1, 1],
        [1, 1],
        [1, 1],
        [1, 1],
        [1, 1],
        [1, 1],
        [1, 1],
        [1, 1],
        [1, 1],
        [1, 1],
        [1, 1],
        [1, 1],
        [1, 1],
        [1, 1],
        [1, 1],
        [1, 1],
        [1, 1],
        [1, 1],
        [1, 1],
        [1, 1],
        [1, 1],
        [1, 1],
        [1, 1],
        [1, 1],
        [1, 1],
        [1, 1],
        [1, 

In [1387]:
next(it)[0]

Trajectory(
{'step_type': <tf.Tensor: shape=(128, 2), dtype=int32, numpy=
array([[1, 1],
       [0, 1],
       [1, 2],
       [1, 1],
       [1, 1],
       [2, 0],
       [1, 1],
       [1, 1],
       [1, 1],
       [1, 1],
       [1, 1],
       [1, 1],
       [1, 1],
       [1, 2],
       [1, 1],
       [1, 1],
       [1, 1],
       [1, 1],
       [1, 1],
       [1, 1],
       [2, 0],
       [1, 1],
       [1, 1],
       [1, 1],
       [0, 1],
       [1, 1],
       [1, 1],
       [1, 1],
       [1, 1],
       [1, 1],
       [1, 1],
       [1, 1],
       [1, 1],
       [1, 2],
       [1, 1],
       [1, 1],
       [1, 1],
       [1, 1],
       [1, 1],
       [1, 1],
       [1, 1],
       [1, 1],
       [1, 1],
       [1, 1],
       [1, 1],
       [1, 2],
       [1, 1],
       [1, 1],
       [1, 1],
       [1, 1],
       [1, 1],
       [1, 1],
       [1, 1],
       [1, 1],
       [0, 1],
       [1, 1],
       [1, 1],
       [1, 1],
       [2, 0],
       [1, 1],
       [1, 1],
       [1, 

Create a training loop

In [1388]:
from tf_agents.utils.common import function

In [1389]:
collect_driver.run

<bound method DynamicStepDriver.run of <tf_agents.drivers.dynamic_step_driver.DynamicStepDriver object at 0x000001E6C32CFF90>>

In [1390]:
agent.train

<bound method TFAgent.train of <tf_agents.agents.dqn.dqn_agent.DqnAgent object at 0x000001E6C1F69E50>>

In [1391]:
collect_driver.run = function(collect_driver.run)
agent.train = function(agent.train)

In [1392]:
collect_driver.run

<tensorflow.python.eager.polymorphic_function.polymorphic_function.Function at 0x1e6bdae5f90>

In [1393]:
agent.train

<tensorflow.python.eager.polymorphic_function.polymorphic_function.Function at 0x1e6c33148d0>

In [1394]:
agent.collect_policy.get_initial_state(env.batch_size)

()

In [1395]:
collect_driver.run(None, ())

(TimeStep(
 {'step_type': <tf.Tensor: shape=(1,), dtype=int32, numpy=array([1])>,
  'reward': <tf.Tensor: shape=(1,), dtype=float32, numpy=array([1.], dtype=float32)>,
  'discount': <tf.Tensor: shape=(1,), dtype=float32, numpy=array([1.], dtype=float32)>,
  'observation': <tf.Tensor: shape=(1, 4), dtype=float32, numpy=
 array([[-0.11005878, -0.7993053 ,  0.09156413,  1.2002417 ]],
       dtype=float32)>}),
 ())

In [1396]:
trajectories, buffer_info = next(it)

In [1397]:
trajectories

Trajectory(
{'step_type': <tf.Tensor: shape=(128, 2), dtype=int32, numpy=
array([[0, 1],
       [1, 1],
       [1, 1],
       [1, 1],
       [1, 1],
       [1, 1],
       [1, 1],
       [1, 1],
       [1, 1],
       [1, 1],
       [1, 1],
       [1, 1],
       [1, 1],
       [1, 1],
       [1, 1],
       [1, 1],
       [1, 1],
       [1, 1],
       [1, 1],
       [1, 1],
       [1, 1],
       [0, 1],
       [2, 0],
       [1, 1],
       [1, 1],
       [1, 1],
       [1, 1],
       [1, 1],
       [1, 1],
       [1, 1],
       [1, 1],
       [1, 1],
       [1, 1],
       [1, 1],
       [0, 1],
       [1, 2],
       [1, 2],
       [1, 1],
       [1, 1],
       [1, 1],
       [1, 1],
       [1, 1],
       [0, 1],
       [1, 1],
       [1, 1],
       [1, 1],
       [1, 1],
       [1, 1],
       [1, 2],
       [1, 1],
       [1, 1],
       [1, 1],
       [1, 1],
       [1, 1],
       [1, 1],
       [1, 1],
       [1, 1],
       [1, 1],
       [1, 1],
       [1, 1],
       [1, 1],
       [1, 

In [1398]:
agent.training_data_spec

Trajectory(
{'step_type': TensorSpec(shape=(), dtype=tf.int32, name='step_type'),
 'observation': BoundedTensorSpec(shape=(4,), dtype=tf.float32, name='observation', minimum=array([-4.8000002e+00, -3.4028235e+38, -4.1887903e-01, -3.4028235e+38],
      dtype=float32), maximum=array([4.8000002e+00, 3.4028235e+38, 4.1887903e-01, 3.4028235e+38],
      dtype=float32)),
 'action': BoundedTensorSpec(shape=(), dtype=tf.int64, name='action', minimum=array(0, dtype=int64), maximum=array(1, dtype=int64)),
 'policy_info': (),
 'next_step_type': TensorSpec(shape=(), dtype=tf.int32, name='step_type'),
 'reward': TensorSpec(shape=(), dtype=tf.float32, name='reward'),
 'discount': BoundedTensorSpec(shape=(), dtype=tf.float32, name='discount', minimum=array(0., dtype=float32), maximum=array(1., dtype=float32))})

In [1399]:
agent.train_sequence_length

2

In [1400]:
agent.train(trajectories)



LossInfo(loss=<tf.Tensor: shape=(), dtype=float32, numpy=0.97520185>, extra=DqnLossInfo(td_loss=<tf.Tensor: shape=(128,), dtype=float32, numpy=
array([1.0148442, 1.0148442, 1.0148442, 1.0148442, 1.0148442, 1.0148442,
       1.0148442, 1.0148442, 1.0148442, 1.0148442, 1.0148442, 1.0148442,
       1.0148442, 1.0148442, 1.0148442, 1.0148442, 1.0148442, 1.0148442,
       1.0148442, 1.0148442, 1.0148442, 1.0148442, 0.       , 1.0148442,
       1.0148442, 1.0148442, 1.0148442, 1.0148442, 1.0148442, 1.0148442,
       1.0148442, 1.0148442, 1.0148442, 1.0148442, 1.0148442, 1.0148442,
       1.0148442, 1.0148442, 1.0148442, 1.0148442, 1.0148442, 1.0148442,
       1.0148442, 1.0148442, 1.0148442, 1.0148442, 1.0148442, 1.0148442,
       1.0148442, 1.0148442, 1.0148442, 1.0148442, 1.0148442, 1.0148442,
       1.0148442, 1.0148442, 1.0148442, 1.0148442, 1.0148442, 1.0148442,
       1.0148442, 1.0148442, 1.0148442, 1.0148442, 1.0148442, 1.0148442,
       1.0148442, 1.0148442, 1.0148442, 1.0148442, 1.