Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Semi-Auto] SPMD Parallel Rule Base #53863

Merged
merged 46 commits into from
Jun 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
a47bf99
base rule
JZ-LIANG May 16, 2023
21b5a75
add sharidng merge
JZ-LIANG May 18, 2023
f7e39d7
add sharidng axis merge
JZ-LIANG May 19, 2023
c92992d
define unified data class for inferencing dist_attr
pkuzyc May 18, 2023
42a7b77
test wrap DistTensorSpec in dygraph mode
pkuzyc May 19, 2023
180edcc
matmul main logic done
JZ-LIANG May 23, 2023
ecbb1ae
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
JZ-LIANG May 23, 2023
f314b56
Merge remote-tracking branch 'zyc/develop' into semi-auto/rule-base
JZ-LIANG May 23, 2023
46153c7
shape int64
JZ-LIANG May 23, 2023
1dcb80e
common cc
JZ-LIANG May 23, 2023
198bc1f
define unified data class for inferencing dist_attr
pkuzyc May 18, 2023
09d82a5
test wrap DistTensorSpec in dygraph mode
pkuzyc May 19, 2023
c3ea2a6
define python api and wrap function in static mode for DistTensorSpec
pkuzyc May 23, 2023
4cd1a2c
revise syntax
JZ-LIANG May 24, 2023
3631f06
Merge remote-tracking branch 'zyc/develop' into semi-auto/rule-base
JZ-LIANG May 24, 2023
ed0c31e
map bugfix
JZ-LIANG May 29, 2023
701d3fa
broadcast func
JZ-LIANG May 29, 2023
c1545a4
compile 1
JZ-LIANG May 29, 2023
3ca2b73
add unitest
JZ-LIANG May 31, 2023
747de08
add registry
JZ-LIANG Jun 6, 2023
968ce61
Merge branch 'semi-auto/rule-base' of https://github.com/JZ-LIANG/Pad…
JZ-LIANG Jun 6, 2023
7be672d
update unitest
JZ-LIANG Jun 6, 2023
3389b7e
bugfix
JZ-LIANG Jun 6, 2023
3882a2c
bugfix
JZ-LIANG Jun 6, 2023
3719a5a
add pybind
JZ-LIANG Jun 6, 2023
73f49a8
bugfix
JZ-LIANG Jun 6, 2023
aced5ea
bugfix macro gloabl name space
JZ-LIANG Jun 6, 2023
ef92dc4
bugfix macro gloabl name space
JZ-LIANG Jun 6, 2023
adcb470
segment fault
JZ-LIANG Jun 8, 2023
43df148
pybind
JZ-LIANG Jun 8, 2023
27803af
pybind test
JZ-LIANG Jun 8, 2023
5612da9
pybind bugfixed1
JZ-LIANG Jun 14, 2023
f9bd281
pybind bugfixed2
JZ-LIANG Jun 14, 2023
18f8d29
pybind unitest
JZ-LIANG Jun 14, 2023
2628043
Merge remote-tracking branch 'upstream/develop' into semi-auto/rule-base
JZ-LIANG Jun 16, 2023
68a512a
merge dev
JZ-LIANG Jun 16, 2023
f2b2edb
merge dev
JZ-LIANG Jun 16, 2023
132558a
merge dev
JZ-LIANG Jun 16, 2023
f3bc740
fixed cmake conflict
JZ-LIANG Jun 16, 2023
c11cdd2
fixed cmake conflict
JZ-LIANG Jun 16, 2023
491bf65
rename get method
JZ-LIANG Jun 20, 2023
041abd4
revise inferforward output type
JZ-LIANG Jun 20, 2023
60c90d3
revise comment
JZ-LIANG Jun 20, 2023
0fc8ff9
Merge remote-tracking branch 'upstream/develop' into semi-auto/rule-base
JZ-LIANG Jun 26, 2023
eecd184
update unitest
JZ-LIANG Jun 26, 2023
b70b3d5
update cmake deps
JZ-LIANG Jun 26, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions paddle/fluid/distributed/auto_parallel/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,7 @@ cc_library(
SRCS dist_attr.cc
DEPS phi auto_parallel_proto proto_desc)

cc_library(auto_parallel DEPS op_dist_attr spmd_rule)

add_subdirectory(test)
add_subdirectory(spmd_rules)
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
cc_library(
spmd_rule
SRCS common.cc dist_tensor_spec.cc matmul_spmd_rule.cc
DEPS phi)
213 changes: 213 additions & 0 deletions paddle/fluid/distributed/auto_parallel/spmd_rules/common.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,213 @@
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/distributed/auto_parallel/spmd_rules/common.h"

