Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sparse_fc supported #41770

Closed
wants to merge 21 commits into from
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ option(WITH_IPU "Compile PaddlePaddle with Graphcore IPU" OFF)
option(WITH_ASCEND_CL "Compile PaddlePaddle with ASCEND CL" ${WITH_ASCEND})
option(WITH_ASCEND_CXX11 "Compile PaddlePaddle with ASCEND and CXX11 ABI" OFF)
option(WITH_ONNXRUNTIME "Compile PaddlePaddle with ONNXRUNTIME" OFF)
option(WITH_CUSPARSELT "Compile PaddlePaddle with CUSPARSELT" OFF)
# Note(zhouwei): It use option above, so put here
include(init)
include(generic) # simplify cmake module
Expand Down
51 changes: 51 additions & 0 deletions cmake/external/cusparselt.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

if (NOT (WITH_CUSPARSELT AND WITH_TENSORRT))
return()
endif ()

if (WITH_ARM OR WIN32)
message(SEND_ERROR "The current sparselt support linux only")
return()
endif ()

INCLUDE(ExternalProject)

SET(CUSPARSELT_PROJECT "extern_cusparselt")
SET(CUSPARSELT_URL "https://developer.download.nvidia.com/compute/libcusparse-lt/0.2.0/local_installers/libcusparse_lt-linux-x86_64-0.2.0.1.tar.gz" CACHE STRING "" FORCE)
SET(CUSPARSELT_PREFIX_DIR ${THIRD_PARTY_PATH}/cusparselt)
SET(CUSPARSELT_INSTALL_DIR ${THIRD_PARTY_PATH}/install/cusparselt)
SET(CUSPARSELT_INC_DIR "${CUSPARSELT_INSTALL_DIR}/include" CACHE PATH "sparselt include directory." FORCE)
SET(CUSPARSELT_LIB_DIR "${CUSPARSELT_INSTALL_DIR}/lib64" CACHE PATH "sparselt lib directory." FORCE)
set_directory_properties(PROPERTIES CLEAN_NO_CUSTOM 1)
include_directories(${CUSPARSELT_INC_DIR})

ExternalProject_Add(
${CUSPARSELT_PROJECT}
${EXTERNAL_PROJECT_LOG_ARGS}
URL ${CUSPARSELT_URL}
PREFIX ${CUSPARSELT_PREFIX_DIR}
DOWNLOAD_NO_PROGRESS 1
CONFIGURE_COMMAND ""
BUILD_COMMAND ""
INSTALL_COMMAND ${CMAKE_COMMAND} -E copy_directory ${CUSPARSELT_PREFIX_DIR}/src/extern_cusparselt/lib64 ${CUSPARSELT_LIB_DIR} &&
${CMAKE_COMMAND} -E copy_directory ${CUSPARSELT_PREFIX_DIR}/src/extern_cusparselt/include ${CUSPARSELT_INC_DIR}
UPDATE_COMMAND ""
)

add_library(cusparselt INTERFACE)
add_dependencies(cusparselt ${CUSPARSELT_PROJECT})
set(CUSPARSELT_FOUND ON)
add_definitions(-DPADDLE_WITH_CUSPARSELT)
7 changes: 7 additions & 0 deletions cmake/inference_lib.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,13 @@ function(copy_part_of_thrid_party TARGET DST)
endif()
endif()

if (WITH_SPARSELT)
set(dst_dir "${DST}/third_party/install/cusparselt")
copy(${TARGET}
SRCS ${CUSPARSELT_INC_DIR} ${CUSPARSELT_LIB_DIR}
DSTS ${dst_dir} ${dst_dir})
endif()

