From add1f013741cc9576084d23100b0a68c0096223f Mon Sep 17 00:00:00 2001 From: liuzhenhai93 Date: Tue, 24 Oct 2023 11:34:45 +0000 Subject: [PATCH 01/25] polish --- paddle/phi/infermeta/spmd_rules/rules.h | 11 ++ .../semi_auto_parallel_for_matmul.py | 183 ++++++------------ .../semi_auto_parallel_for_reduction.py | 56 +++--- test/auto_parallel/semi_auto_parallel_util.py | 124 ++++++++++++ 4 files changed, 226 insertions(+), 148 deletions(-) create mode 100644 test/auto_parallel/semi_auto_parallel_util.py diff --git a/paddle/phi/infermeta/spmd_rules/rules.h b/paddle/phi/infermeta/spmd_rules/rules.h index eda61be1f2284..673808ef4b64e 100644 --- a/paddle/phi/infermeta/spmd_rules/rules.h +++ b/paddle/phi/infermeta/spmd_rules/rules.h @@ -435,6 +435,11 @@ PD_REGISTER_SPMD_RULE( PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmd), PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmdReverse)); +PD_REGISTER_SPMD_RULE( + not_equal, + PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmdReverse)); + // TODO(pkuzyc): add multiary elementwise rule // reduction rule @@ -462,6 +467,12 @@ PD_REGISTER_SPMD_RULE( max, PD_INFER_SPMD(phi::distributed::ReductionInferSpmd), PD_INFER_SPMD(phi::distributed::ReductionInferSpmdReverse)); + +PD_REGISTER_SPMD_RULE( + reduce_max, + PD_INFER_SPMD(phi::distributed::ReductionInferSpmd), + PD_INFER_SPMD(phi::distributed::ReductionInferSpmdReverse)); + PD_REGISTER_SPMD_RULE( min, PD_INFER_SPMD(phi::distributed::ReductionInferSpmd), diff --git a/test/auto_parallel/semi_auto_parallel_for_matmul.py b/test/auto_parallel/semi_auto_parallel_for_matmul.py index 279062f483058..2ff10c69b1302 100644 --- a/test/auto_parallel/semi_auto_parallel_for_matmul.py +++ b/test/auto_parallel/semi_auto_parallel_for_matmul.py @@ -12,68 +12,27 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os import numpy as np import paddle import paddle.distributed as dist +from .semi_auto_parallel_util import SemiAutoParallelTestBase -class TestMatmulApiForSemiAutoParallel: + +class TestMatmulApiForSemiAutoParallel(SemiAutoParallelTestBase): def __init__(self): - self._dtype = os.getenv("dtype") - self._backend = os.getenv("backend") - self._seed = eval(os.getenv("seed")) - self._mesh = dist.ProcessMesh([0, 1], dim_names=["x"]) - - def check_tensor_eq(self, a, b): - np1 = a.numpy() - np2 = b.numpy() - np.testing.assert_allclose(np1, np2, rtol=1e-05, verbose=True) - - def test_body( - self, x_shape, y_shape, x_specs, y_specs, trans_x=False, trans_y=False - ): - paddle.seed(self._seed) - np.random.seed(self._seed) - - x_np = np.random.random(size=x_shape).astype(self._dtype) - y_np = np.random.random(size=y_shape).astype(self._dtype) - x = paddle.to_tensor(x_np) - y = paddle.to_tensor(y_np) - x.stop_gradient = False - y.stop_gradient = False - - x_dist_attr = dist.DistAttr(mesh=self._mesh, sharding_specs=x_specs) - y_dist_attr = dist.DistAttr(mesh=self._mesh, sharding_specs=y_specs) - - dist_x = dist.shard_tensor(x_np, dist_attr=x_dist_attr) - dist_y = dist.shard_tensor(y_np, dist_attr=y_dist_attr) - dist_x.stop_gradient = False - dist_y.stop_gradient = False - - out = paddle.matmul(x, y, transpose_x=trans_x, transpose_y=trans_y) - dist_out = paddle.matmul( - dist_x, dist_y, transpose_x=trans_x, transpose_y=trans_y - ) - self.check_tensor_eq(out, dist_out) - - out.backward() - dist_out.backward() - self.check_tensor_eq(x.grad, dist_x.grad) - self.check_tensor_eq(y.grad, dist_y.grad) - - return dist_out, dist_x.grad, dist_y.grad + super().__init__() def test_matmul_x_row_shard(self): # case1: mk[0,-1],kn[-1,-1] -> mk[0,-1],kn[-1,-1] = mn[0,-1] partial[] - dist_out, dist_x_grad, dist_y_grad = self.test_body( - x_shape=[64, 32], - y_shape=[32, 48], - x_specs=['x', None], - y_specs=[None, None], + inputs_shape = ([64, 32], [32, 48]) + inputs_specs = (['x', None], [None, None]) + dist_input, dist_out = self.test_body( + inputs_shape, inputs_specs, paddle.matmul ) + x, y = dist_input # verify output local shape and dist attr np.testing.assert_equal(dist_out._local_shape, [32, 48], verbose=True) np.testing.assert_equal( @@ -81,30 +40,27 @@ def test_matmul_x_row_shard(self): ) assert dist_out.dist_attr._is_partial() is False # verify x_grad local shape and dist attr + np.testing.assert_equal(x.grad._local_shape, [32, 32], verbose=True) np.testing.assert_equal( - dist_x_grad._local_shape, [32, 32], verbose=True - ) - np.testing.assert_equal( - dist_x_grad.dist_attr.dims_mapping, [0, -1], verbose=True + x.grad.dist_attr.dims_mapping, [0, -1], verbose=True ) - assert dist_x_grad.dist_attr._is_partial() is False + assert x.grad.dist_attr._is_partial() is False # verify y_grad local shape and dist attr + np.testing.assert_equal(y.grad._local_shape, [32, 48], verbose=True) np.testing.assert_equal( - dist_y_grad._local_shape, [32, 48], verbose=True + y.grad.dist_attr.dims_mapping, [-1, -1], verbose=True ) - np.testing.assert_equal( - dist_y_grad.dist_attr.dims_mapping, [-1, -1], verbose=True - ) - assert dist_y_grad.dist_attr._is_partial() is False + assert y.grad.dist_attr._is_partial() is False def test_matmul_x_column_shard(self): # case2: mk[-1, 0],kn[-1,-1] --> mk[-1, 0],kn[0, -1] = nm[-1, -1] partial[0] - dist_out, dist_x_grad, dist_y_grad = self.test_body( - x_shape=[64, 32], - y_shape=[32, 48], - x_specs=[None, 'x'], - y_specs=[None, None], + + inputs_shape = ([64, 32], [32, 48]) + inputs_specs = ([None, 'x'], [None, None]) + dist_input, dist_out = self.test_body( + inputs_shape, inputs_specs, paddle.matmul ) + x, y = dist_input # verify local shape np.testing.assert_equal(dist_out._local_shape, [64, 48], verbose=True) np.testing.assert_equal( @@ -112,32 +68,29 @@ def test_matmul_x_column_shard(self): ) assert dist_out.dist_attr._is_partial() is False # verify x_grad local shape and dist attr + np.testing.assert_equal(x.grad._local_shape, [64, 16], verbose=True) np.testing.assert_equal( - dist_x_grad._local_shape, [64, 16], verbose=True + x.grad.dist_attr.dims_mapping, [-1, 0], verbose=True ) - np.testing.assert_equal( - dist_x_grad.dist_attr.dims_mapping, [-1, 0], verbose=True - ) - assert dist_x_grad.dist_attr._is_partial() is False + assert x.grad.dist_attr._is_partial() is False # verify y_grad local shape and dist attr + np.testing.assert_equal(y.grad._local_shape, [32, 48], verbose=True) np.testing.assert_equal( - dist_y_grad._local_shape, [32, 48], verbose=True + y.grad.dist_attr.dims_mapping, [-1, -1], verbose=True ) - np.testing.assert_equal( - dist_y_grad.dist_attr.dims_mapping, [-1, -1], verbose=True - ) - assert dist_y_grad.dist_attr._is_partial() is False + assert y.grad.dist_attr._is_partial() is False def test_matmul_x_column_shard_trans_x_y(self): # case1: mk[-1,0],kn[-1,-1] -> mk[0,-1],kn[-1,-1] = mn[0,-1] partial[], trans x, trans y - dist_out, dist_x_grad, dist_y_grad = self.test_body( - x_shape=[32, 64], - y_shape=[48, 32], - x_specs=[None, 'x'], - y_specs=[None, None], + inputs_shape = ([32, 64], [48, 32]) + inputs_specs = ([None, 'x'], [None, None]) + dist_input, dist_out = self.test_body( + inputs_shape, + inputs_specs, trans_x=True, trans_y=True, ) + x, y = dist_input # verify output local shape and dist attr np.testing.assert_equal(dist_out._local_shape, [32, 48], verbose=True) np.testing.assert_equal( @@ -145,32 +98,29 @@ def test_matmul_x_column_shard_trans_x_y(self): ) assert dist_out.dist_attr._is_partial() is False # verify x_grad local shape and dist attr + np.testing.assert_equal(x.grad._local_shape, [32, 32], verbose=True) np.testing.assert_equal( - dist_x_grad._local_shape, [32, 32], verbose=True - ) - np.testing.assert_equal( - dist_x_grad.dist_attr.dims_mapping, [-1, 0], verbose=True + x.grad.dist_attr.dims_mapping, [-1, 0], verbose=True ) - assert dist_x_grad.dist_attr._is_partial() is False + assert x.grad.dist_attr._is_partial() is False # verify y_grad local shape and dist attr + np.testing.assert_equal(y.grad._local_shape, [48, 32], verbose=True) np.testing.assert_equal( - dist_y_grad._local_shape, [48, 32], verbose=True - ) - np.testing.assert_equal( - dist_y_grad.dist_attr.dims_mapping, [-1, -1], verbose=True + y.grad.dist_attr.dims_mapping, [-1, -1], verbose=True ) - assert dist_y_grad.dist_attr._is_partial() is False + assert y.grad.dist_attr._is_partial() is False def test_matmul_x_column_shard_trans_x(self): # case1: mk[-1,0],kn[-1,-1] -> mk[0,-1],kn[-1,-1] = mn[0,-1] partial[], trans x - dist_out, dist_x_grad, dist_y_grad = self.test_body( - x_shape=[32, 64], - y_shape=[32, 48], - x_specs=[None, 'x'], - y_specs=[None, None], + inputs_shape = ([32, 64], [32, 48]) + inputs_specs = ([None, 'x'], [None, None]) + dist_input, dist_out = self.test_body( + inputs_shape, + inputs_specs, trans_x=True, trans_y=False, ) + x, y = dist_input # verify output local shape and dist attr np.testing.assert_equal(dist_out._local_shape, [32, 48], verbose=True) np.testing.assert_equal( @@ -178,32 +128,29 @@ def test_matmul_x_column_shard_trans_x(self): ) assert dist_out.dist_attr._is_partial() is False # verify x_grad local shape and dist attr + np.testing.assert_equal(x.grad._local_shape, [32, 32], verbose=True) np.testing.assert_equal( - dist_x_grad._local_shape, [32, 32], verbose=True + x.grad.dist_attr.dims_mapping, [-1, 0], verbose=True ) - np.testing.assert_equal( - dist_x_grad.dist_attr.dims_mapping, [-1, 0], verbose=True - ) - assert dist_x_grad.dist_attr._is_partial() is False + assert x.grad.dist_attr._is_partial() is False # verify y_grad local shape and dist attr + np.testing.assert_equal(y.grad._local_shape, [32, 48], verbose=True) np.testing.assert_equal( - dist_y_grad._local_shape, [32, 48], verbose=True + y.grad.dist_attr.dims_mapping, [-1, -1], verbose=True ) - np.testing.assert_equal( - dist_y_grad.dist_attr.dims_mapping, [-1, -1], verbose=True - ) - assert dist_y_grad.dist_attr._is_partial() is False + assert y.grad.dist_attr._is_partial() is False def test_matmul_x_row_shard_trans_y(self): # case1: mk[0,-1],kn[-1,-1] -> mk[0,-1],kn[-1,-1] = mn[0,-1] partial[], trans y - dist_out, dist_x_grad, dist_y_grad = self.test_body( - x_shape=[64, 32], - y_shape=[48, 32], - x_specs=['x', None], - y_specs=[None, None], + inputs_shape = ([64, 32], [48, 32]) + inputs_specs = (['x', None], [None, None]) + dist_input, dist_out = self.test_body( + inputs_shape, + inputs_specs, trans_x=False, trans_y=True, ) + x, y = dist_input # verify output local shape and dist attr np.testing.assert_equal(dist_out._local_shape, [32, 48], verbose=True) np.testing.assert_equal( @@ -211,21 +158,17 @@ def test_matmul_x_row_shard_trans_y(self): ) assert dist_out.dist_attr._is_partial() is False # verify x_grad local shape and dist attr + np.testing.assert_equal(x.grad._local_shape, [32, 32], verbose=True) np.testing.assert_equal( - dist_x_grad._local_shape, [32, 32], verbose=True - ) - np.testing.assert_equal( - dist_x_grad.dist_attr.dims_mapping, [0, -1], verbose=True + x.grad.dist_attr.dims_mapping, [0, -1], verbose=True ) - assert dist_x_grad.dist_attr._is_partial() is False + assert x.grad.dist_attr._is_partial() is False # verify y_grad local shape and dist attr + np.testing.assert_equal(y.grad._local_shape, [48, 32], verbose=True) np.testing.assert_equal( - dist_y_grad._local_shape, [48, 32], verbose=True - ) - np.testing.assert_equal( - dist_y_grad.dist_attr.dims_mapping, [-1, -1], verbose=True + y.grad.dist_attr.dims_mapping, [-1, -1], verbose=True ) - assert dist_y_grad.dist_attr._is_partial() is False + assert y.grad.dist_attr._is_partial() is False def run_test_case(self): if self._backend == "cpu": diff --git a/test/auto_parallel/semi_auto_parallel_for_reduction.py b/test/auto_parallel/semi_auto_parallel_for_reduction.py index 4b2e7d4bb026b..80b6e562d79df 100644 --- a/test/auto_parallel/semi_auto_parallel_for_reduction.py +++ b/test/auto_parallel/semi_auto_parallel_for_reduction.py @@ -19,8 +19,10 @@ import paddle import paddle.distributed as dist +from .semi_auto_parallel_util import SemiAutoParallelTestBase -class TestReductionApiForSemiAutoParallel: + +class TestReductionApiForSemiAutoParallel(SemiAutoParallelTestBase): def __init__(self): self._dtype = os.getenv("dtype") self._backend = os.getenv("backend") @@ -33,35 +35,33 @@ def check_tensor_eq(self, a, b): np.testing.assert_allclose(np1, np2, rtol=1e-05, verbose=True) def test_body(self, x_shape, out_shape, x_specs, axis, keepdim, op_func): - paddle.seed(self._seed) - np.random.seed(self._seed) - - x = paddle.randn(x_shape, self._dtype) - x.stop_gradient = False - - x_dist_attr = dist.DistAttr(mesh=self._mesh, sharding_specs=x_specs) - - dist_x = dist.shard_tensor(x, dist_attr=x_dist_attr) - dist_x.stop_gradient = False - - dist_out = op_func(dist_x, axis=axis, keepdim=keepdim) - out = op_func(x, axis=axis, keepdim=keepdim) - self.check_tensor_eq(out, dist_out) + dist_input, dist_out = super().test_body( + x_shape, + x_specs, + op_func, + axis=axis, + keepdim=keepdim, + ) np.testing.assert_equal(dist_out.shape, out_shape, verbose=True) - dist_out.backward() - out.backward() - self.check_tensor_eq(x.grad, dist_x.grad) - - def test_sum_x_shard(self): - self.test_body( - x_shape=[4, 8, 6], - out_shape=[4, 6], - x_specs=['x', None, None], - axis=1, - keepdim=False, - op_func=paddle.sum, - ) + def test_reduce_x_shard(self): + for op_func in [paddle.sum, paddle.mean]: + self.test_body( + x_shape=[4, 8, 6], + out_shape=[4, 6], + x_specs=['x', None, None], + axis=1, + keepdim=False, + op_func=op_func, + ) + self.test_body( + x_shape=[4, 8, 6], + out_shape=[8, 6], + x_specs=['x', None, None], + axis=-3, + keepdim=False, + op_func=op_func, + ) def test_sum_x_shard_on_axis(self): self.test_body( diff --git a/test/auto_parallel/semi_auto_parallel_util.py b/test/auto_parallel/semi_auto_parallel_util.py new file mode 100644 index 0000000000000..dc8930eeb14b7 --- /dev/null +++ b/test/auto_parallel/semi_auto_parallel_util.py @@ -0,0 +1,124 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import numpy as np + +import paddle +import paddle.distributed as dist + + +class SemiAutoParallelTestBase: + def __init__(self): + self._dtype = os.getenv("dtype") + self._backend = os.getenv("backend") + self._seed = eval(os.getenv("seed")) + self._mesh = dist.ProcessMesh([0, 1], dim_names=["x"]) + + def check_tensor_eq(self, a, b): + np1 = a.numpy() + np2 = b.numpy() + np.testing.assert_allclose(np1, np2, rtol=1e-05, verbose=True) + + def flatten(self, inputs, terminal_cond): + """ + inputs may be single tensor、tuple + """ + + if terminal_cond(inputs): + return [inputs], "i" + + assert isinstance(inputs, (tuple, list)) + flattened = [] + structure = [] + for i in range(len(inputs)): + tmp, tmp_structure = self.flatten(inputs[i]) + flattened.extend(tmp) + structure.append(tmp_structure) + + if isinstance(inputs, list): + structure = tuple(structure) + return flattened, structure + + def unflatten(self, inputs, structure, offset=0): + """ + inputs may be single tensor + """ + assert isinstance(inputs, list) + assert offset < len(inputs) + if structure == "i": + assert len(inputs) == 1 + offset = offset + 1 + # return a list + return inputs, offset + assert isinstance(structure, (tuple, list)) + unflattened = [] + for i in range(len(structure)): + tmp, offset = self.unflatten(inputs, structure[i], offset) + unflattened.append(tmp) + if isinstance(inputs, tuple): + unflattened = tuple(unflattened) + return unflattened, offset + + def test_body(self, inputs_shape, inputs_specs, op_func, **kwargs): + paddle.seed(self._seed) + np.random.seed(self._seed) + + flat_inputs = [] + flat_dist_inputs = [] + + def terminal_cond(x): + return isinstance(x, list) and all( + not isinstance(e, (list, tuple)) for e in x + ) + + flat_inputs_specs, inputs_structure = self.flatten( + inputs_specs, terminal_cond + ) + flat_inputs_shape, _ = self.flatten(inputs_shape, terminal_cond) + assert len(flat_inputs_specs) == len(flat_inputs_shape) + + for shape, spec in zip(flat_inputs_shape, flat_inputs_specs): + input_np = np.random.random(size=shape).astype(self._dtype) + input = paddle.to_tensor(input_np) + input.stop_gradient = False + input_dist_attr = dist.DistAttr( + mesh=self._mesh, sharding_specs=spec + ) + dist_input = dist.shard_tensor(input, dist_attr=input_dist_attr) + dist_input.stop_gradient = False + flat_inputs.append(input) + flat_dist_inputs.append(dist_input) + + inputs = self.unflatten(flat_inputs, inputs_structure) + dist_inputs = self.unflatten(flat_dist_inputs, inputs_structure) + out = op_func(**inputs, **kwargs) + dist_out = op_func(**dist_inputs, **kwargs) + + def terminal_cond2(x): + return not isinstance(x, (list, tuple)) + + flat_out = self.flatten(out, terminal_cond2) + flat_dist_out = self.flatten(dist_out, terminal_cond2) + assert len(flat_out) == len(flat_dist_out) + for output, dist_output in zip(flat_out, flat_dist_out): + self.check_tensor_eq(out, dist_out) + output.backward() + dist_output.backward() + + for x, dist_x in zip(flat_inputs, flat_dist_inputs): + self.check_tensor_eq(x.grad, dist_x.grad) + + return dist_inputs, dist_out From 80582e7980ff7b9274283316ee5ecf42c94e6df5 Mon Sep 17 00:00:00 2001 From: liuzhenhai93 Date: Sat, 4 Nov 2023 06:41:07 +0000 Subject: [PATCH 02/25] polish --- paddle/phi/infermeta/spmd_rules/concat.cc | 32 ++- paddle/phi/infermeta/spmd_rules/layer_norm.cc | 14 ++ paddle/phi/infermeta/spmd_rules/layer_norm.h | 9 + paddle/phi/infermeta/spmd_rules/utils.cc | 194 +++++++++++++++++- paddle/phi/infermeta/spmd_rules/utils.h | 12 ++ 5 files changed, 256 insertions(+), 5 deletions(-) diff --git a/paddle/phi/infermeta/spmd_rules/concat.cc b/paddle/phi/infermeta/spmd_rules/concat.cc index fd036cfad603a..edf6f5e012815 100644 --- a/paddle/phi/infermeta/spmd_rules/concat.cc +++ b/paddle/phi/infermeta/spmd_rules/concat.cc @@ -25,8 +25,19 @@ namespace distributed { using phi::distributed::auto_parallel::str_join; -static bool IsEmpty(const std::vector& shape) { - return shape.empty() || shape.at(0) == 0; +std::tuple FillConcatNotation(int64_t n_axis, + int64_t concat_axis) { + PADDLE_ENFORCE_EQ( + n_axis > concat_axis, true, phi::errors::InvalidArgument("")); + static const std::string alphabet = "abcdefghijlopqrstuvwxyz"; + PADDLE_ENFORCE_EQ(alphabet.size() > static_cast(n_axis), + true, + phi::errors::InvalidArgument("")); + std::string all_axis = alphabet.substr(0, n_axis); + std::string align_axis = + std::string(all_axis.begin(), all_axis.begin() + concat_axis) + + std::string(all_axis.begin() + concat_axis + 1, all_axis.end()); + return {all_axis, align_axis}; } SpmdInfo ConcatInferSpmd(const std::vector& x, int axis) { @@ -58,10 +69,15 @@ SpmdInfo ConcatInferSpmd(const std::vector& x, int axis) { auto non_empty_index = non_empty_iter - tensor_shapes.begin(); int64_t ndim = static_cast(tensor_shapes[non_empty_index].size()); // normlize dim - int64_t dim = axis; - dim = dim < 0 ? dim + ndim : dim; + auto dim = axis < 0 ? ndim + axis : axis; std::vector input_attrs; + std::transform( + x.begin(), x.end(), std::back_inserter(input_attrs), [](auto& meta) { + return meta.dist_attr(); + }); + /* + // 2、make sure all tensors replicated on concat dim auto n_inputs = x.size(); for (size_t i = 0; i < n_inputs; ++i) { @@ -170,6 +186,14 @@ SpmdInfo ConcatInferSpmd(const std::vector& x, int axis) { } std::swap(input_attrs, new_input_attrs); } + + */ + std::string all_aixs; + std::string align_axis; + std::tie(all_aixs, align_axis) = FillConcatNotation(axis, dim); + std::vector axis_names(input_attrs.size(), all_aixs); + AlignDimsSharding( + &input_attrs, tensor_shapes, axis_names, {}, align_axis, true); return {{input_attrs}, {input_attrs[non_empty_index]}}; } diff --git a/paddle/phi/infermeta/spmd_rules/layer_norm.cc b/paddle/phi/infermeta/spmd_rules/layer_norm.cc index 1dfe8bf19c296..9f7dcf9e47394 100644 --- a/paddle/phi/infermeta/spmd_rules/layer_norm.cc +++ b/paddle/phi/infermeta/spmd_rules/layer_norm.cc @@ -278,5 +278,19 @@ SpmdInfo LayerNormInferSpmdReverse(const DistMetaTensor& x, return {ToArgDistAttr(input_dist_attrs), ToArgDistAttr(output_dist_attrs)}; } +std::tuple, std::string> BuildLayerNormGradEinsum( + int64_t input_rank, int64_t begin_norm_axis) {} + +SpmdInfo LayerNormGradInferSpmd(const DistMetaTensor& x, + const DistMetaTensor& scale, + const DistMetaTensor& bias, + const DistMetaTensor& mean, + const DistMetaTensor& variance, + const DistMetaTensor out_grad, + float epsilon, + int begin_norm_axis) { + return SpmdInfo(); +} + } // namespace distributed } // namespace phi diff --git a/paddle/phi/infermeta/spmd_rules/layer_norm.h b/paddle/phi/infermeta/spmd_rules/layer_norm.h index c33b58a51bc20..195618168cefe 100644 --- a/paddle/phi/infermeta/spmd_rules/layer_norm.h +++ b/paddle/phi/infermeta/spmd_rules/layer_norm.h @@ -26,6 +26,15 @@ SpmdInfo LayerNormInferSpmd(const DistMetaTensor& x, float epsilon, int begin_norm_axis); +SpmdInfo LayerNormGradInferSpmd(const DistMetaTensor& x, + const DistMetaTensor& scale, + const DistMetaTensor& bias, + const DistMetaTensor& mean, + const DistMetaTensor& variance, + const DistMetaTensor out_grad, + float epsilon = 1e-5, + int begin_norm_axis = 1); + SpmdInfo LayerNormInferSpmdReverse(const DistMetaTensor& x, const DistMetaTensor& scale, const DistMetaTensor& bias, diff --git a/paddle/phi/infermeta/spmd_rules/utils.cc b/paddle/phi/infermeta/spmd_rules/utils.cc index 42bbc659b2f2b..8273f7749a765 100644 --- a/paddle/phi/infermeta/spmd_rules/utils.cc +++ b/paddle/phi/infermeta/spmd_rules/utils.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/phi/infermeta/spmd_rules/utils.h" - +#include #include "glog/logging.h" #include "paddle/phi/core/distributed/auto_parallel/dist_attr.h" @@ -201,6 +201,198 @@ bool PlacementEqual(const std::shared_ptr& a, return a_shard->get_axis() == b_shard->get_axis(); } +bool AlignDimsSharding(std::vector* input_attrs_ptr, + const std::vector>& tensor_shapes, + const std::vector& axis_names, + const std::set& skip_mesh_dims, + const std::string& align_axis, + bool allow_partial) { + auto& input_attrs = *input_attrs_ptr; + size_t n_inputs = input_attrs.size(); + PADDLE_ENFORCE_EQ( + n_inputs, tensor_shapes.size(), phi::errors::InvalidArgument("")); + PADDLE_ENFORCE_EQ( + n_inputs, axis_names.size(), phi::errors::InvalidArgument("")); + + PADDLE_ENFORCE_EQ( + !align_axis.empty(), true, phi::errors::InvalidArgument("")); + + std::map, int64_t> axis_name_to_dim; + + for (size_t i = 0; i < n_inputs; i++) { + // 1、check all inputs have the align_axis + for (char axi : align_axis) { + if (axis_names[i].find(axi) == std::string::npos) { + PADDLE_THROW(phi::errors::PreconditionNotMet( + "[%s] some axis not in input [%d],[%s]", + align_axis, + i, + axis_names[i])); + } + } + // 2、build axis map + for (size_t j = 0; j < axis_names[i].size(); j++) { + auto axi = axis_names[i][j]; + axis_name_to_dim[{i, axi}] = j; + } + } + // 3、check all inputs have the same align_axis + auto non_empty_iter = + std::find_if(tensor_shapes.begin(), tensor_shapes.end(), [](auto& shape) { + return !IsEmpty(shape); + }); + auto non_empty_index = non_empty_iter - tensor_shapes.begin(); + + // 3、align non-concat dimensions according to cost + std::vector>> inputs_placements; + std::transform( + input_attrs.begin(), + input_attrs.end(), + std::back_inserter(inputs_placements), + [](const TensorDistAttr& attr) { return attr.to_placement(); }); + + const auto& process_mess = input_attrs[non_empty_index].process_mesh(); + auto has_mismatch = [&](int32_t mesh_dim) { + bool mismatch = false; + for (size_t i = 0; i < n_inputs; i++) { + if (IsEmpty(tensor_shapes[i])) { + continue; + } + auto& p_a = inputs_placements[non_empty_index][mesh_dim]; + auto& p_b = inputs_placements[i][mesh_dim]; + if (!p_a->is_shard()) { + if (!PlacementEqual(p_a, p_b)) { + mismatch = true; + break; + } + } + if (!p_a->is_shard()) { + mismatch = true; + break; + } + auto a_shard = std::dynamic_pointer_cast(p_a); + auto b_shard = std::dynamic_pointer_cast(p_b); + auto a_axis = axis_names[non_empty_index][a_shard->get_axis()]; + auto b_axis = axis_names[non_empty_index][b_shard->get_axis()]; + if (a_axis != b_axis) { + mismatch = true; + break; + } + } + return mismatch; + }; + + // a dim can not be sharded twice along diffrent mesh_dim + std::set sharded_axis; + std::map partial_dim_to_type; + std::map mesh_dim_to_axis; + + // 4、find already shard axis + for (int32_t mesh_dim = 0; mesh_dim < process_mess.ndim(); ++mesh_dim) { + if (!has_mismatch(mesh_dim)) { + auto& old = inputs_placements[non_empty_index][mesh_dim]; + if (old->is_shard()) { + auto shard_placement = std::dynamic_pointer_cast(old); + auto axis_name = + axis_names[non_empty_index][shard_placement->get_axis()]; + if (align_axis.find(axis_name) == std::string::npos) { + continue; + } + sharded_axis.insert(axis_name); + mesh_dim_to_axis[mesh_dim] = axis_name; + } else if (old->is_partial()) { + auto partial_placement = std::dynamic_pointer_cast(old); + auto reduce_type = partial_placement->get_reduce_type(); + if (allow_partial && (reduce_type == ReduceType::kRedSum || + reduce_type == ReduceType::kRedAvg)) { + partial_dim_to_type[mesh_dim] = reduce_type; + } + } + } + } + // 4、align axis + for (int32_t mesh_dim = 0; mesh_dim < process_mess.ndim(); ++mesh_dim) { + if (!has_mismatch(mesh_dim)) { + continue; + } + if (skip_mesh_dims.count(mesh_dim)) { + continue; + } + if (partial_dim_to_type.count(mesh_dim)) { + continue; + } + std::priority_queue, + std::vector>, + std::greater>> + cost_queue; + + for (auto axis_name : align_axis) { + double cost = std::numeric_limits::infinity(); + if (!sharded_axis.count(axis_name)) { + cost = 0.0; + for (size_t i = 0; i < n_inputs; i++) { + auto& tensor_shape = tensor_shapes[i]; + auto& tensor_dist_attr = input_attrs[i]; + if (IsEmpty(tensor_shape)) { + continue; + } + auto shard_dim = axis_name_to_dim[{i, axis_name}]; + if (tensor_shape[shard_dim] < process_mess.dim_size(mesh_dim)) { + // should not be selected + cost += std::numeric_limits::infinity(); + continue; + } + if (IsDimSharded(tensor_dist_attr, shard_dim)) { + continue; + } + int64_t num = std::accumulate(tensor_shape.begin(), + tensor_shape.end(), + 1, + std::multiplies()); + if (num == static_cast(0)) { + continue; + } + std::vector local_shape = + GetLocalShape(tensor_shape, process_mess, inputs_placements[i]); + cost += std::accumulate(local_shape.begin(), + local_shape.end(), + 1, + std::multiplies()) * + process_mess.dim_size(mesh_dim); + } + } + cost_queue.push(std::make_pair(cost, axis_name)); + } + while (!cost_queue.empty()) { + auto cost_axis = cost_queue.top(); + cost_queue.pop(); + if (sharded_axis.count(cost_axis.second)) { + continue; + } + if (cost_axis.first == std::numeric_limits::infinity()) { + continue; + } + sharded_axis.insert(cost_axis.second); + mesh_dim_to_axis[mesh_dim] = cost_axis.second; + } + } + std::vector new_input_attrs; + for (size_t i; i < n_inputs; i++) { + auto& e = input_attrs[i]; + std::vector> placements( + process_mess.ndim(), std::make_shared()); + for (auto pair : mesh_dim_to_axis) { + auto shard_dim = axis_name_to_dim[{i, pair.second}]; + placements[pair.first] = std::make_shared(shard_dim); + } + for (auto pair : partial_dim_to_type) { + placements[pair.first] = std::make_shared(pair.second); + } + new_input_attrs.emplace_back(FromPlacements(e, placements)); + } + std::swap(input_attrs, new_input_attrs); +} + TensorDistAttr FromPlacements( const TensorDistAttr& dist_attr, const std::vector>& placements) { diff --git a/paddle/phi/infermeta/spmd_rules/utils.h b/paddle/phi/infermeta/spmd_rules/utils.h index b5b5e207a0ee6..4bd3ca4851a32 100644 --- a/paddle/phi/infermeta/spmd_rules/utils.h +++ b/paddle/phi/infermeta/spmd_rules/utils.h @@ -14,6 +14,7 @@ limitations under the License. */ #pragma once +#include #include #include #include @@ -27,6 +28,10 @@ namespace phi { namespace distributed { class TensorDistAttr; +bool IsEmpty(const std::vector& shape) { + return shape.empty() || shape.at(0) == 0; +} + // Generate the axis notation of tensor for the einsum notation of a broadcast // operation(alignment star from the rightmost axis). tenosr_ndim: the size of // the tensor. broadcast_ndim: the maxium size of tensors in this broadcast @@ -88,6 +93,13 @@ TensorDistAttr ReplicateTensorDim(const TensorDistAttr& dist_attr, int dim); bool PlacementEqual(const std::shared_ptr& a, const std::shared_ptr& b); +bool AlignDimsSharding(std::vector* input_attrs, + const std::vector>& tensor_shapes, + const std::vector& axis_names, + const std::set& skip_mesh_dims, + const std::string& align_axis, + bool allow_partial); + // Adaptor for variadic arguments template struct ArgsIterator { From 775df693c1281931332a2bad7b3cded119bd60a3 Mon Sep 17 00:00:00 2001 From: liuzhenhai93 Date: Mon, 6 Nov 2023 11:13:01 +0000 Subject: [PATCH 03/25] polish --- paddle/phi/infermeta/spmd_rules/layer_norm.cc | 4 +- paddle/phi/infermeta/spmd_rules/utils.cc | 2 +- .../semi_auto_parallel_for_concat.py | 11 +++ .../semi_auto_parallel_for_layernorm.py | 73 +++++++++++++++++++ 4 files changed, 88 insertions(+), 2 deletions(-) create mode 100644 test/auto_parallel/semi_auto_parallel_for_layernorm.py diff --git a/paddle/phi/infermeta/spmd_rules/layer_norm.cc b/paddle/phi/infermeta/spmd_rules/layer_norm.cc index 9f7dcf9e47394..919b182f623bb 100644 --- a/paddle/phi/infermeta/spmd_rules/layer_norm.cc +++ b/paddle/phi/infermeta/spmd_rules/layer_norm.cc @@ -279,7 +279,9 @@ SpmdInfo LayerNormInferSpmdReverse(const DistMetaTensor& x, } std::tuple, std::string> BuildLayerNormGradEinsum( - int64_t input_rank, int64_t begin_norm_axis) {} + int64_t input_rank, int64_t begin_norm_axis) { + return {{}, ""}; +} SpmdInfo LayerNormGradInferSpmd(const DistMetaTensor& x, const DistMetaTensor& scale, diff --git a/paddle/phi/infermeta/spmd_rules/utils.cc b/paddle/phi/infermeta/spmd_rules/utils.cc index 8273f7749a765..8d017f77432af 100644 --- a/paddle/phi/infermeta/spmd_rules/utils.cc +++ b/paddle/phi/infermeta/spmd_rules/utils.cc @@ -377,7 +377,7 @@ bool AlignDimsSharding(std::vector* input_attrs_ptr, } } std::vector new_input_attrs; - for (size_t i; i < n_inputs; i++) { + for (size_t i = 0; i < n_inputs; i++) { auto& e = input_attrs[i]; std::vector> placements( process_mess.ndim(), std::make_shared()); diff --git a/test/auto_parallel/semi_auto_parallel_for_concat.py b/test/auto_parallel/semi_auto_parallel_for_concat.py index 24605825d5f15..29198ae74c4eb 100644 --- a/test/auto_parallel/semi_auto_parallel_for_concat.py +++ b/test/auto_parallel/semi_auto_parallel_for_concat.py @@ -22,6 +22,15 @@ class TestSplitAndConcatSemiAutoParallel(SemiAutoParallelTestBase): def __init__(self): super().__init__() + def check_dim_mapping(self, inputs, output, expected_dim_mapping): + for t in inputs: + assert ( + t.dist_attr.dim_mapping == expected_dim_mapping + ), f"{t.dist_attr.dim_mapping} vs {expected_dim_mapping}" + assert ( + output.dist_attr.dim_mapping == expected_dim_mapping + ), f"{output.dist_attr.dim_mapping} vs {expected_dim_mapping}" + def test_concat_forward(self): shapes = [[16, 4, 4], [64, 4, 4]] specs = [[None, None, 'x'], [None, None, 'x']] @@ -32,6 +41,7 @@ def test_concat_forward(self): with_backward=False, axis=0, ) + self.check_dim_mapping(inputs, outputs, [-1, -1, 0]) def test_concat_forward_reshard(self): shapes = [[16, 4, 4], [64, 4, 4]] @@ -43,6 +53,7 @@ def test_concat_forward_reshard(self): with_backward=False, axis=0, ) + self.check_dim_mapping(inputs, outputs, [-1, -1, 0]) def run_test_case(self): if self._backend == "cpu": diff --git a/test/auto_parallel/semi_auto_parallel_for_layernorm.py b/test/auto_parallel/semi_auto_parallel_for_layernorm.py new file mode 100644 index 0000000000000..29198ae74c4eb --- /dev/null +++ b/test/auto_parallel/semi_auto_parallel_for_layernorm.py @@ -0,0 +1,73 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from semi_auto_parallel_util import SemiAutoParallelTestBase + +import paddle +import paddle.distributed as dist + + +class TestSplitAndConcatSemiAutoParallel(SemiAutoParallelTestBase): + def __init__(self): + super().__init__() + + def check_dim_mapping(self, inputs, output, expected_dim_mapping): + for t in inputs: + assert ( + t.dist_attr.dim_mapping == expected_dim_mapping + ), f"{t.dist_attr.dim_mapping} vs {expected_dim_mapping}" + assert ( + output.dist_attr.dim_mapping == expected_dim_mapping + ), f"{output.dist_attr.dim_mapping} vs {expected_dim_mapping}" + + def test_concat_forward(self): + shapes = [[16, 4, 4], [64, 4, 4]] + specs = [[None, None, 'x'], [None, None, 'x']] + inputs, outputs = self.runfunc_and_check( + inputs_shape=shapes, + inputs_specs=specs, + op_func=paddle.concat, + with_backward=False, + axis=0, + ) + self.check_dim_mapping(inputs, outputs, [-1, -1, 0]) + + def test_concat_forward_reshard(self): + shapes = [[16, 4, 4], [64, 4, 4]] + specs = [['x', None, None], [None, None, 'x']] + inputs, outputs = self.runfunc_and_check( + inputs_shape=shapes, + inputs_specs=specs, + op_func=paddle.concat, + with_backward=False, + axis=0, + ) + self.check_dim_mapping(inputs, outputs, [-1, -1, 0]) + + def run_test_case(self): + if self._backend == "cpu": + paddle.set_device("cpu") + elif self._backend == "gpu": + paddle.set_device("gpu:" + str(dist.get_rank())) + else: + raise ValueError("Only support cpu or gpu backend.") + + self.test_concat_forward() + # all to all is not supported yet for cpu + if self._backend == "gpu": + self.test_concat_forward_reshard() + + +if __name__ == '__main__': + TestSplitAndConcatSemiAutoParallel().run_test_case() From ad22cccaa5cb9712c6981af7f510cd5cb3ff9ab5 Mon Sep 17 00:00:00 2001 From: liuzhenhai93 Date: Mon, 6 Nov 2023 11:26:50 +0000 Subject: [PATCH 04/25] polish --- paddle/phi/infermeta/spmd_rules/layer_norm.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/phi/infermeta/spmd_rules/layer_norm.cc b/paddle/phi/infermeta/spmd_rules/layer_norm.cc index 919b182f623bb..28d5f8b9be0ae 100644 --- a/paddle/phi/infermeta/spmd_rules/layer_norm.cc +++ b/paddle/phi/infermeta/spmd_rules/layer_norm.cc @@ -280,7 +280,7 @@ SpmdInfo LayerNormInferSpmdReverse(const DistMetaTensor& x, std::tuple, std::string> BuildLayerNormGradEinsum( int64_t input_rank, int64_t begin_norm_axis) { - return {{}, ""}; + return {std::vector(), ""}; } SpmdInfo LayerNormGradInferSpmd(const DistMetaTensor& x, From 5bcaa2f9aa0fa85ca048ecc3adc1a762d51a6536 Mon Sep 17 00:00:00 2001 From: liuzhenhai93 Date: Mon, 6 Nov 2023 11:30:23 +0000 Subject: [PATCH 05/25] polish --- paddle/phi/infermeta/spmd_rules/utils.cc | 2 +- paddle/phi/infermeta/spmd_rules/utils.h | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/paddle/phi/infermeta/spmd_rules/utils.cc b/paddle/phi/infermeta/spmd_rules/utils.cc index 8d017f77432af..a38b4646eaf39 100644 --- a/paddle/phi/infermeta/spmd_rules/utils.cc +++ b/paddle/phi/infermeta/spmd_rules/utils.cc @@ -201,7 +201,7 @@ bool PlacementEqual(const std::shared_ptr& a, return a_shard->get_axis() == b_shard->get_axis(); } -bool AlignDimsSharding(std::vector* input_attrs_ptr, +void AlignDimsSharding(std::vector* input_attrs_ptr, const std::vector>& tensor_shapes, const std::vector& axis_names, const std::set& skip_mesh_dims, diff --git a/paddle/phi/infermeta/spmd_rules/utils.h b/paddle/phi/infermeta/spmd_rules/utils.h index 4bd3ca4851a32..90d605edceba5 100644 --- a/paddle/phi/infermeta/spmd_rules/utils.h +++ b/paddle/phi/infermeta/spmd_rules/utils.h @@ -93,7 +93,7 @@ TensorDistAttr ReplicateTensorDim(const TensorDistAttr& dist_attr, int dim); bool PlacementEqual(const std::shared_ptr& a, const std::shared_ptr& b); -bool AlignDimsSharding(std::vector* input_attrs, +void AlignDimsSharding(std::vector* input_attrs, const std::vector>& tensor_shapes, const std::vector& axis_names, const std::set& skip_mesh_dims, From 8b82133e6f09d0c7f5ab4218f6f6927796de761c Mon Sep 17 00:00:00 2001 From: liuzhenhai93 Date: Tue, 7 Nov 2023 03:32:09 +0000 Subject: [PATCH 06/25] layer_norm_backward --- paddle/phi/infermeta/spmd_rules/layer_norm.cc | 129 +++++++++++++++++- paddle/phi/infermeta/spmd_rules/utils.h | 2 +- 2 files changed, 128 insertions(+), 3 deletions(-) diff --git a/paddle/phi/infermeta/spmd_rules/layer_norm.cc b/paddle/phi/infermeta/spmd_rules/layer_norm.cc index 28d5f8b9be0ae..b4e074b14852b 100644 --- a/paddle/phi/infermeta/spmd_rules/layer_norm.cc +++ b/paddle/phi/infermeta/spmd_rules/layer_norm.cc @@ -280,7 +280,13 @@ SpmdInfo LayerNormInferSpmdReverse(const DistMetaTensor& x, std::tuple, std::string> BuildLayerNormGradEinsum( int64_t input_rank, int64_t begin_norm_axis) { - return {std::vector(), ""}; + std::string alphabet = "ijklmnopqrstuvwxyz"; + std::string x_notation = alphabet.substr(0, input_rank); + std::string mean_variance_notation = x_notation.substr(begin_norm_axis); + std::string align_notation = x_notation.substr(0, begin_norm_axis); + return { + {x_notation, mean_variance_notation, mean_variance_notation, x_notation}, + align_notation}; } SpmdInfo LayerNormGradInferSpmd(const DistMetaTensor& x, @@ -291,7 +297,126 @@ SpmdInfo LayerNormGradInferSpmd(const DistMetaTensor& x, const DistMetaTensor out_grad, float epsilon, int begin_norm_axis) { - return SpmdInfo(); + auto get_shape = [](const auto& meta) { + return phi::vectorize(meta.dims()); + }; + // 1、check tensors shapes + auto x_shape = get_shape(x); + auto scale_shape = get_shape(scale); + auto bias_shape = get_shape(bias); + auto mean_shape = get_shape(mean); + auto variance_shape = get_shape(variance); + auto out_grad_shape = get_shape(out_grad); + PADDLE_ENFORCE_GE( + x_shape.size(), + begin_norm_axis, + phi::errors::InvalidArgument( + "The Tensor x's rank [%d] and begin_norm_axis [%d] are not matched.", + x_shape.size(), + begin_norm_axis)); + PADDLE_ENFORCE_EQ( + x_shape.size(), + out_grad_shape.size(), + phi::errors::InvalidArgument("The Tensor x's rank [%d] and Tensor " + "out_grad's rank [%d] are not matched.", + x_shape.size(), + out_grad_shape.size())); + + PADDLE_ENFORCE_EQ( + scale_shape.size(), + bias_shape.size(), + phi::errors::InvalidArgument("The Tensor scale's rank [%d] and Tensor " + "bias's rank [%d] are not matched.", + scale_shape.size(), + bias_shape.size())); + + PADDLE_ENFORCE_EQ( + mean_shape.size(), + variance_shape.size(), + phi::errors::InvalidArgument("The Tensor mean's rank [%d] and Tensor " + "variance's rank [%d] are not matched.", + mean_shape.size(), + variance_shape.size())); + + PADDLE_ENFORCE_EQ( + scale_shape.size() + begin_norm_axis, + x_shape.size(), + phi::errors::InvalidArgument("The Tensor scale's rank [%d] and Tensor " + "x's rank [%d] are not matched.", + scale_shape.size(), + x_shape.size())); + + if (begin_norm_axis > 0) { + PADDLE_ENFORCE_EQ( + scale_shape.size() + mean_shape.size(), + x_shape.size(), + phi::errors::InvalidArgument("The Tensor scale's rank [%d] and Tensor " + "x's rank [%d] are not matched.", + scale_shape.size(), + x_shape.size())); + } else { + PADDLE_ENFORCE_EQ( + scale_shape.size(), + x_shape.size(), + phi::errors::InvalidArgument("The Tensor scale's rank [%d] and Tensor " + "x's rank [%d] are not matched.", + scale_shape.size(), + x_shape.size())); + } + // 2、align sharding + TensorDistAttr x_dist_attr; + TensorDistAttr mean_dist_attr; + TensorDistAttr variance_dist_attr; + TensorDistAttr grad_dist_attr; + std::vector dist_attrs; + dist_attrs.push_back(x.dist_attr()); + dist_attrs.push_back(mean.dist_attr()); + dist_attrs.push_back(variance.dist_attr()); + dist_attrs.push_back(out_grad.dist_attr()); + if (begin_norm_axis > 0) { + std::vector> shapes = { + x_shape, mean_shape, variance_shape, x_shape}; + std::vector anotations; + std::string align_anotation; + std::tie(anotations, align_anotation) = + BuildLayerNormGradEinsum(x_shape.size(), begin_norm_axis); + AlignDimsSharding( + &dist_attrs, shapes, anotations, {}, align_anotation, false); + x_dist_attr = std::move(dist_attrs[0]); + mean_dist_attr = std::move(dist_attrs[1]); + variance_dist_attr = std::move(dist_attrs[2]); + grad_dist_attr = std::move(dist_attrs[3]); + } else { + x_dist_attr = GetReplicatedDistAttr(dist_attrs[0]); + mean_dist_attr = GetReplicatedDistAttr(dist_attrs[1]); + variance_dist_attr = GetReplicatedDistAttr(dist_attrs[2]); + grad_dist_attr = GetReplicatedDistAttr(dist_attrs[3]); + } + // TODO(liuzhenhai): support sharded scale and bias + TensorDistAttr scale_dist_attr = GetReplicatedDistAttr(scale.dist_attr()); + TensorDistAttr bias_dist_attr = GetReplicatedDistAttr(bias.dist_attr()); + TensorDistAttr scale_grad_dist_attr = + GetReplicatedDistAttr(scale.dist_attr()); + TensorDistAttr bias_grad_dist_attr = GetReplicatedDistAttr(bias.dist_attr()); + // partial grad dim + std::vector partial_on_dims; + const auto& dim_mapping = x_dist_attr.dims_mapping(); + for (int i = 0; i < begin_norm_axis; ++i) { + auto mapping = dim_mapping[i]; + if (mapping != -1) { + partial_on_dims.push_back(i); + } + } + scale_grad_dist_attr.set_partial_status(partial_on_dims); + bias_grad_dist_attr.set_partial_status(partial_on_dims); + + return SpmdInfo({x_dist_attr, + scale_dist_attr, + bias_dist_attr, + mean_dist_attr, + variance_dist_attr, + grad_dist_attr}, + {grad_dist_attr, scale_grad_dist_attr, bias_grad_dist_attr}); } } // namespace distributed diff --git a/paddle/phi/infermeta/spmd_rules/utils.h b/paddle/phi/infermeta/spmd_rules/utils.h index 90d605edceba5..a9e991f6334a4 100644 --- a/paddle/phi/infermeta/spmd_rules/utils.h +++ b/paddle/phi/infermeta/spmd_rules/utils.h @@ -28,7 +28,7 @@ namespace phi { namespace distributed { class TensorDistAttr; -bool IsEmpty(const std::vector& shape) { +inline bool IsEmpty(const std::vector& shape) { return shape.empty() || shape.at(0) == 0; } From b24f85c19370b88cc1c9e56beb8fe5d5e54dd259 Mon Sep 17 00:00:00 2001 From: liuzhenhai93 Date: Tue, 7 Nov 2023 03:54:42 +0000 Subject: [PATCH 07/25] layer_norm_backward --- test/cpp/auto_parallel/spmd_rule_test.cc | 40 ++++++++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/test/cpp/auto_parallel/spmd_rule_test.cc b/test/cpp/auto_parallel/spmd_rule_test.cc index eb6d08542b04a..c4c467fb92c8a 100644 --- a/test/cpp/auto_parallel/spmd_rule_test.cc +++ b/test/cpp/auto_parallel/spmd_rule_test.cc @@ -788,6 +788,46 @@ TEST(ConcatRule, Ctor) { check_dim_mapping(infered_dist_attrs.second[0], {1, -1, 0}); check_partial_dims(infered_dist_attrs.second[0], {}); } + +TEST(LayerNorm, Ctor) { + using phi::distributed::PartialStatus; + std::vector mesh_shape = {2, 2}; + std::vector process_ids = {0, 1, 2, 3}; + std::vector dim_names = {"x", "y"}; + ProcessMesh process_mesh(mesh_shape, process_ids, dim_names); + + std::vector x_shapes = {16, 32, 32}; + + auto build_input = [&](const std::vector& shape, + const std::vector& dim_mapping) { + auto t_dist_attr = TensorDistAttr(); + t_dist_attr.set_process_mesh(process_mesh); + t_dist_attr.set_dims_mapping(dim_mapping); + t_dist_attr.set_dynamic_dims({false, false, false}); + auto input = + phi::distributed::DistMetaTensor(phi::make_ddim(shape), t_dist_attr); + return input; + }; + // test 1 + auto x = build_input(x_shapes, {0, 1, -1}); + auto out_grad = build_input(x_shapes, {0, 1, -1}); + auto mean = build_input({16, 32}, {0, 1}); + auto variance = build_input({16, 32}, {0, 1}); + auto scale = build_input({32}, {0}); + auto bias = build_input({32}, {0}); + + auto spmd1 = + LayerNormGradInferSpmd(x, mean, variance, scale, bias, out_grad, 1.0, 2); + + // test 2 + mean = build_input({16}, {0}); + variance = build_input({16}, {0}); + scale = build_input({32, 32}, {0, 1}); + bias = build_input({32, 32}, {0, 1}); + auto spmd2 = + LayerNormGradInferSpmd(x, mean, variance, scale, bias, out_grad, 1.0, 1); +} + TEST(Util, Ctor) { // test equal test not equal using phi::distributed::PartialStatus; From 758149bc93c7887db518d0f1ce196d1a838409e7 Mon Sep 17 00:00:00 2001 From: liuzhenhai93 Date: Tue, 7 Nov 2023 04:13:48 +0000 Subject: [PATCH 08/25] polish --- paddle/phi/infermeta/spmd_rules/concat.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/phi/infermeta/spmd_rules/concat.cc b/paddle/phi/infermeta/spmd_rules/concat.cc index edf6f5e012815..31b69a49236eb 100644 --- a/paddle/phi/infermeta/spmd_rules/concat.cc +++ b/paddle/phi/infermeta/spmd_rules/concat.cc @@ -190,7 +190,7 @@ SpmdInfo ConcatInferSpmd(const std::vector& x, int axis) { */ std::string all_aixs; std::string align_axis; - std::tie(all_aixs, align_axis) = FillConcatNotation(axis, dim); + std::tie(all_aixs, align_axis) = FillConcatNotation(ndim, dim); std::vector axis_names(input_attrs.size(), all_aixs); AlignDimsSharding( &input_attrs, tensor_shapes, axis_names, {}, align_axis, true); From 02ed2c54c24ca0b342f9c2e6a83b999a65e27258 Mon Sep 17 00:00:00 2001 From: liuzhenhai93 Date: Tue, 7 Nov 2023 04:46:46 +0000 Subject: [PATCH 09/25] add test --- test/cpp/auto_parallel/spmd_rule_test.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/cpp/auto_parallel/spmd_rule_test.cc b/test/cpp/auto_parallel/spmd_rule_test.cc index c4c467fb92c8a..12e80f7059491 100644 --- a/test/cpp/auto_parallel/spmd_rule_test.cc +++ b/test/cpp/auto_parallel/spmd_rule_test.cc @@ -817,7 +817,7 @@ TEST(LayerNorm, Ctor) { auto bias = build_input({32}, {0}); auto spmd1 = - LayerNormGradInferSpmd(x, mean, variance, scale, bias, out_grad, 1.0, 2); + LayerNormGradInferSpmd(x, scale, bias, mean, variance, out_grad, 1.0, 2); // test 2 mean = build_input({16}, {0}); @@ -825,7 +825,7 @@ TEST(LayerNorm, Ctor) { scale = build_input({32, 32}, {0, 1}); bias = build_input({32, 32}, {0, 1}); auto spmd2 = - LayerNormGradInferSpmd(x, mean, variance, scale, bias, out_grad, 1.0, 1); + LayerNormGradInferSpmd(x, scale, bias, mean, variance, out_grad, 1.0, 1); } TEST(Util, Ctor) { From 85a2ecf312cbbdd14e6250feb6dd311aa4323a82 Mon Sep 17 00:00:00 2001 From: liuzhenhai93 Date: Tue, 7 Nov 2023 06:52:57 +0000 Subject: [PATCH 10/25] polish --- paddle/phi/infermeta/spmd_rules/concat.cc | 111 ------------------ paddle/phi/infermeta/spmd_rules/layer_norm.cc | 2 +- paddle/phi/infermeta/spmd_rules/utils.cc | 4 +- 3 files changed, 3 insertions(+), 114 deletions(-) diff --git a/paddle/phi/infermeta/spmd_rules/concat.cc b/paddle/phi/infermeta/spmd_rules/concat.cc index 31b69a49236eb..01ba776039483 100644 --- a/paddle/phi/infermeta/spmd_rules/concat.cc +++ b/paddle/phi/infermeta/spmd_rules/concat.cc @@ -76,118 +76,7 @@ SpmdInfo ConcatInferSpmd(const std::vector& x, int axis) { x.begin(), x.end(), std::back_inserter(input_attrs), [](auto& meta) { return meta.dist_attr(); }); - /* - - // 2、make sure all tensors replicated on concat dim - auto n_inputs = x.size(); - for (size_t i = 0; i < n_inputs; ++i) { - const auto& dist_attr = x[i].dist_attr(); - if ((!IsEmpty(tensor_shapes[i])) && IsDimSharded(dist_attr, dim)) { - auto sharded_dist_attr = ReplicateTensorDim(dist_attr, dim); - input_attrs.emplace_back(sharded_dist_attr); - } else { - input_attrs.emplace_back(dist_attr); - } - } - // 3、align non-concat dimensions according to cost - std::vector>> inputs_placements; - std::transform( - input_attrs.begin(), - input_attrs.end(), - std::back_inserter(inputs_placements), - [](const TensorDistAttr& attr) { return attr.to_placement(); }); - const auto& process_mess = input_attrs[non_empty_index].process_mesh(); - auto has_mismatch = [&](int32_t mesh_dim) { - bool mismatch = false; - for (size_t i = 0; i < n_inputs; i++) { - if ((!IsEmpty(tensor_shapes[i])) && - !PlacementEqual(inputs_placements[non_empty_index][mesh_dim], - inputs_placements[i][mesh_dim])) { - mismatch = true; - break; - } - } - return mismatch; - }; - bool need_reshard = false; - int32_t n_mesh_dim = process_mess.ndim(); - std::vector> best_placements( - n_mesh_dim, std::make_shared()); - // a dim can not be sharded twice along diffrent mesh_dim - std::set sharded_dims = {dim}; - - for (int32_t mesh_dim = 0; mesh_dim < process_mess.ndim(); ++mesh_dim) { - if (!has_mismatch(mesh_dim)) { - // use the old placement - auto& best = inputs_placements[non_empty_index][mesh_dim]; - if (best->is_shard()) { - auto shard_placement = std::dynamic_pointer_cast(best); - sharded_dims.insert(shard_placement->get_axis()); - } - best_placements[mesh_dim] = best; - } - } - - for (int32_t mesh_dim = 0; mesh_dim < process_mess.ndim(); ++mesh_dim) { - if (!has_mismatch(mesh_dim)) { - continue; - } - need_reshard = true; - std::vector costs; - for (int32_t shard_dim = 0; shard_dim < ndim; shard_dim++) { - double cost = std::numeric_limits::infinity(); - if (!sharded_dims.count(shard_dim)) { - cost = 0.0; - for (size_t i = 0; i < n_inputs; i++) { - auto& tensor_shape = tensor_shapes[i]; - auto& tensor_dist_attr = input_attrs[i]; - if (IsEmpty(tensor_shape)) { - continue; - } - - if (tensor_shape[shard_dim] < process_mess.dim_size(mesh_dim)) { - // should not be selected - cost += std::numeric_limits::infinity(); - continue; - } - if (IsDimSharded(tensor_dist_attr, shard_dim)) { - continue; - } - int64_t num = std::accumulate(tensor_shape.begin(), - tensor_shape.end(), - 1, - std::multiplies()); - if (num == static_cast(0)) { - continue; - } - std::vector local_shape = - GetLocalShape(tensor_shape, process_mess, inputs_placements[i]); - cost += std::accumulate(local_shape.begin(), - local_shape.end(), - 1, - std::multiplies()) * - process_mess.dim_size(mesh_dim); - } - } - costs.push_back(cost); - } - auto min_itr = std::min_element(costs.begin(), costs.end()); - auto min_dim = min_itr - costs.begin(); - if (!sharded_dims.count(min_dim)) { - best_placements[mesh_dim] = std::make_shared(min_dim); - sharded_dims.insert(min_dim); - } - } - // set placement to the best placements - if (need_reshard) { - std::vector new_input_attrs; - for (auto& e : input_attrs) { - new_input_attrs.emplace_back(FromPlacements(e, best_placements)); - } - std::swap(input_attrs, new_input_attrs); - } - */ std::string all_aixs; std::string align_axis; std::tie(all_aixs, align_axis) = FillConcatNotation(ndim, dim); diff --git a/paddle/phi/infermeta/spmd_rules/layer_norm.cc b/paddle/phi/infermeta/spmd_rules/layer_norm.cc index b4e074b14852b..89f3f96145459 100644 --- a/paddle/phi/infermeta/spmd_rules/layer_norm.cc +++ b/paddle/phi/infermeta/spmd_rules/layer_norm.cc @@ -282,7 +282,7 @@ std::tuple, std::string> BuildLayerNormGradEinsum( int64_t input_rank, int64_t begin_norm_axis) { std::string alphabet = "ijklmnopqrstuvwxyz"; std::string x_notation = alphabet.substr(0, input_rank); - std::string mean_variance_notation = x_notation.substr(begin_norm_axis); + std::string mean_variance_notation = x_notation.substr(0, begin_norm_axis); std::string align_notation = x_notation.substr(0, begin_norm_axis); return { {x_notation, mean_variance_notation, mean_variance_notation, x_notation}, diff --git a/paddle/phi/infermeta/spmd_rules/utils.cc b/paddle/phi/infermeta/spmd_rules/utils.cc index a38b4646eaf39..b87f1343e863f 100644 --- a/paddle/phi/infermeta/spmd_rules/utils.cc +++ b/paddle/phi/infermeta/spmd_rules/utils.cc @@ -266,14 +266,14 @@ void AlignDimsSharding(std::vector* input_attrs_ptr, break; } } - if (!p_a->is_shard()) { + if (!p_b->is_shard()) { mismatch = true; break; } auto a_shard = std::dynamic_pointer_cast(p_a); auto b_shard = std::dynamic_pointer_cast(p_b); auto a_axis = axis_names[non_empty_index][a_shard->get_axis()]; - auto b_axis = axis_names[non_empty_index][b_shard->get_axis()]; + auto b_axis = axis_names[i][b_shard->get_axis()]; if (a_axis != b_axis) { mismatch = true; break; From f0a47aed04ef89836b9cd84a54f8df413b0176ed Mon Sep 17 00:00:00 2001 From: liuzhenhai93 Date: Tue, 7 Nov 2023 07:28:12 +0000 Subject: [PATCH 11/25] polish --- test/cpp/auto_parallel/spmd_rule_test.cc | 27 ++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/test/cpp/auto_parallel/spmd_rule_test.cc b/test/cpp/auto_parallel/spmd_rule_test.cc index 12e80f7059491..a2cc7e04f77a8 100644 --- a/test/cpp/auto_parallel/spmd_rule_test.cc +++ b/test/cpp/auto_parallel/spmd_rule_test.cc @@ -819,6 +819,20 @@ TEST(LayerNorm, Ctor) { auto spmd1 = LayerNormGradInferSpmd(x, scale, bias, mean, variance, out_grad, 1.0, 2); + EXPECT_EQ(spmd1.first.size(), static_cast(5)); + EXPECT_EQ(spmd1.second.size(), static_cast(3)); + + check_dim_mapping(spmd1.first[0], {0, 1, -1}); + check_dim_mapping(spmd1.first[1], {-1}); + check_dim_mapping(spmd1.first[2], {-1}); + check_dim_mapping(spmd1.first[3], {0, 1}); + check_dim_mapping(spmd1.first[4], {0, 1}); + check_dim_mapping(spmd1.first[5], {0, 1, -1}); + check_dim_mapping(spmd1.second[0], {0, 1, -1}); + check_dim_mapping(spmd1.second[1], {-1}); + check_dim_mapping(spmd1.second[2], {-1}); + check_partial_dims(spmd1.second[1], {0, 1}); + check_partial_dims(spmd1.second[2], {0, 1}); // test 2 mean = build_input({16}, {0}); variance = build_input({16}, {0}); @@ -826,6 +840,19 @@ TEST(LayerNorm, Ctor) { bias = build_input({32, 32}, {0, 1}); auto spmd2 = LayerNormGradInferSpmd(x, scale, bias, mean, variance, out_grad, 1.0, 1); + EXPECT_EQ(spmd2.first.size(), static_cast(5)); + EXPECT_EQ(spmd2.second.size(), static_cast(3)); + check_dim_mapping(spmd2.first[0], {0, -1, -1}); + check_dim_mapping(spmd2.first[1], {-1, -1}); + check_dim_mapping(spmd2.first[2], {-1, -1}); + check_dim_mapping(spmd2.first[3], {0}); + check_dim_mapping(spmd2.first[4], {0}); + check_dim_mapping(spmd2.first[5], {0, -1, -1}); + check_dim_mapping(spmd1.second[0], {0, -1, -1}); + check_dim_mapping(spmd1.second[1], {-1, -1}); + check_dim_mapping(spmd1.second[2], {-1, -1}); + check_partial_dims(spmd1.second[1], {0}); + check_partial_dims(spmd1.second[2], {0}); } TEST(Util, Ctor) { From 56bc5e9017c707d4c70c7dad463105f287402ae3 Mon Sep 17 00:00:00 2001 From: liuzhenhai93 Date: Tue, 7 Nov 2023 08:41:15 +0000 Subject: [PATCH 12/25] polish --- paddle/phi/api/yaml/backward.yaml | 1 + paddle/phi/api/yaml/ops.yaml | 1 + paddle/phi/infermeta/spmd_rules/layer_norm.cc | 25 ------- paddle/phi/infermeta/spmd_rules/utils.cc | 1 + .../semi_auto_parallel_for_layernorm.py | 73 ------------------- test/cpp/auto_parallel/spmd_rule_test.cc | 14 ++-- 6 files changed, 10 insertions(+), 105 deletions(-) delete mode 100644 test/auto_parallel/semi_auto_parallel_for_layernorm.py diff --git a/paddle/phi/api/yaml/backward.yaml b/paddle/phi/api/yaml/backward.yaml index 157d34e28aaca..391151abe85b5 100644 --- a/paddle/phi/api/yaml/backward.yaml +++ b/paddle/phi/api/yaml/backward.yaml @@ -1251,6 +1251,7 @@ infer_meta : func : LayerNormGradInferMeta param : [x, scale, bias] + spmd_rule : LayerNormGradInferSpmd kernel : func : layer_norm_grad data_type : x diff --git a/paddle/phi/api/yaml/ops.yaml b/paddle/phi/api/yaml/ops.yaml index f22f5be8ec028..ef1a56e493cab 100644 --- a/paddle/phi/api/yaml/ops.yaml +++ b/paddle/phi/api/yaml/ops.yaml @@ -1410,6 +1410,7 @@ output : Tensor(out), Tensor(mean), Tensor(variance) infer_meta : func : LayerNormInferMeta + spmd_rule : LayerNormInferSpmd kernel : func : layer_norm data_type : x diff --git a/paddle/phi/infermeta/spmd_rules/layer_norm.cc b/paddle/phi/infermeta/spmd_rules/layer_norm.cc index 89f3f96145459..ff9bc9bfe12ab 100644 --- a/paddle/phi/infermeta/spmd_rules/layer_norm.cc +++ b/paddle/phi/infermeta/spmd_rules/layer_norm.cc @@ -338,31 +338,6 @@ SpmdInfo LayerNormGradInferSpmd(const DistMetaTensor& x, mean_shape.size(), variance_shape.size())); - PADDLE_ENFORCE_EQ( - scale_shape.size() + begin_norm_axis, - x_shape.size(), - phi::errors::InvalidArgument("The Tensor scale's rank [%d] and Tensor " - "x's rank [%d] are not matched.", - scale_shape.size(), - x_shape.size())); - - if (begin_norm_axis > 0) { - PADDLE_ENFORCE_EQ( - scale_shape.size() + mean_shape.size(), - x_shape.size(), - phi::errors::InvalidArgument("The Tensor scale's rank [%d] and Tensor " - "x's rank [%d] are not matched.", - scale_shape.size(), - x_shape.size())); - } else { - PADDLE_ENFORCE_EQ( - scale_shape.size(), - x_shape.size(), - phi::errors::InvalidArgument("The Tensor scale's rank [%d] and Tensor " - "x's rank [%d] are not matched.", - scale_shape.size(), - x_shape.size())); - } // 2、align sharding TensorDistAttr x_dist_attr; TensorDistAttr mean_dist_attr; diff --git a/paddle/phi/infermeta/spmd_rules/utils.cc b/paddle/phi/infermeta/spmd_rules/utils.cc index b87f1343e863f..cd1f5740766a9 100644 --- a/paddle/phi/infermeta/spmd_rules/utils.cc +++ b/paddle/phi/infermeta/spmd_rules/utils.cc @@ -374,6 +374,7 @@ void AlignDimsSharding(std::vector* input_attrs_ptr, } sharded_axis.insert(cost_axis.second); mesh_dim_to_axis[mesh_dim] = cost_axis.second; + break; } } std::vector new_input_attrs; diff --git a/test/auto_parallel/semi_auto_parallel_for_layernorm.py b/test/auto_parallel/semi_auto_parallel_for_layernorm.py deleted file mode 100644 index 29198ae74c4eb..0000000000000 --- a/test/auto_parallel/semi_auto_parallel_for_layernorm.py +++ /dev/null @@ -1,73 +0,0 @@ -# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from semi_auto_parallel_util import SemiAutoParallelTestBase - -import paddle -import paddle.distributed as dist - - -class TestSplitAndConcatSemiAutoParallel(SemiAutoParallelTestBase): - def __init__(self): - super().__init__() - - def check_dim_mapping(self, inputs, output, expected_dim_mapping): - for t in inputs: - assert ( - t.dist_attr.dim_mapping == expected_dim_mapping - ), f"{t.dist_attr.dim_mapping} vs {expected_dim_mapping}" - assert ( - output.dist_attr.dim_mapping == expected_dim_mapping - ), f"{output.dist_attr.dim_mapping} vs {expected_dim_mapping}" - - def test_concat_forward(self): - shapes = [[16, 4, 4], [64, 4, 4]] - specs = [[None, None, 'x'], [None, None, 'x']] - inputs, outputs = self.runfunc_and_check( - inputs_shape=shapes, - inputs_specs=specs, - op_func=paddle.concat, - with_backward=False, - axis=0, - ) - self.check_dim_mapping(inputs, outputs, [-1, -1, 0]) - - def test_concat_forward_reshard(self): - shapes = [[16, 4, 4], [64, 4, 4]] - specs = [['x', None, None], [None, None, 'x']] - inputs, outputs = self.runfunc_and_check( - inputs_shape=shapes, - inputs_specs=specs, - op_func=paddle.concat, - with_backward=False, - axis=0, - ) - self.check_dim_mapping(inputs, outputs, [-1, -1, 0]) - - def run_test_case(self): - if self._backend == "cpu": - paddle.set_device("cpu") - elif self._backend == "gpu": - paddle.set_device("gpu:" + str(dist.get_rank())) - else: - raise ValueError("Only support cpu or gpu backend.") - - self.test_concat_forward() - # all to all is not supported yet for cpu - if self._backend == "gpu": - self.test_concat_forward_reshard() - - -if __name__ == '__main__': - TestSplitAndConcatSemiAutoParallel().run_test_case() diff --git a/test/cpp/auto_parallel/spmd_rule_test.cc b/test/cpp/auto_parallel/spmd_rule_test.cc index a2cc7e04f77a8..bc64aa68faf25 100644 --- a/test/cpp/auto_parallel/spmd_rule_test.cc +++ b/test/cpp/auto_parallel/spmd_rule_test.cc @@ -819,7 +819,7 @@ TEST(LayerNorm, Ctor) { auto spmd1 = LayerNormGradInferSpmd(x, scale, bias, mean, variance, out_grad, 1.0, 2); - EXPECT_EQ(spmd1.first.size(), static_cast(5)); + EXPECT_EQ(spmd1.first.size(), static_cast(6)); EXPECT_EQ(spmd1.second.size(), static_cast(3)); check_dim_mapping(spmd1.first[0], {0, 1, -1}); @@ -840,7 +840,7 @@ TEST(LayerNorm, Ctor) { bias = build_input({32, 32}, {0, 1}); auto spmd2 = LayerNormGradInferSpmd(x, scale, bias, mean, variance, out_grad, 1.0, 1); - EXPECT_EQ(spmd2.first.size(), static_cast(5)); + EXPECT_EQ(spmd2.first.size(), static_cast(6)); EXPECT_EQ(spmd2.second.size(), static_cast(3)); check_dim_mapping(spmd2.first[0], {0, -1, -1}); check_dim_mapping(spmd2.first[1], {-1, -1}); @@ -848,11 +848,11 @@ TEST(LayerNorm, Ctor) { check_dim_mapping(spmd2.first[3], {0}); check_dim_mapping(spmd2.first[4], {0}); check_dim_mapping(spmd2.first[5], {0, -1, -1}); - check_dim_mapping(spmd1.second[0], {0, -1, -1}); - check_dim_mapping(spmd1.second[1], {-1, -1}); - check_dim_mapping(spmd1.second[2], {-1, -1}); - check_partial_dims(spmd1.second[1], {0}); - check_partial_dims(spmd1.second[2], {0}); + check_dim_mapping(spmd2.second[0], {0, -1, -1}); + check_dim_mapping(spmd2.second[1], {-1, -1}); + check_dim_mapping(spmd2.second[2], {-1, -1}); + check_partial_dims(spmd2.second[1], {0}); + check_partial_dims(spmd2.second[2], {0}); } TEST(Util, Ctor) { From d0fdeecfbfcde672249de4b07d49a7a5439a6f91 Mon Sep 17 00:00:00 2001 From: liuzhenhai93 Date: Tue, 7 Nov 2023 08:59:30 +0000 Subject: [PATCH 13/25] add test --- .../semi_auto_parallel_for_layernorm.py | 67 +++++++++++++++++++ .../test_semi_auto_parallel_basic.py | 10 +++ 2 files changed, 77 insertions(+) create mode 100644 test/auto_parallel/semi_auto_parallel_for_layernorm.py diff --git a/test/auto_parallel/semi_auto_parallel_for_layernorm.py b/test/auto_parallel/semi_auto_parallel_for_layernorm.py new file mode 100644 index 0000000000000..0e37a1a434396 --- /dev/null +++ b/test/auto_parallel/semi_auto_parallel_for_layernorm.py @@ -0,0 +1,67 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from semi_auto_parallel_util import SemiAutoParallelTestBase + +import paddle +import paddle.distributed as dist + + +def layer_norm(x, weight, bias, normalized_shape=None): + return paddle.nn.functional.layer_norm( + x, normalized_shape, weight=weight, bias=bias + ) + + +class TestLayerNormSemiAutoParallel(SemiAutoParallelTestBase): + def __init__(self): + super().__init__() + + def test_layernorm_forward(self): + shapes = ([16, 4, 4], [16], [16]) + specs = (["x", None, None], [None], [None]) + inputs, outputs = self.runfunc_and_check( + inputs_shape=shapes, + inputs_specs=specs, + op_func=layer_norm, + with_backward=True, + normalized_shape=[4, 4], + ) + + def test_layernorm_reshard(self): + shapes = ([16, 4, 4], [16], [16]) + specs = ([None, None, 'x'], [None], [None]) + inputs, outputs = self.runfunc_and_check( + inputs_shape=shapes, + inputs_specs=specs, + op_func=layer_norm, + with_backward=True, + normalized_shape=[4, 4], + ) + + def run_test_case(self): + if self._backend == "cpu": + paddle.set_device("cpu") + elif self._backend == "gpu": + paddle.set_device("gpu:" + str(dist.get_rank())) + else: + raise ValueError("Only support cpu or gpu backend.") + + self.test_layernorm_forward() + # all to all is not supported yet for cpu + self.test_layernorm_reshard() + + +if __name__ == '__main__': + TestLayerNormSemiAutoParallel().run_test_case() diff --git a/test/auto_parallel/test_semi_auto_parallel_basic.py b/test/auto_parallel/test_semi_auto_parallel_basic.py index 2589566cb670e..df722cf7dfe74 100644 --- a/test/auto_parallel/test_semi_auto_parallel_basic.py +++ b/test/auto_parallel/test_semi_auto_parallel_basic.py @@ -56,6 +56,16 @@ def test_concat_api(self): user_defined_envs=envs, ) + def test_layernorm_api(self): + envs_list = test_base.gen_product_envs_list( + self._default_envs, self._changeable_envs + ) + for envs in envs_list: + self.run_test_case( + "semi_auto_parallel_for_layernorm.py", + user_defined_envs=envs, + ) + def test_reduction_api(self): envs_list = test_base.gen_product_envs_list( self._default_envs, self._changeable_envs From c634811c260f9bc534e4141895ed7c7b1000b397 Mon Sep 17 00:00:00 2001 From: liuzhenhai93 Date: Tue, 7 Nov 2023 09:38:52 +0000 Subject: [PATCH 14/25] polish --- paddle/phi/api/yaml/ops.yaml | 1 - test/auto_parallel/semi_auto_parallel_for_layernorm.py | 4 ++-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/paddle/phi/api/yaml/ops.yaml b/paddle/phi/api/yaml/ops.yaml index ef1a56e493cab..f22f5be8ec028 100644 --- a/paddle/phi/api/yaml/ops.yaml +++ b/paddle/phi/api/yaml/ops.yaml @@ -1410,7 +1410,6 @@ output : Tensor(out), Tensor(mean), Tensor(variance) infer_meta : func : LayerNormInferMeta - spmd_rule : LayerNormInferSpmd kernel : func : layer_norm data_type : x diff --git a/test/auto_parallel/semi_auto_parallel_for_layernorm.py b/test/auto_parallel/semi_auto_parallel_for_layernorm.py index 0e37a1a434396..71c96175e02fa 100644 --- a/test/auto_parallel/semi_auto_parallel_for_layernorm.py +++ b/test/auto_parallel/semi_auto_parallel_for_layernorm.py @@ -35,7 +35,7 @@ def test_layernorm_forward(self): inputs_shape=shapes, inputs_specs=specs, op_func=layer_norm, - with_backward=True, + with_backward=False, normalized_shape=[4, 4], ) @@ -46,7 +46,7 @@ def test_layernorm_reshard(self): inputs_shape=shapes, inputs_specs=specs, op_func=layer_norm, - with_backward=True, + with_backward=False, normalized_shape=[4, 4], ) From dc0fd8f133bacd51c310113473c56ffa75cbc2b8 Mon Sep 17 00:00:00 2001 From: liuzhenhai93 Date: Tue, 7 Nov 2023 17:41:54 +0800 Subject: [PATCH 15/25] polish --- paddle/phi/api/yaml/backward.yaml | 1 - paddle/phi/api/yaml/ops.yaml | 1 + 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/phi/api/yaml/backward.yaml b/paddle/phi/api/yaml/backward.yaml index 391151abe85b5..157d34e28aaca 100644 --- a/paddle/phi/api/yaml/backward.yaml +++ b/paddle/phi/api/yaml/backward.yaml @@ -1251,7 +1251,6 @@ infer_meta : func : LayerNormGradInferMeta param : [x, scale, bias] - spmd_rule : LayerNormGradInferSpmd kernel : func : layer_norm_grad data_type : x diff --git a/paddle/phi/api/yaml/ops.yaml b/paddle/phi/api/yaml/ops.yaml index f22f5be8ec028..ef1a56e493cab 100644 --- a/paddle/phi/api/yaml/ops.yaml +++ b/paddle/phi/api/yaml/ops.yaml @@ -1410,6 +1410,7 @@ output : Tensor(out), Tensor(mean), Tensor(variance) infer_meta : func : LayerNormInferMeta + spmd_rule : LayerNormInferSpmd kernel : func : layer_norm data_type : x From 62b87f47f7c257da8c03f9057a3e1ed1eb185df2 Mon Sep 17 00:00:00 2001 From: liuzhenhai93 Date: Tue, 7 Nov 2023 11:12:16 +0000 Subject: [PATCH 16/25] format --- paddle/phi/api/yaml/generator/dist_api_gen.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/paddle/phi/api/yaml/generator/dist_api_gen.py b/paddle/phi/api/yaml/generator/dist_api_gen.py index f20dd50e61099..e1f50192b09e2 100644 --- a/paddle/phi/api/yaml/generator/dist_api_gen.py +++ b/paddle/phi/api/yaml/generator/dist_api_gen.py @@ -721,6 +721,14 @@ def generate_specialized_infer_spmd_code(self) -> str: name=param ) input_args_code += "meta_dist_input_" + param + ", " + elif ( + self.inputs['input_info'][param] + == "const paddle::optional&" + ): + input_decl_code += ( + OPTIONAL_SINGLE_DIST_META_IN_TEMPLATE.format(name=param) + ) + input_args_code += "meta_dist_input_" + param + ", " else: raise ValueError( From 29b86b77745bb77f08cb523016bf9b5459bc8c39 Mon Sep 17 00:00:00 2001 From: liuzhenhai93 Date: Tue, 7 Nov 2023 12:38:45 +0000 Subject: [PATCH 17/25] code gen not supported yet --- paddle/phi/api/yaml/ops.yaml | 1 - .../semi_auto_parallel_for_layernorm.py | 67 ------------------- .../test_semi_auto_parallel_basic.py | 10 --- 3 files changed, 78 deletions(-) delete mode 100644 test/auto_parallel/semi_auto_parallel_for_layernorm.py diff --git a/paddle/phi/api/yaml/ops.yaml b/paddle/phi/api/yaml/ops.yaml index ef1a56e493cab..f22f5be8ec028 100644 --- a/paddle/phi/api/yaml/ops.yaml +++ b/paddle/phi/api/yaml/ops.yaml @@ -1410,7 +1410,6 @@ output : Tensor(out), Tensor(mean), Tensor(variance) infer_meta : func : LayerNormInferMeta - spmd_rule : LayerNormInferSpmd kernel : func : layer_norm data_type : x diff --git a/test/auto_parallel/semi_auto_parallel_for_layernorm.py b/test/auto_parallel/semi_auto_parallel_for_layernorm.py deleted file mode 100644 index 71c96175e02fa..0000000000000 --- a/test/auto_parallel/semi_auto_parallel_for_layernorm.py +++ /dev/null @@ -1,67 +0,0 @@ -# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from semi_auto_parallel_util import SemiAutoParallelTestBase - -import paddle -import paddle.distributed as dist - - -def layer_norm(x, weight, bias, normalized_shape=None): - return paddle.nn.functional.layer_norm( - x, normalized_shape, weight=weight, bias=bias - ) - - -class TestLayerNormSemiAutoParallel(SemiAutoParallelTestBase): - def __init__(self): - super().__init__() - - def test_layernorm_forward(self): - shapes = ([16, 4, 4], [16], [16]) - specs = (["x", None, None], [None], [None]) - inputs, outputs = self.runfunc_and_check( - inputs_shape=shapes, - inputs_specs=specs, - op_func=layer_norm, - with_backward=False, - normalized_shape=[4, 4], - ) - - def test_layernorm_reshard(self): - shapes = ([16, 4, 4], [16], [16]) - specs = ([None, None, 'x'], [None], [None]) - inputs, outputs = self.runfunc_and_check( - inputs_shape=shapes, - inputs_specs=specs, - op_func=layer_norm, - with_backward=False, - normalized_shape=[4, 4], - ) - - def run_test_case(self): - if self._backend == "cpu": - paddle.set_device("cpu") - elif self._backend == "gpu": - paddle.set_device("gpu:" + str(dist.get_rank())) - else: - raise ValueError("Only support cpu or gpu backend.") - - self.test_layernorm_forward() - # all to all is not supported yet for cpu - self.test_layernorm_reshard() - - -if __name__ == '__main__': - TestLayerNormSemiAutoParallel().run_test_case() diff --git a/test/auto_parallel/test_semi_auto_parallel_basic.py b/test/auto_parallel/test_semi_auto_parallel_basic.py index df722cf7dfe74..2589566cb670e 100644 --- a/test/auto_parallel/test_semi_auto_parallel_basic.py +++ b/test/auto_parallel/test_semi_auto_parallel_basic.py @@ -56,16 +56,6 @@ def test_concat_api(self): user_defined_envs=envs, ) - def test_layernorm_api(self): - envs_list = test_base.gen_product_envs_list( - self._default_envs, self._changeable_envs - ) - for envs in envs_list: - self.run_test_case( - "semi_auto_parallel_for_layernorm.py", - user_defined_envs=envs, - ) - def test_reduction_api(self): envs_list = test_base.gen_product_envs_list( self._default_envs, self._changeable_envs From 66ec1be0ab265cf291eea26d1fc6daa224424e88 Mon Sep 17 00:00:00 2001 From: liuzhenhai93 Date: Mon, 13 Nov 2023 06:43:21 +0000 Subject: [PATCH 18/25] polish --- paddle/phi/api/yaml/backward.yaml | 1 + paddle/phi/api/yaml/ops.yaml | 1 + .../semi_auto_parallel_for_concat.py | 11 --- .../semi_auto_parallel_for_layernorm.py | 75 +++++++++++++++++++ .../test_semi_auto_parallel_basic.py | 14 ++++ 5 files changed, 91 insertions(+), 11 deletions(-) create mode 100644 test/auto_parallel/semi_auto_parallel_for_layernorm.py diff --git a/paddle/phi/api/yaml/backward.yaml b/paddle/phi/api/yaml/backward.yaml index 157d34e28aaca..fb450795beadf 100644 --- a/paddle/phi/api/yaml/backward.yaml +++ b/paddle/phi/api/yaml/backward.yaml @@ -1250,6 +1250,7 @@ output : Tensor(x_grad), Tensor(scale_grad), Tensor(bias_grad) infer_meta : func : LayerNormGradInferMeta + spmd_rule : LayerNormGradInferSpmd param : [x, scale, bias] kernel : func : layer_norm_grad diff --git a/paddle/phi/api/yaml/ops.yaml b/paddle/phi/api/yaml/ops.yaml index dc8477c06c8e5..199998e1135f5 100644 --- a/paddle/phi/api/yaml/ops.yaml +++ b/paddle/phi/api/yaml/ops.yaml @@ -1410,6 +1410,7 @@ output : Tensor(out), Tensor(mean), Tensor(variance) infer_meta : func : LayerNormInferMeta + spmd_rule : LayerNormInferSpmd kernel : func : layer_norm data_type : x diff --git a/test/auto_parallel/semi_auto_parallel_for_concat.py b/test/auto_parallel/semi_auto_parallel_for_concat.py index 29198ae74c4eb..24605825d5f15 100644 --- a/test/auto_parallel/semi_auto_parallel_for_concat.py +++ b/test/auto_parallel/semi_auto_parallel_for_concat.py @@ -22,15 +22,6 @@ class TestSplitAndConcatSemiAutoParallel(SemiAutoParallelTestBase): def __init__(self): super().__init__() - def check_dim_mapping(self, inputs, output, expected_dim_mapping): - for t in inputs: - assert ( - t.dist_attr.dim_mapping == expected_dim_mapping - ), f"{t.dist_attr.dim_mapping} vs {expected_dim_mapping}" - assert ( - output.dist_attr.dim_mapping == expected_dim_mapping - ), f"{output.dist_attr.dim_mapping} vs {expected_dim_mapping}" - def test_concat_forward(self): shapes = [[16, 4, 4], [64, 4, 4]] specs = [[None, None, 'x'], [None, None, 'x']] @@ -41,7 +32,6 @@ def test_concat_forward(self): with_backward=False, axis=0, ) - self.check_dim_mapping(inputs, outputs, [-1, -1, 0]) def test_concat_forward_reshard(self): shapes = [[16, 4, 4], [64, 4, 4]] @@ -53,7 +43,6 @@ def test_concat_forward_reshard(self): with_backward=False, axis=0, ) - self.check_dim_mapping(inputs, outputs, [-1, -1, 0]) def run_test_case(self): if self._backend == "cpu": diff --git a/test/auto_parallel/semi_auto_parallel_for_layernorm.py b/test/auto_parallel/semi_auto_parallel_for_layernorm.py new file mode 100644 index 0000000000000..ed04612aa2017 --- /dev/null +++ b/test/auto_parallel/semi_auto_parallel_for_layernorm.py @@ -0,0 +1,75 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from semi_auto_parallel_util import SemiAutoParallelTestBase + +import paddle +import paddle.distributed as dist + + +def layer_norm(input, weights, bias, normalized_shape): + return paddle.nn.functional.layer_norm( + input, normalized_shape, weight=weights, bias=bias + ) + + +class TestLayerNormSemiAutoParallel(SemiAutoParallelTestBase): + def __init__(self): + super().__init__() + + def check_dim_mapping(self, output, expected_dim_mapping): + assert ( + output.dist_attr.dims_mapping == expected_dim_mapping + ), f"{output.dist_attr.dims_mapping} vs {expected_dim_mapping}" + + def test_layernorm_forward(self): + shapes = [[16, 4, 4], [16]] + specs = [['x', None, None], [None]] + inputs, outputs = self.runfunc_and_check( + inputs_shape=shapes, + inputs_specs=specs, + op_func=layer_norm, + with_backward=True, + normalized_shape=[16], + ) + self.check_dim_mapping(outputs, [-1, -1, 0]) + + def test_layernorm_reshard(self): + shapes = [[16, 4, 4], [16]] + specs = [[None, None, 'x'], [None]] + inputs, outputs = self.runfunc_and_check( + inputs_shape=shapes, + inputs_specs=specs, + op_func=layer_norm, + with_backward=True, + normalized_shape=[16], + ) + self.check_dim_mapping(inputs, outputs, [-1, -1, 0]) + + def run_test_case(self): + if self._backend == "cpu": + paddle.set_device("cpu") + elif self._backend == "gpu": + paddle.set_device("gpu:" + str(dist.get_rank())) + else: + raise ValueError("Only support cpu or gpu backend.") + + self.test_layernorm_forward() + # all to all is not supported yet for cpu + if self._backend == "gpu": + self.test_layernorm_forward_reshard() + + +if __name__ == '__main__': + TestLayerNormSemiAutoParallel().run_test_case() diff --git a/test/auto_parallel/test_semi_auto_parallel_basic.py b/test/auto_parallel/test_semi_auto_parallel_basic.py index 2589566cb670e..474842ec4f119 100644 --- a/test/auto_parallel/test_semi_auto_parallel_basic.py +++ b/test/auto_parallel/test_semi_auto_parallel_basic.py @@ -26,6 +26,7 @@ def setUp(self): self._default_envs = {"dtype": "float32", "seed": "2023"} self._changeable_envs = {"backend": ["cpu", "gpu"]} + """ def test_matmul_api(self): envs_list = test_base.gen_product_envs_list( self._default_envs, self._changeable_envs @@ -55,7 +56,19 @@ def test_concat_api(self): "semi_auto_parallel_for_concat.py", user_defined_envs=envs, ) + """ + def test_layernorm_api(self): + envs_list = test_base.gen_product_envs_list( + self._default_envs, self._changeable_envs + ) + for envs in envs_list: + self.run_test_case( + "semi_auto_parallel_for_layernorm.py", + user_defined_envs=envs, + ) + + """ def test_reduction_api(self): envs_list = test_base.gen_product_envs_list( self._default_envs, self._changeable_envs @@ -105,6 +118,7 @@ def test_custom_relu_api(self): "semi_auto_parallel_for_custom_relu.py", user_defined_envs=envs, ) + """ if __name__ == "__main__": From 986d72589b591da5bebbe837da9bb0db24114b0c Mon Sep 17 00:00:00 2001 From: liuzhenhai93 Date: Mon, 13 Nov 2023 07:27:12 +0000 Subject: [PATCH 19/25] polish --- test/auto_parallel/semi_auto_parallel_for_layernorm.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/test/auto_parallel/semi_auto_parallel_for_layernorm.py b/test/auto_parallel/semi_auto_parallel_for_layernorm.py index ed04612aa2017..8242c95cdf04b 100644 --- a/test/auto_parallel/semi_auto_parallel_for_layernorm.py +++ b/test/auto_parallel/semi_auto_parallel_for_layernorm.py @@ -34,8 +34,8 @@ def check_dim_mapping(self, output, expected_dim_mapping): ), f"{output.dist_attr.dims_mapping} vs {expected_dim_mapping}" def test_layernorm_forward(self): - shapes = [[16, 4, 4], [16]] - specs = [['x', None, None], [None]] + shapes = ([16, 4, 4], [16], [16]) + specs = (['x', None, None], [None], [None]) inputs, outputs = self.runfunc_and_check( inputs_shape=shapes, inputs_specs=specs, @@ -46,8 +46,8 @@ def test_layernorm_forward(self): self.check_dim_mapping(outputs, [-1, -1, 0]) def test_layernorm_reshard(self): - shapes = [[16, 4, 4], [16]] - specs = [[None, None, 'x'], [None]] + shapes = ([16, 4, 4], [16], [16]) + specs = ([None, None, 'x'], [None], [None]) inputs, outputs = self.runfunc_and_check( inputs_shape=shapes, inputs_specs=specs, From c364e85b4186ac2bb92cdb2f036e988f429f08e2 Mon Sep 17 00:00:00 2001 From: liuzhenhai93 Date: Mon, 13 Nov 2023 07:36:54 +0000 Subject: [PATCH 20/25] polish --- test/auto_parallel/semi_auto_parallel_for_layernorm.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/auto_parallel/semi_auto_parallel_for_layernorm.py b/test/auto_parallel/semi_auto_parallel_for_layernorm.py index 8242c95cdf04b..bfc8c1f227994 100644 --- a/test/auto_parallel/semi_auto_parallel_for_layernorm.py +++ b/test/auto_parallel/semi_auto_parallel_for_layernorm.py @@ -41,7 +41,7 @@ def test_layernorm_forward(self): inputs_specs=specs, op_func=layer_norm, with_backward=True, - normalized_shape=[16], + normalized_shape=[4, 4], ) self.check_dim_mapping(outputs, [-1, -1, 0]) @@ -53,7 +53,7 @@ def test_layernorm_reshard(self): inputs_specs=specs, op_func=layer_norm, with_backward=True, - normalized_shape=[16], + normalized_shape=[4, 4], ) self.check_dim_mapping(inputs, outputs, [-1, -1, 0]) From d76e7e0fe754cbc25a5b92128a0aa34c7c5bee5d Mon Sep 17 00:00:00 2001 From: liuzhenhai93 Date: Mon, 13 Nov 2023 07:50:11 +0000 Subject: [PATCH 21/25] add test --- .../semi_auto_parallel_for_layernorm.py | 12 +++++++++--- test/auto_parallel/test_semi_auto_parallel_basic.py | 4 ---- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/test/auto_parallel/semi_auto_parallel_for_layernorm.py b/test/auto_parallel/semi_auto_parallel_for_layernorm.py index bfc8c1f227994..047cd6cbb79db 100644 --- a/test/auto_parallel/semi_auto_parallel_for_layernorm.py +++ b/test/auto_parallel/semi_auto_parallel_for_layernorm.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import numpy as np from semi_auto_parallel_util import SemiAutoParallelTestBase import paddle @@ -28,6 +29,11 @@ class TestLayerNormSemiAutoParallel(SemiAutoParallelTestBase): def __init__(self): super().__init__() + def check_tensor_eq(self, a, b): + np1 = a.numpy() + np2 = b.numpy() + np.testing.assert_allclose(np1, np2, rtol=1e-04, verbose=True) + def check_dim_mapping(self, output, expected_dim_mapping): assert ( output.dist_attr.dims_mapping == expected_dim_mapping @@ -43,7 +49,7 @@ def test_layernorm_forward(self): with_backward=True, normalized_shape=[4, 4], ) - self.check_dim_mapping(outputs, [-1, -1, 0]) + self.check_dim_mapping(outputs, [0, -1, -1]) def test_layernorm_reshard(self): shapes = ([16, 4, 4], [16], [16]) @@ -55,7 +61,7 @@ def test_layernorm_reshard(self): with_backward=True, normalized_shape=[4, 4], ) - self.check_dim_mapping(inputs, outputs, [-1, -1, 0]) + self.check_dim_mapping(outputs, [-1, -1, -1]) def run_test_case(self): if self._backend == "cpu": @@ -68,7 +74,7 @@ def run_test_case(self): self.test_layernorm_forward() # all to all is not supported yet for cpu if self._backend == "gpu": - self.test_layernorm_forward_reshard() + self.test_layernorm_reshard() if __name__ == '__main__': diff --git a/test/auto_parallel/test_semi_auto_parallel_basic.py b/test/auto_parallel/test_semi_auto_parallel_basic.py index 474842ec4f119..df722cf7dfe74 100644 --- a/test/auto_parallel/test_semi_auto_parallel_basic.py +++ b/test/auto_parallel/test_semi_auto_parallel_basic.py @@ -26,7 +26,6 @@ def setUp(self): self._default_envs = {"dtype": "float32", "seed": "2023"} self._changeable_envs = {"backend": ["cpu", "gpu"]} - """ def test_matmul_api(self): envs_list = test_base.gen_product_envs_list( self._default_envs, self._changeable_envs @@ -56,7 +55,6 @@ def test_concat_api(self): "semi_auto_parallel_for_concat.py", user_defined_envs=envs, ) - """ def test_layernorm_api(self): envs_list = test_base.gen_product_envs_list( @@ -68,7 +66,6 @@ def test_layernorm_api(self): user_defined_envs=envs, ) - """ def test_reduction_api(self): envs_list = test_base.gen_product_envs_list( self._default_envs, self._changeable_envs @@ -118,7 +115,6 @@ def test_custom_relu_api(self): "semi_auto_parallel_for_custom_relu.py", user_defined_envs=envs, ) - """ if __name__ == "__main__": From de22227aa64369fc2ccfed1e662589e24799f11c Mon Sep 17 00:00:00 2001 From: liuzhenhai93 Date: Tue, 14 Nov 2023 06:29:54 +0000 Subject: [PATCH 22/25] polish --- paddle/phi/infermeta/spmd_rules/concat.cc | 5 ----- paddle/phi/infermeta/spmd_rules/utils.cc | 19 +++++++++++++++---- 2 files changed, 15 insertions(+), 9 deletions(-) diff --git a/paddle/phi/infermeta/spmd_rules/concat.cc b/paddle/phi/infermeta/spmd_rules/concat.cc index 01ba776039483..29f90e9b58530 100644 --- a/paddle/phi/infermeta/spmd_rules/concat.cc +++ b/paddle/phi/infermeta/spmd_rules/concat.cc @@ -27,12 +27,7 @@ using phi::distributed::auto_parallel::str_join; std::tuple FillConcatNotation(int64_t n_axis, int64_t concat_axis) { - PADDLE_ENFORCE_EQ( - n_axis > concat_axis, true, phi::errors::InvalidArgument("")); static const std::string alphabet = "abcdefghijlopqrstuvwxyz"; - PADDLE_ENFORCE_EQ(alphabet.size() > static_cast(n_axis), - true, - phi::errors::InvalidArgument("")); std::string all_axis = alphabet.substr(0, n_axis); std::string align_axis = std::string(all_axis.begin(), all_axis.begin() + concat_axis) + diff --git a/paddle/phi/infermeta/spmd_rules/utils.cc b/paddle/phi/infermeta/spmd_rules/utils.cc index 4997d416b6436..9ca39e481738c 100644 --- a/paddle/phi/infermeta/spmd_rules/utils.cc +++ b/paddle/phi/infermeta/spmd_rules/utils.cc @@ -210,12 +210,23 @@ void AlignDimsSharding(std::vector* input_attrs_ptr, auto& input_attrs = *input_attrs_ptr; size_t n_inputs = input_attrs.size(); PADDLE_ENFORCE_EQ( - n_inputs, tensor_shapes.size(), phi::errors::InvalidArgument("")); - PADDLE_ENFORCE_EQ( - n_inputs, axis_names.size(), phi::errors::InvalidArgument("")); + n_inputs, + tensor_shapes.size(), + phi::errors::InvalidArgument( + "n_inputs [%d] and tensor_shapes.size() [%d] not match", + n_inputs, + tensor_shapes.size())); + PADDLE_ENFORCE_EQ(n_inputs, + axis_names.size(), + phi::errors::InvalidArgument( + "n_inputs [%d] and axis_names.size() [%d] not match", + n_inputs, + axis_names.size())); PADDLE_ENFORCE_EQ( - !align_axis.empty(), true, phi::errors::InvalidArgument("")); + !align_axis.empty(), + true, + phi::errors::InvalidArgument("align_axis should be not empty")); std::map, int64_t> axis_name_to_dim; From f02351e830c19d7847e5961946e3c1a407f69dba Mon Sep 17 00:00:00 2001 From: liuzhenhai93 Date: Tue, 14 Nov 2023 06:42:47 +0000 Subject: [PATCH 23/25] polish --- paddle/phi/infermeta/spmd_rules/utils.cc | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/paddle/phi/infermeta/spmd_rules/utils.cc b/paddle/phi/infermeta/spmd_rules/utils.cc index 9ca39e481738c..4e2afefa79372 100644 --- a/paddle/phi/infermeta/spmd_rules/utils.cc +++ b/paddle/phi/infermeta/spmd_rules/utils.cc @@ -223,11 +223,6 @@ void AlignDimsSharding(std::vector* input_attrs_ptr, n_inputs, axis_names.size())); - PADDLE_ENFORCE_EQ( - !align_axis.empty(), - true, - phi::errors::InvalidArgument("align_axis should be not empty")); - std::map, int64_t> axis_name_to_dim; for (size_t i = 0; i < n_inputs; i++) { @@ -235,7 +230,7 @@ void AlignDimsSharding(std::vector* input_attrs_ptr, for (char axi : align_axis) { if (axis_names[i].find(axi) == std::string::npos) { PADDLE_THROW(phi::errors::PreconditionNotMet( - "[%s] some axis not in input [%d],[%s]", + "[%s] some axis not in input [%d],[%s]", align_axis, i, axis_names[i])); From cf2ebd6cc6a42b47179a8308cbb7c10870ed1d36 Mon Sep 17 00:00:00 2001 From: liuzhenhai93 Date: Tue, 14 Nov 2023 06:56:25 +0000 Subject: [PATCH 24/25] polish --- paddle/phi/infermeta/spmd_rules/utils.cc | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/paddle/phi/infermeta/spmd_rules/utils.cc b/paddle/phi/infermeta/spmd_rules/utils.cc index 4e2afefa79372..c857af80083a7 100644 --- a/paddle/phi/infermeta/spmd_rules/utils.cc +++ b/paddle/phi/infermeta/spmd_rules/utils.cc @@ -38,7 +38,7 @@ std::string GetBroadcastAxes(const int64_t& tenosr_ndim, PADDLE_ENFORCE_GE(broadcast_ndim, tenosr_ndim, phi::errors::InvalidArgument( - "The broadcast ndim [%d] is less than tenosr ndim [%d]", + "The broadcast ndim [%d] is less than tensor ndim [%d]", broadcast_ndim, tenosr_ndim)); if (tenosr_ndim <= 0) { @@ -228,13 +228,13 @@ void AlignDimsSharding(std::vector* input_attrs_ptr, for (size_t i = 0; i < n_inputs; i++) { // 1、check all inputs have the align_axis for (char axi : align_axis) { - if (axis_names[i].find(axi) == std::string::npos) { - PADDLE_THROW(phi::errors::PreconditionNotMet( - "[%s] some axis not in input [%d],[%s]", - align_axis, - i, - axis_names[i])); - } + PADDLE_ENFORCE_EQ(axis_names[i].find(axi) == std::string::npos, + true, + phi::errors::PreconditionNotMet( + "align_axis[%s]; some axis not in input [%d],[%s]", + align_axis, + i, + axis_names[i])); } // 2、build axis map for (size_t j = 0; j < axis_names[i].size(); j++) { From 6806fb5f7dbae2e6d1872afa681719ffc3f934fb Mon Sep 17 00:00:00 2001 From: liuzhenhai93 Date: Wed, 15 Nov 2023 12:55:38 +0000 Subject: [PATCH 25/25] polish --- test/auto_parallel/semi_auto_parallel_util.py | 1 - 1 file changed, 1 deletion(-) diff --git a/test/auto_parallel/semi_auto_parallel_util.py b/test/auto_parallel/semi_auto_parallel_util.py index 99c295c6ad870..cfb905e8382a2 100644 --- a/test/auto_parallel/semi_auto_parallel_util.py +++ b/test/auto_parallel/semi_auto_parallel_util.py @@ -102,7 +102,6 @@ def terminal_cond(x): dist_input.stop_gradient = False flat_inputs.append(input) flat_dist_inputs.append(dist_input) - inputs, _ = self.unflatten(flat_inputs, inputs_structure) dist_inputs, _ = self.unflatten(flat_dist_inputs, inputs_structure)