Skip to content

Commit

Permalink
store RNN hidden states in policy._state and add sample_avail in buff…
Browse files Browse the repository at this point in the history
…er (thu-ml#19)
  • Loading branch information
Trinkle23897 committed Jun 29, 2020
1 parent ef8c47b commit 0253f76
Show file tree
Hide file tree
Showing 4 changed files with 85 additions and 37 deletions.
8 changes: 7 additions & 1 deletion test/base/test_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,13 @@ def test_ignore_obs_next(size=10):

def test_stack(size=5, bufsize=9, stack_num=4):
env = MyTestEnv(size)
buf = ReplayBuffer(bufsize, stack_num)
buf = ReplayBuffer(bufsize, stack_num=stack_num)
buf2 = ReplayBuffer(bufsize, stack_num=stack_num, sample_avail=True)
obs = env.reset(1)
for i in range(15):
obs_next, rew, done, info = env.step(1)
buf.add(obs, 1, rew, done, None, info)
buf2.add(obs, 1, rew, done, None, info)
obs = obs_next
if done:
obs = env.reset(1)
Expand All @@ -75,6 +77,10 @@ def test_stack(size=5, bufsize=9, stack_num=4):
[1, 1, 1, 1], [1, 1, 1, 2], [1, 1, 2, 3],
[3, 3, 3, 3], [3, 3, 3, 4], [1, 1, 1, 1]]))
print(buf)
_, indice = buf2.sample(0)
assert indice == [2]
_, indice = buf2.sample(1)
assert indice.sum() == 2


def test_priortized_replaybuffer(size=32, bufsize=15):
Expand Down
28 changes: 15 additions & 13 deletions tianshou/data/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,10 +97,11 @@ class Batch:
function return 4 arguments, and the last one is ``info``);
* ``policy`` the data computed by policy in step :math:`t`;
:class:`Batch` object can be initialized using wide variety of arguments,
starting with the key/value pairs or dictionary, but also list and Numpy
arrays of :class:`dict` or Batch instances. In which case, each element
is considered as an individual sample and get stacked together:
:class:`~tianshou.data.Batch` object can be initialized using wide variety
of arguments, starting with the key/value pairs or dictionary, but also
list and Numpy arrays of :class:`dict` or Batch instances. In which case,
each element is considered as an individual sample and get stacked
together:
::
>>> import numpy as np
Expand All @@ -113,9 +114,9 @@ class Batch:
),
)
:class:`Batch` has the same API as a native Python :class:`dict`. In this
regard, one can access to stored data using string key, or iterate over
stored data:
:class:`~tianshou.data.Batch` has the same API as a native Python
:class:`dict`. In this regard, one can access to stored data using string
key, or iterate over stored data:
::
>>> from tianshou.data import Batch
Expand All @@ -128,8 +129,8 @@ class Batch:
b: [5, 5]
:class:`Batch` is also reproduce partially the Numpy API for arrays. You
can access or iterate over the individual samples, if any:
:class:`~tianshou.data.Batch` is also reproduce partially the Numpy API for
arrays. You can access or iterate over the individual samples, if any:
::
>>> import numpy as np
Expand Down Expand Up @@ -219,11 +220,12 @@ class Batch:
>>> len(data[0])
TypeError: Object of type 'Batch' has no len()
Convenience helpers are available to convert in-place the
stored data into Numpy arrays or Torch tensors.
Convenience helpers are available to convert in-place the stored data into
Numpy arrays or Torch tensors.
Finally, note that Batch instance are serializable and therefore Pickle
compatible. This is especially important for distributed sampling.
Finally, note that :class:`~tianshou.data.Batch` instance are serializable
and therefore Pickle compatible. This is especially important for
distributed sampling.
"""

def __init__(self,
Expand Down
63 changes: 50 additions & 13 deletions tianshou/data/buffer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import numpy as np
from typing import Any, Tuple, Union, Optional

from .batch import Batch, _create_value
from tianshou.data.batch import Batch, _create_value


class ReplayBuffer:
Expand Down Expand Up @@ -91,12 +91,27 @@ class ReplayBuffer:
[12. 13. 14. 15.]
[ 7. 7. 7. 8.]
[ 7. 7. 8. 9.]]
:param int size: the size of replay buffer.
:param int stack_num: the frame-stack sampling argument, should be greater
than 1, defaults to 0 (no stacking).
:param bool ignore_obs_next: whether to store obs_next, defaults to
``False``.
:param bool sample_avail: the parameter indicating sampling only available
index when using frame-stack sampling method, defaults to ``False``.
This feature is not supported in Prioritized Replay Buffer currently.
"""

