-
Notifications
You must be signed in to change notification settings - Fork 5.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add the first implememtation of fusion_group op #19621
Merged
Merged
Changes from all commits
Commits
Show all changes
25 commits
Select commit
Hold shift + click to select a range
6f26abb
Add the dynamic load of nvrtc, and support runtime compiling of CUDA …
Xreki 8c04f5b
Call CUDA driver api to launch the kernel compiled by nvrtc.
Xreki 1dad4e6
Merge branch 'develop' into fuse_jit_compile
Xreki fc953b6
Disable for mac and windows.
Xreki b5fc76c
Refine the codes to support manually specified num_threads and worklo…
Xreki a32fc9a
Merge branch 'develop' into fuse_jit_compile
Xreki 3c879d8
Refine the CUDA kernel to support large dims.
Xreki e0ac413
Merge branch 'develop' into fuse_fusion_group
Xreki dc9947f
Add DeviceCodePool to manage all device codes.
Xreki 9c51937
Add the first implementation fusion_group op.
Xreki 7becebd
Add unit-test for fusion_group op.
Xreki 9f0c215
Add the check of result.
Xreki cfc91e1
Merge branch 'develop' into fuse_fusion_group
Xreki 07aeb1a
Add the check of nvrtc in unit-test.
Xreki 583a254
Add comment to explain the inputs, outputs and features of fusion_gro…
Xreki 56bb1f1
Disable fusion_group op for mac and windows.
Xreki 6b3a03e
Merge branch 'develop' into fuse_fusion_group
Xreki a1fd12e
Make the compiling of device code return status instead of hanging up.
Xreki 121a919
Merge branch 'develop' into fuse_fusion_group
Xreki b907704
Add the check of whether there is CUDA driver library, and do not cor…
Xreki fb8ddc9
Unify fusion_group_op's input and output names.
Xreki b1fb85a
Merge branch 'develop' into fuse_fusion_group
Xreki 3e73678
Add the check of CUDA driver library in unittest.
Xreki fc8a6eb
Merge branch 'develop' into fuse_fusion_group
Xreki cf97946
Refine the calling of PADDLE_ENFORCE.
Xreki File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. */ | ||
|
||
#include "paddle/fluid/operators/fused/fusion_group_op.h" | ||
|
||
namespace paddle { | ||
namespace operators { | ||
|
||
class FusionGroupOp : public framework::OperatorWithKernel { | ||
public: | ||
using framework::OperatorWithKernel::OperatorWithKernel; | ||
|
||
void InferShape(framework::InferShapeContext *ctx) const override { | ||
const size_t num_ins = ctx->Inputs("Inputs").size(); | ||
const size_t num_outs = ctx->Outputs("Outs").size(); | ||
|
||
PADDLE_ENFORCE_GE( | ||
num_ins, 1UL, | ||
platform::errors::InvalidArgument( | ||
"Expected the number of inputs >= 1. Received %d.", num_ins)); | ||
PADDLE_ENFORCE_GE( | ||
num_outs, 1UL, | ||
platform::errors::InvalidArgument( | ||
"Expected the number of outputs >= 1. Recived %d.", num_outs)); | ||
|
||
int type = ctx->Attrs().Get<int>("type"); | ||
PADDLE_ENFORCE_EQ(type, 0UL, | ||
platform::errors::InvalidArgument( | ||
"Only support fusion of elementwise operations.")); | ||
|
||
std::vector<framework::DDim> x_dims = ctx->GetInputsDim("Inputs"); | ||
if (type == 0) { | ||
for (size_t i = 1; i < num_ins; ++i) { | ||
PADDLE_ENFORCE_EQ(x_dims[0], x_dims[i], | ||
platform::errors::InvalidArgument( | ||
"All the inputs' dims should be the same.")); | ||
} | ||
std::vector<framework::DDim> out_dims; | ||
for (size_t j = 0; j < num_outs; ++j) { | ||
out_dims.push_back(x_dims[0]); | ||
} | ||
ctx->SetOutputsDim("Outs", out_dims); | ||
} | ||
|
||
// Only lod of Inputs[0] would be shared with Outs. | ||
for (size_t j = 0; j < num_outs; ++j) { | ||
ctx->ShareLoD("Inputs", /*->*/ "Outs", 0, j); | ||
} | ||
} | ||
}; | ||
|
||
class FusionGroupOpMaker : public framework::OpProtoAndCheckerMaker { | ||
public: | ||
void Make() override { | ||
AddInput("Inputs", | ||
"(std::vector<LoDTensor>) The inputs of fusion_group op.") | ||
.AsDuplicable(); | ||
AddOutput("Outs", | ||
"(std::vector<LoDTensor>) The outputs of fusion_group op.") | ||
.AsDuplicable(); | ||
AddAttr<int>("type", "Fusion type.").SetDefault(0); | ||
AddAttr<std::string>("func_name", "Name of the generated functions.") | ||
.SetDefault(""); | ||
AddComment(R"DOC( | ||
fusion_group Operator. | ||
|
||
It is used to execute a generated CUDA kernel which fuse the computation of | ||
multiple operators into one. It supports serveral types: | ||
0, fused computation of elementwise operations in which all the dims of inputs | ||
and outputs should be exactly the same. | ||
)DOC"); | ||
} | ||
}; | ||
|
||
} // namespace operators | ||
} // namespace paddle | ||
|
||
namespace ops = paddle::operators; | ||
REGISTER_OPERATOR(fusion_group, ops::FusionGroupOp, ops::FusionGroupOpMaker); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. */ | ||
|
||
#include "paddle/fluid/operators/fused/fusion_group_op.h" | ||
|
||
namespace ops = paddle::operators; | ||
namespace plat = paddle::platform; | ||
REGISTER_OP_CUDA_KERNEL( | ||
fusion_group, | ||
ops::FusionGroupKernel<paddle::platform::CUDADeviceContext, double>, | ||
ops::FusionGroupKernel<paddle::platform::CUDADeviceContext, float>); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. */ | ||
|
||
#pragma once | ||
|
||
#include <string> | ||
#include <vector> | ||
#include "paddle/fluid/framework/op_registry.h" | ||
#include "paddle/fluid/platform/device_code.h" | ||
|
||
namespace paddle { | ||
namespace operators { | ||
|
||
template <typename DeviceContext, typename T> | ||
class FusionGroupKernel : public framework::OpKernel<T> { | ||
public: | ||
void Compute(const framework::ExecutionContext& ctx) const override { | ||
auto ins = ctx.MultiInput<framework::LoDTensor>("Inputs"); | ||
auto outs = ctx.MultiOutput<framework::LoDTensor>("Outs"); | ||
int type = ctx.Attr<int>("type"); | ||
|
||
size_t num_ins = ins.size(); | ||
size_t num_outs = outs.size(); | ||
|
||
auto place = ctx.GetPlace(); | ||
for (size_t i = 0; i < num_outs; ++i) { | ||
outs[i]->mutable_data<T>(place); | ||
} | ||
|
||
std::string func_name = ctx.Attr<std::string>("func_name"); | ||
platform::DeviceCode* dev_code = | ||
platform::DeviceCodePool::Instance().Get(place, func_name); | ||
VLOG(3) << "func_name: " << func_name; | ||
|
||
if (type == 0) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. type==elementwise_relu There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
先用int吧,后面和pass里面统一一下,考虑用一个枚举类型。但是因为现在只支持这一种类型,所以还没想好。 |
||
size_t n = ins[0]->numel(); | ||
std::vector<void*> args; | ||
args.push_back(&n); | ||
std::vector<const T*> ptrs(num_ins + num_outs); | ||
for (size_t i = 0; i < num_ins; ++i) { | ||
ptrs[i] = ins[i]->data<T>(); | ||
args.push_back(&ptrs[i]); | ||
} | ||
for (size_t j = 0; j < num_outs; ++j) { | ||
ptrs[num_ins + j] = outs[j]->data<T>(); | ||
args.push_back(&ptrs[num_ins + j]); | ||
} | ||
dev_code->Launch(n, &args); | ||
} | ||
} | ||
}; | ||
|
||
} // namespace operators | ||
} // namespace paddle |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个dev_code需要每次compute的时候都去查找吗?网络创建完后就不会变化了吧,如果比较多的话查找会不会很慢
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
每次compute都要去pool里面拿到,就是访问map,相比program其他那些map访问来说,这个应该算是少的,开销应该还好。如果后面发现这个有性能问题,可以再来改进,比如直接把DeviceCode的指针保存到op信息里面。