Skip to content

Commit

Permalink
fix historical issues
Browse files Browse the repository at this point in the history
  • Loading branch information
Trinkle23897 committed Apr 26, 2020
1 parent 1176dfd commit cf4a2ee
Show file tree
Hide file tree
Showing 6 changed files with 20 additions and 9 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ Setup policy and collectors:

```python
policy = ts.policy.DQNPolicy(net, optim, gamma, n_step,
use_target_network=True, target_update_freq=target_freq)
target_update_freq=target_freq)
train_collector = ts.data.Collector(policy, train_envs, ts.data.ReplayBuffer(buffer_size))
test_collector = ts.data.Collector(policy, test_envs)
```
Expand Down Expand Up @@ -242,7 +242,7 @@ collector.collect(n_episode=1, render=1 / 35)
collector.close()
```

Look at the result saved in tensorboard: (on bash script)
Look at the result saved in tensorboard: (with bash script in your terminal)

```bash
tensorboard --logdir log/dqn
Expand Down
3 changes: 2 additions & 1 deletion test/discrete/test_a2c_with_il.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def get_args():
parser.add_argument('--ent-coef', type=float, default=0.001)
parser.add_argument('--max-grad-norm', type=float, default=None)
parser.add_argument('--gae-lambda', type=float, default=1.)
parser.add_argument('--rew-norm', type=bool, default=False)
args = parser.parse_known_args()[0]
return args

Expand Down Expand Up @@ -74,7 +75,7 @@ def test_a2c(args=get_args()):
policy = A2CPolicy(
actor, critic, optim, dist, args.gamma, gae_lambda=args.gae_lambda,
vf_coef=args.vf_coef, ent_coef=args.ent_coef,
max_grad_norm=args.max_grad_norm)
max_grad_norm=args.max_grad_norm, reward_normalization=args.rew_norm)
# collector
train_collector = Collector(
policy, train_envs, ReplayBuffer(args.buffer_size))
Expand Down
4 changes: 3 additions & 1 deletion test/discrete/test_pg.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ def get_args():
parser.add_argument('--test-num', type=int, default=100)
parser.add_argument('--logdir', type=str, default='log')
parser.add_argument('--render', type=float, default=0.)
parser.add_argument('--rew-norm', type=bool, default=True)
parser.add_argument(
'--device', type=str,
default='cuda' if torch.cuda.is_available() else 'cpu')
Expand Down Expand Up @@ -132,7 +133,8 @@ def test_pg(args=get_args()):
net = net.to(args.device)
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
dist = torch.distributions.Categorical
policy = PGPolicy(net, optim, dist, args.gamma)
policy = PGPolicy(net, optim, dist, args.gamma,
reward_normalization=args.rew_norm)
# collector
train_collector = Collector(
policy, train_envs, ReplayBuffer(args.buffer_size))
Expand Down
8 changes: 7 additions & 1 deletion tianshou/policy/modelfree/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ class A2CPolicy(PGPolicy):
def __init__(self, actor, critic, optim,
dist_fn=torch.distributions.Categorical,
discount_factor=0.99, vf_coef=.5, ent_coef=.01,
max_grad_norm=None, gae_lambda=0.95, **kwargs):
max_grad_norm=None, gae_lambda=0.95,
reward_normalization=False, **kwargs):
super().__init__(None, optim, dist_fn, discount_factor, **kwargs)
self.actor = actor
self.critic = critic
Expand All @@ -44,6 +45,8 @@ def __init__(self, actor, critic, optim,
self._w_ent = ent_coef
self._grad_norm = max_grad_norm
self._batch = 64
self._rew_norm = reward_normalization
self.__eps = np.finfo(np.float32).eps.item()

def process_fn(self, batch, buffer, indice):
if self._lambda in [0, 1]:
Expand Down Expand Up @@ -82,6 +85,9 @@ def forward(self, batch, state=None, **kwargs):

def learn(self, batch, batch_size=None, repeat=1, **kwargs):
self._batch = batch_size
r = batch.returns
if self._rew_norm and r.std() > self.__eps:
batch.returns = (r - r.mean()) / r.std()
losses, actor_losses, vf_losses, ent_losses = [], [], [], []
for _ in range(repeat):
for b in batch.split(batch_size):
Expand Down
2 changes: 1 addition & 1 deletion tianshou/policy/modelfree/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
from copy import deepcopy
import torch.nn.functional as F

from tianshou.data import Batch, PrioritizedReplayBuffer
from tianshou.policy import BasePolicy
from tianshou.data import Batch, PrioritizedReplayBuffer


class DQNPolicy(BasePolicy):
Expand Down
8 changes: 5 additions & 3 deletions tianshou/policy/modelfree/pg.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,15 @@ class PGPolicy(BasePolicy):
"""

def __init__(self, model, optim, dist_fn=torch.distributions.Categorical,
discount_factor=0.99, **kwargs):
discount_factor=0.99, reward_normalization=False, **kwargs):
super().__init__(**kwargs)
self.model = model
self.optim = optim
self.dist_fn = dist_fn
self._eps = np.finfo(np.float32).eps.item()
assert 0 <= discount_factor <= 1, 'discount factor should in [0, 1]'
self._gamma = discount_factor
self._rew_norm = reward_normalization
self.__eps = np.finfo(np.float32).eps.item()

def process_fn(self, batch, buffer, indice):
r"""Compute the discounted returns for each frame:
Expand Down Expand Up @@ -71,7 +72,8 @@ def forward(self, batch, state=None, **kwargs):
def learn(self, batch, batch_size=None, repeat=1, **kwargs):
losses = []
r = batch.returns
batch.returns = (r - r.mean()) / (r.std() + self._eps)
if self._rew_norm and r.std() > self.__eps:
batch.returns = (r - r.mean()) / r.std()
for _ in range(repeat):
for b in batch.split(batch_size):
self.optim.zero_grad()
Expand Down

0 comments on commit cf4a2ee

Please sign in to comment.