diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index 06ea7acb3315e..85deab25dee44 100644 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -96,6 +96,8 @@ pass_library(shuffle_channel_detect_pass inference) pass_library(delete_quant_dequant_op_pass inference) pass_library(delete_quant_dequant_filter_op_pass inference) pass_library(delete_weight_dequant_linear_op_pass inference) +pass_library(delete_weight_dequant_linear_op_encoder_pass inference) +pass_library(delete_weight_dequant_linear_op_decoder_pass inference) pass_library(delete_quant_dequant_linear_op_pass inference) pass_library(delete_dropout_op_pass inference) pass_library(delete_c_identity_op_pass inference) diff --git a/paddle/fluid/framework/ir/delete_quant_dequant_linear_op_pass.cc b/paddle/fluid/framework/ir/delete_quant_dequant_linear_op_pass.cc index 9057f3450453a..e5ecbea39061a 100644 --- a/paddle/fluid/framework/ir/delete_quant_dequant_linear_op_pass.cc +++ b/paddle/fluid/framework/ir/delete_quant_dequant_linear_op_pass.cc @@ -121,14 +121,27 @@ void DeleteQuantDequantLinearOpPass::ApplyImpl(ir::Graph* graph) const { true, platform::errors::InvalidArgument( "Input scale tensor's place should be CPU.")); - const float* input_scale_data = input_scale_tensor.data(); - float input_scale = input_scale_data[0]; + + float input_scale; + if (input_scale_tensor.dtype() == paddle::experimental::DataType::FLOAT32) { + const float* input_scale_data = input_scale_tensor.data(); + input_scale = input_scale_data[0]; + } else if (input_scale_tensor.dtype() == + paddle::experimental::DataType::FLOAT16) { + const phi::dtype::float16* input_scale_data = + input_scale_tensor.data(); + input_scale = static_cast(input_scale_data[0]); + } else { + PADDLE_THROW(platform::errors::Unimplemented("%d is not supported.", + input_scale_tensor.dtype())); + } int nums_any_ops = dequantize_linear_op_out->outputs.size(); for (int i = 0; i < nums_any_ops; ++i) { auto* any_op_desc = dequantize_linear_op_out->outputs[i]->Op(); any_op_desc->SetAttr("Input_scale_" + quantize_linear_op_x->Var()->Name(), input_scale); + // link x to any_op2 any_op_desc->RenameInput(dequantize_linear_op_out->Var()->Name(), quantize_linear_op_x->Var()->Name()); diff --git a/paddle/fluid/framework/ir/delete_weight_dequant_linear_op_decoder_pass.cc b/paddle/fluid/framework/ir/delete_weight_dequant_linear_op_decoder_pass.cc new file mode 100644 index 0000000000000..fe692d01928f7 --- /dev/null +++ b/paddle/fluid/framework/ir/delete_weight_dequant_linear_op_decoder_pass.cc @@ -0,0 +1,373 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/ir/delete_weight_dequant_linear_op_decoder_pass.h" + +#include +#include +#include +#include +#include + +namespace paddle { +namespace framework { +namespace ir { + +#define GET_IR_NODE(node__) GET_IR_NODE_FROM_SUBGRAPH(node__, node__, pattern); +#define GET_NODES \ + GET_IR_NODE(weight_dequantize_linear_op_x); \ + GET_IR_NODE(weight_dequantize_linear_op_scale); \ + GET_IR_NODE(weight_dequantize_linear_op); \ + GET_IR_NODE(weight_dequantize_linear_op_out); \ + GET_IR_NODE(any_op2); + +DeleteWeightDequantLinearOpDecoderPass:: + DeleteWeightDequantLinearOpDecoderPass() { + AddOpCompat(OpCompat("quantize_linear")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Scale") + .IsTensor() + .End() + .AddInput("ZeroPoint") + .IsTensor() + .IsOptional() + .End() + .AddOutput("Y") + .IsTensor() + .End() + .AddAttr("bit_length") + .IsType() + .End() + .AddAttr("quant_axis") + .IsType() + .End() + .AddAttr("round_type") + .IsOptional() + .IsType() + .End(); + AddOpCompat(OpCompat("dequantize_linear")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Scale") + .IsTensor() + .End() + .AddInput("ZeroPoint") + .IsTensor() + .IsOptional() + .End() + .AddOutput("Y") + .IsTensor() + .End() + .AddAttr("bit_length") + .IsType() + .End() + .AddAttr("quant_axis") + .IsType() + .End() + .AddAttr("round_type") + .IsOptional() + .IsType() + .End(); + AddOpCompat(OpCompat("conv2d")) + .AddInput("Input") + .IsTensor() + .End() + .AddInput("Filter") + .IsTensor() + .End() + .AddInput("Bias") + .IsTensor() + .IsOptional() + .End() + .AddInput("ResidualData") + .IsTensor() + .IsOptional() + .End() + .AddOutput("Output") + .IsTensor() + .End() + .AddAttr("strides") + .IsType>() + .End() + .AddAttr("paddings") + .IsType>() + .End() + .AddAttr("padding_algorithm") + .IsOptional() + .IsStringIn({"EXPLICIT", "SAME", "VALID"}) + .End() + .AddAttr("groups") + .IsNumGE(1) + .End() + .AddAttr("dilations") + .IsType>() + .End() + .AddAttr("data_format") + .IsStringIn({"NCHW", "NHWC", "AnyLayout"}) + .End(); + AddOpCompat(OpCompat("depthwise_conv2d")) + .AddInput("Input") + .IsTensor() + .End() + .AddInput("Filter") + .IsTensor() + .End() + .AddInput("Bias") + .IsTensor() + .IsOptional() + .End() + .AddInput("ResidualData") + .IsTensor() + .IsOptional() + .End() + .AddOutput("Output") + .IsTensor() + .End() + .AddAttr("strides") + .IsType>() + .End() + .AddAttr("paddings") + .IsType>() + .End() + .AddAttr("padding_algorithm") + .IsOptional() + .IsStringIn({"EXPLICIT", "SAME", "VALID"}) + .End() + .AddAttr("groups") + .IsNumGE(1) + .End() + .AddAttr("dilations") + .IsType>() + .End() + .AddAttr("data_format") + .IsStringIn({"NCHW", "NHWC", "AnyLayout"}) + .End(); + AddOpCompat(OpCompat("mul")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Y") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("x_num_col_dims") + .IsNumGE(1) + .End() + .AddAttr("y_num_col_dims") + .IsNumEQ(1) + .End(); + AddOpCompat(OpCompat("matmul_v2")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Y") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("trans_x") + .IsBoolEQ(false) + .End() + .AddAttr("trans_y") + .IsBoolEQ(false) + .End(); + AddOpCompat(OpCompat("matmul")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Y") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("alpha") + .IsNumGE(0.99f) + .IsNumLE(1.01f) + .End() + .AddAttr("transpose_X") + .IsBoolEQ(false) + .End() + .AddAttr("transpose_Y") + .IsBoolEQ(false) + .End(); + AddOpCompat(OpCompat("fc")) + .AddInput("Input") + .IsTensor() + .End() + .AddInput("W") + .IsTensor() + .End() + .AddInput("Bias") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("in_num_col_dims") + .IsNumGE(1) + .End() + .AddAttr("activation_type") + .IsStringIn({"relu", ""}) + .End(); + AddOpCompat(OpCompat("conv2d_transpose")) + .AddInput("Input") + .IsTensor() + .End() + .AddInput("Filter") + .IsTensor() + .End() + .AddInput("Bias") + .IsTensor() + .IsOptional() + .End() + .AddOutput("Output") + .IsTensor() + .End() + .AddAttr("output_padding") + .IsType>() + .IsOptional() + .End() + .AddAttr("output_size") + .IsType>() + .IsOptional() + .End() + .AddAttr("groups") + .IsNumGE(1) + .End() + .AddAttr("dilations") + .IsType>() + .End() + .AddAttr("strides") + .IsType>() + .End() + .AddAttr("paddings") + .IsType>() + .End() + .AddAttr("padding_algorithm") + .IsOptional() + .IsStringIn({"EXPLICIT", "SAME", "VALID"}) + .End() + .AddAttr("data_format") + .IsStringIn({"NCHW", "NHWC", "AnyLayout"}) + .End(); +} +// Delete dequantize_linear_op, then dequantize weight +void DeleteWeightDequantLinearOpDecoderPass::ApplyImpl(ir::Graph* graph) const { + const std::string pattern_name = + "delete_weight_dequant_linear_op_decoder_pattern"; + FusePassBase::Init(pattern_name, graph); + + GraphPatternDetector gpd; + auto* scope = param_scope(); + PADDLE_ENFORCE_NOT_NULL(scope, + platform::errors::InvalidArgument( + "Scope in DeleteWeightDequantLinearOpDecoderPass " + "should not be null.")); + // Create pattern + patterns::DeleteWeightDequantLinearOpDecoderPattern pattern( + gpd.mutable_pattern(), pattern_name); + pattern(); + int found_count = 0; + bool is_int8 = false; + + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* g) { + GET_NODES; + /* + if (!IsCompat(subgraph, g)) { + LOG(WARNING) << "delete_weight_dequant_linear_op_pass " + "compat check failed."; + return; + } + */ + is_int8 = true; + std::unordered_set nodes2rm = {}; + + auto* any_op2_desc = any_op2->Op(); + + // Get weight scale + std::vector weight_scale; + auto* weight_scale_tensor = + scope->GetVar(weight_dequantize_linear_op_scale->Name()) + ->GetMutable(); + auto weight_scale_nums = weight_scale_tensor->numel(); + + if (weight_scale_tensor->dtype() == + paddle::experimental::DataType::FLOAT32) { + float* weight_scale_data = weight_scale_tensor->data(); + for (int i = 0; i < weight_scale_nums; i++) { + weight_scale.push_back(weight_scale_data[i]); + } + } else if (weight_scale_tensor->dtype() == + paddle::experimental::DataType::FLOAT16) { + phi::dtype::float16* weight_scale_data = + weight_scale_tensor->data(); + for (int i = 0; i < weight_scale_nums; i++) { + weight_scale.push_back(static_cast(weight_scale_data[i])); + } + } else { + PADDLE_THROW(platform::errors::Unimplemented( + "%d is not supported.", weight_scale_tensor->dtype())); + } + + int quant_axis = PADDLE_GET_CONST( + int, weight_dequantize_linear_op->Op()->GetAttr("quant_axis")); + if (quant_axis == -1) { // per_layer quant_dequant: all OP + PADDLE_ENFORCE_EQ(weight_scale_nums, + 1, + platform::errors::InvalidArgument( + "When quant_axis == -1 means use per_layer " + "quant_dequant, weight_scale'number should be 1.")); + + // Add attr to anyop 2 + any_op2_desc->SetAttr("weight_scale", weight_scale[0]); + } else { + PADDLE_THROW(platform::errors::Unimplemented( + "Delete Weight Dequant Linear Op Encoder Pass is not supported for " + "per-channel quantization")); + } + + nodes2rm.insert(weight_dequantize_linear_op_scale); + nodes2rm.insert(weight_dequantize_linear_op); + nodes2rm.insert(weight_dequantize_linear_op_out); + + // relink weight to any_op2 + any_op2_desc->RenameInput(weight_dequantize_linear_op_out->Var()->Name(), + weight_dequantize_linear_op_x->Var()->Name()); + any_op2_desc->Flush(); + IR_NODE_LINK_TO(weight_dequantize_linear_op_x, any_op2); + GraphSafeRemoveNodes(graph, nodes2rm); + found_count++; + }; + gpd(graph, handler); + if (is_int8) { + auto& enable_int8 = graph->Get("enable_int8"); + enable_int8 = true; + } + AddStatis(found_count); +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(delete_weight_dequant_linear_op_decoder_pass, + paddle::framework::ir::DeleteWeightDequantLinearOpDecoderPass); diff --git a/paddle/fluid/framework/ir/delete_weight_dequant_linear_op_decoder_pass.h b/paddle/fluid/framework/ir/delete_weight_dequant_linear_op_decoder_pass.h new file mode 100644 index 0000000000000..866bfb7b73654 --- /dev/null +++ b/paddle/fluid/framework/ir/delete_weight_dequant_linear_op_decoder_pass.h @@ -0,0 +1,34 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/framework/ir/fuse_pass_base.h" + +namespace paddle { +namespace framework { +namespace ir { + +class DeleteWeightDequantLinearOpDecoderPass : public FusePassBase { + public: + DeleteWeightDequantLinearOpDecoderPass(); + virtual ~DeleteWeightDequantLinearOpDecoderPass() {} + + protected: + void ApplyImpl(ir::Graph* graph) const override; +}; + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/delete_weight_dequant_linear_op_encoder_pass.cc b/paddle/fluid/framework/ir/delete_weight_dequant_linear_op_encoder_pass.cc new file mode 100644 index 0000000000000..0cffcd38b3466 --- /dev/null +++ b/paddle/fluid/framework/ir/delete_weight_dequant_linear_op_encoder_pass.cc @@ -0,0 +1,370 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/ir/delete_weight_dequant_linear_op_encoder_pass.h" + +#include +#include +#include +#include +#include + +namespace paddle { +namespace framework { +namespace ir { + +#define GET_IR_NODE(node__) GET_IR_NODE_FROM_SUBGRAPH(node__, node__, pattern); +#define GET_NODES \ + GET_IR_NODE(weight_dequantize_linear_op_x); \ + GET_IR_NODE(weight_dequantize_linear_op_scale); \ + GET_IR_NODE(weight_dequantize_linear_op); \ + GET_IR_NODE(weight_dequantize_linear_op_out); \ + GET_IR_NODE(any_op2); + +DeleteWeightDequantLinearOpEncoderPass:: + DeleteWeightDequantLinearOpEncoderPass() { + AddOpCompat(OpCompat("quantize_linear")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Scale") + .IsTensor() + .End() + .AddInput("ZeroPoint") + .IsTensor() + .IsOptional() + .End() + .AddOutput("Y") + .IsTensor() + .End() + .AddAttr("bit_length") + .IsType() + .End() + .AddAttr("quant_axis") + .IsType() + .End() + .AddAttr("round_type") + .IsOptional() + .IsType() + .End(); + AddOpCompat(OpCompat("dequantize_linear")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Scale") + .IsTensor() + .End() + .AddInput("ZeroPoint") + .IsTensor() + .IsOptional() + .End() + .AddOutput("Y") + .IsTensor() + .End() + .AddAttr("bit_length") + .IsType() + .End() + .AddAttr("quant_axis") + .IsType() + .End() + .AddAttr("round_type") + .IsOptional() + .IsType() + .End(); + AddOpCompat(OpCompat("conv2d")) + .AddInput("Input") + .IsTensor() + .End() + .AddInput("Filter") + .IsTensor() + .End() + .AddInput("Bias") + .IsTensor() + .IsOptional() + .End() + .AddInput("ResidualData") + .IsTensor() + .IsOptional() + .End() + .AddOutput("Output") + .IsTensor() + .End() + .AddAttr("strides") + .IsType>() + .End() + .AddAttr("paddings") + .IsType>() + .End() + .AddAttr("padding_algorithm") + .IsOptional() + .IsStringIn({"EXPLICIT", "SAME", "VALID"}) + .End() + .AddAttr("groups") + .IsNumGE(1) + .End() + .AddAttr("dilations") + .IsType>() + .End() + .AddAttr("data_format") + .IsStringIn({"NCHW", "NHWC", "AnyLayout"}) + .End(); + AddOpCompat(OpCompat("depthwise_conv2d")) + .AddInput("Input") + .IsTensor() + .End() + .AddInput("Filter") + .IsTensor() + .End() + .AddInput("Bias") + .IsTensor() + .IsOptional() + .End() + .AddInput("ResidualData") + .IsTensor() + .IsOptional() + .End() + .AddOutput("Output") + .IsTensor() + .End() + .AddAttr("strides") + .IsType>() + .End() + .AddAttr("paddings") + .IsType>() + .End() + .AddAttr("padding_algorithm") + .IsOptional() + .IsStringIn({"EXPLICIT", "SAME", "VALID"}) + .End() + .AddAttr("groups") + .IsNumGE(1) + .End() + .AddAttr("dilations") + .IsType>() + .End() + .AddAttr("data_format") + .IsStringIn({"NCHW", "NHWC", "AnyLayout"}) + .End(); + AddOpCompat(OpCompat("mul")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Y") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("x_num_col_dims") + .IsNumGE(1) + .End() + .AddAttr("y_num_col_dims") + .IsNumEQ(1) + .End(); + AddOpCompat(OpCompat("matmul_v2")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Y") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("trans_x") + .IsBoolEQ(false) + .End() + .AddAttr("trans_y") + .IsBoolEQ(false) + .End(); + AddOpCompat(OpCompat("matmul")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Y") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("alpha") + .IsNumGE(0.99f) + .IsNumLE(1.01f) + .End() + .AddAttr("transpose_X") + .IsBoolEQ(false) + .End() + .AddAttr("transpose_Y") + .IsBoolEQ(false) + .End(); + AddOpCompat(OpCompat("fc")) + .AddInput("Input") + .IsTensor() + .End() + .AddInput("W") + .IsTensor() + .End() + .AddInput("Bias") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("in_num_col_dims") + .IsNumGE(1) + .End() + .AddAttr("activation_type") + .IsStringIn({"relu", ""}) + .End(); + AddOpCompat(OpCompat("conv2d_transpose")) + .AddInput("Input") + .IsTensor() + .End() + .AddInput("Filter") + .IsTensor() + .End() + .AddInput("Bias") + .IsTensor() + .IsOptional() + .End() + .AddOutput("Output") + .IsTensor() + .End() + .AddAttr("output_padding") + .IsType>() + .IsOptional() + .End() + .AddAttr("output_size") + .IsType>() + .IsOptional() + .End() + .AddAttr("groups") + .IsNumGE(1) + .End() + .AddAttr("dilations") + .IsType>() + .End() + .AddAttr("strides") + .IsType>() + .End() + .AddAttr("paddings") + .IsType>() + .End() + .AddAttr("padding_algorithm") + .IsOptional() + .IsStringIn({"EXPLICIT", "SAME", "VALID"}) + .End() + .AddAttr("data_format") + .IsStringIn({"NCHW", "NHWC", "AnyLayout"}) + .End(); +} +// Delete dequantize_linear_op, then dequantize weight +void DeleteWeightDequantLinearOpEncoderPass::ApplyImpl(ir::Graph* graph) const { + const std::string pattern_name = + "delete_weight_dequant_linear_op_encoder_pattern"; + FusePassBase::Init(pattern_name, graph); + + GraphPatternDetector gpd; + auto* scope = param_scope(); + PADDLE_ENFORCE_NOT_NULL(scope, + platform::errors::InvalidArgument( + "Scope in DeleteWeightDequantLinearOpEncoderPass " + "should not be null.")); + // Create pattern + patterns::DeleteWeightDequantLinearOpEncoderPattern pattern( + gpd.mutable_pattern(), pattern_name); + pattern(); + int found_count = 0; + bool is_int8 = false; + + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* g) { + GET_NODES; + /* + if (!IsCompat(subgraph, g)) { + LOG(WARNING) << "delete_weight_dequant_linear_op_pass " + "compat check failed."; + return; + } + */ + is_int8 = true; + std::unordered_set nodes2rm = {}; + + auto* any_op2_desc = any_op2->Op(); + + // Get weight scale + std::vector weight_scale; + auto* weight_scale_tensor = + scope->GetVar(weight_dequantize_linear_op_scale->Name()) + ->GetMutable(); + auto weight_scale_nums = weight_scale_tensor->numel(); + + if (weight_scale_tensor->dtype() == + paddle::experimental::DataType::FLOAT32) { + float* weight_scale_data = weight_scale_tensor->data(); + for (int i = 0; i < weight_scale_nums; i++) { + weight_scale.push_back(weight_scale_data[i]); + } + } else if (weight_scale_tensor->dtype() == + paddle::experimental::DataType::FLOAT16) { + phi::dtype::float16* weight_scale_data = + weight_scale_tensor->data(); + for (int i = 0; i < weight_scale_nums; i++) { + weight_scale.push_back(static_cast(weight_scale_data[i])); + } + } else { + PADDLE_THROW(platform::errors::Unimplemented( + "%d is not supported.", weight_scale_tensor->dtype())); + } + + int quant_axis = PADDLE_GET_CONST( + int, weight_dequantize_linear_op->Op()->GetAttr("quant_axis")); + if (quant_axis == -1) { // per_layer quant_dequant: all OP + PADDLE_ENFORCE_EQ(weight_scale_nums, + 1, + platform::errors::InvalidArgument( + "When quant_axis == -1 means use per_layer " + "quant_dequant, weight_scale'number should be 1.")); + + // Add attr to anyop 2 + any_op2_desc->SetAttr("weight_scale", weight_scale[0]); + } else { + PADDLE_THROW(platform::errors::Unimplemented( + "Delete Weight Dequant Linear Op Encoder Pass is not supported for " + "per-channel quantization")); + } + + nodes2rm.insert(weight_dequantize_linear_op_scale); + nodes2rm.insert(weight_dequantize_linear_op); + nodes2rm.insert(weight_dequantize_linear_op_out); + + // relink weight to any_op2 + any_op2_desc->RenameInput(weight_dequantize_linear_op_out->Var()->Name(), + weight_dequantize_linear_op_x->Var()->Name()); + any_op2_desc->Flush(); + IR_NODE_LINK_TO(weight_dequantize_linear_op_x, any_op2); + GraphSafeRemoveNodes(graph, nodes2rm); + found_count++; + }; + gpd(graph, handler); + graph->Set("enable_int8", new bool(is_int8)); + AddStatis(found_count); +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(delete_weight_dequant_linear_op_encoder_pass, + paddle::framework::ir::DeleteWeightDequantLinearOpEncoderPass); diff --git a/paddle/fluid/framework/ir/delete_weight_dequant_linear_op_encoder_pass.h b/paddle/fluid/framework/ir/delete_weight_dequant_linear_op_encoder_pass.h new file mode 100644 index 0000000000000..8aead6bd5cc58 --- /dev/null +++ b/paddle/fluid/framework/ir/delete_weight_dequant_linear_op_encoder_pass.h @@ -0,0 +1,34 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/framework/ir/fuse_pass_base.h" + +namespace paddle { +namespace framework { +namespace ir { + +class DeleteWeightDequantLinearOpEncoderPass : public FusePassBase { + public: + DeleteWeightDequantLinearOpEncoderPass(); + virtual ~DeleteWeightDequantLinearOpEncoderPass() {} + + protected: + void ApplyImpl(ir::Graph* graph) const override; +}; + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/fuse_multi_transformer_layer_pass.cc b/paddle/fluid/framework/ir/fuse_multi_transformer_layer_pass.cc index b730d46ab7c5f..7cbb601a68cbf 100644 --- a/paddle/fluid/framework/ir/fuse_multi_transformer_layer_pass.cc +++ b/paddle/fluid/framework/ir/fuse_multi_transformer_layer_pass.cc @@ -118,9 +118,15 @@ int FuseMultiTransformerLayerPass::BuildFusion(Graph* graph, GraphPatternDetector gpd; auto* pattern = gpd.mutable_pattern(); - // TODO(wufeisheng): Get enable_int8 attr from graph after - // fused_multi_transformer pass with int8 merged bool enable_int8 = false; + if (graph->Has("enable_int8")) { + enable_int8 = graph->Get("enable_int8"); + } + if (!enable_int8) { + VLOG(4) + << "fuse_multi_layer_transformer_pass will match float transformer op " + "cause enable_int8 is not been set or set to false"; + } int num_fuse_op = 0; bool is_decoder = false; @@ -209,7 +215,13 @@ int FuseMultiTransformerLayerPass::BuildFusion(Graph* graph, "OutLinearW", "QKVBias", "QKVW"}; - + if (enable_int8) { + std::vector inputs_names_int8_supp = { + "FFN1OutScale", "FFN2OutScale", "OutLinearOutScale", "QKVOutScale"}; + inputs_names.insert(inputs_names.end(), + inputs_names_int8_supp.begin(), + inputs_names_int8_supp.end()); + } for (const auto& input_name : inputs_names) { MergeInput(fuse_op_descs[0], fuse_op_input_var_name_maps, input_name); } @@ -227,6 +239,17 @@ int FuseMultiTransformerLayerPass::BuildFusion(Graph* graph, } fuse_op_descs[0]->SetOutput("CacheKVOut", merged_cache_kv_out_names); + if (enable_int8) { + // Merge inputs scale + std::vector attr_names = {"qkv_in_scale", + "out_linear_in_scale", + "ffn1_in_scale", + "ffn2_in_scale"}; + for (const auto& name : attr_names) { + MergeAttrs(fuse_op_descs, name); + } + } + //////////////// //// ReLink //// //////////////// diff --git a/paddle/fluid/framework/ir/fuse_multi_transformer_layer_pass_tester.cc b/paddle/fluid/framework/ir/fuse_multi_transformer_layer_pass_tester.cc index 72635d1c95855..c96935a9ac649 100644 --- a/paddle/fluid/framework/ir/fuse_multi_transformer_layer_pass_tester.cc +++ b/paddle/fluid/framework/ir/fuse_multi_transformer_layer_pass_tester.cc @@ -98,6 +98,7 @@ TEST(FuseMultiTransformerLayerPass, encoder_fp) { std::unique_ptr graph(new ir::Graph(layers.main_program())); graph->Set("__param_scope__", CreateParamScope()); graph->Set(kFusedMultiTransformerEncoderFusionCount, new int(num_layers)); + graph->Set("enable_int8", new bool(false)); auto pass = PassRegistry::Instance().Get("fuse_multi_transformer_layer_pass"); if (pass.get() == nullptr) diff --git a/paddle/fluid/framework/ir/fused_multi_transformer_decoder_pass.cc b/paddle/fluid/framework/ir/fused_multi_transformer_decoder_pass.cc index 2d93758f177d2..bc1a2dd0ed4de 100644 --- a/paddle/fluid/framework/ir/fused_multi_transformer_decoder_pass.cc +++ b/paddle/fluid/framework/ir/fused_multi_transformer_decoder_pass.cc @@ -1075,12 +1075,27 @@ PDNode* MultiDevicesFusedMultiTransformerDecoderFuseQKVPattern::operator()() { } // namespace patterns +inline Node* CreatePersistableVarNode(Graph* graph, const std::string& name) { + auto var_desc = VarDesc(name); + var_desc.SetDataType(framework::proto::VarType::FP32); + var_desc.SetPersistable(true); + auto node = graph->CreateVarNode(&var_desc); + return node; +} + int FusedMultiTransformerDecoderPass::BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope) const { GraphPatternDetector gpd; auto* pattern = gpd.mutable_pattern(); + bool enable_int8 = graph->Get("enable_int8"); + if (enable_int8) { + VLOG(3) << "FusedMultiTransformerDecoderPass with int8"; + } else { + VLOG(3) << "FusedMultiTransformerDecoderPass with fp"; + } + // Create pattern. patterns::FusedMultiTransformerDecoderPattern fused_multi_transformer_pattern( pattern, name_scope); @@ -1093,6 +1108,7 @@ int FusedMultiTransformerDecoderPass::BuildFusion(Graph* graph, Node* layer_norm_bias, Node* layer_norm_mean, Node* layer_norm_variance, + Node* matmul0, Node* matmul0_w, Node* matmul1_w, Node* matmul2_w, @@ -1103,6 +1119,7 @@ int FusedMultiTransformerDecoderPass::BuildFusion(Graph* graph, Node* transpose2_2_out, Node* eltadd_qk_b, Node* reshape2_0, + Node* matmul_linear, Node* matmul_linear_w, Node* eltadd_linear_b, Node* ffn_layer_norm, @@ -1110,11 +1127,17 @@ int FusedMultiTransformerDecoderPass::BuildFusion(Graph* graph, Node* ffn_layer_norm_bias, Node* ffn_layer_norm_mean, Node* ffn_layer_norm_variance, + Node* ffn_matmul0, Node* ffn_matmul0_w, + Node* ffn_matmul1, Node* ffn_matmul1_w, Node* ffn_eltadd0_b, Node* ffn_eltadd1_b, Node* ffn_output) { + auto* matmul0_op = matmul0->Op(); + auto* matmul_linear_op = matmul_linear->Op(); + auto* ffn_matmul_0_op = ffn_matmul0->Op(); + auto* ffn_matmul_1_op = ffn_matmul1->Op(); // Calc index of transformer layer by LayerNorm Scale name // This calculation assumes: // 1. no LayerNorm before all transformer layer @@ -1126,7 +1149,9 @@ int FusedMultiTransformerDecoderPass::BuildFusion(Graph* graph, // create fused_multi_transformer OpDesc fused_multi_transformer_op_desc(layer_norm->Op()->Block()); - fused_multi_transformer_op_desc.SetType("fused_multi_transformer"); + fused_multi_transformer_op_desc.SetType(enable_int8 + ? "fused_multi_transformer_int8" + : "fused_multi_transformer"); // 1. Input setting fused_multi_transformer_op_desc.SetInput("X", {input0->Name()}); @@ -1181,8 +1206,66 @@ int FusedMultiTransformerDecoderPass::BuildFusion(Graph* graph, fused_multi_transformer_op_desc.SetAttr("is_test", true); fused_multi_transformer_op_desc.SetAttr("dropout_rate", 0.0f); + if (enable_int8) { + // Set input scale + std::string qkv_input_name = matmul0_op->Input("X")[0]; + auto qkv_in_scale = PADDLE_GET_CONST( + float, matmul0_op->GetAttr("Input_scale_" + qkv_input_name)); + std::string out_linear_input_name = matmul_linear_op->Input("X")[0]; + auto out_linear_in_scale = PADDLE_GET_CONST( + float, + matmul_linear_op->GetAttr("Input_scale_" + out_linear_input_name)); + std::string ffn0_input_name = ffn_matmul_0_op->Input("X")[0]; + auto ffn0_in_scale = PADDLE_GET_CONST( + float, ffn_matmul_0_op->GetAttr("Input_scale_" + ffn0_input_name)); + std::string ffn1_input_name = ffn_matmul_1_op->Input("X")[0]; + auto ffn1_in_scale = PADDLE_GET_CONST( + float, ffn_matmul_1_op->GetAttr("Input_scale_" + ffn1_input_name)); + + // Inverse input scale + qkv_in_scale = 1.0f / qkv_in_scale; + out_linear_in_scale = 1.0f / out_linear_in_scale; + ffn0_in_scale = 1.0f / ffn0_in_scale; + ffn1_in_scale = 1.0f / ffn1_in_scale; + + fused_multi_transformer_op_desc.SetAttr("qkv_in_scale", + std::vector{qkv_in_scale}); + fused_multi_transformer_op_desc.SetAttr( + "out_linear_in_scale", std::vector{out_linear_in_scale}); + fused_multi_transformer_op_desc.SetAttr( + "ffn1_in_scale", std::vector{ffn0_in_scale}); + fused_multi_transformer_op_desc.SetAttr( + "ffn2_in_scale", std::vector{ffn1_in_scale}); + + fused_multi_transformer_op_desc.SetInput( + "QKVOutScale", {matmul0_w->Name() + "_out_scale"}); + fused_multi_transformer_op_desc.SetInput( + "OutLinearOutScale", {matmul_linear_w->Name() + "_out_scale"}); + fused_multi_transformer_op_desc.SetInput( + "FFN1OutScale", {ffn_matmul0_w->Name() + "_out_scale"}); + fused_multi_transformer_op_desc.SetInput( + "FFN2OutScale", {ffn_matmul1_w->Name() + "_out_scale"}); + } + auto* fused_multi_transformer = graph->CreateOpNode(&fused_multi_transformer_op_desc); + + if (enable_int8) { + auto qkv_out_scale_node = + CreatePersistableVarNode(graph, matmul0_w->Name() + "_out_scale"); + auto out_out_scale_node = CreatePersistableVarNode( + graph, matmul_linear_w->Name() + "_out_scale"); + auto ffn0_out_scale_node = + CreatePersistableVarNode(graph, ffn_matmul0_w->Name() + "_out_scale"); + auto ffn1_out_scale_node = + CreatePersistableVarNode(graph, ffn_matmul1_w->Name() + "_out_scale"); + + IR_NODE_LINK_TO(qkv_out_scale_node, fused_multi_transformer); + IR_NODE_LINK_TO(out_out_scale_node, fused_multi_transformer); + IR_NODE_LINK_TO(ffn0_out_scale_node, fused_multi_transformer); + IR_NODE_LINK_TO(ffn1_out_scale_node, fused_multi_transformer); + } + IR_NODE_LINK_TO(input0, fused_multi_transformer); IR_NODE_LINK_TO(layer_norm_scale, fused_multi_transformer); IR_NODE_LINK_TO(layer_norm_bias, fused_multi_transformer); @@ -1456,6 +1539,7 @@ int FusedMultiTransformerDecoderPass::BuildFusion(Graph* graph, layer_norm_bias, layer_norm_mean, layer_norm_variance, + matmul0, matmul0_w, matmul1_w, matmul2_w, @@ -1466,6 +1550,7 @@ int FusedMultiTransformerDecoderPass::BuildFusion(Graph* graph, transpose2_2_out, eltadd_qk_b, reshape2_0, + matmul_linear, matmul_linear_w, eltadd_linear_b, ffn_layer_norm, @@ -1473,7 +1558,9 @@ int FusedMultiTransformerDecoderPass::BuildFusion(Graph* graph, ffn_layer_norm_bias, ffn_layer_norm_mean, ffn_layer_norm_variance, + ffn_matmul0, ffn_matmul0_w, + ffn_matmul1, ffn_matmul1_w, ffn_eltadd0_b, ffn_eltadd1_b, @@ -1732,6 +1819,13 @@ int FusedMultiTransformerDecoderFuseQKVPass::BuildFusion( GraphPatternDetector gpd; auto* pattern = gpd.mutable_pattern(); + bool enable_int8 = graph->Get("enable_int8"); + if (enable_int8) { + VLOG(3) << "FusedMultiTransformerDecoderFuseQKVPass with int8"; + } else { + VLOG(3) << "FusedMultiTransformerDecoderFuseQKVPass with fp"; + } + // Create pattern. patterns::FusedMultiTransformerDecoderFuseQKVPattern fused_multi_transformer_fuse_qkv_pattern(pattern, name_scope); @@ -1744,10 +1838,12 @@ int FusedMultiTransformerDecoderFuseQKVPass::BuildFusion( Node* layer_norm_bias, Node* layer_norm_mean, Node* layer_norm_variance, + Node* matmul0, Node* matmul0_w, Node* eltadd0_b, Node* eltadd_qk_b, Node* reshape2_0, + Node* matmul_linear, Node* matmul_linear_w, Node* eltadd_linear_b, Node* ffn_layer_norm, @@ -1755,11 +1851,17 @@ int FusedMultiTransformerDecoderFuseQKVPass::BuildFusion( Node* ffn_layer_norm_bias, Node* ffn_layer_norm_mean, Node* ffn_layer_norm_variance, + Node* ffn_matmul0, Node* ffn_matmul0_w, + Node* ffn_matmul1, Node* ffn_matmul1_w, Node* ffn_eltadd0_b, Node* ffn_eltadd1_b, Node* ffn_output) { + auto* matmul0_op = matmul0->Op(); + auto* matmul_linear_op = matmul_linear->Op(); + auto* ffn_matmul_0_op = ffn_matmul0->Op(); + auto* ffn_matmul_1_op = ffn_matmul1->Op(); // Calc index of transformer layer by LayerNorm Scale name // This calculation assumes: // 1. no LayerNorm before all transformer layer @@ -1771,7 +1873,9 @@ int FusedMultiTransformerDecoderFuseQKVPass::BuildFusion( // create fused_multi_transformer OpDesc fused_multi_transformer_op_desc(layer_norm->Op()->Block()); - fused_multi_transformer_op_desc.SetType("fused_multi_transformer"); + fused_multi_transformer_op_desc.SetType(enable_int8 + ? "fused_multi_transformer_int8" + : "fused_multi_transformer"); // 1. Input setting fused_multi_transformer_op_desc.SetInput("X", {input0->Name()}); @@ -1826,8 +1930,65 @@ int FusedMultiTransformerDecoderFuseQKVPass::BuildFusion( fused_multi_transformer_op_desc.SetAttr("dropout_rate", 0.0f); fused_multi_transformer_op_desc.SetAttr("is_test", true); + if (enable_int8) { + // Set input scale + std::string qkv_input_name = matmul0_op->Input("X")[0]; + auto qkv_in_scale = PADDLE_GET_CONST( + float, matmul0_op->GetAttr("Input_scale_" + qkv_input_name)); + std::string out_linear_input_name = matmul_linear_op->Input("X")[0]; + auto out_linear_in_scale = PADDLE_GET_CONST( + float, + matmul_linear_op->GetAttr("Input_scale_" + out_linear_input_name)); + std::string ffn0_input_name = ffn_matmul_0_op->Input("X")[0]; + auto ffn0_in_scale = PADDLE_GET_CONST( + float, ffn_matmul_0_op->GetAttr("Input_scale_" + ffn0_input_name)); + std::string ffn1_input_name = ffn_matmul_1_op->Input("X")[0]; + auto ffn1_in_scale = PADDLE_GET_CONST( + float, ffn_matmul_1_op->GetAttr("Input_scale_" + ffn1_input_name)); + + // Inverse input scale + qkv_in_scale = 1.0f / qkv_in_scale; + out_linear_in_scale = 1.0f / out_linear_in_scale; + ffn0_in_scale = 1.0f / ffn0_in_scale; + ffn1_in_scale = 1.0f / ffn1_in_scale; + + fused_multi_transformer_op_desc.SetAttr("qkv_in_scale", + std::vector{qkv_in_scale}); + fused_multi_transformer_op_desc.SetAttr( + "out_linear_in_scale", std::vector{out_linear_in_scale}); + fused_multi_transformer_op_desc.SetAttr( + "ffn1_in_scale", std::vector{ffn0_in_scale}); + fused_multi_transformer_op_desc.SetAttr( + "ffn2_in_scale", std::vector{ffn1_in_scale}); + + fused_multi_transformer_op_desc.SetInput( + "QKVOutScale", {matmul0_w->Name() + "_out_scale"}); + fused_multi_transformer_op_desc.SetInput( + "OutLinearOutScale", {matmul_linear_w->Name() + "_out_scale"}); + fused_multi_transformer_op_desc.SetInput( + "FFN1OutScale", {ffn_matmul0_w->Name() + "_out_scale"}); + fused_multi_transformer_op_desc.SetInput( + "FFN2OutScale", {ffn_matmul1_w->Name() + "_out_scale"}); + } + auto* fused_multi_transformer = graph->CreateOpNode(&fused_multi_transformer_op_desc); + + if (enable_int8) { + auto qkv_out_scale_node = + CreatePersistableVarNode(graph, matmul0_w->Name() + "_out_scale"); + auto out_out_scale_node = CreatePersistableVarNode( + graph, matmul_linear_w->Name() + "_out_scale"); + auto ffn0_out_scale_node = + CreatePersistableVarNode(graph, ffn_matmul0_w->Name() + "_out_scale"); + auto ffn1_out_scale_node = + CreatePersistableVarNode(graph, ffn_matmul1_w->Name() + "_out_scale"); + + IR_NODE_LINK_TO(qkv_out_scale_node, fused_multi_transformer); + IR_NODE_LINK_TO(out_out_scale_node, fused_multi_transformer); + IR_NODE_LINK_TO(ffn0_out_scale_node, fused_multi_transformer); + IR_NODE_LINK_TO(ffn1_out_scale_node, fused_multi_transformer); + } IR_NODE_LINK_TO(input0, fused_multi_transformer); IR_NODE_LINK_TO(layer_norm_scale, fused_multi_transformer); IR_NODE_LINK_TO(layer_norm_bias, fused_multi_transformer); @@ -2088,10 +2249,12 @@ int FusedMultiTransformerDecoderFuseQKVPass::BuildFusion( layer_norm_bias, layer_norm_mean, layer_norm_variance, + matmul0, matmul0_w, eltadd0_b, eltadd_qk_b, reshape2_0, + matmul_linear, matmul_linear_w, eltadd_linear_b, ffn_layer_norm, @@ -2099,7 +2262,9 @@ int FusedMultiTransformerDecoderFuseQKVPass::BuildFusion( ffn_layer_norm_bias, ffn_layer_norm_mean, ffn_layer_norm_variance, + ffn_matmul0, ffn_matmul0_w, + ffn_matmul1, ffn_matmul1_w, ffn_eltadd0_b, ffn_eltadd1_b, @@ -2349,6 +2514,13 @@ int MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::BuildFusion( GraphPatternDetector gpd; auto* pattern = gpd.mutable_pattern(); + bool enable_int8 = graph->Get("enable_int8"); + if (enable_int8) { + VLOG(3) << "MultiDevicesFusedMultiTransformerDecoderFuseQKVPass with int8"; + } else { + VLOG(3) << "MultiDevicesFusedMultiTransformerDecoderFuseQKVPass with fp"; + } + // Create pattern. patterns::MultiDevicesFusedMultiTransformerDecoderFuseQKVPattern fused_multi_transformer_fuse_qkv_pattern(pattern, name_scope); @@ -2362,10 +2534,12 @@ int MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::BuildFusion( Node* layer_norm_mean, Node* layer_norm_variance, Node* c_identity, + Node* matmul0, Node* matmul0_w, Node* eltadd0_b, Node* eltadd_qk_b, Node* reshape2_0, + Node* matmul_linear, Node* matmul_linear_w, Node* eltadd_linear_b, Node* ffn_layer_norm, @@ -2373,11 +2547,16 @@ int MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::BuildFusion( Node* ffn_layer_norm_bias, Node* ffn_layer_norm_mean, Node* ffn_layer_norm_variance, + Node* ffn_c_identity, + Node* ffn_matmul0, Node* ffn_matmul0_w, + Node* ffn_matmul1, Node* ffn_matmul1_w, Node* ffn_eltadd0_b, Node* ffn_eltadd1_b, Node* ffn_output) { + auto* matmul_linear_op = matmul_linear->Op(); + auto* ffn_matmul_1_op = ffn_matmul1->Op(); // Calc index of transformer layer by LayerNorm Scale name // This calculation assumes: // 1. no LayerNorm before all transformer layer @@ -2389,7 +2568,9 @@ int MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::BuildFusion( // create fused_multi_transformer OpDesc fused_multi_transformer_op_desc(layer_norm->Op()->Block()); - fused_multi_transformer_op_desc.SetType("fused_multi_transformer"); + fused_multi_transformer_op_desc.SetType(enable_int8 + ? "fused_multi_transformer_int8" + : "fused_multi_transformer"); // 1. Input setting fused_multi_transformer_op_desc.SetInput("X", {input0->Name()}); @@ -2449,8 +2630,71 @@ int MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::BuildFusion( fused_multi_transformer_op_desc.SetAttr("ring_id", c_identity_op->GetAttr("ring_id")); + if (enable_int8) { + std::string matmul_input_scale_suffix = c_identity_op->Input("X")[0]; + auto qkv_in_scale = PADDLE_GET_CONST( + float, + c_identity_op->GetAttr("Input_scale_" + matmul_input_scale_suffix)); + + std::string out_linear_input_name = matmul_linear_op->Input("X")[0]; + auto out_linear_in_scale = PADDLE_GET_CONST( + float, + matmul_linear_op->GetAttr("Input_scale_" + out_linear_input_name)); + + auto* ffn_c_identity_op = ffn_c_identity->Op(); + std::string ffn_input_scale_suffix = ffn_c_identity_op->Input("X")[0]; + auto ffn0_in_scale = PADDLE_GET_CONST( + float, + ffn_c_identity_op->GetAttr("Input_scale_" + ffn_input_scale_suffix)); + + std::string ffn1_input_name = ffn_matmul_1_op->Input("X")[0]; + auto ffn1_in_scale = PADDLE_GET_CONST( + float, ffn_matmul_1_op->GetAttr("Input_scale_" + ffn1_input_name)); + + // Inverse input scale + qkv_in_scale = 1.0f / qkv_in_scale; + out_linear_in_scale = 1.0f / out_linear_in_scale; + ffn0_in_scale = 1.0f / ffn0_in_scale; + ffn1_in_scale = 1.0f / ffn1_in_scale; + + fused_multi_transformer_op_desc.SetAttr("qkv_in_scale", + std::vector{qkv_in_scale}); + fused_multi_transformer_op_desc.SetAttr( + "out_linear_in_scale", std::vector{out_linear_in_scale}); + fused_multi_transformer_op_desc.SetAttr( + "ffn1_in_scale", std::vector{ffn0_in_scale}); + fused_multi_transformer_op_desc.SetAttr( + "ffn2_in_scale", std::vector{ffn1_in_scale}); + + fused_multi_transformer_op_desc.SetInput( + "QKVOutScale", {matmul0_w->Name() + "_out_scale"}); + fused_multi_transformer_op_desc.SetInput( + "OutLinearOutScale", {matmul_linear_w->Name() + "_out_scale"}); + fused_multi_transformer_op_desc.SetInput( + "FFN1OutScale", {ffn_matmul0_w->Name() + "_out_scale"}); + fused_multi_transformer_op_desc.SetInput( + "FFN2OutScale", {ffn_matmul1_w->Name() + "_out_scale"}); + } + auto* fused_multi_transformer = graph->CreateOpNode(&fused_multi_transformer_op_desc); + + if (enable_int8) { + auto qkv_out_scale_node = + CreatePersistableVarNode(graph, matmul0_w->Name() + "_out_scale"); + auto out_out_scale_node = CreatePersistableVarNode( + graph, matmul_linear_w->Name() + "_out_scale"); + auto ffn0_out_scale_node = + CreatePersistableVarNode(graph, ffn_matmul0_w->Name() + "_out_scale"); + auto ffn1_out_scale_node = + CreatePersistableVarNode(graph, ffn_matmul1_w->Name() + "_out_scale"); + + IR_NODE_LINK_TO(qkv_out_scale_node, fused_multi_transformer); + IR_NODE_LINK_TO(out_out_scale_node, fused_multi_transformer); + IR_NODE_LINK_TO(ffn0_out_scale_node, fused_multi_transformer); + IR_NODE_LINK_TO(ffn1_out_scale_node, fused_multi_transformer); + } + IR_NODE_LINK_TO(input0, fused_multi_transformer); IR_NODE_LINK_TO(layer_norm_scale, fused_multi_transformer); IR_NODE_LINK_TO(layer_norm_bias, fused_multi_transformer); @@ -2737,10 +2981,12 @@ int MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::BuildFusion( layer_norm_mean, layer_norm_variance, c_identity, + matmul0, matmul0_w, eltadd0_b, eltadd_qk_b, reshape2_0, + matmul_linear, matmul_linear_w, eltadd_linear_b, ffn_layer_norm, @@ -2748,7 +2994,10 @@ int MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::BuildFusion( ffn_layer_norm_bias, ffn_layer_norm_mean, ffn_layer_norm_variance, + ffn_c_identity, + ffn_matmul0, ffn_matmul0_w, + ffn_matmul1, ffn_matmul1_w, ffn_eltadd0_b, ffn_eltadd1_b, diff --git a/paddle/fluid/framework/ir/fused_multi_transformer_decoder_pass_tester.cc b/paddle/fluid/framework/ir/fused_multi_transformer_decoder_pass_tester.cc index dbb6781442492..2e54196e599a8 100644 --- a/paddle/fluid/framework/ir/fused_multi_transformer_decoder_pass_tester.cc +++ b/paddle/fluid/framework/ir/fused_multi_transformer_decoder_pass_tester.cc @@ -193,6 +193,7 @@ TEST(FusedMultiTransformerDecoderPass, basic) { std::unique_ptr graph(new ir::Graph(layers.main_program())); graph->Set("__param_scope__", CreateParamScope()); + graph->Set("enable_int8", new bool(false)); auto pass = PassRegistry::Instance().Get("fused_multi_transformer_decoder_pass"); @@ -344,6 +345,7 @@ TEST(FusedMultiTransformerDecoderFuseQKVPass, basic) { std::unique_ptr graph(new ir::Graph(layers.main_program())); graph->Set("__param_scope__", CreateParamScope()); + graph->Set("enable_int8", new bool(false)); auto pass = PassRegistry::Instance().Get( "fused_multi_transformer_decoder_fuse_qkv_pass"); @@ -503,6 +505,7 @@ TEST(MultiDevicesFusedMultiTransformerDecoderFuseQKVPass, basic) { std::unique_ptr graph(new ir::Graph(layers.main_program())); graph->Set("__param_scope__", CreateParamScope()); + graph->Set("enable_int8", new bool(false)); auto pass = PassRegistry::Instance().Get( "multi_devices_fused_multi_transformer_decoder_fuse_qkv_pass"); diff --git a/paddle/fluid/framework/ir/fused_multi_transformer_encoder_pass.cc b/paddle/fluid/framework/ir/fused_multi_transformer_encoder_pass.cc index 0503b3a0a3d59..3635613f8c54b 100644 --- a/paddle/fluid/framework/ir/fused_multi_transformer_encoder_pass.cc +++ b/paddle/fluid/framework/ir/fused_multi_transformer_encoder_pass.cc @@ -1025,21 +1025,14 @@ template inline void QKVWeightsProcess(phi::DenseTensor* wq_tensor, phi::DenseTensor* wk_tensor, phi::DenseTensor* wv_tensor, - phi::DenseTensor* bq_tensor, - phi::DenseTensor* bk_tensor, - phi::DenseTensor* bv_tensor, const int num_head, const int dim_head, const int dim_embed) { auto* wq_data = wq_tensor->mutable_data(platform::CPUPlace()); auto* wk_data = wk_tensor->mutable_data(platform::CPUPlace()); auto* wv_data = wv_tensor->mutable_data(platform::CPUPlace()); - auto* bq_data = bq_tensor->mutable_data(platform::CPUPlace()); - auto* bk_data = bk_tensor->mutable_data(platform::CPUPlace()); - auto* bv_data = bv_tensor->mutable_data(platform::CPUPlace()); auto combined_w_dims = phi::make_ddim({3, num_head, dim_head, dim_embed}); - auto combined_bias_dims = phi::make_ddim({3, num_head, dim_head}); phi::DenseTensor tmp_combined_w_tensor; tmp_combined_w_tensor.Resize(combined_w_dims); @@ -1065,6 +1058,20 @@ inline void QKVWeightsProcess(phi::DenseTensor* wq_tensor, auto* new_combined_w_data = wq_tensor->mutable_data(platform::CPUPlace()); memcpy( new_combined_w_data, tmp_combined_w_data, sizeof(T) * wq_tensor->numel()); +} + +template +inline void QKVBiasProcess(phi::DenseTensor* bq_tensor, + phi::DenseTensor* bk_tensor, + phi::DenseTensor* bv_tensor, + const int num_head, + const int dim_head, + const int dim_embed) { + auto* bq_data = bq_tensor->mutable_data(platform::CPUPlace()); + auto* bk_data = bk_tensor->mutable_data(platform::CPUPlace()); + auto* bv_data = bv_tensor->mutable_data(platform::CPUPlace()); + + auto combined_bias_dims = phi::make_ddim({3, num_head, dim_head}); phi::DenseTensor tmp_combined_bias_tensor; tmp_combined_bias_tensor.Resize(combined_bias_dims); @@ -1085,13 +1092,57 @@ inline void QKVWeightsProcess(phi::DenseTensor* wq_tensor, sizeof(T) * bq_tensor->numel()); } +inline void QKVWeightsBiasProcess(phi::DenseTensor* wq_tensor, + phi::DenseTensor* wk_tensor, + phi::DenseTensor* wv_tensor, + phi::DenseTensor* bq_tensor, + phi::DenseTensor* bk_tensor, + phi::DenseTensor* bv_tensor, + const int num_head, + const int dim_head, + const int dim_embed) { + switch (wq_tensor->dtype()) { + case paddle::experimental::DataType::FLOAT16: + QKVWeightsProcess( + wq_tensor, wk_tensor, wv_tensor, num_head, dim_head, dim_embed); + break; + case paddle::experimental::DataType::FLOAT32: + QKVWeightsProcess( + wq_tensor, wk_tensor, wv_tensor, num_head, dim_head, dim_embed); + break; + case paddle::experimental::DataType::INT8: + QKVWeightsProcess( + wq_tensor, wk_tensor, wv_tensor, num_head, dim_head, dim_embed); + break; + default: + PADDLE_THROW(platform::errors::Unavailable( + "fused_multi_transformer not supported weight dtype. " + "we now only support fp32/fp16/int8.")); + break; + } + switch (bq_tensor->dtype()) { + case paddle::experimental::DataType::FLOAT16: + QKVBiasProcess( + bq_tensor, bk_tensor, bv_tensor, num_head, dim_head, dim_embed); + break; + case paddle::experimental::DataType::FLOAT32: + QKVBiasProcess( + bq_tensor, bk_tensor, bv_tensor, num_head, dim_head, dim_embed); + break; + default: + PADDLE_THROW(platform::errors::Unavailable( + "fused_multi_transformer not supported bias dtype. " + "we now only support fp32/fp16.")); + break; + } +} + template inline void QKVWeightsProcessFuseQKV(phi::DenseTensor* qkv_w_tensor, - phi::DenseTensor* qkv_b_tensor, const int num_head, const int dim_head, const int dim_embed) { - auto* qkv_w_data = qkv_w_tensor->mutable_data(platform::CPUPlace()); + auto* qkv_w_data = qkv_w_tensor->data(); auto transpose_w_dims = phi::make_ddim({3, num_head, dim_head, dim_embed}); phi::DenseTensor tmp_transpose_w_tensor; @@ -1120,8 +1171,14 @@ inline void QKVWeightsProcessFuseQKV(phi::DenseTensor* qkv_w_tensor, memcpy(new_transpose_w_data, tmp_transpose_w_data, sizeof(T) * qkv_w_tensor->numel()); +} - auto* qkv_b_data = qkv_b_tensor->mutable_data(platform::CPUPlace()); +template +inline void QKVBiasProcessFuseQKV(phi::DenseTensor* qkv_b_tensor, + const int num_head, + const int dim_head, + const int dim_embed) { + auto* qkv_b_data = qkv_b_tensor->data(); auto transpose_b_dims = phi::make_ddim({3, num_head, dim_head}); phi::DenseTensor tmp_transpose_b_tensor; @@ -1148,11 +1205,86 @@ inline void QKVWeightsProcessFuseQKV(phi::DenseTensor* qkv_w_tensor, sizeof(T) * qkv_b_tensor->numel()); } +inline void QKVWeightsBiasProcessFuseQKV(phi::DenseTensor* qkv_w_tensor, + phi::DenseTensor* qkv_b_tensor, + const int num_head, + const int dim_head, + const int dim_embed) { + switch (qkv_w_tensor->dtype()) { + case paddle::experimental::DataType::FLOAT16: + QKVWeightsProcessFuseQKV( + qkv_w_tensor, num_head, dim_head, dim_embed); + break; + case paddle::experimental::DataType::FLOAT32: + QKVWeightsProcessFuseQKV( + qkv_w_tensor, num_head, dim_head, dim_embed); + break; + case paddle::experimental::DataType::INT8: + QKVWeightsProcessFuseQKV( + qkv_w_tensor, num_head, dim_head, dim_embed); + break; + default: + PADDLE_THROW(platform::errors::Unavailable( + "fused_multi_transformer not supported weight dtype. " + "we now only support fp32/fp16/int8.")); + break; + } + switch (qkv_b_tensor->dtype()) { + case paddle::experimental::DataType::FLOAT16: + QKVBiasProcessFuseQKV( + qkv_b_tensor, num_head, dim_head, dim_embed); + break; + case paddle::experimental::DataType::FLOAT32: + QKVBiasProcessFuseQKV(qkv_b_tensor, num_head, dim_head, dim_embed); + break; + default: + PADDLE_THROW(platform::errors::Unavailable( + "fused_multi_transformer not supported bias dtype. " + "we now only support fp32/fp16.")); + break; + } +} + +// Just use for fused_multi_transformer_int8 +inline void TransposeWeights(phi::DenseTensor* weight_tensor) { + int m = weight_tensor->dims()[0]; + int n = weight_tensor->dims()[1]; + phi::DenseTensor tmp_weight_tensor; + auto tmp_weight_data = + tmp_weight_tensor.mutable_data({n, m}, platform::CPUPlace()); + auto weight_data = weight_tensor->data(); + for (int i = 0; i < m; ++i) { + for (int j = 0; j < n; ++j) { + int in_idx = i * n + j; + int out_idx = j * m + i; + tmp_weight_data[out_idx] = weight_data[in_idx]; + } + } + weight_tensor->Resize({n, m}); + auto new_weight_data = + weight_tensor->mutable_data(platform::CPUPlace()); + memcpy(new_weight_data, tmp_weight_data, sizeof(int8_t) * m * n); +} + +inline Node* CreatePersistableVarNode(Graph* graph, const std::string& name) { + auto var_desc = VarDesc(name); + var_desc.SetDataType(framework::proto::VarType::FP32); + var_desc.SetPersistable(true); + auto node = graph->CreateVarNode(&var_desc); + return node; +} + int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope) const { GraphPatternDetector gpd; auto* pattern = gpd.mutable_pattern(); + bool enable_int8 = graph->Get("enable_int8"); + if (enable_int8) { + VLOG(3) << "FusedMultiTransformerEncoderPass with int8"; + } else { + VLOG(3) << "FusedMultiTransformerEncoderPass with fp"; + } // Create pattern. patterns::FusedMultiTransformerEncoderPattern fused_multi_transformer_pattern( @@ -1166,6 +1298,7 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph, Node* layer_norm_bias, Node* layer_norm_mean, Node* layer_norm_variance, + Node* matmul0, Node* matmul0_w, Node* matmul1_w, Node* matmul2_w, @@ -1176,6 +1309,7 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph, Node* transpose2_2_out, Node* eltadd_qk_b, Node* reshape2_0, + Node* matmul_linear, Node* matmul_linear_w, Node* eltadd_linear_b, Node* while0, @@ -1184,7 +1318,9 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph, Node* ffn_layer_norm_bias, Node* ffn_layer_norm_mean, Node* ffn_layer_norm_variance, + Node* ffn_matmul0, Node* ffn_matmul0_w, + Node* ffn_matmul1, Node* ffn_matmul1_w, Node* ffn_eltadd0_b, Node* ffn_eltadd1_b, @@ -1196,7 +1332,14 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph, int dim_head = PADDLE_GET_CONST(std::vector, reshape_desc->GetAttr("shape")) .at(3); - int dim_embed = num_head * dim_head; + auto* layer_norm_bias_tensor = + scope->FindVar(layer_norm_bias->Name())->GetMutable(); + int dim_embed = layer_norm_bias_tensor->dims()[0]; + + auto* matmul0_op = matmul0->Op(); + auto* matmul_linear_op = matmul_linear->Op(); + auto* ffn_matmul_0_op = ffn_matmul0->Op(); + auto* ffn_matmul_1_op = ffn_matmul1->Op(); // Calc index of transformer layer by LayerNorm Scale name // This calculation assumes: @@ -1221,30 +1364,27 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph, auto* bv_tensor = scope->FindVar(eltadd2_b->Name())->GetMutable(); - if (wq_tensor->dtype() == phi::DataType::FLOAT32) { - QKVWeightsProcess(wq_tensor, - wk_tensor, - wv_tensor, - bq_tensor, - bk_tensor, - bv_tensor, - num_head, - dim_head, - dim_embed); - } else if (wq_tensor->dtype() == phi::DataType::FLOAT16) { - QKVWeightsProcess(wq_tensor, - wk_tensor, - wv_tensor, - bq_tensor, - bk_tensor, - bv_tensor, - num_head, - dim_head, - dim_embed); - } else { - PADDLE_THROW(platform::errors::Unavailable( - "fused_multi_transformer not supported weight dtype. " - "we now only support fp32 and fp16.")); + QKVWeightsBiasProcess(wq_tensor, + wk_tensor, + wv_tensor, + bq_tensor, + bk_tensor, + bv_tensor, + num_head, + dim_head, + dim_embed); + + if (enable_int8) { + auto* out_linear_w_tensor = scope->FindVar(matmul_linear_w->Name()) + ->GetMutable(); + auto* ffn0_w_tensor = + scope->FindVar(ffn_matmul0_w->Name())->GetMutable(); + auto* ffn1_w_tensor = + scope->FindVar(ffn_matmul1_w->Name())->GetMutable(); + + TransposeWeights(out_linear_w_tensor); + TransposeWeights(ffn0_w_tensor); + TransposeWeights(ffn1_w_tensor); } // reuse the mul0_w and eltadd_0_b nodes for the combined nodes. @@ -1261,7 +1401,9 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph, // create fused_multi_transformer OpDesc fused_multi_transformer_op_desc(layer_norm->Op()->Block()); - fused_multi_transformer_op_desc.SetType("fused_multi_transformer"); + fused_multi_transformer_op_desc.SetType(enable_int8 + ? "fused_multi_transformer_int8" + : "fused_multi_transformer"); // 1. Input setting fused_multi_transformer_op_desc.SetInput("X", {input0->Name()}); @@ -1281,7 +1423,7 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph, VarDesc cache_kv_desc("cache_kv" + std::to_string(layer_idx)); // FIXME: only support max_seq_len <= 1024 cache_kv_desc.SetDataType( - framework::TransToProtoVarType(wq_tensor->dtype())); + framework::TransToProtoVarType(bq_tensor->dtype())); cache_kv_desc.SetPersistable(false); auto* cache_kv = graph->CreateVarNode(&cache_kv_desc); @@ -1296,7 +1438,7 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph, fill_const_op_desc.SetAttr("value", 0); fill_const_op_desc.SetAttr( "dtype", - static_cast(framework::TransToProtoVarType(wq_tensor->dtype()))); + static_cast(framework::TransToProtoVarType(bq_tensor->dtype()))); auto* fill_const_op = graph->CreateOpNode(&fill_const_op_desc); fused_multi_transformer_op_desc.SetInput("CacheKV", {cache_kv->Name()}); @@ -1333,8 +1475,123 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph, fused_multi_transformer_op_desc.SetAttr("is_test", true); fused_multi_transformer_op_desc.SetAttr("dropout_rate", 0.0f); + // Quantization attribute/Input + if (enable_int8) { + // Set input scale + std::string qkv_input_name = matmul0_op->Input("X")[0]; + auto qkv_in_scale = PADDLE_GET_CONST( + float, matmul0_op->GetAttr("Input_scale_" + qkv_input_name)); + std::string out_linear_input_name = matmul_linear_op->Input("X")[0]; + auto out_linear_in_scale = PADDLE_GET_CONST( + float, + matmul_linear_op->GetAttr("Input_scale_" + out_linear_input_name)); + std::string ffn0_input_name = ffn_matmul_0_op->Input("X")[0]; + auto ffn0_in_scale = PADDLE_GET_CONST( + float, ffn_matmul_0_op->GetAttr("Input_scale_" + ffn0_input_name)); + std::string ffn1_input_name = ffn_matmul_1_op->Input("X")[0]; + auto ffn1_in_scale = PADDLE_GET_CONST( + float, ffn_matmul_1_op->GetAttr("Input_scale_" + ffn1_input_name)); + + // Calc outscale and Set them + auto qkv_weight_scale = + PADDLE_GET_CONST(float, matmul0_op->GetAttr("weight_scale")); + auto out_weight_scale = + PADDLE_GET_CONST(float, matmul_linear_op->GetAttr("weight_scale")); + auto ffn0_weight_scale = + PADDLE_GET_CONST(float, ffn_matmul_0_op->GetAttr("weight_scale")); + auto ffn1_weight_scale = + PADDLE_GET_CONST(float, ffn_matmul_1_op->GetAttr("weight_scale")); + + auto qkv_out_scales = std::vector( + 3 * dim_embed, (qkv_weight_scale / 127.0f) * (qkv_in_scale / 127.0f)); + auto out_out_scales = std::vector( + dim_embed, + (out_weight_scale / 127.0f) * (out_linear_in_scale / 127.0f)); + auto ffn0_out_scales = std::vector( + 4 * dim_embed, + (ffn0_weight_scale / 127.0f) * (ffn0_in_scale / 127.0f)); + auto ffn1_out_scales = std::vector( + dim_embed, (ffn1_weight_scale / 127.0f) * (ffn1_in_scale / 127.0f)); + + // Inverse input scale + qkv_in_scale = 1.0f / qkv_in_scale; + out_linear_in_scale = 1.0f / out_linear_in_scale; + ffn0_in_scale = 1.0f / ffn0_in_scale; + ffn1_in_scale = 1.0f / ffn1_in_scale; + + fused_multi_transformer_op_desc.SetAttr("qkv_in_scale", + std::vector{qkv_in_scale}); + fused_multi_transformer_op_desc.SetAttr( + "out_linear_in_scale", std::vector{out_linear_in_scale}); + fused_multi_transformer_op_desc.SetAttr( + "ffn1_in_scale", std::vector{ffn0_in_scale}); + fused_multi_transformer_op_desc.SetAttr( + "ffn2_in_scale", std::vector{ffn1_in_scale}); + + auto qkv_out_scale_var = scope->Var(matmul0_w->Name() + "_out_scale"); + auto out_out_scale_var = + scope->Var(matmul_linear_w->Name() + "_out_scale"); + auto ffn0_out_scale_var = + scope->Var(ffn_matmul0_w->Name() + "_out_scale"); + auto ffn1_out_scale_var = + scope->Var(ffn_matmul1_w->Name() + "_out_scale"); + + auto qkv_out_scale_data = + qkv_out_scale_var->GetMutable() + ->mutable_data({3 * dim_embed}, platform::CPUPlace()); + memcpy(qkv_out_scale_data, + qkv_out_scales.data(), + qkv_out_scales.size() * sizeof(float)); + fused_multi_transformer_op_desc.SetInput( + "QKVOutScale", {matmul0_w->Name() + "_out_scale"}); + + auto out_out_scale_data = + out_out_scale_var->GetMutable() + ->mutable_data({dim_embed}, platform::CPUPlace()); + memcpy(out_out_scale_data, + out_out_scales.data(), + out_out_scales.size() * sizeof(float)); + fused_multi_transformer_op_desc.SetInput( + "OutLinearOutScale", {matmul_linear_w->Name() + "_out_scale"}); + + auto ffn0_out_scale_data = + ffn0_out_scale_var->GetMutable() + ->mutable_data({4 * dim_embed}, platform::CPUPlace()); + memcpy(ffn0_out_scale_data, + ffn0_out_scales.data(), + ffn0_out_scales.size() * sizeof(float)); + fused_multi_transformer_op_desc.SetInput( + "FFN1OutScale", {ffn_matmul0_w->Name() + "_out_scale"}); + + auto ffn1_out_scale_data = + ffn1_out_scale_var->GetMutable() + ->mutable_data({dim_embed}, platform::CPUPlace()); + memcpy(ffn1_out_scale_data, + ffn1_out_scales.data(), + ffn1_out_scales.size() * sizeof(float)); + fused_multi_transformer_op_desc.SetInput( + "FFN2OutScale", {ffn_matmul1_w->Name() + "_out_scale"}); + } + auto* fused_multi_transformer = graph->CreateOpNode(&fused_multi_transformer_op_desc); + + if (enable_int8) { + auto qkv_out_scale_node = + CreatePersistableVarNode(graph, matmul0_w->Name() + "_out_scale"); + auto out_out_scale_node = CreatePersistableVarNode( + graph, matmul_linear_w->Name() + "_out_scale"); + auto ffn0_out_scale_node = + CreatePersistableVarNode(graph, ffn_matmul0_w->Name() + "_out_scale"); + auto ffn1_out_scale_node = + CreatePersistableVarNode(graph, ffn_matmul1_w->Name() + "_out_scale"); + + IR_NODE_LINK_TO(qkv_out_scale_node, fused_multi_transformer); + IR_NODE_LINK_TO(out_out_scale_node, fused_multi_transformer); + IR_NODE_LINK_TO(ffn0_out_scale_node, fused_multi_transformer); + IR_NODE_LINK_TO(ffn1_out_scale_node, fused_multi_transformer); + } + IR_NODE_LINK_TO(input0, fused_multi_transformer); IR_NODE_LINK_TO(layer_norm_scale, fused_multi_transformer); IR_NODE_LINK_TO(layer_norm_bias, fused_multi_transformer); @@ -1622,6 +1879,7 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph, layer_norm_bias, layer_norm_mean, layer_norm_variance, + matmul0, matmul0_w, matmul1_w, matmul2_w, @@ -1632,6 +1890,7 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph, transpose2_2_out, eltadd_qk_b, reshape2_0, + matmul_linear, matmul_linear_w, eltadd_linear_b, while0, @@ -1640,7 +1899,9 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph, ffn_layer_norm_bias, ffn_layer_norm_mean, ffn_layer_norm_variance, + ffn_matmul0, ffn_matmul0_w, + ffn_matmul1, ffn_matmul1_w, ffn_eltadd0_b, ffn_eltadd1_b, @@ -1892,6 +2153,12 @@ int FusedMultiTransformerEncoderFuseQKVPass::BuildFusion( Graph* graph, const std::string& name_scope, Scope* scope) const { GraphPatternDetector gpd; auto* pattern = gpd.mutable_pattern(); + bool enable_int8 = graph->Get("enable_int8"); + if (enable_int8) { + VLOG(3) << "FusedMultiTransformerEncoderFuseQKVPass with int8"; + } else { + VLOG(3) << "FusedMultiTransformerEncoderFuseQKVPass with fp"; + } // Create pattern. patterns::FusedMultiTransformerEncoderFuseQKVPattern @@ -1905,12 +2172,14 @@ int FusedMultiTransformerEncoderFuseQKVPass::BuildFusion( Node* layer_norm_bias, Node* layer_norm_mean, Node* layer_norm_variance, + Node* matmul0, Node* matmul0_w, Node* eltadd0_b, Node* split0_k_out, Node* split0_v_out, Node* eltadd_qk_b, Node* reshape2_0, + Node* matmul_linear, Node* matmul_linear_w, Node* eltadd_linear_b, Node* while0, @@ -1919,7 +2188,9 @@ int FusedMultiTransformerEncoderFuseQKVPass::BuildFusion( Node* ffn_layer_norm_bias, Node* ffn_layer_norm_mean, Node* ffn_layer_norm_variance, + Node* ffn_matmul0, Node* ffn_matmul0_w, + Node* ffn_matmul1, Node* ffn_matmul1_w, Node* ffn_eltadd0_b, Node* ffn_eltadd1_b, @@ -1932,7 +2203,14 @@ int FusedMultiTransformerEncoderFuseQKVPass::BuildFusion( PADDLE_GET_CONST(std::vector, reshape_desc->GetAttr("shape")) .at(3) / 3; // 3 for qkv - int dim_embed = num_head * dim_head; + auto* layer_norm_bias_tensor = + scope->FindVar(layer_norm_bias->Name())->GetMutable(); + int dim_embed = layer_norm_bias_tensor->dims()[0]; + + auto* matmul0_op = matmul0->Op(); + auto* matmul_linear_op = matmul_linear->Op(); + auto* ffn_matmul_0_op = ffn_matmul0->Op(); + auto* ffn_matmul_1_op = ffn_matmul1->Op(); // Calc index of transformer layer by LayerNorm Scale name // This calculation assumes: @@ -1948,21 +2226,27 @@ int FusedMultiTransformerEncoderFuseQKVPass::BuildFusion( auto* qkv_b_tensor = scope->FindVar(eltadd0_b->Name())->GetMutable(); - if (qkv_w_tensor->dtype() == phi::DataType::FLOAT32) { - QKVWeightsProcessFuseQKV( - qkv_w_tensor, qkv_b_tensor, num_head, dim_head, dim_embed); - } else if (qkv_w_tensor->dtype() == phi::DataType::FLOAT16) { - QKVWeightsProcessFuseQKV( - qkv_w_tensor, qkv_b_tensor, num_head, dim_head, dim_embed); - } else { - PADDLE_THROW(platform::errors::Unavailable( - "fused_multi_transformer not supported weight dtype. " - "we now only support fp32 and fp16.")); + QKVWeightsBiasProcessFuseQKV( + qkv_w_tensor, qkv_b_tensor, num_head, dim_head, dim_embed); + + if (enable_int8) { + auto* out_linear_w_tensor = scope->FindVar(matmul_linear_w->Name()) + ->GetMutable(); + auto* ffn0_w_tensor = + scope->FindVar(ffn_matmul0_w->Name())->GetMutable(); + auto* ffn1_w_tensor = + scope->FindVar(ffn_matmul1_w->Name())->GetMutable(); + + TransposeWeights(out_linear_w_tensor); + TransposeWeights(ffn0_w_tensor); + TransposeWeights(ffn1_w_tensor); } // create fused_multi_transformer OpDesc fused_multi_transformer_op_desc(layer_norm->Op()->Block()); - fused_multi_transformer_op_desc.SetType("fused_multi_transformer"); + fused_multi_transformer_op_desc.SetType(enable_int8 + ? "fused_multi_transformer_int8" + : "fused_multi_transformer"); // 1. Input setting fused_multi_transformer_op_desc.SetInput("X", {input0->Name()}); @@ -1982,7 +2266,7 @@ int FusedMultiTransformerEncoderFuseQKVPass::BuildFusion( VarDesc cache_kv_desc("cache_kv" + std::to_string(layer_idx)); // FIXME: only support max_seq_len <= 1024 cache_kv_desc.SetDataType( - framework::TransToProtoVarType(qkv_w_tensor->dtype())); + framework::TransToProtoVarType(qkv_b_tensor->dtype())); cache_kv_desc.SetPersistable(false); auto* cache_kv = graph->CreateVarNode(&cache_kv_desc); @@ -1997,7 +2281,7 @@ int FusedMultiTransformerEncoderFuseQKVPass::BuildFusion( fill_const_op_desc.SetAttr("value", 0); fill_const_op_desc.SetAttr("dtype", static_cast(framework::TransToProtoVarType( - qkv_w_tensor->dtype()))); + qkv_b_tensor->dtype()))); auto* fill_const_op = graph->CreateOpNode(&fill_const_op_desc); fused_multi_transformer_op_desc.SetInput("CacheKV", {cache_kv->Name()}); @@ -2035,8 +2319,125 @@ int FusedMultiTransformerEncoderFuseQKVPass::BuildFusion( fused_multi_transformer_op_desc.SetAttr("is_test", true); fused_multi_transformer_op_desc.SetAttr("dropout_rate", 0.0f); + // Quantization attribute/Input + if (enable_int8) { + // Set input scale + std::string qkv_input_name = matmul0_op->Input("X")[0]; + auto qkv_in_scale = PADDLE_GET_CONST( + float, matmul0_op->GetAttr("Input_scale_" + qkv_input_name)); + std::string out_linear_input_name = matmul_linear_op->Input("X")[0]; + auto out_linear_in_scale = PADDLE_GET_CONST( + float, + matmul_linear_op->GetAttr("Input_scale_" + out_linear_input_name)); + std::string ffn0_input_name = ffn_matmul_0_op->Input("X")[0]; + auto ffn0_in_scale = PADDLE_GET_CONST( + float, ffn_matmul_0_op->GetAttr("Input_scale_" + ffn0_input_name)); + std::string ffn1_input_name = ffn_matmul_1_op->Input("X")[0]; + auto ffn1_in_scale = PADDLE_GET_CONST( + float, ffn_matmul_1_op->GetAttr("Input_scale_" + ffn1_input_name)); + + // Calc outscale and Set them + // TODO(wufeisheng): Currently just match layer-wise weight scale, where + // channel-wise weight scale should also be surpported. + auto qkv_weight_scale = + PADDLE_GET_CONST(float, matmul0_op->GetAttr("weight_scale")); + auto out_weight_scale = + PADDLE_GET_CONST(float, matmul_linear_op->GetAttr("weight_scale")); + auto ffn0_weight_scale = + PADDLE_GET_CONST(float, ffn_matmul_0_op->GetAttr("weight_scale")); + auto ffn1_weight_scale = + PADDLE_GET_CONST(float, ffn_matmul_1_op->GetAttr("weight_scale")); + + auto qkv_out_scales = std::vector( + 3 * dim_embed, (qkv_weight_scale / 127.0f) * (qkv_in_scale / 127.0f)); + auto out_out_scales = std::vector( + dim_embed, + (out_weight_scale / 127.0f) * (out_linear_in_scale / 127.0f)); + auto ffn0_out_scales = std::vector( + 4 * dim_embed, + (ffn0_weight_scale / 127.0f) * (ffn0_in_scale / 127.0f)); + auto ffn1_out_scales = std::vector( + dim_embed, (ffn1_weight_scale / 127.0f) * (ffn1_in_scale / 127.0f)); + + // Inverse input scale + qkv_in_scale = 1.0f / qkv_in_scale; + out_linear_in_scale = 1.0f / out_linear_in_scale; + ffn0_in_scale = 1.0f / ffn0_in_scale; + ffn1_in_scale = 1.0f / ffn1_in_scale; + + fused_multi_transformer_op_desc.SetAttr("qkv_in_scale", + std::vector{qkv_in_scale}); + fused_multi_transformer_op_desc.SetAttr( + "out_linear_in_scale", std::vector{out_linear_in_scale}); + fused_multi_transformer_op_desc.SetAttr( + "ffn1_in_scale", std::vector{ffn0_in_scale}); + fused_multi_transformer_op_desc.SetAttr( + "ffn2_in_scale", std::vector{ffn1_in_scale}); + + auto qkv_out_scale_var = scope->Var(matmul0_w->Name() + "_out_scale"); + auto out_out_scale_var = + scope->Var(matmul_linear_w->Name() + "_out_scale"); + auto ffn0_out_scale_var = + scope->Var(ffn_matmul0_w->Name() + "_out_scale"); + auto ffn1_out_scale_var = + scope->Var(ffn_matmul1_w->Name() + "_out_scale"); + + auto qkv_out_scale_data = + qkv_out_scale_var->GetMutable() + ->mutable_data({3 * dim_embed}, platform::CPUPlace()); + memcpy(qkv_out_scale_data, + qkv_out_scales.data(), + qkv_out_scales.size() * sizeof(float)); + fused_multi_transformer_op_desc.SetInput( + "QKVOutScale", {matmul0_w->Name() + "_out_scale"}); + + auto out_out_scale_data = + out_out_scale_var->GetMutable() + ->mutable_data({dim_embed}, platform::CPUPlace()); + memcpy(out_out_scale_data, + out_out_scales.data(), + out_out_scales.size() * sizeof(float)); + fused_multi_transformer_op_desc.SetInput( + "OutLinearOutScale", {matmul_linear_w->Name() + "_out_scale"}); + + auto ffn0_out_scale_data = + ffn0_out_scale_var->GetMutable() + ->mutable_data({4 * dim_embed}, platform::CPUPlace()); + memcpy(ffn0_out_scale_data, + ffn0_out_scales.data(), + ffn0_out_scales.size() * sizeof(float)); + fused_multi_transformer_op_desc.SetInput( + "FFN1OutScale", {ffn_matmul0_w->Name() + "_out_scale"}); + + auto ffn1_out_scale_data = + ffn1_out_scale_var->GetMutable() + ->mutable_data({dim_embed}, platform::CPUPlace()); + memcpy(ffn1_out_scale_data, + ffn1_out_scales.data(), + ffn1_out_scales.size() * sizeof(float)); + fused_multi_transformer_op_desc.SetInput( + "FFN2OutScale", {ffn_matmul1_w->Name() + "_out_scale"}); + } + auto* fused_multi_transformer = graph->CreateOpNode(&fused_multi_transformer_op_desc); + + if (enable_int8) { + auto qkv_out_scale_node = + CreatePersistableVarNode(graph, matmul0_w->Name() + "_out_scale"); + auto out_out_scale_node = CreatePersistableVarNode( + graph, matmul_linear_w->Name() + "_out_scale"); + auto ffn0_out_scale_node = + CreatePersistableVarNode(graph, ffn_matmul0_w->Name() + "_out_scale"); + auto ffn1_out_scale_node = + CreatePersistableVarNode(graph, ffn_matmul1_w->Name() + "_out_scale"); + + IR_NODE_LINK_TO(qkv_out_scale_node, fused_multi_transformer); + IR_NODE_LINK_TO(out_out_scale_node, fused_multi_transformer); + IR_NODE_LINK_TO(ffn0_out_scale_node, fused_multi_transformer); + IR_NODE_LINK_TO(ffn1_out_scale_node, fused_multi_transformer); + } + IR_NODE_LINK_TO(input0, fused_multi_transformer); IR_NODE_LINK_TO(layer_norm_scale, fused_multi_transformer); IR_NODE_LINK_TO(layer_norm_bias, fused_multi_transformer); @@ -2290,12 +2691,14 @@ int FusedMultiTransformerEncoderFuseQKVPass::BuildFusion( layer_norm_bias, layer_norm_mean, layer_norm_variance, + matmul0, matmul0_w, eltadd0_b, split0_k_out, split0_v_out, eltadd_qk_b, reshape2_0, + matmul_linear, matmul_linear_w, eltadd_linear_b, while0, @@ -2304,7 +2707,9 @@ int FusedMultiTransformerEncoderFuseQKVPass::BuildFusion( ffn_layer_norm_bias, ffn_layer_norm_mean, ffn_layer_norm_variance, + ffn_matmul0, ffn_matmul0_w, + ffn_matmul1, ffn_matmul1_w, ffn_eltadd0_b, ffn_eltadd1_b, @@ -2546,6 +2951,12 @@ int MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::BuildFusion( Graph* graph, const std::string& name_scope, Scope* scope) const { GraphPatternDetector gpd; auto* pattern = gpd.mutable_pattern(); + bool enable_int8 = graph->Get("enable_int8"); + if (enable_int8) { + VLOG(3) << "MultiDevicesFusedMultiTransformerEncoderFuseQKVPass with int8"; + } else { + VLOG(3) << "MultiDevicesFusedMultiTransformerEncoderFuseQKVPass with fp"; + } // Create pattern. patterns::MultiDevicesFusedMultiTransformerEncoderFuseQKVPattern @@ -2560,12 +2971,14 @@ int MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::BuildFusion( Node* layer_norm_mean, Node* layer_norm_variance, Node* c_identity, + Node* matmul0, Node* matmul0_w, Node* eltadd0_b, Node* split0_k_out, Node* split0_v_out, Node* eltadd_qk_b, Node* reshape2_0, + Node* matmul_linear, Node* matmul_linear_w, Node* eltadd_linear_b, Node* while0, @@ -2574,7 +2987,10 @@ int MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::BuildFusion( Node* ffn_layer_norm_bias, Node* ffn_layer_norm_mean, Node* ffn_layer_norm_variance, + Node* ffn_c_identity, + Node* ffn_matmul0, Node* ffn_matmul0_w, + Node* ffn_matmul1, Node* ffn_matmul1_w, Node* ffn_eltadd0_b, Node* ffn_eltadd1_b, @@ -2588,6 +3004,11 @@ int MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::BuildFusion( .at(3) / 3; // 3 for qkv + auto* matmul0_op = matmul0->Op(); + auto* matmul_linear_op = matmul_linear->Op(); + auto* ffn_matmul_0_op = ffn_matmul0->Op(); + auto* ffn_matmul_1_op = ffn_matmul1->Op(); + // Calc index of transformer layer by LayerNorm Scale name // This calculation assumes: // 1. no LayerNorm before all transformer layer @@ -2602,23 +3023,31 @@ int MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::BuildFusion( auto* qkv_b_tensor = scope->FindVar(eltadd0_b->Name())->GetMutable(); - int dim_embed = qkv_w_tensor->dims()[0]; + auto* layer_norm_bias_tensor = + scope->FindVar(layer_norm_bias->Name())->GetMutable(); + int dim_embed = layer_norm_bias_tensor->dims()[0]; - if (qkv_w_tensor->dtype() == phi::DataType::FLOAT32) { - QKVWeightsProcessFuseQKV( - qkv_w_tensor, qkv_b_tensor, num_head, dim_head, dim_embed); - } else if (qkv_w_tensor->dtype() == phi::DataType::FLOAT16) { - QKVWeightsProcessFuseQKV( - qkv_w_tensor, qkv_b_tensor, num_head, dim_head, dim_embed); - } else { - PADDLE_THROW(platform::errors::Unavailable( - "fused_multi_transformer not supported weight dtype. " - "we now only support fp32 and fp16.")); + QKVWeightsBiasProcessFuseQKV( + qkv_w_tensor, qkv_b_tensor, num_head, dim_head, dim_embed); + + if (enable_int8) { + auto* out_linear_w_tensor = scope->FindVar(matmul_linear_w->Name()) + ->GetMutable(); + auto* ffn0_w_tensor = + scope->FindVar(ffn_matmul0_w->Name())->GetMutable(); + auto* ffn1_w_tensor = + scope->FindVar(ffn_matmul1_w->Name())->GetMutable(); + + TransposeWeights(out_linear_w_tensor); + TransposeWeights(ffn0_w_tensor); + TransposeWeights(ffn1_w_tensor); } // create fused_multi_transformer OpDesc fused_multi_transformer_op_desc(layer_norm->Op()->Block()); - fused_multi_transformer_op_desc.SetType("fused_multi_transformer"); + fused_multi_transformer_op_desc.SetType(enable_int8 + ? "fused_multi_transformer_int8" + : "fused_multi_transformer"); // 1. Input setting fused_multi_transformer_op_desc.SetInput("X", {input0->Name()}); @@ -2638,7 +3067,7 @@ int MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::BuildFusion( VarDesc cache_kv_desc("cache_kv" + std::to_string(layer_idx)); // FIXME: only support max_seq_len <= 1024 cache_kv_desc.SetDataType( - framework::TransToProtoVarType(qkv_w_tensor->dtype())); + framework::TransToProtoVarType(qkv_b_tensor->dtype())); cache_kv_desc.SetPersistable(false); auto* cache_kv = graph->CreateVarNode(&cache_kv_desc); @@ -2653,7 +3082,7 @@ int MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::BuildFusion( fill_const_op_desc.SetAttr("value", 0); fill_const_op_desc.SetAttr("dtype", static_cast(framework::TransToProtoVarType( - qkv_w_tensor->dtype()))); + qkv_b_tensor->dtype()))); auto* fill_const_op = graph->CreateOpNode(&fill_const_op_desc); fused_multi_transformer_op_desc.SetInput("CacheKV", {cache_kv->Name()}); @@ -2696,8 +3125,129 @@ int MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::BuildFusion( fused_multi_transformer_op_desc.SetAttr("ring_id", c_identity_op->GetAttr("ring_id")); + // Quantization attribute/Input + if (enable_int8) { + // Set input scale + std::string matmul_input_scale_suffix = c_identity_op->Input("X")[0]; + auto qkv_in_scale = PADDLE_GET_CONST( + float, + c_identity_op->GetAttr("Input_scale_" + matmul_input_scale_suffix)); + + std::string out_linear_input_name = matmul_linear_op->Input("X")[0]; + auto out_linear_in_scale = PADDLE_GET_CONST( + float, + matmul_linear_op->GetAttr("Input_scale_" + out_linear_input_name)); + + auto* ffn_c_identity_op = ffn_c_identity->Op(); + std::string ffn_input_scale_suffix = ffn_c_identity_op->Input("X")[0]; + auto ffn0_in_scale = PADDLE_GET_CONST( + float, + ffn_c_identity_op->GetAttr("Input_scale_" + ffn_input_scale_suffix)); + + std::string ffn1_input_name = ffn_matmul_1_op->Input("X")[0]; + auto ffn1_in_scale = PADDLE_GET_CONST( + float, ffn_matmul_1_op->GetAttr("Input_scale_" + ffn1_input_name)); + + // Calc outscale and Set them + auto qkv_weight_scale = + PADDLE_GET_CONST(float, matmul0_op->GetAttr("weight_scale")); + auto out_weight_scale = + PADDLE_GET_CONST(float, matmul_linear_op->GetAttr("weight_scale")); + auto ffn0_weight_scale = + PADDLE_GET_CONST(float, ffn_matmul_0_op->GetAttr("weight_scale")); + auto ffn1_weight_scale = + PADDLE_GET_CONST(float, ffn_matmul_1_op->GetAttr("weight_scale")); + + auto qkv_out_scales = std::vector( + 3 * dim_embed, (qkv_weight_scale / 127.0f) * (qkv_in_scale / 127.0f)); + auto out_out_scales = std::vector( + dim_embed, + (out_weight_scale / 127.0f) * (out_linear_in_scale / 127.0f)); + auto ffn0_out_scales = std::vector( + 4 * dim_embed, + (ffn0_weight_scale / 127.0f) * (ffn0_in_scale / 127.0f)); + auto ffn1_out_scales = std::vector( + dim_embed, (ffn1_weight_scale / 127.0f) * (ffn1_in_scale / 127.0f)); + + // Inverse input scale + qkv_in_scale = 1.0f / qkv_in_scale; + out_linear_in_scale = 1.0f / out_linear_in_scale; + ffn0_in_scale = 1.0f / ffn0_in_scale; + ffn1_in_scale = 1.0f / ffn1_in_scale; + + fused_multi_transformer_op_desc.SetAttr("qkv_in_scale", + std::vector{qkv_in_scale}); + fused_multi_transformer_op_desc.SetAttr( + "out_linear_in_scale", std::vector{out_linear_in_scale}); + fused_multi_transformer_op_desc.SetAttr( + "ffn1_in_scale", std::vector{ffn0_in_scale}); + fused_multi_transformer_op_desc.SetAttr( + "ffn2_in_scale", std::vector{ffn1_in_scale}); + + auto qkv_out_scale_var = scope->Var(matmul0_w->Name() + "_out_scale"); + auto out_out_scale_var = + scope->Var(matmul_linear_w->Name() + "_out_scale"); + auto ffn0_out_scale_var = + scope->Var(ffn_matmul0_w->Name() + "_out_scale"); + auto ffn1_out_scale_var = + scope->Var(ffn_matmul1_w->Name() + "_out_scale"); + + auto qkv_out_scale_data = + qkv_out_scale_var->GetMutable() + ->mutable_data({3 * dim_embed}, platform::CPUPlace()); + memcpy(qkv_out_scale_data, + qkv_out_scales.data(), + qkv_out_scales.size() * sizeof(float)); + fused_multi_transformer_op_desc.SetInput( + "QKVOutScale", {matmul0_w->Name() + "_out_scale"}); + + auto out_out_scale_data = + out_out_scale_var->GetMutable() + ->mutable_data({dim_embed}, platform::CPUPlace()); + memcpy(out_out_scale_data, + out_out_scales.data(), + out_out_scales.size() * sizeof(float)); + fused_multi_transformer_op_desc.SetInput( + "OutLinearOutScale", {matmul_linear_w->Name() + "_out_scale"}); + + auto ffn0_out_scale_data = + ffn0_out_scale_var->GetMutable() + ->mutable_data({4 * dim_embed}, platform::CPUPlace()); + memcpy(ffn0_out_scale_data, + ffn0_out_scales.data(), + ffn0_out_scales.size() * sizeof(float)); + fused_multi_transformer_op_desc.SetInput( + "FFN1OutScale", {ffn_matmul0_w->Name() + "_out_scale"}); + + auto ffn1_out_scale_data = + ffn1_out_scale_var->GetMutable() + ->mutable_data({dim_embed}, platform::CPUPlace()); + memcpy(ffn1_out_scale_data, + ffn1_out_scales.data(), + ffn1_out_scales.size() * sizeof(float)); + fused_multi_transformer_op_desc.SetInput( + "FFN2OutScale", {ffn_matmul1_w->Name() + "_out_scale"}); + } + auto* fused_multi_transformer = graph->CreateOpNode(&fused_multi_transformer_op_desc); + + if (enable_int8) { + auto qkv_out_scale_node = + CreatePersistableVarNode(graph, matmul0_w->Name() + "_out_scale"); + auto out_out_scale_node = CreatePersistableVarNode( + graph, matmul_linear_w->Name() + "_out_scale"); + auto ffn0_out_scale_node = + CreatePersistableVarNode(graph, ffn_matmul0_w->Name() + "_out_scale"); + auto ffn1_out_scale_node = + CreatePersistableVarNode(graph, ffn_matmul1_w->Name() + "_out_scale"); + + IR_NODE_LINK_TO(qkv_out_scale_node, fused_multi_transformer); + IR_NODE_LINK_TO(out_out_scale_node, fused_multi_transformer); + IR_NODE_LINK_TO(ffn0_out_scale_node, fused_multi_transformer); + IR_NODE_LINK_TO(ffn1_out_scale_node, fused_multi_transformer); + } + IR_NODE_LINK_TO(input0, fused_multi_transformer); IR_NODE_LINK_TO(layer_norm_scale, fused_multi_transformer); IR_NODE_LINK_TO(layer_norm_bias, fused_multi_transformer); @@ -2977,12 +3527,14 @@ int MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::BuildFusion( layer_norm_mean, layer_norm_variance, c_identity, + matmul0, matmul0_w, eltadd0_b, split0_k_out, split0_v_out, eltadd_qk_b, reshape2_0, + matmul_linear, matmul_linear_w, eltadd_linear_b, while0, @@ -2991,7 +3543,10 @@ int MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::BuildFusion( ffn_layer_norm_bias, ffn_layer_norm_mean, ffn_layer_norm_variance, + ffn_c_identity, + ffn_matmul0, ffn_matmul0_w, + ffn_matmul1, ffn_matmul1_w, ffn_eltadd0_b, ffn_eltadd1_b, diff --git a/paddle/fluid/framework/ir/fused_multi_transformer_encoder_pass_tester.cc b/paddle/fluid/framework/ir/fused_multi_transformer_encoder_pass_tester.cc index 2e356d0dc1997..08f4dc06f58aa 100644 --- a/paddle/fluid/framework/ir/fused_multi_transformer_encoder_pass_tester.cc +++ b/paddle/fluid/framework/ir/fused_multi_transformer_encoder_pass_tester.cc @@ -188,6 +188,7 @@ TEST(FusedMultiTransformerEncoderPass, basic) { std::unique_ptr graph(new ir::Graph(layers.main_program())); graph->Set("__param_scope__", CreateParamScope()); + graph->Set("enable_int8", new bool(false)); auto pass = PassRegistry::Instance().Get("fused_multi_transformer_encoder_pass"); @@ -334,6 +335,7 @@ TEST(FusedMultiTransformerEncoderFuseQKVPass, basic) { layers.elementwise_add(attention_out, ffn_eltadd1_out); std::unique_ptr graph(new ir::Graph(layers.main_program())); + graph->Set("enable_int8", new bool(false)); graph->Set("__param_scope__", CreateParamScope()); auto pass = PassRegistry::Instance().Get( @@ -489,6 +491,7 @@ TEST(MultiDevicesFusedMultiTransformerEncoderFuseQKVPass, basic) { layers.elementwise_add(attention_out, ffn_eltadd1_out); std::unique_ptr graph(new ir::Graph(layers.main_program())); + graph->Set("enable_int8", new bool(false)); graph->Set("__param_scope__", CreateParamScope()); auto pass = PassRegistry::Instance().Get( diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index 7f509d64b5c23..753c169f8f6d6 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -3175,6 +3175,73 @@ void patterns::DeleteWeightQuantDequantLinearOpPattern::operator()() { any_op2->LinksFrom({weight_dequantize_linear_op_out}); } +void patterns::DeleteWeightDequantLinearOpEncoderPattern::operator()() { + auto weight_dequantize_linear_op_x = + pattern->NewNode(weight_dequantize_linear_op_x_repr()) + ->AsInput() + ->assert_is_op_input("dequantize_linear", "X") + ->assert_is_persistable_var(); + + auto weight_dequantize_linear_op_scale = + pattern->NewNode(weight_dequantize_linear_op_scale_repr()) + ->AsInput() + ->assert_is_op_input("dequantize_linear", "Scale") + ->assert_is_persistable_var(); + + auto weight_dequantize_linear_op = + pattern->NewNode(weight_dequantize_linear_op_repr()) + ->assert_is_op("dequantize_linear"); + + auto weight_dequantize_linear_op_out = + pattern->NewNode(weight_dequantize_linear_op_out_repr()) + ->AsIntermediate() + ->assert_is_op_output("dequantize_linear", "Y"); + + auto any_op2 = pattern->NewNode(any_op2_repr())->assert_is_op()->AsOutput(); + + // while loop + auto *while0 = + pattern->NewNode(while0_repr())->assert_is_op("while")->AsOutput(); + while0->LinksFrom({weight_dequantize_linear_op_out}); + + weight_dequantize_linear_op + ->LinksFrom( + {weight_dequantize_linear_op_x, weight_dequantize_linear_op_scale}) + .LinksTo({weight_dequantize_linear_op_out}); + any_op2->LinksFrom({weight_dequantize_linear_op_out}); +} + +void patterns::DeleteWeightDequantLinearOpDecoderPattern::operator()() { + auto weight_dequantize_linear_op_x = + pattern->NewNode(weight_dequantize_linear_op_x_repr()) + ->AsInput() + ->assert_is_op_input("dequantize_linear", "X") + ->assert_is_persistable_var(); + + auto weight_dequantize_linear_op_scale = + pattern->NewNode(weight_dequantize_linear_op_scale_repr()) + ->AsInput() + ->assert_is_op_input("dequantize_linear", "Scale") + ->assert_is_persistable_var(); + + auto weight_dequantize_linear_op = + pattern->NewNode(weight_dequantize_linear_op_repr()) + ->assert_is_op("dequantize_linear"); + + auto weight_dequantize_linear_op_out = + pattern->NewNode(weight_dequantize_linear_op_out_repr()) + ->AsIntermediate() + ->assert_is_op_output("dequantize_linear", "Y"); + + auto any_op2 = pattern->NewNode(any_op2_repr())->assert_is_op()->AsOutput(); + + weight_dequantize_linear_op + ->LinksFrom( + {weight_dequantize_linear_op_x, weight_dequantize_linear_op_scale}) + .LinksTo({weight_dequantize_linear_op_out}); + any_op2->LinksFrom({weight_dequantize_linear_op_out}); +} + void patterns::DeleteQuantDequantLinearOpPattern::operator()() { auto quantize_linear_op_x = pattern->NewNode(quantize_linear_op_x_repr()) ->AsInput() diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.h b/paddle/fluid/framework/ir/graph_pattern_detector.h index fdff82d30caaa..cb1b9266b1530 100755 --- a/paddle/fluid/framework/ir/graph_pattern_detector.h +++ b/paddle/fluid/framework/ir/graph_pattern_detector.h @@ -1765,6 +1765,39 @@ struct DeleteWeightQuantDequantLinearOpPattern : public PatternBase { PATTERN_DECL_NODE(any_op2); }; +struct DeleteWeightDequantLinearOpEncoderPattern : public PatternBase { + DeleteWeightDequantLinearOpEncoderPattern(PDPattern* pattern, + const std::string& name_scope) + : PatternBase(pattern, + name_scope, + "delete_weight_quant_dequant_linear_op_pattern") {} + + void operator()(); + + PATTERN_DECL_NODE(weight_dequantize_linear_op_x); + PATTERN_DECL_NODE(weight_dequantize_linear_op_scale); + PATTERN_DECL_NODE(while0); + PATTERN_DECL_NODE(weight_dequantize_linear_op); + PATTERN_DECL_NODE(weight_dequantize_linear_op_out); + PATTERN_DECL_NODE(any_op2); +}; + +struct DeleteWeightDequantLinearOpDecoderPattern : public PatternBase { + DeleteWeightDequantLinearOpDecoderPattern(PDPattern* pattern, + const std::string& name_scope) + : PatternBase(pattern, + name_scope, + "delete_weight_quant_dequant_linear_op_pattern") {} + + void operator()(); + + PATTERN_DECL_NODE(weight_dequantize_linear_op_x); + PATTERN_DECL_NODE(weight_dequantize_linear_op_scale); + PATTERN_DECL_NODE(weight_dequantize_linear_op); + PATTERN_DECL_NODE(weight_dequantize_linear_op_out); + PATTERN_DECL_NODE(any_op2); +}; + struct DeleteQuantDequantLinearOpPattern : public PatternBase { DeleteQuantDequantLinearOpPattern(PDPattern* pattern, const std::string& name_scope) diff --git a/paddle/fluid/framework/ir/pass.cc b/paddle/fluid/framework/ir/pass.cc index 4ad93183996fa..5c05a9f27c1fc 100644 --- a/paddle/fluid/framework/ir/pass.cc +++ b/paddle/fluid/framework/ir/pass.cc @@ -46,7 +46,10 @@ static const std::vector support_subgraph_passes = { "fused_multi_transformer_decoder_fuse_qkv_pass", "multi_devices_fused_multi_transformer_encoder_fuse_qkv_pass", "multi_devices_fused_multi_transformer_decoder_fuse_qkv_pass", - "fuse_multi_transformer_layer_pass"}; + "fuse_multi_transformer_layer_pass", + "delete_quant_dequant_linear_op_pass", + "delete_weight_dequant_linear_op_encoder_pass", + "delete_weight_dequant_linear_op_decoder_pass"}; Graph *Pass::Apply(Graph *graph) const { VLOG(10) << "start to apply pass " << Type() << " to graph"; diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index c964ce7e4d0d2..f02776d00f8c7 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -165,6 +165,9 @@ const std::vector kLiteSubgraphPasses({ // running errors. After fusion operator supports low precision, delete this. const std::vector kGpuLowerPrecisionPasses{ "simplify_with_basic_ops_pass", + "delete_quant_dequant_linear_op_pass", + "delete_weight_dequant_linear_op_encoder_pass", + "delete_weight_dequant_linear_op_decoder_pass", "map_depthwise_conv_to_conv_pass", "conv_bn_fuse_pass", "conv_eltwiseadd_bn_fuse_pass", @@ -203,9 +206,12 @@ const std::vector kTrtLowerPrecisionPasses{ GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) { passes_.assign({ // "identity_scale_op_clean_pass", // - "is_test_pass", // - "simplify_with_basic_ops_pass", // - "map_depthwise_conv_to_conv_pass", + "is_test_pass", // + "simplify_with_basic_ops_pass", // + "delete_quant_dequant_linear_op_pass", // + "delete_weight_dequant_linear_op_encoder_pass", // + "delete_weight_dequant_linear_op_decoder_pass", // + "map_depthwise_conv_to_conv_pass", // "conv_bn_fuse_pass", // "conv_eltwiseadd_bn_fuse_pass", // "embedding_eltwise_layernorm_fuse_pass", // diff --git a/paddle/fluid/operators/fused/attn_gemm_int8.h b/paddle/fluid/operators/fused/attn_gemm_int8.h index 98a45deac3c8d..cdbd5b2e0b821 100644 --- a/paddle/fluid/operators/fused/attn_gemm_int8.h +++ b/paddle/fluid/operators/fused/attn_gemm_int8.h @@ -20,6 +20,7 @@ limitations under the License. */ #include "paddle/fluid/operators/fused/quant_dequant_kernel.h" #include "paddle/fluid/platform/device/gpu/gpu_info.h" #include "paddle/fluid/platform/float16.h" +#include "paddle/phi/backends/gpu/gpu_launch_config.h" #include "paddle/phi/kernels/funcs/broadcast_function.h" #include "paddle/phi/kernels/funcs/elementwise_functor.h" @@ -27,6 +28,7 @@ namespace paddle { namespace operators { using Tensor = phi::DenseTensor; +using phi::backends::gpu::GpuLaunchConfig; template class AttnMatmulINT8 { @@ -36,6 +38,9 @@ class AttnMatmulINT8 { : dev_ctx_(dev_ctx), m_(m), n_(n), k_(k), compute_bias_(compute_bias) { auto helper = std::make_shared(m, k, n); helpers_.emplace_back(helper); + gpu_config_ = std::make_unique( + phi::backends::gpu::GetGpuLaunchConfig1D( + dev_ctx, m * n, DequantKernelVecSize)); } ~AttnMatmulINT8() {} @@ -50,7 +55,6 @@ class AttnMatmulINT8 { phi::DenseTensor* bias_out, const float quant_in_scale, const phi::DenseTensor* dequant_out_scale, - const int quant_out_scale_offset, const int quant_round_type = 1, const float quant_max_bound = 127.0, const float quant_min_bound = -127.0) { @@ -74,9 +78,9 @@ class AttnMatmulINT8 { m_, n_, dev_ctx_.stream(), + gpu_config_.get(), quant_in_scale, - dequant_out_scale->data(), - quant_out_scale_offset); + dequant_out_scale->data()); if (compute_bias_) { // bias_out = output + bias @@ -99,11 +103,13 @@ class AttnMatmulINT8 { phi::DenseTensor* input, const phi::DenseTensor* bias, phi::DenseTensor* output, - phi::DenseTensor* bias_out) { + phi::DenseTensor* bias_out, + void* workspace = nullptr) { helpers_[0]->GEMM(input->data(), weight->data(), output->data(), - dev_ctx_.stream()); + dev_ctx_.stream(), + workspace); } // This function is used to execute GEMM, with input and output's types are @@ -115,8 +121,7 @@ class AttnMatmulINT8 { phi::DenseTensor* output, phi::DenseTensor* output_tmp, phi::DenseTensor* bias_out, - const phi::DenseTensor* dequant_out_scale, - const int quant_out_scale_offset) { + const phi::DenseTensor* dequant_out_scale) { helpers_[0]->GEMM(input->data(), weight->data(), output_tmp->data(), @@ -127,9 +132,9 @@ class AttnMatmulINT8 { m_, n_, dev_ctx_.stream(), + gpu_config_.get(), quant_in_scale, - dequant_out_scale->data(), - quant_out_scale_offset); + dequant_out_scale->data()); if (compute_bias_) { // bias_out = output + bias @@ -183,6 +188,7 @@ class AttnMatmulINT8 { int compute_bias_; std::vector> helpers_; + std::unique_ptr gpu_config_; }; } // namespace operators diff --git a/paddle/fluid/operators/fused/cublaslt.h b/paddle/fluid/operators/fused/cublaslt.h index b9cc6b56f13ee..e9728c58b55dc 100644 --- a/paddle/fluid/operators/fused/cublaslt.h +++ b/paddle/fluid/operators/fused/cublaslt.h @@ -24,6 +24,20 @@ namespace dyl = paddle::platform::dynload; namespace paddle { namespace operators { + +struct CublasLtAlgoParam { + int algoId; + int swizzle; + int customOption; + int tile; + int splitK_val; + int reductionScheme; + int stages; + size_t workspace_size; +}; + +const std::map, CublasLtAlgoParam> AlgoParamCache{}; + class CublasLtHelper { public: CublasLtHelper(int m, int k, int n) @@ -99,38 +113,34 @@ class CublasLtHelper { "cublasLtMatrixLayoutCreate execution error" "refer https://docs.nvidia.com/cuda/cublas/index.html to get more " "information")); - } - ~CublasLtHelper() { - if (handle_) dyl::cublasLtDestroy(handle_); - if (matmul_desc_) dyl::cublasLtMatmulDescDestroy(matmul_desc_); - if (A_desc_) dyl::cublasLtMatrixLayoutDestroy(A_desc_); - if (B_desc_) dyl::cublasLtMatrixLayoutDestroy(B_desc_); - if (C_desc_) dyl::cublasLtMatrixLayoutDestroy(C_desc_); - } - void GEMM(int8_t* A_dev, - const int8_t* B_dev, - int32_t* C_dev, - cudaStream_t stream) { - cublasStatus_t status; +#if CUDA_VERSION >= 11020 -#if __CUDA_ARCH__ >= 800 && CUDA_VERSION >= 11020 - cublasLtMatmulAlgo_t algo; int algoId = 21; int swizzle = 0; int customOption = 0; int tile = 15; int splitK_val = 0; int reductionScheme = 0; -#if CUDA_VERSION >= 11000 int stages = 23; -#endif - -#if CUBLAS_VER_MAJOR < 11 - cudaDataType_t cudaComputeType = CUDA_R_32I; -#else - cublasComputeType_t cudaComputeType = CUBLAS_COMPUTE_32I; -#endif + workspace_size_ = 0; + if (m >= 128) { + tile = 20; + stages = 17; + } + + std::tuple key(m_, k_, n_); + if (AlgoParamCache.count(key) != 0) { + auto value = AlgoParamCache.at(key); + algoId = value.algoId; + swizzle = value.swizzle; + customOption = value.customOption; + tile = value.tile; + splitK_val = value.splitK_val; + reductionScheme = value.reductionScheme; + stages = value.stages; + workspace_size_ = value.workspace_size; + } dyl::cublasLtMatmulAlgoInit(handle_, cudaComputeType, @@ -140,30 +150,43 @@ class CublasLtHelper { CUDA_R_32I, CUDA_R_32I, algoId, - &algo); + &algo_); dyl::cublasLtMatmulAlgoConfigSetAttribute( - &algo, + &algo_, CUBLASLT_ALGO_CONFIG_CUSTOM_OPTION, &(customOption), sizeof(customOption)); dyl::cublasLtMatmulAlgoConfigSetAttribute( - &algo, CUBLASLT_ALGO_CONFIG_TILE_ID, &(tile), sizeof(tile)); - dyl::cublasLtMatmulAlgoConfigSetAttribute(&algo, + &algo_, CUBLASLT_ALGO_CONFIG_TILE_ID, &(tile), sizeof(tile)); + dyl::cublasLtMatmulAlgoConfigSetAttribute(&algo_, CUBLASLT_ALGO_CONFIG_SPLITK_NUM, &(splitK_val), sizeof(splitK_val)); dyl::cublasLtMatmulAlgoConfigSetAttribute( - &algo, CUBLASLT_ALGO_CONFIG_CTA_SWIZZLING, &(swizzle), sizeof(swizzle)); + &algo_, + CUBLASLT_ALGO_CONFIG_CTA_SWIZZLING, + &(swizzle), + sizeof(swizzle)); dyl::cublasLtMatmulAlgoConfigSetAttribute( - &algo, + &algo_, CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME, &(reductionScheme), sizeof(int)); #if CUDA_VERSION >= 11000 dyl::cublasLtMatmulAlgoConfigSetAttribute( - &algo, CUBLASLT_ALGO_CONFIG_STAGES_ID, &(stages), sizeof(stages)); + &algo_, CUBLASLT_ALGO_CONFIG_STAGES_ID, &(stages), sizeof(stages)); #endif #endif + } + ~CublasLtHelper() {} + + void GEMM(int8_t* A_dev, + const int8_t* B_dev, + int32_t* C_dev, + cudaStream_t stream, + void* workspace = nullptr) { + cublasStatus_t status; + status = dyl::cublasLtMatmul(handle_, matmul_desc_, &alpha_, @@ -176,13 +199,15 @@ class CublasLtHelper { C_desc_, C_dev, C_desc_, -#if __CUDA_ARCH__ >= 800 && CUDA_VERSION >= 11020 - &algo, +#if CUDA_VERSION >= 11020 + &algo_, + workspace, + workspace_size_, #else nullptr, -#endif nullptr, 0, +#endif stream); PADDLE_ENFORCE_EQ( status, @@ -199,12 +224,17 @@ class CublasLtHelper { cublasLtMatrixLayout_t A_desc_; cublasLtMatrixLayout_t B_desc_; cublasLtMatrixLayout_t C_desc_; + + cublasLtMatmulAlgo_t algo_; + int32_t alpha_; int32_t beta_; int m_; int k_; int n_; + + size_t workspace_size_; }; } // namespace operators diff --git a/paddle/fluid/operators/fused/fused_dropout_act_bias.h b/paddle/fluid/operators/fused/fused_dropout_act_bias.h index 553fb8d7be604..1156d04b8f557 100644 --- a/paddle/fluid/operators/fused/fused_dropout_act_bias.h +++ b/paddle/fluid/operators/fused/fused_dropout_act_bias.h @@ -86,7 +86,6 @@ __global__ void FusedDropoutActBias( MaskType *mask, const float quant_last_in_scale = 1.0, const float *dequant_out_scale_data = nullptr, - const int quant_out_scale_offset = 0, const float quant_next_in_scale = 1.0, const int quant_round_type = 1, const float quant_max_bound = 127.0, @@ -127,7 +126,6 @@ __global__ void FusedDropoutActBias( act, quant_last_in_scale, dequant_out_scale_data, - quant_out_scale_offset, quant_next_in_scale, quant_round_type, quant_max_bound, @@ -146,7 +144,13 @@ __global__ void FusedActBias(Functor act, const uint64_t cols, const InType *__restrict__ src, const T *__restrict__ bias, - OutType *dst) { + OutType *dst, + const float quant_last_in_scale = 1.0, + const float *dequant_out_scale_data = nullptr, + const float quant_next_in_scale = 1.0, + const int quant_round_type = 1, + const float quant_max_bound = 127.0, + const float quant_min_bound = -127.0) { const int32_t global_thread_idx = blockDim.x * blockIdx.x + threadIdx.x; using LoadT = phi::AlignedVector; using LoadInType = phi::AlignedVector; @@ -156,23 +160,42 @@ __global__ void FusedActBias(Functor act, LoadInType src_vec; LoadT bias_vec; StoreOutType out_vec; + LoadFloat dequant_out_scale_vec; for (int32_t idx = global_thread_idx * VecSize, step = blockDim.x * gridDim.x * VecSize; idx < elem_cnt; idx += step) { const int32_t col_idx = idx % cols; phi::Load(&src[idx], &src_vec); + phi::Load(&dequant_out_scale_data[col_idx], + &dequant_out_scale_vec); if (bias) { phi::Load(&bias[col_idx], &bias_vec); } #pragma unroll for (int32_t unroll_idx = 0; unroll_idx < VecSize; unroll_idx++) { - if (bias) { - out_vec[unroll_idx] = static_cast( - act(static_cast(src_vec[unroll_idx]) + bias_vec[unroll_idx])); + T tmp; + if (std::is_same::value) { + tmp = static_cast(static_cast(src_vec[unroll_idx]) * + dequant_out_scale_vec[unroll_idx]); + if (bias) { + tmp = static_cast(act(tmp + bias_vec[unroll_idx])); + } else { + tmp = static_cast(act(tmp)); + } + out_vec[unroll_idx] = quant_helper(tmp, + quant_next_in_scale, + quant_round_type, + quant_max_bound, + quant_min_bound); } else { - out_vec[unroll_idx] = - static_cast(act(static_cast(src_vec[unroll_idx]))); + if (bias) { + out_vec[unroll_idx] = static_cast( + act(static_cast(src_vec[unroll_idx]) + bias_vec[unroll_idx])); + } else { + out_vec[unroll_idx] = + static_cast(act(static_cast(src_vec[unroll_idx]))); + } } } phi::Store(out_vec, &dst[idx]); @@ -202,7 +225,6 @@ void LaunchDropoutActBias(Functor act_functor, const phi::GPUContext &ctx, const float quant_last_in_scale = 1.0, const float *dequant_out_scale_data = nullptr, - const int quant_out_scale_offset = 0, const float quant_next_in_scale = 1.0, const int quant_round_type = 1, const float quant_max_bound = 127.0, @@ -218,7 +240,7 @@ void LaunchDropoutActBias(Functor act_functor, const int real_vec_size = cols % VecSize == 0 ? VecSize : 1; const auto config = Get1DBlocksAnd2DGrids(ctx, rows, cols, real_vec_size); if (cols % VecSize == 0) { - if (is_test && (dequant_out_scale_data == nullptr)) { + if (is_test) { const int32_t elem_cnt = rows * cols; const int32_t pack_num = elem_cnt / VecSize; const int32_t tmp_cols = cols / VecSize; @@ -227,8 +249,15 @@ void LaunchDropoutActBias(Functor act_functor, const int grid_size = std::max(static_cast(1), (pack_num + block_size - 1) / block_size); FusedActBias - <<>>( - act_functor, elem_cnt, cols, src, bias, dst); + <<>>(act_functor, + elem_cnt, + cols, + src, + bias, + dst, + quant_last_in_scale, + dequant_out_scale_data, + quant_next_in_scale); } else { FusedDropoutActBias <<>>( @@ -246,7 +275,6 @@ void LaunchDropoutActBias(Functor act_functor, mask_data, quant_last_in_scale, dequant_out_scale_data, - quant_out_scale_offset, quant_next_in_scale); } } else { @@ -266,7 +294,6 @@ void LaunchDropoutActBias(Functor act_functor, mask_data, quant_last_in_scale, dequant_out_scale_data, - quant_out_scale_offset, quant_next_in_scale); } } diff --git a/paddle/fluid/operators/fused/fused_dropout_helper.h b/paddle/fluid/operators/fused/fused_dropout_helper.h index 708aef3d690f9..f95d159144f37 100644 --- a/paddle/fluid/operators/fused/fused_dropout_helper.h +++ b/paddle/fluid/operators/fused/fused_dropout_helper.h @@ -154,7 +154,6 @@ class FusedDropoutHelper { MaskType* mask, const float quant_last_in_scale = 1.0, const float* dequant_out_scale_data = nullptr, - const int quant_out_scale_offset = 0, const float quant_next_in_scale = 1.0) { auto increment = GetIncrement(ctx); LaunchResidualDropoutBias( @@ -173,7 +172,6 @@ class FusedDropoutHelper { ctx, quant_last_in_scale, dequant_out_scale_data, - quant_out_scale_offset, quant_next_in_scale); } @@ -212,7 +210,6 @@ class FusedDropoutHelper { MaskType* mask, const float quant_last_in_scale = 1.0, const float* dequant_out_scale_data = nullptr, - const int quant_out_scale_offset = 0, const float quant_next_in_scale = 1.0, const int quant_round_type = 1, const float quant_max_bound = 127.0, @@ -237,7 +234,6 @@ class FusedDropoutHelper { ctx, quant_last_in_scale, dequant_out_scale_data, - quant_out_scale_offset, quant_next_in_scale, quant_round_type, quant_max_bound, @@ -260,7 +256,6 @@ class FusedDropoutHelper { ctx, quant_last_in_scale, dequant_out_scale_data, - quant_out_scale_offset, quant_next_in_scale, quant_round_type, quant_max_bound, @@ -287,7 +282,6 @@ class FusedDropoutHelper { ctx, quant_last_in_scale, dequant_out_scale_data, - quant_out_scale_offset, quant_next_in_scale, quant_round_type, quant_max_bound, @@ -454,7 +448,6 @@ class FusedDropoutLayerNormHelper LayerNormParamType* variance, const float quant_last_in_scale = 1.0, const float* dequant_out_scale_data = nullptr, - const int quant_out_scale_offset = 0, const float quant_next_in_scale = 1.0, const int quant_round_type = 1, const float quant_max_bound = 127.0, @@ -494,7 +487,6 @@ class FusedDropoutLayerNormHelper ctx, quant_last_in_scale, dequant_out_scale_data, - quant_out_scale_offset, quant_next_in_scale, quant_round_type, quant_max_bound, diff --git a/paddle/fluid/operators/fused/fused_layernorm_residual_dropout_bias.h b/paddle/fluid/operators/fused/fused_layernorm_residual_dropout_bias.h index 137943afbfb94..a529271250e5d 100644 --- a/paddle/fluid/operators/fused/fused_layernorm_residual_dropout_bias.h +++ b/paddle/fluid/operators/fused/fused_layernorm_residual_dropout_bias.h @@ -442,7 +442,6 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_fast_ln_fwd_kernel( OutType *__restrict__ y_ptr, const float quant_last_in_scale = 1.0, const float *__restrict__ quant_out_scale_ptr = nullptr, - const int quant_out_scale_offset = 0, const float quant_next_in_scale = 1.0, const int quant_round_type = 1, const float quant_max_bound = 127.0, @@ -504,9 +503,8 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_fast_ln_fwd_kernel( phi::Load(x_ptr + row * ELTS_PER_ROW + col * VecSize, &x_input[it]); if (quant_out_scale_ptr != nullptr) { - phi::Load( - quant_out_scale_ptr + quant_out_scale_offset + col * VecSize, - &dequant_out_scale[it]); + phi::Load(quant_out_scale_ptr + col * VecSize, + &dequant_out_scale[it]); } col += THREADS_PER_ROW; } @@ -543,7 +541,6 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_fast_ln_fwd_kernel( // dropout(x) + residual if (std::is_same::value) { T tmp = (static_cast(static_cast(x_input[it][jt]) * - quant_last_in_scale / dequant_out_scale[it][jt]) + bias[it][jt]) * static_cast(mask_vec[it][jt]) * factor + @@ -567,7 +564,6 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_fast_ln_fwd_kernel( if (std::is_same::value) { // for int32 input, we need to dequantize. T tmp = static_cast(static_cast(x_input[it][jt]) * - quant_last_in_scale / dequant_out_scale[it][jt]) * static_cast(mask_vec[it][jt]) * factor + residual[it][jt]; @@ -752,7 +748,6 @@ void LaunchLayernormResidualDropoutBias( const phi::GPUContext &ctx, const float quant_last_in_scale = 1.0, const float *dequant_out_scale_data = nullptr, - const int quant_out_scale_offset = 0, const float quant_next_in_scale = 1.0, const int quant_round_type = 1, const float quant_max_bound = 127.0, @@ -844,7 +839,6 @@ void LaunchLayernormResidualDropoutBias( layernorm_dst, \ quant_last_in_scale, \ dequant_out_scale_data, \ - quant_out_scale_offset, \ quant_next_in_scale, \ quant_round_type, \ quant_max_bound, \ diff --git a/paddle/fluid/operators/fused/fused_multi_transformer_int8_op.cc b/paddle/fluid/operators/fused/fused_multi_transformer_int8_op.cc index 2a2d1f27edd9c..3a9bd15c101e9 100644 --- a/paddle/fluid/operators/fused/fused_multi_transformer_int8_op.cc +++ b/paddle/fluid/operators/fused/fused_multi_transformer_int8_op.cc @@ -58,6 +58,12 @@ class FusedMultiTransformerINT8Op : public framework::OperatorWithKernel { CHECK_INPUTS(FFN1Weight); CHECK_INPUTS(FFN2Weight); + // scale + CHECK_INPUTS(QKVOutScale); + CHECK_INPUTS(OutLinearOutScale); + CHECK_INPUTS(FFN1OutScale); + CHECK_INPUTS(FFN2OutScale); + CHECK_OUTPUT(Out); // x: qkv's input [batch_size, seq_len, dim_embed] @@ -232,20 +238,24 @@ class FusedMultiTransformerINT8OpMaker "In order to keep consistent with the PTQ/QAT calculation logic," "QKVOutScale should be max_bound * max_bound / max_range." "Here max_range is per-channel weight scale." - "The shape of QKVOutScale is [num_layers, num_channels]") - .AsDispensable(); + "The shape of QKVOutScale is [num_channels]") + .AsDispensable() + .AsDuplicable(); AddInput("OutLinearOutScale", "OutLinearOutScale is used to dequantize out_linear output tensor." "The definition and shape is the same as QKVOutScale") - .AsDispensable(); + .AsDispensable() + .AsDuplicable(); AddInput("FFN1OutScale", "FFN1OutScale is used to dequantize ffn1 output tensor." "The definition and shape is the same as QKVOutScale") - .AsDispensable(); + .AsDispensable() + .AsDuplicable(); AddInput("FFN2OutScale", "FFN2OutScale is used to dequantize ffn2 output tensor." "The definition and shape is the same as QKVOutScale") - .AsDispensable(); + .AsDispensable() + .AsDuplicable(); AddOutput("CacheKVOut", "The updated cache KV. Inplace with CacheKV") .AsDispensable() diff --git a/paddle/fluid/operators/fused/fused_multi_transformer_int8_op.cu b/paddle/fluid/operators/fused/fused_multi_transformer_int8_op.cu index 681748c71c91a..fa22ee8d57e65 100644 --- a/paddle/fluid/operators/fused/fused_multi_transformer_int8_op.cu +++ b/paddle/fluid/operators/fused/fused_multi_transformer_int8_op.cu @@ -48,16 +48,11 @@ class FusedMultiTransformerINT8OpKernel : public framework::OpKernel { // dequant output scales, tensor, size = [num_layers, n], n is gemm output // size - auto *qkv_out_scale = ctx.Input("QKVOutScale"); - auto *out_linear_out_scale = - ctx.Input("OutLinearOutScale"); - auto *ffn1_out_scale = ctx.Input("FFN1OutScale"); - auto *ffn2_out_scale = ctx.Input("FFN2OutScale"); - - int qkv_out_scale_n = qkv_out_scale->dims()[1]; - int out_linear_out_scale_n = out_linear_out_scale->dims()[1]; - int ffn1_out_scale_n = ffn1_out_scale->dims()[1]; - int ffn2_out_scale_n = ffn2_out_scale->dims()[1]; + auto qkv_out_scales = ctx.MultiInput("QKVOutScale"); + auto out_linear_out_scales = + ctx.MultiInput("OutLinearOutScale"); + auto ffn1_out_scales = ctx.MultiInput("FFN1OutScale"); + auto ffn2_out_scales = ctx.MultiInput("FFN2OutScale"); // 1. layer norm const auto pre_layer_norm = ctx.Attr("pre_layer_norm"); @@ -132,6 +127,7 @@ class FusedMultiTransformerINT8OpKernel : public framework::OpKernel { transpose_out_2.Resize({{3, bsz, num_head, seq_len, dim_head}}); auto *transpose_out_2_data = dev_ctx.Alloc(&transpose_out_2, transpose_out_2.numel() * sizeof(T)); + qk_out.Resize({{bsz, num_head, seq_len, out_seq_len}}); auto *qk_out_data = dev_ctx.Alloc(&qk_out, qk_out.numel() * sizeof(T)); @@ -232,19 +228,23 @@ class FusedMultiTransformerINT8OpKernel : public framework::OpKernel { dev_ctx, bsz_seq, dim_embed, ffn2_dropout_param, epsilon); // []. init workspace for cublasLt transform - Tensor input_workspace, output_workspace; + Tensor input_workspace, output_workspace, cublaslt_workspace; // for input and output transform data is CUBLASLT_ORDER_COL32 format, int m_max = bsz_seq, k_max = std::max(dim_embed, dim_ffn), n_max = std::max({output_size, dim_embed, dim_ffn}); - input_workspace.Resize( - {{32 * ((m_max + 32 - 1) / 32), (k_max + 31) / 32 * 32}}); + input_workspace.Resize({{(m_max * k_max + 31) / 32 * 32}}); dev_ctx.Alloc(&input_workspace, input_workspace.numel() * sizeof(int8_t)); - output_workspace.Resize({{n_max * 4, (m_max + 31) / 32 * 32 * 4}}); + + output_workspace.Resize({{(n_max * m_max + 31) / 32 * 32}}); dev_ctx.Alloc(&output_workspace, output_workspace.numel() * sizeof(int32_t)); + cublaslt_workspace.Resize({{3000000}}); + dev_ctx.Alloc(&cublaslt_workspace, + cublaslt_workspace.numel() * sizeof(int8_t)); + // calc auto *out = ctx.Output("Out"); auto *from_data = dev_ctx.Alloc(out, out->numel() * sizeof(T)); @@ -305,8 +305,7 @@ class FusedMultiTransformerINT8OpKernel : public framework::OpKernel { &output_workspace, &qkv_out, qkv_in_scale[i], - qkv_out_scale, - i * qkv_out_scale_n, + qkv_out_scales[i], quant_round_type, quant_max_bound, quant_min_bound); @@ -319,8 +318,7 @@ class FusedMultiTransformerINT8OpKernel : public framework::OpKernel { &output_workspace, &qkv_out, qkv_in_scale[i], - qkv_out_scale, - i * qkv_out_scale_n, + qkv_out_scales[i], quant_round_type, quant_max_bound, quant_min_bound); @@ -332,8 +330,7 @@ class FusedMultiTransformerINT8OpKernel : public framework::OpKernel { &qkv_out, &output_workspace, &qkv_out, - qkv_out_scale, - i * qkv_out_scale_n); + qkv_out_scales[i]); } #ifdef _DEBUG_FUSED_MULTI_TRANSFORMER VLOG(0) << "step2"; @@ -441,8 +438,7 @@ class FusedMultiTransformerINT8OpKernel : public framework::OpKernel { &output_workspace, nullptr, out_linear_in_scale[i], - out_linear_out_scale, - i * out_linear_out_scale_n, + out_linear_out_scales[i], quant_round_type, quant_max_bound, quant_min_bound); @@ -473,8 +469,7 @@ class FusedMultiTransformerINT8OpKernel : public framework::OpKernel { ln_mean_data, ln_var_data, out_linear_in_scale[i], - out_linear_out_scale->data(), - i * out_linear_out_scale_n, + out_linear_out_scales[i]->data(), ffn1_in_scale[i], quant_round_type, quant_max_bound, @@ -504,11 +499,13 @@ class FusedMultiTransformerINT8OpKernel : public framework::OpKernel { // step6. ffn matmul1 if (pre_layer_norm) { - ffn1_linear_compute.ComputeForwardINT8ToINT8(ffn1_weights[i], - &input_workspace, - nullptr, - &output_workspace, - nullptr); + ffn1_linear_compute.ComputeForwardINT8ToINT8( + ffn1_weights[i], + &input_workspace, + nullptr, + &output_workspace, + nullptr, + cublaslt_workspace.data()); } else { ffn1_linear_compute.ComputeForward(ffn1_weights[i], buf1, @@ -518,8 +515,7 @@ class FusedMultiTransformerINT8OpKernel : public framework::OpKernel { &output_workspace, nullptr, ffn1_in_scale[i], - ffn1_out_scale, - i * ffn1_out_scale_n, + ffn1_out_scales[i], quant_round_type, quant_max_bound, quant_min_bound); @@ -539,8 +535,7 @@ class FusedMultiTransformerINT8OpKernel : public framework::OpKernel { input_workspace.data(), ffn1_dropout_mask_data, ffn1_in_scale[i], - ffn1_out_scale->data(), - i * ffn1_out_scale_n, + ffn1_out_scales[i]->data(), ffn2_in_scale[i], quant_round_type, quant_max_bound, @@ -560,11 +555,13 @@ class FusedMultiTransformerINT8OpKernel : public framework::OpKernel { // step8. ffn matmul2 if (pre_layer_norm) { - ffn2_linear_compute.ComputeForwardINT8ToINT8(ffn2_weights[i], - &input_workspace, - nullptr, - &output_workspace, - nullptr); + ffn2_linear_compute.ComputeForwardINT8ToINT8( + ffn2_weights[i], + &input_workspace, + nullptr, + &output_workspace, + nullptr, + cublaslt_workspace.data()); } else { ffn2_linear_compute.ComputeForward(ffn2_weights[i], &ffn1_dropout_out, @@ -574,8 +571,7 @@ class FusedMultiTransformerINT8OpKernel : public framework::OpKernel { &output_workspace, nullptr, ffn2_in_scale[i], - ffn2_out_scale, - i * ffn2_out_scale_n, + ffn2_out_scales[i], quant_round_type, quant_max_bound, quant_min_bound); @@ -616,8 +612,7 @@ class FusedMultiTransformerINT8OpKernel : public framework::OpKernel { ln_mean_data, ln_var_data, ffn2_in_scale[i], - ffn2_out_scale->data(), - i * ffn2_out_scale_n, + ffn2_out_scales[i]->data(), qkv_in_scale[i + 1], quant_round_type, quant_max_bound, @@ -631,8 +626,7 @@ class FusedMultiTransformerINT8OpKernel : public framework::OpKernel { buf1->data(), dropout_mask_out_data, ffn2_in_scale[i], - ffn2_out_scale->data(), - i * ffn2_out_scale_n, + ffn2_out_scales[i]->data(), 1.0); } } else { diff --git a/paddle/fluid/operators/fused/fused_residual_dropout_bias.h b/paddle/fluid/operators/fused/fused_residual_dropout_bias.h index f162d200abfe1..972bbe3326a5d 100644 --- a/paddle/fluid/operators/fused/fused_residual_dropout_bias.h +++ b/paddle/fluid/operators/fused/fused_residual_dropout_bias.h @@ -49,7 +49,6 @@ __forceinline__ __device__ void FusedResidualDropoutBiasOneThread( Functor act_func, const float quant_last_in_scale = 1.0, const float *dequant_out_scale_data = nullptr, - const int quant_out_scale_offset = 0, const float quant_next_in_scale = 1.0, const int quant_round_type = 1, const float quant_max_bound = 127.0, @@ -74,9 +73,8 @@ __forceinline__ __device__ void FusedResidualDropoutBiasOneThread( } // vectorize load data from global phi::Load(&src[row_id * cols + col_id], &src_vec); - phi::Load( - &dequant_out_scale_data[quant_out_scale_offset + col_id], - &quant_out_scale_vec); + phi::Load(&dequant_out_scale_data[col_id], + &quant_out_scale_vec); if (residual) { phi::Load(&residual[row_id * cols + col_id], &residual_vec); } @@ -108,7 +106,7 @@ __forceinline__ __device__ void FusedResidualDropoutBiasOneThread( T tmp; if (std::is_same::value) { T tmp0 = static_cast(static_cast(src_vec[ii]) * - quant_last_in_scale / quant_out_scale_vec[ii]); + quant_out_scale_vec[ii]); tmp = tmp0 + bias_vec[ii]; } else { tmp = static_cast(src_vec[ii]) + bias_vec[ii]; @@ -172,7 +170,6 @@ __global__ void FusedResidualDropoutBias( const bool is_test, const float quant_last_in_scale = 1.0, const float *dequant_out_scale_data = nullptr, - const int quant_out_scale_offset = 0, const float quant_next_in_scale = 1.0) { int col_id = blockDim.x * blockIdx.x + threadIdx.x; int row_id = blockIdx.y; @@ -208,7 +205,6 @@ __global__ void FusedResidualDropoutBias( relu, quant_last_in_scale, dequant_out_scale_data, - quant_out_scale_offset, quant_next_in_scale); } } @@ -236,7 +232,6 @@ void LaunchResidualDropoutBias(const uint32_t rows, const phi::GPUContext &ctx, const float quant_last_in_scale = 1.0, const float *dequant_out_scale_data = nullptr, - const int quant_out_scale_offset = 0, const float quant_next_in_scale = 1.0) { // dropout_prob == 1.0f if (std::abs(dropout_prob - 1.0f) < 1e-5) { @@ -278,7 +273,6 @@ void LaunchResidualDropoutBias(const uint32_t rows, is_test, quant_last_in_scale, dequant_out_scale_data, - quant_out_scale_offset, quant_next_in_scale); } else { FusedResidualDropoutBias @@ -297,7 +291,6 @@ void LaunchResidualDropoutBias(const uint32_t rows, is_test, quant_last_in_scale, dequant_out_scale_data, - quant_out_scale_offset, quant_next_in_scale); } } diff --git a/paddle/fluid/operators/fused/quant_dequant_kernel.h b/paddle/fluid/operators/fused/quant_dequant_kernel.h index 21b7b0f345466..164effe01d316 100644 --- a/paddle/fluid/operators/fused/quant_dequant_kernel.h +++ b/paddle/fluid/operators/fused/quant_dequant_kernel.h @@ -18,17 +18,24 @@ limitations under the License. */ #include "paddle/fluid/operators/fake_quantize_op.h" #include "paddle/fluid/platform/device/gpu/gpu_info.h" #include "paddle/fluid/platform/float16.h" +#include "paddle/phi/backends/gpu/gpu_launch_config.h" +#include "paddle/phi/kernels/funcs/aligned_vector.h" namespace paddle { namespace operators { +using phi::backends::gpu::GpuLaunchConfig; + +constexpr int DequantKernelVecSize = 4; + template __forceinline__ __device__ int8_t quant_helper(const T input, const float scale, const int round_type, const float max_bound, const float min_bound) { - float quant_value = max_bound * inverse(scale) * static_cast(input); + float quant_value = max_bound * scale * static_cast(input); + if (round_type == 0) { quant_value = static_cast(roundWithTiesToEven(quant_value)); } else { @@ -77,7 +84,7 @@ void quantize_kernel_launcher(const T* input, const float min_bound, gpuStream_t stream) { // TODO(minghaoBD): optimize the kennel launch times when m==1 or n==1 - dim3 grid((n + 31) / 32, (m + 31) / 32); + dim3 grid((n >> 2 + 31) / 32, (m + 31) / 32); dim3 block(32, 32); quantize_kernel<<>>(input, @@ -90,46 +97,48 @@ void quantize_kernel_launcher(const T* input, min_bound); } -// dequantize using weight scales and input scales -template +template __global__ void dequantize_kernel(T* output, const int32_t* input, - const int m, // hidden - const int n, // batch size + const int m, // batch size + const int n, // hidden const float quant_in_scale, - const float* dequant_out_scale_data, - const int quant_out_scale_offset) { - int m_id = blockIdx.x * blockDim.x + threadIdx.x; // hidden - int n_id = blockIdx.y * blockDim.y + threadIdx.y; // batch size - - bool check = ((m_id < m) && (n_id < n)); - if (check) { - float out_scale = dequant_out_scale_data[quant_out_scale_offset + m_id]; - output[n_id * m + m_id] = - static_cast(static_cast(input[n_id * m + m_id]) * - quant_in_scale / out_scale); + const float* dequant_out_scale_data) { + int numel = m * n; + int stride = blockDim.x * gridDim.x * VecSize; + int idx = (blockIdx.x * blockDim.x + threadIdx.x) * VecSize; + int col_id = idx % n; + + phi::AlignedVector in_vec; + phi::AlignedVector out_scale_vec; + phi::AlignedVector out_vec; + + for (; idx < numel; idx += stride) { + phi::Load(input + idx, &in_vec); + phi::Load(dequant_out_scale_data + col_id, &out_scale_vec); + +#pragma unroll + for (int i = 0; i < VecSize; ++i) { + out_vec[i] = + static_cast(static_cast(in_vec[i]) * out_scale_vec[i]); + } + + phi::Store(out_vec, output + idx); } } template void dequantize_kernel_launcher(const int32_t* input, T* output, - const int batch_size, // m - const int hidden_units, // n + const int m, // m + const int n, // n gpuStream_t stream, + GpuLaunchConfig* gpu_config, const float quant_in_scale, - const float* dequant_out_scale_data, - const int quant_out_scale_offset) { - dim3 grid((hidden_units + 31) / 32, (batch_size + 31) / 32); - dim3 block(32, 32); - - dequantize_kernel<<>>(output, - input, - hidden_units, - batch_size, - quant_in_scale, - dequant_out_scale_data, - quant_out_scale_offset); + const float* dequant_out_scale_data) { + dequantize_kernel + <<block_per_grid, gpu_config->thread_per_block, 0, stream>>>( + output, input, m, n, quant_in_scale, dequant_out_scale_data); } } // namespace operators diff --git a/python/paddle/fluid/tests/unittests/test_fused_multi_transformer_int8_op.py b/python/paddle/fluid/tests/unittests/test_fused_multi_transformer_int8_op.py index 3f91f9b6e6d90..fbbe2d65418af 100644 --- a/python/paddle/fluid/tests/unittests/test_fused_multi_transformer_int8_op.py +++ b/python/paddle/fluid/tests/unittests/test_fused_multi_transformer_int8_op.py @@ -307,7 +307,7 @@ def generate_input_data(self): self.attn_mask = None def fake_quant(self, input, scale): - quant_value = 127.0 * (1.0 / scale) * paddle.cast(input, 'float32') + quant_value = 127.0 * scale * paddle.cast(input, 'float32') quant_value = paddle.round(quant_value) # No need to clip here because scale is the max value @@ -333,11 +333,8 @@ def GetBaselineOut(self): if self.pre_layer_norm: ln1_out = self.norm(tensor_query) max_v = paddle.max(paddle.abs(paddle.cast(ln1_out, 'float32')))[0] - # self.qkv_in_scales.append(127.0 / max_v) - self.qkv_in_scales.append(max_v) - self.qkv_out_scales.append(127.0 * 127.0) - # print('qkv_in_scales ', i, self.qkv_in_scales[i]) - # print('qkv_out_scales ', i, self.qkv_out_scales[i]) + self.qkv_in_scales.append(1 / max_v) + self.qkv_out_scales.append(max_v / (127.0 * 127.0)) # quant ln1_out ln1_out = self.fake_quant(ln1_out, self.qkv_in_scales[i]) @@ -345,9 +342,7 @@ def GetBaselineOut(self): q = paddle.nn.functional.linear(ln1_out, self.q_weight_tensor) # de quant q = paddle.cast( - paddle.cast(q, 'float32') - * self.qkv_in_scales[i] - / self.qkv_out_scales[i], + paddle.cast(q, 'float32') * self.qkv_out_scales[i], self.x_type, ) @@ -357,17 +352,13 @@ def GetBaselineOut(self): k = paddle.nn.functional.linear(ln1_out, self.k_weight_tensor) k = paddle.cast( - paddle.cast(k, 'float32') - * self.qkv_in_scales[i] - / self.qkv_out_scales[i], + paddle.cast(k, 'float32') * self.qkv_out_scales[i], self.x_type, ) k = k + self.k_proj_bias_tensor v = paddle.nn.functional.linear(ln1_out, self.v_weight_tensor) v = paddle.cast( - paddle.cast(v, 'float32') - * self.qkv_in_scales[i] - / self.qkv_out_scales[i], + paddle.cast(v, 'float32') * self.qkv_out_scales[i], self.x_type, ) v = v + self.v_proj_bias_tensor @@ -442,10 +433,10 @@ def GetBaselineOut(self): max_v = paddle.max( paddle.abs(paddle.cast(out_linear_in, 'float32')) )[0] - # self.out_linear_in_scales.append(127.0 / max_v) - self.out_linear_in_scales.append(max_v) - self.out_linear_out_scales.append((127.0 * 127.0)) + self.out_linear_in_scales.append(1 / max_v) + self.out_linear_out_scales.append(max_v / (127.0 * 127.0)) + out_linear_in = self.fake_quant( out_linear_in, self.out_linear_in_scales[i] ) @@ -455,9 +446,7 @@ def GetBaselineOut(self): ) out = paddle.cast( - paddle.cast(out, 'float32') - * self.out_linear_in_scales[i] - / self.out_linear_out_scales[i], + paddle.cast(out, 'float32') * self.out_linear_out_scales[i], self.x_type, ) @@ -476,8 +465,8 @@ def GetBaselineOut(self): max_v = paddle.max(paddle.abs(paddle.cast(ffn_ln_out, 'float32')))[ 0 ] - self.ffn1_in_scales.append(max_v) - self.ffn1_out_scales.append((127.0 * 127.0)) + self.ffn1_in_scales.append(1 / max_v) + self.ffn1_out_scales.append(max_v / (127.0 * 127.0)) ffn_ln_out = self.fake_quant(ffn_ln_out, self.ffn1_in_scales[i]) ffn1_out = paddle.nn.functional.linear( @@ -485,9 +474,7 @@ def GetBaselineOut(self): ) ffn1_out = paddle.cast( - paddle.cast(ffn1_out, 'float32') - * self.ffn1_in_scales[i] - / self.ffn1_out_scales[i], + paddle.cast(ffn1_out, 'float32') * self.ffn1_out_scales[i], self.x_type, ) @@ -495,10 +482,8 @@ def GetBaselineOut(self): ffn1_out = self.dropout(self.activation(ffn1_out)) max_v = paddle.max(paddle.abs(paddle.cast(ffn1_out, 'float32')))[0] - # self.ffn2_in_scales.append(127.0 / max_v) - self.ffn2_in_scales.append(max_v) - self.ffn2_out_scales.append((127.0 * 127.0)) - # print('ffn2_in_scales ', i, self.ffn2_in_scales[i]) + self.ffn2_in_scales.append(1 / max_v) + self.ffn2_out_scales.append(max_v / (127.0 * 127.0)) ffn1_out = self.fake_quant(ffn1_out, self.ffn2_in_scales[i]) ffn2_out = paddle.nn.functional.linear( @@ -506,16 +491,12 @@ def GetBaselineOut(self): ) ffn2_out = paddle.cast( - paddle.cast(ffn2_out, 'float32') - * self.ffn2_in_scales[i] - / self.ffn2_out_scales[i], + paddle.cast(ffn2_out, 'float32') * self.ffn2_out_scales[i], self.x_type, ) ffn2_out = ffn2_out + self.ffn2_proj_bias_tensor residual_out = attn_out + self.dropout(ffn2_out) - # print("residual ", attn_out) - # print("residual_out ", residual_out) final_out = residual_out if not self.pre_layer_norm: final_out = self.ffn_norm(residual_out) @@ -644,23 +625,18 @@ def GetFusedMultiTransformerOut(self): ffn1_weights, ffn1_biases = [], [] ffn2_weights, ffn2_biases = [], [] ffn_ln_scales, ffn_ln_biases = [], [] + + # Input scales: list of value qkv_in_scale = [] out_linear_in_scale = [] ffn1_in_scale = [] ffn2_in_scale = [] - qkv_out_scales_tensor = paddle.ones( - [self.layers, 3 * self.embed_dim], 'float32' - ) - out_linear_out_scales_tensor = paddle.ones( - [self.layers, self.embed_dim], 'float32' - ) - ffn1_out_scales_tensor = paddle.ones( - [self.layers, 4 * self.embed_dim], 'float32' - ) - ffn2_out_scales_tensor = paddle.ones( - [self.layers, self.embed_dim], 'float32' - ) + # Output dequant scales: list of tensor + qkv_out_scales = [] + out_linear_out_scales = [] + ffn1_out_scales = [] + ffn2_out_scales = [] for i in range(self.layers): qkv_weights.append(qkv_weight_tensor) @@ -680,10 +656,30 @@ def GetFusedMultiTransformerOut(self): ffn1_in_scale.append(self.ffn1_in_scales[i]) ffn2_in_scale.append(self.ffn2_in_scales[i]) - qkv_out_scales_tensor[i, :] *= self.qkv_out_scales[i] - out_linear_out_scales_tensor[i, :] *= self.out_linear_out_scales[i] - ffn1_out_scales_tensor[i, :] *= self.ffn1_out_scales[i] - ffn2_out_scales_tensor[i, :] *= self.ffn2_out_scales[i] + qkv_out_scale = ( + paddle.ones([3 * self.embed_dim], 'float32') + * self.qkv_out_scales[i] + ) + + out_linear_out_scale = ( + paddle.ones([self.embed_dim], 'float32') + * self.out_linear_out_scales[i] + ) + + ffn1_out_scale = ( + paddle.ones([4 * self.embed_dim], 'float32') + * self.ffn1_out_scales[i] + ) + + ffn2_out_scale = ( + paddle.ones([self.embed_dim], 'float32') + * self.ffn2_out_scales[i] + ) + + qkv_out_scales.append(qkv_out_scale) + out_linear_out_scales.append(out_linear_out_scale) + ffn1_out_scales.append(ffn1_out_scale) + ffn2_out_scales.append(ffn2_out_scale) if self.has_cache_kv: cache_kvs.append(paddle.to_tensor(cache_kv, stop_gradient=True)) @@ -713,10 +709,10 @@ def GetFusedMultiTransformerOut(self): trans_qkvw=True, ring_id=-1, name=None, - qkv_out_scales=qkv_out_scales_tensor, - out_linear_out_scales=out_linear_out_scales_tensor, - ffn1_out_scales=ffn1_out_scales_tensor, - ffn2_out_scales=ffn2_out_scales_tensor, + qkv_out_scales=qkv_out_scales, + out_linear_out_scales=out_linear_out_scales, + ffn1_out_scales=ffn1_out_scales, + ffn2_out_scales=ffn2_out_scales, num_head=self.num_heads, dim_head=self.head_dim, dim_ffn=4 * self.embed_dim,