Closed
Description
Description & Motivation
The FSDPStrategy
can use hybrid sharding strategy to shard across smaller sets of ranks in the global dist group. However, it is not flexible enough to let user easily specify the sharding scale.
Pitch
The FSDPStrategy
can use hybrid sharding strategy to shard across smaller sets of ranks in the global dist group. Currently there are two path to use it in Lightning:
- Specify
sharding_strategy
as one of the hybrid sharding strategies. This will shard within one node, and replicate across nodes. - Specify
sharding_strategy
as one of the hybrid sharding strategies, and provideprocess_group
askwards
toFSDPStrategy
. This let user specify how large the sharding scale is. However, it is not easy for user to insert torch dist groups creation code and prepare theprocess_group
ahead of time, because Lightning handles torch dist init_process_group automatically in trainer, or the fabric launcher.
So I'm looking forward to a easier way to use HSDP within Lightning, like:
FSDPStrategy(..., sharding_strategy="HYBRID_SHARD", fsdp_size=16)
to easily shard at specified scale, and let Lightning handle process_group
preparation for PyTorch FSPD wrapper.
Alternatives
No response
Additional context
No response