Skip to content

Commit

Permalink
fix ernie serialize problem (#36769)
Browse files Browse the repository at this point in the history
  • Loading branch information
zlsh80826 committed Oct 27, 2021
1 parent 5e9845b commit d6b1beb
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 3 deletions.
6 changes: 3 additions & 3 deletions paddle/fluid/inference/tensorrt/engine.cc
Expand Up @@ -233,11 +233,11 @@ void TensorRTEngine::FreezeNetwork() {
*network(), *infer_builder_config_));
#else
infer_builder_config_->setFlag(nvinfer1::BuilderFlag::kSPARSE_WEIGHTS);
infer_ptr<nvinfer1::IHostMemory> plan(infer_builder_->buildSerializedNetwork(
ihost_memory_.reset(infer_builder_->buildSerializedNetwork(
*network(), *infer_builder_config_));
infer_ptr<nvinfer1::IRuntime> runtime(createInferRuntime(&logger_));
infer_engine_.reset(
runtime->deserializeCudaEngine(plan->data(), plan->size()));
infer_engine_.reset(runtime->deserializeCudaEngine(ihost_memory_->data(),
ihost_memory_->size()));
#endif

PADDLE_ENFORCE_NOT_NULL(
Expand Down
7 changes: 7 additions & 0 deletions paddle/fluid/inference/tensorrt/engine.h
Expand Up @@ -273,7 +273,14 @@ class TensorRTEngine {
infer_engine_,
platform::errors::InvalidArgument(
"The TensorRT engine must be built first before serialization"));
#if IS_TRT_VERSION_LT(8000)
ihost_memory_.reset(infer_engine_->serialize());
#else
PADDLE_ENFORCE_NOT_NULL(
ihost_memory_,
platform::errors::InvalidArgument(
"TensorRT >= 8.0 requires that buildSerializedNetwork is called"));
#endif
return ihost_memory_.get();
}

Expand Down

0 comments on commit d6b1beb

Please sign in to comment.