# 4.7 Design Deep Reinforcement Learning Agent

In this notebook we will consider an alternative question. Instead of asking how we can maximize profit on a single asset, we will ask whether a machine can pick the best assets given many selections. 

We will consider only price history data. We will provide 90 differenced timesteps on the minute interval period. We will choose the 500 stocks of the S&P 500. 

If our network can outperform the S&P 500 over the given time, we will consider it successful. 

In order to do this, we will need to perform the following steps:

- download datasets for all the stocks in the S&P 500. 
- format the data to represent the simulataneous movement of 500 stocks
- Build an environment to represent this movement
- Train a DQN to learn on it

#### Download datasets for all the stocks in the S&P 500.

In [2]:
from IPython import display
!pip install gcsfs
!pip install tf-agents[reverb]
display.clear_output()

In [3]:
from google.colab import drive
from google.cloud import storage
import pandas as pd
import time
import pandas as pd 
import matplotlib.pyplot as plt
from __future__ import absolute_import, division, print_function

import abc
import tensorflow as tf
import numpy as np
import base64
import IPython
import matplotlib.pyplot as plt
import numpy as np
import os
import tempfile

from tf_agents.environments import py_environment
from tf_agents.environments import tf_environment
from tf_agents.environments import tf_py_environment
from tf_agents.environments import utils
from tf_agents.specs import array_spec
from tf_agents.environments import wrappers
from tf_agents.trajectories import time_step as ts
from tf_agents.agents.dqn import dqn_agent
from tf_agents.drivers import dynamic_step_driver
from tf_agents.eval import metric_utils
from tf_agents.metrics import tf_metrics
from tf_agents.networks import q_network
from tf_agents.policies import random_tf_policy
from tf_agents.replay_buffers import tf_uniform_replay_buffer
from tf_agents.trajectories import trajectory
from tf_agents.utils import common

from tf_agents.agents.ddpg import critic_network
from tf_agents.agents.sac import sac_agent
from tf_agents.agents.sac import tanh_normal_projection_network
from tf_agents.experimental.train import actor
from tf_agents.experimental.train import learner
from tf_agents.experimental.train import triggers
from tf_agents.experimental.train.utils import spec_utils
from tf_agents.experimental.train.utils import strategy_utils
from tf_agents.experimental.train.utils import train_utils
from tf_agents.metrics import py_metrics
from tf_agents.networks import actor_distribution_network
from tf_agents.policies import greedy_policy
from tf_agents.policies import py_tf_eager_policy
from tf_agents.policies import random_py_policy
from tf_agents.replay_buffers import reverb_replay_buffer
from tf_agents.replay_buffers import reverb_utils
import reverb

tf.compat.v1.enable_v2_behavior()

In [10]:
drive.mount('/content/drive')

os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = "/content/drive/My Drive/fin-aml/fin-aml-gcskeys-337dfd8d8867.json"

storage_client = storage.Client()
display.clear_output()

In [5]:
pd.set_option('display.max_rows', 100)
pd.set_option('display.max_columns', 100)

tempdir = tempfile.gettempdir()

In [9]:
df = pd.read_csv('gs://fin-datasets-aml/data/sp500/spdfm.csv')
df.set_index('datetime', inplace=True)
df.head()

