New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Meta device initialization for FSDP in Fabric #18122
Conversation
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great work here. Are you planning to port this to the Trainer? If not, we should open an issue. We should do it before 2.1
# Materialization happens in `setup`. When modules get wrapped by FSDP, the sequence of operations is: | ||
# 1) materialize module 2) call `reset_parameters()` 3) shard the module. | ||
# These operations are applied to each submodule 'bottom up' in the module hierarchy. | ||
empty_init_context = torch.device("meta") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Doesn't this break the loading of checkpoints for models that haven't been FSDP wrapped with setup
yet? For instance: https://github.com/Lightning-AI/lit-gpt/blob/1900b80424825cb221af0b63d19dd33b027d9aff/generate/base.py#L145-L151
Wouldn't load_state_dict
need assign=True
now? pytorch/pytorch#96161 (comment)
What does this PR do?
Fixes #16448
Fixes #18008
This allows you to instantiate very large models that wouldn't fit in memory (either CPU or GPU) as fast as possible. No memory for weights get allocated, neither in CPU nor GPU memory and parameters are materialized/initialized with random weights directly at the time the model gets wrapped and sharded in
Fabric.setup()
.Notes:
reset_parameters()
pytorch/pytorch#104187 in PyTorch 2.1 nightly.Requirement: Your submodules define a
reset_parameters()
method that can be called to init the params. This is the case for all built-in PyTorch layers. If you have a custom layer, you'd have to add that method.Limitation:
Since the model is put on the meta device, you can't reference the parameters of that model in an optimizer like so:
The user has to change the code to set up the model first, then create the optimizer referencing the FSDP parameters:
We have checks for this and explain it to the user in an error message. This will also be documented (see note above) and we will find a way to lift this limitation in the future.
Before submitting
PR review
Anyone in the community is welcome to review the PR.
Before you start reviewing, make sure you have read the review guidelines. In short, see the following bullet-list:
Reviewer checklist
cc @Borda @carmocca @justusschock @awaelchli