Skip to content

Commit

Permalink
fix thu-ml#98, support thu-ml#99 (thu-ml#102)
Browse files Browse the repository at this point in the history
* Add auto alpha tuning and exploration noise for sac.
Add class BaseNoise and GaussianNoise for the concept of exploration noise.
Add new test for sac tested in MountainCarContinuous-v0,
which should benefits from the two above new feature.

* add exploration noise to collector, fix example to adapt modification

* fix thu-ml#98

* enable off-policy to update multiple times in one step. (thu-ml#99)
  • Loading branch information
danagi committed Jun 27, 2020
1 parent 4b7ef53 commit ef8c47b
Show file tree
Hide file tree
Showing 5 changed files with 13 additions and 7 deletions.
2 changes: 1 addition & 1 deletion examples/ant_v2_sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def test_sac(args=get_args()):
# model
actor = ActorProb(
args.layer_num, args.state_shape, args.action_shape,
args.max_action, args.device
args.max_action, args.device, unbounded=True
).to(args.device)
actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)
critic1 = Critic(
Expand Down
6 changes: 4 additions & 2 deletions examples/continuous_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def forward(self, s, **kwargs):

class ActorProb(nn.Module):
def __init__(self, layer_num, state_shape, action_shape,
max_action, device='cpu'):
max_action, device='cpu', unbounded=False):
super().__init__()
self.device = device
self.model = [
Expand All @@ -40,14 +40,16 @@ def __init__(self, layer_num, state_shape, action_shape,
self.mu = nn.Linear(128, np.prod(action_shape))
self.sigma = nn.Linear(128, np.prod(action_shape))
self._max = max_action
self._unbounded = unbounded

def forward(self, s, **kwargs):
if not isinstance(s, torch.Tensor):
s = torch.tensor(s, device=self.device, dtype=torch.float)
batch = s.shape[0]
s = s.view(batch, -1)
logits = self.model(s)
mu = self._max * torch.tanh(self.mu(logits))
if not self._unbounded:
mu = self._max * torch.tanh(self.mu(logits))
sigma = torch.exp(self.sigma(logits))
return (mu, sigma), None

Expand Down
2 changes: 1 addition & 1 deletion examples/halfcheetahBullet_v0_sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def test_sac(args=get_args()):
# model
actor = ActorProb(
args.layer_num, args.state_shape, args.action_shape,
args.max_action, args.device
args.max_action, args.device, unbounded=True
).to(args.device)
actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)
critic1 = Critic(
Expand Down
2 changes: 1 addition & 1 deletion examples/sac_mcc.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def test_sac(args=get_args()):
# model
actor = ActorProb(
args.layer_num, args.state_shape, args.action_shape,
args.max_action, args.device
args.max_action, args.device, unbounded=True
).to(args.device)
actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)
critic1 = Critic(
Expand Down
8 changes: 6 additions & 2 deletions tianshou/trainer/offpolicy.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def offpolicy_trainer(
collect_per_step: int,
episode_per_test: Union[int, List[int]],
batch_size: int,
update_per_step: int = 1,
train_fn: Optional[Callable[[int], None]] = None,
test_fn: Optional[Callable[[int], None]] = None,
stop_fn: Optional[Callable[[float], bool]] = None,
Expand All @@ -42,10 +43,13 @@ def offpolicy_trainer(
in one epoch.
:param int collect_per_step: the number of frames the collector would
collect before the network update. In other words, collect some frames
and do one policy network update.
and do some policy network update.
:param episode_per_test: the number of episodes for one policy evaluation.
:param int batch_size: the batch size of sample data, which is going to
feed in the policy network.
:param int update_per_step: the number of times the policy network would
be updated after frames be collected. In other words, collect some
frames and do some policy network update.
:param function train_fn: a function receives the current number of epoch
index and performs some operations at the beginning of training in this
epoch.
Expand Down Expand Up @@ -98,7 +102,7 @@ def offpolicy_trainer(
policy.train()
if train_fn:
train_fn(epoch)
for i in range(min(
for i in range(update_per_step * min(
result['n/st'] // collect_per_step, t.total - t.n)):
global_step += 1
losses = policy.learn(train_collector.sample(batch_size))
Expand Down

0 comments on commit ef8c47b

Please sign in to comment.