#include <glog/logging.h>

#include "paddle/fluid/distributed/auto_parallel/spmd_rules/rules.h"

namespace paddle {
namespace distributed {
namespace auto_parallel {

std::pair<std::vector<TensorDistAttr>, std::vector<TensorDistAttr>>
SPMDRuleBase::InferForward(const std::vector<DistTensorSpec>& input_specs,
const paddle::framework::AttributeMap& attrs) {
PADDLE_THROW(
phi::errors::Unimplemented("InferForward should be called from a "
"derived class of SPMDRuleBase !"));
}

std::pair<std::vector<TensorDistAttr>, std::vector<TensorDistAttr>>
SPMDRuleBase::InferBackward(const std::vector<DistTensorSpec>& output_specs,
const paddle::framework::AttributeMap& attrs) {
PADDLE_THROW(
phi::errors::Unimplemented("InferBackward should be called from a "
"derived class of SPMDRuleBase !"));
}

std::unordered_map<std::string, int64_t> ShardingMergeForTensors(
const std::vector<std::pair<const std::string, const std::vector<int64_t>>>&
tensor_axes_to_dim_pairs) {
std::unordered_map<std::string, int64_t> axis_to_dim_map;
std::unordered_map<int64_t, std::string> dim_to_axis_map;
int64_t merge_dim;

for (auto& pair : tensor_axes_to_dim_pairs) {
for (size_t i = 0; i < pair.second.size(); ++i) {
auto tensor_axis = pair.first.substr(i, 1);
auto mesh_dim = pair.second[i];

if (axis_to_dim_map.count(tensor_axis) == 0) {
merge_dim = mesh_dim;
} else {
merge_dim = ShardingMergeForAxis(
tensor_axis, mesh_dim, axis_to_dim_map[tensor_axis]);
}
axis_to_dim_map[tensor_axis] = merge_dim;
if (merge_dim != -1) {
if (dim_to_axis_map.count(merge_dim) == 0) {
dim_to_axis_map.insert({merge_dim, tensor_axis});
} else if (dim_to_axis_map[merge_dim].find(tensor_axis) ==
std::string::npos) {
dim_to_axis_map[merge_dim] += tensor_axis;
}
}
}
}

// Resolute "mesh_dim shard by more than one axis" confict.
// Now we just naive pick the first axis naively.
// (TODO) use local cost model to pick the axis with lowest cost(in concern of
// memory or communication or computation).
for (auto& it : dim_to_axis_map) {
if (it.second.size() > 1) {
VLOG(4) << "Sharding Conflict: Mesh_Dim [" << it.first
<< "] are Sharding Multiple Tensor Axis: [" << it.second
<< "]. The Axis: [" << it.second[0] << "] is Picked.";
for (size_t i = 1; i < it.second.size(); ++i) {
axis_to_dim_map[it.second.substr(i, 1)] = -1;
}
}
}

return axis_to_dim_map;
}

// Rule1: A repicated dimension could be merged by any sharded dimension.
// Rule2: A tensor axis could at most be sharded by one mesh dimension.
// (TODO trigger heuristics cost model and reshard to handle axis sharded by
// multiple dimension case.)
int64_t ShardingMergeForAxis(const std::string& axis,
const int64_t& mesh_dim1,
const int64_t& mesh_dim2) {
if (mesh_dim1 != mesh_dim2) {
if (mesh_dim1 == -1) {
return mesh_dim2;
} else if (mesh_dim2 == -1) {
return mesh_dim1;
} else {
// (TODO) local cost model here.
PADDLE_THROW(
phi::errors::Unimplemented("Tensor Axis[%s] is Sharded by two "
"different mesh dimension [%d] and [%d].",
axis,
mesh_dim1,
mesh_dim2));
}

} else {
return mesh_dim1;
}
}

TensorDistAttr CopyTensorDistAttrForOutput(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Better add a copy constructor.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

copy constructor is not fit in this case since part of data member would be changed after "copy".

const TensorDistAttr& src_dist_attr) {
TensorDistAttr new_dist_attr = TensorDistAttr();
new_dist_attr.set_process_mesh(src_dist_attr.process_mesh());
new_dist_attr.set_batch_dim(src_dist_attr.batch_dim());
new_dist_attr.set_dynamic_dims(src_dist_attr.dynamic_dims());
// new_dist_attr.set_annotated(false); TODO unset field is false by default.
return new_dist_attr;
}

std::vector<int64_t> ResoluteOutputPartialDimension(
const std::unordered_map<std::string, int64_t>& axis_to_dim_map,
const std::string& tensor_axes) {
std::vector<int64_t> partial_on_dims;

for (auto& it : axis_to_dim_map) {
if (tensor_axes.find(it.first) == std::string::npos) {
if (it.second > -1) {
partial_on_dims.push_back(it.second);
}
}
}
return partial_on_dims;
}

std::string GetBroadcastAxes(const int64_t& tenosr_ndim,
const int64_t& broadcast_ndim,
const std::string& alphabet) {
PADDLE_ENFORCE_GE(
alphabet.size(),
broadcast_ndim,
phi::errors::InvalidArgument(
"size of alphabet [%d] is less than broadcast ndim [%d]",
alphabet.size(),
broadcast_ndim));
PADDLE_ENFORCE_GE(broadcast_ndim,
tenosr_ndim,
phi::errors::InvalidArgument(
"broadcast ndim [%d] is less than tenosr ndim [%d]",
broadcast_ndim,
tenosr_ndim));
if (tenosr_ndim <= 0) {
return std::string();
}
return alphabet.substr(broadcast_ndim - tenosr_ndim, tenosr_ndim);
}

// SPMDRuleMap
SPMDRuleMap& SPMDRuleMap::Instance() {
static SPMDRuleMap g_spmd_rule_map;
return g_spmd_rule_map;
}

// To enable default replicated spmd rule for op that are NOT registered
// which all tensors of inputs and outputs will be replicated in all ranks of
// the mesh.
SPMDRuleBase* SPMDRuleMap::Get(const std::string& op_type) const {
auto rule_ptr = GetNullable(op_type);
if (rule_ptr == nullptr) {
std::string str;
for (const auto& item : map_) {
str += item.first + ", ";
}
VLOG(4) << "Size of current map [" << map_.size() << "]";
VLOG(4) << "Keys are [" << str << "]";
}
PADDLE_ENFORCE_NOT_NULL(
rule_ptr,
platform::errors::NotFound(
"NO SPMD Rule has been registered for Operator [%s].", op_type));
return rule_ptr;
}

SPMDRuleBase* SPMDRuleMap::GetNullable(const std::string& op_type) const {
auto it = map_.find(op_type);
if (it == map_.end()) {
return nullptr;
} else {
return it->second.get();
}
}

int SPMDRuleMap::Insert(const std::string& op_type,
std::unique_ptr<SPMDRuleBase> rule) {
VLOG(4) << "Call SPMDRuleMap::Insert!";
PADDLE_ENFORCE_NE(
Has(op_type),
true,
platform::errors::AlreadyExists(
"SPMD Rule for Operator [%s] has been registered.", op_type));
map_.insert({op_type, std::move(rule)});

return 1;
}

} // namespace auto_parallel
} // namespace distributed
} // namespace paddle
161 changes: 161 additions & 0 deletions paddle/fluid/distributed/auto_parallel/spmd_rules/common.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include <iterator>
#include <map>
#include <string>
#include <vector>

