Skip to content

Commit

Permalink
add cutlass act set in conv_elementwise_add_act_fuse_pass
Browse files Browse the repository at this point in the history
add cutlass act set in conv2d_fusion_layout_transfer_pass
  • Loading branch information
zhoutianzi666 committed Dec 8, 2022
1 parent 6542027 commit a2fe918
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 7 deletions.
30 changes: 27 additions & 3 deletions paddle/fluid/framework/ir/conv2d_fusion_layout_transfer_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,19 @@ void Conv2dFusionLayoutTransferPass::ApplyImpl(ir::Graph *graph) const {
FusePassBase::Init("data_layout_transfer", graph);
auto *scope = param_scope();

// only float16 compute precision need insert transfer_layout.
bool is_fp16_precision =
static_cast<phi::DataType>(Get<int>("model_precision")) ==
phi::DataType::FLOAT16 ||
Get<bool>("enable_gpu_half");
bool cutlass_enable = false;

#ifdef PADDLE_WITH_CUTLASS
cutlass_enable = true;
#endif

if (!(is_fp16_precision && cutlass_enable)) return;

PADDLE_ENFORCE_EQ(graph->IsMainGraph(),
true,
platform::errors::InvalidArgument(
Expand Down Expand Up @@ -169,14 +182,24 @@ void Conv2dFusionLayoutTransferPass::ApplyImpl(ir::Graph *graph) const {
if (data_format != "NCHW") return false;

auto filter_names = op_node->Op()->Input("Filter");
auto act_type = op_node->Op()->GetAttrIfExists<std::string>("activation");
const int cutlass_aligment = 8;
std::unordered_set<std::string> cutlass_act_set = {
"relu", "swish", "identity", "leaky_relu"};
if (!cutlass_act_set.count(act_type)) {
return false;
}

// If filter's channel is not multiple of 8, conv2d_fusion not run at nhwc.
for (const auto &filter_name : filter_names) {
auto *filter_var = scope->FindLocalVar(filter_name);
const auto &filter_tensor = filter_var->Get<phi::DenseTensor>();
if (filter_tensor.dims().size() == 4 &&
(filter_tensor.dims()[0] % 8 != 0 ||
filter_tensor.dims()[1] % 8 != 0)) {
CHECK_EQ(filter_tensor.dims().size() == 4UL, true);
int oc = filter_tensor.dims()[0];
int ic = filter_tensor.dims()[1];
bool cutlass_can_support =
oc % cutlass_aligment == 0 && ic % cutlass_aligment == 0;
if (!cutlass_can_support) {
return false;
}
}
Expand All @@ -190,6 +213,7 @@ void Conv2dFusionLayoutTransferPass::ApplyImpl(ir::Graph *graph) const {
auto *op_desc = op_node->Op();
auto nhwc_attr = framework::Attribute(std::string("NHWC"));
op_desc->SetAttr("data_format", nhwc_attr);
op_desc->SetType("conv2d_fusion_cutlass");
op_desc->Flush();

// transfer weights
Expand Down
70 changes: 66 additions & 4 deletions paddle/fluid/framework/ir/conv_elementwise_add_act_fuse_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ framework::proto::OpDesc PrepareOpDesc(
const framework::proto::OpDesc& base_desc,
const std::string& bias,
const std::string& activation,
const std::string& output) {
const std::string& output,
float alpha) {
auto proto = base_desc;
framework::OpDesc desc(proto, nullptr);
desc.SetType("conv2d_fusion");
Expand All @@ -46,6 +47,8 @@ framework::proto::OpDesc PrepareOpDesc(
desc.SetOutput("Output", {output});
desc.SetAttr("is_test", true);
desc.SetAttr("use_cudnn", false);
// for leaky_relu use
desc.SetAttr("fuse_alpha", alpha);
desc.Flush();
return *desc.Proto();
}
Expand Down Expand Up @@ -118,6 +121,25 @@ ConvElementwiseAddActFusePass::ConvElementwiseAddActFusePass() {
.AddOutput("Out")
.IsTensor()
.End();

AddOpCompat(OpCompat("swish"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End();

AddOpCompat(OpCompat("leaky_relu"))
.AddInput("X")
.IsTensor()
.End()
.AddAttr("alpha")
.IsType<float>()
.End()
.AddOutput("Out")
.IsTensor()
.End();
}

void ConvElementwiseAddActFusePass::ApplyImpl(ir::Graph* graph) const {
Expand All @@ -137,8 +159,28 @@ void ConvElementwiseAddActFusePass::ApplyImpl(ir::Graph* graph) const {
std::unordered_set<std::string> cudnn_act_set({"identity", "relu"});
#endif

std::unordered_set<std::string> cutlass_act_set;
std::unordered_set<std::string> all_act_set = cudnn_act_set;

bool is_fp16_precision =
static_cast<phi::DataType>(Get<int>("model_precision")) ==
phi::DataType::FLOAT16 ||
Get<bool>("enable_gpu_half");
const int cutlass_aligment = 8;
if (is_fp16_precision) {
#ifdef PADDLE_WITH_CUTLASS
// cutlass now support these activations
// cutlass_act_set.insert("swish");
// cutlass_act_set.insert("relu");
// cutlass_act_set.insert("identity");
// cutlass_act_set.insert("leaky_relu");

all_act_set.insert(cutlass_act_set.begin(), cutlass_act_set.end());
#endif
}

patterns::ConvElementwiseaddAct pattern(gpd.mutable_pattern(), pattern_name);
pattern(x, cudnn_act_set);
pattern(x, all_act_set);

auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
Expand All @@ -152,9 +194,27 @@ void ConvElementwiseAddActFusePass::ApplyImpl(ir::Graph* graph) const {
std::string bias_name = elementwise_add_in_y->Name();
std::string act_op_type = act_op->Op()->Type();
std::string act_op_out = act_out->Name();
auto* scope = param_scope();
auto* filter_var = scope->FindLocalVar(conv_filter->Name());
auto* filter_tensor = filter_var->GetMutable<phi::DenseTensor>();
CHECK_EQ(filter_tensor->dims().size() == 4UL, true);
// when this conv2d_fusion problem size is not supported by cutlass and not
// supported by cuDNN, we should not apply this pass
int oc = filter_tensor->dims()[0];
int ic = filter_tensor->dims()[1];
bool cutlass_can_fuse = oc % cutlass_aligment == 0 &&
ic % cutlass_aligment == 0 &&
cutlass_act_set.count(act_op_type);
bool cudnn_can_fuse = cudnn_act_set.count(act_op_type);
if (!cutlass_can_fuse && !cudnn_can_fuse) {
return;
}

float alpha = 0.f;
alpha = act_op->Op()->GetAttrIfExists<float>("alpha");

auto new_op_proto =
PrepareOpDesc(base_op_desc, bias_name, act_op_type, act_op_out);
PrepareOpDesc(base_op_desc, bias_name, act_op_type, act_op_out, alpha);
framework::OpDesc new_op_desc(new_op_proto, nullptr);

// Create a new node for the fused op.
Expand Down Expand Up @@ -195,4 +255,6 @@ REGISTER_PASS_CAPABILITY(conv_elementwise_add_act_fuse_pass)
.EQ("relu", 0)
.EQ("sigmoid", 0)
.EQ("tanh", 0)
.EQ("identity", 0));
.EQ("identity", 0)
.EQ("leaky_relu", 0)
.EQ("swish", 0));

0 comments on commit a2fe918

Please sign in to comment.