Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Error Code 4: Internal Error (Network must have at least one output) #1728

Closed
Tian14267 opened this issue Jan 13, 2022 · 7 comments
Closed
Labels
Topic: ND shape triaged Issue has been triaged by maintainers

Comments

@Tian14267
Copy link

Tian14267 commented Jan 13, 2022

Description

I covert my onnx model into trt , but I get this error : Error Code 4: Internal Error (Network must have at least one output)
detail error is :

Loading ONNX file from path ./onnx_model/FastSpeech_.onnx...
Beginning ONNX file parsing
[01/13/2022-20:56:48] [TRT] [W] onnx2trt_utils.cpp:366: Your ONNX model has been generated with INT64 weights, while TensorRT does not natively support INT64. Attempting to cast down to INT32.
[01/13/2022-20:57:21] [TRT] [W] onnx2trt_utils.cpp:392: One or more weights outside the range of INT32 was clamped
[01/13/2022-20:57:21] [TRT] [E] [shuffleNode.cpp::symbolicExecute::387] Error Code 4: Internal Error (Reshape_3578: IShuffleLayer applied to shape tensor must have 0 or 1 reshape dimensions: dimensions were [-1,2])
Completed parsing of ONNX file
Building an engine from file ./onnx_model/FastSpeech_.onnx; this may take a while...
6875
[01/13/2022-20:57:21] [TRT] [E] 4: [network.cpp::validate::2633] Error Code 4: Internal Error (Network must have at least one output)
Completed creating Engine
Traceback (most recent call last):
  File "onnx2trt.py", line 102, in <module>
    engine = ONNX_build_engine_acoustic_model(onnx_file_path, write_engine)
  File "onnx2trt.py", line 88, in ONNX_build_engine_acoustic_model
    f.write(engine.serialize())
AttributeError: 'NoneType' object has no attribute 'serialize'

Here is My code:

    G_LOGGER = trt.Logger(trt.Logger.WARNING)
    explicit_batch = 1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
    batch_size = 8 
    with trt.Builder(G_LOGGER) as builder, builder.create_network(explicit_batch) as network, \
            trt.OnnxParser(network, G_LOGGER) as parser:
        builder.max_batch_size = batch_size
        config = builder.create_builder_config()
        config.max_workspace_size = common.GiB(2)
        config.set_flag(trt.BuilderFlag.FP16)
        print('Loading ONNX file from path {}...'.format(onnx_file_path))
        with open(onnx_file_path, 'rb') as model:
            print('Beginning ONNX file parsing')
            parser.parse(model.read())
        print('Completed parsing of ONNX file')
        print('Building an engine from file {}; this may take a while...'.format(onnx_file_path))
        profile = builder.create_optimization_profile() 
        profile.set_shape("texts", (1, 2), (1, 50), (1, 500))
        profile.set_shape("src_lens.1", [1], [50], [500])
        profile.set_shape("speakers", [0], [1], [10])
        profile.set_shape("max_src_len", [1], [50], [500])
        config.add_optimization_profile(profile)

        engine = builder.build_engine(network, config)
        print("Completed creating Engine")
        if write_engine:
            engine_file_path = 'acoustic_model_efficientnet_b1.trt'
            with open(engine_file_path, "wb") as f:
                f.write(engine.serialize())

Here is my onnx_model

My onnx is like this:
image

Environment

TensorRT Version: TensorRT-8.2.1.8
NVIDIA GPU: GeForce RTX 3090
NVIDIA Driver Version:
CUDA Version: 11.2
CUDNN Version:
Operating System: centos 7
Python Version (if applicable): 3.6
Tensorflow Version (if applicable): 1.15.4
PyTorch Version (if applicable): 1.8.1
Baremetal or Container (if so, version):

Can anybody help me ?

@Tian14267
Copy link
Author

If I add

last_layer = network.get_layer(network.num_layers - 1)
network.mark_output(last_layer.get_output(0))

Then ,this will happen:
[01/14/2022-09:21:08] [TRT] [E] [shuffleNode.cpp::symbolicExecute::387] Error Code 4: Internal Error (Reshape_3578: IShuffleLayer applied to shape tensor must have 0 or 1 reshape dimensions: dimensions were [-1,2])
image

@ttyio
Copy link
Collaborator

ttyio commented Feb 23, 2022

@Tian14267
the error is

[01/13/2022-20:57:21] [TRT] [E] [shuffleNode.cpp::symbolicExecute::387] Error Code 4: Internal Error (Reshape_3578: IShuffleLayer applied to shape tensor must have 0 or 1 reshape dimensions: dimensions were [-1,2])

have you tried:

 polygraphy surgeon sanitize model.onnx --fold-constants --output model_folded.onnx

If this not works, we have to wait next one or two major release, which support N-D shape tensor feature to fix this.

@ttyio ttyio added Topic: ND shape triaged Issue has been triaged by maintainers labels Feb 23, 2022
@Monibsediqi
Copy link

Hi! Any update on this? I'm facing a similar issue.

@nvpohanh
Copy link
Collaborator

ND shape tensor will be supported in the next version after TRT 8.4 GA

@zzg-tju
Copy link

zzg-tju commented Aug 17, 2022

Problem solved, thanks for the solution~

@ttyio
Copy link
Collaborator

ttyio commented Aug 24, 2022

Closing since it is fix, thanks!

@ttyio ttyio closed this as completed Aug 24, 2022
@catalwaysright
Copy link

Problem solved, thanks for the solution~

@zzg-tju How did you solve this?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Topic: ND shape triaged Issue has been triaged by maintainers
Projects
None yet
Development

No branches or pull requests

6 participants