diff --git a/paddle/fluid/eager/api/manual/eager_manual/nodes/multiply_node.cc b/paddle/fluid/eager/api/manual/eager_manual/nodes/multiply_node.cc index ed83bb29714ff..2e4489fdcc12e 100644 --- a/paddle/fluid/eager/api/manual/eager_manual/nodes/multiply_node.cc +++ b/paddle/fluid/eager/api/manual/eager_manual/nodes/multiply_node.cc @@ -74,6 +74,12 @@ MultiplyGradNode::operator()( // Runtime check if we need next grad bool trace_backward = egr::Controller::Instance().HasGrad() && create_graph; + // Set DistAttr of Out Tensor for semi-auto parallel + if (IsRunAutoParallel()) { + egr::EagerUtils::SetGradOutputDistAttr( + out_metas, {0, 1}, api_output_0, api_output_1); + } + // Inplace Check // Inplace Strategy diff --git a/paddle/fluid/pybind/eager_math_op_patch.cc b/paddle/fluid/pybind/eager_math_op_patch.cc index ecae39fb43a49..21578110323ab 100644 --- a/paddle/fluid/pybind/eager_math_op_patch.cc +++ b/paddle/fluid/pybind/eager_math_op_patch.cc @@ -579,6 +579,11 @@ static PyObject* tensor__mul__method(TensorObject* self, } } + const phi::distributed::ProcessMesh* mesh = nullptr; + if (InputsContainDistTensor(&mesh, self_tensor, other_tensor)) { + ConvertAllInputsToDistTensor(mesh, self_tensor, other_tensor); + } + // 4. calculation VLOG(6) << "Calling multiply_ad_func in tensor__mul__method"; { diff --git a/paddle/fluid/pybind/tensor.cc b/paddle/fluid/pybind/tensor.cc index 5b6efa9e1dba9..7205333bb688c 100644 --- a/paddle/fluid/pybind/tensor.cc +++ b/paddle/fluid/pybind/tensor.cc @@ -1038,7 +1038,13 @@ void BindTensor(pybind11::module &m) { // NOLINT [](DistTensor &self) { return self.value(); }, py::return_value_policy::reference) .def("numel", - [](DistTensor &self) -> int64_t { return self.value().numel(); }); + [](DistTensor &self) -> int64_t { return self.value().numel(); }) + .def("_share_data_with", [](DistTensor &self, const DistTensor &src) { + self.unsafe_set_dims(src.dims()); + self.unsafe_set_dist_attr(src.dist_attr()); + self.unsafe_mutable_value()->ShareDataWith(src.value()); + return self; + }); #endif py::class_(m, "SelectedRows") diff --git a/paddle/phi/core/distributed/auto_parallel/dist_attr.cc b/paddle/phi/core/distributed/auto_parallel/dist_attr.cc index 46e58cc9b373e..3c95f2c3ff66f 100644 --- a/paddle/phi/core/distributed/auto_parallel/dist_attr.cc +++ b/paddle/phi/core/distributed/auto_parallel/dist_attr.cc @@ -349,7 +349,8 @@ std::string TensorDistAttr::partial_status_string() const { } bool TensorDistAttr::empty() const { - return process_mesh_.empty() || dims_mapping_.empty(); + // dims_mapping is empty when the tensor is 0-dim, but it is also be valid. + return process_mesh_.empty(); } std::vector> TensorDistAttr::to_placement() diff --git a/paddle/phi/core/distributed/auto_parallel/dist_meta_tensor.cc b/paddle/phi/core/distributed/auto_parallel/dist_meta_tensor.cc index dc5d6c20e62b3..1e3164de81865 100644 --- a/paddle/phi/core/distributed/auto_parallel/dist_meta_tensor.cc +++ b/paddle/phi/core/distributed/auto_parallel/dist_meta_tensor.cc @@ -22,11 +22,11 @@ namespace distributed { phi::DDim DistMetaTensor::dims() const { // member values in tensor_ have higher priority than those in DistMetaTensor if (tensor_ != nullptr) { - PADDLE_ENFORCE_EQ(this->is_dist(), - true, - phi::errors::InvalidArgument( - "The current MetaTensor doesn't contains " - "DistTensor when call `dist_attr` method.")); + PADDLE_ENFORCE_EQ( + this->is_dist(), + true, + phi::errors::InvalidArgument("The current MetaTensor doesn't contains " + "DistTensor when call `dims` method.")); return MetaTensor::dims(); } else { return dims_; diff --git a/test/auto_parallel/semi_auto_parallel_simple_net.py b/test/auto_parallel/semi_auto_parallel_simple_net.py index 62fec8c906336..3187dc75ad993 100644 --- a/test/auto_parallel/semi_auto_parallel_simple_net.py +++ b/test/auto_parallel/semi_auto_parallel_simple_net.py @@ -28,19 +28,19 @@ # TODO(chenweihang): update to MLP Layer later class DemoNet(nn.Layer): - def __init__(self, np_w0, np_w1): + def __init__(self, np_w0, np_w1, param_suffix=""): super().__init__() self.w0 = self.create_parameter( shape=[IMAGE_SIZE, IMAGE_SIZE], attr=paddle.framework.ParamAttr( - name="demo_weight_1", + name="demo_weight_1" + param_suffix, initializer=paddle.nn.initializer.Assign(np_w0), ), ) self.w1 = self.create_parameter( shape=[IMAGE_SIZE, CLASS_NUM], attr=paddle.framework.ParamAttr( - name="nemo_weight_2", + name="nemo_weight_2" + param_suffix, initializer=paddle.nn.initializer.Assign(np_w1), ), ) @@ -52,20 +52,20 @@ def forward(self, x): class DPDemoNet(nn.Layer): - def __init__(self, np_w0, np_w1, mesh): + def __init__(self, np_w0, np_w1, mesh, param_suffix=""): super().__init__() self.mesh = mesh self.w0 = self.create_parameter( shape=[IMAGE_SIZE, IMAGE_SIZE], attr=paddle.framework.ParamAttr( - name="dp_demo_weight_1", + name="dp_demo_weight_1" + param_suffix, initializer=paddle.nn.initializer.Assign(np_w0), ), ) self.w1 = self.create_parameter( shape=[IMAGE_SIZE, CLASS_NUM], attr=paddle.framework.ParamAttr( - name="dp_nemo_weight_2", + name="dp_nemo_weight_2" + param_suffix, initializer=paddle.nn.initializer.Assign(np_w1), ), ) @@ -85,13 +85,13 @@ def forward(self, x): class MPDemoNet(nn.Layer): - def __init__(self, np_w0, np_w1, mesh): + def __init__(self, np_w0, np_w1, mesh, param_suffix=""): super().__init__() self.w0 = dist.shard_tensor( self.create_parameter( shape=[IMAGE_SIZE, IMAGE_SIZE], attr=paddle.framework.ParamAttr( - name="mp_demo_weight_1", + name="mp_demo_weight_1" + param_suffix, initializer=paddle.nn.initializer.Assign(np_w0), ), ), @@ -101,7 +101,7 @@ def __init__(self, np_w0, np_w1, mesh): self.create_parameter( shape=[IMAGE_SIZE, CLASS_NUM], attr=paddle.framework.ParamAttr( - name="mp_nemo_weight_2", + name="mp_nemo_weight_2" + param_suffix, initializer=paddle.nn.initializer.Assign(np_w1), ), ), diff --git a/test/auto_parallel/semi_auto_parallel_simple_net_amp.py b/test/auto_parallel/semi_auto_parallel_simple_net_amp.py new file mode 100644 index 0000000000000..3a17024063162 --- /dev/null +++ b/test/auto_parallel/semi_auto_parallel_simple_net_amp.py @@ -0,0 +1,122 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +from semi_auto_parallel_simple_net import ( + DemoNet, + DPDemoNet, + MPDemoNet, + TestSimpleNetForSemiAutoParallel, +) + +import paddle +import paddle.distributed as dist +from paddle import nn + + +class TestSimpleNetWithAmpForSemiAutoParallel(TestSimpleNetForSemiAutoParallel): + def __init__(self): + self._dtype = os.getenv("dtype") + self._backend = os.getenv("backend") + self._seed = eval(os.getenv("seed")) + self._mesh = dist.ProcessMesh([0, 1], dim_names=["x"]) + + paddle.set_device(self._backend) + self.init_input_data() + self.init_single_card_net_result() + + def run_dynamic_amp(self, layer, level='O1'): + if level == 'O2': + layer = paddle.amp.decorate(models=layer, level='O2') + # create loss + loss_fn = nn.MSELoss() + scaler = paddle.amp.GradScaler(init_loss_scaling=1024) + # run forward and backward + image = paddle.to_tensor(self.image) + + with paddle.amp.auto_cast(level=level): + out = layer(image) + label = paddle.to_tensor(self.label) + loss = loss_fn(out, label) + + scaled = scaler.scale(loss) + scaled.backward() + return loss, layer.w0.grad, layer.w1.grad + + def init_single_card_net_result(self): + ( + self.base_loss_o1, + self.base_w0_grad_o1, + self.base_w1_grad_o1, + ) = self.run_dynamic_amp(DemoNet(self.w0, self.w1, 'O1'), 'O1') + ( + self.base_loss_o2, + self.base_w0_grad_o2, + self.base_w1_grad_o2, + ) = self.run_dynamic_amp(DemoNet(self.w0, self.w1, 'O2'), 'O2') + + def test_dp_demo_net(self): + ( + self.dp_loss_o1, + self.dp_w0_grad_o1, + self.dp_w1_grad_o1, + ) = self.run_dynamic_amp( + DPDemoNet(self.w0, self.w1, self._mesh, 'O1'), 'O1' + ) + self.check_tensor_eq(self.dp_loss_o1, self.base_loss_o1) + self.check_tensor_eq(self.dp_w0_grad_o1, self.base_w0_grad_o1) + self.check_tensor_eq(self.dp_w1_grad_o1, self.base_w1_grad_o1) + + ( + self.dp_loss_o2, + self.dp_w0_grad_o2, + self.dp_w1_grad_o2, + ) = self.run_dynamic_amp( + DPDemoNet(self.w0, self.w1, self._mesh, 'O2'), 'O2' + ) + self.check_tensor_eq(self.dp_loss_o2, self.base_loss_o2) + self.check_tensor_eq(self.dp_w0_grad_o2, self.base_w0_grad_o2) + self.check_tensor_eq(self.dp_w1_grad_o2, self.base_w1_grad_o2) + + def test_mp_demo_net(self): + ( + self.mp_loss_o1, + self.mp_w0_grad_o1, + self.mp_w1_grad_o1, + ) = self.run_dynamic_amp( + MPDemoNet(self.w0, self.w1, self._mesh, 'O1'), 'O1' + ) + self.check_tensor_eq(self.mp_loss_o1, self.base_loss_o1) + self.check_tensor_eq(self.mp_w0_grad_o1, self.base_w0_grad_o1) + self.check_tensor_eq(self.mp_w1_grad_o1, self.base_w1_grad_o1) + + ( + self.mp_loss_o2, + self.mp_w0_grad_o2, + self.mp_w1_grad_o2, + ) = self.run_dynamic_amp( + MPDemoNet(self.w0, self.w1, self._mesh, 'O2'), 'O2' + ) + self.check_tensor_eq(self.mp_loss_o2, self.base_loss_o2) + self.check_tensor_eq(self.mp_w0_grad_o2, self.base_w0_grad_o2) + self.check_tensor_eq(self.mp_w1_grad_o2, self.base_w1_grad_o2) + + def run_test_case(self): + self.test_dp_demo_net() + self.test_mp_demo_net() + + +if __name__ == '__main__': + TestSimpleNetWithAmpForSemiAutoParallel().run_test_case() diff --git a/test/auto_parallel/test_semi_auto_parallel_single_strategy.py b/test/auto_parallel/test_semi_auto_parallel_single_strategy.py index 89ef4ac6a1a10..03b31f70a9e9b 100644 --- a/test/auto_parallel/test_semi_auto_parallel_single_strategy.py +++ b/test/auto_parallel/test_semi_auto_parallel_single_strategy.py @@ -19,7 +19,10 @@ class TestSemiAutoParallelSingleStrategy(test_base.CommunicationTestDistBase): def setUp(self): - super().setUp(num_of_devices=2, timeout=120) + super().setUp( + num_of_devices=2, + timeout=120, + ) self._default_envs = { "dtype": "float32", "seed": "2023", @@ -36,6 +39,17 @@ def test_simple_net_single_strategy(self): user_defined_envs=envs, ) + def test_simple_net_single_strategy_with_amp(self): + self._changeable_envs = {"backend": ["gpu"]} + envs_list = test_base.gen_product_envs_list( + self._default_envs, self._changeable_envs + ) + for envs in envs_list: + self.run_test_case( + "semi_auto_parallel_simple_net_amp.py", + user_defined_envs=envs, + ) + if __name__ == "__main__": unittest.main()