Skip to content

Commit

Permalink
[DTensor] min, max and prod sharding propagation rules (pytorch#112403)
Browse files Browse the repository at this point in the history
* `torch/distributed/_tensor/ops/math_ops.py` and `test/distributed/_tensor/test_math_ops.py`: add min, max and prod sharding propagation rules
* `torch/distributed/_tensor/sharding_prop.py` Validate OutputSpec to provide better errors when provided invalid specs
* `torch/distributed/_tensor/op_schema.py`: import `OpOverload` directly to aid linters

Pull Request resolved: pytorch#112403
Approved by: https://github.com/wanchaol
  • Loading branch information
joshlk authored and Skylion007 committed Nov 14, 2023
1 parent b1c49cd commit 77e3b5b
Show file tree
Hide file tree
Showing 4 changed files with 107 additions and 57 deletions.
44 changes: 30 additions & 14 deletions test/distributed/_tensor/test_math_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,27 +16,43 @@


class DistMathOpsTest(DTensorTestBase):
@with_comms
def test_sum(self):
def linear_op_reductions(self, op_str):
device_mesh = self.build_device_mesh()

shard_spec = [Shard(0)]

tensor_to_sum = torch.randn(12, 8, 8)
tensor = torch.randn(12, 8, 8)
dtensor = distribute_tensor(tensor, device_mesh, shard_spec)

mat1 = distribute_tensor(tensor_to_sum, device_mesh, shard_spec)
op = getattr(tensor, op_str)
op_dt = getattr(dtensor, op_str)

keep_dim_or_not = [True, False, None]
for dim in range(tensor_to_sum.ndim):
for dim in range(tensor.ndim):
for keep_dim in keep_dim_or_not:
sum_args = (dim, keep_dim) if keep_dim is not None else (dim,)
dim_sumed_tensor = tensor_to_sum.sum(*sum_args)
dt_dim_sumed_tensor = mat1.sum(*sum_args).full_tensor()
self.assertEqual(dt_dim_sumed_tensor, dim_sumed_tensor)

full_sumed_tensor = tensor_to_sum.sum()
dt_sum = mat1.sum().full_tensor()
self.assertEqual(dt_sum, full_sumed_tensor)
args = (dim, keep_dim) if keep_dim is not None else (dim,)
if op_str in ("max", "min"):
# min and max return a tuple when dim specified
dim_reduced_tensor, _ = op(*args)
dt_reduced, _ = op_dt(*args)
else:
dim_reduced_tensor = op(*args)
dt_reduced = op_dt(*args)
dt_dim_reduced_tensor = dt_reduced.full_tensor()
self.assertEqual(dt_dim_reduced_tensor, dim_reduced_tensor)

full_reduced_tensor = op()
dt_full_reduced = op_dt().full_tensor()
self.assertEqual(dt_full_reduced, full_reduced_tensor)

@with_comms
def test_linear_op_reductions(self):
for op_str in ("all", "sum", "prod", "max", "min"):
self.linear_op_reductions(op_str)

@with_comms
@skip_unless_torch_gpu
def test_mean(self):
self.linear_op_reductions("mean")

# TODO: forward test can be removed once test_softmax_with_bwd passes on CPU
@with_comms
Expand Down
7 changes: 4 additions & 3 deletions torch/distributed/_tensor/op_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Dict, List, Optional, Sequence, Tuple, Union

import torch
from torch._ops import OpOverload
from torch.distributed._tensor.device_mesh import DeviceMesh
from torch.distributed._tensor.placement_types import DTensorSpec

Expand Down Expand Up @@ -34,14 +35,14 @@ def _rebuild_tensor_from_dtensor_meta(arg) -> object:
)


def _is_inplace_op(op: torch._ops.OpOverload):
def _is_inplace_op(op: OpOverload):
# simple analysis of function schema to determine
# if this is an inplace variant, it might not
# be entirely correct, but it's good enough for now.
return op._schema.name[-1] == "_"


