-
Notifications
You must be signed in to change notification settings - Fork 4.4k
[feature] Fix TF tests, add --torch CLI option, allow run TF without torch installed #4305
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
4a1fbb1
419bec9
6c98f6f
a521ab1
00f2137
e3e5869
0ea29c9
bc03a01
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -524,6 +524,11 @@ def to_settings(self) -> type: | |
| return _mapping[self] | ||
|
|
||
|
|
||
| class FrameworkType(Enum): | ||
| TENSORFLOW: str = "tensorflow" | ||
| PYTORCH: str = "pytorch" | ||
|
|
||
|
|
||
| @attr.s(auto_attribs=True) | ||
| class TrainerSettings(ExportableSettings): | ||
| trainer_type: TrainerType = TrainerType.PPO | ||
|
|
@@ -546,6 +551,7 @@ def _set_default_hyperparameters(self): | |
| threaded: bool = True | ||
| self_play: Optional[SelfPlaySettings] = None | ||
| behavioral_cloning: Optional[BehavioralCloningSettings] = None | ||
| framework: FrameworkType = FrameworkType.TENSORFLOW | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There are 2 ways to have a pytorch trainer then, with
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yep, that's true with every CLI option - it goes into the YAML somehow. That way a run can be specified entirely from a YAML file - this is for cloud training and for reproducibility. |
||
|
|
||
| cattr.register_structure_hook( | ||
| Dict[RewardSignalType, RewardSignalSettings], RewardSignalSettings.structure | ||
|
|
@@ -713,7 +719,13 @@ def from_argparse(args: argparse.Namespace) -> "RunOptions": | |
| configured_dict["engine_settings"][key] = val | ||
| else: # Base options | ||
| configured_dict[key] = val | ||
| return RunOptions.from_dict(configured_dict) | ||
|
|
||
| # Apply --torch retroactively | ||
| final_runoptions = RunOptions.from_dict(configured_dict) | ||
| if "torch" in DetectDefault.non_default_args: | ||
| for trainer_set in final_runoptions.behaviors.values(): | ||
| trainer_set.framework = FrameworkType.PYTORCH | ||
| return final_runoptions | ||
|
|
||
| @staticmethod | ||
| def from_dict(options_dict: Dict[str, Any]) -> "RunOptions": | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a way to hide the help, or do we not care if the users see this?
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes -
argparse.SUPPRESS. I can hide it. Should we? I guess it depends on whether we want people to try it or not.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's not hide it then