-
Notifications
You must be signed in to change notification settings - Fork 0
Sac fix #96
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
mighty/mighty_agents/base_agent.py
Outdated
|
|
||
|
|
||
| # 3) optionally overwrite next_s on truncation | ||
| if self.handle_timeout_termination: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure I like the naming. What does it mean to "handle_timeout_termination"? Should be more expressive. Also: when do we want this? Always? On specific envs? Specific algos? I would actually assume always, since we only want the next_s for next action prediction and always final obs in the replay. In that case we don't need a flag at all.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed the optional flag
| next_s, reward, terminated, truncated, infos = self.env.step(action) | ||
|
|
||
| # 2) decide which samples are true “done” | ||
| replay_dones = terminated # physics‐failure only |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Comment here is env-specific. Also inconsistent: dones are always overwritten to real termination regardless of what the flag says.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Make this default
mighty/mighty_agents/sac.py
Outdated
| # Pack transition | ||
| transition = TransitionBatch(curr_s, action, reward, next_s, dones) | ||
| # Pack transition | ||
| # `terminated` is used for physics failures in environments like `MightyEnv` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
At least remove the weird AI comments
| elif isinstance(out, tuple) and len(out) == 4: | ||
| action = out[0] # [batch, action_dim] | ||
|
|
||
| print(f'Self Model : {self.model}') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove print
| print(f'Self Model : {self.model}') | ||
| log_prob = sample_nondeterministic_logprobs( | ||
| z=out[1], mean=out[2], log_std=out[3], sac=self.algo == "sac" | ||
| z=out[1], mean=out[2], log_std=out[3], sac=isinstance(self.model, SACModel) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Bad idea! What If I want to implement a different model class for SAC that e.g. handles prediction differently? Then the policy stops functioning.
| z: torch.Tensor, | ||
| mean: torch.Tensor, | ||
| log_std: torch.Tensor, | ||
| sac: bool = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The flag is here to stay model agnostic. Now you make it impossible to add new model classes for SAC...
| # 4-tuple case (Tanh squashing): (action, z, mean, log_std) | ||
| elif isinstance(model_output, tuple) and len(model_output) == 4: | ||
| action, z, mean, log_std = model_output | ||
| log_prob = sample_nondeterministic_logprobs( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't understand the reason for changing this, it's the same code but longer and locking into a specific model class?
| return action.detach().cpu().numpy(), log_prob | ||
| else: | ||
| weighted_log_prob = log_prob * self.entropy_coefficient | ||
| weighted_log_prob = log_prob |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is strange, now both do the same?!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reverted
| log_prob = sample_nondeterministic_logprobs( | ||
| z=z, mean=mean, log_std=log_std, sac=self.algo == "sac" | ||
| ) | ||
| if not isinstance(self.model, SACModel): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same issue as above, identical function longer and worse
| """ | ||
| feats = self.feature_extractor(state) | ||
| x = self.policy_net(feats) | ||
| x = self.policy_net(state) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not in the mighty format. The separate feature extractor is there to have a predictable structure and access to a feature embedding "Mighty-er" format would be to have a feature extractor -> policy head and then a q_feature_extractor. No functional difference, but it's relevant for continuity between algos.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated -- performance similar
Updates to SAC
handle_timeout_termination. When this is set to true, we treat the final states ofterminateddifferently fromtruncated. This is related to [Bug] Infinite horizon tasks are handled like episodic tasks DLR-RM/stable-baselines3#284- We now use the
policy_log_prob()from the SAC model exclusivley for the tanh correction instead ofsample_nondeterministic_logprobs. The latter can be potentially just made for PPO- Added
make_policy_head()to seperate hte policy head functionality- The SAC network forward method now handles action rescaling and log_prob resampling
- SAC update uses fresh samples for alpha update, and exponentiates log_alpha