Skip to content

Commit

Permalink
[FSDP][2/N][Easy] Prepare _auto_wrap for fully_shard
Browse files Browse the repository at this point in the history
ghstack-source-id: d5085f211f31ab81ed7351a037f3c59b7caff515
Pull Request resolved: pytorch#104407
  • Loading branch information
awgu committed Jun 29, 2023
1 parent 8836487 commit 68fe084
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 46 deletions.
9 changes: 7 additions & 2 deletions test/distributed/fsdp/test_wrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,8 +151,13 @@ def test_error_already_wrapped(self, nested, cuda_init_mode):
if cuda_init_mode == CUDAInitMode.CUDA_AFTER:
wrapped_fsdp = wrapped_fsdp.cuda()

with self.assertRaisesRegex(ValueError, "to NOT be FullyShardedDataParallel"):
mod = FSDP(wrapped_fsdp, auto_wrap_policy=size_based_auto_wrap_policy)
wrapped_module_name = "lin1.1" if nested else "lin1"
with self.assertRaisesRegex(
ValueError,
"FSDP auto wrapping requires modules to not already have FSDP "
f"applied but found {wrapped_module_name} in",
):
FSDP(wrapped_fsdp, auto_wrap_policy=size_based_auto_wrap_policy)

@skip_if_lt_x_gpu(2)
@parametrize("use_or_policy", [True, False])
Expand Down
79 changes: 44 additions & 35 deletions torch/distributed/fsdp/_wrap_utils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
import collections
import functools
import inspect
import warnings
from functools import partial
from typing import Any, Deque, Dict, List, NamedTuple, Set, Tuple, Type
from typing import Any, Callable, Deque, Dict, List, NamedTuple, Set, Tuple, Type, Union

import torch
import torch.nn as nn
from torch.distributed.fsdp._common_utils import _is_fsdp_flattened
from torch.distributed.fsdp._common_utils import (
_get_module_fsdp_state,
_is_fsdp_flattened,
)
from torch.distributed.fsdp._utils import _override_module_mixed_precision

from torch.distributed.fsdp.wrap import (
Expand All @@ -32,29 +36,31 @@ class FullyShardedModuleState(NamedTuple):


def _auto_wrap(
auto_wrap_kwargs: Dict[str, Any],
root_module: nn.Module,
policy: Union[Callable, _FSDPPolicy],
ignored_modules: Set[nn.Module],
ignored_params: Set[nn.Parameter],
fsdp_kwargs: Dict[str, Any],
module_wrapper_cls: Any, # e.g. `FullyShardedDataParallel`
) -> None:
fsdp_fn: Callable, # `FullyShardedDataParallel` or `fully_shard`
):
"""
Recursively auto wraps the root module given by the key "module" in
``auto_wrap_kwargs`` with the arguments in ``auto_wrap_kwargs`` and
``fsdp_kwargs``.
Auto wraps modules in ``root_module`` 's tree according to ``policy``
following a post-order traversal.
Precondition: ``auto_wrap_policy`` contains the arguments expected by
``_recursive_wrap()``, where ``auto_wrap_policy`` is not ``None``.
``fsdp_kwargs`` contains all FSDP arguments except ``module``.
Precondition: ``fsdp_kwargs`` should contain all FSDP arguments except
``module``. This function accepts the kwargs dict directly since it gets
forwarded into the post-order traversal function.
"""
root_module = auto_wrap_kwargs["module"]
auto_wrap_policy = auto_wrap_kwargs["auto_wrap_policy"]
ignored_modules = auto_wrap_kwargs["ignored_modules"]
mixed_precision = fsdp_kwargs["mixed_precision"]
_check_nested_wrapping(root_module, module_wrapper_cls)
is_wrapper = inspect.isclass(fsdp_fn)
# TODO: We may relax this no-nested-wrapping constraint to support manual
# wrapping followed by auto wrapping.
_check_nested_wrapping(root_module)

