Skip to content

Commit

Permalink
fix rnn (thu-ml#19), add __repr__, and fix thu-ml#26
Browse files Browse the repository at this point in the history
  • Loading branch information
Trinkle23897 committed Apr 9, 2020
1 parent 85ac150 commit d54c2ab
Show file tree
Hide file tree
Showing 10 changed files with 120 additions and 24 deletions.
1 change: 0 additions & 1 deletion .github/workflows/pytest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ jobs:
strategy:
matrix:
python-version: [3.6, 3.7, 3.8]

steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
Expand Down
4 changes: 2 additions & 2 deletions test/base/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@ def __init__(self, size, sleep=0):
self.sleep = sleep
self.reset()

def reset(self):
def reset(self, state=0):
self.done = False
self.index = 0
self.index = state
return self.index

def step(self, action):
Expand Down
1 change: 1 addition & 0 deletions test/base/test_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def test_batch():
batch.obs = np.arange(5)
for i, b in enumerate(batch.split(1, permute=False)):
assert b.obs == batch[i].obs
print(batch)


if __name__ == '__main__':
Expand Down
20 changes: 20 additions & 0 deletions test/base/test_buffer.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import numpy as np
from tianshou.data import ReplayBuffer

if __name__ == '__main__':
Expand Down Expand Up @@ -28,5 +29,24 @@ def test_replaybuffer(size=10, bufsize=20):
assert buf2[-1].obs == buf[4].obs


def test_stack(size=5, bufsize=9, stack_num=4):
env = MyTestEnv(size)
buf = ReplayBuffer(bufsize, stack_num)
obs = env.reset(1)
for i in range(15):
obs_next, rew, done, info = env.step(1)
buf.add(obs, 1, rew, done, None, info)
obs = obs_next
if done:
obs = env.reset(1)
indice = np.arange(len(buf))
assert abs(buf.get_stack(indice, 'obs') - np.array([
[1, 1, 1, 2], [1, 1, 2, 3], [1, 2, 3, 4],
[1, 1, 1, 1], [1, 1, 1, 2], [1, 1, 2, 3],
[3, 3, 3, 3], [3, 3, 3, 4], [1, 1, 1, 1]])).sum() < 1e-6
print(buf)


if __name__ == '__main__':
test_replaybuffer()
test_stack()
2 changes: 1 addition & 1 deletion test/discrete/test_drqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def get_args():
parser.add_argument('--stack-num', type=int, default=4)
parser.add_argument('--lr', type=float, default=1e-3)
parser.add_argument('--gamma', type=float, default=0.9)
parser.add_argument('--n-step', type=int, default=3)
parser.add_argument('--n-step', type=int, default=4)
parser.add_argument('--target-update-freq', type=int, default=320)
parser.add_argument('--epoch', type=int, default=100)
parser.add_argument('--step-per-epoch', type=int, default=1000)
Expand Down
12 changes: 7 additions & 5 deletions test/discrete/test_pg.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,27 +31,29 @@ def compute_return_base(batch, aa=None, bb=None, gamma=0.1):

def test_fn(size=2560):
policy = PGPolicy(None, None, None, discount_factor=0.1)
buf = ReplayBuffer(100)
buf.add(1, 1, 1, 1, 1)
fn = policy.process_fn
# fn = compute_return_base
batch = Batch(
done=np.array([1, 0, 0, 1, 0, 1, 0, 1.]),
rew=np.array([0, 1, 2, 3, 4, 5, 6, 7.]),
)
batch = fn(batch, None, None)
batch = fn(batch, buf, 0)
ans = np.array([0, 1.23, 2.3, 3, 4.5, 5, 6.7, 7])
assert abs(batch.returns - ans).sum() <= 1e-5
batch = Batch(
done=np.array([0, 1, 0, 1, 0, 1, 0.]),
rew=np.array([7, 6, 1, 2, 3, 4, 5.]),
)
batch = fn(batch, None, None)
batch = fn(batch, buf, 0)
ans = np.array([7.6, 6, 1.2, 2, 3.4, 4, 5])
assert abs(batch.returns - ans).sum() <= 1e-5
batch = Batch(
done=np.array([0, 1, 0, 1, 0, 0, 1.]),
rew=np.array([7, 6, 1, 2, 3, 4, 5.]),
)
batch = fn(batch, None, None)
batch = fn(batch, buf, 0)
ans = np.array([7.6, 6, 1.2, 2, 3.45, 4.5, 5])
assert abs(batch.returns - ans).sum() <= 1e-5
if __name__ == '__main__':
Expand All @@ -66,7 +68,7 @@ def test_fn(size=2560):
print(f'vanilla: {(time.time() - t) / cnt}')
t = time.time()
for _ in range(cnt):
policy.process_fn(batch, None, None)
policy.process_fn(batch, buf, 0)
print(f'policy: {(time.time() - t) / cnt}')


Expand Down Expand Up @@ -147,5 +149,5 @@ def stop_fn(x):


if __name__ == '__main__':
# test_fn()
test_fn()
test_pg()
16 changes: 16 additions & 0 deletions tianshou/data/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,22 @@ def __getitem__(self, index):
b.__dict__.update(**{k: self.__dict__[k][index]})
return b

def __repr__(self):
"""Return str(self)."""
s = self.__class__.__name__ + '(\n'
flag = False
for k in self.__dict__.keys():
if k[0] != '_' and self.__dict__[k] is not None:
rpl = '\n' + ' ' * (6 + len(k))
obj = str(self.__dict__[k]).replace('\n', rpl)
s += f' {k}: {obj},\n'
flag = True
if flag:
s += ')\n'
else:
s = self.__class__.__name__ + '()\n'
return s

def append(self, batch):
"""Append a :class:`~tianshou.data.Batch` object to current batch."""
assert isinstance(batch, Batch), 'Only append Batch is allowed!'
Expand Down
74 changes: 65 additions & 9 deletions tianshou/data/buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,34 @@ class ReplayBuffer(object):
>>> batch_data, indice = buf.sample(batch_size=4)
>>> batch_data.obs == buf[indice].obs
array([ True, True, True, True])
From version v0.2.2, :class:`~tianshou.data.ReplayBuffer` supports
frame_stack sampling, typically for RNN usage:
::
>>> buf = ReplayBuffer(size=9, stack_num=4)
>>> for i in range(16):
... done = i % 5 == 0
... buf.add(obs=i, act=i, rew=i, done=done, obs_next=0, info={})
>>> print(buf.obs)
[ 9. 10. 11. 12. 13. 14. 15. 7. 8.]
>>> print(buf.done)
[0. 1. 0. 0. 0. 0. 1. 0. 0.]
>>> index = np.arange(len(buf))
>>> print(buf.get_stack(index, 'obs'))
[[ 7. 7. 8. 9.]
[ 7. 8. 9. 10.]
[11. 11. 11. 11.]
[11. 11. 11. 12.]
[11. 11. 12. 13.]
[11. 12. 13. 14.]
[12. 13. 14. 15.]
[ 7. 7. 7. 7.]
[ 7. 7. 7. 8.]]
>>> # here is another way to get the stacked data
>>> # (stack only for obs and obs_next)
>>> sum(sum(buf.get_stack(index, 'obs') - buf[index].obs))
0.0
"""

def __init__(self, size, stack_num=0):
Expand All @@ -51,8 +79,26 @@ def __len__(self):
"""Return len(self)."""
return self._size

def __repr__(self):
"""Return str(self)."""
s = self.__class__.__name__ + '(\n'
flag = False
for k in self.__dict__.keys():
if k[0] != '_' and self.__dict__[k] is not None:
rpl = '\n' + ' ' * (6 + len(k))
obj = str(self.__dict__[k]).replace('\n', rpl)
s += f' {k}: {obj},\n'
flag = True
if flag:
s += ')\n'
else:
s = self.__class__.__name__ + '()\n'
return s

def _add_to_buffer(self, name, inst):
if inst is None:
if getattr(self, name, None) is None:
self.__dict__[name] = None
return
if self.__dict__.get(name, None) is None:
if isinstance(inst, np.ndarray):
Expand All @@ -72,13 +118,14 @@ def update(self, buffer):
i = begin = buffer._index % len(buffer)
while True:
self.add(
buffer.obs[i], buffer.act[i], buffer.rew[i],
buffer.done[i], buffer.obs_next[i], buffer.info[i])
buffer.obs[i], buffer.act[i], buffer.rew[i], buffer.done[i],
None if buffer.obs_next is None else buffer.obs_next[i],
buffer.info[i])
i = (i + 1) % len(buffer)
if i == begin:
break

def add(self, obs, act, rew, done, obs_next=0, info={}, weight=None):
def add(self, obs, act, rew, done, obs_next=None, info={}, weight=None):
"""Add a batch of data into replay buffer."""
assert isinstance(info, dict), \
'You should return a dict in the last argument of env.step().'
Expand All @@ -97,7 +144,6 @@ def add(self, obs, act, rew, done, obs_next=0, info={}, weight=None):
def reset(self):
"""Clear all the data in replay buffer."""
self._index = self._size = 0
self.indice = []

def sample(self, batch_size):
"""Get a random sample from buffer with size equal to batch_size. \
Expand All @@ -114,28 +160,38 @@ def sample(self, batch_size):
])
return self[indice], indice

