diff --git a/README.md b/README.md index 08a00019c..053e0f271 100644 --- a/README.md +++ b/README.md @@ -5,9 +5,9 @@ --- [![PyPI](https://img.shields.io/pypi/v/tianshou)](https://pypi.org/project/tianshou/) +[![Documentation Status](https://readthedocs.org/projects/tianshou/badge/?version=latest)](https://tianshou.readthedocs.io) [![Unittest](https://github.com/thu-ml/tianshou/workflows/Unittest/badge.svg?branch=master)](https://github.com/thu-ml/tianshou/actions) [![codecov](https://img.shields.io/codecov/c/gh/thu-ml/tianshou)](https://codecov.io/gh/thu-ml/tianshou) -[![Documentation Status](https://readthedocs.org/projects/tianshou/badge/?version=latest)](https://tianshou.readthedocs.io) [![GitHub issues](https://img.shields.io/github/issues/thu-ml/tianshou)](https://github.com/thu-ml/tianshou/issues) [![GitHub stars](https://img.shields.io/github/stars/thu-ml/tianshou)](https://github.com/thu-ml/tianshou/stargazers) [![GitHub forks](https://img.shields.io/github/forks/thu-ml/tianshou)](https://github.com/thu-ml/tianshou/network) @@ -40,7 +40,7 @@ In Chinese, Tianshou means the innate talent, not taught by others. Tianshou is Tianshou is currently hosted on [PyPI](https://pypi.org/project/tianshou/). It requires Python >= 3.6. You can simply install Tianshou with the following command: ```bash -pip3 install tianshou -U +pip3 install tianshou ``` You can also install with the newest version through GitHub: @@ -49,6 +49,17 @@ You can also install with the newest version through GitHub: pip3 install git+https://github.com/thu-ml/tianshou.git@master ``` +If you use Anaconda or Miniconda, you can install Tianshou through the following command lines: + +```bash +# create a new virtualenv and install pip, change the env name if you like +conda create -n myenv pip +# activate the environment +conda activate myenv +# install tianshou +pip install tianshou +``` + After installation, open your python console and type ```python diff --git a/docs/index.rst b/docs/index.rst index 312ec79db..949763f41 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -30,13 +30,23 @@ Installation Tianshou is currently hosted on `PyPI `_. You can simply install Tianshou with the following command: :: - pip3 install tianshou -U + pip3 install tianshou You can also install with the newest version through GitHub: :: pip3 install git+https://github.com/thu-ml/tianshou.git@master +If you use Anaconda or Miniconda, you can install Tianshou through the following command lines: +:: + + # create a new virtualenv and install pip, change the env name if you like + conda create -n myenv pip + # activate the environment + conda activate myenv + # install tianshou + pip install tianshou + After installation, open your python console and type :: diff --git a/test/base/env.py b/test/base/env.py index ec0ef618a..fbab50052 100644 --- a/test/base/env.py +++ b/test/base/env.py @@ -3,15 +3,16 @@ class MyTestEnv(gym.Env): - def __init__(self, size, sleep=0): + def __init__(self, size, sleep=0, dict_state=False): self.size = size self.sleep = sleep + self.dict_state = dict_state self.reset() def reset(self, state=0): self.done = False self.index = state - return self.index + return {'index': self.index} if self.dict_state else self.index def step(self, action): if self.done: @@ -20,11 +21,21 @@ def step(self, action): time.sleep(self.sleep) if self.index == self.size: self.done = True - return self.index, 0, True, {} + if self.dict_state: + return {'index': self.index}, 0, True, {} + else: + return self.index, 0, True, {} if action == 0: self.index = max(self.index - 1, 0) - return self.index, 0, False, {} + if self.dict_state: + return {'index': self.index}, 0, False, {} + else: + return self.index, 0, False, {} elif action == 1: self.index += 1 self.done = self.index == self.size - return self.index, int(self.done), self.done, {'key': 1} + if self.dict_state: + return {'index': self.index}, int(self.done), self.done, \ + {'key': 1} + else: + return self.index, int(self.done), self.done, {'key': 1} diff --git a/test/base/test_batch.py b/test/base/test_batch.py index 53a2cdc0a..6ef1e956f 100644 --- a/test/base/test_batch.py +++ b/test/base/test_batch.py @@ -15,7 +15,7 @@ def test_batch(): with pytest.raises(IndexError): batch[2] batch.obs = np.arange(5) - for i, b in enumerate(batch.split(1, permute=False)): + for i, b in enumerate(batch.split(1, shuffle=False)): assert b.obs == batch[i].obs print(batch) diff --git a/test/base/test_collector.py b/test/base/test_collector.py index 2b20756f0..a7ed671b3 100644 --- a/test/base/test_collector.py +++ b/test/base/test_collector.py @@ -2,7 +2,7 @@ from torch.utils.tensorboard import SummaryWriter from tianshou.policy import BasePolicy -from tianshou.env import SubprocVectorEnv +from tianshou.env import VectorEnv, SubprocVectorEnv from tianshou.data import Collector, Batch, ReplayBuffer if __name__ == '__main__': @@ -12,10 +12,13 @@ class MyPolicy(BasePolicy): - def __init__(self): + def __init__(self, dict_state=False): super().__init__() + self.dict_state = dict_state def forward(self, batch, state=None): + if self.dict_state: + return Batch(act=np.ones(batch.obs['index'].shape[0])) return Batch(act=np.ones(batch.obs.shape[0])) def learn(self): @@ -75,5 +78,24 @@ def test_collector(): 1, 2, 3, 1, 2, 3, 4, 1, 2, 3, 4, 5]) +def test_collector_with_dict_state(): + env = MyTestEnv(size=5, sleep=0, dict_state=True) + policy = MyPolicy(dict_state=True) + c0 = Collector(policy, env, ReplayBuffer(size=100)) + c0.collect(n_step=3) + c0.collect(n_episode=3) + env_fns = [ + lambda: MyTestEnv(size=2, sleep=0, dict_state=True), + lambda: MyTestEnv(size=3, sleep=0, dict_state=True), + lambda: MyTestEnv(size=4, sleep=0, dict_state=True), + lambda: MyTestEnv(size=5, sleep=0, dict_state=True), + ] + envs = VectorEnv(env_fns) + c1 = Collector(policy, envs, ReplayBuffer(size=100)) + c1.collect(n_step=10) + c1.collect(n_episode=[2, 1, 1, 2]) + + if __name__ == '__main__': test_collector() + test_collector_with_dict_state() diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index 6858773f3..616c5311f 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -1,4 +1,5 @@ import torch +import pprint import numpy as np @@ -23,7 +24,7 @@ class Batch(object): ) In short, you can define a :class:`Batch` with any key-value pair. The - current implementation of Tianshou typically use 6 keys in + current implementation of Tianshou typically use 6 reserved keys in :class:`~tianshou.data.Batch`: * ``obs`` the observation of step :math:`t` ; @@ -56,7 +57,7 @@ class Batch(object): array([0, 11, 22, 0, 11, 22]) >>> # split whole data into multiple small batch - >>> for d in data.split(size=2, permute=False): + >>> for d in data.split(size=2, shuffle=False): ... print(d.obs, d.rew) [ 0 11] [6 6] [22 0] [6 6] @@ -65,24 +66,56 @@ class Batch(object): def __init__(self, **kwargs): super().__init__() - self.__dict__.update(kwargs) + self._meta = {} + for k, v in kwargs.items(): + if (isinstance(v, list) or isinstance(v, np.ndarray)) \ + and len(v) > 0 and isinstance(v[0], dict) and k != 'info': + self._meta[k] = list(v[0].keys()) + for k_ in v[0].keys(): + k__ = '_' + k + '@' + k_ + self.__dict__[k__] = np.array([ + v[i][k_] for i in range(len(v)) + ]) + elif isinstance(v, dict): + self._meta[k] = list(v.keys()) + for k_ in v.keys(): + k__ = '_' + k + '@' + k_ + self.__dict__[k__] = v[k_] + else: + self.__dict__[k] = kwargs[k] def __getitem__(self, index): """Return self[index].""" + if isinstance(index, str): + return self.__getattr__(index) b = Batch() for k in self.__dict__.keys(): - if self.__dict__[k] is not None: + if k != '_meta' and self.__dict__[k] is not None: b.__dict__.update(**{k: self.__dict__[k][index]}) + b._meta = self._meta return b + def __getattr__(self, key): + """Return self.key""" + if key not in self._meta.keys(): + if key not in self.__dict__.keys(): + raise AttributeError(key) + return self.__dict__[key] + d = {} + for k_ in self._meta[key]: + k__ = '_' + key + '@' + k_ + d[k_] = self.__dict__[k__] + return d + def __repr__(self): """Return str(self).""" s = self.__class__.__name__ + '(\n' flag = False - for k in sorted(self.__dict__.keys()): - if k[0] != '_' and self.__dict__[k] is not None: + for k in sorted(list(self.__dict__.keys()) + list(self._meta.keys())): + if k[0] != '_' and (self.__dict__.get(k, None) is not None or + k in self._meta.keys()): rpl = '\n' + ' ' * (6 + len(k)) - obj = str(self.__dict__[k]).replace('\n', rpl) + obj = pprint.pformat(self.__getattr__(k)).replace('\n', rpl) s += f' {k}: {obj},\n' flag = True if flag: @@ -91,10 +124,18 @@ def __repr__(self): s = self.__class__.__name__ + '()\n' return s + def keys(self): + """Return self.keys().""" + return sorted([i for i in self.__dict__.keys() if i[0] != '_'] + + list(self._meta.keys())) + def append(self, batch): """Append a :class:`~tianshou.data.Batch` object to current batch.""" assert isinstance(batch, Batch), 'Only append Batch is allowed!' for k in batch.__dict__.keys(): + if k == '_meta': + self._meta.update(batch.__dict__[k]) + continue if batch.__dict__[k] is None: continue if not hasattr(self, k) or self.__dict__[k] is None: @@ -117,22 +158,22 @@ def __len__(self): """Return len(self).""" return min([ len(self.__dict__[k]) for k in self.__dict__.keys() - if self.__dict__[k] is not None]) + if k != '_meta' and self.__dict__[k] is not None]) - def split(self, size=None, permute=True): + def split(self, size=None, shuffle=True): """Split whole data into multiple small batch. :param int size: if it is ``None``, it does not split the data batch; otherwise it will divide the data batch with the given size. Default to ``None``. - :param bool permute: randomly shuffle the entire data batch if it is + :param bool shuffle: randomly shuffle the entire data batch if it is ``True``, otherwise remain in the same. Default to ``True``. """ length = len(self) if size is None: size = length temp = 0 - if permute: + if shuffle: index = np.random.permutation(length) else: index = np.arange(length) diff --git a/tianshou/data/buffer.py b/tianshou/data/buffer.py index 76c0df0db..db025ff42 100644 --- a/tianshou/data/buffer.py +++ b/tianshou/data/buffer.py @@ -1,3 +1,4 @@ +import pprint import numpy as np from tianshou.data.batch import Batch @@ -92,6 +93,7 @@ def __init__(self, size, stack_num=0, ignore_obs_next=False, **kwargs): self._maxsize = size self._stack = stack_num self._save_s_ = not ignore_obs_next + self._meta = {} self.reset() def __len__(self): @@ -102,10 +104,11 @@ def __repr__(self): """Return str(self).""" s = self.__class__.__name__ + '(\n' flag = False - for k in self.__dict__.keys(): - if k[0] != '_' and self.__dict__[k] is not None: + for k in sorted(list(self.__dict__.keys()) + list(self._meta.keys())): + if k[0] != '_' and (self.__dict__.get(k, None) is not None or + k in self._meta.keys()): rpl = '\n' + ' ' * (6 + len(k)) - obj = str(self.__dict__[k]).replace('\n', rpl) + obj = pprint.pformat(self.__getattr__(k)).replace('\n', rpl) s += f' {k}: {obj},\n' flag = True if flag: @@ -114,23 +117,51 @@ def __repr__(self): s = self.__class__.__name__ + '()\n' return s + def __getattr__(self, key): + """Return self.key""" + if key not in self._meta.keys(): + if key not in self.__dict__.keys(): + raise AttributeError(key) + return self.__dict__[key] + d = {} + for k_ in self._meta[key]: + k__ = '_' + key + '@' + k_ + d[k_] = self.__dict__[k__] + return d + def _add_to_buffer(self, name, inst): if inst is None: if getattr(self, name, None) is None: self.__dict__[name] = None return + if name in self._meta.keys(): + for k in inst.keys(): + self._add_to_buffer('_' + name + '@' + k, inst[k]) + return if self.__dict__.get(name, None) is None: if isinstance(inst, np.ndarray): self.__dict__[name] = np.zeros([self._maxsize, *inst.shape]) elif isinstance(inst, dict): - self.__dict__[name] = np.array( - [{} for _ in range(self._maxsize)]) + if name == 'info': + self.__dict__[name] = np.array( + [{} for _ in range(self._maxsize)]) + else: + if self._meta.get(name, None) is None: + self._meta[name] = [ + '_' + name + '@' + k for k in inst.keys()] + for k in inst.keys(): + k_ = '_' + name + '@' + k + self._add_to_buffer(k_, inst[k]) else: # assume `inst` is a number self.__dict__[name] = np.zeros([self._maxsize]) if isinstance(inst, np.ndarray) and \ self.__dict__[name].shape[1:] != inst.shape: - self.__dict__[name] = np.zeros([self._maxsize, *inst.shape]) - self.__dict__[name][self._index] = inst + raise ValueError( + "Cannot add data to a buffer with different shape, " + f"key: {name}, expect shape: {self.__dict__[name].shape[1:]}, " + f"given shape: {inst.shape}.") + if name not in self._meta.keys(): + self.__dict__[name][self._index] = inst def update(self, buffer): """Move the data from the given buffer to self.""" @@ -144,7 +175,8 @@ def update(self, buffer): if i == begin: break - def add(self, obs, act, rew, done, obs_next=None, info={}, weight=None): + def add(self, obs, act, rew, done, obs_next=None, info={}, policy={}, + **kwargs): """Add a batch of data into replay buffer.""" assert isinstance(info, dict), \ 'You should return a dict in the last argument of env.step().' @@ -155,6 +187,7 @@ def add(self, obs, act, rew, done, obs_next=None, info={}, weight=None): if self._save_s_: self._add_to_buffer('obs_next', obs_next) self._add_to_buffer('info', info) + self._add_to_buffer('policy', policy) if self._maxsize > 0: self._size = min(self._size + 1, self._maxsize) self._index = (self._index + 1) % self._maxsize @@ -180,11 +213,13 @@ def sample(self, batch_size): ]) return self[indice], indice - def get(self, indice, key): + def get(self, indice, key, stack_num=None): """Return the stacked result, e.g. [s_{t-3}, s_{t-2}, s_{t-1}, s_t], where s is self.key, t is indice. The stack_num (here equals to 4) is given from buffer initialization procedure. """ + if stack_num is None: + stack_num = self._stack if not isinstance(indice, np.ndarray): if np.isscalar(indice): indice = np.array(indice) @@ -200,18 +235,37 @@ def get(self, indice, key): indice += 1 - self.done[indice].astype(np.int) indice[indice == self._size] = 0 key = 'obs' - if self._stack == 0: + if stack_num == 0: self.done[last_index] = last_done - return self.__dict__[key][indice] - stack = [] - for i in range(self._stack): - stack = [self.__dict__[key][indice]] + stack + if key in self._meta: + return {k.split('@')[-1]: self.__dict__[k][indice] + for k in self._meta[key]} + else: + return self.__dict__[key][indice] + if key in self._meta: + many_keys = self._meta[key] + stack = {k.split('@')[-1]: [] for k in self._meta[key]} + else: + stack = [] + many_keys = None + for i in range(stack_num): + if many_keys is not None: + for k_ in many_keys: + k = k_.split('@')[-1] + stack[k] = [self.__dict__[k_][indice]] + stack[k] + else: + stack = [self.__dict__[key][indice]] + stack pre_indice = indice - 1 pre_indice[pre_indice == -1] = self._size - 1 indice = pre_indice + self.done[pre_indice].astype(np.int) indice[indice == self._size] = 0 self.done[last_index] = last_done - return np.stack(stack, axis=1) + if many_keys is not None: + for k in stack: + stack[k] = np.stack(stack[k], axis=1) + else: + stack = np.stack(stack, axis=1) + return stack def __getitem__(self, index): """Return a data batch: self[index]. If stack_num is set to be > 0, @@ -223,7 +277,8 @@ def __getitem__(self, index): rew=self.rew[index], done=self.done[index], obs_next=self.get(index, 'obs_next'), - info=self.info[index] + info=self.info[index], + policy=self.get(index, 'policy'), ) @@ -234,7 +289,7 @@ class ListReplayBuffer(ReplayBuffer): .. seealso:: - Please refer to :class:`~tianshou.data.ListReplayBuffer` for more + Please refer to :class:`~tianshou.data.ReplayBuffer` for more detailed explanation. """ @@ -256,7 +311,13 @@ def reset(self): class PrioritizedReplayBuffer(ReplayBuffer): - """docstring for PrioritizedReplayBuffer""" + """Prioritized replay buffer implementation. + + .. seealso:: + + Please refer to :class:`~tianshou.data.ReplayBuffer` for more + detailed explanation. + """ def __init__(self, size, alpha: float, beta: float, mode: str = 'weight', **kwargs): @@ -270,17 +331,18 @@ def __init__(self, size, alpha: float, beta: float, self._amortization_freq = 50 self._amortization_counter = 0 - def add(self, obs, act, rew, done, obs_next=0, info={}, weight=1.0): + def add(self, obs, act, rew, done, obs_next=0, info={}, policy={}, + weight=1.0): """Add a batch of data into replay buffer.""" self._weight_sum += np.abs(weight)**self._alpha - \ self.weight[self._index] # we have to sacrifice some convenience for speed :( - self._add_to_buffer('weight', np.abs(weight)**self._alpha) - super().add(obs, act, rew, done, obs_next, info) + self._add_to_buffer('weight', np.abs(weight) ** self._alpha) + super().add(obs, act, rew, done, obs_next, info, policy) self._check_weight_sum() def sample(self, batch_size: int = 0, importance_sample: bool = True): - """ Get a random sample from buffer with priority probability. \ + """Get a random sample from buffer with priority probability. \ Return all the data in the buffer if batch_size is ``0``. :return: Sample data and its corresponding index inside the buffer. @@ -290,7 +352,8 @@ def sample(self, batch_size: int = 0, importance_sample: bool = True): # will cause weight update conflict indice = np.random.choice( self._size, batch_size, - p=(self.weight/self.weight.sum())[:self._size], replace=False) + p=(self.weight / self.weight.sum())[:self._size], + replace=False) # self._weight_sum is not work for the accuracy issue # p=(self.weight/self._weight_sum)[:self._size], replace=False) elif batch_size == 0: @@ -305,8 +368,9 @@ def sample(self, batch_size: int = 0, importance_sample: bool = True): batch = self[indice] if importance_sample: impt_weight = Batch( - impt_weight=1/np.power( - self._size*(batch.weight/self._weight_sum), self._beta)) + impt_weight=1 / np.power( + self._size * (batch.weight / self._weight_sum), + self._beta)) batch.append(impt_weight) self._check_weight_sum() return batch, indice @@ -316,7 +380,7 @@ def reset(self): super().reset() def update_weight(self, indice, new_weight: np.ndarray): - """update priority weight by indice in this buffer + """Update priority weight by indice in this buffer. :param indice: indice you want to update weight :param new_weight: new priority weight you wangt to update @@ -333,7 +397,8 @@ def __getitem__(self, index): done=self.done[index], obs_next=self.get(index, 'obs_next'), info=self.info[index], - weight=self.weight[index] + weight=self.weight[index], + policy=self.get(index, 'policy'), ) def _check_weight_sum(self): diff --git a/tianshou/policy/modelfree/a2c.py b/tianshou/policy/modelfree/a2c.py index e1f511678..9a50712ab 100644 --- a/tianshou/policy/modelfree/a2c.py +++ b/tianshou/policy/modelfree/a2c.py @@ -54,7 +54,7 @@ def process_fn(self, batch, buffer, indice): batch, None, gamma=self._gamma, gae_lambda=self._lambda) v_ = [] with torch.no_grad(): - for b in batch.split(self._batch, permute=False): + for b in batch.split(self._batch, shuffle=False): v_.append(self.critic(b.obs_next).detach().cpu().numpy()) v_ = np.concatenate(v_, axis=0) return self.compute_episodic_return( diff --git a/tianshou/policy/modelfree/ppo.py b/tianshou/policy/modelfree/ppo.py index c1a045eab..25cfb28be 100644 --- a/tianshou/policy/modelfree/ppo.py +++ b/tianshou/policy/modelfree/ppo.py @@ -74,7 +74,7 @@ def process_fn(self, batch, buffer, indice): batch, None, gamma=self._gamma, gae_lambda=self._lambda) v_ = [] with torch.no_grad(): - for b in batch.split(self._batch, permute=False): + for b in batch.split(self._batch, shuffle=False): v_.append(self.critic(b.obs_next)) v_ = torch.cat(v_, dim=0).cpu().numpy() return self.compute_episodic_return( @@ -111,7 +111,7 @@ def learn(self, batch, batch_size=None, repeat=1, **kwargs): v = [] old_log_prob = [] with torch.no_grad(): - for b in batch.split(batch_size, permute=False): + for b in batch.split(batch_size, shuffle=False): v.append(self.critic(b.obs)) old_log_prob.append(self(b).dist.log_prob( torch.tensor(b.act, device=v[0].device)))