Skip to content

Commit

Permalink
Adds fixed alpha version of Soft Actor Critic algorithm (#178)
Browse files Browse the repository at this point in the history
* Fixes a bug where, in sac.py, self._alpha was not being re-computed after loading self._log_alpha from an optim_state_dict.

* Adds fixed_alpha option for SAC, which sets alpha to be a constant and does not adapt the alpha value.

Co-authored-by: jordan-schneider <jordan.jack.schneider@gmail.com>
  • Loading branch information
Jordan Schneider and jordan-schneider committed Sep 5, 2020
1 parent 59bc259 commit a9ac84f
Showing 1 changed file with 14 additions and 7 deletions.
21 changes: 14 additions & 7 deletions rlpyt/algos/qpg/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def __init__(
target_update_tau=0.005, # tau=1 for hard update.
target_update_interval=1, # 1000 for hard update, 1 for soft.
learning_rate=3e-4,
fixed_alpha=None, # None for adaptive alpha, float for any fixed value
OptimCls=torch.optim.Adam,
optim_kwargs=None,
initial_optim_state_dict=None, # for all of them.
Expand Down Expand Up @@ -109,10 +110,15 @@ def optim_initialize(self, rank=0):
lr=self.learning_rate, **self.optim_kwargs)
self.q2_optimizer = self.OptimCls(self.agent.q2_parameters(),
lr=self.learning_rate, **self.optim_kwargs)
self._log_alpha = torch.zeros(1, requires_grad=True)
self._alpha = torch.exp(self._log_alpha.detach())
self.alpha_optimizer = self.OptimCls((self._log_alpha,),
lr=self.learning_rate, **self.optim_kwargs)
if self.fixed_alpha is None:
self._log_alpha = torch.zeros(1, requires_grad=True)
self._alpha = torch.exp(self._log_alpha.detach())
self.alpha_optimizer = self.OptimCls((self._log_alpha,),
lr=self.learning_rate, **self.optim_kwargs)
else:
self._log_alpha = torch.tensor([np.log(self.fixed_alpha)])
self._alpha = torch.tensor([self.fixed_alpha])
self.alpha_optimizer = None
if self.target_entropy == "auto":
self.target_entropy = -np.prod(self.agent.env_spaces.action.shape)
if self.initial_optim_state_dict is not None:
Expand Down Expand Up @@ -268,7 +274,7 @@ def loss(self, samples):
# 0.5 * pi_mean ** 2 + 0.5 * pi_log_std ** 2, dim=-1)
pi_loss = valid_mean(pi_losses, valid)

if self.target_entropy is not None:
if self.target_entropy is not None and self.fixed_alpha is None:
alpha_losses = - self._log_alpha * (log_pi.detach() + self.target_entropy)
alpha_loss = valid_mean(alpha_losses, valid)
else:
Expand Down Expand Up @@ -309,15 +315,16 @@ def optim_state_dict(self):
pi_optimizer=self.pi_optimizer.state_dict(),
q1_optimizer=self.q1_optimizer.state_dict(),
q2_optimizer=self.q2_optimizer.state_dict(),
alpha_optimizer=self.alpha_optimizer.state_dict(),
alpha_optimizer=self.alpha_optimizer.state_dict() if self.alpha_optimizer else None,
log_alpha=self._log_alpha.detach().item(),
)

def load_optim_state_dict(self, state_dict):
self.pi_optimizer.load_state_dict(state_dict["pi_optimizer"])
self.q1_optimizer.load_state_dict(state_dict["q1_optimizer"])
self.q2_optimizer.load_state_dict(state_dict["q2_optimizer"])
self.alpha_optimizer.load_state_dict(state_dict["alpha_optimizer"])
if self.alpha_optimizer is not None and state_dict["alpha_optimizer"] is not None:
self.alpha_optimizer.load_state_dict(state_dict["alpha_optimizer"])
with torch.no_grad():
self._log_alpha[:] = state_dict["log_alpha"]
self._alpha = torch.exp(self._log_alpha.detach())

0 comments on commit a9ac84f

Please sign in to comment.