Skip to content

Commit

Permalink
Multimodal obs (thu-ml#38, thu-ml#27, thu-ml#25)
Browse files Browse the repository at this point in the history
  • Loading branch information
Trinkle23897 committed Apr 28, 2020
1 parent cf4a2ee commit 5b69975
Show file tree
Hide file tree
Showing 9 changed files with 212 additions and 52 deletions.
15 changes: 13 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
---

[![PyPI](https://img.shields.io/pypi/v/tianshou)](https://pypi.org/project/tianshou/)
[![Documentation Status](https://readthedocs.org/projects/tianshou/badge/?version=latest)](https://tianshou.readthedocs.io)
[![Unittest](https://github.com/thu-ml/tianshou/workflows/Unittest/badge.svg?branch=master)](https://github.com/thu-ml/tianshou/actions)
[![codecov](https://img.shields.io/codecov/c/gh/thu-ml/tianshou)](https://codecov.io/gh/thu-ml/tianshou)
[![Documentation Status](https://readthedocs.org/projects/tianshou/badge/?version=latest)](https://tianshou.readthedocs.io)
[![GitHub issues](https://img.shields.io/github/issues/thu-ml/tianshou)](https://github.com/thu-ml/tianshou/issues)
[![GitHub stars](https://img.shields.io/github/stars/thu-ml/tianshou)](https://github.com/thu-ml/tianshou/stargazers)
[![GitHub forks](https://img.shields.io/github/forks/thu-ml/tianshou)](https://github.com/thu-ml/tianshou/network)
Expand Down Expand Up @@ -40,7 +40,7 @@ In Chinese, Tianshou means the innate talent, not taught by others. Tianshou is
Tianshou is currently hosted on [PyPI](https://pypi.org/project/tianshou/). It requires Python >= 3.6. You can simply install Tianshou with the following command:

```bash
pip3 install tianshou -U
pip3 install tianshou
```

You can also install with the newest version through GitHub:
Expand All @@ -49,6 +49,17 @@ You can also install with the newest version through GitHub:
pip3 install git+https://github.com/thu-ml/tianshou.git@master
```

If you use Anaconda or Miniconda, you can install Tianshou through the following command lines:

```bash
# create a new virtualenv and install pip, change the env name if you like
conda create -n myenv pip
# activate the environment
conda activate myenv
# install tianshou
pip install tianshou
```

After installation, open your python console and type

```python
Expand Down
12 changes: 11 additions & 1 deletion docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,23 @@ Installation
Tianshou is currently hosted on `PyPI <https://pypi.org/project/tianshou/>`_. You can simply install Tianshou with the following command:
::

pip3 install tianshou -U
pip3 install tianshou

You can also install with the newest version through GitHub:
::

pip3 install git+https://github.com/thu-ml/tianshou.git@master

If you use Anaconda or Miniconda, you can install Tianshou through the following command lines:
::

# create a new virtualenv and install pip, change the env name if you like
conda create -n myenv pip
# activate the environment
conda activate myenv
# install tianshou
pip install tianshou

After installation, open your python console and type
::

Expand Down
21 changes: 16 additions & 5 deletions test/base/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,16 @@


class MyTestEnv(gym.Env):
def __init__(self, size, sleep=0):
def __init__(self, size, sleep=0, dict_state=False):
self.size = size
self.sleep = sleep
self.dict_state = dict_state
self.reset()

def reset(self, state=0):
self.done = False
self.index = state
return self.index
return {'index': self.index} if self.dict_state else self.index

def step(self, action):
if self.done:
Expand All @@ -20,11 +21,21 @@ def step(self, action):
time.sleep(self.sleep)
if self.index == self.size:
self.done = True
return self.index, 0, True, {}
if self.dict_state:
return {'index': self.index}, 0, True, {}
else:
return self.index, 0, True, {}
if action == 0:
self.index = max(self.index - 1, 0)
return self.index, 0, False, {}
if self.dict_state:
return {'index': self.index}, 0, False, {}
else:
return self.index, 0, False, {}
elif action == 1:
self.index += 1
self.done = self.index == self.size
return self.index, int(self.done), self.done, {'key': 1}
if self.dict_state:
return {'index': self.index}, int(self.done), self.done, \
{'key': 1}
else:
return self.index, int(self.done), self.done, {'key': 1}
2 changes: 1 addition & 1 deletion test/base/test_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def test_batch():
with pytest.raises(IndexError):
batch[2]
batch.obs = np.arange(5)
for i, b in enumerate(batch.split(1, permute=False)):
for i, b in enumerate(batch.split(1, shuffle=False)):
assert b.obs == batch[i].obs
print(batch)

Expand Down
26 changes: 24 additions & 2 deletions test/base/test_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from torch.utils.tensorboard import SummaryWriter

from tianshou.policy import BasePolicy
from tianshou.env import SubprocVectorEnv
from tianshou.env import VectorEnv, SubprocVectorEnv
from tianshou.data import Collector, Batch, ReplayBuffer

if __name__ == '__main__':
Expand All @@ -12,10 +12,13 @@


class MyPolicy(BasePolicy):
def __init__(self):
def __init__(self, dict_state=False):
super().__init__()
self.dict_state = dict_state

def forward(self, batch, state=None):
if self.dict_state:
return Batch(act=np.ones(batch.obs['index'].shape[0]))
return Batch(act=np.ones(batch.obs.shape[0]))

def learn(self):
Expand Down Expand Up @@ -75,5 +78,24 @@ def test_collector():
1, 2, 3, 1, 2, 3, 4, 1, 2, 3, 4, 5])


def test_collector_with_dict_state():
env = MyTestEnv(size=5, sleep=0, dict_state=True)
policy = MyPolicy(dict_state=True)
c0 = Collector(policy, env, ReplayBuffer(size=100))
c0.collect(n_step=3)
c0.collect(n_episode=3)
env_fns = [
lambda: MyTestEnv(size=2, sleep=0, dict_state=True),
lambda: MyTestEnv(size=3, sleep=0, dict_state=True),
lambda: MyTestEnv(size=4, sleep=0, dict_state=True),
lambda: MyTestEnv(size=5, sleep=0, dict_state=True),
]
envs = VectorEnv(env_fns)
c1 = Collector(policy, envs, ReplayBuffer(size=100))
c1.collect(n_step=10)
c1.collect(n_episode=[2, 1, 1, 2])


if __name__ == '__main__':
test_collector()
test_collector_with_dict_state()
63 changes: 52 additions & 11 deletions tianshou/data/batch.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import torch
import pprint
import numpy as np


Expand All @@ -23,7 +24,7 @@ class Batch(object):
)
In short, you can define a :class:`Batch` with any key-value pair. The
current implementation of Tianshou typically use 6 keys in
current implementation of Tianshou typically use 6 reserved keys in
:class:`~tianshou.data.Batch`:
* ``obs`` the observation of step :math:`t` ;
Expand Down Expand Up @@ -56,7 +57,7 @@ class Batch(object):
array([0, 11, 22, 0, 11, 22])
>>> # split whole data into multiple small batch
>>> for d in data.split(size=2, permute=False):
>>> for d in data.split(size=2, shuffle=False):
... print(d.obs, d.rew)
[ 0 11] [6 6]
[22 0] [6 6]
Expand All @@ -65,24 +66,56 @@ class Batch(object):

def __init__(self, **kwargs):
super().__init__()
self.__dict__.update(kwargs)
self._meta = {}
for k, v in kwargs.items():
if (isinstance(v, list) or isinstance(v, np.ndarray)) \
and len(v) > 0 and isinstance(v[0], dict) and k != 'info':
self._meta[k] = list(v[0].keys())
for k_ in v[0].keys():
k__ = '_' + k + '@' + k_
self.__dict__[k__] = np.array([
v[i][k_] for i in range(len(v))
])
elif isinstance(v, dict):
self._meta[k] = list(v.keys())
for k_ in v.keys():
k__ = '_' + k + '@' + k_
self.__dict__[k__] = v[k_]
else:
self.__dict__[k] = kwargs[k]

def __getitem__(self, index):
"""Return self[index]."""
if isinstance(index, str):
return self.__getattr__(index)
b = Batch()
for k in self.__dict__.keys():
if self.__dict__[k] is not None:
if k != '_meta' and self.__dict__[k] is not None:
b.__dict__.update(**{k: self.__dict__[k][index]})
b._meta = self._meta
return b

def __getattr__(self, key):
"""Return self.key"""
if key not in self._meta.keys():
if key not in self.__dict__.keys():
raise AttributeError(key)
return self.__dict__[key]
d = {}
for k_ in self._meta[key]:
k__ = '_' + key + '@' + k_
d[k_] = self.__dict__[k__]
return d

def __repr__(self):
"""Return str(self)."""
s = self.__class__.__name__ + '(\n'
flag = False
for k in sorted(self.__dict__.keys()):
if k[0] != '_' and self.__dict__[k] is not None:
for k in sorted(list(self.__dict__.keys()) + list(self._meta.keys())):
if k[0] != '_' and (self.__dict__.get(k, None) is not None or
k in self._meta.keys()):
rpl = '\n' + ' ' * (6 + len(k))
obj = str(self.__dict__[k]).replace('\n', rpl)
obj = pprint.pformat(self.__getattr__(k)).replace('\n', rpl)
s += f' {k}: {obj},\n'
flag = True
if flag:
Expand All @@ -91,10 +124,18 @@ def __repr__(self):
s = self.__class__.__name__ + '()\n'
return s

def keys(self):
"""Return self.keys()."""
return sorted([i for i in self.__dict__.keys() if i[0] != '_'] +
list(self._meta.keys()))

def append(self, batch):
"""Append a :class:`~tianshou.data.Batch` object to current batch."""
assert isinstance(batch, Batch), 'Only append Batch is allowed!'
for k in batch.__dict__.keys():
if k == '_meta':
self._meta.update(batch.__dict__[k])
continue
if batch.__dict__[k] is None:
continue
if not hasattr(self, k) or self.__dict__[k] is None:
Expand All @@ -117,22 +158,22 @@ def __len__(self):
"""Return len(self)."""
return min([
len(self.__dict__[k]) for k in self.__dict__.keys()
if self.__dict__[k] is not None])
if k != '_meta' and self.__dict__[k] is not None])

def split(self, size=None, permute=True):
def split(self, size=None, shuffle=True):
"""Split whole data into multiple small batch.
:param int size: if it is ``None``, it does not split the data batch;
otherwise it will divide the data batch with the given size.
Default to ``None``.
:param bool permute: randomly shuffle the entire data batch if it is
:param bool shuffle: randomly shuffle the entire data batch if it is
``True``, otherwise remain in the same. Default to ``True``.
"""
length = len(self)
if size is None:
size = length
temp = 0
if permute:
if shuffle:
index = np.random.permutation(length)
else:
index = np.arange(length)
Expand Down
Loading

0 comments on commit 5b69975

Please sign in to comment.