Skip to content

Commit

Permalink
nstep multidim support
Browse files Browse the repository at this point in the history
  • Loading branch information
Trinkle23897 committed Jan 4, 2021
1 parent 44cf066 commit 3695f12
Show file tree
Hide file tree
Showing 9 changed files with 43 additions and 104 deletions.
File renamed without changes
File renamed without changes
File renamed without changes
File renamed without changes
File renamed without changes
16 changes: 16 additions & 0 deletions test/base/test_returns.py
Expand Up @@ -76,6 +76,10 @@ def target_q_fn(buffer, indice):
return torch.tensor(-buffer.rew[indice], dtype=torch.float32)


def target_q_fn_multidim(buffer, indice):
return target_q_fn(buffer, indice).unsqueeze(1).repeat(1, 51)


def compute_nstep_return_base(nstep, gamma, buffer, indice):
returns = np.zeros_like(indice, dtype=np.float)
buf_len = len(buffer)
Expand Down Expand Up @@ -108,20 +112,32 @@ def test_nstep_returns(size=10000):
assert np.allclose(returns, [2.6, 4, 4.4, 5.3, 6.2, 8, 8, 8.9, 9.8, 12])
r_ = compute_nstep_return_base(1, .1, buf, indice)
assert np.allclose(returns, r_), (r_, returns)
returns_multidim = to_numpy(BasePolicy.compute_nstep_return(
batch, buf, indice, target_q_fn_multidim, gamma=.1, n_step=1
).pop('returns'))
assert np.allclose(returns_multidim, returns[:, np.newaxis])
# test nstep = 2
returns = to_numpy(BasePolicy.compute_nstep_return(
batch, buf, indice, target_q_fn, gamma=.1, n_step=2).pop('returns'))
assert np.allclose(returns, [
3.4, 4, 5.53, 6.62, 7.8, 8, 9.89, 10.98, 12.2, 12])
r_ = compute_nstep_return_base(2, .1, buf, indice)
assert np.allclose(returns, r_)
returns_multidim = to_numpy(BasePolicy.compute_nstep_return(
batch, buf, indice, target_q_fn_multidim, gamma=.1, n_step=2
).pop('returns'))
assert np.allclose(returns_multidim, returns[:, np.newaxis])
# test nstep = 10
returns = to_numpy(BasePolicy.compute_nstep_return(
batch, buf, indice, target_q_fn, gamma=.1, n_step=10).pop('returns'))
assert np.allclose(returns, [
3.4, 4, 5.678, 6.78, 7.8, 8, 10.122, 11.22, 12.2, 12])
r_ = compute_nstep_return_base(10, .1, buf, indice)
assert np.allclose(returns, r_)
returns_multidim = to_numpy(BasePolicy.compute_nstep_return(
batch, buf, indice, target_q_fn_multidim, gamma=.1, n_step=10
).pop('returns'))
assert np.allclose(returns_multidim, returns[:, np.newaxis])

