-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Create context managers before entering any with ExitStack #18716
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -333,6 +333,8 @@ def module_to_device(self, module: Module) -> None: | |
| pass | ||
|
|
||
| def module_init_context(self, empty_init: Optional[bool] = None) -> ContextManager: | ||
| precision_init_ctx = self.precision.init_context() | ||
| module_sharded_ctx = self.module_sharded_context() | ||
| stack = ExitStack() | ||
| if _TORCH_GREATER_EQUAL_2_1 and empty_init: | ||
| # Materialization happens in `setup`. When modules get wrapped by FSDP, the sequence of operations is: | ||
|
|
@@ -341,8 +343,8 @@ def module_init_context(self, empty_init: Optional[bool] = None) -> ContextManag | |
| stack.enter_context(torch.device("meta")) | ||
| elif _TORCH_GREATER_EQUAL_1_13: | ||
| stack.enter_context(_EmptyInit(enabled=bool(empty_init))) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why isn't it applied in other places here?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you include it in #18734? |
||
| stack.enter_context(self.precision.init_context()) | ||
| stack.enter_context(self.module_sharded_context()) | ||
| stack.enter_context(precision_init_ctx) | ||
| stack.enter_context(module_sharded_ctx) | ||
| return stack | ||
|
|
||
| def module_sharded_context(self) -> ContextManager: | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for completeness, it would probably also be better to apply it to this line right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think this one matters because in the super() call, all the ctxmanagers are instantiated before any is entered