Skip to content

Commit

Permalink
delete unnecessary shape and slice op
Browse files Browse the repository at this point in the history
  • Loading branch information
RichardWooSJTU committed Nov 17, 2022
1 parent 071708f commit 4f068e8
Show file tree
Hide file tree
Showing 3 changed files with 113 additions and 149 deletions.
45 changes: 1 addition & 44 deletions paddle/fluid/framework/ir/fuse_multi_transformer_layer_pass.cc
Expand Up @@ -62,34 +62,7 @@ MultiTransformerLayerPattern::operator()(bool enable_int8,
fused_multi_transformer_name, "Out");

if (is_decoder) {
auto shape_repr =
PDNodeName(name_scope_, repr_, id_, "shape_" + std::to_string(i));
node_reprs["shape_" + std::to_string(i)] = shape_repr;
auto* shape = pattern->NewNode(shape_repr)->assert_is_op("shape");

auto shape_out_repr =
PDNodeName(name_scope_, repr_, id_, "shape_out_" + std::to_string(i));
node_reprs["shape_out_" + std::to_string(i)] = shape_out_repr;
auto* shape_out =
pattern->NewNode(shape_out_repr)->assert_is_op_output("shape", "Out");

shape->LinksFrom({src_mask}).LinksTo({shape_out});

auto slice_repr =
PDNodeName(name_scope_, repr_, id_, "slice_" + std::to_string(i));
node_reprs["slice_" + std::to_string(i)] = slice_repr;
auto* slice = pattern->NewNode(slice_repr)->assert_is_op("slice");

auto slice_out_repr =
PDNodeName(name_scope_, repr_, id_, "slice_out_" + std::to_string(i));
node_reprs["slice_out_" + std::to_string(i)] = slice_out_repr;
auto* slice_out =
pattern->NewNode(slice_out_repr)->assert_is_op_output("slice", "Out");

slice->LinksFrom({shape_out}).LinksTo({slice_out});

fused_multi_transformer->LinksFrom({x0, src_mask, slice_out})
.LinksTo({out});
fused_multi_transformer->LinksFrom({x0, src_mask}).LinksTo({out});
} else {
auto cache_kv_repr =
PDNodeName(name_scope_, repr_, id_, "cache_kv_" + std::to_string(i));
Expand Down Expand Up @@ -187,10 +160,6 @@ int FuseMultiTransformerLayerPass::BuildFusion(Graph* graph,
std::vector<Node*> fuse_op_nodes;
std::vector<Node*> out_nodes;

std::vector<std::string> unused_node_prefixes = {
"shape_", "shape_out_", "slice_", "slice_out_"};
std::vector<Node*> unused_nodes;

std::vector<OpDesc*> fuse_op_descs;
std::vector<VariableNameMap> fuse_op_input_var_name_maps;
std::vector<VariableNameMap> fuse_op_output_var_name_maps;
Expand Down Expand Up @@ -219,14 +188,6 @@ int FuseMultiTransformerLayerPass::BuildFusion(Graph* graph,
fill_op_node->Op()->SetInput("Input", {x0->Name()});
IR_NODE_UNLINK(out_nodes[i - 1], fill_op_node);
IR_NODE_LINK_TO(x0, fill_op_node);
} else if (is_decoder && i != 0) {
for (const auto& unused_node_prefix : unused_node_prefixes) {
PDNode* unused_pdnode =
multi_layer_pattern.PatternBase::pattern->RetrieveNode(
node_reprs[unused_node_prefix + std::to_string(i)]);
Node* unused_node = subgraph.at(unused_pdnode);
unused_nodes.push_back(unused_node);
}
}
}

Expand Down Expand Up @@ -293,10 +254,6 @@ int FuseMultiTransformerLayerPass::BuildFusion(Graph* graph,
std::unordered_set<const Node*> marked_fuse_op_nodes(
fuse_op_nodes.begin() + 1, fuse_op_nodes.end());

if (is_decoder) {
marked_fuse_op_nodes.insert(unused_nodes.begin(), unused_nodes.end());
}

GraphSafeRemoveNodes(graph, marked_fuse_op_nodes);
++fusion_count;
};
Expand Down
216 changes: 111 additions & 105 deletions paddle/fluid/framework/ir/fused_multi_transformer_decoder_pass.cc
Expand Up @@ -1146,35 +1146,7 @@ int FusedMultiTransformerDecoderPass::BuildFusion(Graph* graph,
auto cache_kv_name = "cache_kv" + std::to_string(layer_idx);
fused_multi_transformer_op_desc.SetInput("CacheKV", {cache_kv_name});

VarDesc shape_out_desc("shape_out." + std::to_string(layer_idx));
shape_out_desc.SetDataType(proto::VarType::INT32);
shape_out_desc.SetPersistable(false);
auto* shape_out = graph->CreateVarNode(&shape_out_desc);

OpDesc shape_op_desc(layer_norm->Op()->Block());
shape_op_desc.SetType("shape");
shape_op_desc.SetInput("Input", {eltadd_qk_b->Name()});
shape_op_desc.SetOutput("Out", {shape_out->Name()});
auto* shape_op = graph->CreateOpNode(&shape_op_desc);

VarDesc slice_out_desc("slice_out." + std::to_string(layer_idx));
slice_out_desc.SetDataType(proto::VarType::INT32);
slice_out_desc.SetPersistable(false);
auto* slice_out = graph->CreateVarNode(&slice_out_desc);

OpDesc slice_op_desc(layer_norm->Op()->Block());
slice_op_desc.SetType("slice");
slice_op_desc.SetInput("Input", {shape_out->Name()});
slice_op_desc.SetOutput("Out", {slice_out->Name()});
std::vector<int> axes = {0};
std::vector<int> starts = {3};
std::vector<int> ends = {4};
slice_op_desc.SetAttr("axes", axes);
slice_op_desc.SetAttr("starts", starts);
slice_op_desc.SetAttr("ends", ends);
auto* slice_op = graph->CreateOpNode(&slice_op_desc);

fused_multi_transformer_op_desc.SetInput("TimeStep", {slice_out->Name()});
fused_multi_transformer_op_desc.SetInput("TimeStep", {"slice_out.0"});

// Out Linear input
fused_multi_transformer_op_desc.SetInput("OutLinearW",
Expand Down Expand Up @@ -1219,12 +1191,42 @@ int FusedMultiTransformerDecoderPass::BuildFusion(Graph* graph,
IR_NODE_LINK_TO(eltadd0_b, fused_multi_transformer);
IR_NODE_LINK_TO(eltadd_qk_b, fused_multi_transformer);

// TimeStep link
IR_NODE_LINK_TO(eltadd_qk_b, shape_op);
IR_NODE_LINK_TO(shape_op, shape_out);
IR_NODE_LINK_TO(shape_out, slice_op);
IR_NODE_LINK_TO(slice_op, slice_out);
IR_NODE_LINK_TO(slice_out, fused_multi_transformer)
if (layer_idx == 0) {
VarDesc shape_out_desc("shape_out.0");
shape_out_desc.SetDataType(proto::VarType::INT32);
shape_out_desc.SetPersistable(false);
auto* shape_out = graph->CreateVarNode(&shape_out_desc);

OpDesc shape_op_desc(layer_norm->Op()->Block());
shape_op_desc.SetType("shape");
shape_op_desc.SetInput("Input", {eltadd_qk_b->Name()});
shape_op_desc.SetOutput("Out", {shape_out->Name()});
auto* shape_op = graph->CreateOpNode(&shape_op_desc);

VarDesc slice_out_desc("slice_out.0");
slice_out_desc.SetDataType(proto::VarType::INT32);
slice_out_desc.SetPersistable(false);
auto* slice_out = graph->CreateVarNode(&slice_out_desc);

OpDesc slice_op_desc(layer_norm->Op()->Block());
slice_op_desc.SetType("slice");
slice_op_desc.SetInput("Input", {shape_out->Name()});
slice_op_desc.SetOutput("Out", {slice_out->Name()});
std::vector<int> axes = {0};
std::vector<int> starts = {3};
std::vector<int> ends = {4};
slice_op_desc.SetAttr("axes", axes);
slice_op_desc.SetAttr("starts", starts);
slice_op_desc.SetAttr("ends", ends);
auto* slice_op = graph->CreateOpNode(&slice_op_desc);

// TimeStep link
IR_NODE_LINK_TO(eltadd_qk_b, shape_op);
IR_NODE_LINK_TO(shape_op, shape_out);
IR_NODE_LINK_TO(shape_out, slice_op);
IR_NODE_LINK_TO(slice_op, slice_out);
IR_NODE_LINK_TO(slice_out, fused_multi_transformer)
}

IR_NODE_LINK_TO(matmul_linear_w, fused_multi_transformer);
IR_NODE_LINK_TO(eltadd_linear_b, fused_multi_transformer);
Expand Down Expand Up @@ -1789,35 +1791,7 @@ int FusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
auto cache_kv_name = "cache_kv" + std::to_string(layer_idx);
fused_multi_transformer_op_desc.SetInput("CacheKV", {cache_kv_name});

VarDesc shape_out_desc("shape_out." + std::to_string(layer_idx));
shape_out_desc.SetDataType(proto::VarType::INT32);
shape_out_desc.SetPersistable(false);
auto* shape_out = graph->CreateVarNode(&shape_out_desc);

OpDesc shape_op_desc(layer_norm->Op()->Block());
shape_op_desc.SetType("shape");
shape_op_desc.SetInput("Input", {eltadd_qk_b->Name()});
shape_op_desc.SetOutput("Out", {shape_out->Name()});
auto* shape_op = graph->CreateOpNode(&shape_op_desc);

VarDesc slice_out_desc("slice_out." + std::to_string(layer_idx));
slice_out_desc.SetDataType(proto::VarType::INT32);
slice_out_desc.SetPersistable(false);
auto* slice_out = graph->CreateVarNode(&slice_out_desc);

OpDesc slice_op_desc(layer_norm->Op()->Block());
slice_op_desc.SetType("slice");
slice_op_desc.SetInput("Input", {shape_out->Name()});
slice_op_desc.SetOutput("Out", {slice_out->Name()});
std::vector<int> axes = {0};
std::vector<int> starts = {3};
std::vector<int> ends = {4};
slice_op_desc.SetAttr("axes", axes);
slice_op_desc.SetAttr("starts", starts);
slice_op_desc.SetAttr("ends", ends);
auto* slice_op = graph->CreateOpNode(&slice_op_desc);

fused_multi_transformer_op_desc.SetInput("TimeStep", {slice_out->Name()});
fused_multi_transformer_op_desc.SetInput("TimeStep", {"slice_out.0"});

// Out Linear input
fused_multi_transformer_op_desc.SetInput("OutLinearW",
Expand Down Expand Up @@ -1862,12 +1836,42 @@ int FusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
IR_NODE_LINK_TO(eltadd0_b, fused_multi_transformer);
IR_NODE_LINK_TO(eltadd_qk_b, fused_multi_transformer);

// TimeStep link
IR_NODE_LINK_TO(eltadd_qk_b, shape_op);
IR_NODE_LINK_TO(shape_op, shape_out);
IR_NODE_LINK_TO(shape_out, slice_op);
IR_NODE_LINK_TO(slice_op, slice_out);
IR_NODE_LINK_TO(slice_out, fused_multi_transformer)
if (layer_idx == 0) {
VarDesc shape_out_desc("shape_out.0");
shape_out_desc.SetDataType(proto::VarType::INT32);
shape_out_desc.SetPersistable(false);
auto* shape_out = graph->CreateVarNode(&shape_out_desc);

OpDesc shape_op_desc(layer_norm->Op()->Block());
shape_op_desc.SetType("shape");
shape_op_desc.SetInput("Input", {eltadd_qk_b->Name()});
shape_op_desc.SetOutput("Out", {shape_out->Name()});
auto* shape_op = graph->CreateOpNode(&shape_op_desc);

VarDesc slice_out_desc("slice_out.0");
slice_out_desc.SetDataType(proto::VarType::INT32);
slice_out_desc.SetPersistable(false);
auto* slice_out = graph->CreateVarNode(&slice_out_desc);

OpDesc slice_op_desc(layer_norm->Op()->Block());
slice_op_desc.SetType("slice");
slice_op_desc.SetInput("Input", {shape_out->Name()});
slice_op_desc.SetOutput("Out", {slice_out->Name()});
std::vector<int> axes = {0};
std::vector<int> starts = {3};
std::vector<int> ends = {4};
slice_op_desc.SetAttr("axes", axes);
slice_op_desc.SetAttr("starts", starts);
slice_op_desc.SetAttr("ends", ends);
auto* slice_op = graph->CreateOpNode(&slice_op_desc);

// TimeStep link
IR_NODE_LINK_TO(eltadd_qk_b, shape_op);
IR_NODE_LINK_TO(shape_op, shape_out);
IR_NODE_LINK_TO(shape_out, slice_op);
IR_NODE_LINK_TO(slice_op, slice_out);
IR_NODE_LINK_TO(slice_out, fused_multi_transformer)
}

IR_NODE_LINK_TO(matmul_linear_w, fused_multi_transformer);
IR_NODE_LINK_TO(eltadd_linear_b, fused_multi_transformer);
Expand Down Expand Up @@ -2405,35 +2409,7 @@ int MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
auto cache_kv_name = "cache_kv" + std::to_string(layer_idx);
fused_multi_transformer_op_desc.SetInput("CacheKV", {cache_kv_name});

VarDesc shape_out_desc("shape_out." + std::to_string(layer_idx));
shape_out_desc.SetDataType(proto::VarType::INT32);
shape_out_desc.SetPersistable(false);
auto* shape_out = graph->CreateVarNode(&shape_out_desc);

OpDesc shape_op_desc(layer_norm->Op()->Block());
shape_op_desc.SetType("shape");
shape_op_desc.SetInput("Input", {eltadd_qk_b->Name()});
shape_op_desc.SetOutput("Out", {shape_out->Name()});
auto* shape_op = graph->CreateOpNode(&shape_op_desc);

VarDesc slice_out_desc("slice_out." + std::to_string(layer_idx));
slice_out_desc.SetDataType(proto::VarType::INT32);
slice_out_desc.SetPersistable(false);
auto* slice_out = graph->CreateVarNode(&slice_out_desc);

OpDesc slice_op_desc(layer_norm->Op()->Block());
slice_op_desc.SetType("slice");
slice_op_desc.SetInput("Input", {shape_out->Name()});
slice_op_desc.SetOutput("Out", {slice_out->Name()});
std::vector<int> axes = {0};
std::vector<int> starts = {3};
std::vector<int> ends = {4};
slice_op_desc.SetAttr("axes", axes);
slice_op_desc.SetAttr("starts", starts);
slice_op_desc.SetAttr("ends", ends);
auto* slice_op = graph->CreateOpNode(&slice_op_desc);

fused_multi_transformer_op_desc.SetInput("TimeStep", {slice_out->Name()});
fused_multi_transformer_op_desc.SetInput("TimeStep", {"slice_out.0"});

// Out Linear input
fused_multi_transformer_op_desc.SetInput("OutLinearW",
Expand Down Expand Up @@ -2483,12 +2459,42 @@ int MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
IR_NODE_LINK_TO(eltadd0_b, fused_multi_transformer);
IR_NODE_LINK_TO(eltadd_qk_b, fused_multi_transformer);

// TimeStep link
IR_NODE_LINK_TO(eltadd_qk_b, shape_op);
IR_NODE_LINK_TO(shape_op, shape_out);
IR_NODE_LINK_TO(shape_out, slice_op);
IR_NODE_LINK_TO(slice_op, slice_out);
IR_NODE_LINK_TO(slice_out, fused_multi_transformer)
if (layer_idx == 0) {
VarDesc shape_out_desc("shape_out.0");
shape_out_desc.SetDataType(proto::VarType::INT32);
shape_out_desc.SetPersistable(false);
auto* shape_out = graph->CreateVarNode(&shape_out_desc);

OpDesc shape_op_desc(layer_norm->Op()->Block());
shape_op_desc.SetType("shape");
shape_op_desc.SetInput("Input", {eltadd_qk_b->Name()});
shape_op_desc.SetOutput("Out", {shape_out->Name()});
auto* shape_op = graph->CreateOpNode(&shape_op_desc);

VarDesc slice_out_desc("slice_out.0");
slice_out_desc.SetDataType(proto::VarType::INT32);
slice_out_desc.SetPersistable(false);
auto* slice_out = graph->CreateVarNode(&slice_out_desc);

OpDesc slice_op_desc(layer_norm->Op()->Block());
slice_op_desc.SetType("slice");
slice_op_desc.SetInput("Input", {shape_out->Name()});
slice_op_desc.SetOutput("Out", {slice_out->Name()});
std::vector<int> axes = {0};
std::vector<int> starts = {3};
std::vector<int> ends = {4};
slice_op_desc.SetAttr("axes", axes);
slice_op_desc.SetAttr("starts", starts);
slice_op_desc.SetAttr("ends", ends);
auto* slice_op = graph->CreateOpNode(&slice_op_desc);

// TimeStep link
IR_NODE_LINK_TO(eltadd_qk_b, shape_op);
IR_NODE_LINK_TO(shape_op, shape_out);
IR_NODE_LINK_TO(shape_out, slice_op);
IR_NODE_LINK_TO(slice_op, slice_out);
IR_NODE_LINK_TO(slice_out, fused_multi_transformer)
}

IR_NODE_LINK_TO(matmul_linear_w, fused_multi_transformer);
IR_NODE_LINK_TO(eltadd_linear_b, fused_multi_transformer);
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/inference/api/paddle_pass_builder.cc
Expand Up @@ -177,6 +177,7 @@ const std::vector<std::string> kGpuLowerPrecisionPasses{
"fused_multi_transformer_decoder_fuse_qkv_pass",
"multi_devices_fused_multi_transformer_encoder_fuse_qkv_pass",
"multi_devices_fused_multi_transformer_decoder_fuse_qkv_pass",
"fuse_multi_transformer_layer_pass",
"gpu_cpu_map_matmul_v2_to_mul_pass",
"gpu_cpu_map_matmul_v2_to_matmul_pass",
"fc_fuse_pass",
Expand Down

0 comments on commit 4f068e8

Please sign in to comment.