Skip to content

Conversation

@amsks
Copy link
Collaborator

@amsks amsks commented Aug 7, 2025

Updates to SAC

  • Major: Base agnent now has an option handle_timeout_termination. When this is set to true, we treat the final states of terminated differently from truncated. This is related to [Bug] Infinite horizon tasks are handled like episodic tasks DLR-RM/stable-baselines3#284
  • Minor: Interface changes in SAC
    - We now use the policy_log_prob() from the SAC model exclusivley for the tanh correction instead of sample_nondeterministic_logprobs. The latter can be potentially just made for PPO
    - Added make_policy_head() to seperate hte policy head functionality
    - The SAC network forward method now handles action rescaling and log_prob resampling
    - SAC update uses fresh samples for alpha update, and exponentiates log_alpha
  • Updated Tests

@amsks amsks requested a review from TheEimer August 7, 2025 09:34
@amsks amsks added the bug Something isn't working label Aug 7, 2025
@amsks amsks added this to the MLOSS milestone Aug 7, 2025


# 3) optionally overwrite next_s on truncation
if self.handle_timeout_termination:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure I like the naming. What does it mean to "handle_timeout_termination"? Should be more expressive. Also: when do we want this? Always? On specific envs? Specific algos? I would actually assume always, since we only want the next_s for next action prediction and always final obs in the replay. In that case we don't need a flag at all.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed the optional flag

next_s, reward, terminated, truncated, infos = self.env.step(action)

# 2) decide which samples are true “done”
replay_dones = terminated # physics‐failure only
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment here is env-specific. Also inconsistent: dones are always overwritten to real termination regardless of what the flag says.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make this default

# Pack transition
transition = TransitionBatch(curr_s, action, reward, next_s, dones)
# Pack transition
# `terminated` is used for physics failures in environments like `MightyEnv`
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At least remove the weird AI comments

elif isinstance(out, tuple) and len(out) == 4:
action = out[0] # [batch, action_dim]

print(f'Self Model : {self.model}')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove print

print(f'Self Model : {self.model}')
log_prob = sample_nondeterministic_logprobs(
z=out[1], mean=out[2], log_std=out[3], sac=self.algo == "sac"
z=out[1], mean=out[2], log_std=out[3], sac=isinstance(self.model, SACModel)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bad idea! What If I want to implement a different model class for SAC that e.g. handles prediction differently? Then the policy stops functioning.

z: torch.Tensor,
mean: torch.Tensor,
log_std: torch.Tensor,
sac: bool = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The flag is here to stay model agnostic. Now you make it impossible to add new model classes for SAC...

# 4-tuple case (Tanh squashing): (action, z, mean, log_std)
elif isinstance(model_output, tuple) and len(model_output) == 4:
action, z, mean, log_std = model_output
log_prob = sample_nondeterministic_logprobs(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand the reason for changing this, it's the same code but longer and locking into a specific model class?

return action.detach().cpu().numpy(), log_prob
else:
weighted_log_prob = log_prob * self.entropy_coefficient
weighted_log_prob = log_prob
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is strange, now both do the same?!

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reverted

log_prob = sample_nondeterministic_logprobs(
z=z, mean=mean, log_std=log_std, sac=self.algo == "sac"
)
if not isinstance(self.model, SACModel):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same issue as above, identical function longer and worse

"""
feats = self.feature_extractor(state)
x = self.policy_net(feats)
x = self.policy_net(state)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not in the mighty format. The separate feature extractor is there to have a predictable structure and access to a feature embedding "Mighty-er" format would be to have a feature extractor -> policy head and then a q_feature_extractor. No functional difference, but it's relevant for continuity between algos.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated -- performance similar

@TheEimer TheEimer merged commit 9edb039 into main Aug 8, 2025
2 checks passed
@TheEimer TheEimer deleted the sac_fix branch August 8, 2025 10:51
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants