Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Inference] optimize some code and fix some bug #48780

Merged
merged 2 commits into from
Dec 9, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions paddle/fluid/framework/ir/delete_fill_constant_op_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,11 @@ void FillConstData(phi::DenseTensor* out_t, T value) {
}

void DeleteFillConstantOpPass::ApplyImpl(ir::Graph* graph) const {
bool with_dynamic_shape = Get<bool>("with_dynamic_shape");
// Not support
if (with_dynamic_shape) {
return;
}
FusePassBase::Init("delete_fill_constant_op_pass", graph);
GraphPatternDetector detector;
auto fill_constant_op =
Expand Down
64 changes: 40 additions & 24 deletions paddle/fluid/framework/ir/float_to_half_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,12 @@

#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/common/bfloat16.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/errors.h"

namespace paddle {
namespace framework {
Expand Down Expand Up @@ -620,34 +625,45 @@ void FloatToHalfPass::ConvertWeightsData() const {
for (const auto& var_name : var_names) {
if (vars_convert_to_half_.count(var_name)) {
VLOG(4) << var_name << "'s data type was convert to half";
#define CONVERT_TENSOR_DTYPE(DTYPE, dtype) \
half_tensor.set_type(DTYPE); \
auto* half_data = half_tensor.mutable_data<dtype>(platform::CPUPlace()); \
for (int64_t i = 0; i < origin_tensor->numel(); i++) { \
half_data[i] = static_cast<dtype>(origin_data[i]); \
} \
origin_tensor->clear(); \
paddle::framework::TensorCopySync( \
half_tensor, platform::CPUPlace(), origin_tensor)

auto* var = scope->FindLocalVar(var_name);

if (var->IsType<phi::DenseTensor>()) {
auto* origin_tensor = var->GetMutable<phi::DenseTensor>();
phi::DenseTensor half_tensor;
half_tensor.Resize(origin_tensor->dims());
auto* origin_data =
origin_tensor->mutable_data<float>(platform::CPUPlace());
if (half_precision_ == phi::DataType::FLOAT16) {
CONVERT_TENSOR_DTYPE(paddle::experimental::DataType::FLOAT16,
phi::dtype::float16);
} else if (half_precision_ == phi::DataType::BFLOAT16) {
CONVERT_TENSOR_DTYPE(paddle::experimental::DataType::BFLOAT16,
phi::dtype::bfloat16);
CHECK_EQ(var->IsType<phi::DenseTensor>(), true);

auto* origin_tensor = var->GetMutable<phi::DenseTensor>();

phi::DenseTensor half_tensor;
half_tensor.Resize(origin_tensor->dims());
half_tensor.set_type(half_precision_);

if (half_precision_ == phi::DataType::FLOAT16) {
auto* half_data =
half_tensor.mutable_data<phi::dtype::float16>(phi::CPUPlace{});
for (int64_t i = 0; i < origin_tensor->numel(); i++) {
if (origin_tensor->dtype() == phi::DataType::FLOAT64) {
auto* origin_data = origin_tensor->data<double>();
half_data[i] = static_cast<phi::dtype::float16>(origin_data[i]);
} else if (origin_tensor->dtype() == phi::DataType::FLOAT32) {
auto* origin_data = origin_tensor->data<float>();
half_data[i] = static_cast<phi::dtype::float16>(origin_data[i]);
}
}
} else if (half_precision_ == phi::DataType::BFLOAT16) {
auto* half_data =
half_tensor.mutable_data<phi::dtype::bfloat16>(phi::CPUPlace{});
for (int64_t i = 0; i < origin_tensor->numel(); i++) {
if (origin_tensor->dtype() == phi::DataType::FLOAT64) {
auto* origin_data = origin_tensor->data<double>();
half_data[i] = static_cast<phi::dtype::bfloat16>(origin_data[i]);
} else if (origin_tensor->dtype() == phi::DataType::FLOAT32) {
auto* origin_data = origin_tensor->data<float>();
half_data[i] = static_cast<phi::dtype::bfloat16>(origin_data[i]);
}
}
}
origin_tensor->clear();
paddle::framework::TensorCopySync(
half_tensor, phi::CPUPlace{}, origin_tensor);
}
#undef CONVERT_TENSOR_DTYPE
}
}

Expand Down
3 changes: 0 additions & 3 deletions paddle/fluid/framework/ir/float_to_half_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,6 @@
#include "paddle/fluid/framework/ir/node.h"
#include "paddle/phi/common/backend.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/common/layout.h"
#include "paddle/phi/common/place.h"

namespace paddle {
namespace framework {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ void MapDepthwiseConv2ConvPass::ApplyImpl(ir::Graph* graph) const {
std::string op_type = op_desc->Type();
if (!replaced_map.count(op_type)) continue;
op_desc->SetType(replaced_map[op_type]);
op_desc->SetAttr("use_cudnn", true);
op_desc->Flush();
++found_count;
}
Expand Down
27 changes: 2 additions & 25 deletions paddle/fluid/inference/analysis/ir_pass_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/inference/analysis/argument.h"
#include "paddle/fluid/string/pretty_log.h"
#include "paddle/phi/core/errors.h"

namespace paddle {
namespace inference {
Expand Down Expand Up @@ -303,42 +304,18 @@ void IRPassManager::CreatePasses(Argument *argument,
}

std::unique_ptr<Graph> IRPassManager::Apply(std::unique_ptr<Graph> graph) {
if (passes_.empty()) {
return graph;
}
PADDLE_ENFORCE_NOT_NULL(
graph.get(),
platform::errors::PreconditionNotMet("Graph cannot be NULL."));
graph.get(), platform::errors::InvalidArgument("Graph cannot be null."));
// Apply all the passes
for (const auto &pass : passes_) {
if (pass->Type() != "graph_viz_pass" && !disable_logs_) {
PrettyLogEndl(Style::H2(), "--- Running IR pass [%s]", pass->Type());
}
// delete_fill_constant_op_pass is not apply under trt dynamic shape
if (pass->Type() == "delete_fill_constant_op_pass") {
bool use_dynamic = pass->Get<bool>("with_dynamic_shape");
if (use_dynamic) continue;
}
graph.reset(pass->Apply(graph.release()));
}
return graph;
}

framework::proto::ProgramDesc IRPassManager::AcquireProgram(
std::unique_ptr<Graph> *graph, ProgramDesc *program) const {
auto pass =
framework::ir::PassRegistry::Instance().Get("graph_to_program_pass");

// Direct using ProgramDesc desc(argument->main_program()) may cause
// incomplete copies of information.
ProgramDesc desc;
desc.CopyFrom(*program->Proto());
pass->SetNotOwned("program", &desc);
auto *the_graph = graph->release();
graph->reset(pass->Apply(the_graph));
return *desc.Proto();
}

} // namespace analysis
} // namespace inference
} // namespace paddle
6 changes: 0 additions & 6 deletions paddle/fluid/inference/analysis/ir_pass_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,15 +48,9 @@ class IRPassManager final {

std::unique_ptr<Graph> Apply(std::unique_ptr<Graph> graph);

framework::proto::ProgramDesc AcquireProgram(std::unique_ptr<Graph> *graph,
ProgramDesc *program) const;

framework::ir::Graph &graph() const { return *graph_; }

private:
void CreatePasses(Argument *argument, const std::vector<std::string> &passes);

std::unique_ptr<Graph> graph_;
std::vector<std::unique_ptr<Pass>> passes_;
bool disable_logs_{false};
};
Expand Down
8 changes: 4 additions & 4 deletions paddle/fluid/inference/api/analysis_config.cc
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ void AnalysisConfig::EnableUseGpu(uint64_t memory_pool_init_size_mb,
}
#else
LOG(ERROR) << "Please use PaddlePaddle with GPU version.";
use_gpu_ = false;
#endif

Update();
Expand Down Expand Up @@ -299,7 +300,7 @@ void AnalysisConfig::LoadIpuConfig(const std::string &config_path) {

if (ipu_config_mapper_.find(key) == ipu_config_mapper_.end()) {
PADDLE_THROW(platform::errors::InvalidArgument(
"invalid key {} in IPU config: ", key));
"invalid key %s in IPU config: ", key));
}
switch (ipu_config_mapper_.at(key)) {
case ipu_config_code::ipu_device_num:
Expand Down Expand Up @@ -335,10 +336,9 @@ void AnalysisConfig::LoadIpuConfig(const std::string &config_path) {
case ipu_config_code::ipu_enable_model_runtime_executor:
ipu_enable_model_runtime_executor_ = string2bool(value);
break;

default:
PADDLE_THROW(platform::errors::InvalidArgument(
"invalid key {} in IPU config", key));
"invalid key %s in IPU config", key));
break;
}
}
Expand Down Expand Up @@ -1424,7 +1424,7 @@ bool AnalysisConfig::trt_allow_build_at_runtime() const {
return trt_allow_build_at_runtime_;
}

void AnalysisConfig::Exp_DisableMixedInferOps(
void AnalysisConfig::Exp_DisableMixedPrecisionOps(
const std::unordered_set<std::string> &black_list) {
mixed_black_list_ = black_list;
}
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/inference/api/paddle_analysis_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -1009,7 +1009,7 @@ struct PD_INFER_DECL AnalysisConfig {
/// interface is in the experimental stage and may change in the future. Note
/// that the blacklist must be the same as the model conversion blacklist.
///
void Exp_DisableMixedInferOps(
void Exp_DisableMixedPrecisionOps(
const std::unordered_set<std::string>& black_list);

void SetApplyOptim(bool value) { apply_optim_ = value; }
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/inference/tests/api/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -418,7 +418,7 @@ if(WITH_GPU)
analyzer_ernie_tester.cc)
inference_analysis_api_test(gpu_ernie_half_test ${ERNIE_INSTALL_DIR}
gpu_ernie_half_test.cc)
set_tests_properties(gpu_ernie_half_test PROPERTIES TIMEOUT 40)
set_tests_properties(gpu_ernie_half_test PROPERTIES TIMEOUT 60)
endif()
inference_analysis_api_int8_test(test_analyzer_ernie_int8 ${ERNIE_INSTALL_DIR}
analyzer_ernie_int8_tester.cc)
Expand Down