Skip to content

Commit

Permalink
[FSDP][Easy] Allow ModuleWrapPolicy to take Iterable
Browse files Browse the repository at this point in the history
ghstack-source-id: 6e56f5c32da86336e4a943a49f237d448c532cad
Pull Request resolved: pytorch#104999
  • Loading branch information
awgu committed Jul 21, 2023
1 parent 76ebc08 commit 0159d77
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions torch/distributed/fsdp/wrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,13 +198,14 @@ def _module_wrap_policy(
class ModuleWrapPolicy(_FSDPPolicy):
"""This is a wrapper around :func:`_module_wrap_policy`."""

def __init__(self, module_classes: Set[Type[nn.Module]]):
def __init__(self, module_classes: Iterable[Type[nn.Module]]):
module_classes_set = set(module_classes)
self._policy: Callable = functools.partial(
_module_wrap_policy,
module_classes=module_classes,
module_classes=module_classes_set,
)
self._module_classes = module_classes
self._module_classes_str = str(module_classes)
self._module_classes = module_classes_set
self._module_classes_str = str(module_classes_set)

@property
def policy(self):
Expand Down

0 comments on commit 0159d77

Please sign in to comment.