Skip to content

Commit

Permalink
Fix the bug of compile checkpoint to tensorrt
Browse files Browse the repository at this point in the history
Closes #491

Signed-off-by: zhangkaili <zhang.kaili@zte.com.cn>
  • Loading branch information
KellyZhang2020 committed Jul 9, 2021
1 parent 9ba353a commit f304c36
Showing 1 changed file with 4 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@
import os
from typing import Any, Mapping, NamedTuple, Optional, Sequence, List
import numpy as np
import caffe2.python.onnx.frontend
from caffe2.proto import caffe2_pb2
from . import repository
from .. import utilities
from ..models.data_format import DataFormat
Expand Down Expand Up @@ -48,6 +46,9 @@ def parse_caffe_net(net, pb_path):

@repository.REPOSITORY.register(source_type=CaffeModelFile, target_type=OnnxModel, config_type=Config)
def compile_source(source: CaffeModelFile, config: Config) -> OnnxModel:
from caffe2.python.onnx.frontend import caffe2_net_to_onnx_model # pylint: disable=import-outside-toplevel
from caffe2.proto import caffe2_pb2 # pylint: disable=import-outside-toplevel

predict_net = parse_caffe_net(caffe2_pb2.NetDef(), os.path.join(source.model_path, 'predict_net.pb'))
predict_net.name = "model" if predict_net.name == "" else predict_net.name # pylint: disable=no-member
init_net = parse_caffe_net(caffe2_pb2.NetDef(), os.path.join(source.model_path, 'init_net.pb'))
Expand All @@ -57,7 +58,7 @@ def compile_source(source: CaffeModelFile, config: Config) -> OnnxModel:
input_shape.insert(0, config.max_batch_size)
value_info[config.input_names[i]] = (config.input_type, input_shape)

onnx_model = caffe2.python.onnx.frontend.caffe2_net_to_onnx_model(predict_net, init_net, value_info)
onnx_model = caffe2_net_to_onnx_model(predict_net, init_net, value_info)

graph = onnx_model.graph # pylint: disable=no-member
return OnnxModel(model_proto=onnx_model,
Expand Down

0 comments on commit f304c36

Please sign in to comment.