Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add spmd rule for amp ops #64202

Merged
merged 10 commits into from
May 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 77 additions & 0 deletions paddle/phi/infermeta/spmd_rules/amp_ops.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/infermeta/spmd_rules/amp_ops.h"

#include <vector>
#include "glog/logging.h"

#include "paddle/phi/infermeta/spmd_rules/utils.h"

namespace phi {
namespace distributed {
// TODO(zhiqiu): support xs on different mesh.
SpmdInfo CheckFiniteAndUnscaleSpmd(const std::vector<DistMetaTensor>& xs,
const DistMetaTensor& scale) {
std::vector<TensorDistAttr> xs_attrs;
paddle::flat_hash_map<int64_t, ReduceType> partial_on_dims;
for (auto& x : xs) {
auto dist_attr = x.dist_attr();
dist_attr.clean_partial_status();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could allow the partial status pass through to have better performance:
if not allow partial to get through, allreduce will be conducted before this operation, no matter checkfinite return False or True.
BUT if we allow the partial to get through, allreduce will be triggered by the later operation(e.g. adam) and if the checkfinite return True, not adam would be call and therefore not allreduce is conduct

xs_attrs.emplace_back(dist_attr);
auto dims_mapping = dist_attr.dims_mapping();
for (auto& m : dims_mapping) {
if (m != -1 && partial_on_dims.count(m) == 0) {
partial_on_dims[m] = ReduceType::kRedMax;
}
}
}
TensorDistAttr found_infinite_attr =
CopyTensorDistAttrForOutput(scale.dist_attr());
found_infinite_attr.set_partial_status(partial_on_dims);
found_infinite_attr.set_dims_mapping({-1});
return {{xs_attrs, scale.dist_attr()}, {xs_attrs, found_infinite_attr}};
}

SpmdInfo UpdateLossScalingSpmd(const std::vector<DistMetaTensor>& xs,
const DistMetaTensor& found_infinite,
const DistMetaTensor& prev_loss_scaling,
const DistMetaTensor& in_good_steps,
const DistMetaTensor& in_bad_steps,
int incr_every_n_steps,
int decr_every_n_nan_or_inf,
float incr_ratio,
float decr_ratio,
Scalar stop_update) {
std::vector<TensorDistAttr> xs_attrs;
for (auto& x : xs) {
auto dist_attr = x.dist_attr();
dist_attr.clean_partial_status();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the same as above

xs_attrs.emplace_back(dist_attr);
}
TensorDistAttr found_infinite_attr =
CopyTensorDistAttrForOutput(found_infinite.dist_attr());
return {{xs_attrs,
found_infinite_attr,
prev_loss_scaling.dist_attr(),
in_good_steps.dist_attr(),
in_bad_steps.dist_attr()},
{xs_attrs,
found_infinite_attr,
in_good_steps.dist_attr(),
in_bad_steps.dist_attr()}};
}

} // namespace distributed
} // namespace phi
39 changes: 39 additions & 0 deletions paddle/phi/infermeta/spmd_rules/amp_ops.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <vector>

#include "paddle/phi/common/scalar.h"
#include "paddle/phi/core/distributed/auto_parallel/dist_meta_tensor.h"
#include "paddle/phi/core/distributed/type_defs.h"

