Skip to content

Commit

Permalink
[FSDP][3/N] Unify fully_shard auto wrap
Browse files Browse the repository at this point in the history
ghstack-source-id: 141ba57ed15da65b7052e81f706b82925a376d33
Pull Request resolved: pytorch#104408
  • Loading branch information
awgu committed Jun 30, 2023
1 parent 5c9a865 commit 27e4cf6
Show file tree
Hide file tree
Showing 7 changed files with 67 additions and 129 deletions.
15 changes: 12 additions & 3 deletions test/distributed/_composable/fully_shard/test_fully_shard_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,11 +251,20 @@ def test_nested_fully_shard_shared_state(self):
Tests that nested applications of ``fully_shard`` share the expected
data structure state.
"""
self.run_subtests(
{"use_policy": [False, True]},
self._test_nested_fully_shard_shared_state,
)

def _test_nested_fully_shard_shared_state(self, use_policy: bool):
device = torch.device("cuda")
composable_module = CompositeParamModel(device=device)
fully_shard(composable_module.u1)
fully_shard(composable_module.u2)
fully_shard(composable_module)
if use_policy:
fully_shard(composable_module, policy=ModuleWrapPolicy({UnitModule}))
else:
fully_shard(composable_module.u1)
fully_shard(composable_module.u2)
fully_shard(composable_module)

# Run a forward pass to trigger lazy initialization
inp = torch.randn((2, 100), device=device)
Expand Down
52 changes: 36 additions & 16 deletions torch/distributed/_composable/fully_shard.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,20 @@
_init_core_state,
_init_device_handle,
_init_ignored_module_states,
_init_param_handles_from_module,
_init_param_handle_from_module,
_init_prefetching_state,
_init_process_group_state,
_init_runtime_state,
_init_state_dict_state,
HYBRID_SHARDING_STRATEGIES,
)
from torch.distributed.fsdp._runtime_utils import (
_register_post_forward_hooks,
_register_pre_forward_hooks,
_register_root_pre_forward_hook,
)
from torch.distributed.fsdp._state_dict_utils import _register_all_state_dict_hooks
from torch.distributed.fsdp._wrap_utils import _auto_wrap
from torch.distributed.fsdp.api import (
BackwardPrefetch,
CPUOffload,
Expand Down Expand Up @@ -66,39 +68,57 @@ def fully_shard(
state = _init_process_group_state(
state, process_group, ShardingStrategy.FULL_SHARD, policy
)
limit_all_gathers = True
use_orig_params = True
backward_prefetch_limit = 1
forward_prefetch_limit = 1
if policy is not None:
fsdp_kwargs = {
"process_group": process_group,
"strategy": strategy,
"mixed_precision": mixed_precision,
"cpu_offload": cpu_offload,
"ignored_modules": ignored_modules,
"device_id": device_id,
"param_init_fn": param_init_fn,
"sync_module_states": sync_module_states,
"forward_prefetch": forward_prefetch,
"ignored_states": ignored_states,
}
if strategy in HYBRID_SHARDING_STRATEGIES:
fsdp_kwargs["process_group"] = (state.process_group, state._inter_node_pg)
_auto_wrap(
module,
policy,
state._ignored_modules,
state._ignored_params,
fsdp_kwargs,
fully_shard,
)
state = _init_core_state(
state,
strategy or ShardingStrategy.FULL_SHARD,
mixed_precision,
cpu_offload,
limit_all_gathers,
use_orig_params,
backward_prefetch_limit,
forward_prefetch_limit,
limit_all_gathers=True,
use_orig_params=True,
backward_prefetch_limit=1,
forward_prefetch_limit=1,
)
state = _init_runtime_state(state)
state = _init_prefetching_state(
state, BackwardPrefetch.BACKWARD_PRE, forward_prefetch=forward_prefetch
)
state = _init_buffer_state(state, module)
state = _init_param_handles_from_module(
state,
module,
policy,
device_id,
param_init_fn,
sync_module_states,
state = _init_param_handle_from_module(
state, module, device_id, param_init_fn, sync_module_states
)
state = _init_state_dict_state(state)
_register_all_state_dict_hooks(state)
modules = list(module.modules())
_register_pre_forward_hooks(state, modules)
_register_post_forward_hooks(state, modules)
_register_root_pre_forward_hook(state, module) # prepend last
# Always insert the state for the passed-in module even if it has no
# managed parameters, in which case it has no handles and does not appear
# in `_fully_sharded_module_to_handles`
_insert_module_state(module, state)
for submodule in module.modules():
if (
submodule in state._fully_sharded_module_to_handles
Expand Down
4 changes: 4 additions & 0 deletions torch/distributed/fsdp/_common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,10 @@ def _module_handles(state: _FSDPState, module: nn.Module) -> List:
the handles that contain some parameter in ``module``.
"""
if _is_composable(state):
# A valid FSDP state may have no managed parameters and hence no
# handles, meaning no entry in `_fully_sharded_module_to_handles`
if len(state._handles) == 0:
return []
assert (
module in state._fully_sharded_module_to_handles
), f"Expects a fully sharded module but got {module} on rank {state.rank}"
Expand Down
90 changes: 1 addition & 89 deletions torch/distributed/fsdp/_init_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
Optional,
Set,
Tuple,
Type,
Union,
)