Unnamed: 0_level_0,MMM,ABT,ABBV,ABMD,ACN,ATVI,ADBE,AMD,AAP,AES,AFL,A,APD,AKAM,ALK,ALB,ARE,ALXN,ALGN,ALLE,LNT,ALL,GOOGL,GOOG,MO,AMZN,AMCR,AEE,AAL,AEP,AXP,AIG,AMT,AWK,AMP,ABC,AME,AMGN,APH,ADI,ANSS,ANTM,AON,AOS,APA,AIV,AAPL,AMAT,APTV,ADM,...,USB,UAA,UA,UNP,UAL,UNH,UPS,URI,UHS,UNM,VFC,VLO,VAR,VTR,VRSN,VRSK,VZ,VRTX,VIAC,V,VNO,VMC,WRB,WAB,WMT,WBA,DIS,WM,WAT,WEC,WFC,WELL,WST,WDC,WU,WRK,WY,WHR,WMB,WLTW,WYNN,XEL,XRX,XLNX,XYL,YUM,ZBRA,ZBH,ZION,ZTS
datetime,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1,Unnamed: 24_level_1,Unnamed: 25_level_1,Unnamed: 26_level_1,Unnamed: 27_level_1,Unnamed: 28_level_1,Unnamed: 29_level_1,Unnamed: 30_level_1,Unnamed: 31_level_1,Unnamed: 32_level_1,Unnamed: 33_level_1,Unnamed: 34_level_1,Unnamed: 35_level_1,Unnamed: 36_level_1,Unnamed: 37_level_1,Unnamed: 38_level_1,Unnamed: 39_level_1,Unnamed: 40_level_1,Unnamed: 41_level_1,Unnamed: 42_level_1,Unnamed: 43_level_1,Unnamed: 44_level_1,Unnamed: 45_level_1,Unnamed: 46_level_1,Unnamed: 47_level_1,Unnamed: 48_level_1,Unnamed: 49_level_1,Unnamed: 50_level_1,Unnamed: 51_level_1,Unnamed: 52_level_1,Unnamed: 53_level_1,Unnamed: 54_level_1,Unnamed: 55_level_1,Unnamed: 56_level_1,Unnamed: 57_level_1,Unnamed: 58_level_1,Unnamed: 59_level_1,Unnamed: 60_level_1,Unnamed: 61_level_1,Unnamed: 62_level_1,Unnamed: 63_level_1,Unnamed: 64_level_1,Unnamed: 65_level_1,Unnamed: 66_level_1,Unnamed: 67_level_1,Unnamed: 68_level_1,Unnamed: 69_level_1,Unnamed: 70_level_1,Unnamed: 71_level_1,Unnamed: 72_level_1,Unnamed: 73_level_1,Unnamed: 74_level_1,Unnamed: 75_level_1,Unnamed: 76_level_1,Unnamed: 77_level_1,Unnamed: 78_level_1,Unnamed: 79_level_1,Unnamed: 80_level_1,Unnamed: 81_level_1,Unnamed: 82_level_1,Unnamed: 83_level_1,Unnamed: 84_level_1,Unnamed: 85_level_1,Unnamed: 86_level_1,Unnamed: 87_level_1,Unnamed: 88_level_1,Unnamed: 89_level_1,Unnamed: 90_level_1,Unnamed: 91_level_1,Unnamed: 92_level_1,Unnamed: 93_level_1,Unnamed: 94_level_1,Unnamed: 95_level_1,Unnamed: 96_level_1,Unnamed: 97_level_1,Unnamed: 98_level_1,Unnamed: 99_level_1,Unnamed: 100_level_1,Unnamed: 101_level_1
2020-09-04 19:59:00,0.000724,-9.6e-05,0.000871,-0.000834,-0.001144,-0.001139,-0.001586,-0.001829,0.000325,-0.001139,0.001189,-0.000103,0.000101,-0.000623,-0.002205,-0.000634,-0.000962,-0.000443,-0.000939,-0.001104,-0.000375,-0.000426,1.3e-05,-0.000811,-0.000115,-0.000401,0.002269,-0.000507,-0.001395,-0.00038,0.000568,0.001174,-0.00096,-7e-05,-0.000447,-0.000518,-0.000302,-0.000805,-0.00083,0.000681,-0.000405,0.000253,-0.001036,0.000309,0.000342,-0.000534,-0.001861,-0.001148,0.000232,0.000538,...,0.000522,0.001437,0.0,-0.000105,-0.001701,-0.00109,-0.000156,-0.000397,-0.000528,-0.001257,-0.000147,-0.000293,-0.000748,-0.001122,-0.000737,-0.000109,0.000909,-0.000413,-0.000347,-0.000244,-0.000544,-0.00032,0.0,0.000435,-0.00042,-0.001084,0.000303,0.00045,-0.000874,0.000207,0.001614,0.001226,-0.001214,-0.000524,-0.002171,-0.001844,-0.001377,-0.000476,-0.001209,-0.000624,0.000684,-0.000571,0.000268,-0.000689,-0.000733,0.001155,0.000902,-0.00204,0.000148,-0.000636
2020-09-04 19:58:00,-0.000664,-0.000576,-0.000763,-0.000435,-0.000719,-0.001201,-0.001157,-0.000974,-6.5e-05,0.000284,-0.000529,-0.000516,3.4e-05,-2.4e-05,0.000367,-0.001584,-0.000901,-0.000265,9.7e-05,-0.000401,-0.000563,-0.000959,-0.000835,-0.001093,-0.00115,-0.000488,-0.001819,-0.000759,0.000293,-0.000759,-0.001515,-0.000672,-0.001159,-0.00014,-0.00083,-0.000724,-0.000402,-0.000724,-0.001013,-0.000341,-0.000467,-0.000652,-0.00069,-0.001132,-0.000685,-0.000534,-0.002931,-0.000819,-0.00116,-0.000215,...,-0.000783,-0.00048,0.0,-0.001098,0.000131,-0.001793,-0.001089,-0.00136,-0.00176,-0.001506,-0.001612,0.00039,0.0,-0.001569,0.0,-0.000985,-0.001489,-0.000188,0.0,-0.001026,-0.000815,-0.002159,-0.001335,0.000145,-0.000385,-0.000271,-0.001516,-0.00072,0.00034,-0.000726,-0.000808,-0.002374,-0.001818,-0.001834,0.0,-0.001688,-0.00086,-0.001637,0.000241,-0.001007,-0.000114,-0.000856,-0.000536,-0.000393,-0.000305,-0.001104,-0.000691,-0.00075,-0.000594,-0.000381
2020-09-04 19:57:00,-0.000664,0.000192,-0.000327,3.6e-05,8.5e-05,0.000379,0.001014,0.00073,0.00065,0.0,0.000529,-0.000323,-0.000168,-0.001013,-0.000245,5.3e-05,0.00048,0.0,0.0,-0.000902,-0.000469,0.0,0.000246,0.000737,0.000115,0.000773,0.0,0.000443,-0.000733,-0.000632,-9.5e-05,0.0,0.000479,-0.000384,-0.001148,-0.000931,0.0,-0.000161,-0.000276,-0.00017,0.000684,0.00029,0.000788,-0.000822,0.000685,-0.000133,-0.001152,-0.000164,0.000348,-0.000323,...,0.000261,-0.000479,0.0,2.6e-05,0.000431,0.000735,0.0,-0.000226,0.001054,0.001003,0.0,-0.000293,5.7e-05,-0.000224,0.000295,-0.000383,0.0,-0.000712,0.0,9.8e-05,-0.000272,-0.000479,-0.000254,-0.00029,-0.000385,-0.000541,0.000378,0.000405,-0.000777,-0.000207,0.000606,0.00049,3.8e-05,0.000262,-0.000433,-0.000919,0.000344,0.000208,-0.000483,0.000479,0.001482,-7.1e-05,0.000804,0.000393,-0.000671,-0.000263,0.0,0.000357,-0.000297,6.4e-05
2020-09-04 19:56:00,-0.000241,0.000192,0.0,5.4e-05,-0.000127,-0.00019,0.000548,-0.002799,-0.000488,0.000853,-0.000529,-0.000554,0.000201,-0.00129,-0.000611,0.000369,-0.0006,0.000177,0.000776,0.0001,9.2e-05,0.000106,3e-06,-0.000555,0.000574,-0.001423,0.0,-0.000569,0.001099,-0.000253,0.0,-0.000168,-0.000958,-0.000244,0.000446,0.0,-0.000402,0.000141,0.00046,-0.000936,-0.001082,-0.000907,-0.00069,0.000205,0.0,0.000133,0.000247,-0.001064,-0.00058,0.000108,...,0.000653,0.000479,0.000536,0.000575,0.000254,-0.000256,6.2e-05,0.000792,-0.000264,-0.001004,0.000878,0.00039,-2.9e-05,0.000783,-0.001031,-0.000219,0.000165,-0.000487,0.00151,-4.9e-05,0.000543,0.000479,-0.000476,0.000724,-0.00014,0.000541,-0.000151,-0.000315,0.000727,0.0,-0.000404,0.0,-0.000681,-0.001046,0.000217,0.000459,-0.000172,-0.000565,0.001449,9.6e-05,0.0,-0.000285,-0.000268,-0.001278,-0.000488,-0.00021,-0.00046,-0.000928,-0.000445,-3.2e-05
2020-09-04 19:55:00,-0.000362,0.0,-0.000326,-5.4e-05,0.001057,0.000253,-0.000183,0.001335,-0.000553,0.000569,-0.000793,0.00018,0.000469,0.000736,0.000489,0.000357,0.00054,0.000796,-0.000453,0.0002,0.000377,0.000852,-0.002481,-0.001005,0.002414,0.000671,0.001362,0.000822,0.000367,0.000694,0.001135,-0.000504,8e-05,3.5e-05,-0.000191,0.0,0.000351,0.001226,-0.000184,0.0,0.000661,0.000327,0.000788,-0.000411,0.000343,0.0,-0.001152,0.001227,-0.001043,0.000753,...,-0.000653,-0.001918,-0.00161,0.000209,-0.001077,0.001215,-0.00059,-0.000679,-0.000308,0.000501,0.0,-0.000488,0.000891,-0.000336,-0.000932,0.0,0.000587,0.001461,0.000748,0.000293,-0.000407,-0.000399,0.000634,-0.000797,0.001258,-0.000135,8e-06,0.00018,-0.000194,0.000829,0.000404,8.2e-05,0.000453,-0.001254,-0.000217,0.000612,-0.000515,-0.000505,-0.000484,0.001174,0.000228,0.000784,0.0,-0.000295,0.0,0.0,0.000153,-0.000285,0.000148,0.000477