set(dst_dir "${DST}/third_party/install/gflags")
copy(${TARGET}
SRCS ${GFLAGS_INCLUDE_DIR} ${GFLAGS_LIBRARIES}
Expand Down
5 changes: 5 additions & 0 deletions cmake/third_party.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -426,4 +426,9 @@ if (WITH_IPU)
list(APPEND third_party_deps extern_poplar)
endif()

if(WITH_CUSPARSELT)
include(external/cusparselt) # download, build, install cusparselt
list(APPEND third_party_deps extern_cusparselt)
endif()

add_custom_target(third_party ALL DEPENDS ${third_party_deps})
1 change: 1 addition & 0 deletions paddle/fluid/framework/ir/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ pass_library(add_support_int8_pass inference)
pass_library(matmul_scale_fuse_pass inference)
pass_library(gpu_cpu_map_matmul_to_mul_pass inference)
pass_library(mixed_precision_configure_pass inference)
pass_library(replace_dense_with_sparse_pass inference)
pass_library(generate_pass DEPS pass_desc_proto)
target_link_libraries(generate_pass pass_desc_proto)

Expand Down
25 changes: 25 additions & 0 deletions paddle/fluid/framework/ir/graph_pattern_detector.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3424,6 +3424,31 @@ PDNode *patterns::AddSupportInt8::operator()() {
return quant_out;
}

PDNode *patterns::DenseFC::operator()() {
auto *fc = pattern->NewNode(fc_repr())->assert_is_op("fc");
// Input
auto *fc_input = pattern->NewNode(fc_input_repr())
->AsInput()
->assert_is_op_input("fc", "Input");
// Filter
auto *fc_weights = pattern->NewNode(fc_weights_repr())
->AsInput()
->assert_is_op_input("fc", "W");
// Bias
auto *fc_bias = pattern->NewNode(fc_bias_repr())
->AsInput()
->assert_is_op_input("fc", "Bias");
// Output
auto *fc_out = pattern->NewNode(fc_out_repr())
->AsOutput()
->assert_is_op_output("fc", "Out")
->assert_is_only_output_of_op("fc");

fc->LinksFrom({fc_input, fc_weights, fc_bias}).LinksTo({fc_out});

return fc_out;
}

} // namespace ir
} // namespace framework
} // namespace paddle
17 changes: 17 additions & 0 deletions paddle/fluid/framework/ir/graph_pattern_detector.h
Original file line number Diff line number Diff line change
Expand Up @@ -1941,6 +1941,23 @@ struct AddSupportInt8 : public PatternBase {
PATTERN_DECL_NODE(quant_out);
};

//
// \brief Pattern looking for dense fc.
//
struct DenseFC : public PatternBase {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个新增的Pattern和当前文件中已定义的patterns::FC有什么区别么?为什么不直接用patterns::FC?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

patterns::FC是找mul+element的 不一样

DenseFC(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "dense_fc") {}

PDNode* operator()();

// declare operator node's name
PATTERN_DECL_NODE(fc);
PATTERN_DECL_NODE(fc_out);
PATTERN_DECL_NODE(fc_input);
PATTERN_DECL_NODE(fc_weights);
PATTERN_DECL_NODE(fc_bias);
};

} // namespace patterns

// Link two ir::Nodes from each other.
Expand Down
119 changes: 119 additions & 0 deletions paddle/fluid/framework/ir/replace_dense_with_sparse_pass.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/framework/ir/replace_dense_with_sparse_pass.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/op_version_registry.h"

