diff --git a/test/distributed/fsdp/test_fsdp_mixed_precision.py b/test/distributed/fsdp/test_fsdp_mixed_precision.py index 2e43d4d05f6da..1adb87e936e5b 100644 --- a/test/distributed/fsdp/test_fsdp_mixed_precision.py +++ b/test/distributed/fsdp/test_fsdp_mixed_precision.py @@ -731,7 +731,7 @@ def never_wrap_policy(*args, **kwargs): ) with self.assertWarnsRegex( expected_warning=UserWarning, - expected_regex="batch norm submodules will be wrapped as separate", + expected_regex="These modules will be wrapped as separate FSDP", ): model = FSDP( net, diff --git a/torch/distributed/fsdp/_utils.py b/torch/distributed/fsdp/_utils.py index 5a53a62856396..af21f72a0a2a0 100644 --- a/torch/distributed/fsdp/_utils.py +++ b/torch/distributed/fsdp/_utils.py @@ -1,19 +1,35 @@ +import weakref from functools import partial -from typing import Any, Dict, Type +from typing import Any, Dict, Iterable, Set, Type import torch +import torch.nn as nn from torch.distributed.utils import _apply_to_tensors from torch.utils._mode_utils import no_dispatch +# Save a global mapping from module to its input tensor dtype to be populated +# during the forward pre-hook and consumed in the forward post-hook when +# overriding a module's mixed precision +# NOTE: We currently take the last input tensor's dtype in the case of multiple +# floating-point input tensors, which may be incorrect. However, since there is +# not a 1:1 correspondence between input and output tensors, we must use *some* +# heuristic like this to predict the desired output dtype. +_MODULE_TO_INP_DTYPE: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary() + + def _override_module_mixed_precision( root: torch.nn.Module, - module_cls_to_override: Type[torch.nn.Module], + module_classes_to_override: Iterable[Type[nn.Module]], wrap_override_dict: Dict[str, Any] = {"mixed_precision": None}, # noqa: B006 -): +) -> Set[Type[nn.Module]]: + module_classes_to_override = tuple(set(module_classes_to_override)) + # Return a set of the actually overridden module classes + overridden_module_classes: Set[Type[nn.Module]] = set() for mod in root.modules(): - if isinstance(mod, module_cls_to_override): + if isinstance(mod, module_classes_to_override): + overridden_module_classes.add(type(mod)) mod._wrap_overrides = wrap_override_dict # type: ignore[assignment] # TODO: We need to run this mixed precision ignored module in fp32, # but ensure subsequent modules, that may possibly be running with @@ -21,27 +37,33 @@ def _override_module_mixed_precision( # without user having to adjust mixed precision config too much. # As a result, we attach pre and post forward hooks to up / down # cast. We should revisit this design. - old_dtype = None - def cast_fn(dtype, x: torch.Tensor) -> torch.Tensor: + def cast_fn( + dtype: torch.dtype, module: nn.Module, x: torch.Tensor + ) -> torch.Tensor: if not torch.is_floating_point(x) or x.dtype == dtype: return x - nonlocal old_dtype - old_dtype = x.dtype + _MODULE_TO_INP_DTYPE[module] = x.dtype return x.to(dtype) def forward_pre_hook(module, args): - return _apply_to_tensors(partial(cast_fn, torch.float32), args) + return _apply_to_tensors(partial(cast_fn, torch.float32, module), args) def forward_post_hook(module, args, output): - nonlocal old_dtype - if old_dtype is not None: - return _apply_to_tensors(partial(cast_fn, old_dtype), output) + # NOTE: If the forward did not have any floating-point tensors, + # then the dtype will not be set for this module, and we do not + # upcast the dtype. + if module in _MODULE_TO_INP_DTYPE: + old_dtype = _MODULE_TO_INP_DTYPE[module] + return _apply_to_tensors( + partial(cast_fn, old_dtype, module), output + ) # We intentionally append both of these hooks so that they run after # all other hooks. mod.register_forward_pre_hook(forward_pre_hook, prepend=False) mod.register_forward_hook(forward_post_hook, prepend=False) + return overridden_module_classes def _same_storage(x: torch.Tensor, y: torch.Tensor) -> bool: diff --git a/torch/distributed/fsdp/_wrap_utils.py b/torch/distributed/fsdp/_wrap_utils.py index 71a8896d6551f..3be14ad7f565c 100644 --- a/torch/distributed/fsdp/_wrap_utils.py +++ b/torch/distributed/fsdp/_wrap_utils.py @@ -2,7 +2,7 @@ import functools import warnings from functools import partial -from typing import Any, Deque, Dict, List, NamedTuple, Set, Tuple +from typing import Any, Deque, Dict, List, NamedTuple, Set, Tuple, Type import torch import torch.nn as nn @@ -12,8 +12,12 @@ from torch.distributed.fsdp.wrap import ( _FSDPPolicy, _or_policy, + _post_order_apply, _recursive_wrap, + _run_mixed_precision_override_policy, + _run_module_wrap_policy, _wrap_module_cls_individually, + ModuleWrapPolicy, ) @@ -41,27 +45,44 @@ def _auto_wrap( ``_recursive_wrap()``, where ``auto_wrap_policy`` is not ``None``. ``fsdp_kwargs`` contains all FSDP arguments except ``module``. """ + root_module = auto_wrap_kwargs["module"] auto_wrap_policy = auto_wrap_kwargs["auto_wrap_policy"] + ignored_modules = auto_wrap_kwargs["ignored_modules"] + mixed_precision = fsdp_kwargs["mixed_precision"] + _check_nested_wrapping(root_module, module_wrapper_cls) + + # TODO: Start migration to refactored auto wrapping with `ModuleWrapPolicy` + if isinstance(auto_wrap_policy, ModuleWrapPolicy): + module_classes = auto_wrap_policy._module_classes + fsdp_kwargs["auto_wrap_policy"] = None + target_module_to_kwargs = _run_module_wrap_policy( + root_module, module_classes, ignored_modules, fsdp_kwargs + ) + if mixed_precision is not None: + target_module_to_kwargs = _run_mixed_precision_override_policy( + root_module, + mixed_precision._module_classes_to_ignore, + ignored_modules, + fsdp_kwargs, + target_module_to_kwargs, + ) + overridden_module_classes = _override_module_mixed_precision( + root_module, mixed_precision._module_classes_to_ignore + ) + _warn_on_overridden_mixed_precision(overridden_module_classes) + _post_order_apply(root_module, target_module_to_kwargs, module_wrapper_cls) + return + # Support new way to pass an auto wrap policy if isinstance(auto_wrap_policy, _FSDPPolicy): auto_wrap_policy = auto_wrap_policy.policy - root_module = auto_wrap_kwargs["module"] assert auto_wrap_policy is not None - # For auto wrapping, submodules should not already be wrapped with FSDP - # since double wrapping is not supported - for module_name, module in root_module.named_modules(): - if isinstance(module, module_wrapper_cls): - raise ValueError( - f"Expected {module_name} to NOT be FullyShardedDataParallel " - "if using an `auto_wrap_policy`" - ) - mixed_precision = fsdp_kwargs["mixed_precision"] if mixed_precision is not None: - for mp_module_to_override in mixed_precision._module_classes_to_ignore: - # Make modules of this particular type run in fp32 by wrapping them in their own - # FSDP unit. - _override_module_mixed_precision(root_module, mp_module_to_override) - + # Wrap modules of the ignored types separately and register forward + # hooks to cast to fp32 and back to the original dtype, respectively + overridden_module_classes = _override_module_mixed_precision( + root_module, mixed_precision._module_classes_to_ignore + ) auto_wrap_policy = functools.partial( _or_policy, policies=[ @@ -72,17 +93,38 @@ def _auto_wrap( ), ], ) - warnings.warn( - "Both mixed precision and an `auto_wrap_policy` were specified " - "for FSDP, where the wrapped module has batch norm submodules. " - "The batch norm submodules will be wrapped as separate FSDP " - "instances with mixed precision disabled since some batch norm " - "kernels do not support low precision." - ) - auto_wrap_kwargs["auto_wrap_policy"] = auto_wrap_policy + auto_wrap_kwargs["auto_wrap_policy"] = auto_wrap_policy + _warn_on_overridden_mixed_precision(overridden_module_classes) _recursive_wrap(**auto_wrap_kwargs, **fsdp_kwargs) +def _check_nested_wrapping( + root_module: nn.Module, + wrapper_cls: Any, # e.g. `FullyShardedDataParallel` +): + # For auto wrapping, submodules should not already be wrapped with FSDP + # since double wrapping is not supported + for module_name, module in root_module.named_modules(): + if isinstance(module, wrapper_cls): + raise ValueError( + f"Expected {module_name} to NOT be FullyShardedDataParallel " + "if using an `auto_wrap_policy`" + ) + + +def _warn_on_overridden_mixed_precision( + overridden_module_classes: Set[Type[nn.Module]], +): + if len(overridden_module_classes) == 0: + return + warnings.warn( + "Both mixed precision and an auto_wrap_policy were specified to FSDP, " + f"where the wrapped module has submodules of type:\n{overridden_module_classes}\n" + "These modules will be wrapped as separate FSDP instacnes with mixed " + "precision disabled." + ) + + def _get_fully_sharded_module_to_states( root_module: nn.Module, auto_wrap_policy: _FSDPPolicy, diff --git a/torch/distributed/fsdp/wrap.py b/torch/distributed/fsdp/wrap.py index cc30b91f07b9f..73f7289b5382e 100644 --- a/torch/distributed/fsdp/wrap.py +++ b/torch/distributed/fsdp/wrap.py @@ -4,6 +4,7 @@ # LICENSE file in the root directory of this source tree. import contextlib +import copy import functools from abc import ABC, abstractmethod from typing import ( @@ -12,6 +13,7 @@ cast, Dict, Generator, + Iterable, Optional, Sequence, Set, @@ -32,6 +34,83 @@ ] +def _post_order_apply( + root_module: nn.Module, + target_module_to_kwargs: Dict[nn.Module, Dict[str, Any]], + fn_to_apply: Callable[[Any], nn.Module], +) -> None: + """ + This applies a function ``fn_to_apply`` to some target modules with + specified kwargs per target module (via ``target_module_to_kwargs``) + following a post-order traversal. This reduces the problem to constructing + ``target_module_to_kwargs``. + + NOTE: Since the auto wrap policy is an arg to FSDP, this does not apply the + function to the root module, as the caller should be exactly that. + """ + # Track visited modules to avoid visiting shared modules multiple times + visited_modules: Set[nn.Module] = {root_module} + + def _post_order_apply_inner( + module: nn.Module, + module_name: str, + parent_module: Optional[nn.Module], + ): + for child_module_name, child_module in module.named_children(): + if child_module not in visited_modules: + visited_modules.add(child_module) + _post_order_apply_inner(child_module, child_module_name, module) + if module in target_module_to_kwargs and module is not root_module: + assert module_name != "", ( + "Non-root modules should have their module name set but got " + f"an empty module name for {module}" + ) + kwargs = target_module_to_kwargs[module] + new_module = fn_to_apply(module, **kwargs) + setattr(parent_module, module_name, new_module) + + _post_order_apply_inner(root_module, "", None) + + +def _run_module_wrap_policy( + root_module: nn.Module, + module_classes: Iterable[Type[nn.Module]], + ignored_modules: Set[nn.Module], + fsdp_kwargs: Dict[str, Any], +) -> Dict[nn.Module, Dict[str, Any]]: + """ + TODO: To match the existing ``ModuleWrapPolicy`` behavior, every wrapped + module shares the same FSDP kwargs. + """ + module_classes_tuple = tuple(set(module_classes)) + target_module_to_kwargs: Dict[nn.Module, Dict[str, Any]] = {} + for module in root_module.modules(): + if module in ignored_modules: + continue + elif isinstance(module, module_classes_tuple): + # Shallow copy to avoid coupling changes across modules + target_module_to_kwargs[module] = copy.copy(fsdp_kwargs) + return target_module_to_kwargs + + +def _run_mixed_precision_override_policy( + root_module: nn.Module, + module_classes: Iterable[Type[nn.Module]], + ignored_modules: Set[nn.Module], + fsdp_kwargs: Dict[str, Any], + target_module_to_kwargs: Dict[nn.Module, Dict[str, Any]], +): + module_classes_tuple = tuple(set(module_classes)) + for module in root_module.modules(): + if module in ignored_modules: + continue + elif isinstance(module, module_classes_tuple): + # This policy overrides any existing policy + target_module_to_kwargs[module] = fsdp_kwargs + target_module_to_kwargs[module]["mixed_precision"] = None + return target_module_to_kwargs + + def always_wrap_policy(*args, **kwargs) -> bool: """ A simple recursive wrap policy that always returns ``True``. This means @@ -96,6 +175,7 @@ def __init__(self, module_classes: Set[Type[nn.Module]]): _module_wrap_policy, module_classes=module_classes, ) + self._module_classes = module_classes self._module_classes_str = str(module_classes) @property