Skip to content

Commit

Permalink
[tp] improve documentation (pytorch#115880)
Browse files Browse the repository at this point in the history
Improve the TP documentation in terms of format and descriptions

Pull Request resolved: pytorch#115880
Approved by: https://github.com/XilunWu
  • Loading branch information
wanchaol authored and ZhiweiYan-96 committed Dec 22, 2023
1 parent 0942718 commit 595e6be
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 25 deletions.
2 changes: 1 addition & 1 deletion docs/source/distributed.tensor.parallel.rst
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ Tensor Parallelism supports the following parallel styles:
To simply configure the nn.Module's inputs and outputs with DTensor layouts
and perform necessary layout redistributions, without distribute the module
parameters to DTensors, the following classes can be used in
the ``parallelize_plan``of ``parallelize_module``:
the ``parallelize_plan`` of ``parallelize_module``:

.. autoclass:: torch.distributed.tensor.parallel.PrepareModuleInput
:members:
Expand Down
18 changes: 9 additions & 9 deletions torch/distributed/tensor/parallel/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,9 @@ def parallelize_module( # type: ignore[return]
to be parallelized.
User can also specify different parallel style per module fully qualified name (FQN).
The API supports 2D parallelism natively by accepting an n-dimension device_mesh
and users just need to specify the dimension where we perform tensor parallelism on.
Note that ``parallelize_module`` only accepts a 1-D :class:`DeviceMesh`, if you have a 2-D or N-D :class:`DeviceMesh`,
slice the DeviceMesh to a 1-D sub DeviceMesh first then pass to this API(i.e. ``device_mesh[\"tp\"]``)
Args:
module (:class:`nn.Module`):
Expand All @@ -50,9 +51,10 @@ def parallelize_module( # type: ignore[return]
:class:`ParallelStyle` object which contains how
we prepare input/output for Tensor Parallelism or it can be a
dict of module FQN and its corresponding :class:`ParallelStyle` object.
tp_mesh_dim (int):
tp_mesh_dim (int, deprecated):
The dimension of ``device_mesh`` where we perform
Tensor Parallelism on.
Tensor Parallelism on, this field is deprecated and will be removed in future.
If you have a 2-D or N-D :class:`DeviceMesh`, consider passing in device_mesh[\"tp\"]
Return:
A :class:`nn.Module` object parallelized.
Expand All @@ -66,11 +68,9 @@ def parallelize_module( # type: ignore[return]
>>> m = parallelize_module(m, ColwiseParallel())
>>>
.. warning::
Currently, there are some constraints which makes it hard for complicated modules
like ``MultiheadAttention`` to work out of box for Tensor or Sequence Parallelism.
We recommend users to try ``ColwiseParallel`` and ``RowwiseParallel`` for each parameter
or submodule and there might be some code changes needed now.
.. note:: For complex module architecture like Attention, MLP layers, we recommend composing
different ParallelStyles together (i.e. ``ColwiseParallel`` and ``RowwiseParallel``) and pass
as a parallelize_plan, to achieves the desired sharding computation.
"""
torch._C._log_api_usage_once("torch.distributed.tensor.parallel.parallelize_module")

Expand Down
34 changes: 19 additions & 15 deletions torch/distributed/tensor/parallel/style.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,9 @@ class ColwiseParallel(ParallelStyle):
The DTensor layout of the output for the nn.Module, this is used to ensure the output of the nn.Module
with the user desired layout. If not specified, the output tensor is sharded on the last dimension.
use_local_output (bool, optional):
Whether to use local :class:`torch.Tensor` instead of `DTensor` for the module output, default: True.
Whether to use local :class:`torch.Tensor` instead of :class:`DTensor` for the module output, default: True.
Returns:
A ParallelStyle object that represents Colwise sharding of the nn.Module.
A :class:`ParallelStyle` object that represents Colwise sharding of the nn.Module.
Example::
>>> # xdoctest: +SKIP(failing)
Expand All @@ -60,6 +60,10 @@ class ColwiseParallel(ParallelStyle):
>>> parallelize_plan={"w1": ColwiseParallel()},
>>> )
>>> ...
... note:: By default ``ColwiseParallel`` output is sharded on the last dimension if the ``output_layouts`` not
specified, if there're operators that require specific tensor shape (i.e. before the paired ``RowwiseParallel``),
keep in mind that if the output is sharded the operator might need to be adjusted to the sharded size.
"""

def __init__(
Expand All @@ -68,7 +72,7 @@ def __init__(
input_layouts: Optional[Placement] = None,
output_layouts: Optional[Placement] = None,
use_local_output: bool = True
) -> None:
):
super().__init__()
self.input_layouts = (input_layouts or Replicate(), )
self.output_layouts = (output_layouts or Shard(-1), )
Expand Down Expand Up @@ -146,9 +150,9 @@ class RowwiseParallel(ParallelStyle):
The DTensor layout of the output for the nn.Module, this is used to ensure the output of the nn.Module
with the user desired layout. If not specified, the output tensor is replicated.
use_local_output (bool, optional):
Whether to use local :class:`torch.Tensor` instead of `DTensor` for the module output, default: True.
Whether to use local :class:`torch.Tensor` instead of :class:`DTensor` for the module output, default: True.
Returns:
A ParallelStyle object that represents Rowwise sharding of the nn.Module.
A :class:`ParallelStyle` object that represents Rowwise sharding of the nn.Module.
Example::
>>> # xdoctest: +SKIP(failing)
Expand All @@ -171,7 +175,7 @@ def __init__(
input_layouts: Optional[Placement] = None,
output_layouts: Optional[Placement] = None,
use_local_output: bool = True
) -> None:
):
super().__init__()
self.input_layouts = (input_layouts or Shard(-1), )
self.output_layouts = (output_layouts or Replicate(), )
Expand Down Expand Up @@ -225,20 +229,20 @@ def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
class PrepareModuleInput(ParallelStyle):
"""
Configure the nn.Module's inputs to convert the input tensors of the nn.Module to DTensors at runtime according to
input_layouts, and perform layout redistribution according to the desired_input_layouts.
``input_layouts``, and perform layout redistribution according to the ``desired_input_layouts``.
Keyword Args:
input_layouts (Union[Placement, Tuple[Placement]]):
The DTensor layouts of input tensors for the nn.Module, this is used to convert the input tensors to
DTensors. If some inputs are not torch.Tensor or no need to convert to DTensors, `None` need to be specified
DTensors. If some inputs are not torch.Tensor or no need to convert to DTensors, ``None`` need to be specified
as a placeholder.
desired_input_layouts (Union[Placement, Tuple[Placement]]):
The desired DTensor layout of input tensors for the nn.Module, this is used to ensure the inputs of the nn.Module
have the desired DTensor layouts. This argument needs to have the same length with `input_layouts`.
have the desired DTensor layouts. This argument needs to have the same length with ``input_layouts``.
use_local_output (bool, optional):
Whether to use local :class:`torch.Tensor` instead of `DTensor` for the module inputs, default: False.
Whether to use local :class:`torch.Tensor` instead of :class:`DTensor` for the module inputs, default: False.
Returns:
A ParallelStyle object that prepares the sharding layouts of the nn.Module's inputs.
A :class:`ParallelStyle` object that prepares the sharding layouts of the nn.Module's inputs.
Example::
>>> # xdoctest: +SKIP(failing)
Expand Down Expand Up @@ -298,18 +302,18 @@ def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
class PrepareModuleOutput(ParallelStyle):
"""
Configure the nn.Module's outputs to convert the output tensors of the nn.Module to DTensors at runtime according to
output_layouts, and perform layout redistribution according to the desired_output_layouts.
``output_layouts``, and perform layout redistribution according to the ``desired_output_layouts``.
Keyword Args:
output_layouts (Union[Placement, Tuple[Placement]]):
The DTensor layouts of output tensors for the nn.Module, this is used to convert the output tensors to
DTensors if they are `torch.Tensor`. If some outputs are not torch.Tensor or no need to convert to DTensors,
`None` need to be specified as a placeholder.
DTensors if they are :class:`torch.Tensor`. If some outputs are not torch.Tensor or no need to convert to DTensors,
``None`` need to be specified as a placeholder.
desired_output_layouts (Union[Placement, Tuple[Placement]]):
The desired DTensor layouts of output tensors for the nn.Module, this is used to ensure the outputs of the nn.Module
have the desired DTensor layouts.
use_local_output (bool, optional):
Whether to use local :class:`torch.Tensor` instead of `DTensor` for the module outputs, default: False.
Whether to use local :class:`torch.Tensor` instead of :class:`DTensor` for the module outputs, default: False.
Returns:
A ParallelStyle object that prepares the sharding layouts of the nn.Module's outputs.
Expand Down

0 comments on commit 595e6be

Please sign in to comment.