#include "paddle/fluid/distributed/auto_parallel/spmd_rules/dist_tensor_spec.h"
#include "paddle/fluid/framework/attribute.h"
#include "paddle/fluid/framework/type_defs.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/phi/core/distributed/auto_parallel/dist_attr.h"
#include "paddle/utils/flat_hash_map.h"

namespace paddle {
namespace distributed {
namespace auto_parallel {

using paddle::framework::Attribute;

class SPMDRuleBase {
public:
virtual ~SPMDRuleBase() {}

// Based on the information of Input Tensors and Op Attribute:
// 1. Merge the Sharding (dims_mapping) among Input Tensors.
// 2. Infer the Sharding (dims_mapping) for Output Tensors.
// The Info of input tensors (Shape and DistAttr) are wrapped as
// DistTensorSpec, and op attribtue should be given as AttributeMap. The
// Output is a pair consist of two vectors:
// 1. The first vector: the merged DistAttr of input tensors.
// 2. The infered DistAttr of output tensors.
// The Merged DistAttr might be different from the original Intput DistAttrs,
// which means that the corressponding input tensor need to be reshard.
virtual std::pair<std::vector<TensorDistAttr>, std::vector<TensorDistAttr>>
InferForward(const std::vector<DistTensorSpec>& input_specs,
const paddle::framework::AttributeMap& attrs);

// Based on the information of Output Tensors and Op Attribute:
// 1. Merge the Sharding (dims_mapping) among Output Tensors.
// 2. Infer the Sharding (dims_mapping) for Input Tensors.
// The Info of output tensors (Shape and DistAttr) are wrapped as
// DistTensorSpec, and op attribtue should be given as AttributeMap. The
// Output is a pair consist of two vectors:
// 1. The first vector: the merged DistAttr of output tensors.
// 2. The infered DistAttr of Input tensors.
virtual std::pair<std::vector<TensorDistAttr>, std::vector<TensorDistAttr>>
InferBackward(const std::vector<DistTensorSpec>& output_specs,
const paddle::framework::AttributeMap& attrs);

template <typename T>
inline const T ExtractAttr(
const std::string& name,
const paddle::framework::AttributeMap& attrs) const {
auto& attr = GetAttr(name, attrs);

// In order to get bool attr properly
framework::proto::AttrType attr_type =
static_cast<framework::proto::AttrType>(attr.index() - 1);
if (attr_type == framework::proto::AttrType::INT) {
if (std::is_same<bool, T>::value) {
return static_cast<bool>(PADDLE_GET_CONST(int, attr));
}
}

return PADDLE_GET_CONST(T, attr);
}

const Attribute& GetAttr(const std::string& name,
const paddle::framework::AttributeMap& attrs) const {
auto iter = attrs.find(name);
PADDLE_ENFORCE_NE(iter,
attrs.end(),
paddle::platform::errors::NotFound(
"(%s) is not found in AttributeMap."));
return iter->second;
}
};

// Merge sharding specification (dims mapping) of given tensors.
// The same axes of different tensors will be merged.
std::unordered_map<std::string, int64_t> ShardingMergeForTensors(
const std::vector<std::pair<const std::string, const std::vector<int64_t>>>&
tensor_axes_to_dim_pairs);

// Merge the sharding specification (dims mapping) for one tensor Axis.
// Rule1: A repicated dimension could be merged by any sharded dimension.
// Rule2: A tensor axis could at most be sharded by one mesh dimension.
// (TODO trigger heuristics cost model and reshard to handle axis sharded by
// multiple dimension case.)
int64_t ShardingMergeForAxis(const std::string& axis,
const int64_t& mesh_dim1,
const int64_t& mesh_dim2);

TensorDistAttr CopyTensorDistAttrForOutput(const TensorDistAttr& src_dist_attr);

// Resolute the partial mesh dimension of a output tensor, giving the
// merged sharding specifcation of input tensors and the axis names of output
// tensor. Input are
std::vector<int64_t> ResoluteOutputPartialDimension(
const std::unordered_map<std::string, int64_t>& axis_to_dim_map,
const std::string& tensor_axes);

// Generate the axis notation of tensor for the einsum notation of a broadcast
// operation(alignment star from the rightmost axis). tenosr_ndim: the size of
// the tensor. broadcast_ndim: the maxium size of tensors in this broadcast
// operation. alphabet: the characters used to represent the axes of tensor.
// length of alphabet should >= broadcast_ndim.
std::string GetBroadcastAxes(const int64_t& tenosr_ndim,
const int64_t& broadcast_ndim,
const std::string& alphabet);

// The static map that stores and initializes all the registered SPMD rules.
class SPMDRuleMap {
public:
~SPMDRuleMap() = default;

// A singleton
static SPMDRuleMap& Instance();

// Returns the spmd rule for the given op_type
SPMDRuleBase* Get(const std::string& op_type) const;

// Returns the spmd by name or nullptr if not registered
SPMDRuleBase* GetNullable(const std::string& op_type) const;

// Register a spmd for an op_type.
int Insert(const std::string& op_type, std::unique_ptr<SPMDRuleBase> rule);

bool Has(const std::string& op_type) const {
return map_.find(op_type) != map_.end();
}

private:
SPMDRuleMap() = default;
paddle::flat_hash_map<std::string, std::unique_ptr<SPMDRuleBase>> map_;
DISABLE_COPY_AND_ASSIGN(SPMDRuleMap);
};

#define REGISTER_SPMD_RULE(op_type, rule_class, ...) \
UNUSED static int __spmd_rule_holder_##op_type = \
::paddle::distributed::auto_parallel::SPMDRuleMap::Instance().Insert( \
#op_type, std::make_unique<rule_class>(__VA_ARGS__))

} // namespace auto_parallel
} // namespace distributed
} // namespace paddle
Loading