New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Flexible and easy to use HSDP setting #19504
base: master
Are you sure you want to change the base?
Conversation
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #19504 +/- ##
=========================================
- Coverage 84% 58% -26%
=========================================
Files 425 420 -5
Lines 35023 34945 -78
=========================================
- Hits 29371 20116 -9255
- Misses 5652 14829 +9177 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great job. Looking forward to this feature.
if isinstance(self._fsdp_kwargs.get("device_mesh"), tuple): | ||
from torch.distributed.device_mesh import init_device_mesh | ||
|
||
self._fsdp_kwargs["device_mesh"] = init_device_mesh("cuda", self._fsdp_kwargs["device_mesh"]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since the tuple specification is a feature of Lightning, we should list the device_mesh parameter explicitly in the init args (see the docstring I added). The kwargs are for things that we pass down to FSDP directly.
So I suggest to set self.device_mesh
and update this attribute here to the actual DeviceMesh 😃
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If device_mesh
is separated from kwargs
, we will fail the check in _init_sharding_strategy
in the self.__init__
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed now. @awaelchli could you take a look again?
for more information, see https://pre-commit.ci
Any updates on this? |
Hi @awaelchli , I just fixed a merge conflict with the master and pushed the change. Could you take another look? Let me know if you have any concerns. |
What does this PR do?
Fixes #19502
It allows users to provide
device_mesh
as a tuple to theFSDPStrategy
.The test case for hybrid fsdp strategy would be broken, because_init_sharding_strategy
is moved out of the__init__
function.Before submitting
PR review
Anyone in the community is welcome to review the PR.
Before you start reviewing, make sure you have read the review guidelines. In short, see the following bullet-list:
Reviewer checklist
📚 Documentation preview 📚: https://pytorch-lightning--19504.org.readthedocs.build/en/19504/