Skip to content

Commit

Permalink
Fix the bug of compile model to tensorrt
Browse files Browse the repository at this point in the history
Closes #491

Signed-off-by: zhangkaili <zhang.kaili@zte.com.cn>
  • Loading branch information
KellyZhang2020 committed Jul 19, 2021
1 parent 3436030 commit 1067863
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@
import os
from typing import Any, Mapping, NamedTuple, Optional, Sequence, List
import numpy as np
import caffe2.python.onnx.frontend
from caffe2.proto import caffe2_pb2
from . import repository
from .. import utilities
from ..models.data_format import DataFormat
Expand Down Expand Up @@ -48,6 +46,9 @@ def parse_caffe_net(net, pb_path):

@repository.REPOSITORY.register(source_type=CaffeModelFile, target_type=OnnxModel, config_type=Config)
def compile_source(source: CaffeModelFile, config: Config) -> OnnxModel:
import caffe2.python.onnx.frontend # pylint: disable=import-outside-toplevel
from caffe2.proto import caffe2_pb2 # pylint: disable=import-outside-toplevel

predict_net = parse_caffe_net(caffe2_pb2.NetDef(), os.path.join(source.model_path, 'predict_net.pb'))
predict_net.name = "model" if predict_net.name == "" else predict_net.name # pylint: disable=no-member
init_net = parse_caffe_net(caffe2_pb2.NetDef(), os.path.join(source.model_path, 'init_net.pb'))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from tempfile import NamedTemporaryFile
import onnx
import onnx.utils
from paddle2onnx.command import program2onnx
from . import repository
from .. import utilities

Expand Down Expand Up @@ -43,6 +42,7 @@ def from_env(env: Mapping[str, str]) -> 'Config':

@repository.REPOSITORY.register(source_type=PaddlePaddleModelFile, target_type=OnnxModel, config_type=Config)
def compile_source(source: PaddlePaddleModelFile, config: Config) -> OnnxModel:
from paddle2onnx.command import program2onnx # pylint: disable=import-outside-toplevel
with NamedTemporaryFile(suffix='.onnx') as onnx_file:
program2onnx(model_dir=source.model_path,
save_file=onnx_file.name,
Expand Down

0 comments on commit 1067863

Please sign in to comment.