def __init__(self, size: int, stack_num: Optional[int] = 0,
ignore_obs_next: bool = False, **kwargs) -> None:
ignore_obs_next: bool = False,
sample_avail: bool = False, **kwargs) -> None:
super().__init__()
self._maxsize = size
self._stack = stack_num
assert stack_num != 1, \
'stack_num should greater than 1'
self._avail = sample_avail and stack_num > 1
self._avail_index = []
self._save_s_ = not ignore_obs_next
self._index = 0
self._size = 0
Expand Down Expand Up @@ -146,7 +161,7 @@ def update(self, buffer: 'ReplayBuffer') -> None:
def add(self,
obs: Union[dict, Batch, np.ndarray],
act: Union[np.ndarray, float],
rew: float,
rew: Union[int, float],
done: bool,
obs_next: Optional[Union[dict, Batch, np.ndarray]] = None,
info: dict = {},
Expand All @@ -165,6 +180,23 @@ def add(self,
self._add_to_buffer('obs_next', obs_next)
self._add_to_buffer('info', info)
self._add_to_buffer('policy', policy)

# maintain available index for frame-stack sampling
if self._avail:
# update current frame
avail = sum(self.done[i] for i in range(
self._index - self._stack + 1, self._index)) == 0
if self._size < self._stack - 1:
avail = False
if avail and self._index not in self._avail_index:
self._avail_index.append(self._index)
elif not avail and self._index in self._avail_index:
self._avail_index.remove(self._index)
# remove the later available frame because of broken storage
t = (self._index + self._stack - 1) % self._maxsize
if t in self._avail_index:
self._avail_index.remove(t)

if self._maxsize > 0:
self._size = min(self._size + 1, self._maxsize)
self._index = (self._index + 1) % self._maxsize
Expand All @@ -175,6 +207,7 @@ def reset(self) -> None:
"""Clear all the data in replay buffer."""
self._index = 0
self._size = 0
self._avail_index = []

def sample(self, batch_size: int) -> Tuple[Batch, np.ndarray]:
"""Get a random sample from buffer with size equal to batch_size. \
Expand All @@ -183,12 +216,17 @@ def sample(self, batch_size: int) -> Tuple[Batch, np.ndarray]:
:return: Sample data and its corresponding index inside the buffer.
"""
if batch_size > 0:
indice = np.random.choice(self._size, batch_size)
_all = self._avail_index if self._avail else self._size
indice = np.random.choice(_all, batch_size)
else:
indice = np.concatenate([
np.arange(self._index, self._size),
np.arange(0, self._index),
])
if self._avail:
indice = np.array(self._avail_index)
else:
indice = np.concatenate([
np.arange(self._index, self._size),
np.arange(0, self._index),
])
assert len(indice) > 0, 'No available indice can be sampled.'
return self[indice], indice

def get(self, indice: Union[slice, int, np.integer, np.ndarray], key: str,
Expand Down Expand Up @@ -247,11 +285,10 @@ def __getitem__(self, index: Union[
return Batch(
obs=self.get(index, 'obs'),
act=self.act[index],
# act_=self.get(index, 'act'), # stacked action, for RNN
rew=self.rew[index],
done=self.done[index],
obs_next=self.get(index, 'obs_next'),
info=self.get(index, 'info', stack_num=0),
info=self.get(index, 'info'),
policy=self.get(index, 'policy')
)

Expand Down Expand Up @@ -317,7 +354,7 @@ def __init__(self, size: int, alpha: float, beta: float,
def add(self,
obs: Union[dict, np.ndarray],
act: Union[np.ndarray, float],
rew: float,
rew: Union[int, float],
done: bool,
obs_next: Optional[Union[dict, np.ndarray]] = None,
info: dict = {},
Expand Down Expand Up @@ -401,11 +438,11 @@ def update_weight(self, indice: Union[slice, np.ndarray],
- self.weight[indice].sum()
self.weight[indice] = np.power(np.abs(new_weight), self._alpha)

def __getitem__(self, index: Union[slice, np.ndarray]) -> Batch:
def __getitem__(self, index: Union[
slice, int, np.integer, np.ndarray]) -> Batch:
return Batch(
obs=self.get(index, 'obs'),
act=self.act[index],
# act_=self.get(index, 'act'), # stacked action, for RNN
rew=self.rew[index],
done=self.done[index],
obs_next=self.get(index, 'obs_next'),
Expand Down
23 changes: 13 additions & 10 deletions tianshou/data/collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,14 +200,8 @@ def _reset_state(self, id: Union[int, List[int]]) -> None:
return
if isinstance(self.state, list):
self.state[id] = None
elif isinstance(self.state, (dict, Batch)):
for k in self.state.keys():
if isinstance(self.state[k], list):
self.state[k][id] = None
elif isinstance(self.state[k], (torch.Tensor, np.ndarray)):
self.state[k][id] = 0
elif isinstance(self.state, (torch.Tensor, np.ndarray)):
self.state[id] = 0
elif isinstance(self.state, (Batch, torch.Tensor, np.ndarray)):
self.state[id] *= 0

def collect(self,
n_step: int = 0,
Expand Down Expand Up @@ -272,9 +266,18 @@ def collect(self,
else:
with torch.no_grad():
result = self.policy(batch, self.state)

# save hidden state to policy._state, in order to save into buffer
self.state = result.get('state', None)
self._policy = to_numpy(result.policy) \
if hasattr(result, 'policy') else [{}] * self.env_num
if hasattr(result, 'policy'):
self._policy = to_numpy(result.policy)
if self.state is not None:
self._policy._state = self.state
elif self.state is not None:
self._policy = Batch(_state=self.state)
else:
self._policy = [{}] * self.env_num

self._act = to_numpy(result.act)
if self._action_noise is not None:
self._act += self._action_noise(self._act.shape)
Expand Down

0 comments on commit 0253f76

Please sign in to comment.