From d6b1beb063585dca355ced6f0f0ec09084b8c68d Mon Sep 17 00:00:00 2001 From: zlsh80826 Date: Wed, 27 Oct 2021 18:42:58 +0800 Subject: [PATCH] fix ernie serialize problem (#36769) --- paddle/fluid/inference/tensorrt/engine.cc | 6 +++--- paddle/fluid/inference/tensorrt/engine.h | 7 +++++++ 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/inference/tensorrt/engine.cc b/paddle/fluid/inference/tensorrt/engine.cc index 26182a7932199..575c018586361 100644 --- a/paddle/fluid/inference/tensorrt/engine.cc +++ b/paddle/fluid/inference/tensorrt/engine.cc @@ -233,11 +233,11 @@ void TensorRTEngine::FreezeNetwork() { *network(), *infer_builder_config_)); #else infer_builder_config_->setFlag(nvinfer1::BuilderFlag::kSPARSE_WEIGHTS); - infer_ptr plan(infer_builder_->buildSerializedNetwork( + ihost_memory_.reset(infer_builder_->buildSerializedNetwork( *network(), *infer_builder_config_)); infer_ptr runtime(createInferRuntime(&logger_)); - infer_engine_.reset( - runtime->deserializeCudaEngine(plan->data(), plan->size())); + infer_engine_.reset(runtime->deserializeCudaEngine(ihost_memory_->data(), + ihost_memory_->size())); #endif PADDLE_ENFORCE_NOT_NULL( diff --git a/paddle/fluid/inference/tensorrt/engine.h b/paddle/fluid/inference/tensorrt/engine.h index 0e1b9fe3366ca..9397d4e89de42 100644 --- a/paddle/fluid/inference/tensorrt/engine.h +++ b/paddle/fluid/inference/tensorrt/engine.h @@ -273,7 +273,14 @@ class TensorRTEngine { infer_engine_, platform::errors::InvalidArgument( "The TensorRT engine must be built first before serialization")); +#if IS_TRT_VERSION_LT(8000) ihost_memory_.reset(infer_engine_->serialize()); +#else + PADDLE_ENFORCE_NOT_NULL( + ihost_memory_, + platform::errors::InvalidArgument( + "TensorRT >= 8.0 requires that buildSerializedNetwork is called")); +#endif return ihost_memory_.get(); }