-
Notifications
You must be signed in to change notification settings - Fork 5.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Semi-Auto] SPMD Parallel Rule Base #53863
Merged
Merged
Changes from all commits
Commits
Show all changes
46 commits
Select commit
Hold shift + click to select a range
a47bf99
base rule
JZ-LIANG 21b5a75
add sharidng merge
JZ-LIANG f7e39d7
add sharidng axis merge
JZ-LIANG c92992d
define unified data class for inferencing dist_attr
pkuzyc 42a7b77
test wrap DistTensorSpec in dygraph mode
pkuzyc 180edcc
matmul main logic done
JZ-LIANG ecbb1ae
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
JZ-LIANG f314b56
Merge remote-tracking branch 'zyc/develop' into semi-auto/rule-base
JZ-LIANG 46153c7
shape int64
JZ-LIANG 1dcb80e
common cc
JZ-LIANG 198bc1f
define unified data class for inferencing dist_attr
pkuzyc 09d82a5
test wrap DistTensorSpec in dygraph mode
pkuzyc c3ea2a6
define python api and wrap function in static mode for DistTensorSpec
pkuzyc 4cd1a2c
revise syntax
JZ-LIANG 3631f06
Merge remote-tracking branch 'zyc/develop' into semi-auto/rule-base
JZ-LIANG ed0c31e
map bugfix
JZ-LIANG 701d3fa
broadcast func
JZ-LIANG c1545a4
compile 1
JZ-LIANG 3ca2b73
add unitest
JZ-LIANG 747de08
add registry
JZ-LIANG 968ce61
Merge branch 'semi-auto/rule-base' of https://github.com/JZ-LIANG/Pad…
JZ-LIANG 7be672d
update unitest
JZ-LIANG 3389b7e
bugfix
JZ-LIANG 3882a2c
bugfix
JZ-LIANG 3719a5a
add pybind
JZ-LIANG 73f49a8
bugfix
JZ-LIANG aced5ea
bugfix macro gloabl name space
JZ-LIANG ef92dc4
bugfix macro gloabl name space
JZ-LIANG adcb470
segment fault
JZ-LIANG 43df148
pybind
JZ-LIANG 27803af
pybind test
JZ-LIANG 5612da9
pybind bugfixed1
JZ-LIANG f9bd281
pybind bugfixed2
JZ-LIANG 18f8d29
pybind unitest
JZ-LIANG 2628043
Merge remote-tracking branch 'upstream/develop' into semi-auto/rule-base
JZ-LIANG 68a512a
merge dev
JZ-LIANG f2b2edb
merge dev
JZ-LIANG 132558a
merge dev
JZ-LIANG f3bc740
fixed cmake conflict
JZ-LIANG c11cdd2
fixed cmake conflict
JZ-LIANG 491bf65
rename get method
JZ-LIANG 041abd4
revise inferforward output type
JZ-LIANG 60c90d3
revise comment
JZ-LIANG 0fc8ff9
Merge remote-tracking branch 'upstream/develop' into semi-auto/rule-base
JZ-LIANG eecd184
update unitest
JZ-LIANG b70b3d5
update cmake deps
JZ-LIANG File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
4 changes: 4 additions & 0 deletions
4
paddle/fluid/distributed/auto_parallel/spmd_rules/CMakeLists.txt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
cc_library( | ||
spmd_rule | ||
SRCS common.cc dist_tensor_spec.cc matmul_spmd_rule.cc | ||
DEPS phi) |
213 changes: 213 additions & 0 deletions
213
paddle/fluid/distributed/auto_parallel/spmd_rules/common.cc
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,213 @@ | ||
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. */ | ||
|
||
#include "paddle/fluid/distributed/auto_parallel/spmd_rules/common.h" | ||
|
||
#include <glog/logging.h> | ||
|
||
#include "paddle/fluid/distributed/auto_parallel/spmd_rules/rules.h" | ||
|
||
namespace paddle { | ||
namespace distributed { | ||
namespace auto_parallel { | ||
|
||
std::pair<std::vector<TensorDistAttr>, std::vector<TensorDistAttr>> | ||
SPMDRuleBase::InferForward(const std::vector<DistTensorSpec>& input_specs, | ||
const paddle::framework::AttributeMap& attrs) { | ||
PADDLE_THROW( | ||
phi::errors::Unimplemented("InferForward should be called from a " | ||
"derived class of SPMDRuleBase !")); | ||
} | ||
|
||
std::pair<std::vector<TensorDistAttr>, std::vector<TensorDistAttr>> | ||
SPMDRuleBase::InferBackward(const std::vector<DistTensorSpec>& output_specs, | ||
const paddle::framework::AttributeMap& attrs) { | ||
PADDLE_THROW( | ||
phi::errors::Unimplemented("InferBackward should be called from a " | ||
"derived class of SPMDRuleBase !")); | ||
} | ||
|
||
std::unordered_map<std::string, int64_t> ShardingMergeForTensors( | ||
const std::vector<std::pair<const std::string, const std::vector<int64_t>>>& | ||
tensor_axes_to_dim_pairs) { | ||
std::unordered_map<std::string, int64_t> axis_to_dim_map; | ||
std::unordered_map<int64_t, std::string> dim_to_axis_map; | ||
int64_t merge_dim; | ||
|
||
for (auto& pair : tensor_axes_to_dim_pairs) { | ||
for (size_t i = 0; i < pair.second.size(); ++i) { | ||
auto tensor_axis = pair.first.substr(i, 1); | ||
auto mesh_dim = pair.second[i]; | ||
|
||
if (axis_to_dim_map.count(tensor_axis) == 0) { | ||
merge_dim = mesh_dim; | ||
} else { | ||
merge_dim = ShardingMergeForAxis( | ||
tensor_axis, mesh_dim, axis_to_dim_map[tensor_axis]); | ||
} | ||
axis_to_dim_map[tensor_axis] = merge_dim; | ||
if (merge_dim != -1) { | ||
if (dim_to_axis_map.count(merge_dim) == 0) { | ||
dim_to_axis_map.insert({merge_dim, tensor_axis}); | ||
} else if (dim_to_axis_map[merge_dim].find(tensor_axis) == | ||
std::string::npos) { | ||
dim_to_axis_map[merge_dim] += tensor_axis; | ||
} | ||
} | ||
} | ||
} | ||
|
||
// Resolute "mesh_dim shard by more than one axis" confict. | ||
// Now we just naive pick the first axis naively. | ||
// (TODO) use local cost model to pick the axis with lowest cost(in concern of | ||
// memory or communication or computation). | ||
for (auto& it : dim_to_axis_map) { | ||
if (it.second.size() > 1) { | ||
VLOG(4) << "Sharding Conflict: Mesh_Dim [" << it.first | ||
<< "] are Sharding Multiple Tensor Axis: [" << it.second | ||
<< "]. The Axis: [" << it.second[0] << "] is Picked."; | ||
for (size_t i = 1; i < it.second.size(); ++i) { | ||
axis_to_dim_map[it.second.substr(i, 1)] = -1; | ||
} | ||
} | ||
} | ||
|
||
return axis_to_dim_map; | ||
} | ||
|
||
// Rule1: A repicated dimension could be merged by any sharded dimension. | ||
// Rule2: A tensor axis could at most be sharded by one mesh dimension. | ||
// (TODO trigger heuristics cost model and reshard to handle axis sharded by | ||
// multiple dimension case.) | ||
int64_t ShardingMergeForAxis(const std::string& axis, | ||
const int64_t& mesh_dim1, | ||
const int64_t& mesh_dim2) { | ||
if (mesh_dim1 != mesh_dim2) { | ||
if (mesh_dim1 == -1) { | ||
return mesh_dim2; | ||
} else if (mesh_dim2 == -1) { | ||
return mesh_dim1; | ||
} else { | ||
// (TODO) local cost model here. | ||
PADDLE_THROW( | ||
phi::errors::Unimplemented("Tensor Axis[%s] is Sharded by two " | ||
"different mesh dimension [%d] and [%d].", | ||
axis, | ||
mesh_dim1, | ||
mesh_dim2)); | ||
} | ||
|
||
} else { | ||
return mesh_dim1; | ||
} | ||
} | ||
|
||
TensorDistAttr CopyTensorDistAttrForOutput( | ||
const TensorDistAttr& src_dist_attr) { | ||
TensorDistAttr new_dist_attr = TensorDistAttr(); | ||
new_dist_attr.set_process_mesh(src_dist_attr.process_mesh()); | ||
new_dist_attr.set_batch_dim(src_dist_attr.batch_dim()); | ||
new_dist_attr.set_dynamic_dims(src_dist_attr.dynamic_dims()); | ||
// new_dist_attr.set_annotated(false); TODO unset field is false by default. | ||
return new_dist_attr; | ||
} | ||
|
||
std::vector<int64_t> ResoluteOutputPartialDimension( | ||
const std::unordered_map<std::string, int64_t>& axis_to_dim_map, | ||
const std::string& tensor_axes) { | ||
std::vector<int64_t> partial_on_dims; | ||
|
||
for (auto& it : axis_to_dim_map) { | ||
if (tensor_axes.find(it.first) == std::string::npos) { | ||
if (it.second > -1) { | ||
partial_on_dims.push_back(it.second); | ||
} | ||
} | ||
} | ||
return partial_on_dims; | ||
} | ||
|
||
std::string GetBroadcastAxes(const int64_t& tenosr_ndim, | ||
const int64_t& broadcast_ndim, | ||
const std::string& alphabet) { | ||
PADDLE_ENFORCE_GE( | ||
alphabet.size(), | ||
broadcast_ndim, | ||
phi::errors::InvalidArgument( | ||
"size of alphabet [%d] is less than broadcast ndim [%d]", | ||
alphabet.size(), | ||
broadcast_ndim)); | ||
PADDLE_ENFORCE_GE(broadcast_ndim, | ||
tenosr_ndim, | ||
phi::errors::InvalidArgument( | ||
"broadcast ndim [%d] is less than tenosr ndim [%d]", | ||
broadcast_ndim, | ||
tenosr_ndim)); | ||
if (tenosr_ndim <= 0) { | ||
return std::string(); | ||
} | ||
return alphabet.substr(broadcast_ndim - tenosr_ndim, tenosr_ndim); | ||
} | ||
|
||
// SPMDRuleMap | ||
SPMDRuleMap& SPMDRuleMap::Instance() { | ||
static SPMDRuleMap g_spmd_rule_map; | ||
return g_spmd_rule_map; | ||
} | ||
|
||
// To enable default replicated spmd rule for op that are NOT registered | ||
// which all tensors of inputs and outputs will be replicated in all ranks of | ||
// the mesh. | ||
SPMDRuleBase* SPMDRuleMap::Get(const std::string& op_type) const { | ||
auto rule_ptr = GetNullable(op_type); | ||
if (rule_ptr == nullptr) { | ||
std::string str; | ||
for (const auto& item : map_) { | ||
str += item.first + ", "; | ||
} | ||
VLOG(4) << "Size of current map [" << map_.size() << "]"; | ||
VLOG(4) << "Keys are [" << str << "]"; | ||
} | ||
PADDLE_ENFORCE_NOT_NULL( | ||
rule_ptr, | ||
platform::errors::NotFound( | ||
"NO SPMD Rule has been registered for Operator [%s].", op_type)); | ||
return rule_ptr; | ||
} | ||
|
||
SPMDRuleBase* SPMDRuleMap::GetNullable(const std::string& op_type) const { | ||
auto it = map_.find(op_type); | ||
if (it == map_.end()) { | ||
return nullptr; | ||
} else { | ||
return it->second.get(); | ||
} | ||
} | ||
|
||
int SPMDRuleMap::Insert(const std::string& op_type, | ||
std::unique_ptr<SPMDRuleBase> rule) { | ||
VLOG(4) << "Call SPMDRuleMap::Insert!"; | ||
PADDLE_ENFORCE_NE( | ||
Has(op_type), | ||
true, | ||
platform::errors::AlreadyExists( | ||
"SPMD Rule for Operator [%s] has been registered.", op_type)); | ||
map_.insert({op_type, std::move(rule)}); | ||
|
||
return 1; | ||
} | ||
|
||
} // namespace auto_parallel | ||
} // namespace distributed | ||
} // namespace paddle |
161 changes: 161 additions & 0 deletions
161
paddle/fluid/distributed/auto_parallel/spmd_rules/common.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,161 @@ | ||
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. */ | ||
|
||
#pragma once | ||
|
||
#include <iterator> | ||
#include <map> | ||
#include <string> | ||
#include <vector> | ||
|
||
#include "paddle/fluid/distributed/auto_parallel/spmd_rules/dist_tensor_spec.h" | ||
#include "paddle/fluid/framework/attribute.h" | ||
#include "paddle/fluid/framework/type_defs.h" | ||
#include "paddle/fluid/platform/enforce.h" | ||
#include "paddle/phi/core/distributed/auto_parallel/dist_attr.h" | ||
#include "paddle/utils/flat_hash_map.h" | ||
|
||
namespace paddle { | ||
namespace distributed { | ||
namespace auto_parallel { | ||
|
||
using paddle::framework::Attribute; | ||
|
||
class SPMDRuleBase { | ||
public: | ||
virtual ~SPMDRuleBase() {} | ||
|
||
// Based on the information of Input Tensors and Op Attribute: | ||
// 1. Merge the Sharding (dims_mapping) among Input Tensors. | ||
// 2. Infer the Sharding (dims_mapping) for Output Tensors. | ||
// The Info of input tensors (Shape and DistAttr) are wrapped as | ||
// DistTensorSpec, and op attribtue should be given as AttributeMap. The | ||
// Output is a pair consist of two vectors: | ||
// 1. The first vector: the merged DistAttr of input tensors. | ||
// 2. The infered DistAttr of output tensors. | ||
// The Merged DistAttr might be different from the original Intput DistAttrs, | ||
// which means that the corressponding input tensor need to be reshard. | ||
virtual std::pair<std::vector<TensorDistAttr>, std::vector<TensorDistAttr>> | ||
InferForward(const std::vector<DistTensorSpec>& input_specs, | ||
const paddle::framework::AttributeMap& attrs); | ||
|
||
// Based on the information of Output Tensors and Op Attribute: | ||
// 1. Merge the Sharding (dims_mapping) among Output Tensors. | ||
// 2. Infer the Sharding (dims_mapping) for Input Tensors. | ||
// The Info of output tensors (Shape and DistAttr) are wrapped as | ||
// DistTensorSpec, and op attribtue should be given as AttributeMap. The | ||
// Output is a pair consist of two vectors: | ||
// 1. The first vector: the merged DistAttr of output tensors. | ||
// 2. The infered DistAttr of Input tensors. | ||
virtual std::pair<std::vector<TensorDistAttr>, std::vector<TensorDistAttr>> | ||
InferBackward(const std::vector<DistTensorSpec>& output_specs, | ||
const paddle::framework::AttributeMap& attrs); | ||
|
||
template <typename T> | ||
inline const T ExtractAttr( | ||
const std::string& name, | ||
const paddle::framework::AttributeMap& attrs) const { | ||
auto& attr = GetAttr(name, attrs); | ||
|
||
// In order to get bool attr properly | ||
framework::proto::AttrType attr_type = | ||
static_cast<framework::proto::AttrType>(attr.index() - 1); | ||
if (attr_type == framework::proto::AttrType::INT) { | ||
if (std::is_same<bool, T>::value) { | ||
return static_cast<bool>(PADDLE_GET_CONST(int, attr)); | ||
} | ||
} | ||
|
||
return PADDLE_GET_CONST(T, attr); | ||
} | ||
|
||
const Attribute& GetAttr(const std::string& name, | ||
const paddle::framework::AttributeMap& attrs) const { | ||
auto iter = attrs.find(name); | ||
PADDLE_ENFORCE_NE(iter, | ||
attrs.end(), | ||
paddle::platform::errors::NotFound( | ||
"(%s) is not found in AttributeMap.")); | ||
return iter->second; | ||
} | ||
}; | ||
|
||
// Merge sharding specification (dims mapping) of given tensors. | ||
// The same axes of different tensors will be merged. | ||
std::unordered_map<std::string, int64_t> ShardingMergeForTensors( | ||
const std::vector<std::pair<const std::string, const std::vector<int64_t>>>& | ||
tensor_axes_to_dim_pairs); | ||
|
||
// Merge the sharding specification (dims mapping) for one tensor Axis. | ||
// Rule1: A repicated dimension could be merged by any sharded dimension. | ||
// Rule2: A tensor axis could at most be sharded by one mesh dimension. | ||
// (TODO trigger heuristics cost model and reshard to handle axis sharded by | ||
// multiple dimension case.) | ||
int64_t ShardingMergeForAxis(const std::string& axis, | ||
const int64_t& mesh_dim1, | ||
const int64_t& mesh_dim2); | ||
|
||
TensorDistAttr CopyTensorDistAttrForOutput(const TensorDistAttr& src_dist_attr); | ||
|
||
// Resolute the partial mesh dimension of a output tensor, giving the | ||
// merged sharding specifcation of input tensors and the axis names of output | ||
// tensor. Input are | ||
std::vector<int64_t> ResoluteOutputPartialDimension( | ||
const std::unordered_map<std::string, int64_t>& axis_to_dim_map, | ||
const std::string& tensor_axes); | ||
|
||
// Generate the axis notation of tensor for the einsum notation of a broadcast | ||
// operation(alignment star from the rightmost axis). tenosr_ndim: the size of | ||
// the tensor. broadcast_ndim: the maxium size of tensors in this broadcast | ||
// operation. alphabet: the characters used to represent the axes of tensor. | ||
// length of alphabet should >= broadcast_ndim. | ||
std::string GetBroadcastAxes(const int64_t& tenosr_ndim, | ||
const int64_t& broadcast_ndim, | ||
const std::string& alphabet); | ||
|
||
// The static map that stores and initializes all the registered SPMD rules. | ||
class SPMDRuleMap { | ||
public: | ||
~SPMDRuleMap() = default; | ||
|
||
// A singleton | ||
static SPMDRuleMap& Instance(); | ||
|
||
// Returns the spmd rule for the given op_type | ||
SPMDRuleBase* Get(const std::string& op_type) const; | ||
|
||
// Returns the spmd by name or nullptr if not registered | ||
SPMDRuleBase* GetNullable(const std::string& op_type) const; | ||
|
||
// Register a spmd for an op_type. | ||
int Insert(const std::string& op_type, std::unique_ptr<SPMDRuleBase> rule); | ||
|
||
bool Has(const std::string& op_type) const { | ||
return map_.find(op_type) != map_.end(); | ||
} | ||
|
||
private: | ||
SPMDRuleMap() = default; | ||
paddle::flat_hash_map<std::string, std::unique_ptr<SPMDRuleBase>> map_; | ||
DISABLE_COPY_AND_ASSIGN(SPMDRuleMap); | ||
}; | ||
|
||
#define REGISTER_SPMD_RULE(op_type, rule_class, ...) \ | ||
UNUSED static int __spmd_rule_holder_##op_type = \ | ||
::paddle::distributed::auto_parallel::SPMDRuleMap::Instance().Insert( \ | ||
#op_type, std::make_unique<rule_class>(__VA_ARGS__)) | ||
|
||
} // namespace auto_parallel | ||
} // namespace distributed | ||
} // namespace paddle |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Better add a copy constructor.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
copy constructor is not fit in this case since part of data member would be changed after "copy".