Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Meta device initialization for FSDP in Fabric #18122

Merged
merged 21 commits into from Aug 2, 2023
Merged

Conversation

awaelchli
Copy link
Member

@awaelchli awaelchli commented Jul 20, 2023

What does this PR do?

Fixes #16448
Fixes #18008

fabric = Fabric(strategy="fsdp")

with fabric.init_module(empty_init=True):
    model = Model() # the model is now on the meta device (no memory occupied)
...
# materialization and param init and sharding happens here:
model = fabric.setup(model)

# you can now start training
train(model)

This allows you to instantiate very large models that wouldn't fit in memory (either CPU or GPU) as fast as possible. No memory for weights get allocated, neither in CPU nor GPU memory and parameters are materialized/initialized with random weights directly at the time the model gets wrapped and sharded in Fabric.setup().

Notes:

  • This new feature is possible thanks to [RFC] Revisiting Meta Device Initialization with reset_parameters() pytorch/pytorch#104187 in PyTorch 2.1 nightly.
    Requirement: Your submodules define a reset_parameters() method that can be called to init the params. This is the case for all built-in PyTorch layers. If you have a custom layer, you'd have to add that method.
  • I have tested this in the lit-gpt repo with the full finetuning and the pretraining scripts.
  • Documentation will be updated in a follow-up. Reason: I'd like to send a PR to lit-gpt first to apply this in all models, making sure there are no unforeseen caveats.

Limitation:

Since the model is put on the meta device, you can't reference the parameters of that model in an optimizer like so:

with fabric.init_module(empty_init=True):
	model = ...
optimizer = Adam(model.parameters())  # references meta-device parameters
model, optimizer = fabric.setup(model, optimizer)  # error

The user has to change the code to set up the model first, then create the optimizer referencing the FSDP parameters:

with fabric.init_module(empty_init=True):
	model = ...
model = fabric.setup(model)  # set up model first
optimizer = Adam(model.parameters())  # references real parameters
optimizer = fabric.setup_optimizers(optimizer)

We have checks for this and explain it to the user in an error message. This will also be documented (see note above) and we will find a way to lift this limitation in the future.

Before submitting
  • Was this discussed/agreed via a GitHub issue? (not for typos and docs)
  • Did you read the contributor guideline, Pull Request section?
  • Did you make sure your PR does only one thing, instead of bundling different changes together?
  • Did you make sure to update the documentation with your changes? (if necessary)
  • Did you write any new necessary tests? (not for typos and docs)
  • Did you verify new and existing tests pass locally with your changes?
  • Did you list all the breaking changes introduced by this pull request?
  • Did you update the CHANGELOG? (not for typos, docs, test updates, or minor internal changes/refactors)

PR review

Anyone in the community is welcome to review the PR.
Before you start reviewing, make sure you have read the review guidelines. In short, see the following bullet-list:

Reviewer checklist
  • Is this pull request ready for review? (if not, please submit in draft mode)
  • Check that all items from Before submitting are resolved
  • Make sure the title is self-explanatory and the description concisely explains the PR
  • Add labels and milestones (and optionally projects) to the PR so it can be classified

cc @Borda @carmocca @justusschock @awaelchli

@awaelchli awaelchli marked this pull request as draft July 20, 2023 01:04
@github-actions github-actions bot added the fabric lightning.fabric.Fabric label Jul 20, 2023
@github-actions
Copy link
Contributor

github-actions bot commented Jul 20, 2023

⚡ Required checks status: All passing 🟢

Groups summary