def _is_out_variant_op(op: torch._ops.OpOverload):
def _is_out_variant_op(op: OpOverload):
# simple analysis of function schema to determine
# if this is an out variant, it might not
# be entirely correct, but it's good enough for now.
Expand Down Expand Up @@ -185,7 +186,7 @@ class OpSchema:
with its DTensorSpec
"""

op: torch._ops.OpOverload
op: OpOverload
args_schema: ArgsType
kwargs_schema: KwargsType

Expand Down
43 changes: 22 additions & 21 deletions torch/distributed/_tensor/ops/math_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,30 +155,30 @@ def common_reduction_strategy(
return reduction_strategy


@register_op_strategy(
[aten.all.default, aten.sum.default, aten.sum.dim_IntList],
schema_info=RuntimeSchemaInfo(1),
)
def default_reduction_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy:
args_schema = op_schema.args_schema
input_strategy = args_schema[0]
assert isinstance(input_strategy, OpStrategy)
dims = None
if len(op_schema.args_schema) > 1:
dims = _infer_reduction_dims(args_schema[1], input_strategy.output_ndim)

reduce_dims = list(range(input_strategy.output_ndim)) if dims is None else dims

keep_dim = len(op_schema.args_schema) > 2 and bool(op_schema.args_schema[2])
return common_reduction_strategy(
mesh, input_strategy, reduce_dims, keep_dim=keep_dim, reduction_linear=True
)
LINEAR_REDUCTION_OP_MAP = {
aten.all.default: c10d.ReduceOp.SUM,
aten.all.dim: c10d.ReduceOp.SUM,
aten.sum.default: c10d.ReduceOp.SUM,
aten.sum.dim_IntList: c10d.ReduceOp.SUM,
aten.prod.default: c10d.ReduceOp.PRODUCT,
aten.prod.dim_int: c10d.ReduceOp.PRODUCT,
aten.prod.int_out: c10d.ReduceOp.PRODUCT,
aten.mean.default: c10d.ReduceOp.AVG,
aten.mean.dim: c10d.ReduceOp.AVG,
aten.mean.out: c10d.ReduceOp.AVG,
aten.max.default: c10d.ReduceOp.MAX,
aten.max.dim: c10d.ReduceOp.MAX,
aten.max.out: c10d.ReduceOp.MAX,
aten.min.default: c10d.ReduceOp.MIN,
aten.min.dim: c10d.ReduceOp.MIN,
aten.min.out: c10d.ReduceOp.MIN,
}


@register_op_strategy(
[aten.mean.default, aten.mean.dim, aten.mean.out], schema_info=RuntimeSchemaInfo(1)
list(LINEAR_REDUCTION_OP_MAP.keys()), schema_info=RuntimeSchemaInfo(1)
)
def mean_reduction_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy:
def linear_reduction_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy:
args_schema = op_schema.args_schema
input_strategy = args_schema[0]
assert isinstance(input_strategy, OpStrategy)
Expand All @@ -189,13 +189,14 @@ def mean_reduction_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy
reduce_dims = list(range(input_strategy.output_ndim)) if dims is None else dims

keep_dim = len(op_schema.args_schema) > 2 and bool(op_schema.args_schema[2])
reduction_op = LINEAR_REDUCTION_OP_MAP[op_schema.op]
return common_reduction_strategy(
mesh,
input_strategy,
reduce_dims,
keep_dim=keep_dim,
reduction_linear=True,
reduction_op=c10d.ReduceOp.AVG,
reduction_op=reduction_op,
)


Expand Down
70 changes: 51 additions & 19 deletions torch/distributed/_tensor/sharding_prop.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from functools import lru_cache
from itertools import chain
from typing import Callable, cast, Dict, List, Optional, Sequence
from typing import Callable, cast, Dict, List, Optional, Sequence, Tuple, Union

import torch
from torch._ops import OpOverload
Expand All @@ -23,6 +23,14 @@
aten = torch.ops.aten


def _length(obj) -> int:
if obj is None:
return 0
if not isinstance(obj, Sequence):
return 1
return len(obj)


class ShardingPropagator:
def __init__(self) -> None:
self.op_to_rules: Dict[OpOverload, Callable[[OpSchema], OutputSharding]] = {}
Expand Down Expand Up @@ -60,7 +68,9 @@ def register_op_strategy(
if schema_info is not None:
self.op_to_schema_info[op_overload] = schema_info

def _propagate_tensor_meta(self, op_schema: OpSchema) -> object:
def _propagate_tensor_meta(
self, op_schema: OpSchema
) -> Union[None, TensorMeta, List[TensorMeta], Tuple[TensorMeta, ...]]:
"""
Propagate the tensor metadata, it could either return a TensorMeta
or a list/tuple of TensorMetas
Expand Down Expand Up @@ -98,22 +108,45 @@ def _propagate_tensor_meta(self, op_schema: OpSchema) -> object:
return None