namespace phi {
namespace distributed {
SpmdInfo CheckFiniteAndUnscaleSpmd(const std::vector<DistMetaTensor>& xs,
const DistMetaTensor& scale);

SpmdInfo UpdateLossScalingSpmd(const std::vector<DistMetaTensor>& xs,
const DistMetaTensor& found_infinite,
const DistMetaTensor& prev_loss_scaling,
const DistMetaTensor& in_good_steps,
const DistMetaTensor& in_bad_steps,
int incr_every_n_steps,
int decr_every_n_nan_or_inf,
float incr_ratio,
float decr_ratio,
Scalar stop_update = false);
} // namespace distributed
} // namespace phi
30 changes: 28 additions & 2 deletions paddle/phi/infermeta/spmd_rules/layer_norm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -383,12 +383,38 @@ SpmdInfo LayerNormGradInferSpmd(const DistMetaTensor& x,
std::string align_annotation;
std::tie(annotations, align_annotation) =
BuildLayerNormGradEinsum(x_shape.size(), begin_norm_axis);
AlignDimsSharding(
&dist_attrs, shapes, annotations, {}, align_annotation, false);

// Sharding Propagation
std::vector<std::pair<std::string, std::vector<int64_t>>>
axes_sharding_info;
auto x_dims_mapping = dist_attrs[0].dims_mapping();
auto out_grad_dims_mapping = dist_attrs[3].dims_mapping();
std::fill(
x_dims_mapping.begin() + begin_norm_axis, x_dims_mapping.end(), -1);
std::fill(out_grad_dims_mapping.begin() + begin_norm_axis,
out_grad_dims_mapping.end(),
-1);
axes_sharding_info.emplace_back(annotations[0], x_dims_mapping);
axes_sharding_info.emplace_back(annotations[1],
dist_attrs[1].dims_mapping());
axes_sharding_info.emplace_back(annotations[2],
dist_attrs[2].dims_mapping());
axes_sharding_info.emplace_back(annotations[3], out_grad_dims_mapping);
std::unordered_map<std::string, int64_t> axis_to_dim_map =
ShardingMergeForTensors(axes_sharding_info);

x_dist_attr = std::move(dist_attrs[0]);
x_dist_attr.set_dims_mapping(
GetDimsMappingForAxes(annotations[0], axis_to_dim_map));
mean_dist_attr = std::move(dist_attrs[1]);
mean_dist_attr.set_dims_mapping(
GetDimsMappingForAxes(annotations[1], axis_to_dim_map));
variance_dist_attr = std::move(dist_attrs[2]);
variance_dist_attr.set_dims_mapping(
GetDimsMappingForAxes(annotations[2], axis_to_dim_map));
out_grad_dist_attr = std::move(dist_attrs[3]);
out_grad_dist_attr.set_dims_mapping(
GetDimsMappingForAxes(annotations[3], axis_to_dim_map));
} else {
x_dist_attr = GetReplicatedDistAttr(dist_attrs[0]);
mean_dist_attr = GetReplicatedDistAttr(dist_attrs[1]);
Expand Down
1 change: 1 addition & 0 deletions paddle/phi/infermeta/spmd_rules/rules.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ limitations under the License. */

#pragma once
#include "paddle/phi/core/distributed/auto_parallel/inferspmd_utils.h"
#include "paddle/phi/infermeta/spmd_rules/amp_ops.h"
#include "paddle/phi/infermeta/spmd_rules/argmax.h"
#include "paddle/phi/infermeta/spmd_rules/cast.h"
#include "paddle/phi/infermeta/spmd_rules/concat.h"
Expand Down
3 changes: 3 additions & 0 deletions paddle/phi/ops/yaml/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@
output : Tensor(param_out), Tensor(moment1_out), Tensor(moment2_out), Tensor(beta1_pow_out), Tensor(beta2_pow_out), Tensor(master_param_out)
infer_meta :
func : AdamInferMeta
spmd_rule : AdamInferSpmdDynamic
kernel :
func : adam {dense, dense, dense, dense, dense, dense, dense, dense, dense -> dense, dense, dense, dense, dense, dense},
adam_dense_param_sparse_grad {dense, selected_rows, dense, dense, dense, dense, dense, dense, dense -> dense, dense, dense, dense, dense, dense}
Expand Down Expand Up @@ -526,6 +527,7 @@
infer_meta :
func : CheckFiniteAndUnscaleInferMeta
param : [x, scale]
spmd_rule : CheckFiniteAndUnscaleSpmd
kernel :
func : check_finite_and_unscale
param : [x, scale]
Expand Down Expand Up @@ -3332,6 +3334,7 @@
infer_meta :
func : UpdateLossScalingInferMeta
param : [x, found_infinite, prev_loss_scaling, in_good_steps, in_bad_steps]
spmd_rule : UpdateLossScalingSpmd
kernel :
func : update_loss_scaling
data_type : x
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,16 +57,26 @@ def dp_mp_pp_shard_fn(self, layer_name, layer, process_mesh):
layer.weight = dist.shard_tensor(
layer.weight, self._pp_mesh0, [Replicate(), Shard(1)]
)
layer.bias = dist.shard_tensor(
layer.bias, self._pp_mesh0, [Replicate(), Replicate()]
)
if layer.bias is not None:
layer.bias = dist.shard_tensor(
layer.bias, self._pp_mesh0, [Replicate(), Replicate()]
)
elif layer_name == 'linear_1':
layer.weight = dist.shard_tensor(
layer.weight, self._pp_mesh1, [Replicate(), Shard(0)]
)
layer.bias = dist.shard_tensor(
layer.bias, self._pp_mesh1, [Replicate(), Replicate()]
if layer.bias is not None:
layer.bias = dist.shard_tensor(
layer.bias, self._pp_mesh1, [Replicate(), Replicate()]
)
elif layer_name == 'norm':
layer.weight = dist.shard_tensor(
layer.weight, self._pp_mesh1, [Replicate(), Replicate()]
)
if layer.bias is not None:
layer.bias = dist.shard_tensor(
layer.bias, self._pp_mesh1, [Replicate(), Replicate()]
)

def test_dp_mp_pp_demo_net(self):
self.set_random_seed(self._seed)
Expand All @@ -92,19 +102,19 @@ def test_dp_mp_pp_demo_net(self):
if rank in [0, 1, 2, 3]:
# linear_0 weight and bias
self.check_tensor_eq(
self.dp_mp_pp_parameters[0], self.base_parameters[0]
self.dp_mp_pp_parameters[0], self.base_parameters[0], rtol=2e-4
)

else:
self.check_tensor_eq(self.dp_mp_pp_loss, self.base_loss, rtol=1e-4)
self.check_tensor_eq(
self.dp_mp_pp_parameters[1], self.base_parameters[1]
self.dp_mp_pp_parameters[1], self.base_parameters[1], rtol=1e-4
)
else:
self.check_tensor_eq(self.dp_mp_pp_loss, self.base_loss)
# linear_1 weight and bias
self.check_tensor_eq(
self.dp_mp_pp_parameters[2], self.base_parameters[2]
self.dp_mp_pp_parameters[2], self.base_parameters[2], rtol=2e-5
)
self.check_tensor_eq(
self.dp_mp_pp_parameters[3], self.base_parameters[3]
self.dp_mp_pp_parameters[3], self.base_parameters[3], rtol=2e-4
)

# save load
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def run_dynamic(self, layer, is_sp=False):
loss_fn = nn.MSELoss()
# run forward and backward
opt = paddle.optimizer.AdamW(
learning_rate=0.1, parameters=layer.parameters()
learning_rate=0.001, parameters=layer.parameters()
)
for _ in range(5):
image, label = self.init_input_data()
Expand All @@ -192,7 +192,6 @@ def run_dynamic(self, layer, is_sp=False):

loss = loss_fn(out, label)
loss.backward()

opt.step()
return loss, layer.parameters()

Expand All @@ -204,14 +203,8 @@ def test_dp_mp_sp_demo_net(self):
self.dp_mp_sp_loss,
self.dp_mp_sp_parameters,
) = self.run_dynamic(model, is_sp=True)

self.check_tensor_eq(self.dp_mp_sp_loss, self.base_loss)
for param, param_base in zip(
self.dp_mp_sp_parameters, self.base_parameters
):
if param.grad._is_initialized():
self.check_tensor_eq(param, param_base)
self.check_tensor_eq(param.grad, param_base.grad)
if dist.get_rank() in model.pp1_mesh.process_ids:
self.check_tensor_eq(self.dp_mp_sp_loss, self.base_loss, rtol=1e-3)

def run_test_case(self):
self.test_dp_mp_sp_demo_net()
Expand Down
42 changes: 28 additions & 14 deletions test/auto_parallel/semi_auto_parallel_simple_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

BATCH_SIZE = 16
BATCH_NUM = 4
IMAGE_SIZE = 784
IMAGE_SIZE = 128
CLASS_NUM = 10


Expand All @@ -46,20 +46,26 @@ def __init__(
super().__init__()
weight_attr_0 = create_numpy_like_random(param_prefix + "_0")
weight_attr_1 = create_numpy_like_random(param_prefix + "_1")
weight_attr_2 = create_numpy_like_random(param_prefix + "_2")

self.is_pp = is_pp
self.is_recompute = is_recompute
self.pp_reshard_dist_attr = pp_reshard_dist_attr
self.linear_0 = nn.Linear(IMAGE_SIZE, IMAGE_SIZE, weight_attr_0)
self.linear_1 = nn.Linear(IMAGE_SIZE, CLASS_NUM, weight_attr_1)
self.relu = nn.ReLU()
self.linear_0 = nn.Linear(
IMAGE_SIZE, IMAGE_SIZE, weight_attr_0, bias_attr=False
)
self.linear_1 = nn.Linear(
IMAGE_SIZE, CLASS_NUM, weight_attr_1, bias_attr=False
)
self.norm = nn.LayerNorm([IMAGE_SIZE], weight_attr=weight_attr_2)

def _inner_forward_fn(self, x):
out = self.linear_0(x)
out = self.relu(out)
if self.is_pp:
out = dist.reshard(out, *self.pp_reshard_dist_attr)
out = self.norm(out)
out = self.linear_1(out)
out = paddle.abs(out)
return out

def forward(self, x):
Expand Down Expand Up @@ -101,12 +107,20 @@ def pp_shard_fn(self, layer_name, layer, process_mesh):
bias_dist_attr = (self._pp_mesh0, [Replicate()])

layer.weight = dist.shard_tensor(layer.weight, *weight_dist_attr)
layer.bias = dist.shard_tensor(layer.bias, *bias_dist_attr)
if layer.bias is not None:
layer.bias = dist.shard_tensor(layer.bias, *bias_dist_attr)
elif layer_name == 'linear_1':
weight_dist_attr = (self._pp_mesh1, [Replicate()])
bias_dist_attr = (self._pp_mesh1, [Replicate()])
layer.weight = dist.shard_tensor(layer.weight, *weight_dist_attr)
layer.bias = dist.shard_tensor(layer.bias, *bias_dist_attr)
if layer.bias is not None:
layer.bias = dist.shard_tensor(layer.bias, *bias_dist_attr)
elif layer_name == 'norm':
weight_dist_attr = (self._pp_mesh1, [Replicate()])
bias_dist_attr = (self._pp_mesh1, [Replicate()])
layer.weight = dist.shard_tensor(layer.weight, *weight_dist_attr)
if layer.bias is not None:
layer.bias = dist.shard_tensor(layer.bias, *bias_dist_attr)

def set_random_seed(self, seed):
random.seed(seed)
Expand All @@ -128,10 +142,10 @@ def run_dynamic(self, layer, shard_input=False, is_pp=False):
input_dist_attr = (self._mesh, [Shard(0)])

opt = paddle.optimizer.SGD(
learning_rate=0.1, parameters=layer.parameters()
learning_rate=0.001, parameters=layer.parameters()
)
opt = dist.shard_optimizer(opt)
for _ in range(5):
for _ in range(3):
image, label = self.init_input_data()
if shard_input:
image = dist.shard_tensor(image, *input_dist_attr)
Expand Down Expand Up @@ -164,9 +178,9 @@ def test_dp_demo_net(self):
DemoNet("dp_demo_weight"),
shard_input=True,
)
self.check_tensor_eq(self.dp_loss, self.base_loss)
self.check_tensor_eq(self.dp_loss, self.base_loss, rtol=1e-4)
for param, param_base in zip(self.dp_parameters, self.base_parameters):
self.check_tensor_eq(param, param_base)
self.check_tensor_eq(param, param_base, rtol=2e-4)
self.check_tensor_eq(param.grad, param_base.grad)

def test_mp_demo_net(self):
Expand Down Expand Up @@ -209,12 +223,12 @@ def test_pp_demo_net(self):
# cross-mesh now, ReshardXToReplicated function in eager_method
# needs to be fixed later.
if rank == 0:
# linear_0 weight and bias
# linear_0 weight
self.check_tensor_eq(self.pp_parameters[0], self.base_parameters[0])
self.check_tensor_eq(self.pp_parameters[1], self.base_parameters[1])
else:
self.check_tensor_eq(self.pp_loss, self.base_loss)
# linear_1 weight and bias
# linear_1 weight, norm.weight, norm.bias
self.check_tensor_eq(self.pp_parameters[1], self.base_parameters[1])
self.check_tensor_eq(self.pp_parameters[2], self.base_parameters[2])
self.check_tensor_eq(self.pp_parameters[3], self.base_parameters[3])

Expand Down
Loading