Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tensor rt cuda graph use with plugin. #802

Closed
kycocotree opened this issue Sep 29, 2020 · 2 comments
Closed

tensor rt cuda graph use with plugin. #802

kycocotree opened this issue Sep 29, 2020 · 2 comments
Labels
Plugins triaged Issue has been triaged by maintainers

Comments

@kycocotree
Copy link

Description

Environment

TensorRT Version: 7.1.3.4
GPU Type: GTX 1080
Nvidia Driver Version: 440.100
CUDA Version: 10.2
CUDNN Version: 8.0
Operating System + Version: Ubuntu 18.04
Python Version (if applicable):
TensorFlow Version (if applicable):
PyTorch Version (if applicable):
Baremetal or Container (which commit + image + tag):

Relevant Files

Steps To Reproduce

Hi, I made a tensorrt engine by adding a plugin to the yolov3 model.
I want to launch a graph using cudaStreamCapture function. (bellow code, (BuildGraph).)
However, when capturing a graph, 900 cudaError occurs in the cuda kernel function added as a plugin.
How can I create a kernel function made for plugin use to capture it in graph?

bool TRTContext::BuildGraph(trt_device::TrtCudaStream& stream)
{
/*
https://docs.nvidia.com/deeplearning/tensorrt/api/c_api/classnvinfer1_1_1_i_execution_context.html#ac7a5737264c2b7860baef0096d961f5a
Note:
Calling enqueueV2() with a stream in CUDA graph capture mode has a known issue.
If dynamic shapes are used, the first enqueueV2() call after a setInputShapeBinding() call will cause failure in stream capture due to resource allocation.
Please call enqueueV2() once before capturing the graph.
*/

bool captured = true;
if (m_bExplicitBatch_) {
    m_pExecutionContext_->enqueueV2(&m_vecDeviceBindings_[0], stream.get(), nullptr);
}

captured = m_cudaGraph_.beginCapture(stream);
if (captured) {
    if (m_bExplicitBatch_) {
        if (!m_pExecutionContext_->enqueueV2(&m_vecDeviceBindings_[0], stream.get(), nullptr)) {
            // CHECK(cuerr);
            captured = false;
        }
        printf("[%s] Use enqueueV2\n", __FUNCTION__);
    }
    else {
        if (!m_pExecutionContext_->enqueue(m_iMaxBatchSize_, &m_vecDeviceBindings_[0], stream.get(), nullptr)) {
            // CHECK(cuerr);
            captured = false;
        }
    }

    if (captured) {
        captured = m_cudaGraph_.endCapture(stream);
    }
} 

return captured;

}

In .cu file
int cuda::decode (int batch_size,
const void *const *inputs, void *outputs,
size_t height, size_t width,
size_t num_count, size_t num_classes,
const std::vector &anchors,
size_t stride, size_t grid_w, size_t grid_h,
void workspace, size_t workspace_size, cudaStream_t stream)
{
int anchor_size = anchors.size() / num_count;
int anchors_byte_size = sizeof(float) * anchors.size();
void
anchors_d = nullptr;
CUDA_CHECK(cudaMalloc(&anchors_d, anchors_byte_size));
CUDA_CHECK(cudaMemcpyAsync(anchors_d, anchors.data(), anchors_byte_size, cudaMemcpyHostToDevice, stream));
cudaStreamSynchronize(stream);

int input_elements_size = grid_w * grid_h * num_count * (1 + LOCATIONS + num_classes);
int output_elements_size = (grid_w * grid_h * num_count * sizeof(Detection) / sizeof(float));

CUDA_CHECK(cudaMemset(static_cast<float*>(outputs), 0, output_elements_size * batch_size * sizeof(float)));

int num_threads = 1024;
int num_elements = grid_h * grid_w;
if (num_elements < num_threads)
num_threads = num_elements;

int num_blocks = 1;
for (int b = 0; b < batch_size; b++) {
const float* input = static_cast<const float >(inputs[0]) + (b * input_elements_size);
float
output = static_cast<float >(outputs) + (b * output_elements_size);
const int num_per_thread = num_elements / num_threads;
if (num_per_thread >= 1)
num_blocks = num_per_thread;
//printf("num_threads: %d, num_per_thread: %d, num_elements:%d\n", num_threads, num_per_thread, num_elements);
decode_kernel <<< num_blocks, num_threads, 0, stream >>>
(num_per_thread, input, output, static_cast<float
>(anchors_d), anchor_size, grid_w, grid_h, width, height, num_classes, num_count, num_elements);
}
cudaStreamSynchronize(stream);
CUDA_CHECK(cudaFree(anchors_d));

return cudaSuccess;
}

@ttyio
Copy link
Collaborator

ttyio commented Oct 14, 2020

Hello @kycocotree , thanks for reporting the issue.

cuda graph use with trt plugin works, and we have demoBERT that support cuda graph and most of the layers were written in plugin, https://github.com/NVIDIA/TensorRT/tree/release/7.1/demo/BERT

Is the above cuda::decode function called inside enqueue? If so could you move out all the cudaMalloc, cudaMemcpyAsync, cudaMemset? Some of those runtime API call might not supported in CUDA graph. Also usually there is no resource allocation in enqueue function, APIs cudaMalloc/cudaMemset are synchronization APIs and harm the performance. You can check other opensource plugins as reference.

Thanks!

@ttyio ttyio added the triaged Issue has been triaged by maintainers label Oct 14, 2020
@ttyio ttyio added the Plugins label Nov 2, 2020
@ttyio
Copy link
Collaborator

ttyio commented Nov 2, 2020

close since no response for more than 2 weeks, please reopen if you still have question, thanks!

@ttyio ttyio closed this as completed Nov 2, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Plugins triaged Issue has been triaged by maintainers
Projects
None yet
Development

No branches or pull requests

2 participants