Expand All @@ -35,7 +34,6 @@
TrainingState,
)
from torch.distributed.fsdp._limiter_utils import _FreeEventQueue
from torch.distributed.fsdp._wrap_utils import _get_fully_sharded_module_to_states
from torch.distributed.fsdp.api import (
BackwardPrefetch,
CPUOffload,
Expand Down Expand Up @@ -479,11 +477,9 @@ def _init_param_handle_from_module(
device_id: Optional[Union[int, torch.device]],
param_init_fn: Optional[Callable[[nn.Module], None]],
sync_module_states: bool,
module_wrapper_cls: Type,
) -> _FSDPState:
"""
Initializes a ``FlatParamHandle`` from a module ``fully_sharded_module``.
This is the module wrapper code path.
"""
_check_single_device_module(fully_sharded_module, state._ignored_params)
device_from_device_id = _get_device_from_device_id(device_id, state.rank)
Expand All @@ -498,7 +494,7 @@ def _init_param_handle_from_module(
elif is_torchdistX_deferred_init:
deferred_init.materialize_module(
fully_sharded_module,
check_fn=lambda k: not isinstance(k, module_wrapper_cls),
check_fn=lambda k: _get_module_fsdp_state(k) is None,
)
_move_module_to_device(
fully_sharded_module, state._ignored_params, device_from_device_id
Expand All @@ -519,90 +515,6 @@ def _init_param_handle_from_module(
return state


@no_type_check
def _init_param_handles_from_module(
state: _FSDPState,
root_module: nn.Module,
policy: _FSDPPolicy,
device_id: Optional[Union[int, torch.device]],
param_init_fn: Optional[Callable[[nn.Module], None]],
sync_module_states: bool,
) -> _FSDPState:
"""
Initializes all ``FlatParamHandle`` s from a module ``root_module``. This
is the non-module-wrapper code path. ``root_module`` is guaranteed to be
a fully sharded module, and some of its submodules may be as well,
depending on ``policy``. See [Note: Fully Sharded Module].
"""
fully_sharded_module_to_states = _get_fully_sharded_module_to_states(
root_module,
policy,
state._ignored_modules,
state._ignored_params,
)
_check_single_device_module(root_module, state._ignored_params)
device_from_device_id = _get_device_from_device_id(device_id, state.rank)
# Initialize and shard `FlatParamHandle`s one by one following reverse
# depth-first order (i.e. reverse `.modules()` order), which represents a
# reverse topological sort order. This avoids increasing peak GPU memory
# usage when the unsharded model exists on CPU or meta device.
# NOTE: This order differs from that followed by the wrapper path when
# using auto wrapping, which also represents a valid reverse topological
# sort order, but the difference does not matter.
materialized_module = False
for fully_sharded_module, (params, buffers) in reversed(
fully_sharded_module_to_states.items()
):
# Materialize the module if needed
is_meta_module, is_torchdistX_deferred_init = _need_to_materialize_module(
fully_sharded_module, state._ignored_params
)
if is_meta_module or is_torchdistX_deferred_init:
materialized_module = True
# Save the parameter and buffer names to reacquire references after
# after materialization since their variables may change
param_names, buffer_names = _get_state_names_for_states(
fully_sharded_module, params, buffers
)
if (
is_meta_module or is_torchdistX_deferred_init
) and param_init_fn is not None:
_materialize_with_param_init_fn(fully_sharded_module, param_init_fn)
elif is_meta_module:
_materialize_meta_module(fully_sharded_module, device_id)
elif is_torchdistX_deferred_init:
deferred_init.materialize_module(
root_module,
check_fn=lambda _: True,
)
if materialized_module:
# Reacquire references using the pre-computed state names
params = [
fully_sharded_module.get_parameter(param_name)
for param_name in param_names
]
buffers = [
fully_sharded_module.get_buffer(buffer_name)
for buffer_name in buffer_names
]
_move_states_to_device(params, buffers, device_from_device_id)
if state.compute_device is None: # only need to set once
state.compute_device = _get_compute_device(
fully_sharded_module,
state._ignored_params,
device_from_device_id,
state.rank,
)
if sync_module_states:
_sync_module_states(params, buffers, state.process_group)
_init_param_handle_from_params(state, params, fully_sharded_module)
# Reverse `_handles` to preserve depth-first `.modules()` order for
# consistency with the wrapper path (namely, so that `_get_fsdp_handles()`
# returns the same ordering for both paths).
state._handles.reverse()
return state


@no_type_check
def _init_param_handle_from_params(
state: _FSDPState,
Expand Down
31 changes: 13 additions & 18 deletions torch/distributed/fsdp/_runtime_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1475,24 +1475,19 @@ def _get_buffers_and_dtypes_for_computation(
_p_assert(state._is_root, "Expects the root to cast buffers")
buffers: List[torch.Tensor] = []
buffer_dtypes: List[Optional[torch.dtype]] = []
if _is_composable(state):
buffers = [
buffer for module in root_module.modules() for buffer in module.buffers()
]
buffer_dtypes = [
state.mixed_precision.buffer_dtype for _ in range(len(buffers))
]
else:
visited_buffers = set()
# Traverse the FSDP instances bottom-up so that we prefer the owning
# FSDP instance's mixed precision setting for each buffer
for fsdp_module in reversed(traversal_utils._get_fsdp_states(root_module)):
for buffer in fsdp_module.buffers():
if buffer in visited_buffers:
continue
visited_buffers.add(buffer)
buffers.append(buffer)
buffer_dtypes.append(fsdp_module.mixed_precision.buffer_dtype)
visited_buffers: Set[torch.Tensor] = set()
# Traverse the FSDP states bottom-up so that we prefer the owning FSDP
# instance's mixed precision setting for each buffer
fsdp_states, fsdp_modules = traversal_utils._get_fsdp_states_with_modules(
root_module
)
for fsdp_state, fsdp_module in zip(reversed(fsdp_states), reversed(fsdp_modules)):
for buffer in fsdp_module.buffers():
if buffer in visited_buffers:
continue
visited_buffers.add(buffer)
buffers.append(buffer)
buffer_dtypes.append(fsdp_state.mixed_precision.buffer_dtype)
assert len(buffers) == len(buffer_dtypes), f"{len(buffers)} {len(buffer_dtypes)}"
return buffers, buffer_dtypes

Expand Down
3 changes: 1 addition & 2 deletions torch/distributed/fsdp/_traversal_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,7 @@ def _get_fsdp_states_with_modules(
# Perform depth-first search from `module` to ensure that we do not
# traverse into an incompatible API's subtree (use DFS instead of BFS to
# match `.modules()` order)
deque: Deque[nn.Module] = collections.deque()
deque.append(module)
deque: Deque[nn.Module] = collections.deque([module])
while deque:
submodule = deque.popleft()
visited_modules.add(submodule)
Expand Down
1 change: 0 additions & 1 deletion torch/distributed/fsdp/fully_sharded_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,7 +455,6 @@ def __init__(
device_id,
param_init_fn,
sync_module_states,
FullyShardedDataParallel,
)
self._fsdp_wrapped_module = module
if not use_orig_params:
Expand Down

0 comments on commit 27e4cf6

Please sign in to comment.