def _get_stack(self, indice, key):
def get_stack(self, indice, key):
"""Return the stacked result, e.g. [s_{t-3}, s_{t-2}, s_{t-1}, s_t],
where s is self.key, t is indice. The stack_num (here equals to 4) is
given from buffer initialization procedure.
"""
if self.__dict__.get(key, None) is None:
return None
if self._stack == 0:
return self.__dict__[key][indice]
stack = []
# set last frame done to True
last_index = (self._index - 1 + self._size) % self._size
last_done, self.done[last_index] = self.done[last_index], True
for i in range(self._stack):
stack = [self.__dict__[key][indice]] + stack
indice = indice - 1 + self.done[indice - 1].astype(np.int)
indice[indice == -1] = self._size - 1
pre_indice = indice - 1
pre_indice[pre_indice == -1] = self._size - 1
indice = pre_indice + self.done[pre_indice].astype(np.int)
indice[indice == self._size] = 0
self.done[last_index] = last_done
return np.stack(stack, axis=1)

def __getitem__(self, index):
"""Return a data batch: self[index]. If stack_num is set to be > 0,
return the stacked obs and obs_next with shape [batch, len, ...].
"""
return Batch(
obs=self._get_stack(index, 'obs'),
obs=self.get_stack(index, 'obs'),
act=self.act[index],
rew=self.rew[index],
done=self.done[index],
obs_next=self._get_stack(index, 'obs_next'),
obs_next=self.get_stack(index, 'obs_next'),
info=self.info[index]
)

