Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Save and load sharded checkpoints with FSDP in Fabric #17323

Merged
merged 46 commits into from Apr 16, 2023
Merged

Conversation

awaelchli
Copy link
Member

@awaelchli awaelchli commented Apr 11, 2023

What does this PR do?

Fixes #14816

This PR enables the following:

fabric = Fabric(strategy="fsdp", devices=2)

# this works now:
# (key names can be chosen freely by user)
checkpoint = {"model": model, "optimizer": optimizer, "other": "anything}
fabric.save(path, checkpoint)

# this works now:
fabric.load(path, checkpoint)

The checkpoint file structure looks like this (if devices=2):

os.listdir(path)
["meta.pt", ".metadata", "__0_0.distcp", "__1_0.distcp"]

The ".metadata" file is from the FSDP file writer, the "*.distcp" are the distributed checkpoint files holding the tensors, and the "meta.pt" is a file that Fabric's FSDPStrategy saves with all user dict data next to model and optimizer (from the example above: {"other": "anything})

Future Work

This is a minimal implementation for sharded checkpointing and loading. It is the best choice for large models and is the most memory efficient that FSDP can offer right now (offload to CPU, sharded state dict, chunk-wise filewriter). In the future, we need to

  1. Support saving and loading full-state dict as well (through a flag). This is important for use cases where we for example load a pretrained model from a single file and load it into an FSDP model.
  2. Enable the feature for torch < 2.0. This requires additional testing since some APIs/imports have changed slightly.

While testing, I stumbled upon this bug in PyTorch: pytorch/pytorch#99079

cc @Borda @awaelchli @carmocca @justusschock

@github-actions github-actions bot added the fabric lightning.fabric.Fabric label Apr 11, 2023
@awaelchli awaelchli added feature Is an improvement or enhancement checkpointing Related to checkpointing strategy: fsdp Fully Sharded Data Parallel labels Apr 11, 2023
@awaelchli awaelchli added this to the 2.1 milestone Apr 11, 2023
@awaelchli awaelchli changed the title WIP: Save and load sharded checkpoints with FSDP in Fabric Save and load sharded checkpoints with FSDP in Fabric Apr 14, 2023
@awaelchli awaelchli marked this pull request as ready for review April 14, 2023 03:40
@github-actions
Copy link
Contributor

github-actions bot commented Apr 14, 2023

⚡ Required checks status: All passing 🟢

Groups summary

🟢 pytorch_lightning: Tests workflow
Check ID Status
pl-cpu (macOS-11, lightning, 3.8, 1.11) success
pl-cpu (macOS-11, lightning, 3.9, 1.12) success
pl-cpu (macOS-11, lightning, 3.10, 1.13) success
pl-cpu (macOS-11, lightning, 3.10, 2.0) success
pl-cpu (macOS-11, lightning, 3.8, 1.11, oldest) success
pl-cpu (ubuntu-20.04, lightning, 3.8, 1.11) success
pl-cpu (ubuntu-20.04, lightning, 3.9, 1.12) success
pl-cpu (ubuntu-20.04, lightning, 3.10, 1.13) success
pl-cpu (ubuntu-20.04, lightning, 3.10, 2.0) success
pl-cpu (ubuntu-20.04, lightning, 3.8, 1.11, oldest) success
pl-cpu (windows-2022, lightning, 3.8, 1.11) success
pl-cpu (windows-2022, lightning, 3.9, 1.12) success
pl-cpu (windows-2022, lightning, 3.10, 1.13) success
pl-cpu (windows-2022, lightning, 3.10, 2.0) success
pl-cpu (windows-2022, lightning, 3.8, 1.11, oldest) success
pl-cpu (macOS-11, pytorch, 3.8, 1.13) success
pl-cpu (ubuntu-20.04, pytorch, 3.8, 1.13) success
pl-cpu (windows-2022, pytorch, 3.8, 1.13) success

These checks are required after the changes to .github/workflows/ci-tests-pytorch.yml, src/lightning/fabric/fabric.py, src/lightning/fabric/strategies/deepspeed.py, src/lightning/fabric/strategies/fsdp.py.

🟢 pytorch_lightning: Azure GPU
Check ID Status
pytorch-lightning (GPUs) success

These checks are required after the changes to src/lightning/fabric/fabric.py, src/lightning/fabric/strategies/deepspeed.py, src/lightning/fabric/strategies/fsdp.py.

🟢 fabric: Docs
Check ID Status
make-doctest (fabric) success
make-html (fabric) success

These checks are required after the changes to src/lightning/fabric/fabric.py, src/lightning/fabric/strategies/deepspeed.py, src/lightning/fabric/strategies/fsdp.py.

🟢 lightning_fabric: CPU workflow
Check ID Status
fabric-cpu (macOS-11, lightning, 3.8, 1.11) success
fabric-cpu (macOS-11, lightning, 3.9, 1.12) success
fabric-cpu (macOS-11, lightning, 3.10, 1.13) success
fabric-cpu (macOS-11, lightning, 3.10, 2.0) success
fabric-cpu (macOS-11, lightning, 3.8, 1.11, oldest) success
fabric-cpu (ubuntu-20.04, lightning, 3.8, 1.11) success
fabric-cpu (ubuntu-20.04, lightning, 3.9, 1.12) success
fabric-cpu (ubuntu-20.04, lightning, 3.10, 1.13) success
fabric-cpu (ubuntu-20.04, lightning, 3.10, 2.0) success
fabric-cpu (ubuntu-20.04, lightning, 3.8, 1.11, oldest) success
fabric-cpu (windows-2022, lightning, 3.8, 1.11) success
fabric-cpu (windows-2022, lightning, 3.9, 1.12) success
fabric-cpu (windows-2022, lightning, 3.10, 1.13) success
fabric-cpu (windows-2022, lightning, 3.10, 2.0) success
fabric-cpu (windows-2022, lightning, 3.8, 1.11, oldest) success
fabric-cpu (macOS-11, fabric, 3.8, 1.13) success
fabric-cpu (ubuntu-20.04, fabric, 3.8, 1.13) success
fabric-cpu (windows-2022, fabric, 3.8, 1.13) success

These checks are required after the changes to src/lightning/fabric/fabric.py, src/lightning/fabric/strategies/deepspeed.py, src/lightning/fabric/strategies/fsdp.py, tests/tests_fabric/helpers/models.py, tests/tests_fabric/strategies/test_deepspeed_integration.py, tests/tests_fabric/strategies/test_fsdp.py, tests/tests_fabric/strategies/test_fsdp_integration.py, tests/tests_fabric/test_fabric.py.

🟢 lightning_fabric: Azure GPU
Check ID Status
lightning-fabric (GPUs) success

These checks are required after the changes to src/lightning/fabric/fabric.py, src/lightning/fabric/strategies/deepspeed.py, src/lightning/fabric/strategies/fsdp.py, tests/tests_fabric/helpers/models.py, tests/tests_fabric/strategies/test_deepspeed_integration.py, tests/tests_fabric/strategies/test_fsdp.py, tests/tests_fabric/strategies/test_fsdp_integration.py, tests/tests_fabric/test_fabric.py.

🟢 mypy
Check ID Status
mypy success

These checks are required after the changes to src/lightning/fabric/fabric.py, src/lightning/fabric/strategies/deepspeed.py, src/lightning/fabric/strategies/fsdp.py.

🟢 install
Check ID Status
install-pkg (ubuntu-22.04, app, 3.8) success
install-pkg (ubuntu-22.04, app, 3.10) success
install-pkg (ubuntu-22.04, fabric, 3.8) success
install-pkg (ubuntu-22.04, fabric, 3.10) success
install-pkg (ubuntu-22.04, pytorch, 3.8) success
install-pkg (ubuntu-22.04, pytorch, 3.10) success
install-pkg (ubuntu-22.04, lightning, 3.8) success
install-pkg (ubuntu-22.04, lightning, 3.10) success
install-pkg (ubuntu-22.04, notset, 3.8) success
install-pkg (ubuntu-22.04, notset, 3.10) success
install-pkg (macOS-12, app, 3.8) success
install-pkg (macOS-12, app, 3.10) success
install-pkg (macOS-12, fabric, 3.8) success
install-pkg (macOS-12, fabric, 3.10) success
install-pkg (macOS-12, pytorch, 3.8) success
install-pkg (macOS-12, pytorch, 3.10) success
install-pkg (macOS-12, lightning, 3.8) success
install-pkg (macOS-12, lightning, 3.10) success
install-pkg (macOS-12, notset, 3.8) success
install-pkg (macOS-12, notset, 3.10) success
install-pkg (windows-2022, app, 3.8) success
install-pkg (windows-2022, app, 3.10) success
install-pkg (windows-2022, fabric, 3.8) success
install-pkg (windows-2022, fabric, 3.10) success
install-pkg (windows-2022, pytorch, 3.8) success
install-pkg (windows-2022, pytorch, 3.10) success
install-pkg (windows-2022, lightning, 3.8) success
install-pkg (windows-2022, lightning, 3.10) success
install-pkg (windows-2022, notset, 3.8) success
install-pkg (windows-2022, notset, 3.10) success

These checks are required after the changes to src/lightning/fabric/fabric.py, src/lightning/fabric/strategies/deepspeed.py, src/lightning/fabric/strategies/fsdp.py.

🟢 link-check
Check ID Status
check-md-links / markdown-link-check success

These checks are required after the changes to src/lightning/fabric/CHANGELOG.md.


Thank you for your contribution! 💜

Note
This comment is automatically generated and updates for 60 minutes every 180 seconds. If you have any other questions, contact carmocca for help.

src/lightning/fabric/strategies/fsdp.py Outdated Show resolved Hide resolved
src/lightning/fabric/strategies/fsdp.py Outdated Show resolved Hide resolved
Copy link
Collaborator

@lantiga lantiga left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome!

src/lightning/fabric/strategies/fsdp.py Outdated Show resolved Hide resolved
src/lightning/fabric/strategies/fsdp.py Outdated Show resolved Hide resolved
@mergify mergify bot added the ready PRs ready to be merged label Apr 14, 2023
awaelchli and others added 3 commits April 14, 2023 05:25
Co-authored-by: Luca Antiga <luca.antiga@gmail.com>
Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com>
Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com>
Copy link
Member

@carmocca carmocca left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think that porting this to support torch less than 2.0 is important. If you are worried about silent errors, we can raise an error at the start of FSDP if the torch version is lower than 2.0, suggesting to upgrade.

src/lightning/fabric/fabric.py Outdated Show resolved Hide resolved
@codecov
Copy link

codecov bot commented Apr 14, 2023

Codecov Report

Merging #17323 (bb2e0db) into master (8e7b949) will decrease coverage by 24%.
The diff coverage is 46%.

Additional details and impacted files
@@            Coverage Diff            @@
##           master   #17323     +/-   ##
=========================================
- Coverage      83%      59%    -24%     
=========================================
  Files         415      410      -5     
  Lines       31437    31427     -10     
=========================================
- Hits        26048    18596   -7452     
- Misses       5389    12831   +7442     

@github-actions github-actions bot added the ci Continuous Integration label Apr 16, 2023
@awaelchli awaelchli merged commit 0dc42f5 into master Apr 16, 2023
100 of 101 checks passed
@awaelchli awaelchli deleted the fsdp-checkpoint branch April 16, 2023 18:11
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
checkpointing Related to checkpointing ci Continuous Integration fabric lightning.fabric.Fabric feature Is an improvement or enhancement ready PRs ready to be merged strategy: fsdp Fully Sharded Data Parallel
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Checkpointing primitives for Fabric
4 participants