In [13]:
X = df.to_numpy()

In [14]:
X.shape

(7792, 505)

#### Build an environment to represent this movement

In [17]:
class StockMarketEnv(py_environment.PyEnvironment):
    '''
    Observation: The observation should be a (90,505) matrix
    Action: A (505) vector with probabilties from 0 1, max 10 are encoded as 1, all others are 0
    Reward: dot product of the (505,1) top 10 choices with the next (1,505) returns
    '''
    def __init__(self, X):
        self._action_spec = array_spec.BoundedArraySpec(
                                    shape=(1020,), dtype=np.float32, minimum=0, maximum=1, name='action')
        self._observation_spec = array_spec.BoundedArraySpec(
                                    shape=(90,505), dtype=np.float32, minimum=-10, maximum=10 ,name='observation')
        self._X = X
        self._state = np.array(self._X[:90], dtype=np.float32)
        self._i = 0
        self._episode_ended = False
        
#         self._step_type_spec = array_spec.BoundedArraySpec(
#                                     shape=(1,), dtype=np.int32, name='step_type')
#         self._reward_spec = array_spec.BoundedArraySpec(
#                                     shape=(1,), dtype=np.float32, name='reward')
#         self._discount_spec = array_spec.BoundedArraySpec(
#                                     shape=(1,), dtype=np.float32, name='discount')

    def action_spec(self):
        return self._action_spec

    def observation_spec(self):
        return self._observation_spec

    def _reset(self):
        self._state = np.array(self._X[:90], dtype=np.float32) ## input array
        self._i = 0
        self._episode_ended = False
        return ts.restart(self._state)

    def _step(self, action):
        '''
        Given a state array:
            - Choose the top ten stocks
            - Compute the reward by taking the dot product
            - Update the new state by taking the next timestep
            - Return the ts.transition(new_state, reward, discount=1)
        '''
        if self._episode_ended:
            # The last action ended the episode. Ignore the current action and start
            # a new episode.
            return self.reset()
        
        action_state = action.copy()
        action_state[action_state.argsort()[-10:]] = 1
        
        mask = np.ones(action_state.shape, bool)
        mask[action_state.argsort()[-10:]] = False
        action_state[mask] = 0
        
        reward_state = np.array(self._X[91+self._i], dtype=np.float64)
        
        ## modify for hold and sell
        # reward = np.dot(reward_state, action_state)[0]
        reward_pos = np.dot(reward_state, action_state[:505])
        reward_neg = np.dot(-reward_state, action_state[515:])
        reward = reward_pos + reward_neg
        # 10 possible action states for holding cash, idx [505:515] -> zero reward
        
        self._i += 1
        
        if self._i + 91 >= self._X.shape[0]:
            self._episode_ended = True
        
        self._state = np.array(self._X[self._i:90+self._i], dtype=np.float64)
        
        if self._episode_ended:
            return ts.termination(np.array(self._state, dtype=np.float32), reward=np.array(reward, dtype=np.float32))
        else:
            return ts.transition(np.array(self._state, dtype=np.float32), reward=np.array(reward, dtype=np.float32), discount=1.0)