if __name__ == '__main__':
buf = ReplayBuffer(size)
Expand Down
17 changes: 11 additions & 6 deletions tianshou/policy/base.py
Expand Up @@ -245,7 +245,7 @@ def compute_nstep_return(
to False.
:return: a Batch. The result will be stored in batch.returns as a
torch.Tensor with shape (bsz, ).
torch.Tensor with the same shape as target_q_fn's return tensor.
"""
rew = buffer.rew
if rew_norm:
Expand All @@ -257,12 +257,11 @@ def compute_nstep_return(
mean, std = 0.0, 1.0
buf_len = len(buffer)
terminal = (indice + n_step - 1) % buf_len
target_q_torch = target_q_fn(buffer, terminal).flatten() # (bsz, )
target_q_torch = target_q_fn(buffer, terminal) # (bsz, ?)
target_q = to_numpy(target_q_torch)

target_q = _nstep_return(rew, buffer.done, target_q, indice,
gamma, n_step, len(buffer), mean, std)

batch.returns = to_torch_as(target_q, target_q_torch)
if hasattr(batch, "weight"): # prio buffer update
batch.weight = to_torch_as(batch.weight, target_q_torch)
Expand All @@ -275,7 +274,7 @@ def _compile(self) -> None:
i64 = np.array([0, 1], dtype=np.int64)
_episodic_return(f64, f64, b, 0.1, 0.1)
_episodic_return(f32, f64, b, 0.1, 0.1)
_nstep_return(f64, b, f32, i64, 0.1, 1, 4, 1.0, 0.0)
_nstep_return(f64, b, f32, i64, 0.1, 1, 4, 0.0, 1.0)


@njit
Expand Down Expand Up @@ -311,13 +310,19 @@ def _nstep_return(
std: float,
) -> np.ndarray:
"""Numba speedup: 0.3s -> 0.15s."""
returns = np.zeros(indice.shape)
target_shape = target_q.shape
bsz = target_shape[0]
# change rew/target_q to 2d array
target_q = target_q.reshape(bsz, -1)
rew = rew.reshape(-1, 1) # assume reward is a scalar
returns = np.zeros(target_q.shape)
gammas = np.full(indice.shape, n_step)
for n in range(n_step - 1, -1, -1):
now = (indice + n) % buf_len
gammas[done[now] > 0] = n
returns[done[now] > 0] = 0.0
returns = (rew[now] - mean) / std + gamma * returns
target_q[gammas != n_step] = 0.0
gammas = gammas.reshape(-1, 1)
target_q = target_q * (gamma ** gammas) + returns
return target_q
return target_q.reshape(target_shape)
112 changes: 15 additions & 97 deletions tianshou/policy/modelfree/c51.py
@@ -1,7 +1,6 @@
import torch
import numpy as np
from numba import njit
from typing import Any, Dict, Union, Optional, Tuple
from typing import Any, Dict, Union, Optional

from tianshou.policy import DQNPolicy
from tianshou.data import Batch, ReplayBuffer, to_torch_as, to_numpy
Expand All @@ -19,7 +18,7 @@ class C51Policy(DQNPolicy):
:param float v_min: the value of the smallest atom in the support set,
defaults to -10.0.
:param float v_max: the value of the largest atom in the support set,
defaults to -10.0.
defaults to 10.0.
:param int estimation_step: greater than 1, the number of steps to look
ahead.
:param int target_update_freq: the target network update frequency (0 if
Expand All @@ -30,7 +29,7 @@ class C51Policy(DQNPolicy):
.. seealso::
Please refer to :class:`~tianshou.policy.DQNPolicy` for more detailed
explanation.
explanation.
"""

def __init__(
Expand All @@ -46,71 +45,21 @@ def __init__(
reward_normalization: bool = False,
**kwargs: Any,
) -> None:
super().__init__(model, optim, discount_factor,
estimation_step, target_update_freq,
reward_normalization, **kwargs)
super().__init__(model, optim, discount_factor, estimation_step,
target_update_freq, reward_normalization, **kwargs)
assert num_atoms > 1, "num_atoms should be greater than 1"
assert v_min < v_max, "v_max should be larger than v_min"
self._num_atoms = num_atoms
self._v_min = v_min
self._v_max = v_max
self.support = torch.linspace(self._v_min, self._v_max,
self._num_atoms)
self.delta_z = (v_max - v_min) / (num_atoms - 1)

@staticmethod
def prepare_n_step(
batch: Batch,
buffer: ReplayBuffer,
indice: np.ndarray,
gamma: float = 0.99,
n_step: int = 1,
rew_norm: bool = False,
) -> Batch:
"""Modify the obs_next, done and rew in batch for computing n-step return.
:param batch: a data batch, which is equal to buffer[indice].
:type batch: :class:`~tianshou.data.Batch`
:param buffer: a data buffer which contains several full-episode data
chronologically.
:type buffer: :class:`~tianshou.data.ReplayBuffer`
:param indice: sampled timestep.
:type indice: numpy.ndarray
:param float gamma: the discount factor, should be in [0, 1], defaults
to 0.99.
:param int n_step: the number of estimation step, should be an int
greater than 0, defaults to 1.
:param bool rew_norm: normalize the reward to Normal(0, 1), defaults
to False.
:return: a Batch with modified obs_next, done and rew.
"""
buf_len = len(buffer)
if rew_norm:
bfr = buffer.rew[: min(buf_len, 1000)] # avoid large buffer
mean, std = bfr.mean(), bfr.std()
if np.isclose(std, 0, 1e-2):
mean, std = 0.0, 1.0
else:
mean, std = 0.0, 1.0
buffer_n = buffer[(indice + n_step - 1) % buf_len]
batch.obs_next = buffer_n.obs_next
rew_n, done_n = _nstep_batch(buffer.rew, buffer.done,
indice, gamma, n_step, buf_len, mean, std)
batch.rew = rew_n
batch.done = done_n
return batch

def process_fn(
self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray
) -> Batch:
"""Prepare the batch for calculating the n-step return.
More details can be found at
:meth:`~tianshou.policy.C51Policy.prepare_n_step`.
"""
batch = self.prepare_n_step(
batch, buffer, indice,
self._gamma, self._n_step, self._rew_norm)
return batch
def _target_q(
self, buffer: ReplayBuffer, indice: np.ndarray
) -> torch.Tensor:
return self.support.repeat(len(indice), 1) # shape: [bsz, num_atoms]

def forward(
self,
Expand Down Expand Up @@ -164,25 +113,15 @@ def _target_dist(self, batch: Batch) -> torch.Tensor:
a = next_b.act
next_dist = next_b.logits
next_dist = next_dist[np.arange(len(a)), a, :]
device = next_dist.device
reward = torch.from_numpy(batch.rew).to(device).unsqueeze(1)
done = torch.from_numpy(batch.done).to(device).float().unsqueeze(1)
support = self.support.to(device)

# Compute the projection of bellman update Tz onto the support z.
target_support = reward + (
self._gamma ** self._n_step) * (1.0 - done) * support.unsqueeze(0)
target_support = target_support.clamp(self._v_min, self._v_max)

support = self.support.to(next_dist.device)
target_support = batch.returns.clamp(
self._v_min, self._v_max).to(next_dist.device)
# An amazing trick for calculating the projection gracefully.
# ref: https://github.com/ShangtongZhang/DeepRL
target_dist = (1 - (target_support.unsqueeze(1) -
support.view(1, -1, 1)).abs() / self.delta_z
).clamp(0, 1) * next_dist.unsqueeze(1)
target_dist = target_dist.sum(-1)
if hasattr(batch, "weight"): # prio buffer update
batch.weight = to_torch_as(batch.weight, target_dist)
return target_dist
return target_dist.sum(-1)

def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
if self._target and self._cnt % self._freq == 0:
Expand All @@ -201,24 +140,3 @@ def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
self.optim.step()
self._cnt += 1
return {"loss": loss.item()}


@njit
def _nstep_batch(
rew: np.ndarray,
done: np.ndarray,
indice: np.ndarray,
gamma: float,
n_step: int,
buf_len: int,
mean: float,
std: float,
) -> Tuple[np.ndarray, np.ndarray]:
rew_n = np.zeros(indice.shape)
done_n = done[indice]
for n in range(n_step - 1, -1, -1):
now = (indice + n) % buf_len
done_t = done[now]
done_n = np.bitwise_or(done_n, done_t)
rew_n = (rew[now] - mean) / std + (1.0 - done_t) * gamma * rew_n
return rew_n, done_n
2 changes: 1 addition & 1 deletion tianshou/utils/net/common.py
Expand Up @@ -175,7 +175,7 @@ class CategoricalNet(Net):
.. seealso::
Please refer to :class:`~tianshou.utils.net.common.Net` for
more detailed explanation.
more detailed explanation.
"""

def __init__(
Expand Down

0 comments on commit 3695f12

Please sign in to comment.