def _wrap_output_spec_tensor_meta(
self, output_spec: OutputSpecType, output_tensor_meta: object
self,
op: OpOverload,
output_spec: OutputSpecType,
output_tensor_meta: Union[
None, TensorMeta, List[TensorMeta], Tuple[TensorMeta, ...]
],
) -> None:
"""
Wrap the output_spec with the tensor metadata from the output.
"""
if output_spec is not None:
if isinstance(output_spec, DTensorSpec):
assert isinstance(output_tensor_meta, TensorMeta)
output_spec.tensor_meta = output_tensor_meta
elif isinstance(output_spec, (tuple, list)):
for i, spec in enumerate(output_spec):
if isinstance(spec, DTensorSpec):
assert isinstance(output_tensor_meta, (tuple, list))
output_tensor_meta_i = output_tensor_meta[i]
assert isinstance(output_tensor_meta_i, TensorMeta)
spec.tensor_meta = output_tensor_meta_i

if isinstance(output_spec, DTensorSpec):
if not isinstance(output_tensor_meta, TensorMeta):
# Either error due to ShardingPropagator or due to incorrect OutputSpec
if not isinstance(output_tensor_meta, (tuple, list)):
raise ValueError(
"ShardingPropagator error: output does not have an associated TensorMeta"
)
raise ValueError(
f"For the op {op.name()}, `output_spec` has 1 output which does not equal the "
f"number of op outputs: {len(output_tensor_meta)}."
)
output_spec.tensor_meta = output_tensor_meta
elif isinstance(output_spec, (tuple, list)):
if not isinstance(output_tensor_meta, (tuple, list)) or len(
output_spec
) != len(output_tensor_meta):
raise ValueError(
f"For the op {op.name()}, `output_spec` has {len(output_spec)} outputs which does not equal the "
f"number of op outputs {_length(output_tensor_meta)}."
)
for i, spec in enumerate(output_spec):
if isinstance(spec, DTensorSpec):
output_tensor_meta_i = output_tensor_meta[i]
if not isinstance(output_tensor_meta_i, TensorMeta):
raise ValueError(
f"ShardingPropagator error: output {i} does not have an associated TensorMeta"
)
spec.tensor_meta = output_tensor_meta_i

def propagate(self, op_info: OpInfo) -> None:
# We cannot use an lru cache if we know that inputs will have dynamic shapes,
Expand Down Expand Up @@ -274,10 +307,9 @@ def spec_to_strategy(spec: object) -> object:
raise ValueError("Unsupported op strategy type")

# associate the output sharding with the output tensor metadata
if out_tensor_meta is not None:
self._wrap_output_spec_tensor_meta(
output_sharding.output_spec, out_tensor_meta
)
self._wrap_output_spec_tensor_meta(
op_schema.op, output_sharding.output_spec, out_tensor_meta
)
return output_sharding
elif op_schema.op in self.op_to_rules:
# propagate the sharding with rule
Expand Down Expand Up @@ -322,7 +354,7 @@ def spec_to_strategy(spec: object) -> object:

# associate the output sharding with the output tensor metadata
self._wrap_output_spec_tensor_meta(
output_sharding.output_spec, out_tensor_meta
op_schema.op, output_sharding.output_spec, out_tensor_meta
)

return output_sharding
Expand Down

0 comments on commit 77e3b5b

Please sign in to comment.