Skip to content

Commit

Permalink
fix(log_prob): fixed bugs when calculating log_prob while using funct…
Browse files Browse the repository at this point in the history
…ion `squash_action`. (#34)

thanks to @BlueFisher
  • Loading branch information
StepNeverStop committed Aug 30, 2021
1 parent 23f910c commit 078b523
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 12 deletions.
9 changes: 5 additions & 4 deletions rls/algorithms/single/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,10 +137,11 @@ def _train_continuous(self, BATCH):
if self.is_continuous:
target_mu, target_log_std = self.actor(
BATCH.obs_, begin_mask=BATCH.begin_mask) # [T, B, A]
dist = td.Normal(target_mu, target_log_std.exp())
dist = td.Independent(
td.Normal(target_mu, target_log_std.exp()), 1)
target_pi = dist.sample() # [T, B, A]
target_pi, target_log_pi = squash_action(
target_pi, dist.log_prob(target_pi)) # [T, B, A], [T, B, 1]
target_pi, dist.log_prob(target_pi).unsqueeze(-1)) # [T, B, A], [T, B, 1]
else:
target_logits = self.actor(
BATCH.obs_, begin_mask=BATCH.begin_mask) # [T, B, A]
Expand Down Expand Up @@ -170,10 +171,10 @@ def _train_continuous(self, BATCH):
if self.is_continuous:
mu, log_std = self.actor(
BATCH.obs, begin_mask=BATCH.begin_mask) # [T, B, A]
dist = td.Normal(mu, log_std.exp())
dist = td.Independent(td.Normal(mu, log_std.exp()), 1)
pi = dist.rsample() # [T, B, A]
pi, log_pi = squash_action(
pi, dist.log_prob(pi)) # [T, B, A], [T, B, 1]
pi, dist.log_prob(pi).unsqueeze(-1)) # [T, B, A], [T, B, 1]
entropy = dist.entropy().mean() # 1
else:
logits = self.actor(
Expand Down
8 changes: 4 additions & 4 deletions rls/algorithms/single/sac_v.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,10 +145,10 @@ def _train_continuous(self, BATCH):
if self.is_continuous:
mu, log_std = self.actor(
BATCH.obs, begin_mask=BATCH.begin_mask) # [T, B, A]
dist = td.Normal(mu, log_std.exp())
dist = td.Independent(td.Normal(mu, log_std.exp()), 1)
pi = dist.rsample() # [T, B, A]
pi, log_pi = squash_action(
pi, dist.log_prob(pi)) # [T, B, A], [T, B, 1]
pi, dist.log_prob(pi).unsqueeze(-1)) # [T, B, A], [T, B, 1]
else:
logits = self.actor(
BATCH.obs, begin_mask=BATCH.begin_mask) # [T, B, A]
Expand Down Expand Up @@ -189,10 +189,10 @@ def _train_continuous(self, BATCH):
if self.is_continuous:
mu, log_std = self.actor(
BATCH.obs, begin_mask=BATCH.begin_mask) # [T, B, A]
dist = td.Normal(mu, log_std.exp())
dist = td.Independent(td.Normal(mu, log_std.exp()), 1)
pi = dist.rsample() # [T, B, A]
pi, log_pi = squash_action(
pi, dist.log_prob(pi)) # [T, B, A], [T, B, 1]
pi, dist.log_prob(pi).unsqueeze(-1)) # [T, B, A], [T, B, 1]
entropy = dist.entropy().mean() # 1
else:
logits = self.actor(
Expand Down
9 changes: 5 additions & 4 deletions rls/algorithms/single/tac.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,10 +119,11 @@ def _train(self, BATCH):
if self.is_continuous:
target_mu, target_log_std = self.actor(
BATCH.obs_, begin_mask=BATCH.begin_mask) # [T, B, A]
dist = td.Normal(target_mu, target_log_std.exp())
dist = td.Independent(
td.Normal(target_mu, target_log_std.exp()), 1)
target_pi = dist.sample() # [T, B, A]
target_pi, target_log_pi = squash_action(
target_pi, dist.log_prob(target_pi), is_independent=False) # [T, B, A]
target_pi, dist.log_prob(target_pi).unsqueeze(-1), is_independent=False) # [T, B, A]
target_log_pi = tsallis_entropy_log_q(
target_log_pi, self.entropic_index) # [T, B, 1]
else:
Expand Down Expand Up @@ -160,10 +161,10 @@ def _train(self, BATCH):
if self.is_continuous:
mu, log_std = self.actor(
BATCH.obs, begin_mask=BATCH.begin_mask) # [T, B, A]
dist = td.Normal(mu, log_std.exp())
dist = td.Independent(td.Normal(mu, log_std.exp()), 1)
pi = dist.rsample() # [T, B, A]
pi, log_pi = squash_action(pi, dist.log_prob(
pi), is_independent=False) # [T, B, A]
pi).unsqueeze(-1), is_independent=False) # [T, B, A]
log_pi = tsallis_entropy_log_q(
log_pi, self.entropic_index) # [T, B, 1]
entropy = dist.entropy().mean() # 1
Expand Down

0 comments on commit 078b523

Please sign in to comment.