Here we will use the TF Agents to train the CartPole environment with DQN.
TF Agents package makes the implementation of RL algo easier.

In [50]:
import tensorflow as tf
from tf_agents.environments import suite_gym

In [51]:
env = suite_gym.load("CartPole-v1")

In [52]:
env

<tf_agents.environments.wrappers.TimeLimit at 0x22073625650>

In [53]:
env.gym

<TimeLimit<OrderEnforcing<CartPoleEnv<CartPole-v1>>>>

In [54]:
env.reset()

TimeStep(
{'step_type': array(0),
 'reward': array(0., dtype=float32),
 'discount': array(1., dtype=float32),
 'observation': array([-0.0098784 ,  0.0001675 ,  0.01735287, -0.04730087], dtype=float32)})

In [55]:
env.step(0)

TimeStep(
{'step_type': array(1),
 'reward': array(1., dtype=float32),
 'discount': array(1., dtype=float32),
 'observation': array([-0.00987505, -0.19519892,  0.01640685,  0.25080612], dtype=float32)})

Explore Environment Specification

In [56]:
env.observation_spec()

BoundedArraySpec(shape=(4,), dtype=dtype('float32'), name='observation', minimum=[-4.8000002e+00 -3.4028235e+38 -4.1887903e-01 -3.4028235e+38], maximum=[4.8000002e+00 3.4028235e+38 4.1887903e-01 3.4028235e+38])

In [57]:
env.action_spec()

BoundedArraySpec(shape=(), dtype=dtype('int64'), name='action', minimum=0, maximum=1)

In [58]:
env.time_step_spec()

TimeStep(
{'step_type': ArraySpec(shape=(), dtype=dtype('int32'), name='step_type'),
 'reward': ArraySpec(shape=(), dtype=dtype('float32'), name='reward'),
 'discount': BoundedArraySpec(shape=(), dtype=dtype('float32'), name='discount', minimum=0.0, maximum=1.0),
 'observation': BoundedArraySpec(shape=(4,), dtype=dtype('float32'), name='observation', minimum=[-4.8000002e+00 -3.4028235e+38 -4.1887903e-01 -3.4028235e+38], maximum=[4.8000002e+00 3.4028235e+38 4.1887903e-01 3.4028235e+38])})

In [59]:
env.reward_spec()

ArraySpec(shape=(), dtype=dtype('float32'), name='reward')

In [60]:
env.discount_spec()

BoundedArraySpec(shape=(), dtype=dtype('float32'), name='discount', minimum=0.0, maximum=1.0)

In [61]:
env.current_time_step()

TimeStep(
{'step_type': array(1),
 'reward': array(1., dtype=float32),
 'discount': array(1., dtype=float32),
 'observation': array([-0.00987505, -0.19519892,  0.01640685,  0.25080612], dtype=float32)})

Wrap the environment with TFPyEnvironment which supports both py and tf environments.

In [62]:
from tf_agents.environments.tf_py_environment import TFPyEnvironment

In [63]:
env = TFPyEnvironment(env)

In [64]:
env

<tf_agents.environments.tf_py_environment.TFPyEnvironment at 0x2207372f450>

In [65]:
env.reset()

TimeStep(
{'step_type': <tf.Tensor: shape=(1,), dtype=int32, numpy=array([0])>,
 'reward': <tf.Tensor: shape=(1,), dtype=float32, numpy=array([0.], dtype=float32)>,
 'discount': <tf.Tensor: shape=(1,), dtype=float32, numpy=array([1.], dtype=float32)>,
 'observation': <tf.Tensor: shape=(1, 4), dtype=float32, numpy=
array([[-0.02148625, -0.0419181 , -0.03171623, -0.02292799]],
      dtype=float32)>})

In [66]:
env.step(0)

TimeStep(
{'step_type': <tf.Tensor: shape=(1,), dtype=int32, numpy=array([1])>,
 'reward': <tf.Tensor: shape=(1,), dtype=float32, numpy=array([1.], dtype=float32)>,
 'discount': <tf.Tensor: shape=(1,), dtype=float32, numpy=array([1.], dtype=float32)>,
 'observation': <tf.Tensor: shape=(1, 4), dtype=float32, numpy=
array([[-0.02232462, -0.2365712 , -0.0321748 ,  0.2595818 ]],
      dtype=float32)>})

