* This note provides a snippet of using replay buffers of Tf-agents, 
    * See [tutorials](https://github.com/tensorflow/agents/blob/master/docs/tutorials/5_replay_buffers_tutorial.ipynb)

In [None]:
import tensorflow as tf
from tf_agents.replay_buffers import tf_uniform_replay_buffer

In [None]:
def createInstanceOfReplayBuffer(nPv = 3, nMv = 2, batch_size = 1, max_length=2**10):
    data_spec =  (
        tf.TensorSpec([nPv,], tf.float32, 'observation')
        , tf.TensorSpec([nMv,], tf.float32, 'action')
        , tf.TensorSpec([nPv,], tf.float32, 'next_observation')
        , tf.TensorSpec([], tf.float32, 'reward')
        )

    replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
        data_spec,
        batch_size=batch_size,
        max_length=max_length)
    
    return replay_buffer

In [None]:
def collectData(nMv = 2, nPv = 3, batch_size = 1, nSample = 2**7):
    for _ in range(nSample):
        observationBatch = tf.random.normal([batch_size, nPv])
        actionBatch = tf.random.normal([batch_size, nMv])
        nextObservationBatch = tf.random.normal([batch_size, nPv])
        rewardBatch = tf.random.normal([batch_size,])
        yield (observationBatch, actionBatch, nextObservationBatch, rewardBatch)

In [None]:
def collecDataAidedByDynamicStepDriver():
    raise NotImplementedError()

create an instance of replay buffer:

In [None]:
replay_buffer = createInstanceOfReplayBuffer()

add batches of items in the replay buffer:

In [None]:
replay_buffer.clear()
for aBatch in collectData():
    replay_buffer.add_batch(aBatch)

read items from the buffer:

In [None]:
sample_batch_size = 2**5
num_steps = 1
dataset = replay_buffer.as_dataset(
    sample_batch_size=sample_batch_size
    , num_steps=num_steps)
trajectories, _ = iter(dataset).__next__()

In [None]:
# >>for trj in trajectories:
# >>    print(trj.shape)
#
# (32, 1, 3)
# (32, 1, 2)
# (32, 1, 3)
# (32, 1)

## Create a trajectory by running a closed-loop simulation of an environment and a policy:

