diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index 4b429d5f72347..3c2e8bf85a799 100644 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -103,7 +103,7 @@ pass_library(delete_c_identity_op_pass inference) pass_library(preln_residual_bias_fuse_pass inference) pass_library(delete_fill_constant_op_pass inference) pass_library(constant_folding_pass inference) -pass_library(float_to_half_pass inference) +pass_library(auto_mixed_precision_pass inference) pass_library(conv2d_fusion_layout_transfer_pass inference) pass_library(simplify_with_basic_ops_pass base) pass_library(fc_elementwise_layernorm_fuse_pass base) diff --git a/paddle/fluid/framework/ir/float_to_half_pass.cc b/paddle/fluid/framework/ir/auto_mixed_precision_pass.cc similarity index 74% rename from paddle/fluid/framework/ir/float_to_half_pass.cc rename to paddle/fluid/framework/ir/auto_mixed_precision_pass.cc index 9389490712c65..bc034301989b0 100644 --- a/paddle/fluid/framework/ir/float_to_half_pass.cc +++ b/paddle/fluid/framework/ir/auto_mixed_precision_pass.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/framework/ir/float_to_half_pass.h" +#include "paddle/fluid/framework/ir/auto_mixed_precision_pass.h" #include "paddle/fluid/framework/ir/graph_helper.h" #include "paddle/fluid/framework/operator.h" @@ -29,7 +29,7 @@ namespace ir { namespace { -using VarType = FloatToHalfPass::VarType; +using VarType = AutoMixedPrecisionPass::VarType; bool PhiKernelSupportPrecision( const std::string& op_type, @@ -71,6 +71,23 @@ bool GpuKernelSupportPrecision( return support; } +inline bool VarNodeHasDtype(Node* var_node) { + auto type = var_node->Var()->GetType(); + return (type == VarType::SELECTED_ROWS) || (type == VarType::LOD_TENSOR) || + (type == VarType::LOD_TENSOR_ARRAY) || (type == VarType::STRINGS) || + (type == VarType::VOCAB); +} + +inline bool IsFloatType(VarType::Type type) { + return (type == VarType::FP64) || (type == VarType::FP32); +} + +inline bool IsHalfType(VarType::Type type) { + return (type == VarType::FP16) || (type == VarType::BF16); +} + +}; // namespace + void DoInsertCastOp(Graph* graph, Node* var_node, Node* op_node, @@ -123,27 +140,26 @@ void DoInsertCastOp(Graph* graph, IR_NODE_UNLINK(var_node, op_node); } -inline bool VarNodeHasDtype(Node* var_node) { - auto type = var_node->Var()->GetType(); - return (type == VarType::SELECTED_ROWS) || (type == VarType::LOD_TENSOR) || - (type == VarType::LOD_TENSOR_ARRAY) || (type == VarType::STRINGS) || - (type == VarType::VOCAB); -} - -inline bool IsFloatType(VarType::Type type) { - return (type == VarType::FP64) || (type == VarType::FP32); -} - -inline bool IsHalfType(VarType::Type type) { - return (type == VarType::FP16) || (type == VarType::BF16); +bool OpSupportPrecision(const std::string& op_type, + phi::Backend backend, + phi::DataType precision, + const std::unordered_set& black_list) { + bool support = false; + if (black_list.count(op_type) == 0) { + if (backend == phi::Backend::GPU) { + support = GpuKernelSupportPrecision(op_type, precision); + } else { + PADDLE_THROW(paddle::platform::errors::InvalidArgument( + "Now, only support backend of GPU.")); + } + } + return support; } -}; // namespace - // The set of ops that support fp16 calculation and are considered // numerically-dangerous, slower and whose effects may also be observed in // downstream ops. -void FloatToHalfPass::SetDefaultBlacklist() const { +void AutoMixedPrecisionPass::SetDefaultBlacklist() const { black_list_.insert({ // numerically-dangerous "acos", @@ -175,12 +191,27 @@ void FloatToHalfPass::SetDefaultBlacklist() const { }); } -void FloatToHalfPass::Init(Graph* graph) const { - keep_io_types_ = true; - half_precision_ = - static_cast(Get("mixed_precision_mode")); +void AutoMixedPrecisionPass::Init(Graph* graph) const { + bool enable_gpu_mixed = Get("enable_gpu_mixed"); + if (enable_gpu_mixed) { + backend_ = phi::Backend::GPU; + } + + skip_pass_ = !enable_gpu_mixed; + + low_precision_ = static_cast(Get("mixed_precision_mode")); + black_list_ = Get>("mixed_black_list"); SetDefaultBlacklist(); + VLOG(4) << "black_list has "; + for (const auto& name : black_list_) { + VLOG(4) << " - " << name; + } + + keep_io_types_ = true; + if (Has("keep_io_types")) { + keep_io_types_ = Get("keep_io_types"); + } auto graph_size = graph->SubGraphsSize(); VLOG(4) << "graph size: " << graph_size; @@ -204,24 +235,27 @@ void FloatToHalfPass::Init(Graph* graph) const { } } -void FloatToHalfPass::ApplyImpl(Graph* graph) const { - auto enable_gpu_half = Get("enable_gpu_half"); - if (!enable_gpu_half) return; - - PADDLE_ENFORCE_NOT_NULL( - graph, - platform::errors::PreconditionNotMet( - "During the float to half pass, the graph should not be nullptr.")); - PADDLE_ENFORCE_EQ( - graph->IsMainGraph(), - true, - platform::errors::PreconditionNotMet( - "During the float to half pass, the graph should be main graph.")); +void AutoMixedPrecisionPass::ApplyImpl(Graph* graph) const { + PADDLE_ENFORCE_NOT_NULL(graph, + platform::errors::PreconditionNotMet( + "During the auto_mixed_precision_pass, the graph " + "should not be nullptr.")); + PADDLE_ENFORCE_EQ(graph->IsMainGraph(), + true, + platform::errors::PreconditionNotMet( + "During the auto_mixed_precision_pass, the graph " + "should be main graph.")); - FusePassBase::Init("float_to_half", graph); + FusePassBase::Init("auto_mixed_precision", graph); Init(graph); VLOG(4) << "Init done"; + + if (skip_pass_) { + VLOG(3) << "Skip auto_mixed_precision_pass."; + return; + } + SetOpUniqueType(); VLOG(4) << "SetOpUniqueType done"; GetOpPrecision(); @@ -240,19 +274,7 @@ void FloatToHalfPass::ApplyImpl(Graph* graph) const { VLOG(4) << "RestoreOpOriginType done"; } -bool FloatToHalfPass::OpSupportPrecision(const std::string& op_type, - phi::DataType precision, - phi::Backend backend) const { - bool support = false; - if (black_list_.count(op_type) == 0) { - if (backend == phi::Backend::GPU) { - support = GpuKernelSupportPrecision(op_type, precision); - } - } - return support; -} - -void FloatToHalfPass::SetOpUniqueType() const { +void AutoMixedPrecisionPass::SetOpUniqueType() const { int suffix = 0; for (const auto& nodes : all_op_nodes_) { for (auto* op_node : nodes) { @@ -269,7 +291,7 @@ void FloatToHalfPass::SetOpUniqueType() const { } } -void FloatToHalfPass::RestoreOpOriginType() const { +void AutoMixedPrecisionPass::RestoreOpOriginType() const { for (const auto& nodes : all_op_nodes_) { for (auto* op_node : nodes) { auto op_type = op_node->Op()->Type(); @@ -281,7 +303,7 @@ void FloatToHalfPass::RestoreOpOriginType() const { } } -inline std::string FloatToHalfPass::GetOpOriginalType( +inline std::string AutoMixedPrecisionPass::GetOpOriginalType( const std::string& op_type) const { if (op_original_type_.count(op_type)) { return op_original_type_.at(op_type); @@ -289,22 +311,21 @@ inline std::string FloatToHalfPass::GetOpOriginalType( return op_type; } -void FloatToHalfPass::ProcessOpWithDtypeAttr() const { +void AutoMixedPrecisionPass::ProcessOpWithDtypeAttr() const { for (const auto& nodes : all_op_nodes_) { for (auto* op_node : nodes) { auto op_type = op_node->Op()->Type(); - if (op_run_half_.count(op_type) == 0) continue; + if (op_run_low_precision_.count(op_type) == 0) continue; if (op_node->Op()->HasAttr("dtype")) { auto dtype = op_node->Op()->GetAttrIfExists("dtype"); if (IsFloatType(static_cast(dtype))) { op_node->Op()->SetAttr( "dtype", - static_cast( - framework::TransToProtoVarType(half_precision_))); + static_cast(framework::TransToProtoVarType(low_precision_))); op_node->Op()->Flush(); VLOG(4) << "process op with dtype attr: " << op_type << " ( " << dtype - << " --->" << static_cast(half_precision_) << " )"; + << " --->" << static_cast(low_precision_) << " )"; } } if (op_node->Op()->HasAttr("out_dtype")) { @@ -312,11 +333,10 @@ void FloatToHalfPass::ProcessOpWithDtypeAttr() const { if (IsFloatType(static_cast(out_dtype))) { op_node->Op()->SetAttr( "out_dtype", - static_cast( - framework::TransToProtoVarType(half_precision_))); + static_cast(framework::TransToProtoVarType(low_precision_))); op_node->Op()->Flush(); VLOG(4) << "process op with out_dtype attr: " << op_type << " ( " - << out_dtype << " --->" << static_cast(half_precision_) + << out_dtype << " --->" << static_cast(low_precision_) << " )"; } } @@ -324,37 +344,39 @@ void FloatToHalfPass::ProcessOpWithDtypeAttr() const { } } -void FloatToHalfPass::GetOpPrecision() const { +void AutoMixedPrecisionPass::GetOpPrecision() const { for (const auto& nodes : all_op_nodes_) { for (auto* op_node : nodes) { auto op_type = op_node->Op()->Type(); - bool support_half = true; + bool support_low_precision = true; if (GetOpOriginalType(op_type) == "feed" || GetOpOriginalType(op_type) == "fetch") { - support_half = !keep_io_types_; + support_low_precision = !keep_io_types_; } else { - support_half = - OpSupportPrecision(GetOpOriginalType(op_type), half_precision_); + support_low_precision = OpSupportPrecision( + GetOpOriginalType(op_type), backend_, low_precision_, black_list_); } if (op_node->Op()->HasAttr("dtype")) { auto dtype = op_node->Op()->GetAttrIfExists("dtype"); - support_half = - support_half && IsFloatType(static_cast(dtype)); + support_low_precision = support_low_precision && + IsFloatType(static_cast(dtype)); } else if (op_node->Op()->HasAttr("out_dtype")) { auto out_dtype = op_node->Op()->GetAttrIfExists("out_dtype"); - support_half = - support_half && IsFloatType(static_cast(out_dtype)); + support_low_precision = + support_low_precision && + IsFloatType(static_cast(out_dtype)); } else { // if op's input var and output var is not dense tensor, the op should - // not run half. + // not run at low precision. for (auto* in_var_node : op_node->inputs) { CHECK_EQ(in_var_node->IsVar(), true); auto* real_in_var_node = real_vars_[in_var_node->Var()->Name()]; if (real_in_var_node->Var()->Persistable()) continue; - support_half = support_half && (real_in_var_node->Var()->GetType() == - VarType::LOD_TENSOR); + support_low_precision = + support_low_precision && + (real_in_var_node->Var()->GetType() == VarType::LOD_TENSOR); } for (auto* out_var_node : op_node->outputs) { @@ -362,23 +384,25 @@ void FloatToHalfPass::GetOpPrecision() const { auto* real_out_var_node = real_vars_[out_var_node->Var()->Name()]; if (real_out_var_node->Var()->Persistable()) continue; - support_half = support_half && (real_out_var_node->Var()->GetType() == - VarType::LOD_TENSOR); + support_low_precision = + support_low_precision && + (real_out_var_node->Var()->GetType() == VarType::LOD_TENSOR); } } - if (support_half) { - op_run_half_.insert(op_type); - VLOG(4) << "support precision: " << op_type << " run at half"; + if (support_low_precision) { + op_run_low_precision_.insert(op_type); + VLOG(4) << "support precision: " << op_type << " run at low precision"; } else { - VLOG(4) << "support precision: " << op_type << " not run at half"; + VLOG(4) << "support precision: " << op_type + << " not run at low precision"; } } } } -void FloatToHalfPass::UpdateOpPrecision() const { - std::unordered_set vars_should_not_half; +void AutoMixedPrecisionPass::UpdateOpPrecision() const { + std::unordered_set vars_should_not_low_precision; // var -> the var's all input op std::unordered_map> var_input_ops; @@ -401,30 +425,16 @@ void FloatToHalfPass::UpdateOpPrecision() const { << " is output of " << op_type; } - // the select_input op's input var should not convert to half. when - // op's output var is select_input op's input var, the op should not run - // half. + // the select_input op's input var should not convert to low precision. + // when op's output var is select_input op's input var, the op should + // not run at low precision. if (GetOpOriginalType(op_node->Op()->Type()) == "select_input") { for (auto* in_var_node : op_node->inputs) { CHECK_EQ(in_var_node->IsVar(), true); if (in_var_node->Var()->Persistable()) continue; if (!VarNodeHasDtype(in_var_node)) continue; - vars_should_not_half.insert(in_var_node->Var()->Name()); - } - } - - // when op_1 only support cpu kernel. if op_2's intput var is op_1's - // output var, then op_2 should not run half. - if (GetOpOriginalType(op_type) != "feed" && - !GpuKernelSupportPrecision(GetOpOriginalType(op_type), - phi::DataType::FLOAT32)) { - for (auto* out_var_node : op_node->outputs) { - CHECK_EQ(out_var_node->IsVar(), true); - if (out_var_node->Var()->Persistable()) continue; - if (!VarNodeHasDtype(out_var_node)) continue; - - vars_should_not_half.insert(out_var_node->Var()->Name()); + vars_should_not_low_precision.insert(in_var_node->Var()->Name()); } } } @@ -437,25 +447,7 @@ void FloatToHalfPass::UpdateOpPrecision() const { precision_updated = false; for (const auto& nodes : all_op_nodes_) { for (auto* op_node : nodes) { - if (op_run_half_.count(op_node->Op()->Type()) == 0) continue; - - for (auto* in_var_node : op_node->inputs) { - CHECK_EQ(in_var_node->IsVar(), true); - if (!VarNodeHasDtype(in_var_node)) continue; - - auto* real_in_var_node = real_vars_[in_var_node->Var()->Name()]; - if (real_in_var_node->Var()->Persistable()) continue; - - if (vars_should_not_half.count(real_in_var_node->Var()->Name())) { - op_run_half_.erase(op_node->Op()->Type()); - precision_updated = true; - VLOG(4) << op_node->Op()->Type() - << " should not support half precision."; - break; - } - } - - if (op_run_half_.count(op_node->Op()->Type()) == 0) continue; + if (op_run_low_precision_.count(op_node->Op()->Type()) == 0) continue; for (auto* out_var_node : op_node->outputs) { CHECK_EQ(out_var_node->IsVar(), true); @@ -464,24 +456,25 @@ void FloatToHalfPass::UpdateOpPrecision() const { auto* real_out_var_node = real_vars_[out_var_node->Var()->Name()]; if (real_out_var_node->Var()->Persistable()) continue; - bool not_run_half = false; + bool not_run_low_precision = false; const auto& input_op_nodes = var_input_ops[real_out_var_node->Var()->Name()]; - if (vars_should_not_half.count(real_out_var_node->Var()->Name())) { - not_run_half = true; + if (vars_should_not_low_precision.count( + real_out_var_node->Var()->Name())) { + not_run_low_precision = true; } else { for (auto* node : input_op_nodes) { - if (op_run_half_.count(node->Op()->Type()) == 0) { - not_run_half = true; + if (op_run_low_precision_.count(node->Op()->Type()) == 0) { + not_run_low_precision = true; break; } } } - if (not_run_half) { - op_run_half_.erase(op_node->Op()->Type()); + if (not_run_low_precision) { + op_run_low_precision_.erase(op_node->Op()->Type()); precision_updated = true; VLOG(4) << op_node->Op()->Type() - << " should not support half precision."; + << " should not run at low precision."; break; } } @@ -491,8 +484,8 @@ void FloatToHalfPass::UpdateOpPrecision() const { } // special ops, its weights should not be low precision. -bool FloatToHalfPass::InputVarsNotConvert(Node* op_node, - const std::string& var_name) const { +bool AutoMixedPrecisionPass::InputVarsNotConvert( + Node* op_node, const std::string& var_name) const { auto* op_desc = op_node->Op(); if (GetOpOriginalType(op_desc->Type()) == "batch_norm") { auto vecs = op_desc->Input("Bias"); @@ -532,8 +525,8 @@ bool FloatToHalfPass::InputVarsNotConvert(Node* op_node, return false; } -bool FloatToHalfPass::OutputVarsNotConvert(Node* op_node, - const std::string& var_name) const { +bool AutoMixedPrecisionPass::OutputVarsNotConvert( + Node* op_node, const std::string& var_name) const { auto* op_desc = op_node->Op(); // batch_norm's input and output (variance and mean) are the same. if (GetOpOriginalType(op_desc->Type()) == "batch_norm") { @@ -557,10 +550,14 @@ bool FloatToHalfPass::OutputVarsNotConvert(Node* op_node, return false; } -void FloatToHalfPass::SetVarPrecision() const { +void AutoMixedPrecisionPass::SetVarPrecision() const { for (const auto& nodes : all_op_nodes_) { for (auto* op_node : nodes) { - if (op_run_half_.count(op_node->Op()->Type())) { + if (op_run_low_precision_.count(op_node->Op()->Type()) == 0) { + continue; + } + + if (GetOpOriginalType(op_node->Op()->Type()) != "feed") { for (auto* in_var_node : op_node->inputs) { CHECK_EQ(in_var_node->IsVar(), true); @@ -573,11 +570,13 @@ void FloatToHalfPass::SetVarPrecision() const { if (real_in_var_node->Var()->Persistable()) { real_in_var_node->Var()->SetDataType( - framework::TransToProtoVarType(half_precision_)); - vars_convert_to_half_.insert(in_var_name); + framework::TransToProtoVarType(low_precision_)); + vars_convert_to_low_precision_.insert(in_var_name); } } + } + if (GetOpOriginalType(op_node->Op()->Type()) != "fetch") { for (auto* out_var_node : op_node->outputs) { CHECK_EQ(out_var_node->IsVar(), true); @@ -589,9 +588,9 @@ void FloatToHalfPass::SetVarPrecision() const { if (OutputVarsNotConvert(op_node, out_var_name)) continue; real_out_var_node->Var()->SetDataType( - framework::TransToProtoVarType(half_precision_)); + framework::TransToProtoVarType(low_precision_)); if (real_out_var_node->Var()->Persistable()) { - vars_convert_to_half_.insert(out_var_name); + vars_convert_to_low_precision_.insert(out_var_name); } } } @@ -606,24 +605,24 @@ void FloatToHalfPass::SetVarPrecision() const { if (!VarNodeHasDtype(var_node)) continue; auto var_name = var_node->Var()->Name(); - if (vars_convert_to_half_.count(var_name)) { + if (vars_convert_to_low_precision_.count(var_name)) { var_node->Var()->SetDataType( - framework::TransToProtoVarType(half_precision_)); + framework::TransToProtoVarType(low_precision_)); } } } } -void FloatToHalfPass::ConvertWeightsData() const { +void AutoMixedPrecisionPass::ConvertWeightsData() const { auto* scope = param_scope(); - PADDLE_ENFORCE_NOT_NULL( - scope, - platform::errors::PreconditionNotMet( - "During the float to half pass, the scope should not be null.")); + PADDLE_ENFORCE_NOT_NULL(scope, + platform::errors::PreconditionNotMet( + "During the auto_mixed_precision_pass, the scope " + "should not be null.")); auto var_names = scope->LocalVarNames(); for (const auto& var_name : var_names) { - if (vars_convert_to_half_.count(var_name)) { + if (vars_convert_to_low_precision_.count(var_name)) { VLOG(4) << var_name << "'s data type was convert to half"; auto* var = scope->FindLocalVar(var_name); @@ -631,25 +630,29 @@ void FloatToHalfPass::ConvertWeightsData() const { auto* origin_tensor = var->GetMutable(); - phi::DenseTensor half_tensor; - half_tensor.Resize(origin_tensor->dims()); - half_tensor.set_type(half_precision_); + phi::DenseTensor low_precision_tensor; + low_precision_tensor.Resize(origin_tensor->dims()); + low_precision_tensor.set_type(low_precision_); - if (half_precision_ == phi::DataType::FLOAT16) { - auto* half_data = - half_tensor.mutable_data(phi::CPUPlace{}); + if (low_precision_ == phi::DataType::FLOAT16) { + auto* low_precision_data = + low_precision_tensor.mutable_data( + phi::CPUPlace{}); for (int64_t i = 0; i < origin_tensor->numel(); i++) { if (origin_tensor->dtype() == phi::DataType::FLOAT64) { auto* origin_data = origin_tensor->data(); - half_data[i] = static_cast(origin_data[i]); + low_precision_data[i] = + static_cast(origin_data[i]); } else if (origin_tensor->dtype() == phi::DataType::FLOAT32) { auto* origin_data = origin_tensor->data(); - half_data[i] = static_cast(origin_data[i]); + low_precision_data[i] = + static_cast(origin_data[i]); } } - } else if (half_precision_ == phi::DataType::BFLOAT16) { + } else if (low_precision_ == phi::DataType::BFLOAT16) { auto* half_data = - half_tensor.mutable_data(phi::CPUPlace{}); + low_precision_tensor.mutable_data( + phi::CPUPlace{}); for (int64_t i = 0; i < origin_tensor->numel(); i++) { if (origin_tensor->dtype() == phi::DataType::FLOAT64) { auto* origin_data = origin_tensor->data(); @@ -662,12 +665,12 @@ void FloatToHalfPass::ConvertWeightsData() const { } origin_tensor->clear(); paddle::framework::TensorCopySync( - half_tensor, phi::CPUPlace{}, origin_tensor); + low_precision_tensor, phi::CPUPlace{}, origin_tensor); } } } -void FloatToHalfPass::InsertCastOp() const { +void AutoMixedPrecisionPass::InsertCastOp() const { int suffix = 0; std::unordered_map cache; @@ -681,7 +684,7 @@ void FloatToHalfPass::InsertCastOp() const { if (op_node->Op()->HasAttr("sub_block")) continue; VLOG(4) << "process op: " << op_type - << " run half: " << op_run_half_.count(op_type); + << " run low precision: " << op_run_low_precision_.count(op_type); auto inputs = op_node->inputs; for (auto* in_var_node : inputs) { @@ -696,17 +699,17 @@ void FloatToHalfPass::InsertCastOp() const { VLOG(4) << "process var: " << real_in_var_node->Var()->Name() << " with type " << in_var_type; - if (IsFloatType(in_var_type) && op_run_half_.count(op_type)) { + if (IsFloatType(in_var_type) && op_run_low_precision_.count(op_type)) { DoInsertCastOp(subgraphes_[i], in_var_node, op_node, in_var_type, - framework::TransToProtoVarType(half_precision_), + framework::TransToProtoVarType(low_precision_), block_desc, &suffix, &cache); } else if (IsHalfType(in_var_type) && - op_run_half_.count(op_type) == 0) { + op_run_low_precision_.count(op_type) == 0) { DoInsertCastOp(subgraphes_[i], in_var_node, op_node, @@ -738,4 +741,5 @@ void FloatToHalfPass::InsertCastOp() const { } // namespace framework } // namespace paddle -REGISTER_PASS(float_to_half_pass, paddle::framework::ir::FloatToHalfPass); +REGISTER_PASS(auto_mixed_precision_pass, + paddle::framework::ir::AutoMixedPrecisionPass); diff --git a/paddle/fluid/framework/ir/float_to_half_pass.h b/paddle/fluid/framework/ir/auto_mixed_precision_pass.h similarity index 66% rename from paddle/fluid/framework/ir/float_to_half_pass.h rename to paddle/fluid/framework/ir/auto_mixed_precision_pass.h index 1af59f5fbc30d..578d47282b76d 100644 --- a/paddle/fluid/framework/ir/float_to_half_pass.h +++ b/paddle/fluid/framework/ir/auto_mixed_precision_pass.h @@ -27,13 +27,13 @@ namespace paddle { namespace framework { namespace ir { -class FloatToHalfPass : public FusePassBase { +class AutoMixedPrecisionPass : public FusePassBase { public: using VarType = framework::proto::VarType; public: - FloatToHalfPass() = default; - ~FloatToHalfPass() = default; + AutoMixedPrecisionPass() = default; + ~AutoMixedPrecisionPass() = default; protected: void ApplyImpl(Graph* graph) const override; @@ -43,10 +43,6 @@ class FloatToHalfPass : public FusePassBase { void SetDefaultBlacklist() const; - bool OpSupportPrecision(const std::string& op_type, - phi::DataType precision, - phi::Backend backend = phi::Backend::GPU) const; - void SetOpUniqueType() const; void RestoreOpOriginType() const; @@ -70,9 +66,13 @@ class FloatToHalfPass : public FusePassBase { void ConvertWeightsData() const; private: - mutable bool keep_io_types_; + mutable bool skip_pass_{false}; + + mutable bool keep_io_types_{false}; // float16 or bfloat16 now - mutable phi::DataType half_precision_; + mutable phi::DataType low_precision_{phi::DataType::FLOAT16}; + + mutable phi::Backend backend_{phi::Backend::GPU}; mutable std::unordered_set black_list_; @@ -84,12 +84,26 @@ class FloatToHalfPass : public FusePassBase { mutable std::vector> all_op_nodes_; // op's unique type -> the op's origin type mutable std::unordered_map op_original_type_; - // op's unique type -> whether the op run at half precision - mutable std::unordered_set op_run_half_; + // op's unique type -> whether the op run at low precision + mutable std::unordered_set op_run_low_precision_; - mutable std::unordered_set vars_convert_to_half_; + mutable std::unordered_set vars_convert_to_low_precision_; }; +bool OpSupportPrecision(const std::string& op_type, + phi::Backend backend, + phi::DataType precision, + const std::unordered_set& black_list); + +void DoInsertCastOp(Graph* graph, + Node* var_node, + Node* op_node, + proto::VarType::Type from_type, + proto::VarType::Type to_type, + framework::BlockDesc* block_desc, + int* suffix, + std::unordered_map* cache); + } // namespace ir } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/ir/conv2d_fusion_layout_transfer_pass.cc b/paddle/fluid/framework/ir/conv2d_fusion_layout_transfer_pass.cc index dbba001d52101..efed7dd6e637b 100644 --- a/paddle/fluid/framework/ir/conv2d_fusion_layout_transfer_pass.cc +++ b/paddle/fluid/framework/ir/conv2d_fusion_layout_transfer_pass.cc @@ -142,7 +142,7 @@ void Conv2dFusionLayoutTransferPass::ApplyImpl(ir::Graph *graph) const { bool is_fp16_precision = static_cast(Get("model_precision")) == phi::DataType::FLOAT16 || - Get("enable_gpu_half"); + Get("enable_gpu_mixed"); bool cutlass_enable = false; #ifdef PADDLE_WITH_CUTLASS diff --git a/paddle/fluid/framework/ir/conv_elementwise_add_act_fuse_pass.cc b/paddle/fluid/framework/ir/conv_elementwise_add_act_fuse_pass.cc index 063eb90d90af1..2f527ff1e707b 100644 --- a/paddle/fluid/framework/ir/conv_elementwise_add_act_fuse_pass.cc +++ b/paddle/fluid/framework/ir/conv_elementwise_add_act_fuse_pass.cc @@ -165,7 +165,7 @@ void ConvElementwiseAddActFusePass::ApplyImpl(ir::Graph* graph) const { bool is_fp16_precision = static_cast(Get("model_precision")) == phi::DataType::FLOAT16 || - Get("enable_gpu_half"); + Get("enable_gpu_mixed"); constexpr int CUTLASS_NHWC_ALIGNMENT = 8; if (is_fp16_precision) { #ifdef PADDLE_WITH_CUTLASS diff --git a/paddle/fluid/inference/analysis/argument.h b/paddle/fluid/inference/analysis/argument.h index c386bdcb2e45c..002eb29b776ea 100644 --- a/paddle/fluid/inference/analysis/argument.h +++ b/paddle/fluid/inference/analysis/argument.h @@ -365,7 +365,7 @@ struct Argument { DECL_ARGUMENT_FIELD(mixed_black_list, MixedBlackList, std::unordered_set); - DECL_ARGUMENT_FIELD(enable_gpu_half, EnableGPUHalf, bool); + DECL_ARGUMENT_FIELD(enable_gpu_mixed, EnableGPUMixed, bool); DECL_ARGUMENT_FIELD(mixed_precision_mode, MixedPrecisionMode, int); // cinn compiler related diff --git a/paddle/fluid/inference/analysis/ir_pass_manager.cc b/paddle/fluid/inference/analysis/ir_pass_manager.cc index f84ed64e7009a..734c8a60fb86b 100644 --- a/paddle/fluid/inference/analysis/ir_pass_manager.cc +++ b/paddle/fluid/inference/analysis/ir_pass_manager.cc @@ -45,8 +45,10 @@ IRPassManager::IRPassManager(Argument *argument) { void IRPassManager::CreatePasses(Argument *argument, const std::vector &passes) { + // For graph_viz_pass std::string pre_pass; int pass_num = 0; + for (const std::string &pass_name : passes) { auto pass = framework::ir::PassRegistry::Instance().Get(pass_name); pass->Set("use_varseqlen", new bool(argument->tensorrt_use_varseqlen())); @@ -87,14 +89,14 @@ void IRPassManager::CreatePasses(Argument *argument, argument->tensorrt_tuned_dynamic_shape(); pass->Set("with_dynamic_shape", new bool(with_dynamic_shape)); - // mixed precision related - pass->Set("model_precision", new int(argument->model_precision())); + // Mixed precision related. pass->Set( "mixed_black_list", new std::unordered_set(argument->mixed_black_list())); - pass->Set("enable_gpu_half", new bool(argument->enable_gpu_half())); + pass->Set("enable_gpu_mixed", new bool(argument->enable_gpu_mixed())); pass->Set("mixed_precision_mode", new int(argument->mixed_precision_mode())); + pass->Set("model_precision", new int(argument->model_precision())); if (pass_name == "graph_viz_pass") { std::string optim_cache_dir = argument->optim_cache_dir(); @@ -210,6 +212,7 @@ void IRPassManager::CreatePasses(Argument *argument, new std::vector(argument->tensorrt_disabled_ops())); pass->Set("trt_use_dla", new bool(argument->tensorrt_use_dla())); pass->Set("trt_dla_core", new int(argument->tensorrt_dla_core())); + // Setting the disable_trt_plugin_fp16 to true means that TRT plugin will // not run fp16. pass->Set("disable_trt_plugin_fp16", @@ -238,8 +241,7 @@ void IRPassManager::CreatePasses(Argument *argument, pass->Set("root_predictor_id", new int(argument->root_predictor_id())); } else if (pass_name == "build_cinn_pass") { pass->Set("is_inference_stage", new bool(argument->use_cinn_compiler())); - } - if (pass_name == "lite_subgraph_pass") { + } else if (pass_name == "lite_subgraph_pass") { bool lite_enable_int8 = argument->lite_precision_mode() == AnalysisConfig::Precision::kInt8; pass->Set("program", @@ -287,8 +289,7 @@ void IRPassManager::CreatePasses(Argument *argument, pass->Set("nnadapter_model_cache_token", new std::vector( argument->nnadapter_model_cache_token())); - } - if (pass_name == "fc_fuse_pass") { + } else if (pass_name == "fc_fuse_pass") { pass->Set("use_gpu", new bool(argument->use_gpu())); bool fc_mkldnn_pass = 0; for (const std::string &pass_n : passes) { diff --git a/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc b/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc index 4d95551ec1114..f765d9c22bbd5 100644 --- a/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc +++ b/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc @@ -83,14 +83,14 @@ void OutputProcess(framework::ir::Graph *graph, backend, precision, blacklist)) { - AddCastOp(graph, - var_node, - next_op, - framework::proto::VarType::FP32, - to_type, - &suffix, - block_desc, - &var_to_cast_op_map); + InsertCastOp(graph, + var_node, + next_op, + framework::proto::VarType::FP32, + to_type, + block_desc, + &suffix, + &var_to_cast_op_map); var_node->Var()->SetDataType(framework::proto::VarType::FP32); } } diff --git a/paddle/fluid/inference/analysis/passes/CMakeLists.txt b/paddle/fluid/inference/analysis/passes/CMakeLists.txt index fa074f962eb3d..96121601cb6fd 100644 --- a/paddle/fluid/inference/analysis/passes/CMakeLists.txt +++ b/paddle/fluid/inference/analysis/passes/CMakeLists.txt @@ -13,7 +13,7 @@ cc_library( cc_library( convert_to_mixed_precision SRCS convert_to_mixed_precision.cc - DEPS analysis_pass ir_graph_build_pass) + DEPS analysis_pass ir_graph_build_pass auto_mixed_precision_pass) cc_library( ir_params_sync_among_devices_pass SRCS ir_params_sync_among_devices_pass.cc diff --git a/paddle/fluid/inference/analysis/passes/convert_to_mixed_precision.cc b/paddle/fluid/inference/analysis/passes/convert_to_mixed_precision.cc index afc1d8a882ca6..f1939fc8b328b 100644 --- a/paddle/fluid/inference/analysis/passes/convert_to_mixed_precision.cc +++ b/paddle/fluid/inference/analysis/passes/convert_to_mixed_precision.cc @@ -14,662 +14,72 @@ #include "paddle/fluid/inference/analysis/passes/convert_to_mixed_precision.h" -#include -#include -#include -#include -#include -#include -#include - -#include "paddle/fluid/framework/block_desc.h" #include "paddle/fluid/framework/executor.h" -#include "paddle/fluid/framework/framework.pb.h" -#include "paddle/fluid/framework/ir/graph.h" +#include "paddle/fluid/framework/ir/auto_mixed_precision_pass.h" #include "paddle/fluid/framework/ir/graph_helper.h" -#include "paddle/fluid/framework/ir/graph_pattern_detector.h" -#include "paddle/fluid/framework/ir/node.h" -#include "paddle/fluid/framework/program_desc.h" -#include "paddle/fluid/framework/scope.h" -#include "paddle/fluid/framework/var_desc.h" #include "paddle/fluid/inference/io.h" -#include "paddle/phi/common/bfloat16.h" -#include "paddle/phi/common/data_type.h" -#include "paddle/phi/common/float16.h" -#include "paddle/phi/common/layout.h" -#include "paddle/phi/common/place.h" +#include "paddle/phi/common/backend.h" namespace paddle { namespace inference { namespace analysis { -namespace { -using VarType = framework::proto::VarType; - -bool PhiKernelSupportPrecision( - const std::string& op_type, +ConvertToMixedPrecisionPass::ConvertToMixedPrecisionPass( + const std::string& model_file, + const std::string& params_file, + const std::string& mixed_model_file, + const std::string& mixed_params_file, + phi::DataType mixed_precision, phi::Backend backend, - phi::DataType data_type, - phi::DataLayout layout = phi::DataLayout::ALL_LAYOUT) { - auto kernels = phi::KernelFactory::Instance().kernels(); - if (kernels.find(op_type) == kernels.end()) { - return false; - } - phi::KernelKey kernel_key(backend, layout, data_type); - return phi::KernelFactory::Instance().HasKernel(op_type, kernel_key); -} - -bool GpuKernelSupportPrecision( - const std::string& op_type, - phi::DataType data_type, - phi::DataLayout layout = phi::DataLayout::ALL_LAYOUT) { - auto phi_op_type = phi::TransToPhiKernelName(op_type); - bool res = PhiKernelSupportPrecision( - phi_op_type, phi::Backend::GPU, data_type, layout); - res |= PhiKernelSupportPrecision( - phi_op_type, phi::Backend::GPUDNN, data_type, layout); - - if (!res) { - auto& all_kernels = framework::OperatorWithKernel::AllOpKernels(); - auto it = all_kernels.find(op_type); - if (it != all_kernels.end()) { - for (auto& kern_pair : it->second) { - if (platform::is_gpu_place(kern_pair.first.place_) && - kern_pair.first.data_type_ == VarType::FP16) { - res = true; - break; - } - } - } - } - return res; -} - -class ConvertToMixedPrecisionPass { - using BlockID = size_t; - - public: - explicit ConvertToMixedPrecisionPass( - const std::string& model_file, - const std::string& params_file, - const std::string& mixed_model_file, - const std::string& mixed_params_file, - phi::DataType mixed_precision, - phi::Backend backend, - bool keep_io_types, - const std::unordered_set& black_list) - : model_file_(model_file), - params_file_(params_file), - mixed_model_file_(mixed_model_file), - mixed_params_file_(mixed_params_file), - mixed_precision_(mixed_precision), - backend_(backend), - keep_io_types_(keep_io_types), - black_list_(black_list), - place_(paddle::CPUPlace()), - executor_(place_) { - VLOG(4) << "black_list has "; - for (auto& name : black_list_) { - VLOG(4) << " - " << name; - } - } - - void Run(); - - private: - void LoadAndPrepare(); - inline bool VarNodeHasDtype(framework::ir::Node* node); - void ConvertAllFp64ToFp32(framework::ir::Graph* graph); - void FixCastAttr(framework::ir::Graph* graph); - void SaveMixedModel(); - void ConvertTensorDtype(BlockID block_idx); - void ProcessInputNode(bool support_precision, - framework::ir::Node* in_node, - framework::ir::Node* op_node, - int* suffix, - framework::BlockDesc* block_desc, - VarType::Type to_type, - BlockID block_idx); - - void ProcessOutputNode(BlockID block_idx, - framework::ir::Node* var_node, - VarType::Type to_type); - inline bool IsFloatVarType(VarType::Type type); - - bool OutShouldNotConvert(framework::ir::Node* var_node); - // Just process special cases for weights conversion. - bool WeightsShouldNotConvert(framework::ir::Node* var_node); - - // Return Node* which first appers in block. - framework::ir::Node* GetRealVarNode(framework::ir::Node* node); - - // Fallback to fp32 dtype when encounter circle (Not a DAG graph). - void ProcessCircleCases(); - - private: - std::string model_file_; - std::string params_file_; - std::string mixed_model_file_; - std::string mixed_params_file_; - phi::DataType mixed_precision_; - phi::Backend backend_; - bool keep_io_types_; - std::unordered_set black_list_; - paddle::CPUPlace place_; - framework::Executor executor_; - framework::Scope scope_; - - std::unordered_map name2node_; - std::unordered_map cast_map_; - int suffix_{0}; - - std::set var_names_in_circles_; - - std::unique_ptr program_desc_{nullptr}; - std::unique_ptr main_graph_{nullptr}; - std::vector graphes_; -}; - -framework::ir::Node* ConvertToMixedPrecisionPass::GetRealVarNode( - framework::ir::Node* var_node) { - CHECK_EQ(var_node->IsVar(), true); - if (name2node_.count(var_node->Name())) return name2node_[var_node->Name()]; - return var_node; -} - -inline bool ConvertToMixedPrecisionPass::VarNodeHasDtype( - framework::ir::Node* var_node) { - CHECK_EQ(var_node->IsVar(), true); - auto type = var_node->Var()->GetType(); - return (type == VarType::SELECTED_ROWS) || (type == VarType::LOD_TENSOR) || - (type == VarType::LOD_TENSOR_ARRAY) || (type == VarType::STRINGS) || - (type == VarType::VOCAB); -} - -void ConvertToMixedPrecisionPass::ProcessInputNode( - bool support_precision, - framework::ir::Node* in_node, - framework::ir::Node* op_node, - int* suffix, - framework::BlockDesc* block_desc, - VarType::Type to_type, - BlockID block_idx) { - if (!in_node->IsVar()) return; - auto* real_node = GetRealVarNode(in_node); - if (!VarNodeHasDtype(real_node)) return; - auto* graph = graphes_[block_idx]; - auto* in_var = real_node->Var(); - auto in_var_type = in_var->GetDataType(); - auto prev_type = in_var_type; - - if (support_precision) { - if (in_var->Persistable() && in_var_type == VarType::FP32) { - if (WeightsShouldNotConvert(in_node)) return; - in_var->SetDataType(to_type); - in_var_type = to_type; - VLOG(3) << " in_node name " << in_var->Name() << " from " << prev_type - << " to " << to_type; - } else if (!in_var->Persistable() && IsFloatVarType(in_var_type) && - in_var_type != to_type) { - AddCastOp(graph, - in_node, - op_node, - in_var_type, - to_type, - suffix, - block_desc, - &cast_map_); - VLOG(3) << " in_node name " << in_var->Name() << "(" << prev_type - << ") to " << cast_map_[in_node]->Name() << "(" << to_type << ")"; - } - } else { - if (!in_var->Persistable() && IsFloatVarType(in_var_type) && - in_var_type != to_type) { - AddCastOp(graph, - in_node, - op_node, - in_var_type, - to_type, - suffix, - block_desc, - &cast_map_); - VLOG(3) << " in_node name " << in_var->Name() << "(" << prev_type - << ") to " << cast_map_[in_node]->Name() << "(" << to_type << ")"; - } - } -} - -void ConvertToMixedPrecisionPass::ProcessOutputNode( - BlockID block_idx, framework::ir::Node* var_node, VarType::Type to_type) { - if (!var_node->IsVar()) return; - auto* real_node = GetRealVarNode(var_node); - if (!VarNodeHasDtype(real_node)) return; - auto* out_var = real_node->Var(); - auto prev_type = out_var->GetDataType(); - if (out_var->GetDataType() == VarType::FP32) { - if (OutShouldNotConvert(var_node)) return; - out_var->SetDataType(to_type); - } - VLOG(3) << " out_node name " << var_node->Name() << " from dtype " - << prev_type << " to " << out_var->GetDataType(); -} - -// Just process special cases. -bool ConvertToMixedPrecisionPass::OutShouldNotConvert( - framework::ir::Node* var_node) { - auto op_node = var_node->inputs[0]; - auto* op_desc = op_node->Op(); - - // batch_norm's input and output (variance and mean) are the same. - if (op_desc->Type() == "batch_norm") { - auto vecs = op_desc->Output("MeanOut"); - if (std::find(vecs.begin(), vecs.end(), var_node->Name()) != vecs.end()) { - return true; - } - vecs = op_desc->Output("VarianceOut"); - if (std::find(vecs.begin(), vecs.end(), var_node->Name()) != vecs.end()) { - return true; - } - vecs = op_desc->Output("SavedMean"); - if (std::find(vecs.begin(), vecs.end(), var_node->Name()) != vecs.end()) { - return true; - } - vecs = op_desc->Output("SavedVariance"); - if (std::find(vecs.begin(), vecs.end(), var_node->Name()) != vecs.end()) { - return true; - } + bool keep_io_types, + const std::unordered_set& black_list) + : model_file_(model_file), + params_file_(params_file), + mixed_model_file_(mixed_model_file), + mixed_params_file_(mixed_params_file), + mixed_precision_(mixed_precision), + backend_(backend), + keep_io_types_(keep_io_types), + black_list_(black_list) { + if (mixed_precision_ != phi::DataType::FLOAT16 && + mixed_precision_ != phi::DataType::BFLOAT16) { + PADDLE_THROW(paddle::platform::errors::InvalidArgument( + "mixed_precision currently not supported dtype %d, we now only " + "support fp16 and bf16.", + static_cast(mixed_precision_))); } - - return false; -} - -bool ConvertToMixedPrecisionPass::WeightsShouldNotConvert( - framework::ir::Node* var_node) { - auto op_nodes = var_node->outputs; - for (auto* op_node : op_nodes) { - auto* op_desc = op_node->Op(); - // batch_norm op's bias, mean, scale and variance just be float32, so we can - // not convert the dtype. - if (op_desc->Type() == "batch_norm") { - auto vecs = op_desc->Input("Bias"); - if (std::find(vecs.begin(), vecs.end(), var_node->Name()) != vecs.end()) { - return true; - } - vecs = op_desc->Input("Mean"); - if (std::find(vecs.begin(), vecs.end(), var_node->Name()) != vecs.end()) { - return true; - } - vecs = op_desc->Input("Scale"); - if (std::find(vecs.begin(), vecs.end(), var_node->Name()) != vecs.end()) { - return true; - } - vecs = op_desc->Input("Variance"); - if (std::find(vecs.begin(), vecs.end(), var_node->Name()) != vecs.end()) { - return true; - } - } else if (op_desc->Type() == "fused_multi_transformer") { - auto vecs = op_desc->Input("LnScale"); - if (std::find(vecs.begin(), vecs.end(), var_node->Name()) != vecs.end()) { - return true; - } - - vecs = op_desc->Input("LnBias"); - if (std::find(vecs.begin(), vecs.end(), var_node->Name()) != vecs.end()) { - return true; - } - - vecs = op_desc->Input("FFNLnScale"); - if (std::find(vecs.begin(), vecs.end(), var_node->Name()) != vecs.end()) { - return true; - } - - vecs = op_desc->Input("FFNLnBias"); - if (std::find(vecs.begin(), vecs.end(), var_node->Name()) != vecs.end()) { - return true; - } - } + if (backend_ != phi::Backend::GPU) { + PADDLE_THROW(paddle::platform::errors::InvalidArgument( + "mixed_precision currently not supported place %d, we now only " + "support gpu.", + static_cast(backend_))); } - - return false; } -inline bool ConvertToMixedPrecisionPass::IsFloatVarType(VarType::Type type) { - return (type == VarType::FP16) || (type == VarType::FP32) || - (type == VarType::BF16); -} +void ConvertToMixedPrecisionPass::LoadModel() { + framework::Executor exe{platform::CPUPlace{}}; -void ConvertToMixedPrecisionPass::LoadAndPrepare() { - program_desc_ = - inference::Load(&executor_, &scope_, model_file_, params_file_); + auto program_desc = inference::Load(&exe, &scope_, model_file_, params_file_); main_graph_ = std::unique_ptr( - new framework::ir::Graph(*program_desc_)); - - for (size_t i = 0; i < main_graph_->SubGraphsSize(); ++i) { - auto* graph = main_graph_->GetSubGraph(i); - graphes_.push_back(graph); - - for (auto* node : graph->Nodes()) { - if (!node->IsVar()) continue; - if (!name2node_.count(node->Name())) { - name2node_[node->Name()] = node; - } - } - } - - ProcessCircleCases(); -} - -// Find var names which in circles. -void ConvertToMixedPrecisionPass::ProcessCircleCases() { - std::vector vars_in_circles; - for (size_t idx = 0; idx < program_desc_->Size(); ++idx) { - for (auto* op : program_desc_->Block(idx).AllOps()) { - // TODO(inference): batch_norm has circle, but we need to fuse it in conv - // op. - if (op->Type() == "batch_norm") continue; - const auto& in_names = op->InputArgumentNames(); - const auto& out_names = op->OutputArgumentNames(); - std::set in_names_set(in_names.begin(), in_names.end()); - std::set out_names_set(out_names.begin(), out_names.end()); - std::set_intersection(in_names_set.begin(), - in_names_set.end(), - out_names_set.begin(), - out_names_set.end(), - std::back_inserter(vars_in_circles)); - } - } - - for (auto& name : vars_in_circles) { - var_names_in_circles_.insert(name); - } - for (auto& name : var_names_in_circles_) { - LOG(INFO) << name - << " in circles, so we will skip process those vars and ops."; - } -} - -inline void ProcessConstantOpAttr(framework::ir::Node* op_node, - VarType::Type from_type, - VarType::Type to_type) { - if (!op_node->IsOp()) return; - auto op_type = op_node->Op()->Type(); - if (op_type == "feed" || op_type == "fetch") return; - - if (op_type == "fill_constant") { - if (PADDLE_GET_CONST(int, op_node->Op()->GetAttr("dtype")) == - static_cast(from_type)) - op_node->Op()->SetAttr("dtype", static_cast(to_type)); - } else if (op_type == "assign_value") { - if (PADDLE_GET_CONST(int, op_node->Op()->GetAttr("dtype")) == - static_cast(from_type)) - op_node->Op()->SetAttr("dtype", static_cast(to_type)); - } else if (op_type == "eye") { - if (PADDLE_GET_CONST(int, op_node->Op()->GetAttr("dtype")) == - static_cast(from_type)) - op_node->Op()->SetAttr("dtype", static_cast(to_type)); - } else if (op_type == "fill_any_like") { - if (PADDLE_GET_CONST(int, op_node->Op()->GetAttr("dtype")) == - static_cast(from_type)) - op_node->Op()->SetAttr("dtype", static_cast(to_type)); - } else if (op_type == "cast") { - if (PADDLE_GET_CONST(int, op_node->Op()->GetAttr("in_dtype")) == - static_cast(from_type)) - op_node->Op()->SetAttr("in_dtype", static_cast(to_type)); - if (PADDLE_GET_CONST(int, op_node->Op()->GetAttr("out_dtype")) == - static_cast(from_type)) - op_node->Op()->SetAttr("out_dtype", static_cast(to_type)); - } -} - -void ConvertToMixedPrecisionPass::ConvertAllFp64ToFp32( - framework::ir::Graph* graph) { - auto op_nodes = framework::ir::TopologySortOperations(*graph); - for (auto* op_node : op_nodes) { - if (!op_node->IsOp()) continue; - auto op_type = op_node->Op()->Type(); - ProcessConstantOpAttr(op_node, VarType::FP64, VarType::FP32); - auto inputs = op_node->inputs; - for (auto* in_node : inputs) { - auto* in_var = in_node->Var(); - if (!in_var->Persistable() && in_var->GetDataType() == VarType::FP64) { - in_var->SetDataType(VarType::FP32); - } - } - } + new framework::ir::Graph(*program_desc)); + main_graph_->SetNotOwned(framework::ir::kParamScopeAttr, &scope_); } void ConvertToMixedPrecisionPass::Run() { - LoadAndPrepare(); + LoadModel(); - for (size_t i = 0; i < graphes_.size(); ++i) { - auto* graph = graphes_[i]; - VLOG(2) << " -------- handle subgraph " << i << ", has " - << graph->Nodes().size() << " nodes --------"; + framework::ir::AutoMixedPrecisionPass pass; + pass.Set("mixed_precision_mode", new int{static_cast(mixed_precision_)}); + pass.Set("mixed_black_list", + new std::unordered_set{black_list_}); + pass.Set("enable_gpu_mixed", new bool{true}); + pass.Set("keep_io_types", new bool{keep_io_types_}); - ConvertAllFp64ToFp32(graph); - ConvertTensorDtype(i); - FixCastAttr(graph); - - CHECK_EQ(framework::ir::VarDescIsConsistency(*graph), true); - } + pass.Apply(main_graph_.get()); SaveMixedModel(); } -void ConvertToMixedPrecisionPass::ConvertTensorDtype(BlockID block_idx) { - auto* graph = graphes_[block_idx]; - VarType::Type to_type; - if (mixed_precision_ == phi::DataType::FLOAT16) { - to_type = VarType::FP16; - } else if (mixed_precision_ == phi::DataType::BFLOAT16) { - to_type = VarType::BF16; - } else { - PADDLE_THROW(paddle::platform::errors::InvalidArgument( - "mixed_precision currently not supported dtype %d, we now only " - "support fp16 and bf16.", - static_cast(mixed_precision_))); - } - - auto op_nodes = framework::ir::TopologySortOperations(*graph); - auto* block_desc = op_nodes[0]->Op()->Block(); - int num_low_precision = 0; - std::vector output_nodes; - - for (auto* op_node : op_nodes) { - if (!op_node->IsOp()) continue; - auto op_type = op_node->Op()->Type(); - VLOG(3) << "-------------------- op_type " << op_type << ", phi_type " - << phi::TransToPhiKernelName(op_type); - // 1. set input dtype. - if (op_type == "feed") { - auto feed_var = op_node->outputs[0]->Var(); - if (!keep_io_types_ && feed_var->GetDataType() == VarType::FP32) { - feed_var->SetDataType(to_type); - } - } else if (op_type == "fetch") { - auto* fetch_var = op_node->inputs[0]; - output_nodes.push_back(fetch_var); - continue; - } else if (op_type == "cast") { - continue; - } - - // We can not add cast operator before ops who have sub_block, as in - // sub_block we may get a var which may be transformer by cast op. - else if (op_node->Op()->HasAttr("sub_block")) { // NOLINT - continue; - } - - // 2. if op support fp16/bf16 and not in blacklist. - // - cast weight to fp16/bf16. - // - add cast op if the input dtype is not fp16/bf16. - // - set output dtype. - else if (black_list_.count(op_type) == 0) { // NOLINT - bool support_precision = - OpSupportPrecision(op_type, backend_, mixed_precision_, black_list_); - - // If op's output in circle, we should not convert to fp16. - for (auto* out_node : op_node->outputs) { - if (var_names_in_circles_.count(out_node->Name())) { - support_precision = false; - VLOG(2) << " op's output " << out_node->Name() - << " is in circle, we can not support this case, just skip."; - break; - } - } - - // If the op has no input or output of float type, we will not choose the - // low precision kernel. - if (support_precision) { - bool has_float_in_out{false}; - for (auto* in_node : op_node->inputs) { - if (!in_node->IsVar()) continue; - if (in_node->Var()->GetType() != VarType::LOD_TENSOR) { - support_precision = false; - VLOG(2) << " op has tensor array input[" << in_node->Name() - << "], just skip."; - break; - } - auto* real_node = GetRealVarNode(in_node); - if (real_node->Var()->GetDataType() == VarType::FP16 || - real_node->Var()->GetDataType() == VarType::FP32 || - real_node->Var()->GetDataType() == VarType::FP64 || - real_node->Var()->GetDataType() == VarType::BF16) { - has_float_in_out = true; - break; - } - } - for (auto* out_node : op_node->outputs) { - if (!out_node->IsVar()) continue; - auto* real_node = GetRealVarNode(out_node); - if (real_node->Var()->GetDataType() == VarType::FP16 || - real_node->Var()->GetDataType() == VarType::FP32 || - real_node->Var()->GetDataType() == VarType::FP64 || - real_node->Var()->GetDataType() == VarType::BF16) { - has_float_in_out = true; - break; - } - } - - if (!has_float_in_out) { - support_precision = false; - VLOG(2) << " op doesn't has float input and output, just skip."; - } - } - - VLOG(2) << "op type: " << op_type - << " support low precision: " << support_precision; - - if (support_precision) { - ProcessConstantOpAttr(op_node, VarType::FP32, to_type); - VLOG(2) << " process input nodes:"; - ++num_low_precision; - auto inputs = op_node->inputs; - for (auto* in_node : inputs) { - ProcessInputNode( - true, in_node, op_node, &suffix_, block_desc, to_type, block_idx); - } - - VLOG(2) << " process output nodes:"; - auto outputs = op_node->outputs; - for (auto* out_node : outputs) { - ProcessOutputNode(block_idx, out_node, to_type); - } - } else { - auto inputs = op_node->inputs; - for (auto* in_node : inputs) { - ProcessInputNode(false, - in_node, - op_node, - &suffix_, - block_desc, - VarType::FP32, - block_idx); - } - } - } - - // 3. check op not support fp16/bf16 or in blacklist. - // - add cast op if the input dtype is not fp32. - else { // NOLINT - VLOG(3) << "not to run fp16 op_type: " << op_type << ", node input size " - << op_node->inputs.size(); - auto in_nodes = op_node->inputs; - for (auto* in_node : in_nodes) { - auto* in_var = in_node->Var(); - if (in_var->GetDataType() == to_type) { - AddCastOp(graph, - in_node, - op_node, - to_type, - VarType::FP32, - &suffix_, - block_desc, - &cast_map_); - VLOG(3) << "-- " << in_node->Name() << "(" << to_type << ") to " - << cast_map_[in_node]->Name() << "(" << VarType::FP32 << ")"; - } - } - } - } - - // 4. if output_op's dtype is not compatible to output dtype, then just - // insert cast. - for (auto* node : output_nodes) { - framework::ir::Node* fetch_op{nullptr}; - for (auto* op_node : node->outputs) { - if (op_node->IsOp() && op_node->Op()->Type() == "fetch") { - fetch_op = op_node; - } - } - CHECK_NOTNULL(fetch_op); - auto* var = node->Var(); - if (keep_io_types_ && var->GetDataType() == to_type) { - // fp16/bf16 -> fp32. - AddCastOp(graph, - node, - fetch_op, - to_type, - VarType::FP32, - &suffix_, - block_desc, - &cast_map_); - } else if (!keep_io_types_ && var->GetDataType() == VarType::FP32) { - // fp32 -> fp16/bf16 - AddCastOp(graph, - node, - fetch_op, - VarType::FP32, - to_type, - &suffix_, - block_desc, - &cast_map_); - } - } - - if (num_low_precision) - LOG(INFO) << "--- detected " << num_low_precision - << " low precision ops in " << block_idx << " subgraph"; -} - -// We modify op's input output precision, and we need to fix cast op in_dtype -// and out_dtype attribute. -// TODO(inference): we need a cast elimination pass. -void ConvertToMixedPrecisionPass::FixCastAttr(framework::ir::Graph* graph) { - auto op_nodes = framework::ir::TopologySortOperations(*graph); - for (auto* op_node : op_nodes) { - if (!op_node->IsOp()) continue; - auto op_type = op_node->Op()->Type(); - if (op_type != "cast") continue; - auto input = op_node->inputs[0]; - auto output = op_node->outputs[0]; - op_node->Op()->SetAttr("in_dtype", - static_cast(input->Var()->GetDataType())); - op_node->Op()->SetAttr("out_dtype", - static_cast(output->Var()->GetDataType())); - } -} - void ConvertToMixedPrecisionPass::SaveMixedModel() { framework::ProgramDesc mixed_program_desc; framework::ir::GraphToProgram(*main_graph_, &mixed_program_desc); @@ -677,51 +87,6 @@ void ConvertToMixedPrecisionPass::SaveMixedModel() { auto parameters = scope_.LocalVarNames(); std::sort(parameters.begin(), parameters.end()); - std::unordered_set weights_should_be_fp32; - for (auto* node : main_graph_->Nodes()) { - if (!node->IsVar()) continue; - if (VarNodeHasDtype(node)) { - if (node->Var()->Persistable() && - node->Var()->GetDataType() == VarType::FP32) { - VLOG(2) << "weights keep to fp32: " << node->Name() << ", ptr " - << reinterpret_cast(node->Var()); - weights_should_be_fp32.insert(node->Name()); - } - } - } - -#define CONVERT_TENSOR_DTYPE(DTYPE, dtype) \ - mixed_tensor.set_type(DTYPE); \ - auto* mixed_data = mixed_tensor.mutable_data(platform::CPUPlace()); \ - for (int64_t i = 0; i < origin_tensor->numel(); i++) { \ - mixed_data[i] = static_cast(origin_data[i]); \ - } \ - origin_tensor->clear(); \ - paddle::framework::TensorCopySync( \ - mixed_tensor, platform::CPUPlace(), origin_tensor) - - for (const auto& param_name : parameters) { - if (weights_should_be_fp32.count(param_name)) continue; - auto* var = scope_.FindLocalVar(param_name); - if (var->IsType()) { - auto* origin_tensor = var->GetMutable(); - if (origin_tensor->dtype() != phi::DataType::FLOAT32) continue; - phi::DenseTensor mixed_tensor; - mixed_tensor.Resize(origin_tensor->dims()); - auto* origin_data = - origin_tensor->mutable_data(platform::CPUPlace()); - if (mixed_precision_ == phi::DataType::FLOAT16) { - CONVERT_TENSOR_DTYPE(paddle::experimental::DataType::FLOAT16, - phi::dtype::float16); - } else if (mixed_precision_ == phi::DataType::BFLOAT16) { - CONVERT_TENSOR_DTYPE(paddle::experimental::DataType::BFLOAT16, - phi::dtype::bfloat16); - } - } - } - -#undef CONVERT_TENSOR_DTYPE - auto SerializeParams = [&]() -> std::string { std::ostringstream os; phi::CPUContext ctx; @@ -746,73 +111,32 @@ void ConvertToMixedPrecisionPass::SaveMixedModel() { mixed_program_desc.Proto()->SerializeAsString()); StrToBinary(mixed_params_file_, SerializeParams()); } -} // namespace - -void AddCastOp( - framework::ir::Graph* graph, - framework::ir::Node* node, - framework::ir::Node* next_op, - VarType::Type from_type, - VarType::Type to_type, - int* suffix, - framework::BlockDesc* block_desc, - std::unordered_map* map) { - auto update_cast_desc = [&](framework::OpDesc& desc, - const std::string& x_name, - const std::string& out_name, - const int in_dtype, - const int out_dtype) { - desc.SetType("cast"); - desc.SetInput("X", {x_name}); - desc.SetOutput("Out", {out_name}); - desc.SetAttr("in_dtype", in_dtype); - desc.SetAttr("out_dtype", out_dtype); - desc.SetAttr("use_mkldnn", false); - desc.SetAttr("with_quant_attr", false); - desc.Flush(); - }; - - if (map->count(node) == 0) { - // insert cast op before node. - std::string cast_input_name = node->Var()->Name(); - std::string cast_output_name = - node->Var()->Name() + "_cast.tmp_" + std::to_string((*suffix)++); - CHECK_NOTNULL(block_desc); - framework::OpDesc cast_op_desc(block_desc); - update_cast_desc(cast_op_desc, - cast_input_name, - cast_output_name, - static_cast(from_type), - static_cast(to_type)); - auto* cast_op_node = graph->CreateOpNode(&cast_op_desc); - auto* cast_output_vardesc = block_desc->Var(cast_output_name); - cast_output_vardesc->SetPersistable(false); - cast_output_vardesc->SetDataType(to_type); - cast_output_vardesc->SetShape(node->Var()->GetShape()); - auto* cast_output_node = graph->CreateVarNode(cast_output_vardesc); - IR_NODE_LINK_TO(cast_op_node, cast_output_node); - (*map)[node] = cast_output_node; - } - next_op->Op()->Rename(node->Name(), map->at(node)->Name()); - IR_NODE_LINK_TO(node, map->at(node)->inputs[0]); - IR_NODE_UNLINK(node, next_op); - IR_NODE_LINK_TO(map->at(node), next_op); -} bool OpSupportPrecision(const std::string& op_type, phi::Backend backend, phi::DataType precision, - const std::unordered_set& blacklist) { - auto phi_op_type = phi::TransToPhiKernelName(op_type); - bool support_precision = false; - if (blacklist.count(op_type) == 0) { - if (backend == phi::Backend::GPU) - support_precision = GpuKernelSupportPrecision(op_type, precision); - else - support_precision = - PhiKernelSupportPrecision(phi_op_type, backend, precision); - } - return support_precision; + const std::unordered_set& black_list) { + return framework::ir::OpSupportPrecision( + op_type, backend, precision, black_list); +} + +void InsertCastOp( + framework::ir::Graph* graph, + framework::ir::Node* var_node, + framework::ir::Node* op_node, + framework::proto::VarType::Type from_type, + framework::proto::VarType::Type to_type, + framework::BlockDesc* block_desc, + int* suffix, + std::unordered_map* visited) { + framework::ir::DoInsertCastOp(graph, + var_node, + op_node, + from_type, + to_type, + block_desc, + suffix, + visited); } void ConvertToMixedPrecision( diff --git a/paddle/fluid/inference/analysis/passes/convert_to_mixed_precision.h b/paddle/fluid/inference/analysis/passes/convert_to_mixed_precision.h index 583512408c586..3a1e5fbb30a21 100644 --- a/paddle/fluid/inference/analysis/passes/convert_to_mixed_precision.h +++ b/paddle/fluid/inference/analysis/passes/convert_to_mixed_precision.h @@ -15,14 +15,12 @@ #pragma once #include -#include #include #include "paddle/fluid/framework/block_desc.h" #include "paddle/fluid/framework/ir/graph.h" -#include "paddle/fluid/framework/ir/graph_helper.h" -#include "paddle/fluid/framework/ir/graph_pattern_detector.h" #include "paddle/fluid/framework/program_desc.h" +#include "paddle/fluid/framework/scope.h" #include "paddle/phi/common/backend.h" #include "paddle/phi/common/data_type.h" @@ -30,20 +28,52 @@ namespace paddle { namespace inference { namespace analysis { +class ConvertToMixedPrecisionPass { + public: + explicit ConvertToMixedPrecisionPass( + const std::string& model_file, + const std::string& params_file, + const std::string& mixed_model_file, + const std::string& mixed_params_file, + phi::DataType mixed_precision, + phi::Backend backend, + bool keep_io_types, + const std::unordered_set& black_list); + + void Run(); + + private: + void LoadModel(); + void SaveMixedModel(); + + private: + std::string model_file_; + std::string params_file_; + std::string mixed_model_file_; + std::string mixed_params_file_; + phi::DataType mixed_precision_; + phi::Backend backend_; + bool keep_io_types_; + std::unordered_set black_list_; + + framework::Scope scope_; + std::unique_ptr main_graph_{nullptr}; +}; + bool OpSupportPrecision(const std::string& op_type, phi::Backend backend, phi::DataType precision, - const std::unordered_set& blacklist); + const std::unordered_set& black_list); -void AddCastOp( +void InsertCastOp( framework::ir::Graph* graph, - framework::ir::Node* node, - framework::ir::Node* next_op, + framework::ir::Node* var_node, + framework::ir::Node* op_node, framework::proto::VarType::Type from_type, framework::proto::VarType::Type to_type, - int* suffix, framework::BlockDesc* block_desc, - std::unordered_map* map); + int* suffix, + std::unordered_map* visited); void ConvertToMixedPrecision(const std::string& model_file, const std::string& params_file, diff --git a/paddle/fluid/inference/api/analysis_config.cc b/paddle/fluid/inference/api/analysis_config.cc index 1224caf88e668..58b0d2a1189ad 100644 --- a/paddle/fluid/inference/api/analysis_config.cc +++ b/paddle/fluid/inference/api/analysis_config.cc @@ -99,7 +99,7 @@ void AnalysisConfig::EnableUseGpu(uint64_t memory_pool_init_size_mb, // default } else if (precision_mode == Precision::kHalf || precision_mode == Precision::kBf16) { - enable_gpu_half_ = true; + enable_gpu_mixed_ = true; } else { LOG(ERROR) << "The Paddle-GPU inference currently only supports " @@ -396,7 +396,7 @@ AnalysisConfig::AnalysisConfig(const AnalysisConfig &other) { // Mixed precision related. CP_MEMBER(mixed_black_list_); - CP_MEMBER(enable_gpu_half_); + CP_MEMBER(enable_gpu_mixed_); CP_MEMBER(mixed_precision_mode_); CP_MEMBER(enable_memory_optim_); @@ -1017,7 +1017,7 @@ std::string AnalysisConfig::SerializeInfoCache() { ss << params_file_; ss << use_gpu_; - ss << enable_gpu_half_; + ss << enable_gpu_mixed_; ss << use_external_stream_; ss << exec_stream_; ss << use_fc_padding_; @@ -1234,7 +1234,7 @@ std::string AnalysisConfig::Summary() { os.InsertRow({"use_gpu", use_gpu_ ? "true" : "false"}); if (use_gpu_) { os.InsertRow({"gpu_device_id", std::to_string(gpu_device_id_)}); - os.InsertRow({"enable_gpu_half_", std::to_string(enable_gpu_half_)}); + os.InsertRow({"enable_gpu_mixed_", std::to_string(enable_gpu_mixed_)}); os.InsertRow({"memory_pool_init_size", std::to_string(memory_pool_init_size_mb_) + "MB"}); os.InsertRow( diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index d3d1f62a20885..76de4ffd6ce60 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -1277,10 +1277,10 @@ void AnalysisPredictor::PrepareArgument() { if (!config_.ir_optim()) { argument_.SetEnableIrOptim(false); - if (config_.enable_gpu_half_) { + if (config_.enable_gpu_mixed_) { argument_.SetEnableIrOptim(true); pass_builder->ClearPasses(); - pass_builder->AppendPass("float_to_half_pass"); + pass_builder->AppendPass("auto_mixed_precision_pass"); LOG(INFO) << "This model run in Paddle-GPU mixed precision mode with no ir " "optimization."; @@ -1291,7 +1291,7 @@ void AnalysisPredictor::PrepareArgument() { if (config_.ir_debug_) { pass_builder->TurnOnDebug(); } - if (config_.enable_gpu_half_) { + if (config_.enable_gpu_mixed_) { LOG(INFO) << "This model run in Paddle-GPU mixed precision mode."; } } @@ -1303,7 +1303,7 @@ void AnalysisPredictor::PrepareArgument() { // mixed precison. argument_.SetModelPrecision(static_cast(model_precision_)); argument_.SetMixedBlackList(config_.mixed_black_list_); - argument_.SetEnableGPUHalf(config_.enable_gpu_half_); + argument_.SetEnableGPUMixed(config_.enable_gpu_mixed_); argument_.SetMixedPrecisionMode(static_cast( paddle::ConvertPrecision(config_.mixed_precision_mode_))); } diff --git a/paddle/fluid/inference/api/paddle_analysis_config.h b/paddle/fluid/inference/api/paddle_analysis_config.h index b4c5a0d293574..41eea1fb98c31 100644 --- a/paddle/fluid/inference/api/paddle_analysis_config.h +++ b/paddle/fluid/inference/api/paddle_analysis_config.h @@ -1049,7 +1049,7 @@ struct PD_INFER_DECL AnalysisConfig { bool use_gpu_{false}; int gpu_device_id_{0}; uint64_t memory_pool_init_size_mb_{100}; // initial size is 100MB. - bool enable_gpu_half_{false}; + bool enable_gpu_mixed_{false}; bool thread_local_stream_{false}; bool use_cudnn_{false}; diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index 7d325498f7a60..0f8da2894f413 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -245,7 +245,8 @@ GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) { "conv_elementwise_add_fuse_pass", // #endif // "transpose_flatten_concat_fuse_pass", // - "float_to_half_pass", // + "constant_folding_pass", // + "auto_mixed_precision_pass", // }); use_gpu_ = true;