Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add the first implememtation of fusion_group op #19621

Merged
merged 25 commits into from
Jan 3, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
6f26abb
Add the dynamic load of nvrtc, and support runtime compiling of CUDA …
Xreki Aug 23, 2019
8c04f5b
Call CUDA driver api to launch the kernel compiled by nvrtc.
Xreki Aug 26, 2019
1dad4e6
Merge branch 'develop' into fuse_jit_compile
Xreki Aug 26, 2019
fc953b6
Disable for mac and windows.
Xreki Aug 27, 2019
b5fc76c
Refine the codes to support manually specified num_threads and worklo…
Xreki Aug 27, 2019
a32fc9a
Merge branch 'develop' into fuse_jit_compile
Xreki Aug 29, 2019
3c879d8
Refine the CUDA kernel to support large dims.
Xreki Aug 29, 2019
e0ac413
Merge branch 'develop' into fuse_fusion_group
Xreki Sep 3, 2019
dc9947f
Add DeviceCodePool to manage all device codes.
Xreki Sep 3, 2019
9c51937
Add the first implementation fusion_group op.
Xreki Sep 4, 2019
7becebd
Add unit-test for fusion_group op.
Xreki Sep 4, 2019
9f0c215
Add the check of result.
Xreki Sep 4, 2019
cfc91e1
Merge branch 'develop' into fuse_fusion_group
Xreki Sep 4, 2019
07aeb1a
Add the check of nvrtc in unit-test.
Xreki Sep 4, 2019
583a254
Add comment to explain the inputs, outputs and features of fusion_gro…
Xreki Sep 4, 2019
56bb1f1
Disable fusion_group op for mac and windows.
Xreki Sep 4, 2019
6b3a03e
Merge branch 'develop' into fuse_fusion_group
Xreki Sep 5, 2019
a1fd12e
Make the compiling of device code return status instead of hanging up.
Xreki Sep 5, 2019
121a919
Merge branch 'develop' into fuse_fusion_group
Xreki Oct 29, 2019
b907704
Add the check of whether there is CUDA driver library, and do not cor…
Xreki Oct 29, 2019
fb8ddc9
Unify fusion_group_op's input and output names.
Xreki Oct 29, 2019
b1fb85a
Merge branch 'develop' into fuse_fusion_group
Xreki Oct 30, 2019
3e73678
Add the check of CUDA driver library in unittest.
Xreki Oct 30, 2019
fc8a6eb
Merge branch 'develop' into fuse_fusion_group
Xreki Dec 27, 2019
cf97946
Refine the calling of PADDLE_ENFORCE.
Xreki Dec 27, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion cmake/operators.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,9 @@ function(op_library TARGET)
# Define operators that don't need pybind here.
foreach(manual_pybind_op "compare_op" "logical_op" "nccl_op"
"tensor_array_read_write_op" "tensorrt_engine_op" "conv_fusion_op"
"fusion_transpose_flatten_concat_op" "fusion_conv_inception_op" "sync_batch_norm_op" "dgc_op" "fused_fc_elementwise_layernorm_op" "multihead_matmul_op")
"fusion_transpose_flatten_concat_op" "fusion_conv_inception_op"
"sync_batch_norm_op" "dgc_op" "fused_fc_elementwise_layernorm_op"
"multihead_matmul_op" "fusion_group_op")
if ("${TARGET}" STREQUAL "${manual_pybind_op}")
set(pybind_flag 1)
endif()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ void FusionGroupPass::InsertFusionGroupOp(
input_names.push_back(n->Name());
external_nodes.insert(n);
}
op_desc.SetInput("Xs", input_names);
op_desc.SetInput("Inputs", input_names);

std::vector<std::string> output_names;
for (auto* n : output_vars_of_subgraph) {
Expand Down
9 changes: 8 additions & 1 deletion paddle/fluid/operators/fused/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@ register_operators(EXCLUDES
fusion_transpose_flatten_concat_op
fusion_conv_inception_op
fused_fc_elementwise_layernorm_op
multihead_matmul_op)
multihead_matmul_op
fusion_group_op)

if (WITH_GPU)
# conv_fusion_op needs cudnn 7 above
Expand All @@ -26,4 +27,10 @@ if (WITH_GPU)
# multihead_matmul_op
op_library(multihead_matmul_op)
file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(multihead_matmul);\n")
# fusion_group
if(NOT APPLE AND NOT WIN32)
op_library(fusion_group_op DEPS device_code)
file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(fusion_group);\n")
cc_test(test_fusion_group_op SRCS fusion_group_op_test.cc DEPS fusion_group_op)
endif()
endif()
90 changes: 90 additions & 0 deletions paddle/fluid/operators/fused/fusion_group_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/fused/fusion_group_op.h"