🟢 pytorch_lightning: Tests workflow
Check ID Status
pl-cpu (macOS-11, lightning, 3.8, 1.11) success
pl-cpu (macOS-11, lightning, 3.9, 1.12) success
pl-cpu (macOS-11, lightning, 3.10, 1.13) success
pl-cpu (macOS-11, lightning, 3.10, 2.0) success
pl-cpu (macOS-11, lightning, 3.8, 1.11, oldest) success
pl-cpu (ubuntu-20.04, lightning, 3.8, 1.11) success
pl-cpu (ubuntu-20.04, lightning, 3.9, 1.12) success
pl-cpu (ubuntu-20.04, lightning, 3.10, 1.13) success
pl-cpu (ubuntu-20.04, lightning, 3.10, 2.0) success
pl-cpu (ubuntu-20.04, lightning, 3.8, 1.11, oldest) success
pl-cpu (windows-2022, lightning, 3.8, 1.11) success
pl-cpu (windows-2022, lightning, 3.9, 1.12) success
pl-cpu (windows-2022, lightning, 3.10, 1.13) success
pl-cpu (windows-2022, lightning, 3.10, 2.0) success
pl-cpu (windows-2022, lightning, 3.8, 1.11, oldest) success
pl-cpu (macOS-11, pytorch, 3.8, 1.13) success
pl-cpu (ubuntu-20.04, pytorch, 3.8, 1.13) success
pl-cpu (windows-2022, pytorch, 3.8, 1.13) success

These checks are required after the changes to src/lightning/fabric/fabric.py, src/lightning/fabric/strategies/fsdp.py.

🟢 pytorch_lightning: Azure GPU
Check ID Status
pytorch-lightning (GPUs) success

These checks are required after the changes to src/lightning/fabric/fabric.py, src/lightning/fabric/strategies/fsdp.py.

🟢 pytorch_lightning: Benchmarks
Check ID Status
lightning.Benchmarks success

These checks are required after the changes to src/lightning/fabric/fabric.py, src/lightning/fabric/strategies/fsdp.py.

🟢 fabric: Docs
Check ID Status
make-doctest (fabric) success
make-html (fabric) success

These checks are required after the changes to src/lightning/fabric/fabric.py, src/lightning/fabric/strategies/fsdp.py.

🟢 lightning_fabric: CPU workflow
Check ID Status
fabric-cpu (macOS-11, lightning, 3.8, 1.11) success
fabric-cpu (macOS-11, lightning, 3.9, 1.12) success
fabric-cpu (macOS-11, lightning, 3.10, 1.13) success
fabric-cpu (macOS-11, lightning, 3.10, 2.0) success
fabric-cpu (macOS-11, lightning, 3.8, 1.11, oldest) success
fabric-cpu (ubuntu-20.04, lightning, 3.8, 1.11) success
fabric-cpu (ubuntu-20.04, lightning, 3.9, 1.12) success
fabric-cpu (ubuntu-20.04, lightning, 3.10, 1.13) success
fabric-cpu (ubuntu-20.04, lightning, 3.10, 2.0) success
fabric-cpu (ubuntu-20.04, lightning, 3.8, 1.11, oldest) success
fabric-cpu (windows-2022, lightning, 3.8, 1.11) success
fabric-cpu (windows-2022, lightning, 3.9, 1.12) success
fabric-cpu (windows-2022, lightning, 3.10, 1.13) success
fabric-cpu (windows-2022, lightning, 3.10, 2.0) success
fabric-cpu (windows-2022, lightning, 3.8, 1.11, oldest) success
fabric-cpu (macOS-11, fabric, 3.8, 1.13) success
fabric-cpu (ubuntu-20.04, fabric, 3.8, 1.13) success
fabric-cpu (windows-2022, fabric, 3.8, 1.13) success

These checks are required after the changes to src/lightning/fabric/fabric.py, src/lightning/fabric/strategies/fsdp.py, tests/tests_fabric/strategies/test_fsdp.py, tests/tests_fabric/strategies/test_fsdp_integration.py, tests/tests_fabric/test_fabric.py.

🟢 lightning_fabric: Azure GPU
Check ID Status
lightning-fabric (GPUs) success

These checks are required after the changes to src/lightning/fabric/fabric.py, src/lightning/fabric/strategies/fsdp.py, tests/tests_fabric/strategies/test_fsdp.py, tests/tests_fabric/strategies/test_fsdp_integration.py, tests/tests_fabric/test_fabric.py.

🟢 mypy
Check ID Status
mypy success

These checks are required after the changes to src/lightning/fabric/fabric.py, src/lightning/fabric/strategies/fsdp.py.

