Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Question] Is IOutputAllocator::reallocateOutput guaranteed to be called before context->enqueueV3 returns? #3875

Open
mirzadeh opened this issue May 16, 2024 · 5 comments
Assignees
Labels
triaged Issue has been triaged by maintainers

Comments

@mirzadeh
Copy link

Description

I cannot find any information regarding when IOutputAllocator::reallocateOutput is called with respect to context->enqueueV3. Is there any guarantee this function is called before enqueueV3 returns or should I explicitly synchronize stream?

In other words, in the following pseudo-code:

context_->setOutputAllocator(name, allocator);
// ...
context->enqueueV3(stream);
cudaStreamSynchronize(stream); // <-- Is this necessary?

// Memcpy devie -> host; Is it valid to ask allocator for device buffer without stream synchronization?
cudaMemcpyAsync(hostBuffer, allocator->getDeviceBuffer(), ...);

Should I explicitly synchronize the stream after enqueueV3 for device allocator->getDeviceBuffer() to be valid? Or is allocator->reallocateOutput guaranteed to be called before enqueueV3 returns, in which case stream synchronization is unnecessary?

Environment

TensorRT Version:

NVIDIA GPU:

NVIDIA Driver Version:

CUDA Version:

CUDNN Version:

Operating System:

Python Version (if applicable):

Tensorflow Version (if applicable):

PyTorch Version (if applicable):

Baremetal or Container (if so, version):

Relevant Files

Model link:

Steps To Reproduce

Commands or scripts:

Have you tried the latest release?:

Can this model run on other frameworks? For example run ONNX model with ONNXRuntime (polygraphy run <model.onnx> --onnxrt):

@mirzadeh mirzadeh changed the title [Question] Is IOutputAllocator::reallocateOutput guaranteed to be called before context->enqueueV3 returns? [Question] Is IOutputAllocator::reallocateOutput guaranteed to be called before context->enqueueV3 returns? May 16, 2024
@zerollzeng
Copy link
Collaborator

Please refer to our api doc: https://docs.nvidia.com/deeplearning/tensorrt/api/c_api/classnvinfer1_1_1_i_execution_context.html#aa174ba57c44df821625ce4d3317dd7aa

should I explicitly synchronize stream?

yes

Should I explicitly synchronize the stream after enqueueV3 for device allocator->getDeviceBuffer() to be valid?

the ptr is always valid until you free the memory, but the correct output is ready only after synchronization is done.

@zerollzeng zerollzeng self-assigned this May 26, 2024
@zerollzeng zerollzeng added the triaged Issue has been triaged by maintainers label May 26, 2024
@mirzadeh
Copy link
Author

I think my question was more about the calling order of reallocateOutput and enqueueV3. Since enqueueV3 is async, is it possible that by the time cudaMemcpy is called, reallocateOutput is still not called by TensorRT and therefore the device pointer is invalid (b/c reallocate might return a different pointer)?

If there is guarantee that reallocateOutput is always called by the time enqueueV3 returns, there is no need for an explicit synchronization before memcpy.

@GDUTLinsy
Copy link

401decd0b4f21766da3b8e4f98c5d66
I'm having the following problem, what should I do?Here's my code.
Uploading 90783492ee9173ad92e9da4d3b9acd3.png…

@GDUTLinsy
Copy link

Uploading 90783492ee9173ad92e9da4d3b9acd3.png…

@GDUTLinsy
Copy link

void* buffers[2]{};

const int inputIndex = 0;
const int outputIndex = 1;

// Create GPU buffers on device 
CHECK(cudaMalloc(&buffers[inputIndex], batchSize * 3 * IN_H * IN_W * sizeof(float)));
CHECK(cudaMalloc(&buffers[outputIndex], batchSize * 4));

//为输入输出传递Tensorrt缓冲区
context.setTensorAddress(IN_NAME, buffers[inputIndex]);
context.setTensorAddress(OUT_NAME, buffers[outputIndex]);

// Create stream 
cudaStream_t stream{};
CHECK(cudaStreamCreate(&stream));

// DMA input batch data to device, infer on the batch asynchronously, and DMA output back to host 
CHECK(cudaMemcpyAsync(input,buffers[inputIndex], batchSize * 3 * IN_H * IN_W * sizeof(float), cudaMemcpyHostToDevice, stream));

//执行推理
context.enqueueV3(stream);

CHECK(cudaMemcpyAsync(output, buffers[outputIndex], batchSize * 4 * sizeof(float), cudaMemcpyDeviceToHost, stream));
CHECK(cudaStreamSynchronize(stream));

// Release stream and buffers 
CHECK(cudaStreamDestroy(stream));
CHECK(cudaFree(buffers[inputIndex]));
CHECK(cudaFree(buffers[outputIndex]));

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
triaged Issue has been triaged by maintainers
Projects
None yet
Development

No branches or pull requests

3 participants