namespace paddle {
namespace framework {
namespace ir {

ReplaceDenseWithSparsePass::ReplaceDenseWithSparsePass() {
AddOpCompat(OpCompat("fc"))
.AddInput("Input")
.IsTensor()
.End()
.AddInput("W")
.IsTensor()
.End()
.AddInput("Bias")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End();
}

void ReplaceDenseWithSparsePass::ApplyImpl(Graph *graph) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));

std::string name_scope = "replace_dense_with_sparse_pass";
b3602sss marked this conversation as resolved.
Show resolved Hide resolved
FusePassBase::Init(name_scope, graph);
GraphPatternDetector gpd;

patterns::DenseFC dense_fc_pattern(gpd.mutable_pattern(),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

在op_convert中,会将某些mul op转为fc op, 这种情况是不是也要转为稀疏fc?

例如:
在分布式推理时,mul和elementwise会被通信OP分开,导致其无法融合为FC op,而是mul op单独被converter转为FC op.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这种情况应该是分布式要解决的,应该不是我们层面考虑的。

"dense_replace_pass");
dense_fc_pattern();
int found_dense_fc_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph,
Graph *g) {
VLOG(4) << "Replace dense fc with sparse_fc.";

/* if (!IsCompat(subgraph, g)) {
LOG(WARNING) << "Pass in op compat failed.";
return;
}*/

GET_IR_NODE_FROM_SUBGRAPH(fc_out, fc_out, dense_fc_pattern);
GET_IR_NODE_FROM_SUBGRAPH(fc, fc, dense_fc_pattern);
GET_IR_NODE_FROM_SUBGRAPH(fc_input, fc_input, dense_fc_pattern);
GET_IR_NODE_FROM_SUBGRAPH(fc_weights, fc_weights, dense_fc_pattern);
GET_IR_NODE_FROM_SUBGRAPH(fc_bias, fc_bias, dense_fc_pattern);

auto *fc_op = fc->Op();
auto w_name = fc_op->Input("W")[0];
// recognize sparse op by name
if (w_name.find("sparse_2_4") != w_name.npos) {
// fake op
OpDesc desc(fc_op->Block());
desc.SetType("sparse_fc");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个命名需要和其它稀疏方式区分么?比如后续可能会添加的block sparsity

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个可以后续定好的后改,用户没有感知。

desc.SetInput("Input", {fc_input->Name()});
desc.SetInput("W", {fc_weights->Name()});
desc.SetInput("Bias", {fc_bias->Name()});
desc.SetOutput("Out", {fc_out->Name()});

// copy all attr
if (fc_op->HasAttr("x_num_col_dims")) {
desc.SetAttr("x_num_col_dims", fc_op->GetAttr("x_num_col_dims"));
}
if (fc_op->HasAttr("in_num_col_dims")) {
desc.SetAttr("in_num_col_dims", fc_op->GetAttr("in_num_col_dims"));
}
desc.SetAttr("activation_type", fc_op->GetAttr("activation_type"));
if (fc_op->HasAttr("enable_int8")) {
desc.SetAttr("enable_int8", fc_op->GetAttr("enable_int8"));
}
if (fc_op->HasAttr("Input_scale")) {
desc.SetAttr("Input_scale", fc_op->GetAttr("Input_scale"));
}
if (fc_op->HasAttr("support_int8")) {
desc.SetAttr("support_int8", fc_op->GetAttr("support_int8"));
}
if (fc_op->HasAttr("out_threshold")) {
desc.SetAttr("out_threshold", fc_op->GetAttr("out_threshold"));
}
desc.Flush();
GraphSafeRemoveNodes(g, {fc});
auto sparse_fc_node = g->CreateOpNode(&desc);

IR_NODE_LINK_TO(fc_input, sparse_fc_node);
IR_NODE_LINK_TO(fc_weights, sparse_fc_node);
IR_NODE_LINK_TO(fc_bias, sparse_fc_node);
IR_NODE_LINK_TO(sparse_fc_node, fc_out);
found_dense_fc_count++;
}
};

gpd(graph, handler);
AddStatis(found_dense_fc_count);
}

} // namespace ir
} // namespace framework
} // namespace paddle

REGISTER_PASS(replace_dense_with_sparse_pass,
paddle::framework::ir::ReplaceDenseWithSparsePass);
45 changes: 45 additions & 0 deletions paddle/fluid/framework/ir/replace_dense_with_sparse_pass.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */

#pragma once

#include <string>

#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/inference/api/paddle_analysis_config.h"

namespace paddle {
namespace framework {
namespace ir {

/**
* Replace dense op with sparse op
*/
class Graph;

class ReplaceDenseWithSparsePass : public FusePassBase {
public:
ReplaceDenseWithSparsePass();

protected:
void ApplyImpl(ir::Graph* graph) const override;

const std::string name_scope_{"replace_dense_with_sparse_pass"};
};

} // namespace ir
} // namespace framework
} // namespace paddle
3 changes: 3 additions & 0 deletions paddle/fluid/inference/api/analysis_predictor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1765,6 +1765,9 @@ USE_TRT_CONVERTER(fused_preln_embedding_eltwise_layernorm)
USE_TRT_CONVERTER(preln_skip_layernorm)
USE_TRT_CONVERTER(roll)
USE_TRT_CONVERTER(strided_slice)
#if PADDLE_WITH_CUSPARSELT && IS_TRT_VERSION_GE(8000)
USE_TRT_CONVERTER(sparse_fc)
#endif
#endif