In [18]:
environment = StockMarketEnv(X)
utils.validate_py_environment(environment, episodes=5)

##### Train a DQN to learn on it

In [49]:
# Important interval values for training speed
# Set lower to increase speed of training 

num_iterations = 10000 # @param {type:"integer"}
log_interval = 500 # @param {type:"integer"}=
num_eval_episodes = 1 # @param {type:"integer"}
eval_interval = 1000 # @param {type:"integer"}

In [22]:
initial_collect_steps = 10000 # @param {type:"integer"}
collect_steps_per_iteration = 1 # @param {type:"integer"}
replay_buffer_capacity = 10000 # @param {type:"integer"}

batch_size = 256 # @param {type:"integer"}

critic_learning_rate = 3e-4 # @param {type:"number"}
actor_learning_rate = 3e-4 # @param {type:"number"}
alpha_learning_rate = 3e-4 # @param {type:"number"}
target_update_tau = 0.005 # @param {type:"number"}
target_update_period = 1 # @param {type:"number"}
gamma = 0.99 # @param {type:"number"}
reward_scale_factor = 1.0 # @param {type:"number"}

actor_fc_layer_params = (505, 1020)
critic_joint_fc_layer_params = (505, 1020)



policy_save_interval = 5000 # @param {type:"integer"}