Expand Down
12 changes: 6 additions & 6 deletions tianshou/data/collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
import torch
import warnings
import numpy as np
from tianshou.env import BaseVectorEnv
from tianshou.data import Batch, ReplayBuffer, \
ListReplayBuffer

from tianshou.utils import MovAvg
from tianshou.env import BaseVectorEnv
from tianshou.data import Batch, ReplayBuffer, ListReplayBuffer


class Collector(object):
Expand All @@ -22,8 +22,8 @@ class Collector(object):
:class:`~tianshou.data.ReplayBuffer`.
:param int stat_size: for the moving average of recording speed, defaults
to 100.
:param bool store_obs_next: whether to store the obs_next to replay
buffer, defaults to ``True``.
:param bool store_obs_next: store the next observation to replay buffer or
not, defaults to ``True``.
Example:
::
Expand Down Expand Up @@ -302,7 +302,7 @@ def collect(self, n_step=0, n_episode=0, render=None):
self._obs = obs_next
if self._multi_env:
cur_episode = sum(cur_episode)
duration = time.time() - start_time
duration = max(time.time() - start_time, 1e-9)
self.step_speed.add(cur_step / duration)
self.episode_speed.add(cur_episode / duration)
self.collect_step += cur_step
Expand Down
2 changes: 2 additions & 0 deletions tianshou/policy/pg.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ def process_fn(self, batch, buffer, indice):
discount factor, :math:`\gamma \in [0, 1]`.
"""
batch.returns = self._vanilla_returns(batch)
if getattr(batch, 'obs_next', None) is None:
batch.obs_next = buffer[(indice + 1) % len(buffer)].obs
# batch.returns = self._vectorized_returns(batch)
return batch

Expand Down

0 comments on commit d54c2ab

Please sign in to comment.