See this tutorial: [Train a Deep Q Network with TF-Agents](https://tensorflow.google.cn/agents/tutorials/1_dqn_tutorial)

In [None]:
import tensorflow as tf

from tf_agents.environments import suite_gym
from tf_agents.environments import tf_py_environment

from tf_agents.policies import random_tf_policy
from tf_agents.trajectories import trajectory

In [None]:
env = tf_py_environment.TFPyEnvironment(suite_gym.load('CartPole-v0'))
policy = random_tf_policy.RandomTFPolicy(time_step_spec=env.time_step_spec(), 
                                         action_spec=env.action_spec())

In [None]:
env.reset()

time_step = env.current_time_step()
action_step = policy.action(time_step)
next_time_step = env.step(action_step.action)
traj = trajectory.from_transition(time_step, action_step, next_time_step)

## Collect trajectories by using a DynamicStepDriver instance:

See this tutorial: [Train a Deep Q Network with TF-Agents](https://tensorflow.google.cn/agents/tutorials/1_dqn_tutorial)

In [None]:
import tensorflow as tf

from tf_agents.drivers import dynamic_episode_driver
from tf_agents.drivers import dynamic_step_driver
from tf_agents.environments import suite_gym
from tf_agents.environments import tf_py_environment
from tf_agents.policies import random_tf_policy
from tf_agents.replay_buffers import tf_uniform_replay_buffer

Create a pair of environment and policy instances, which are converted by tensorflow wrappers:

In [None]:
env = tf_py_environment.TFPyEnvironment(suite_gym.load('CartPole-v0'))
policy = random_tf_policy.RandomTFPolicy(time_step_spec=env.time_step_spec(), 
                                         action_spec=env.action_spec())

Create an instance of replay buffer wrapped by tf-format:

In [None]:
replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
    data_spec = policy.collect_data_spec,
    batch_size=1)

Collect trajectories from closed loop between the environment and the policy:

In [None]:
driver = dynamic_step_driver.DynamicStepDriver(
    env,
    policy,
    observers=[replay_buffer.add_batch, ],
    num_steps=13)

nTrajectory = 3
for _ in range(nTrajectory):
    driver.run(env.reset())

Call a trajectory from the replay buffer:

In [None]:
for trj, _ in replay_buffer.as_dataset().__iter__():
    break

## Collect trajectories by using a DynamicEpisodeDriver instance:

See this document: [tf_agents.drivers.dynamic_episode_driver.DynamicEpisodeDriver](https://www.tensorflow.org/agents/api_docs/python/tf_agents/drivers/dynamic_episode_driver/DynamicEpisodeDriver)

In [None]:
import tensorflow as tf

from tf_agents.drivers import dynamic_episode_driver
from tf_agents.environments import suite_gym
from tf_agents.environments import tf_py_environment
from tf_agents.policies import random_tf_policy
#from tf_agents.replay_buffers import episodic_replay_buffer
from tf_agents.replay_buffers import tf_uniform_replay_buffer

Create a pair of environment and policy instances, which are converted by tensorflow wrappers:

In [None]:
env = tf_py_environment.TFPyEnvironment(suite_gym.load('CartPole-v0'))
policy = random_tf_policy.RandomTFPolicy(time_step_spec=env.time_step_spec(), 
                                         action_spec=env.action_spec())

Create an instance of replay buffer wrapped by tf-format:

In [None]:
replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
    data_spec = policy.collect_data_spec,
    batch_size = 1)

Collect trajectories from closed loop between the environment and the policy:

In [None]:
driver = dynamic_episode_driver.DynamicEpisodeDriver(
    env,
    policy,
    observers=[replay_buffer.add_batch, ],
    num_episodes=1)

replay_buffer.clear()

nTrajectory = 1
for _ in range(nTrajectory):
    driver.run(env.reset())

Call a trajectory from the replay buffer:

In [None]:
import itertools
for trj, _ in itertools.islice(replay_buffer.as_dataset(sample_batch_size=1, num_steps=3).__iter__(),10):
    print(trj.is_last())

## Sample episodes from episodic replay buffer without using any driver

+ [Using EpisodicReplayBuffer in TF-Agents](https://stackoverflow.com/questions/65397939/using-episodicreplaybuffer-in-tf-agents)
+ [episodic_replay_buffer.py](https://github.com/tensorflow/agents/blob/master/tf_agents/replay_buffers/episodic_replay_buffer.py)

In [None]:
import tensorflow as tf

from tf_agents.environments import suite_gym
from tf_agents.environments import tf_py_environment
from tf_agents.policies import random_tf_policy
from tf_agents.replay_buffers import episodic_replay_buffer
from tf_agents.trajectories import trajectory

Create a pair of environment and policy instances, which are converted by tensorflow wrappers:

In [None]:
env = tf_py_environment.TFPyEnvironment(suite_gym.load('CartPole-v0'))
policy = random_tf_policy.RandomTFPolicy(time_step_spec=env.time_step_spec(), 
                                         action_spec=env.action_spec())

Create an instance of replay buffer wrapped by tf-format:

In [None]:
replay_buffer = episodic_replay_buffer.EpisodicReplayBuffer(
    data_spec = policy.collect_data_spec,
    capacity = 1000,
    completed_only = True)

Collect trajectories from closed loop between the environment and the policy:

In [None]:
replay_buffer.clear()

collect_episodes_per_iteration = 3
for _ in range(collect_episodes_per_iteration):    
    
    id_eps = tf.constant((-1,), dtype = tf.int64)
    
    env.reset()    
    while True:
        time_step = env.current_time_step()
        if time_step.is_last():
            break
        else:
            action_step = policy.action(time_step)
            next_time_step = env.step(action_step.action)    
            traj = trajectory.from_transition(time_step, action_step, next_time_step)
            id_eps = replay_buffer.add_batch(traj, id_eps)

Call some trajectories from the replay buffer:

In [None]:
for trj, _ in replay_buffer.as_dataset(sample_batch_size=2**3, num_steps=2**4).__iter__():
    print(">>is first")
    print(trj.is_first())
    print(">>is last")
    print(trj.is_last())
    break

In [None]:
for trj, _ in replay_buffer.as_dataset().__iter__():
    print(">>is first")
    print(trj.is_first())
    print(">>is last")
    print(trj.is_last())
    print(">> discount")
    print(trj.discount)
    break