🟢 install
Check ID Status
install-pkg (ubuntu-22.04, app, 3.8) success
install-pkg (ubuntu-22.04, app, 3.10) success
install-pkg (ubuntu-22.04, fabric, 3.8) success
install-pkg (ubuntu-22.04, fabric, 3.10) success
install-pkg (ubuntu-22.04, pytorch, 3.8) success
install-pkg (ubuntu-22.04, pytorch, 3.10) success
install-pkg (ubuntu-22.04, lightning, 3.8) success
install-pkg (ubuntu-22.04, lightning, 3.10) success
install-pkg (ubuntu-22.04, notset, 3.8) success
install-pkg (ubuntu-22.04, notset, 3.10) success
install-pkg (macOS-12, app, 3.8) success
install-pkg (macOS-12, app, 3.10) success
install-pkg (macOS-12, fabric, 3.8) success
install-pkg (macOS-12, fabric, 3.10) success
install-pkg (macOS-12, pytorch, 3.8) success
install-pkg (macOS-12, pytorch, 3.10) success
install-pkg (macOS-12, lightning, 3.8) success
install-pkg (macOS-12, lightning, 3.10) success
install-pkg (macOS-12, notset, 3.8) success
install-pkg (macOS-12, notset, 3.10) success
install-pkg (windows-2022, app, 3.8) success
install-pkg (windows-2022, app, 3.10) success
install-pkg (windows-2022, fabric, 3.8) success
install-pkg (windows-2022, fabric, 3.10) success
install-pkg (windows-2022, pytorch, 3.8) success
install-pkg (windows-2022, pytorch, 3.10) success
install-pkg (windows-2022, lightning, 3.8) success
install-pkg (windows-2022, lightning, 3.10) success
install-pkg (windows-2022, notset, 3.8) success
install-pkg (windows-2022, notset, 3.10) success

These checks are required after the changes to src/lightning/fabric/fabric.py, src/lightning/fabric/strategies/fsdp.py.

🟢 link-check
Check ID Status
check-md-links / markdown-link-check success

These checks are required after the changes to src/lightning/fabric/CHANGELOG.md.


Thank you for your contribution! 💜

Note
This comment is automatically generated and updates for 60 minutes every 180 seconds. If you have any other questions, contact carmocca for help.

@awaelchli awaelchli added strategy: fsdp Fully Sharded Data Parallel feature Is an improvement or enhancement fun Staff contributions outside working hours - to differentiate from the "community" label labels Jul 27, 2023
@awaelchli awaelchli added this to the 2.1 milestone Aug 1, 2023
@awaelchli awaelchli changed the title [WIP] Meta device init for FSDP Meta device initializatoin for FSDP in Fabric Aug 1, 2023
@awaelchli awaelchli marked this pull request as ready for review August 1, 2023 16:22
@awaelchli awaelchli changed the title Meta device initializatoin for FSDP in Fabric Meta device initialization for FSDP in Fabric Aug 1, 2023
@mergify mergify bot added the ready PRs ready to be merged label Aug 2, 2023
@awaelchli awaelchli merged commit 50e01c7 into master Aug 2, 2023
104 checks passed
@awaelchli awaelchli deleted the fabric/fsdp-meta-init branch August 2, 2023 11:58
Copy link
Member

@carmocca carmocca left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work here. Are you planning to port this to the Trainer? If not, we should open an issue. We should do it before 2.1

# Materialization happens in `setup`. When modules get wrapped by FSDP, the sequence of operations is:
# 1) materialize module 2) call `reset_parameters()` 3) shard the module.
# These operations are applied to each submodule 'bottom up' in the module hierarchy.
empty_init_context = torch.device("meta")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't this break the loading of checkpoints for models that haven't been FSDP wrapped with setup yet? For instance: https://github.com/Lightning-AI/lit-gpt/blob/1900b80424825cb221af0b63d19dd33b027d9aff/generate/base.py#L145-L151

Wouldn't load_state_dict need assign=True now? pytorch/pytorch#96161 (comment)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
fabric lightning.fabric.Fabric feature Is an improvement or enhancement fun Staff contributions outside working hours - to differentiate from the "community" label ready PRs ready to be merged strategy: fsdp Fully Sharded Data Parallel
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Multi-node training with FSDP results in weird behaviour Adopt FakeTensorMode for FSDP
4 participants