Skip to content

Commit

Permalink
style check
Browse files Browse the repository at this point in the history
  • Loading branch information
minghaoBD committed May 23, 2022
1 parent e574ad9 commit 5231918
Show file tree
Hide file tree
Showing 7 changed files with 98 additions and 72 deletions.
45 changes: 27 additions & 18 deletions paddle/fluid/framework/ir/graph_pattern_detector.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3450,30 +3450,39 @@ PDNode *patterns::DenseFC::operator()() {
}

PDNode *patterns::MultiheadMatmul::operator()() {
auto *multihead_matmul = pattern->NewNode(multihead_matmul_repr())->assert_is_op("multihead_matmul");
auto *multihead_matmul = pattern->NewNode(multihead_matmul_repr())
->assert_is_op("multihead_matmul");
// Input
auto *multihead_matmul_input = pattern->NewNode(multihead_matmul_input_repr())
->AsInput()
->assert_is_op_input("multihead_matmul", "Input");
auto *multihead_matmul_input =
pattern->NewNode(multihead_matmul_input_repr())
->AsInput()
->assert_is_op_input("multihead_matmul", "Input");
// Filter
auto *multihead_matmul_weights = pattern->NewNode(multihead_matmul_weights_repr())
->AsInput()
->assert_is_op_input("multihead_matmul", "W");
auto *multihead_matmul_weights =
pattern->NewNode(multihead_matmul_weights_repr())
->AsInput()
->assert_is_op_input("multihead_matmul", "W");
// Bias
auto *multihead_matmul_bias = pattern->NewNode(multihead_matmul_bias_repr())
->AsInput()
->assert_is_op_input("multihead_matmul", "Bias");
auto *multihead_matmul_bias =
pattern->NewNode(multihead_matmul_bias_repr())
->AsInput()
->assert_is_op_input("multihead_matmul", "Bias");
// BiasQK
auto *multihead_matmul_biasqk = pattern->NewNode(multihead_matmul_biasqk_repr())
->AsInput()
->assert_is_op_input("multihead_matmul", "BiasQK");
auto *multihead_matmul_biasqk =
pattern->NewNode(multihead_matmul_biasqk_repr())
->AsInput()
->assert_is_op_input("multihead_matmul", "BiasQK");
// Output
auto *multihead_matmul_out = pattern->NewNode(multihead_matmul_out_repr())
->AsOutput()
->assert_is_op_output("multihead_matmul", "Out")
->assert_is_only_output_of_op("multihead_matmul");
auto *multihead_matmul_out =
pattern->NewNode(multihead_matmul_out_repr())
->AsOutput()
->assert_is_op_output("multihead_matmul", "Out")
->assert_is_only_output_of_op("multihead_matmul");

multihead_matmul->LinksFrom({multihead_matmul_input, multihead_matmul_weights, multihead_matmul_bias, multihead_matmul_biasqk}).LinksTo({multihead_matmul_out});
multihead_matmul
->LinksFrom({multihead_matmul_input, multihead_matmul_weights,
multihead_matmul_bias, multihead_matmul_biasqk})
.LinksTo({multihead_matmul_out});

return multihead_matmul_out;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ namespace paddle {
namespace framework {
namespace ir {

ReplaceDenseMultiheadMatmulWithSparsePass::ReplaceDenseMultiheadMatmulWithSparsePass() {
ReplaceDenseMultiheadMatmulWithSparsePass::
ReplaceDenseMultiheadMatmulWithSparsePass() {
AddOpCompat(OpCompat("multihead_matmul"))
.AddInput("Input")
.IsTensor()
Expand All @@ -47,8 +48,8 @@ void ReplaceDenseMultiheadMatmulWithSparsePass::ApplyImpl(Graph *graph) const {
FusePassBase::Init(name_scope, graph);
GraphPatternDetector gpd;

patterns::MultiheadMatmul multihead_matmul_pattern(gpd.mutable_pattern(),
"dense_multihead_matmul_replace_pass");
patterns::MultiheadMatmul multihead_matmul_pattern(
gpd.mutable_pattern(), "dense_multihead_matmul_replace_pass");
multihead_matmul_pattern();
int found_multihead_matmul_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph,
Expand All @@ -60,12 +61,19 @@ void ReplaceDenseMultiheadMatmulWithSparsePass::ApplyImpl(Graph *graph) const {
return;
}*/

GET_IR_NODE_FROM_SUBGRAPH(multihead_matmul_out, multihead_matmul_out, multihead_matmul_pattern);
GET_IR_NODE_FROM_SUBGRAPH(multihead_matmul, multihead_matmul, multihead_matmul_pattern);
GET_IR_NODE_FROM_SUBGRAPH(multihead_matmul_input, multihead_matmul_input, multihead_matmul_pattern);
GET_IR_NODE_FROM_SUBGRAPH(multihead_matmul_weights, multihead_matmul_weights, multihead_matmul_pattern);
GET_IR_NODE_FROM_SUBGRAPH(multihead_matmul_bias, multihead_matmul_bias, multihead_matmul_pattern);
GET_IR_NODE_FROM_SUBGRAPH(multihead_matmul_biasqk, multihead_matmul_biasqk, multihead_matmul_pattern);
GET_IR_NODE_FROM_SUBGRAPH(multihead_matmul_out, multihead_matmul_out,
multihead_matmul_pattern);
GET_IR_NODE_FROM_SUBGRAPH(multihead_matmul, multihead_matmul,
multihead_matmul_pattern);
GET_IR_NODE_FROM_SUBGRAPH(multihead_matmul_input, multihead_matmul_input,
multihead_matmul_pattern);
GET_IR_NODE_FROM_SUBGRAPH(multihead_matmul_weights,
multihead_matmul_weights,
multihead_matmul_pattern);
GET_IR_NODE_FROM_SUBGRAPH(multihead_matmul_bias, multihead_matmul_bias,
multihead_matmul_pattern);
GET_IR_NODE_FROM_SUBGRAPH(multihead_matmul_biasqk, multihead_matmul_biasqk,
multihead_matmul_pattern);

auto *multihead_matmul_op = multihead_matmul->Op();
auto w_name = multihead_matmul_op->Input("W")[0];
Expand All @@ -84,19 +92,23 @@ void ReplaceDenseMultiheadMatmulWithSparsePass::ApplyImpl(Graph *graph) const {
desc.SetAttr("alpha", multihead_matmul_op->GetAttr("alpha"));
desc.SetAttr("head_number", multihead_matmul_op->GetAttr("head_number"));
if (multihead_matmul_op->HasAttr("Input_scale")) {
desc.SetAttr("Input_scale", multihead_matmul_op->GetAttr("Input_scale"));
desc.SetAttr("Input_scale",
multihead_matmul_op->GetAttr("Input_scale"));
}
if (multihead_matmul_op->HasAttr("fc_out_threshold")) {
desc.SetAttr("fc_out_threshold", multihead_matmul_op->GetAttr("fc_out_threshold"));
desc.SetAttr("fc_out_threshold",
multihead_matmul_op->GetAttr("fc_out_threshold"));
}
if (multihead_matmul_op->HasAttr("qkv2context_plugin_int8")) {
desc.SetAttr("qkv2context_plugin_int8", multihead_matmul_op->GetAttr("qkv2context_plugin_int8"));
desc.SetAttr("qkv2context_plugin_int8",
multihead_matmul_op->GetAttr("qkv2context_plugin_int8"));
}
if (multihead_matmul_op->HasAttr("dp_probs")) {
desc.SetAttr("dp_probs", multihead_matmul_op->GetAttr("dp_probs"));
}
if (multihead_matmul_op->HasAttr("out_threshold")) {
desc.SetAttr("out_threshold", multihead_matmul_op->GetAttr("out_threshold"));
desc.SetAttr("out_threshold",
multihead_matmul_op->GetAttr("out_threshold"));
}
desc.Flush();
GraphSafeRemoveNodes(g, {multihead_matmul});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ class ReplaceDenseMultiheadMatmulWithSparsePass : public FusePassBase {
protected:
void ApplyImpl(ir::Graph* graph) const override;

const std::string name_scope_{"replace_dense_multihead_matmul_with_sparse_pass"};
const std::string name_scope_{
"replace_dense_multihead_matmul_with_sparse_pass"};
};

} // namespace ir
Expand Down
42 changes: 21 additions & 21 deletions paddle/fluid/inference/api/paddle_pass_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -91,27 +91,27 @@ const std::vector<std::string> kTRTSubgraphPasses({
"delete_quant_dequant_linear_op_pass", //
"add_support_int8_pass", //
// "fc_fuse_pass", //
"simplify_with_basic_ops_pass", //
"embedding_eltwise_layernorm_fuse_pass", //
"preln_embedding_eltwise_layernorm_fuse_pass", //
"multihead_matmul_fuse_pass_v2", //
"multihead_matmul_fuse_pass_v3", //
"skip_layernorm_fuse_pass", //
"preln_skip_layernorm_fuse_pass", //
"conv_bn_fuse_pass", //
"unsqueeze2_eltwise_fuse_pass", //
"trt_squeeze2_matmul_fuse_pass", //
"trt_reshape2_matmul_fuse_pass", //
"trt_flatten2_matmul_fuse_pass", //
"trt_map_matmul_v2_to_mul_pass", //
"trt_map_matmul_v2_to_matmul_pass", //
"trt_map_matmul_to_mul_pass", //
"fc_fuse_pass", //
"conv_elementwise_add_fuse_pass", //
"replace_dense_with_sparse_pass", //
"replace_dense_multihead_matmul_with_sparse_pass", //
"tensorrt_subgraph_pass", //
"conv_bn_fuse_pass", //
"simplify_with_basic_ops_pass", //
"embedding_eltwise_layernorm_fuse_pass", //
"preln_embedding_eltwise_layernorm_fuse_pass", //
"multihead_matmul_fuse_pass_v2", //
"multihead_matmul_fuse_pass_v3", //
"skip_layernorm_fuse_pass", //
"preln_skip_layernorm_fuse_pass", //
"conv_bn_fuse_pass", //
"unsqueeze2_eltwise_fuse_pass", //
"trt_squeeze2_matmul_fuse_pass", //
"trt_reshape2_matmul_fuse_pass", //
"trt_flatten2_matmul_fuse_pass", //
"trt_map_matmul_v2_to_mul_pass", //
"trt_map_matmul_v2_to_matmul_pass", //
"trt_map_matmul_to_mul_pass", //
"fc_fuse_pass", //
"conv_elementwise_add_fuse_pass", //
"replace_dense_with_sparse_pass", //
"replace_dense_multihead_matmul_with_sparse_pass", //
"tensorrt_subgraph_pass", //
"conv_bn_fuse_pass", //
#if CUDNN_VERSION >= 7100 // To run conv_fusion, the version of cudnn must be
// guaranteed at least v7
// cudnn8.0 has memory leak problem in conv + eltwise + act, so we
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ class SparseMultiheadMatMulOpConverter : public OpConverter {

void operator()(const framework::proto::OpDesc& op,
const framework::Scope& scope, bool test_mode) override {
VLOG(3) << "convert a fluid sparse_multihead_matmul op to a corresponding tensorrt "
VLOG(3) << "convert a fluid sparse_multihead_matmul op to a corresponding "
"tensorrt "
"network structure";
framework::OpDesc op_desc(op, nullptr);
// Declare inputs
Expand Down Expand Up @@ -335,9 +336,9 @@ class SparseMultiheadMatMulOpConverter : public OpConverter {
w_data = static_cast<void*>(weight_data);
}

TensorRTEngine::Weight weight{with_fp16 ? nvinfer1::DataType::kHALF : nvinfer1::DataType::kFLOAT,
static_cast<void*>(w_data),
static_cast<size_t>(weight_t->numel())};
TensorRTEngine::Weight weight{
with_fp16 ? nvinfer1::DataType::kHALF : nvinfer1::DataType::kFLOAT,
static_cast<void*>(w_data), static_cast<size_t>(weight_t->numel())};
weight.dims.assign({n, m});

TensorRTEngine::Weight bias{nvinfer1::DataType::kFLOAT,
Expand All @@ -360,22 +361,23 @@ class SparseMultiheadMatMulOpConverter : public OpConverter {
}
reshape_before_fc_layer->setReshapeDimensions(reshape_before_fc_dim);
reshape_before_fc_layer->setName(
("shuffle_before_sparse_multihead_mamul(Output: " + output_name + ")")
("shuffle_before_sparse_multihead_mamul(Output: " + output_name +
")")
.c_str());

// add layer fc
nvinfer1::ILayer* fc_layer = nullptr;
if (op_desc.HasAttr("Input_scale")) {
plugin::SpmmPluginDynamic* plugin = new_spmm_plugin(
&weight, &bias, nvinfer1::DataType::kINT8, n);
plugin::SpmmPluginDynamic* plugin =
new_spmm_plugin(&weight, &bias, nvinfer1::DataType::kINT8, n);
std::vector<nvinfer1::ITensor*> plugin_inputs;
plugin_inputs.emplace_back(reshape_before_fc_layer->getOutput(0));
fc_layer = engine_->network()->addPluginV2(
plugin_inputs.data(), plugin_inputs.size(), *plugin);
} else {
plugin::SpmmPluginDynamic* plugin = new_spmm_plugin(
&weight, &bias,
with_fp16 ? nvinfer1::DataType::kHALF : nvinfer1::DataType::kFLOAT,
&weight, &bias, with_fp16 ? nvinfer1::DataType::kHALF
: nvinfer1::DataType::kFLOAT,
n);
std::vector<nvinfer1::ITensor*> plugin_inputs;
plugin_inputs.emplace_back(reshape_before_fc_layer->getOutput(0));
Expand Down Expand Up @@ -430,4 +432,5 @@ class SparseMultiheadMatMulOpConverter : public OpConverter {
} // namespace inference
} // namespace paddle

REGISTER_TRT_OP_CONVERTER(sparse_multihead_matmul, SparseMultiheadMatMulOpConverter);
REGISTER_TRT_OP_CONVERTER(sparse_multihead_matmul,
SparseMultiheadMatMulOpConverter);
5 changes: 3 additions & 2 deletions paddle/fluid/inference/tensorrt/op_teller.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ struct SimpleOpTypeSetTeller : public Teller {
teller_set.insert("sparse_fc");
int8_teller_set.insert("sparse_fc");
teller_set.insert("sparse_multihead_matmul");
int8_teller_set.insert("sparse_multihead_matmul");
int8_teller_set.insert("sparse_multihead_matmul");
#endif
}

Expand Down Expand Up @@ -1742,7 +1742,8 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
#if IS_TRT_VERSION_GE(8000)
if (op_type == "sparse_fc" || op_type == "sparse_multihead_matmul") {
if (!with_dynamic_shape) {
VLOG(3) << "the sparse_fc and sparse_multihead_matmul does not support static shape yet";
VLOG(3) << "the sparse_fc and sparse_multihead_matmul does not support "
"static shape yet";
return false;
}
}
Expand Down
14 changes: 7 additions & 7 deletions paddle/fluid/inference/tensorrt/plugin/spmm_plugin.cu
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,7 @@ SpmmPluginDynamic::SpmmPluginDynamic(const std::string& layer_name,
compressed_size);
cudaMemcpy(weight_compressed_dev_, weight_compressed_, compressed_size,
cudaMemcpyHostToDevice);

has_bias_ = (bias != nullptr);
if (has_bias_) {
// Each plugin has a copy of bias
Expand Down Expand Up @@ -552,7 +552,7 @@ void SpmmPluginDynamic::configurePlugin(
platform::errors::InvalidArgument(
"precision_ should be equal to inputs[0].desc.type"));
const auto& inDims0 = inputs[0].desc.dims;
if (inDims0.nbDims==5) {
if (inDims0.nbDims == 5) {
PADDLE_ENFORCE_EQ(inDims0.nbDims, 5, platform::errors::InvalidArgument(
"inDims0.nbDims should be 5"));
PADDLE_ENFORCE_EQ(k_, inDims0.d[2],
Expand All @@ -565,7 +565,7 @@ void SpmmPluginDynamic::configurePlugin(
const int BS = inputs->max.d[0];
const int Seq = inputs->max.d[1];
m_max_ = BS * Seq;
} else if (inDims0.nbDims==4) {
} else if (inDims0.nbDims == 4) {
PADDLE_ENFORCE_EQ(inDims0.nbDims, 4, platform::errors::InvalidArgument(
"inDims0.nbDims should be 4"));
PADDLE_ENFORCE_EQ(k_, inDims0.d[1],
Expand All @@ -576,7 +576,7 @@ void SpmmPluginDynamic::configurePlugin(
PADDLE_ENFORCE_EQ(inDims0.d[3], 1, platform::errors::InvalidArgument(
"inDims0.d[3] should be 1"));
const int BS_Seq = inputs->max.d[0];
m_max_ = BS_Seq;
m_max_ = BS_Seq;
}
// The optimal algorighm id is for m = m_max_
// To Do: configurePlugin takes time when m is changed
Expand Down Expand Up @@ -642,15 +642,15 @@ int SpmmPluginDynamic::enqueue(const nvinfer1::PluginTensorDesc* inputDesc,
PADDLE_ENFORCE_EQ(is_configured_, true,
platform::errors::InvalidArgument(
"The plugin is not configured before enqueue"));
if (inputDesc->dims.nbDims==5){
if (inputDesc->dims.nbDims == 5) {
PADDLE_ENFORCE_EQ(
k_, inputDesc->dims.d[2],
platform::errors::InvalidArgument("k_ == inputDesc->dims.d[2]"));
} else if (inputDesc->dims.nbDims==4) {
} else if (inputDesc->dims.nbDims == 4) {
PADDLE_ENFORCE_EQ(
k_, inputDesc->dims.d[1],
platform::errors::InvalidArgument("k_ == inputDesc->dims.d[1]"));
}
}
float alpha = 1.0f;
float beta = 0.0f;
if (inputDesc->type == nvinfer1::DataType::kFLOAT) {
Expand Down

0 comments on commit 5231918

Please sign in to comment.