In [23]:
X.shape[0], X.shape[0]//2

(7792, 3896)

In [24]:
# Code to produce tf environments
# train_py_env = StockMarketEnv(X[:X.shape[0]//2])
# eval_py_env = StockMarketEnv(X[X.shape[0]//2:])
# train_env = tf_py_environment.TFPyEnvironment(train_py_env)
# eval_env = tf_py_environment.TFPyEnvironment(eval_py_env)

# Code to produce py environments 
train_env = StockMarketEnv(X[:X.shape[0]//2])
eval_env = StockMarketEnv(X[X.shape[0]//2:])

In [25]:
print('Observation Spec:')
print(train_env.time_step_spec().observation)
print('Action Spec:')
print(train_env.action_spec())

Observation Spec:
BoundedArraySpec(shape=(90, 505), dtype=dtype('float32'), name='observation', minimum=-10.0, maximum=10.0)
Action Spec:
BoundedArraySpec(shape=(1020,), dtype=dtype('float32'), name='action', minimum=0.0, maximum=1.0)


In [27]:
use_gpu = False 

strategy = strategy_utils.get_strategy(tpu=False, use_gpu=use_gpu)

In [28]:
observation_spec, action_spec, time_step_spec = (spec_utils.get_tensor_specs(train_env))

with strategy.scope():
    critic_net = critic_network.CriticNetwork((observation_spec, action_spec),
                                            observation_fc_layer_params=None,
                                            action_fc_layer_params=None,
                                            joint_fc_layer_params=critic_joint_fc_layer_params,
                                            kernel_initializer='glorot_uniform',
                                            last_kernel_initializer='glorot_uniform')

In [29]:
action_spec

BoundedTensorSpec(shape=(1020,), dtype=tf.float32, name='action', minimum=array(0., dtype=float32), maximum=array(1., dtype=float32))

In [30]:
observation_spec

BoundedTensorSpec(shape=(90, 505), dtype=tf.float32, name='observation', minimum=array(-10., dtype=float32), maximum=array(10., dtype=float32))

In [31]:
actor_fc_layer_params

(505, 1020)

In [32]:
with strategy.scope():
    actor_net = actor_distribution_network.ActorDistributionNetwork(observation_spec,
                                                                    action_spec,
                                                                    fc_layer_params=actor_fc_layer_params,
                                                                    continuous_projection_net=(
                                              tanh_normal_projection_network.TanhNormalProjectionNetwork))

In [33]:
time_step_spec

TimeStep(step_type=TensorSpec(shape=(), dtype=tf.int32, name='step_type'), reward=TensorSpec(shape=(), dtype=tf.float32, name='reward'), discount=BoundedTensorSpec(shape=(), dtype=tf.float32, name='discount', minimum=array(0., dtype=float32), maximum=array(1., dtype=float32)), observation=BoundedTensorSpec(shape=(90, 505), dtype=tf.float32, name='observation', minimum=array(-10., dtype=float32), maximum=array(10., dtype=float32)))

In [34]:
with strategy.scope():
    train_step = train_utils.create_train_step()

    tf_agent = sac_agent.SacAgent(
                                time_step_spec,
                                action_spec,
                                actor_network=actor_net,
                                critic_network=critic_net,
                                actor_optimizer=tf.compat.v1.train.AdamOptimizer(
                                                                learning_rate=actor_learning_rate),
                                critic_optimizer=tf.compat.v1.train.AdamOptimizer(
                                                                learning_rate=critic_learning_rate),
                                alpha_optimizer=tf.compat.v1.train.AdamOptimizer(
                                                                learning_rate=alpha_learning_rate),
                                target_update_tau=target_update_tau,
                                target_update_period=target_update_period,
                                td_errors_loss_fn=tf.math.squared_difference,
                                gamma=gamma,
                                reward_scale_factor=reward_scale_factor,
                                train_step_counter=train_step)

    tf_agent.initialize()

In [35]:
table_name = 'uniform_table'
table = reverb.Table(
    table_name,
    max_size=replay_buffer_capacity,
    sampler=reverb.selectors.Uniform(),
    remover=reverb.selectors.Fifo(),
    rate_limiter=reverb.rate_limiters.MinSize(1))

reverb_server = reverb.Server([table])

In [36]:
reverb_replay = reverb_replay_buffer.ReverbReplayBuffer(
    tf_agent.collect_data_spec,
    sequence_length=2,
    table_name=table_name,
    local_server=reverb_server)

In [37]:
dataset = reverb_replay.as_dataset(
      sample_batch_size=batch_size, num_steps=2).prefetch(50)
experience_dataset_fn = lambda: dataset

In [38]:
tf_eval_policy = tf_agent.policy
eval_policy = py_tf_eager_policy.PyTFEagerPolicy(
  tf_eval_policy, use_tf_function=True)

In [39]:
tf_collect_policy = tf_agent.collect_policy
collect_policy = py_tf_eager_policy.PyTFEagerPolicy(
  tf_collect_policy, use_tf_function=True)

In [40]:
rb_observer = reverb_utils.ReverbAddTrajectoryObserver(
  reverb_replay.py_client,
  table_name,
  sequence_length=2,
  stride_length=1)

In [41]:
random_policy = random_py_policy.RandomPyPolicy(
  train_env.time_step_spec(), train_env.action_spec())

In [42]:
initial_collect_actor = actor.Actor(
              train_env,
              random_policy,
              train_step,
              steps_per_run=initial_collect_steps,
              observers=[rb_observer])
initial_collect_actor.run()

In [43]:
env_step_metric = py_metrics.EnvironmentSteps()
collect_actor = actor.Actor(
                  train_env,
                  collect_policy,
                  train_step,
                  steps_per_run=1,
                  metrics=actor.collect_metrics(10),
                  summary_dir=os.path.join(tempdir, learner.TRAIN_DIR),
                  observers=[rb_observer, env_step_metric])

In [44]:
eval_actor = actor.Actor(
  eval_env,
  eval_policy,
  train_step,
  episodes_per_run=num_eval_episodes,
  metrics=actor.eval_metrics(num_eval_episodes),
  summary_dir=os.path.join(tempdir, 'eval'),
)

In [45]:
saved_model_dir = os.path.join(tempdir, learner.POLICY_SAVED_MODEL_DIR)

# Triggers to save the agent's policy checkpoints.
learning_triggers = [
    triggers.PolicySavedModelTrigger(
        saved_model_dir,
        tf_agent,
        train_step,
        interval=policy_save_interval),
    triggers.StepPerSecondLogTrigger(train_step, interval=1000),
]

agent_learner = learner.Learner(
  tempdir,
  train_step,
  tf_agent,
  experience_dataset_fn,
  triggers=learning_triggers)

Instructions for updating:
Do not pass `graph_parents`.  They will  no longer be used.




Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.


Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.


INFO:tensorflow:Assets written to: /tmp/policies/collect_policy/assets


INFO:tensorflow:Assets written to: /tmp/policies/collect_policy/assets


INFO:tensorflow:Assets written to: /tmp/policies/greedy_policy/assets


INFO:tensorflow:Assets written to: /tmp/policies/greedy_policy/assets


INFO:tensorflow:Assets written to: /tmp/policies/policy/assets


INFO:tensorflow:Assets written to: /tmp/policies/policy/assets


In [46]:
%%time
def get_eval_metrics():
    eval_actor.run()
    results = {}
    for metric in eval_actor.metrics:
        results[metric.name] = metric.result()
    return results

metrics = get_eval_metrics()

In [47]:
%%time
def log_eval_metrics(step, metrics):
    eval_results = (', ').join(
              '{} = {:.6f}'.format(name, result) for name, result in metrics.items())
    print('step = {0}: {1}'.format(step, eval_results))

log_eval_metrics(0, metrics)

step = 0: AverageReturn = 0.135403, AverageEpisodeLength = 3805.000000


In [None]:
#@test {"skip": true}
try:
    %%time
except:
    pass

start = time.time()
# Reset the train step
tf_agent.train_step_counter.assign(0)

# Evaluate the agent's policy once before training.
avg_return = metrics["AverageReturn"]
returns = [avg_return]

for _ in range(num_iterations):
    # Training.
    collect_actor.run()
    loss_info = agent_learner.run(iterations=1)

    # Evaluating.
    step = agent_learner.train_step_numpy

    if eval_interval and step % eval_interval == 0:
        metrics = get_eval_metrics()
        log_eval_metrics(step, metrics)
        returns.append(metrics["AverageReturn"])

    if log_interval and step % log_interval == 0:
        print('step = {0}: loss = {1}'.format(step, loss_info.loss.numpy()))
        minsec = divmod((time.time() - start), 60)
        hourmin = divmod(minsec, 60)
        print(f"Time elasped: {hourmin[0]}:{hourmin[1]}:{minsec[1]}")

rb_observer.close()
reverb_server.stop()

CPU times: user 3 µs, sys: 0 ns, total: 3 µs
Wall time: 6.44 µs


In [None]:
#@test {"skip": true}

steps = range(0, num_iterations + 1, eval_interval)
plt.plot(steps, returns)
plt.ylabel('Average Return')
plt.xlabel('Step')
plt.ylim()