# TODO: Start migration to refactored auto wrapping with `ModuleWrapPolicy`
if isinstance(auto_wrap_policy, ModuleWrapPolicy):
module_classes = auto_wrap_policy._module_classes
fsdp_kwargs["auto_wrap_policy"] = None
if isinstance(policy, ModuleWrapPolicy):
module_classes = policy._module_classes
fsdp_kwargs["auto_wrap_policy" if is_wrapper else "policy"] = None
target_module_to_kwargs = _run_module_wrap_policy(
root_module, module_classes, ignored_modules, fsdp_kwargs
)
Expand All @@ -70,45 +76,48 @@ def _auto_wrap(
root_module, mixed_precision._module_classes_to_ignore
)
_warn_on_overridden_mixed_precision(overridden_module_classes)
_post_order_apply(root_module, target_module_to_kwargs, module_wrapper_cls)
_post_order_apply(root_module, target_module_to_kwargs, fsdp_fn)
return

# Support new way to pass an auto wrap policy
if isinstance(auto_wrap_policy, _FSDPPolicy):
auto_wrap_policy = auto_wrap_policy.policy
assert auto_wrap_policy is not None
if isinstance(policy, _FSDPPolicy):
policy = policy.policy
assert policy is not None
recursive_wrap_kwargs = {
"module": root_module,
"auto_wrap_policy": policy,
"wrapper_cls": fsdp_fn,
"ignored_modules": ignored_modules,
"ignored_params": ignored_params,
"only_wrap_children": True,
}
if mixed_precision is not None:
# Wrap modules of the ignored types separately and register forward
# hooks to cast to fp32 and back to the original dtype, respectively
overridden_module_classes = _override_module_mixed_precision(
root_module, mixed_precision._module_classes_to_ignore
)
auto_wrap_policy = functools.partial(
policy = functools.partial(
_or_policy,
policies=[
auto_wrap_policy,
policy,
partial(
_wrap_module_cls_individually,
module_classes=mixed_precision._module_classes_to_ignore,
),
],
)
auto_wrap_kwargs["auto_wrap_policy"] = auto_wrap_policy
recursive_wrap_kwargs["auto_wrap_policy"] = policy
_warn_on_overridden_mixed_precision(overridden_module_classes)
_recursive_wrap(**auto_wrap_kwargs, **fsdp_kwargs)
_recursive_wrap(**recursive_wrap_kwargs, **fsdp_kwargs)


def _check_nested_wrapping(
root_module: nn.Module,
wrapper_cls: Any, # e.g. `FullyShardedDataParallel`
):
# For auto wrapping, submodules should not already be wrapped with FSDP
# since double wrapping is not supported
def _check_nested_wrapping(root_module: nn.Module):
for module_name, module in root_module.named_modules():
if isinstance(module, wrapper_cls):
if _get_module_fsdp_state(module) is not None:
raise ValueError(
f"Expected {module_name} to NOT be FullyShardedDataParallel "
"if using an `auto_wrap_policy`"
"FSDP auto wrapping requires modules to not already have "
f"FSDP applied but found {module_name} in\n{root_module}"
)


Expand Down
17 changes: 8 additions & 9 deletions torch/distributed/fsdp/fully_sharded_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,14 +406,6 @@ def __init__(
self, process_group, sharding_strategy, auto_wrap_policy
)
if auto_wrap_policy is not None:
auto_wrap_kwargs = {
"module": module,
"auto_wrap_policy": auto_wrap_policy,
"wrapper_cls": FullyShardedDataParallel,
"ignored_modules": self._ignored_modules,
"ignored_params": self._ignored_params,
"only_wrap_children": True, # avoid double wrapping the root
}
fsdp_kwargs = {
"process_group": process_group,
"sharding_strategy": sharding_strategy,
Expand All @@ -433,7 +425,14 @@ def __init__(
# process groups.
fsdp_kwargs["process_group"] = (self.process_group, self._inter_node_pg)

_auto_wrap(auto_wrap_kwargs, fsdp_kwargs, FullyShardedDataParallel)
_auto_wrap(
module,
auto_wrap_policy,
self._ignored_modules,
self._ignored_params,
fsdp_kwargs,
FullyShardedDataParallel,
)

backward_prefetch_limit = 1
forward_prefetch_limit = 1
Expand Down

0 comments on commit 68fe084

Please sign in to comment.