Skip to content

Commit

Permalink
add unit test
Browse files Browse the repository at this point in the history
  • Loading branch information
RichardWooSJTU committed Nov 2, 2022
1 parent e152480 commit 15ea2d7
Show file tree
Hide file tree
Showing 4 changed files with 324 additions and 31 deletions.
4 changes: 4 additions & 0 deletions paddle/fluid/framework/ir/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,10 @@ cc_test(
test_fused_multi_transformer_decoder_pass
SRCS fused_multi_transformer_decoder_pass_tester.cc
DEPS fused_multi_transformer_decoder_pass)
cc_test(
test_fuse_multi_transformer_layer_pass
SRCS fuse_multi_transformer_layer_pass_tester.cc
DEPS fuse_multi_transformer_layer_pass)
cc_test(
test_conv_bn_fuse_pass_cc
SRCS conv_bn_fuse_pass_tester.cc
Expand Down
39 changes: 8 additions & 31 deletions paddle/fluid/framework/ir/fuse_multi_transformer_layer_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ inline void MergeInput(OpDesc* op,
const std::vector<VariableNameMap>& input_name_maps,
const std::string& input_name) {
std::vector<std::string> tmp = input_name_maps[0].at(input_name);
for (int i = 1; i < input_name_maps.size(); ++i) {
for (size_t i = 1; i < input_name_maps.size(); ++i) {
tmp.insert(tmp.end(),
input_name_maps[i].at(input_name).begin(),
input_name_maps[i].at(input_name).end());
Expand All @@ -141,7 +141,7 @@ template <typename T>
inline void MergeAttrs(const std::vector<OpDesc*>& ops,
const std::string& attr_name) {
std::vector<T> res;
for (int i = 0; i < ops.size(); ++i) {
for (size_t i = 0; i < ops.size(); ++i) {
auto scale_vec =
PADDLE_GET_CONST(std::vector<T>, ops[i]->GetAttr(attr_name));
res.insert(res.end(), scale_vec.begin(), scale_vec.end());
Expand All @@ -156,25 +156,19 @@ int FuseMultiTransformerLayerPass::BuildFusion(Graph* graph,
auto* pattern = gpd.mutable_pattern();
VLOG(0) << "In build fusion";

// TODO(wufeisheng): Get enable_int8 attr from graph after
// fused_multi_transformer pass with int8 merged
bool enable_int8 = false;
if (graph->Has("enable_int8")) {
enable_int8 = graph->Get<bool>("enable_int8");
}
if (!enable_int8) {
VLOG(4)
<< "fuse_multi_transformer_layer_pass will match float transformer op "
"cause enable_int8 is not been set or set to false";
}

int num_fuse_op = 0;
bool is_decoder = false;
if (graph->Has("enable_int8")) {
num_fuse_op = graph->Get<int>("num_fused_multi_transformer_op");
}

if (graph->Has(kFusedMultiTransformerEncoderFusionCount)) {
VLOG(0) << "encoder fusion count";
num_fuse_op = graph->Get<int>(kFusedMultiTransformerEncoderFusionCount);
is_decoder = false;
} else if (graph->Has(kFusedMultiTransformerDecoderFusionCount)) {
VLOG(0) << "decoder fusion count";
num_fuse_op = graph->Get<int>(kFusedMultiTransformerDecoderFusionCount);
is_decoder = true;
}
Expand Down Expand Up @@ -280,13 +274,7 @@ int FuseMultiTransformerLayerPass::BuildFusion(Graph* graph,
"OutLinearW",
"QKVBias",
"QKVW"};
if (enable_int8) {
std::vector<std::string> inputs_names_int8_supp = {
"FFN1OutScale", "FFN2OutScale", "OutLinearOutScale", "QKVOutScale"};
inputs_names.insert(inputs_names.end(),
inputs_names_int8_supp.begin(),
inputs_names_int8_supp.end());
}

for (const auto& input_name : inputs_names) {
MergeInput(fuse_op_descs[0], fuse_op_input_var_name_maps, input_name);
}
Expand All @@ -308,17 +296,6 @@ int FuseMultiTransformerLayerPass::BuildFusion(Graph* graph,
// }
fuse_op_descs[0]->SetOutput("CacheKVOut", merged_cache_kv_out_names);

if (enable_int8) {
// Merge inputs scale
std::vector<std::string> attr_names = {"qkv_in_scale",
"out_linear_in_scale",
"ffn1_in_scale",
"ffn2_in_scale"};
for (const auto& name : attr_names) {
MergeAttrs<float>(fuse_op_descs, name);
}
VLOG(0) << "Finsh Merge attrs";
}
////////////////
//// ReLink ////
////////////////
Expand Down
199 changes: 199 additions & 0 deletions paddle/fluid/framework/ir/fuse_multi_transformer_layer_pass_tester.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include <gtest/gtest.h>

#include "paddle/fluid/framework/ir/fuse_multi_transformer_layer_pass.h"
#include "paddle/fluid/framework/ir/pass_tester_helper.h"
#include "paddle/fluid/framework/op_version_registry.h"

#define DEF_INPUT_DATA \
Layers layers; \
int num_layers = 3; \
auto* x = layers.data("x", {1, 128, 1024}); \
auto* src_mask = layers.data("src_mask", {1, 16, 128, 128}); \
auto* ln_scale = layers.data("ln_scale", {1024}, true); \
auto* ln_bias = layers.data("ln_bias", {1024}, true); \
auto* ffn_ln_scale = layers.data("ffn_ln_scale", {1024}, true); \
auto* ffn_ln_bias = layers.data("ffn_ln_bias", {1024}, true); \
auto* qkv_w = layers.data("qkv_w", {3, 16, 64, 1024}, true); \
auto* out_linear_w = layers.data("out_linear_w", {1024, 1024}, true); \
auto* ffn1_w = layers.data("ffn1_w", {1024, 4096}, true); \
auto* ffn2_w = layers.data("ffn2_w", {4096, 1024}, true); \
auto* qkv_bias = layers.data("qkv_bias", {3072}, true); \
auto* out_linear_bias = layers.data("out_linear_bias", {1024}, true); \
auto* ffn1_bias = layers.data("ffn1_bias", {4096}, true); \
auto* ffn2_bias = layers.data("ffn2_bias", {1024}, true);

namespace paddle {
namespace framework {
namespace ir {

void AddVarToScope(Scope* param_scope,
const std::string& name,
const DDim& dims) {
auto* tensor = param_scope->Var(name)->GetMutable<phi::DenseTensor>();
tensor->Resize(dims);
tensor->mutable_data<float>(platform::CPUPlace());
}

Scope* CreateParamScope() {
auto param_scope = new Scope();
AddVarToScope(param_scope, "ln_scale", {1024});
AddVarToScope(param_scope, "ln_bias", {1024});
AddVarToScope(param_scope, "ffn_ln_scale", {1024});
AddVarToScope(param_scope, "ffn_ln_bias", {1024});

AddVarToScope(param_scope, "qkv_w", {3, 16, 64, 1024});
AddVarToScope(param_scope, "out_linear_w", {1024, 1024});
AddVarToScope(param_scope, "ffn1_w", {1024, 4096});
AddVarToScope(param_scope, "ffn2_w", {4096, 1024});
AddVarToScope(param_scope, "qkv_bias", {3072});
AddVarToScope(param_scope, "out_linear_bias", {1024});
AddVarToScope(param_scope, "ffn1_bias", {4096});
AddVarToScope(param_scope, "ffn2_bias", {1024});

return param_scope;
}
TEST(FuseMultiTransformerLayerPass, encoder_fp) {
// Layers layers;
// int num_layers = 3;
// // Vars
// auto* x = layers.data("x", {1, 128, 1024});
// auto* src_mask = layers.data("src_mask", {1, 16, 128, 128});

// auto* ln_scale = layers.data("ln_scale", {1024}, true);
// auto* ln_bias = layers.data("ln_bias", {1024}, true);
// auto* ffn_ln_scale = layers.data("ffn_ln_scale", {1024}, true);
// auto* ffn_ln_bias = layers.data("ffn_ln_bias", {1024}, true);
// auto* qkv_w = layers.data("qkv_w", {3, 16, 64, 1024}, true);
// auto* out_linear_w = layers.data("out_linear_w", {1024, 1024}, true);
// auto* ffn1_w = layers.data("ffn1_w", {1024, 4096}, true);
// auto* ffn2_w = layers.data("ffn2_w", {4096, 1024}, true);
// auto* qkv_bias = layers.data("qkv_bias", {3072}, true);
// auto* out_linear_bias = layers.data("out_linear_bias", {1024}, true);
// auto* ffn1_bias = layers.data("ffn1_bias", {4096}, true);
// auto* ffn2_bias = layers.data("ffn2_bias", {1024}, true);

DEF_INPUT_DATA

// Layers
for (int i = 0; i < num_layers; ++i) {
std::cout << "begin to add fill const layer " << i << std::endl;
auto* cache_kv = layers.fill_constant_batch_size_like(
x,
static_cast<int>(proto::VarType::FP32),
0,
1,
{2, -1, 16, 1024, 64},
0);
std::cout << "begin to add fused_multi_transformer layer " << i
<< std::endl;
auto* out = layers.fused_multi_transformer(x,
cache_kv,
src_mask,
qkv_w,
qkv_bias,
out_linear_w,
out_linear_bias,
ffn1_w,
ffn1_bias,
ffn2_w,
ffn2_bias,
ln_scale,
ln_bias,
ffn_ln_scale,
ffn_ln_bias,
0.1,
1e-12);

x = out;
}
std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program()));
graph->Set("__param_scope__", CreateParamScope());
graph->Set(kFusedMultiTransformerEncoderFusionCount, new int(num_layers));

auto pass = PassRegistry::Instance().Get("fuse_multi_transformer_layer_pass");
if (pass.get() == nullptr)
LOG(INFO) << "get fuse_multi_transformer_layer_pass failed";

graph.reset(pass->Apply(graph.release()));
int num_nodes_after = GetNumOpNodes(graph, "fused_multi_transformer");

PADDLE_ENFORCE_EQ(
num_nodes_after,
1,
platform::errors::InvalidArgument(
"After the fuse_multi_transformer_layer_pass, "
"The node num in graph should be 1, but the result is %d",
num_nodes_after));
}
TEST(FuseMultiTransformerLayerPass, decoder_fp) {
DEF_INPUT_DATA

x = layers.data("x", {1, 1, 1024});
auto* cache_kv = layers.data("cache_kv", {2, 1, 16, 1024, 64}, true);
src_mask = layers.data("src_mask", {1, 16, 1, 129});

// Layers
for (int i = 0; i < num_layers; ++i) {
auto* shape_out = layers.shape(src_mask);
auto* time_stamp = layers.slice(shape_out, {0}, {3}, {4});
std::cout << "begin to add fused_multi_transformer layer " << i
<< std::endl;
auto* out = layers.fused_multi_transformer(x,
cache_kv,
src_mask,
qkv_w,
qkv_bias,
out_linear_w,
out_linear_bias,
ffn1_w,
ffn1_bias,
ffn2_w,
ffn2_bias,
ln_scale,
ln_bias,
ffn_ln_scale,
ffn_ln_bias,
0.1,
1e-12,
time_stamp);

x = out;
}
std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program()));
auto param_scope = CreateParamScope();
AddVarToScope(param_scope, "cache_kv", {2, 1, 16, 1024, 64});
graph->Set("__param_scope__", param_scope);

graph->Set(kFusedMultiTransformerDecoderFusionCount, new int(num_layers));

auto pass = PassRegistry::Instance().Get("fuse_multi_transformer_layer_pass");
if (pass.get() == nullptr)
LOG(INFO) << "get fuse_multi_transformer_layer_pass failed";

graph.reset(pass->Apply(graph.release()));
int num_nodes_after = GetNumOpNodes(graph, "fused_multi_transformer");

PADDLE_ENFORCE_EQ(
num_nodes_after,
1,
platform::errors::InvalidArgument(
"After the fuse_multi_transformer_layer_pass, "
"The node num in graph should be 1, but the result is %d",
num_nodes_after));
}
} // namespace ir
} // namespace framework
} // namespace paddle

USE_PASS(fuse_multi_transformer_layer_pass);
Loading

0 comments on commit 15ea2d7

Please sign in to comment.