Skip to content

Commit

Permalink
[AutoParallel] Verify auto parallel in amp mode (#58172)
Browse files Browse the repository at this point in the history
* add amp test

* support apm o1 mode

* add share_data_with method for dist tensor
  • Loading branch information
chenwhql committed Oct 19, 2023
1 parent b1fc0ee commit 648f3c4
Show file tree
Hide file tree
Showing 8 changed files with 171 additions and 17 deletions.
Expand Up @@ -74,6 +74,12 @@ MultiplyGradNode::operator()(
// Runtime check if we need next grad
bool trace_backward = egr::Controller::Instance().HasGrad() && create_graph;

// Set DistAttr of Out Tensor for semi-auto parallel
if (IsRunAutoParallel()) {
egr::EagerUtils::SetGradOutputDistAttr(
out_metas, {0, 1}, api_output_0, api_output_1);
}

// Inplace Check

// Inplace Strategy
Expand Down
5 changes: 5 additions & 0 deletions paddle/fluid/pybind/eager_math_op_patch.cc
Expand Up @@ -579,6 +579,11 @@ static PyObject* tensor__mul__method(TensorObject* self,
}
}

const phi::distributed::ProcessMesh* mesh = nullptr;
if (InputsContainDistTensor(&mesh, self_tensor, other_tensor)) {
ConvertAllInputsToDistTensor(mesh, self_tensor, other_tensor);
}

// 4. calculation
VLOG(6) << "Calling multiply_ad_func in tensor__mul__method";
{
Expand Down
8 changes: 7 additions & 1 deletion paddle/fluid/pybind/tensor.cc
Expand Up @@ -1038,7 +1038,13 @@ void BindTensor(pybind11::module &m) { // NOLINT
[](DistTensor &self) { return self.value(); },
py::return_value_policy::reference)
.def("numel",
[](DistTensor &self) -> int64_t { return self.value().numel(); });
[](DistTensor &self) -> int64_t { return self.value().numel(); })
.def("_share_data_with", [](DistTensor &self, const DistTensor &src) {
self.unsafe_set_dims(src.dims());
self.unsafe_set_dist_attr(src.dist_attr());
self.unsafe_mutable_value()->ShareDataWith(src.value());
return self;
});
#endif

py::class_<phi::SelectedRows>(m, "SelectedRows")
Expand Down
3 changes: 2 additions & 1 deletion paddle/phi/core/distributed/auto_parallel/dist_attr.cc
Expand Up @@ -349,7 +349,8 @@ std::string TensorDistAttr::partial_status_string() const {
}

bool TensorDistAttr::empty() const {
return process_mesh_.empty() || dims_mapping_.empty();
// dims_mapping is empty when the tensor is 0-dim, but it is also be valid.
return process_mesh_.empty();
}

std::vector<std::shared_ptr<PlacementStatus>> TensorDistAttr::to_placement()
Expand Down
10 changes: 5 additions & 5 deletions paddle/phi/core/distributed/auto_parallel/dist_meta_tensor.cc
Expand Up @@ -22,11 +22,11 @@ namespace distributed {
phi::DDim DistMetaTensor::dims() const {
// member values in tensor_ have higher priority than those in DistMetaTensor
if (tensor_ != nullptr) {
PADDLE_ENFORCE_EQ(this->is_dist(),
true,
phi::errors::InvalidArgument(
"The current MetaTensor doesn't contains "
"DistTensor when call `dist_attr` method."));
PADDLE_ENFORCE_EQ(
this->is_dist(),
true,
phi::errors::InvalidArgument("The current MetaTensor doesn't contains "
"DistTensor when call `dims` method."));
return MetaTensor::dims();
} else {
return dims_;
Expand Down
18 changes: 9 additions & 9 deletions test/auto_parallel/semi_auto_parallel_simple_net.py
Expand Up @@ -28,19 +28,19 @@

# TODO(chenweihang): update to MLP Layer later
class DemoNet(nn.Layer):
def __init__(self, np_w0, np_w1):
def __init__(self, np_w0, np_w1, param_suffix=""):
super().__init__()
self.w0 = self.create_parameter(
shape=[IMAGE_SIZE, IMAGE_SIZE],
attr=paddle.framework.ParamAttr(
name="demo_weight_1",
name="demo_weight_1" + param_suffix,
initializer=paddle.nn.initializer.Assign(np_w0),
),
)
self.w1 = self.create_parameter(
shape=[IMAGE_SIZE, CLASS_NUM],
attr=paddle.framework.ParamAttr(
name="nemo_weight_2",
name="nemo_weight_2" + param_suffix,
initializer=paddle.nn.initializer.Assign(np_w1),
),
)
Expand All @@ -52,20 +52,20 @@ def forward(self, x):


class DPDemoNet(nn.Layer):
def __init__(self, np_w0, np_w1, mesh):
def __init__(self, np_w0, np_w1, mesh, param_suffix=""):
super().__init__()
self.mesh = mesh
self.w0 = self.create_parameter(
shape=[IMAGE_SIZE, IMAGE_SIZE],
attr=paddle.framework.ParamAttr(
name="dp_demo_weight_1",
name="dp_demo_weight_1" + param_suffix,
initializer=paddle.nn.initializer.Assign(np_w0),
),
)
self.w1 = self.create_parameter(
shape=[IMAGE_SIZE, CLASS_NUM],
attr=paddle.framework.ParamAttr(
name="dp_nemo_weight_2",
name="dp_nemo_weight_2" + param_suffix,
initializer=paddle.nn.initializer.Assign(np_w1),
),
)
Expand All @@ -85,13 +85,13 @@ def forward(self, x):


class MPDemoNet(nn.Layer):
def __init__(self, np_w0, np_w1, mesh):
def __init__(self, np_w0, np_w1, mesh, param_suffix=""):
super().__init__()
self.w0 = dist.shard_tensor(
self.create_parameter(
shape=[IMAGE_SIZE, IMAGE_SIZE],
attr=paddle.framework.ParamAttr(
name="mp_demo_weight_1",
name="mp_demo_weight_1" + param_suffix,
initializer=paddle.nn.initializer.Assign(np_w0),
),
),
Expand All @@ -101,7 +101,7 @@ def __init__(self, np_w0, np_w1, mesh):
self.create_parameter(
shape=[IMAGE_SIZE, CLASS_NUM],
attr=paddle.framework.ParamAttr(
name="mp_nemo_weight_2",
name="mp_nemo_weight_2" + param_suffix,
initializer=paddle.nn.initializer.Assign(np_w1),
),
),
Expand Down
122 changes: 122 additions & 0 deletions test/auto_parallel/semi_auto_parallel_simple_net_amp.py
@@ -0,0 +1,122 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os

from semi_auto_parallel_simple_net import (
DemoNet,
DPDemoNet,
MPDemoNet,
TestSimpleNetForSemiAutoParallel,
)

import paddle
import paddle.distributed as dist
from paddle import nn


class TestSimpleNetWithAmpForSemiAutoParallel(TestSimpleNetForSemiAutoParallel):
def __init__(self):
self._dtype = os.getenv("dtype")
self._backend = os.getenv("backend")
self._seed = eval(os.getenv("seed"))
self._mesh = dist.ProcessMesh([0, 1], dim_names=["x"])

paddle.set_device(self._backend)
self.init_input_data()
self.init_single_card_net_result()

def run_dynamic_amp(self, layer, level='O1'):
if level == 'O2':
layer = paddle.amp.decorate(models=layer, level='O2')
# create loss
loss_fn = nn.MSELoss()
scaler = paddle.amp.GradScaler(init_loss_scaling=1024)
# run forward and backward
image = paddle.to_tensor(self.image)

with paddle.amp.auto_cast(level=level):
out = layer(image)
label = paddle.to_tensor(self.label)
loss = loss_fn(out, label)

scaled = scaler.scale(loss)
scaled.backward()
return loss, layer.w0.grad, layer.w1.grad

def init_single_card_net_result(self):
(
self.base_loss_o1,
self.base_w0_grad_o1,
self.base_w1_grad_o1,
) = self.run_dynamic_amp(DemoNet(self.w0, self.w1, 'O1'), 'O1')
(
self.base_loss_o2,
self.base_w0_grad_o2,
self.base_w1_grad_o2,
) = self.run_dynamic_amp(DemoNet(self.w0, self.w1, 'O2'), 'O2')

def test_dp_demo_net(self):
(
self.dp_loss_o1,
self.dp_w0_grad_o1,
self.dp_w1_grad_o1,
) = self.run_dynamic_amp(
DPDemoNet(self.w0, self.w1, self._mesh, 'O1'), 'O1'
)
self.check_tensor_eq(self.dp_loss_o1, self.base_loss_o1)
self.check_tensor_eq(self.dp_w0_grad_o1, self.base_w0_grad_o1)
self.check_tensor_eq(self.dp_w1_grad_o1, self.base_w1_grad_o1)

(
self.dp_loss_o2,
self.dp_w0_grad_o2,
self.dp_w1_grad_o2,
) = self.run_dynamic_amp(
DPDemoNet(self.w0, self.w1, self._mesh, 'O2'), 'O2'
)
self.check_tensor_eq(self.dp_loss_o2, self.base_loss_o2)
self.check_tensor_eq(self.dp_w0_grad_o2, self.base_w0_grad_o2)
self.check_tensor_eq(self.dp_w1_grad_o2, self.base_w1_grad_o2)

def test_mp_demo_net(self):
(
self.mp_loss_o1,
self.mp_w0_grad_o1,
self.mp_w1_grad_o1,
) = self.run_dynamic_amp(
MPDemoNet(self.w0, self.w1, self._mesh, 'O1'), 'O1'
)
self.check_tensor_eq(self.mp_loss_o1, self.base_loss_o1)
self.check_tensor_eq(self.mp_w0_grad_o1, self.base_w0_grad_o1)
self.check_tensor_eq(self.mp_w1_grad_o1, self.base_w1_grad_o1)

(
self.mp_loss_o2,
self.mp_w0_grad_o2,
self.mp_w1_grad_o2,
) = self.run_dynamic_amp(
MPDemoNet(self.w0, self.w1, self._mesh, 'O2'), 'O2'
)
self.check_tensor_eq(self.mp_loss_o2, self.base_loss_o2)
self.check_tensor_eq(self.mp_w0_grad_o2, self.base_w0_grad_o2)
self.check_tensor_eq(self.mp_w1_grad_o2, self.base_w1_grad_o2)

def run_test_case(self):
self.test_dp_demo_net()
self.test_mp_demo_net()


if __name__ == '__main__':
TestSimpleNetWithAmpForSemiAutoParallel().run_test_case()
16 changes: 15 additions & 1 deletion test/auto_parallel/test_semi_auto_parallel_single_strategy.py
Expand Up @@ -19,7 +19,10 @@

class TestSemiAutoParallelSingleStrategy(test_base.CommunicationTestDistBase):
def setUp(self):
super().setUp(num_of_devices=2, timeout=120)
super().setUp(
num_of_devices=2,
timeout=120,
)
self._default_envs = {
"dtype": "float32",
"seed": "2023",
Expand All @@ -36,6 +39,17 @@ def test_simple_net_single_strategy(self):
user_defined_envs=envs,
)

def test_simple_net_single_strategy_with_amp(self):
self._changeable_envs = {"backend": ["gpu"]}
envs_list = test_base.gen_product_envs_list(
self._default_envs, self._changeable_envs
)
for envs in envs_list:
self.run_test_case(
"semi_auto_parallel_simple_net_amp.py",
user_defined_envs=envs,
)


if __name__ == "__main__":
unittest.main()

0 comments on commit 648f3c4

Please sign in to comment.