namespace paddle {
namespace operators {

class FusionGroupOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;

void InferShape(framework::InferShapeContext *ctx) const override {
const size_t num_ins = ctx->Inputs("Inputs").size();
const size_t num_outs = ctx->Outputs("Outs").size();

PADDLE_ENFORCE_GE(
num_ins, 1UL,
platform::errors::InvalidArgument(
"Expected the number of inputs >= 1. Received %d.", num_ins));
PADDLE_ENFORCE_GE(
num_outs, 1UL,
platform::errors::InvalidArgument(
"Expected the number of outputs >= 1. Recived %d.", num_outs));

int type = ctx->Attrs().Get<int>("type");
PADDLE_ENFORCE_EQ(type, 0UL,
platform::errors::InvalidArgument(
"Only support fusion of elementwise operations."));

std::vector<framework::DDim> x_dims = ctx->GetInputsDim("Inputs");
if (type == 0) {
for (size_t i = 1; i < num_ins; ++i) {
PADDLE_ENFORCE_EQ(x_dims[0], x_dims[i],
platform::errors::InvalidArgument(
"All the inputs' dims should be the same."));
}
std::vector<framework::DDim> out_dims;
for (size_t j = 0; j < num_outs; ++j) {
out_dims.push_back(x_dims[0]);
}
ctx->SetOutputsDim("Outs", out_dims);
}

// Only lod of Inputs[0] would be shared with Outs.
for (size_t j = 0; j < num_outs; ++j) {
ctx->ShareLoD("Inputs", /*->*/ "Outs", 0, j);
}
}
};

class FusionGroupOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("Inputs",
"(std::vector<LoDTensor>) The inputs of fusion_group op.")
.AsDuplicable();
AddOutput("Outs",
"(std::vector<LoDTensor>) The outputs of fusion_group op.")
.AsDuplicable();
AddAttr<int>("type", "Fusion type.").SetDefault(0);
AddAttr<std::string>("func_name", "Name of the generated functions.")
.SetDefault("");
AddComment(R"DOC(
fusion_group Operator.

It is used to execute a generated CUDA kernel which fuse the computation of
multiple operators into one. It supports serveral types:
0, fused computation of elementwise operations in which all the dims of inputs
and outputs should be exactly the same.
)DOC");
}
};

} // namespace operators
} // namespace paddle

namespace ops = paddle::operators;
REGISTER_OPERATOR(fusion_group, ops::FusionGroupOp, ops::FusionGroupOpMaker);
22 changes: 22 additions & 0 deletions paddle/fluid/operators/fused/fusion_group_op.cu.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/fused/fusion_group_op.h"

namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(
fusion_group,
ops::FusionGroupKernel<paddle::platform::CUDADeviceContext, double>,
ops::FusionGroupKernel<paddle::platform::CUDADeviceContext, float>);
65 changes: 65 additions & 0 deletions paddle/fluid/operators/fused/fusion_group_op.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include <string>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/device_code.h"

namespace paddle {
namespace operators {

template <typename DeviceContext, typename T>
class FusionGroupKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto ins = ctx.MultiInput<framework::LoDTensor>("Inputs");
auto outs = ctx.MultiOutput<framework::LoDTensor>("Outs");
int type = ctx.Attr<int>("type");

size_t num_ins = ins.size();
size_t num_outs = outs.size();

auto place = ctx.GetPlace();
for (size_t i = 0; i < num_outs; ++i) {
outs[i]->mutable_data<T>(place);
}

std::string func_name = ctx.Attr<std::string>("func_name");
platform::DeviceCode* dev_code =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个dev_code需要每次compute的时候都去查找吗?网络创建完后就不会变化了吧,如果比较多的话查找会不会很慢

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

每次compute都要去pool里面拿到,就是访问map,相比program其他那些map访问来说,这个应该算是少的,开销应该还好。如果后面发现这个有性能问题,可以再来改进,比如直接把DeviceCode的指针保存到op信息里面。

platform::DeviceCodePool::Instance().Get(place, func_name);
VLOG(3) << "func_name: " << func_name;

if (type == 0) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

type==elementwise_relu

Copy link
Contributor Author

@Xreki Xreki Oct 30, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • elementwise_add/mul之类的,认为是elementwise类型的二元操作,relu/sigmoid之类的,认为是elementwise类型的一元操作。所以elementwise_relu不合适。
  • type属性,目前是留个空间,后面可能会支持更多类型的计算模式。目前type对op的计算kernel实现没有作用,主要是用来决定InferShape如何检查各个输入Tensor的dims,以及推导输出Tensor的dims。

先用int吧,后面和pass里面统一一下,考虑用一个枚举类型。但是因为现在只支持这一种类型,所以还没想好。

size_t n = ins[0]->numel();
std::vector<void*> args;
args.push_back(&n);
std::vector<const T*> ptrs(num_ins + num_outs);
for (size_t i = 0; i < num_ins; ++i) {
ptrs[i] = ins[i]->data<T>();
args.push_back(&ptrs[i]);
}
for (size_t j = 0; j < num_outs; ++j) {
ptrs[num_ins + j] = outs[j]->data<T>();
args.push_back(&ptrs[num_ins + j]);
}
dev_code->Launch(n, &args);
}
}
};

} // namespace operators
} // namespace paddle
Loading