Skip to content

Commit

Permalink
[FSDP] Introduce ModuleWrapPolicy for simplicity
Browse files Browse the repository at this point in the history
ghstack-source-id: f4f7e36fcf30155219d94f14ea6bf185fdb9ac41
Pull Request resolved: pytorch#88450
  • Loading branch information
awgu committed Nov 12, 2022
1 parent 7aa144a commit 7534e8c
Show file tree
Hide file tree
Showing 15 changed files with 244 additions and 356 deletions.
27 changes: 12 additions & 15 deletions test/distributed/_composable/test_fully_shard.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# Owner(s): ["oncall: distributed"]

import copy
import functools
import sys
from typing import Any, Tuple

Expand All @@ -12,7 +11,7 @@
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp._common_utils import _is_fsdp_flattened
from torch.distributed.fsdp._runtime_utils import _root_pre_forward
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
from torch.distributed.fsdp.wrap import ModuleWrapPolicy
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_fsdp import FSDPTest
from torch.testing._internal.common_utils import (
Expand Down Expand Up @@ -62,10 +61,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return z

@staticmethod
def auto_wrap_policy():
return functools.partial(
transformer_auto_wrap_policy, transformer_layer_cls={SubModel}
)
def policy():
return ModuleWrapPolicy({SubModel})

def get_input(self, device=torch.device) -> Tuple[Any, ...]:
return (torch.randn((8, 5), device=device),)
Expand All @@ -85,13 +82,13 @@ def test_auto_wrap_policy(self):
local_model = Model(device=torch.device("cuda"))
fsdp_wrapped_model = FSDP(
copy.deepcopy(local_model),
auto_wrap_policy=Model.auto_wrap_policy(),
auto_wrap_policy=Model.policy(),
use_orig_params=True,
)
composable_module = copy.deepcopy(local_model)
fully_shard(
composable_module,
auto_wrap_policy=Model.auto_wrap_policy(),
policy=Model.policy(),
)

# Check that the composable module has the same names as the local
Expand Down Expand Up @@ -138,7 +135,7 @@ def test_device_id(self):
assert param.device == cpu_device
fully_shard(
composable_module,
auto_wrap_policy=Model.auto_wrap_policy(),
policy=Model.policy(),
device_id=self.rank,
)
for param in composable_module.parameters():
Expand All @@ -157,12 +154,12 @@ def test_sync_module_states(self):
param.zero_()
fsdp_wrapped_model = FSDP(
copy.deepcopy(local_model),
auto_wrap_policy=Model.auto_wrap_policy(),
auto_wrap_policy=Model.policy(),
use_orig_params=True,
)
fully_shard(
composable_module,
auto_wrap_policy=Model.auto_wrap_policy(),
policy=Model.policy(),
sync_module_states=True,
)
for (composable_param, fsdp_wrapped_param) in zip(
Expand Down Expand Up @@ -197,13 +194,13 @@ def _param_init_fn(module: nn.Module):
composable_module = Model(device="meta")
fsdp_wrapped_model = FSDP(
Model(device="meta"),
auto_wrap_policy=Model.auto_wrap_policy(),
auto_wrap_policy=Model.policy(),
param_init_fn=_param_init_fn,
use_orig_params=True,
)
fully_shard(
composable_module,
auto_wrap_policy=Model.auto_wrap_policy(),
policy=Model.policy(),
param_init_fn=_param_init_fn,
)
for (composable_param, fsdp_wrapped_param) in zip(
Expand All @@ -227,13 +224,13 @@ def test_training(self):
local_model = Model(device=device)
fsdp_wrapped_model = FSDP(
copy.deepcopy(local_model),
auto_wrap_policy=Model.auto_wrap_policy(),
auto_wrap_policy=Model.policy(),
use_orig_params=True,
)
composable_module = copy.deepcopy(local_model)
fully_shard(
composable_module,
auto_wrap_policy=Model.auto_wrap_policy(),
policy=Model.policy(),
)
del local_model # not needed anymore
LR = 1e-2
Expand Down
10 changes: 4 additions & 6 deletions test/distributed/fsdp/test_fsdp_clip_grad_norm.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# Owner(s): ["oncall: distributed"]

import functools
import itertools
import sys
from typing import Union
Expand All @@ -12,7 +11,7 @@
CPUOffload,
FullyShardedDataParallel as FSDP,
)
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
from torch.distributed.fsdp.wrap import ModuleWrapPolicy
from torch.nn import TransformerDecoderLayer, TransformerEncoderLayer
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
Expand Down Expand Up @@ -102,12 +101,11 @@ def _test_ddp_parity(
)
ddp_model = DDP(local_model, device_ids=[self.rank])
fsdp_kwargs = {
"auto_wrap_policy": functools.partial(
transformer_auto_wrap_policy,
transformer_layer_cls={
"auto_wrap_policy": ModuleWrapPolicy(
{
TransformerEncoderLayer,
TransformerDecoderLayer,
},
}
),
"cpu_offload": CPUOffload(offload_params=offload_params),
"use_orig_params": use_orig_params,
Expand Down
22 changes: 18 additions & 4 deletions test/distributed/fsdp/test_fsdp_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,11 @@
FullyShardedDataParallel as FSDP,
ShardingStrategy,
)
from torch.distributed.fsdp.wrap import always_wrap_policy, transformer_auto_wrap_policy
from torch.distributed.fsdp.wrap import (
always_wrap_policy,
ModuleWrapPolicy,
transformer_auto_wrap_policy,
)
from torch.nn import TransformerDecoderLayer, TransformerEncoderLayer
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_fsdp import (
Expand Down Expand Up @@ -211,10 +215,20 @@ def forward(self, x, y):
def test_device_id_auto_wrap(self):
"""Tests that ``auto_wrap_policy`` propagates ``device_id`` to all
nested FSDP instances."""
auto_wrap_policy = functools.partial(
transformer_auto_wrap_policy,
transformer_layer_cls={TransformerEncoderLayer, TransformerDecoderLayer},
self.run_subtests(
{"use_callable": [False, True]},
self._test_device_id_auto_wrap,
)

def _test_device_id_auto_wrap(self, use_callable: bool):
module_classes = {TransformerEncoderLayer, TransformerDecoderLayer}
if use_callable:
auto_wrap_policy = functools.partial(
transformer_auto_wrap_policy,
transformer_layer_cls=module_classes,
)
else:
auto_wrap_policy = ModuleWrapPolicy(module_classes)
fsdp_kwargs = {
"auto_wrap_policy": auto_wrap_policy,
"device_id": torch.cuda.current_device(),
Expand Down
12 changes: 5 additions & 7 deletions test/distributed/fsdp/test_fsdp_state_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
)
from torch.distributed.fsdp._shard_utils import _gather_state_dict
from torch.distributed.fsdp._unshard_param_utils import FLAT_PARAM
from torch.distributed.fsdp.wrap import enable_wrap, transformer_auto_wrap_policy, wrap
from torch.distributed.fsdp.wrap import enable_wrap, ModuleWrapPolicy, wrap
from torch.nn import Linear, Module, TransformerDecoderLayer, TransformerEncoderLayer
from torch.nn.parallel import DistributedDataParallel
from torch.optim import SGD
Expand Down Expand Up @@ -350,9 +350,8 @@ def test_state_dict_with_manual_ac_wrapper(
@skip_if_lt_x_gpu(2)
@parametrize("state_dict_type", _SUPPORTED_STATE_DICT_IMPLS)
def test_state_dict_with_shared_parameters(self, state_dict_type):
auto_wrap_policy = partial(
transformer_auto_wrap_policy,
transformer_layer_cls={TransformerEncoderLayer, TransformerDecoderLayer},
auto_wrap_policy = ModuleWrapPolicy(
{TransformerEncoderLayer, TransformerDecoderLayer}
)
model_creator = partial(
TransformerWithSharedParams.init,
Expand All @@ -377,9 +376,8 @@ def test_state_dict_rank0_offload_save_load_flow(self, use_orig_params: bool):
"""Tests saving a model checkpoint only on rank 0 and loading it only
on rank 0 with ``sync_module_states=True`` to emulate the workflow to
avoid redundant CPU memory usage."""
auto_wrap_policy = partial(
transformer_auto_wrap_policy,
transformer_layer_cls={TransformerEncoderLayer, TransformerDecoderLayer},
auto_wrap_policy = ModuleWrapPolicy(
{TransformerEncoderLayer, TransformerDecoderLayer}
)
fsdp_kwargs = {
"auto_wrap_policy": auto_wrap_policy,
Expand Down
9 changes: 4 additions & 5 deletions test/distributed/fsdp/test_fsdp_use_orig_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
ShardingStrategy,
)
from torch.distributed.fsdp._common_utils import clean_tensor_name
from torch.distributed.fsdp.wrap import always_wrap_policy, transformer_auto_wrap_policy
from torch.distributed.fsdp.wrap import always_wrap_policy, ModuleWrapPolicy
from torch.nn import TransformerDecoderLayer, TransformerEncoderLayer
from torch.nn.parallel.distributed import DistributedDataParallel as DDP
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
Expand Down Expand Up @@ -117,12 +117,11 @@ def _get_fsdp_transformer_and_optim(
# combination with the parameter group construction, ensures different
# hyperparameter settings within one `FlatParameter`
fsdp_kwargs = {
"auto_wrap_policy": functools.partial(
transformer_auto_wrap_policy,
transformer_layer_cls={
"auto_wrap_policy": ModuleWrapPolicy(
{
TransformerEncoderLayer,
TransformerDecoderLayer,
},
}
),
"use_orig_params": True,
"sharding_strategy": sharding_strategy,
Expand Down
7 changes: 2 additions & 5 deletions test/distributed/fsdp/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# Owner(s): ["oncall: distributed"]

import functools
import random
import sys
import unittest
Expand All @@ -14,7 +13,7 @@
from torch import distributed as dist
from torch.distributed.fsdp._utils import _apply_to_tensors
from torch.distributed.fsdp._wrap_utils import _get_submodule_to_states
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
from torch.distributed.fsdp.wrap import ModuleWrapPolicy
from torch.distributed.utils import _replace_by_prefix
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
Expand Down Expand Up @@ -173,9 +172,7 @@ def test_module_wrap_policy(self):
# Compute the mapping from submodule to states according to a logical
# module wrap policy
module_classes = (nn.Sequential,)
auto_wrap_policy = functools.partial(
transformer_auto_wrap_policy, transformer_layer_cls=set(module_classes)
)
auto_wrap_policy = ModuleWrapPolicy(set(module_classes))
submodule_to_states = _get_submodule_to_states(
model, auto_wrap_policy, set(), set()
)
Expand Down
16 changes: 16 additions & 0 deletions test/distributed/fsdp/test_wrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import tempfile
import unittest
from enum import auto, Enum
from typing import Callable, Union

import torch
import torch.nn as nn
Expand All @@ -15,10 +16,12 @@
FullyShardedDataParallel as FSDP,
)
from torch.distributed.fsdp.wrap import (
_FSDPPolicy,
_or_policy,
_wrap_batchnorm_individually,
always_wrap_policy,
enable_wrap,
ModuleWrapPolicy,
size_based_auto_wrap_policy,
transformer_auto_wrap_policy,
wrap,
Expand Down Expand Up @@ -373,6 +376,19 @@ def test_transformer_auto_wrap_policy(self):
transformer_auto_wrap_policy,
transformer_layer_cls={TransformerEncoderLayer, TransformerDecoderLayer},
)
self._test_transformer_wrapping(auto_wrap_policy)

@unittest.skipIf(torch.cuda.device_count() < 2, "Requires at least 2 GPUs")
def test_module_wrap_policy(self):
"""Tests the ``ModuleWrapPolicy``."""
auto_wrap_policy = ModuleWrapPolicy(
{TransformerEncoderLayer, TransformerDecoderLayer}
)
self._test_transformer_wrapping(auto_wrap_policy)

def _test_transformer_wrapping(
self, auto_wrap_policy: Union[Callable, _FSDPPolicy]
):
fsdp_kwargs = {"auto_wrap_policy": auto_wrap_policy}
fsdp_model = TransformerWithSharedParams.init(
self.process_group,
Expand Down
8 changes: 6 additions & 2 deletions torch/distributed/_composable/fully_shard.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
MixedPrecision,
ShardingStrategy,
)
from torch.distributed.fsdp.wrap import _FSDPPolicy


@contract
Expand All @@ -32,7 +33,7 @@ def fully_shard(
process_group: Optional[dist.ProcessGroup] = None,
mixed_precision: Optional[MixedPrecision] = None,
cpu_offload: Optional[CPUOffload] = None,
auto_wrap_policy: Optional[Callable] = None,
policy: Optional[_FSDPPolicy] = None,
ignored_modules: Optional[Iterable[torch.nn.Module]] = None,
device_id: Optional[Union[int, torch.device]] = None,
param_init_fn: Optional[Callable[[nn.Module], None]] = None,
Expand All @@ -41,6 +42,9 @@ def fully_shard(
"""
Applies ``FullyShardedDataParallel` (FSDP) semantics to ``module``.
"""
# Enforce the new auto wrap policy
if policy is not None and not isinstance(policy, _FSDPPolicy):
raise ValueError(f"Expects an `_FSDPPolicy` but got {policy}")
state = fully_shard.state(module)
state = _init_ignored_module_states(state, module, ignored_modules)
state = _init_process_group_state(state, process_group)
Expand All @@ -64,7 +68,7 @@ def fully_shard(
state = _init_param_handles_from_module(
state,
module,
auto_wrap_policy,
policy,
device_id,
param_init_fn,
sync_module_states,
Expand Down
1 change: 0 additions & 1 deletion torch/distributed/fsdp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,3 @@
ShardingStrategy,
StateDictType,
)
from .wrap import ParamExecOrderWrapPolicy
5 changes: 3 additions & 2 deletions torch/distributed/fsdp/_init_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
HandleConfig,
HandleShardingStrategy,
)
from torch.distributed.fsdp.wrap import _FSDPPolicy
from torch.distributed.utils import _sync_params_and_buffers
from torch.utils.hooks import RemovableHandle

Expand Down Expand Up @@ -262,7 +263,7 @@ def _init_param_handle_from_module(
def _init_param_handles_from_module(
state: _FSDPState,
root_module: nn.Module,
auto_wrap_policy: Callable,
policy: _FSDPPolicy,
device_id: Optional[Union[int, torch.device]],
param_init_fn: Optional[Callable[[nn.Module], None]],
sync_module_states: bool,
Expand All @@ -273,7 +274,7 @@ def _init_param_handles_from_module(
"""
submodule_to_states = _get_submodule_to_states(
root_module,
auto_wrap_policy,
policy,
state._ignored_modules,
state._ignored_params,
)
Expand Down
Loading

0 comments on commit 7534e8c

Please sign in to comment.