namespace paddle_infer {
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/inference/api/paddle_pass_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ const std::vector<std::string> kTRTSubgraphPasses({
"trt_map_matmul_to_mul_pass", //
"fc_fuse_pass", //
"conv_elementwise_add_fuse_pass", //
"replace_dense_with_sparse_pass", //
"tensorrt_subgraph_pass", //
"conv_bn_fuse_pass", //
#if CUDNN_VERSION >= 7100 // To run conv_fusion, the version of cudnn must be
Expand Down
41 changes: 15 additions & 26 deletions paddle/fluid/inference/tensorrt/convert/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,31 +1,20 @@
# Add TRT tests
list(APPEND CONVERT_FILES matmul_op.cc conv2d_op.cc fc_op.cc pool2d_op.cc elementwise_op.cc
batch_norm_op.cc activation_op.cc softmax_op.cc concat_op.cc dropout_op.cc group_norm_op.cc
pad_op.cc split_op.cc prelu_op.cc leaky_relu_op.cc gelu_op.cc layer_norm_op.cc multihead_matmul_op.cc
shuffle_channel_op.cc swish_op.cc instance_norm_op.cc stack_op.cc transpose_op.cc flatten_op.cc flatten_contiguous_range_op.cc
emb_eltwise_layernorm.cc skip_layernorm.cc scale_op.cc slice_op.cc hard_sigmoid_op.cc hard_swish_op.cc clip_op.cc
gather_op.cc anchor_generator_op.cc yolo_box_op.cc roi_align_op.cc affine_channel_op.cc multiclass_nms_op.cc
multiclass_nms3_op.cc nearest_interp_op.cc reshape_op.cc reduce_op.cc gather_nd_op.cc tile_op.cc
conv3d_op.cc mish_op.cc nearest_interp_v2_op.cc pool3d_op.cc deformable_conv_op.cc preln_emb_eltwise_layernorm.cc
preln_skip_layernorm.cc strided_slice_op.cc roll_op.cc)

if (CUSPARSELT_FOUND AND ${TENSORRT_MAJOR_VERSION} GREATER_EQUAL 8)
list(APPEND CONVERT_FILES sparse_fc_op.cc)
endif()

nv_library(tensorrt_converter
SRCS matmul_op.cc conv2d_op.cc fc_op.cc pool2d_op.cc elementwise_op.cc
batch_norm_op.cc activation_op.cc softmax_op.cc concat_op.cc dropout_op.cc group_norm_op.cc
pad_op.cc split_op.cc prelu_op.cc leaky_relu_op.cc gelu_op.cc layer_norm_op.cc multihead_matmul_op.cc
shuffle_channel_op.cc swish_op.cc instance_norm_op.cc stack_op.cc transpose_op.cc flatten_op.cc flatten_contiguous_range_op.cc
emb_eltwise_layernorm.cc skip_layernorm.cc scale_op.cc slice_op.cc hard_sigmoid_op.cc hard_swish_op.cc clip_op.cc
gather_op.cc
anchor_generator_op.cc
yolo_box_op.cc
roi_align_op.cc
affine_channel_op.cc
multiclass_nms_op.cc
multiclass_nms3_op.cc
nearest_interp_op.cc
reshape_op.cc
reduce_op.cc
gather_nd_op.cc
tile_op.cc
conv3d_op.cc
mish_op.cc
nearest_interp_v2_op.cc
pool3d_op.cc
deformable_conv_op.cc
preln_emb_eltwise_layernorm.cc
strided_slice_op.cc
preln_skip_layernorm.cc
roll_op.cc
SRCS ${CONVERT_FILES}
DEPS tensorrt_engine tensorrt_plugin operator scope framework_proto op_registry)

nv_test(test_op_converter SRCS test_op_converter.cc DEPS
Expand Down
Loading