In [67]:
env.observation_spec()

BoundedTensorSpec(shape=(4,), dtype=tf.float32, name='observation', minimum=array([-4.8000002e+00, -3.4028235e+38, -4.1887903e-01, -3.4028235e+38],
      dtype=float32), maximum=array([4.8000002e+00, 3.4028235e+38, 4.1887903e-01, 3.4028235e+38],
      dtype=float32))

In [68]:
env.action_spec()

BoundedTensorSpec(shape=(), dtype=tf.int64, name='action', minimum=array(0, dtype=int64), maximum=array(1, dtype=int64))

In [69]:
env.time_step_spec()

TimeStep(
{'step_type': TensorSpec(shape=(), dtype=tf.int32, name='step_type'),
 'reward': TensorSpec(shape=(), dtype=tf.float32, name='reward'),
 'discount': BoundedTensorSpec(shape=(), dtype=tf.float32, name='discount', minimum=array(0., dtype=float32), maximum=array(1., dtype=float32)),
 'observation': BoundedTensorSpec(shape=(4,), dtype=tf.float32, name='observation', minimum=array([-4.8000002e+00, -3.4028235e+38, -4.1887903e-01, -3.4028235e+38],
      dtype=float32), maximum=array([4.8000002e+00, 3.4028235e+38, 4.1887903e-01, 3.4028235e+38],
      dtype=float32))})

In [70]:
env.reward_spec()

TensorSpec(shape=(), dtype=tf.float32, name='reward')

In [71]:
env.discount_spec()

BoundedArraySpec(shape=(), dtype=dtype('float32'), name='discount', minimum=0.0, maximum=1.0)

In [72]:
env.current_time_step()

TimeStep(
{'step_type': <tf.Tensor: shape=(1,), dtype=int32, numpy=array([1])>,
 'reward': <tf.Tensor: shape=(1,), dtype=float32, numpy=array([1.], dtype=float32)>,
 'discount': <tf.Tensor: shape=(1,), dtype=float32, numpy=array([1.], dtype=float32)>,
 'observation': <tf.Tensor: shape=(1, 4), dtype=float32, numpy=
array([[-0.02232462, -0.2365712 , -0.0321748 ,  0.2595818 ]],
      dtype=float32)>})

Create a Deep Q Network

In [73]:
from tf_agents.networks.q_network import QNetwork

In [74]:
help(QNetwork)

Help on class QNetwork in module tf_agents.networks.q_network:

