Skip to content

Segment fault of TensorRT 10.3.0 when running sam2 decoder on Orin #4504

@04633435

Description

@04633435

Hi there, I am encountering some problems when I do the inference with trt python API. I detail the problem below, any advice would be appreciated, thx~

Description

I tried to run SAM2-Decoder on Orin with Python API, but it fails with the error below

[07/01/2025-14:14:38] [TRT] [E] IExecutionContext::enqueueV3: Error Code 1: Cuda Runtime (an illegal memory access was encountered)
...
pycuda._driver.LogicError: cuMemcpyDtoHAsync failed: an illegal memory access was encountered
PyCUDA WARNING: a clean-up operation failed (dead context maybe?)
cuMemFreeHost failed: an illegal memory access was encountered
PyCUDA WARNING: a clean-up operation failed (dead context maybe?)
cuMemFree failed: an illegal memory access was encountered
[07/01/2025-14:14:38] [TRT] [E] [graphContext.h::~MyelinGraphContext::84] Error Code 1: Myelin ([impl.cpp:cuda_object_deallocate:435] Error 700 destroying stream '0xaaaaed9cd400'.)
[07/01/2025-14:14:38] [TRT] [E] [graphContext.h::~MyelinGraphContext::84] Error Code 1: Myelin ([impl.cpp:cuda_object_deallocate:435] Error 700 destroying stream '0xaaaaeda249f0'.)
[07/01/2025-14:14:38] [TRT] [E] [graphContext.h::~MyelinGraphContext::84] Error Code 1: Myelin ([impl.cpp:cuda_object_deallocate:435] Error 700 destroying stream '0xaaaaedd49440'.)
...

I was thought it has something to do with my built engine, but I tried the inferencing via trtexec as well, and the inference worked well.

It makes me confused, is there something wrong within my inferencing code? The code I used to inference is shown below

       with engine.create_execution_context() as context:
            # Allocate host and device buffers
            tensor_names = [engine.get_tensor_name(i) for i in range(engine.num_io_tensors)]
            print(f"[INFO]tensor_names: {tensor_names}")
            
            stream = cuda.Stream()
            # context.set_optimization_profile_async(1, stream)
            batch_size = 1
            for i, tensor_name in enumerate(tensor_names):
                if engine.get_tensor_mode(tensor_name) == trt.TensorIOMode.INPUT:
                    if tensor_name == 'point_coords' or tensor_name == 'point_labels' or tensor_name == 'mask_input' or tensor_name == 'has_mask_input':
                        print(f"[INFO]{tensor_name} set_shape: {input_tensor[i].shape}")
                        context.set_input_shape(tensor_name, input_tensor[i].shape)
                    input_np_data = input_tensor[i].ravel()
                    input_device_memory = cuda.mem_alloc(input_np_data.nbytes)
                    context.set_tensor_address(tensor_name, int(input_device_memory))
                    cuda.memcpy_htod_async(input_device_memory, input_np_data, stream)
                else:
                    output_shape = context.get_tensor_shape(tensor_name)
                    output_shapes.append(output_shape)
                    size = trt.volume(context.get_tensor_shape(tensor_name))
                    dtype = trt.nptype(engine.get_tensor_dtype(tensor_name))
                    print(f"[INFO]output shape: {context.get_tensor_shape(tensor_name)}, size: {size}, dtype: {dtype}")
                    output_buffer = cuda.pagelocked_empty(size, dtype)
                    output_memory = cuda.mem_alloc(output_buffer.nbytes)
                    context.set_tensor_address(tensor_name, int(output_memory))
                    output_buffers.append(output_buffer)
                    output_memorys.append(output_memory)

            # Run inference
            context.execute_async_v3(stream_handle=stream.handle)

            # Transfer prediction output from the GPU.
            for output_buffer, output_memory in zip(output_buffers, output_memorys):
                cuda.memcpy_dtoh_async(output_buffer, output_memory, stream)
            # Synchronize the stream
            stream.synchronize()
            out0 = output_buffers[0].reshape(output_shapes[0])
            out1 = output_buffers[1].reshape(output_shapes[1])

Environment

TensorRT Version: 10.3.0

NVIDIA GPU: Orin(64GB)

NVIDIA Driver Version: 540.4.0

CUDA Version: 12.6

CUDNN Version: 9.3.0

Operating System:

Python Version (if applicable): 3.10.12

Tensorflow Version (if applicable):

PyTorch Version (if applicable):

Baremetal or Container (if so, version):

Relevant Files

Model link:

Steps To Reproduce

Commands or scripts:

Have you tried the latest release?:

Can this model run on other frameworks? For example run ONNX model with ONNXRuntime (polygraphy run <model.onnx> --onnxrt):

Metadata

Metadata

Assignees

No one assigned

    Labels

    InvestigatingIssue is under investigation by TensorRT devsModule:Embeddedissues when using TensorRT on embedded platformsModule:RuntimeOther generic runtime issues that does not fall into other modules

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions