Skip to content

Commit

Permalink
[FSDP][1/N] Move wrapper ModuleWrapPolicy to new path
Browse files Browse the repository at this point in the history
ghstack-source-id: 6600d7ad0d44834537abefee871178fd3cdd6ff7
Pull Request resolved: pytorch#104346
  • Loading branch information
awgu committed Jul 5, 2023
1 parent 20b39f4 commit 600332f
Show file tree
Hide file tree
Showing 4 changed files with 181 additions and 37 deletions.
2 changes: 1 addition & 1 deletion test/distributed/fsdp/test_fsdp_mixed_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -731,7 +731,7 @@ def never_wrap_policy(*args, **kwargs):
)
with self.assertWarnsRegex(
expected_warning=UserWarning,
expected_regex="batch norm submodules will be wrapped as separate",
expected_regex="These modules will be wrapped as separate FSDP",
):
model = FSDP(
net,
Expand Down
46 changes: 34 additions & 12 deletions torch/distributed/fsdp/_utils.py
Original file line number Diff line number Diff line change
@@ -1,47 +1,69 @@
import weakref
from functools import partial
from typing import Any, Dict, Type
from typing import Any, Dict, Iterable, Set, Type

import torch
import torch.nn as nn

from torch.distributed.utils import _apply_to_tensors
from torch.utils._mode_utils import no_dispatch


# Save a global mapping from module to its input tensor dtype to be populated
# during the forward pre-hook and consumed in the forward post-hook when
# overriding a module's mixed precision
# NOTE: We currently take the last input tensor's dtype in the case of multiple
# floating-point input tensors, which may be incorrect. However, since there is
# not a 1:1 correspondence between input and output tensors, we must use *some*
# heuristic like this to predict the desired output dtype.
_MODULE_TO_INP_DTYPE: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary()


def _override_module_mixed_precision(
root: torch.nn.Module,
module_cls_to_override: Type[torch.nn.Module],
module_classes_to_override: Iterable[Type[nn.Module]],
wrap_override_dict: Dict[str, Any] = {"mixed_precision": None}, # noqa: B006
):
) -> Set[Type[nn.Module]]:
module_classes_to_override = tuple(set(module_classes_to_override))
# Return a set of the actually overridden module classes
overridden_module_classes: Set[Type[nn.Module]] = set()
for mod in root.modules():
if isinstance(mod, module_cls_to_override):
if isinstance(mod, module_classes_to_override):
overridden_module_classes.add(type(mod))
mod._wrap_overrides = wrap_override_dict # type: ignore[assignment]
# TODO: We need to run this mixed precision ignored module in fp32,
# but ensure subsequent modules, that may possibly be running with
# mixed precision, still receive the appropriate precision inputs
# without user having to adjust mixed precision config too much.
# As a result, we attach pre and post forward hooks to up / down
# cast. We should revisit this design.
old_dtype = None

def cast_fn(dtype, x: torch.Tensor) -> torch.Tensor:
def cast_fn(
dtype: torch.dtype, module: nn.Module, x: torch.Tensor
) -> torch.Tensor:
if not torch.is_floating_point(x) or x.dtype == dtype:
return x
nonlocal old_dtype
old_dtype = x.dtype
_MODULE_TO_INP_DTYPE[module] = x.dtype
return x.to(dtype)

def forward_pre_hook(module, args):
return _apply_to_tensors(partial(cast_fn, torch.float32), args)
return _apply_to_tensors(partial(cast_fn, torch.float32, module), args)

def forward_post_hook(module, args, output):
nonlocal old_dtype
if old_dtype is not None:
return _apply_to_tensors(partial(cast_fn, old_dtype), output)
# NOTE: If the forward did not have any floating-point tensors,
# then the dtype will not be set for this module, and we do not
# upcast the dtype.
if module in _MODULE_TO_INP_DTYPE:
old_dtype = _MODULE_TO_INP_DTYPE[module]
return _apply_to_tensors(
partial(cast_fn, old_dtype, module), output
)

# We intentionally append both of these hooks so that they run after
# all other hooks.
mod.register_forward_pre_hook(forward_pre_hook, prepend=False)
mod.register_forward_hook(forward_post_hook, prepend=False)
return overridden_module_classes


def _same_storage(x: torch.Tensor, y: torch.Tensor) -> bool:
Expand Down
90 changes: 66 additions & 24 deletions torch/distributed/fsdp/_wrap_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import functools
import warnings
from functools import partial
from typing import Any, Deque, Dict, List, NamedTuple, Set, Tuple
from typing import Any, Deque, Dict, List, NamedTuple, Set, Tuple, Type

import torch
import torch.nn as nn
Expand All @@ -12,8 +12,12 @@
from torch.distributed.fsdp.wrap import (
_FSDPPolicy,
_or_policy,
_post_order_apply,
_recursive_wrap,
_run_mixed_precision_override_policy,
_run_module_wrap_policy,
_wrap_module_cls_individually,
ModuleWrapPolicy,
)


Expand Down Expand Up @@ -41,27 +45,44 @@ def _auto_wrap(
``_recursive_wrap()``, where ``auto_wrap_policy`` is not ``None``.
``fsdp_kwargs`` contains all FSDP arguments except ``module``.
"""
root_module = auto_wrap_kwargs["module"]
auto_wrap_policy = auto_wrap_kwargs["auto_wrap_policy"]
ignored_modules = auto_wrap_kwargs["ignored_modules"]
mixed_precision = fsdp_kwargs["mixed_precision"]
_check_nested_wrapping(root_module, module_wrapper_cls)

# TODO: Start migration to refactored auto wrapping with `ModuleWrapPolicy`
if isinstance(auto_wrap_policy, ModuleWrapPolicy):
module_classes = auto_wrap_policy._module_classes
fsdp_kwargs["auto_wrap_policy"] = None
target_module_to_kwargs = _run_module_wrap_policy(
root_module, module_classes, ignored_modules, fsdp_kwargs
)
if mixed_precision is not None:
target_module_to_kwargs = _run_mixed_precision_override_policy(
root_module,
mixed_precision._module_classes_to_ignore,
ignored_modules,
fsdp_kwargs,
target_module_to_kwargs,
)
overridden_module_classes = _override_module_mixed_precision(
root_module, mixed_precision._module_classes_to_ignore
)
_warn_on_overridden_mixed_precision(overridden_module_classes)
_post_order_apply(root_module, target_module_to_kwargs, module_wrapper_cls)
return

# Support new way to pass an auto wrap policy
if isinstance(auto_wrap_policy, _FSDPPolicy):
auto_wrap_policy = auto_wrap_policy.policy
root_module = auto_wrap_kwargs["module"]
assert auto_wrap_policy is not None
# For auto wrapping, submodules should not already be wrapped with FSDP
# since double wrapping is not supported
for module_name, module in root_module.named_modules():
if isinstance(module, module_wrapper_cls):
raise ValueError(
f"Expected {module_name} to NOT be FullyShardedDataParallel "
"if using an `auto_wrap_policy`"
)
mixed_precision = fsdp_kwargs["mixed_precision"]
if mixed_precision is not None:
for mp_module_to_override in mixed_precision._module_classes_to_ignore:
# Make modules of this particular type run in fp32 by wrapping them in their own
# FSDP unit.
_override_module_mixed_precision(root_module, mp_module_to_override)

# Wrap modules of the ignored types separately and register forward
# hooks to cast to fp32 and back to the original dtype, respectively
overridden_module_classes = _override_module_mixed_precision(
root_module, mixed_precision._module_classes_to_ignore
)
auto_wrap_policy = functools.partial(
_or_policy,
policies=[
Expand All @@ -72,17 +93,38 @@ def _auto_wrap(
),
],
)
warnings.warn(
"Both mixed precision and an `auto_wrap_policy` were specified "
"for FSDP, where the wrapped module has batch norm submodules. "
"The batch norm submodules will be wrapped as separate FSDP "
"instances with mixed precision disabled since some batch norm "
"kernels do not support low precision."
)
auto_wrap_kwargs["auto_wrap_policy"] = auto_wrap_policy
auto_wrap_kwargs["auto_wrap_policy"] = auto_wrap_policy
_warn_on_overridden_mixed_precision(overridden_module_classes)
_recursive_wrap(**auto_wrap_kwargs, **fsdp_kwargs)


def _check_nested_wrapping(
root_module: nn.Module,
wrapper_cls: Any, # e.g. `FullyShardedDataParallel`
):
# For auto wrapping, submodules should not already be wrapped with FSDP
# since double wrapping is not supported
for module_name, module in root_module.named_modules():
if isinstance(module, wrapper_cls):
raise ValueError(
f"Expected {module_name} to NOT be FullyShardedDataParallel "
"if using an `auto_wrap_policy`"
)


def _warn_on_overridden_mixed_precision(
overridden_module_classes: Set[Type[nn.Module]],
):
if len(overridden_module_classes) == 0:
return
warnings.warn(
"Both mixed precision and an auto_wrap_policy were specified to FSDP, "
f"where the wrapped module has submodules of type:\n{overridden_module_classes}\n"
"These modules will be wrapped as separate FSDP instacnes with mixed "
"precision disabled."
)


def _get_fully_sharded_module_to_states(
root_module: nn.Module,
auto_wrap_policy: _FSDPPolicy,
Expand Down
80 changes: 80 additions & 0 deletions torch/distributed/fsdp/wrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# LICENSE file in the root directory of this source tree.

import contextlib
import copy
import functools
from abc import ABC, abstractmethod
from typing import (
Expand All @@ -12,6 +13,7 @@
cast,
Dict,
Generator,
Iterable,
Optional,
Sequence,
Set,
Expand All @@ -32,6 +34,83 @@
]


def _post_order_apply(
root_module: nn.Module,
target_module_to_kwargs: Dict[nn.Module, Dict[str, Any]],
fn_to_apply: Callable[[Any], nn.Module],
) -> None:
"""
This applies a function ``fn_to_apply`` to some target modules with
specified kwargs per target module (via ``target_module_to_kwargs``)
following a post-order traversal. This reduces the problem to constructing
``target_module_to_kwargs``.
NOTE: Since the auto wrap policy is an arg to FSDP, this does not apply the
function to the root module, as the caller should be exactly that.
"""
# Track visited modules to avoid visiting shared modules multiple times
visited_modules: Set[nn.Module] = {root_module}

def _post_order_apply_inner(
module: nn.Module,
module_name: str,
parent_module: Optional[nn.Module],
):
for child_module_name, child_module in module.named_children():
if child_module not in visited_modules:
visited_modules.add(child_module)
_post_order_apply_inner(child_module, child_module_name, module)
if module in target_module_to_kwargs and module is not root_module:
assert module_name != "", (
"Non-root modules should have their module name set but got "
f"an empty module name for {module}"
)
kwargs = target_module_to_kwargs[module]
new_module = fn_to_apply(module, **kwargs)
setattr(parent_module, module_name, new_module)

_post_order_apply_inner(root_module, "", None)


def _run_module_wrap_policy(
root_module: nn.Module,
module_classes: Iterable[Type[nn.Module]],
ignored_modules: Set[nn.Module],
fsdp_kwargs: Dict[str, Any],
) -> Dict[nn.Module, Dict[str, Any]]:
"""
TODO: To match the existing ``ModuleWrapPolicy`` behavior, every wrapped
module shares the same FSDP kwargs.
"""
module_classes_tuple = tuple(set(module_classes))
target_module_to_kwargs: Dict[nn.Module, Dict[str, Any]] = {}
for module in root_module.modules():
if module in ignored_modules:
continue
elif isinstance(module, module_classes_tuple):
# Shallow copy to avoid coupling changes across modules
target_module_to_kwargs[module] = copy.copy(fsdp_kwargs)
return target_module_to_kwargs


def _run_mixed_precision_override_policy(
root_module: nn.Module,
module_classes: Iterable[Type[nn.Module]],
ignored_modules: Set[nn.Module],
fsdp_kwargs: Dict[str, Any],
target_module_to_kwargs: Dict[nn.Module, Dict[str, Any]],
):
module_classes_tuple = tuple(set(module_classes))
for module in root_module.modules():
if module in ignored_modules:
continue
elif isinstance(module, module_classes_tuple):
# This policy overrides any existing policy
target_module_to_kwargs[module] = fsdp_kwargs
target_module_to_kwargs[module]["mixed_precision"] = None
return target_module_to_kwargs


def always_wrap_policy(*args, **kwargs) -> bool:
"""
A simple recursive wrap policy that always returns ``True``. This means
Expand Down Expand Up @@ -96,6 +175,7 @@ def __init__(self, module_classes: Set[Type[nn.Module]]):
_module_wrap_policy,
module_classes=module_classes,
)
self._module_classes = module_classes
self._module_classes_str = str(module_classes)

@property
Expand Down

0 comments on commit 600332f

Please sign in to comment.