class QNetwork(tf_agents.networks.network.Network)
 |  QNetwork(input_tensor_spec, action_spec, preprocessing_layers=None, preprocessing_combiner=None, conv_layer_params=None, fc_layer_params=(75, 40), dropout_layer_params=None, activation_fn=<function relu at 0x000002206F11D300>, kernel_initializer=None, batch_squash=True, dtype=tf.float32, q_layer_activation_fn=None, name='QNetwork')
 |  
 |  Feed Forward network.
 |  
 |  Method resolution order:
 |      QNetwork
 |      tf_agents.networks.network.Network
 |      keras.src.engine.base_layer.Layer
 |      tensorflow.python.module.module.Module
 |      tensorflow.python.trackable.autotrackable.AutoTrackable
 |      tensorflow.python.trackable.base.Trackable
 |      keras.src.utils.version_utils.LayerVersionSelector
 |      builtins.object
 |  
 |  Methods defined here:
 |  
 |  __init__(self, input_tensor_spec, action_spec, preprocessing_layers=None, preprocessing_combine

In [75]:
q_net = QNetwork(env.observation_spec(), env.action_spec())

Create a DQN agent

In [76]:
from tf_agents.agents.dqn.dqn_agent import DqnAgent
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.losses import MeanSquaredError
from tensorflow.keras.optimizers.schedules import PolynomialDecay

In [77]:
help(DqnAgent)

Help on class DqnAgent in module tf_agents.agents.dqn.dqn_agent:

class DqnAgent(tf_agents.agents.tf_agent.TFAgent)
 |  DqnAgent(time_step_spec: tf_agents.trajectories.time_step.TimeStep, action_spec: Union[tensorflow.python.framework.type_spec.TypeSpec, tensorflow.python.framework.tensor.TensorSpec, tensorflow.python.ops.ragged.ragged_tensor.RaggedTensorSpec, tensorflow.python.framework.sparse_tensor.SparseTensorSpec, ForwardRef('tf_agents.distributions.utils.DistributionSpecV2'), Iterable[ForwardRef('NestedTensorSpec')], Mapping[str, ForwardRef('NestedTensorSpec')]], q_network: tf_agents.networks.network.Network, optimizer: Union[keras.src.optimizers.optimizer.Optimizer, tensorflow.python.training.optimizer.Optimizer], observation_and_action_constraint_splitter: Optional[Callable[[Union[tensorflow.python.framework.type_spec.TypeSpec, tensorflow.python.framework.tensor.TensorSpec, tensorflow.python.ops.ragged.ragged_tensor.RaggedTensorSpec, tensorflow.python.framework.sparse_tensor.Sp

In [78]:
optimizer = Adam(learning_rate=0.001)
loss = MeanSquaredError('none', 'mean_squared_error')
loss_fn = loss.call
discount_factor = 0.95
epsilon_fn = PolynomialDecay(
    initial_learning_rate=1.0,
    decay_steps=1600,
    end_learning_rate=0.0,
    power=1
    )
target_model_update = 50
train_step = tf.Variable(0)

agent = DqnAgent(
    env.time_step_spec(),
    env.action_spec(),
    q_network=q_net,
    optimizer=optimizer,
    gamma=discount_factor,
    td_errors_loss_fn=loss_fn,
    epsilon_greedy=lambda:epsilon_fn(train_step),
    train_step_counter=train_step,
    target_update_period=target_model_update,
)

In [79]:
env.time_step_spec()

TimeStep(
{'step_type': TensorSpec(shape=(), dtype=tf.int32, name='step_type'),
 'reward': TensorSpec(shape=(), dtype=tf.float32, name='reward'),
 'discount': BoundedTensorSpec(shape=(), dtype=tf.float32, name='discount', minimum=array(0., dtype=float32), maximum=array(1., dtype=float32)),
 'observation': BoundedTensorSpec(shape=(4,), dtype=tf.float32, name='observation', minimum=array([-4.8000002e+00, -3.4028235e+38, -4.1887903e-01, -3.4028235e+38],
      dtype=float32), maximum=array([4.8000002e+00, 3.4028235e+38, 4.1887903e-01, 3.4028235e+38],
      dtype=float32))})

In [80]:
agent.initialize()

Create a Replay Buffer to store experiences

In [81]:
from tf_agents.replay_buffers.tf_uniform_replay_buffer import TFUniformReplayBuffer

In [82]:
agent.collect_data_spec

Trajectory(
{'step_type': TensorSpec(shape=(), dtype=tf.int32, name='step_type'),
 'observation': BoundedTensorSpec(shape=(4,), dtype=tf.float32, name='observation', minimum=array([-4.8000002e+00, -3.4028235e+38, -4.1887903e-01, -3.4028235e+38],
      dtype=float32), maximum=array([4.8000002e+00, 3.4028235e+38, 4.1887903e-01, 3.4028235e+38],
      dtype=float32)),
 'action': BoundedTensorSpec(shape=(), dtype=tf.int64, name='action', minimum=array(0, dtype=int64), maximum=array(1, dtype=int64)),
 'policy_info': (),
 'next_step_type': TensorSpec(shape=(), dtype=tf.int32, name='step_type'),
 'reward': TensorSpec(shape=(), dtype=tf.float32, name='reward'),
 'discount': BoundedTensorSpec(shape=(), dtype=tf.float32, name='discount', minimum=array(0., dtype=float32), maximum=array(1., dtype=float32))})

In [83]:
replay_buffer = TFUniformReplayBuffer(
    data_spec= agent.collect_data_spec,
    batch_size= env.batch_size,
    max_length= 10000
)

Create an observer to write into the replay buffer

In [84]:
replay_buffer_observer = replay_buffer.add_batch

In [85]:
replay_buffer

<tf_agents.replay_buffers.tf_uniform_replay_buffer.TFUniformReplayBuffer at 0x22073767e50>

In [86]:
replay_buffer_observer

<bound method ReplayBuffer.add_batch of <tf_agents.replay_buffers.tf_uniform_replay_buffer.TFUniformReplayBuffer object at 0x0000022073767E50>>

Create a Driver that explores environment using a given policy, collects experience and broadcast them to observer.

In [87]:
from tf_agents.drivers.dynamic_step_driver import DynamicStepDriver

In [88]:
agent.policy

<tf_agents.policies.greedy_policy.GreedyPolicy at 0x2207372d350>

In [89]:
collect_driver = DynamicStepDriver(
    env=env,
    policy= agent.collect_policy,
    observers= [replay_buffer_observer],
    num_steps=200
)

Create a driver to just fill the replay buffer with some experiences with a random policy

In [90]:
from tf_agents.policies.random_tf_policy import RandomTFPolicy

In [91]:
initial_collect_policy = RandomTFPolicy(env.time_step_spec(), env.action_spec())

In [92]:
initial_collect_policy

<tf_agents.policies.random_tf_policy.RandomTFPolicy at 0x220735d0750>

In [93]:
initial_driver = DynamicStepDriver(
    env,
    initial_collect_policy,
    [replay_buffer_observer],
    num_steps=1000
)

In [94]:
final_time_step, final_policy_state = initial_driver.run()

In [95]:
final_time_step

TimeStep(
{'step_type': <tf.Tensor: shape=(1,), dtype=int32, numpy=array([2])>,
 'reward': <tf.Tensor: shape=(1,), dtype=float32, numpy=array([1.], dtype=float32)>,
 'discount': <tf.Tensor: shape=(1,), dtype=float32, numpy=array([0.], dtype=float32)>,
 'observation': <tf.Tensor: shape=(1, 4), dtype=float32, numpy=
array([[ 0.09330807,  0.6483104 , -0.22590569, -1.6200706 ]],
      dtype=float32)>})

In [96]:
final_policy_state

()

Create a Dataset of sample a batch of trajectories for agent to train.

In [97]:
from tf_agents.trajectories.trajectory import to_transition

In [98]:
trajectories, buffer_info = replay_buffer.get_next(sample_batch_size=2, num_steps=3)

Instructions for updating:
Use `as_dataset(..., single_deterministic_pass=False) instead.


In [99]:
trajectories

Trajectory(
{'step_type': <tf.Tensor: shape=(2, 3), dtype=int32, numpy=
array([[1, 1, 1],
       [1, 1, 1]])>,
 'observation': <tf.Tensor: shape=(2, 3, 4), dtype=float32, numpy=
array([[[-0.01349626, -0.03294312,  0.01875259,  0.08401202],
        [-0.01415512, -0.2283288 ,  0.02043283,  0.38255194],
        [-0.0187217 , -0.4237348 ,  0.02808387,  0.6816066 ]],

       [[ 0.04898272,  0.23138125, -0.05394045, -0.4107656 ],
        [ 0.05361035,  0.03706384, -0.06215576, -0.1355642 ],
        [ 0.05435162,  0.23301847, -0.06486705, -0.44719058]]],
      dtype=float32)>,
 'action': <tf.Tensor: shape=(2, 3), dtype=int64, numpy=
array([[0, 0, 1],
       [0, 1, 1]], dtype=int64)>,
 'policy_info': (),
 'next_step_type': <tf.Tensor: shape=(2, 3), dtype=int32, numpy=
array([[1, 1, 1],
       [1, 1, 1]])>,
 'reward': <tf.Tensor: shape=(2, 3), dtype=float32, numpy=
array([[1., 1., 1.],
       [1., 1., 1.]], dtype=float32)>,
 'discount': <tf.Tensor: shape=(2, 3), dtype=float32, numpy=
array([[1.

In [100]:
buffer_info

BufferInfo(ids=<tf.Tensor: shape=(2, 3), dtype=int64, numpy=
array([[  46,   47,   48],
       [1025, 1026, 1027]], dtype=int64)>, probabilities=<tf.Tensor: shape=(2,), dtype=float32, numpy=array([0.00096061, 0.00096061], dtype=float32)>)

In [101]:
time_steps, action_steps, next_time_steps = to_transition(trajectories)

In [102]:
time_steps

TimeStep(
{'step_type': <tf.Tensor: shape=(2, 2), dtype=int32, numpy=
array([[1, 1],
       [1, 1]])>,
 'reward': <tf.Tensor: shape=(2, 2), dtype=float32, numpy=
array([[0., 0.],
       [0., 0.]], dtype=float32)>,
 'discount': <tf.Tensor: shape=(2, 2), dtype=float32, numpy=
array([[0., 0.],
       [0., 0.]], dtype=float32)>,
 'observation': <tf.Tensor: shape=(2, 2, 4), dtype=float32, numpy=
array([[[-0.01349626, -0.03294312,  0.01875259,  0.08401202],
        [-0.01415512, -0.2283288 ,  0.02043283,  0.38255194]],

       [[ 0.04898272,  0.23138125, -0.05394045, -0.4107656 ],
        [ 0.05361035,  0.03706384, -0.06215576, -0.1355642 ]]],
      dtype=float32)>})

In [103]:
action_steps

PolicyStep(action=<tf.Tensor: shape=(2, 2), dtype=int64, numpy=
array([[0, 0],
       [0, 1]], dtype=int64)>, state=(), info=())

In [104]:
dataset = replay_buffer.as_dataset(
    sample_batch_size=128,
    num_steps=2,
)

In [105]:
dataset

<_MapDataset element_spec=(Trajectory(
{'step_type': TensorSpec(shape=(128, 2), dtype=tf.int32, name=None),
 'observation': TensorSpec(shape=(128, 2, 4), dtype=tf.float32, name=None),
 'action': TensorSpec(shape=(128, 2), dtype=tf.int64, name=None),
 'policy_info': (),
 'next_step_type': TensorSpec(shape=(128, 2), dtype=tf.int32, name=None),
 'reward': TensorSpec(shape=(128, 2), dtype=tf.float32, name=None),
 'discount': TensorSpec(shape=(128, 2), dtype=tf.float32, name=None)}), BufferInfo(ids=TensorSpec(shape=(128, 2), dtype=tf.int64, name=None), probabilities=TensorSpec(shape=(128,), dtype=tf.float32, name=None)))>

In [106]:
it = iter(dataset)

In [107]:
it

<tensorflow.python.data.ops.iterator_ops.OwnedIterator at 0x22073627c10>

In [108]:
next(it)

(Trajectory(
 {'step_type': <tf.Tensor: shape=(128, 2), dtype=int32, numpy=
 array([[1, 1],
        [0, 1],
        [1, 1],
        [1, 1],
        [1, 1],
        [1, 1],
        [1, 1],
        [1, 1],
        [1, 1],
        [1, 1],
        [1, 1],
        [1, 1],
        [1, 1],
        [1, 1],
        [1, 1],
        [1, 1],
        [1, 1],
        [1, 1],
        [1, 1],
        [0, 1],
        [1, 1],
        [1, 1],
        [1, 1],
        [1, 1],
        [1, 2],
        [1, 1],
        [1, 2],
        [1, 1],
        [1, 1],
        [1, 1],
        [1, 1],
        [1, 1],
        [1, 1],
        [1, 1],
        [1, 1],
        [1, 1],
        [1, 1],
        [1, 1],
        [1, 1],
        [1, 1],
        [1, 1],
        [1, 1],
        [1, 1],
        [1, 1],
        [2, 0],
        [1, 2],
        [1, 1],
        [1, 1],
        [0, 1],
        [1, 1],
        [1, 1],
        [0, 1],
        [1, 1],
        [1, 1],
        [1, 1],
        [1, 1],
        [1, 1],
        [1, 

In [109]:
next(it)[0]

Trajectory(
{'step_type': <tf.Tensor: shape=(128, 2), dtype=int32, numpy=
array([[1, 1],
       [1, 1],
       [1, 1],
       [1, 1],
       [1, 1],
       [0, 1],
       [1, 2],
       [1, 1],
       [0, 1],
       [1, 1],
       [1, 1],
       [1, 1],
       [1, 1],
       [1, 1],
       [1, 1],
       [1, 1],
       [1, 1],
       [1, 1],
       [1, 1],
       [1, 1],
       [1, 1],
       [1, 1],
       [1, 1],
       [1, 1],
       [1, 2],
       [1, 2],
       [1, 1],
       [1, 1],
       [1, 1],
       [0, 1],
       [1, 1],
       [1, 1],
       [1, 1],
       [1, 1],
       [1, 1],
       [2, 0],
       [1, 1],
       [1, 1],
       [1, 1],
       [1, 1],
       [1, 1],
       [1, 1],
       [1, 1],
       [1, 1],
       [1, 1],
       [1, 1],
       [1, 1],
       [1, 1],
       [1, 1],
       [1, 1],
       [1, 1],
       [1, 1],
       [1, 1],
       [1, 1],
       [1, 1],
       [1, 1],
       [1, 1],
       [1, 1],
       [1, 1],
       [1, 1],
       [1, 1],
       [2, 

Create a training loop

In [110]:
from tf_agents.utils.common import function

In [111]:
collect_driver.run

<bound method DynamicStepDriver.run of <tf_agents.drivers.dynamic_step_driver.DynamicStepDriver object at 0x0000022073364B50>>

In [112]:
agent.train

<bound method TFAgent.train of <tf_agents.agents.dqn.dqn_agent.DqnAgent object at 0x00000220735E8F10>>

In [113]:
collect_driver.run = function(collect_driver.run)
agent.train = function(agent.train)

In [114]:
collect_driver.run

<tensorflow.python.eager.polymorphic_function.polymorphic_function.Function at 0x220738525d0>

In [115]:
agent.train

<tensorflow.python.eager.polymorphic_function.polymorphic_function.Function at 0x22074c90790>

In [116]:
agent.collect_policy.get_initial_state(env.batch_size)

()

In [117]:
collect_driver.run(None, ())

(TimeStep(
 {'step_type': <tf.Tensor: shape=(1,), dtype=int32, numpy=array([1])>,
  'reward': <tf.Tensor: shape=(1,), dtype=float32, numpy=array([1.], dtype=float32)>,
  'discount': <tf.Tensor: shape=(1,), dtype=float32, numpy=array([1.], dtype=float32)>,
  'observation': <tf.Tensor: shape=(1, 4), dtype=float32, numpy=
 array([[ 0.06335325,  0.544458  , -0.0807394 , -0.9116811 ]],
       dtype=float32)>}),
 ())

In [118]:
ts, ps = collect_driver.run(None,())

In [119]:
trajectories, buffer_info = next(it)

In [120]:
trajectories

Trajectory(
{'step_type': <tf.Tensor: shape=(128, 2), dtype=int32, numpy=
array([[1, 1],
       [1, 1],
       [1, 1],
       [1, 1],
       [1, 1],
       [1, 1],
       [1, 1],
       [2, 0],
       [1, 1],
       [1, 1],
       [1, 1],
       [1, 1],
       [1, 1],
       [1, 1],
       [1, 1],
       [1, 1],
       [1, 1],
       [1, 1],
       [1, 1],
       [1, 1],
       [1, 1],
       [1, 1],
       [1, 1],
       [1, 1],
       [1, 1],
       [1, 1],
       [1, 1],
       [1, 1],
       [1, 1],
       [1, 1],
       [1, 1],
       [1, 1],
       [1, 1],
       [1, 1],
       [1, 1],
       [2, 0],
       [1, 1],
       [1, 1],
       [1, 1],
       [1, 1],
       [1, 1],
       [1, 1],
       [1, 1],
       [1, 1],
       [1, 1],
       [1, 1],
       [1, 1],
       [2, 0],
       [1, 1],
       [0, 1],
       [1, 1],
       [1, 1],
       [1, 1],
       [1, 1],
       [2, 0],
       [1, 1],
       [1, 1],
       [1, 1],
       [1, 2],
       [1, 1],
       [1, 1],
       [1, 

In [121]:
agent.training_data_spec

Trajectory(
{'step_type': TensorSpec(shape=(), dtype=tf.int32, name='step_type'),
 'observation': BoundedTensorSpec(shape=(4,), dtype=tf.float32, name='observation', minimum=array([-4.8000002e+00, -3.4028235e+38, -4.1887903e-01, -3.4028235e+38],
      dtype=float32), maximum=array([4.8000002e+00, 3.4028235e+38, 4.1887903e-01, 3.4028235e+38],
      dtype=float32)),
 'action': BoundedTensorSpec(shape=(), dtype=tf.int64, name='action', minimum=array(0, dtype=int64), maximum=array(1, dtype=int64)),
 'policy_info': (),
 'next_step_type': TensorSpec(shape=(), dtype=tf.int32, name='step_type'),
 'reward': TensorSpec(shape=(), dtype=tf.float32, name='reward'),
 'discount': BoundedTensorSpec(shape=(), dtype=tf.float32, name='discount', minimum=array(0., dtype=float32), maximum=array(1., dtype=float32))})

In [122]:
agent.train_sequence_length

2

In [123]:
agent.train(trajectories)

Instructions for updating:
back_prop=False is deprecated. Consider using tf.stop_gradient instead.
Instead of:
results = tf.foldr(fn, elems, back_prop=False)
Use:
results = tf.nest.map_structure(tf.stop_gradient, tf.foldr(fn, elems))


LossInfo(loss=<tf.Tensor: shape=(), dtype=float32, numpy=0.9125452>, extra=DqnLossInfo(td_loss=<tf.Tensor: shape=(128,), dtype=float32, numpy=
array([0.9815613, 0.9815613, 0.9815613, 0.9815613, 0.9815613, 0.9815613,
       0.9815613, 0.       , 0.9815613, 0.9815613, 0.9815613, 0.9815613,
       0.9815613, 0.9815613, 0.9815613, 0.9815613, 0.9815613, 0.9815613,
       0.9815613, 0.9815613, 0.9815613, 0.9815613, 0.9815613, 0.9815613,
       0.9815613, 0.9815613, 0.9815613, 0.9815613, 0.9815613, 0.9815613,
       0.9815613, 0.9815613, 0.9815613, 0.9815613, 0.9815613, 0.       ,
       0.9815613, 0.9815613, 0.9815613, 0.9815613, 0.9815613, 0.9815613,
       0.9815613, 0.9815613, 0.9815613, 0.9815613, 0.9815613, 0.       ,
       0.9815613, 0.9815613, 0.9815613, 0.9815613, 0.9815613, 0.9815613,
       0.       , 0.9815613, 0.9815613, 0.9815613, 0.9815613, 0.9815613,
       0.9815613, 0.9815613, 0.9815613, 0.9815613, 0.       , 0.9815613,
       0.9815613, 0.9815613, 0.9815613, 0.       , 0.9

In [131]:
def train_agent(n_iteration):
    time_step = None
    policy_state = agent.collect_policy.get_initial_state(env.batch_size)
    iterator = iter(dataset)

    for iteration in range(n_iteration):
        time_step, policy_state = collect_driver.run(time_step, policy_state)
        trajectories, buffer_info = next(iterator)
        train_loss = agent.train(trajectories)
        print(train_loss.loss.numpy())

In [132]:
train_agent(5)

0.61250025
0.56943023
0.5018613
0.5089667
0.46563017
