Fabric always wraps the root module with FSDP #18005
Labels
bug
Something isn't working
fabric
lightning.fabric.Fabric
strategy: fsdp
Fully Sharded Data Parallel
ver: 2.1.x
Milestone
Bug description
When the user is manually wrapping specific sections of the module with FSDP without using a wrapping policy, for example:
Fabric
still wraps the root module into FSDP.This could defeat the memory savings expected by the user, from the FSDP docs:
This is not a problem with the Trainer because the user must wrap it inside
configure_sharded_hook
and the current behaviour is thatFSDPStrategy.setup_module
will be skipped if the hook is defined: https://github.com/Lightning-AI/lightning/blob/f4240ca42c75ad67b2655351b38830fa0ba82cba/src/lightning/pytorch/strategies/fsdp.py#L265-L271What version are you seeing the problem on?
master
How to reproduce the bug
Error messages and logs
No response
Environment
No response
More info
Discovered during #18004
The complex part of this bugfix is that current FSDP checkpointing logic assumes that the root is wrapped.
I found PyTorch tests with non-root FSDP modules: https://github.com/pytorch/pytorch/blob/main/torch/testing/_internal/common_fsdp.py#L438-L446
cc @carmocca @justusschock @awaelchli
The text was updated successfully, but these errors were encountered: