Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

multi-stream parallel execution with one GPU ERROR #846

Closed
Jacoobr opened this issue Oct 24, 2020 · 10 comments
Closed

multi-stream parallel execution with one GPU ERROR #846

Jacoobr opened this issue Oct 24, 2020 · 10 comments
Labels
documentation triaged Issue has been triaged by maintainers

Comments

@Jacoobr
Copy link

Jacoobr commented Oct 24, 2020

I tried to execute two parallel infer with one context and two stream, but i got the Error bellowed:

\cudaDeconvolutionRunner.cpp (295) - Cudnn Error in nvinfer1::rt::cuda::DeconvolutionRunner::execute: 8 (CUDNN_STATUS_EXECUTION_FAILED)
[10/24/2020-15:52:49] [E] [TRT] FAILED_EXECUTION: Unknown exception
[10/24/2020-15:52:49] [E] [TRT] C:\source\rtSafe\safeContext.cpp (133) - Cudnn Error in nvinfer1::rt::CommonContext::configure: 7 (CUDNN_STATUS_MAPPING_ERROR)

the code for executing two parallel infer (the execution of enqueueV2() are asynchronously) with one context and two stream are followed:

    cudaStream_t streamS{ NULL };
    cudaStreamCreate(&streamS);
    cudaStream_t streamS2{ NULL };
    cudaStreamCreate(&streamS2);
    auto context = SampleUniquePtr<nvinfer1::IExecutionContext>(mEngine->createExecutionContext());
    if (!context)
    {
        return false;
    }
    // Read the input data into the managed buffers
    assert(mParams.inputTensorNames.size() == 1);
    if (!processInput(buffers))
    {
        return false;
    }

    processInput(buffers2);
    // Memcpy from host input buffers to device input buffers
    //***buffers.copyInputToDevice();
    buffers.copyInputToDeviceAsync(streamS);
    buffers2.copyInputToDeviceAsync(streamS2);
    std::cout << "before execute inter time (ms):" << (clock() - start_) << "\n";
    bool status = true;
    clock_t a = clock();
    for (unsigned i = 0; i < 100; ++i) {
    //context->executeV2(buffers.getDeviceBindings().data());     //calculate the prediction result.
    context->enqueueV2(buffers.getDeviceBindings().data(), streamS, nullptr);
    context->enqueueV2(buffers2.getDeviceBindings().data(), streamS2, nullptr);
    }

Can someone gives me some advice for paralleling infer with the code?

Environment

Windows OS + TensorRT 7.1.3 + one 2060super GPU

@ttyio
Copy link
Collaborator

ttyio commented Oct 26, 2020

Hello @Jacoobr , thanks for reporting.
the IExecutionContext contains shared resource, so if you want parallel execution, you have to create two IExecutionContext, one assigned for each cuda stream. Could you take a try? thanks

@ttyio ttyio added the triaged Issue has been triaged by maintainers label Oct 26, 2020
@Jacoobr
Copy link
Author

Jacoobr commented Oct 26, 2020

Hi @ttyio, thanks for you reply. I tried to create one 'IExecutionContext' of each cuda stream, the error has not appeared. But i got another error about dynamic profile setting with IExecutionContext like this:

[10/26/2020-19:31:20] [E] [TRT] Profile 0 has been chosen by another IExecutionContext. Use another profileIndex or destroy the IExecutionContext that use this profile.
[10/26/2020-19:31:20] [W] [TRT] Could not set default profile 0 for execution context. Profile index must be set explicitly.
[10/26/2020-19:31:20] [E] [TRT] Parameter check failed at: engine.cpp::nvinfer1::rt::ExecutionContext::setBindingDimensions::1033, condition: mOptimizationProfile >= 0 && mOptimizationProfile < mEngine.getNbOptimizationProfiles()
read input finished!!
before execute inter time (ms):1746
[10/26/2020-19:31:20] [E] [TRT] Parameter check failed at: engine.cpp::nvinfer1::rt::ExecutionContext::enqueueV2::545, condition: mOptimizationProfile >= 0 && mOptimizationProfile < mEngine.getNbOptimizationProfiles()
[10/26/2020-19:31:20] [E] [TRT] Parameter check failed at: engine.cpp::nvinfer1::rt::ExecutionContext::enqueueV2::545, condition: mOptimizationProfile >= 0 && mOptimizationProfile < mEngine.getNbOptimizationProfiles()
[10/26/2020-19:31:20] [E] [TRT] Parameter check failed at: engine.cpp::nvinfer1::rt::ExecutionContext::enqueueV2::545, condition: mOptimizationProfile >= 0 && mOptimizationProfile < mEngine.getNbOptimizationProfiles()
[10/26/2020-19:31:20] [E] [TRT] Parameter check failed at: engine.cpp::nvinfer1::rt::ExecutionContext::enqueueV2::545, condition: mOptimizationProfile >= 0 && mOptimizationProfile < mEngine.getNbOptimizationProfiles()
[10/26/2020-19:31:20] [E] [TRT] Parameter check failed at: engine.cpp::nvinfer1::rt::ExecutionContext::enqueueV2::545, condition: mOptimizationProfile >= 0 && mOptimizationProfile < mEngine.getNbOptimizationProfiles()
[10/26/2020-19:31:20] [E] [TRT] Parameter check failed at: engine.cpp::nvinfer1::rt::ExecutionContext::enqueueV2::545, condition: mOptimizationProfile >= 0 && mOptimizationProfile < mEngine.getNbOptimizationProfiles()
[10/26/2020-19:31:20] [E] [TRT] Parameter check failed at: engine.cpp::nvinfer1::rt::ExecutionContext::enqueueV2::545, condition: mOptimizationProfile >= 0 && mOptimizationProfile < mEngine.getNbOptimizationProfiles()
[10/26/2020-19:31:20] [E] [TRT] Parameter check failed at: engine.cpp::nvinfer1::rt::ExecutionContext::enqueueV2::545, condition: mOptimizationProfile >= 0 && mOptimizationProfile < mEngine.getNbOptimizationProfiles()
[10/26/2020-19:31:20] [E] [TRT] Parameter check failed at: engine.cpp::nvinfer1::rt::ExecutionContext::enqueueV2::545, condition: mOptimizationProfile >= 0 && mOptimizationProfile < mEngine.getNbOptimizationProfiles()
[10/26/2020-19:31:20] [E] [TRT] Parameter check failed at: engine.cpp::nvinfer1::rt::ExecutionContext::enqueueV2::545, condition: mOptimizationProfile >= 0 && mOptimizationProfile < mEngine.getNbOptimizationProfiles()

Because i need the Engine can infer different input size image, so i create a optimization profile to let the engine fit different input size image inference with the code bellowed:

   auto input_dim = network->getInput(0)->getDimensions();
    IOptimizationProfile* profile = builder->createOptimizationProfile();
    profile->setDimensions("input", OptProfileSelector::kMIN, input_dim);
    input_dim.d[3] = xxx;
    input_dim.d[4] = xxx;
    profile->setDimensions("input", OptProfileSelector::kOPT, input_dim);
    input_dim.d[3] = xxx;
    input_dim.d[4] = xxx;
    profile->setDimensions("input", OptProfileSelector::kMAX, input_dim);
    config->addOptimizationProfile(profile);

when i run with one IExecutionContext, the infer stage can runs well and inference results are right without any error, but using two IExecutionContext in parallel execution the Error front occurred and the result of the two executions only the first inference result is right, the second execution result is wrong . Does the profile (create at build stage) can not be paralleled at infer stage ?
How can i solve this error that i can use profile when infer stage with two parallel execution? Appreciated for your any answer.

@ttyio
Copy link
Collaborator

ttyio commented Oct 27, 2020

Hello @Jacoobr , thanks for reply.
Could you share more code? I want to know more detail on how you create the dynamic shape profile, how you setup the profile, and how these 2 ctx are used with cuda stream. Thanks!

@Jacoobr
Copy link
Author

Jacoobr commented Oct 28, 2020

Hi @ttyio, At build engine stage, i create only one dynamic shape profile in constructNetwork function. the detail code:

bool SampleOnnxMNIST::constructNetwork(SampleUniquePtr<nvinfer1::IBuilder>& builder,
    SampleUniquePtr<nvinfer1::INetworkDefinition>& network, SampleUniquePtr<nvinfer1::IBuilderConfig>& config,
    SampleUniquePtr<nvonnxparser::IParser>& parser)
{
    /* auto parsed = parser->parseFromFile(
         locateFile(mParams.onnxFileName, mParams.dataDirs).c_str(), static_cast<int>(gLogger.getReportableSeverity()));*/
    clock_t parser_t = clock();
    //parse onnx file
    auto parsed = parser->parseFromFile(mParams.onnxFileName.c_str(), static_cast<int>(sample::gLogger.getReportableSeverity()));  

	//create optimization profile for Dynamic input support
    IOptimizationProfile* profile = builder->createOptimizationProfile();
    profile->setDimensions("input", OptProfileSelector::kMIN, input_dim_min);
    profile->setDimensions("input", OptProfileSelector::kOPT, input_dim_op);
    profile->setDimensions("input", OptProfileSelector::kMAX, input_dim_max);
    config->addOptimizationProfile(profile);
    if (profile->isValid()) {
        std::cout << "set profile succesfully!!\n";
    }
    else {
        std::cout << "set profile failed!!\n";
    }
    std::cout << "parser tackes (ms): " << (clock() - parser_t) << "\n";
    if (!parsed)
    {
        return false;
    }
    builder->setMaxBatchSize(mParams.batchSize);
    //config->setMaxWorkspaceSize(5_GiB);
    if (mParams.maxWorkSpaceSizeUnit == 3) {
        config->setMaxWorkspaceSize(mParams.maxWorkSpaceSizeVal * (1 << 30));
    }
    else if (mParams.maxWorkSpaceSizeUnit == 2) {
        config->setMaxWorkspaceSize(mParams.maxWorkSpaceSizeVal * (1 << 20));
    }
    else if (mParams.maxWorkSpaceSizeUnit == 1) {
        config->setMaxWorkspaceSize(mParams.maxWorkSpaceSizeVal * (1 << 10));
    }
    if (mParams.fp16)
    {
        config->setFlag(BuilderFlag::kFP16);
    }
    if (mParams.int8)
    {
        config->setFlag(BuilderFlag::kINT8);
        samplesCommon::setAllTensorScales(network.get(), 127.0f, 127.0f);
    }
    if (mParams.fp32) {
        std::cout << "start try to use TF32\n";
        bool hasTf32 = builder->platformHasTf32();
        std::cout << "build engine with TF32?: " << hasTf32 << "\n";
        config->setFlag(BuilderFlag::kTF32);
    }
    samplesCommon::enableDLA(builder.get(), config.get(), mParams.dlaCore);

    return true;
}

And at infer stage, i set two contexts each with different buffers and streams. i use context->setBindingDimensions(0, dims_infer); the dims_infer presented the specify image size that i will input at infer stage. the detail code as followed:

bool SampleOnnxMNIST::infer()
{
    std::cout << "infer1 start!\n";
    clock_t start_ = clock();
    auto runtime = SampleUniquePtr<nvinfer1::IRuntime>(nvinfer1::createInferRuntime(sample::gLogger.getTRTLogger()));
    std::string engineFile= mParams.engineStoredDir + mParams.engineFileName;
    std::ifstream ifs(engineFile.c_str(), std::ios::binary);
    size_t s = 0;
    ifs >> s;
    std::unique_ptr<char[]> pModel(new char[s]());
    ifs.read(pModel.get(), s);
    ifs.close();
    mEngine = std::shared_ptr<nvinfer1::ICudaEngine>(runtime->deserializeCudaEngine(pModel.get(), s, nullptr), samplesCommon::InferDeleter());
   
    size_t wsp = mEngine->getDeviceMemorySize();

    auto dims_kopt = mEngine->getProfileDimensions(0, 0, OptProfileSelector::kMAX);
    for (int i = 0;i < mParams.inputDims;i++) {
        dims_kopt.d[i] = mParams.inputDataSize[i];
    }
    
    std::cout << "workspace: " << wsp / 1024 / 1024.0 << "MB" << std::endl;
    std::cout << "deserializeCudaEngine time (ms):" << (clock() - start_) << "\n";

    auto context = SampleUniquePtr<nvinfer1::IExecutionContext>(mEngine->createExecutionContext());
    auto context2 = SampleUniquePtr<nvinfer1::IExecutionContext>(mEngine->createExecutionContext());
    context->setBindingDimensions(0, dims_kopt);	//set the specify executed image size for contex1
    context2->setBindingDimensions(0, dims_kopt);	//set the specify executed image size for contex2

    samplesCommon::BufferManager buffers(mEngine, dims_kopt);
    samplesCommon::BufferManager buffers2(mEngine, dims_kopt);

    cudaStream_t streamS{ NULL };
    cudaStreamCreate(&streamS);
    cudaStream_t streamS2{ NULL };
    cudaStreamCreate(&streamS2);

    // Read the input data into the managed buffers
    assert(mParams.inputTensorNames.size() == 1);
    if (!processInput(buffers))
    {
        return false;
    }
    std::cout << "read input finished!!\n";
    processInput(buffers2);
    // Memcpy from host input buffers to device input buffers
    buffers.copyInputToDeviceAsync(streamS);
    buffers2.copyInputToDeviceAsync(streamS2);

    std::cout << "before execute inter time (ms):" << (clock() - start_) << "\n";
    bool status = true;
    cudaEvent_t start, stop;
    cudaEventCreate(&start);
    cudaEventCreate(&stop);
    cudaEventRecord(start);
    for (unsigned i = 0; i < 1; ++i) {
	//* parallel execute the two inferences with different contexts and streams
    context->enqueueV2(buffers.getDeviceBindings().data(), streamS, nullptr);
    context2->enqueueV2(buffers2.getDeviceBindings().data(), streamS2, nullptr);
    }
    cudaEventRecord(stop);
    cudaEventSynchronize(stop);
    float ftm(0);
    cudaEventElapsedTime(&ftm, start, stop);
    std::cout << "execute1 time (ms): " << ftm/1   << std::endl;

    // Memcpy from device output buffers to host output buffers
    clock_t write_t = clock();
    //***buffers.copyOutputToHost();
    buffers.copyOutputToHostAsync(streamS);
    buffers2.copyOutputToHostAsync(streamS2);
    cudaStreamSynchronize(streamS);
    cudaStreamSynchronize(streamS2);
    cudaStreamDestroy(streamS);
    cudaStreamDestroy(streamS2);
	//*Write execution result 
	...
	
    return true;
}

I don't understand why the profile that i set at build stage can not be paralleled with different context with asynchronous enqueueV2() execution? or there are errors with my code with parallel execution?

@ttyio
Copy link
Collaborator

ttyio commented Oct 28, 2020

Ah, thanks for the code @Jacoobr ,
Could you create two profiles and assign to separate context, here is some explanation from https://docs.nvidia.com/deeplearning/tensorrt/api/c_api/classnvinfer1_1_1_i_execution_context.html#aeaa47145eac410c2ff9df8a29b64a96f

If the associated CUDA engine has dynamic inputs, this method must be called at least once with a unique profileIndex before calling execute or enqueue (i.e. the profile index may not be in use by another execution context that has not been destroyed yet).

@Jacoobr
Copy link
Author

Jacoobr commented Oct 29, 2020

Hello @ttyio, i tried to create two profiles and assign to separate context the error with profile at parallel infer stage just gone. Thank your very much. But the two parallel inference results only the first is right, the second context inference result are not right. What's more, i test the two parallel execution time compared with only one execution, the cost time of two parallel execution are 2x compared with only one execution. So, i think the parallel execution with the two contexts are failed. Would you mind give some advice with parallel execution, or what's wrong with my process?
For more specification, i test with sampleMNIST project in parallel methods like i did create two contexts and two streams for parallel inference and the result are the sames error. The second context execution result is not right. and the parallel two context infer cost time are 2x than only one context execution.
There are the code and result:

bool SampleMNIST::infer()
{
    
    //auto dims = mEngine->getBindingDimensions(0);
    samplesCommon::BufferManager buffers(mEngine);
    samplesCommon::BufferManager buffers2(mEngine);

    auto context = SampleUniquePtr<nvinfer1::IExecutionContext>(mEngine->createExecutionContext());
    auto context2 = SampleUniquePtr<nvinfer1::IExecutionContext>(mEngine->createExecutionContext());

    if (!context)
    {
        return false;
    }

    // Pick a random digit to try to infer
    srand(time(NULL));
    const int digit1 = rand() % 10;
    const int digit2 = rand() % 10;

    // Read the input data into the managed buffers
    // There should be just 1 input tensor
    assert(mParams.inputTensorNames.size() == 1);
    if (!processInput(buffers, mParams.inputTensorNames[0], digit1))
    {
        return false;
    }
    processInput(buffers2, mParams.inputTensorNames[0], digit2);
    // Create CUDA stream for the execution of this inference.
    cudaStream_t stream;
    CHECK(cudaStreamCreate(&stream));
    cudaStream_t stream2;
    CHECK(cudaStreamCreate(&stream2));

    // Asynchronously copy data from host input buffers to device input buffers
    buffers.copyInputToDeviceAsync(stream);
    buffers.copyInputToDeviceAsync(stream2);

    cudaEvent_t start, stop;
    cudaEventCreate(&start);
    cudaEventCreate(&stop);
    cudaEventRecord(start);
    // Asynchronously enqueue the inference work
    context->enqueue(1, buffers.getDeviceBindings().data(), stream, nullptr);
    context2->enqueue(1, buffers2.getDeviceBindings().data(), stream2, nullptr);
 
    cudaEventRecord(stop);
    cudaEventSynchronize(stop);
    float ftm(0);
    cudaEventElapsedTime(&ftm, start, stop);
    std::cout << "forward elapsed time: " << ftm << std::endl;

    // Asynchronously copy data from device output buffers to host output buffers
    buffers.copyOutputToHostAsync(stream);
    buffers.copyOutputToHostAsync(stream2);
    // Wait for the work in the stream to complete
    cudaStreamSynchronize(stream);
    cudaStreamSynchronize(stream2);

    // Release stream
    cudaStreamDestroy(stream);
    cudaStreamDestroy(stream2);

    // Check and print the output of the inference
    // There should be just one output tensor
    assert(mParams.outputTensorNames.size() == 1);
    bool outputCorrect = verifyOutput(buffers, mParams.outputTensorNames[0], digit1);
    bool outputCorrect2 = verifyOutput(buffers2, mParams.outputTensorNames[0], digit2);

    bool result = outputCorrect && outputCorrect2;
    return result;
}

terminal console:

&&&& RUNNING TensorRT.sample_mnist # C:\NVIDIA\TensorRT\TensorRT-7.1.3.4\samples\sampleMNIST\\..\..\bin\sample_mnist.exe[10/29/2020-14:57:35] [I] Building and running a GPU inference engine for MNIST
[10/29/2020-14:58:03] [W] [TRT] No implementation obeys reformatting-free rules, at least 2 reformatting nodes are needed, now picking the fastest path instead.
[10/29/2020-14:58:03] [I] [TRT] Detected 1 inputs and 1 output network tensors.
mdoel size: 891513
[10/29/2020-14:58:03] [I] Input:
@@@@@@@@@@@@@@@@@@@@@@@@@@@@
@@@@@@@@@@@@@@@@@@@@@@@@@@@@
@@@@@@@@@@@@@@@@@@@@@@@@@@@@
@@@@@@@@@@@@@@@@@@@@@@@@@@@@
@@@@@@@@@@@@@@@@@@@@@@@@@@@@
@@@@@@@@@@@@@@@@@@@@@@@@@@@@
@@@@@@@@@@@@@@@@@@@@@@@@@@@@
@@@@@@@@@@@@#=.  +*=#@@@@@@@
@@@@@@@@@@@*   :.   -@@@@@@@
@@@@@@@@@@#  :#@@:  +@@@@@@@
@@@@@@@@@*  :@@@*  .@@@@@@@@
@@@@@@@@=  =@@@@.  *@@@@@@@@
@@@@@@@=  -@@@@*  =@@@@@@@@@
@@@@@@@  -@@@%:  -@@@@@@@@@@
@@@@@@%  %%+:    *@@@@@@@@@@
@@@@@@@      ..  @@@@@@@@@@@
@@@@@@@#  .=%%: =@@@@@@@@@@@
@@@@@@@@@@@@@#  +@@@@@@@@@@@
@@@@@@@@@@@@@#  @@@@@@@@@@@@
@@@@@@@@@@@@@@  @@@@@@@@@@@@
@@@@@@@@@@@@@#  @@@@@@@@@@@@
@@@@@@@@@@@@@+  @@@@@@@@@@@@
@@@@@@@@@@@@@%  @@@@@@@@@@@@
@@@@@@@@@@@@@@. #@@@@@@@@@@@
@@@@@@@@@@@@@@* :%@@@@@@@@@@
@@@@@@@@@@@@@@@: -@@@@@@@@@@
@@@@@@@@@@@@@@@@= %@@@@@@@@@
@@@@@@@@@@@@@@@@@@@@@@@@@@@@

[10/29/2020-14:58:03] [I] Input:
@@@@@@@@@@@@@@@@@@@@@@@@@@@@
@@@@@@@@@@@@@@@@@@@@@@@@@@@@
@@@@@@@@@@@@@@@@@@@@@@@@@@@@
@@@@@@@@@@@@@@@@@@@@@@@@@@@@
@@@@@@@@@@@@@@@@@@@@@@@@@@@@
@@@@@@@@@@@@@@@@@@@@@@@@@@@@
@@@@@@@@@@@@@@@@@@@@@@@@@@@@
@@@@@@@@@@@@@@@++-  .@@@@@@@
@@@@@@@@@@@@#+-      .@@@@@@
@@@@@@@@%:..   :     .@@@@@@
@@@@@@@#.     %@%#.  :@@@@@@
@@@@@@@:    +#@@@=  .@@@@@@@
@@@@@@#   .#@@@@@:  =@@@@@@@
@@@@@@:  :#@@@@@*  :#@@@@@@@
@@@@@@*.:%@@@@@+  .%@@@@@@@@
@@@@@@@@@@@@@@*   +@@@@@@@@@
@@@@@@@@@@@@@@   +@@@@@@@@@@
@@@@@@@@@@@@@=   %@@@@@@@@@@
@@@@@@@@@@@@@:  +@@@@@@@@@@@
@@@@@@@@@@@@+  -@@@@@@@@@@@@
@@@@@@@@@@@-  +@@@@@@@@@@@@@
@@@@@@@@@@%  .@@@@@@@@@@@@@@
@@@@@@@@@%- .%@@@@@@@@@@@@@@
@@@@@@@@@-  +@@@@@@@@@@@@@@@
@@@@@@@@%: =%@@@@@@@@@@@@@@@
@@@@@@@@= .%@@@@@@@@@@@@@@@@
@@@@@@@@# *@@@@@@@@@@@@@@@@@
@@@@@@@@@@@@@@@@@@@@@@@@@@@@

forward elapsed time: 1.19808
[10/29/2020-14:58:03] [I] Output:
0:
1:
2:
3:
4:
5:
6:
7:
8:
9: **********

[10/29/2020-14:58:03] [I] Output:
0:
1:
2:
3:
4:
5:
6:
7:
8:
9:

&&&& FAILED TensorRT.sample_mnist # C:\NVIDIA\TensorRT\TensorRT-7.1.3.4\samples\sampleMNIST\\..\..\bin\sample_mnist.exe

Appreciated for your any reply! thanks.

@ttyio
Copy link
Collaborator

ttyio commented Nov 2, 2020

Hello @Jacoobr ,

For the accuracy issue, I have not run it through, but there is typo in your the code, buffers2 is not copy to device, could you fix it and retry?

For the throughput issue, this is case by case for different network, for example, when gpu is already fully occupied with one network, if we kick off another network simultaneously, we won't get higher occupancy. You can play with nsight systems for overall performance, and nsight compute for CUDA kernel performance. Hope this helps, thanks!

@Jacoobr
Copy link
Author

Jacoobr commented Nov 3, 2020

Hi @ttyio. Sorry for the mistake of the code. I fix the typo of the code and the two inference results all are correct. Thank you very much. But the infer time of two parallel executions still take 2x than only one execution, it's not certain the two executions are parallel asynchronously. I will check it later with nsight compute tool. Any way, thank you again.

@OpDaSo
Copy link

OpDaSo commented Jul 15, 2021

I created a network with multiple profiles therefore it can be created more than one context.
If I run the following code the first time everything is fine but for the second time I get the error:
[E] [TRT] 1: [resources.cpp::nvinfer1::ScopedCudaStream::ScopedCudaStream::447] Error Code 1: Cuda Runtime (an illegal memory access was encountered)
The error occurs at the line: cudaMemcpy(outputPred, outputBuffer, outputSize*sizeof(float), cudaMemcpyDeviceToHost);

This is the code:

nvinfer1::IRuntime* runtime = nvinfer1::createInferRuntime(sample::gLogger);
engine = runtime->deserializeCudaEngine(data.data(), data.size(), nullptr);
context = _engine->createExecutionContext();

for(int i=0; i<10; i++) {
    _context->setOptimizationProfile(i);
    _context->setBindingDimensions(2*i, nvinfer1::Dims4(4,224,224,3));
    _inputIndex = _engine->getBindingIndex("input_layer:0");
    _outputIndex = _engine->getBindingIndex("Identity:0");
    _input_height = _engine->getBindingDimensions(_inputIndex).d[1]; 
    _input_width = _engine->getBindingDimensions(_inputIndex).d[2];
    _input_channels = _engine->getBindingDimensions(_inputIndex).d[3];

    cudaError_t cudaErr = cudaStreamCreate(&_stream);
    int input_batch_size = 1;
    int inputSize = input_batch_size * _input_height * _input_width * _input_channels;
    int outputSize = input_batch_size * _engine->getBindingDimensions(_outputIndex).d[1];
    
    void* inputBuffer;
    cudaMalloc(&inputBuffer, inputSize*sizeof(float));
    void* outputBuffer;
    cudaMalloc(&outputBuffer, outputSize*sizeof(float)) ;

    cv::Mat dummyImg = cv::Mat::ones(224, 224, CV_8UC3);

    int volChl = _input_channels * _input_width;
    int volImg = _input_channels * _input_height * _input_width;
    float* inputImg = (float*) malloc(inputSize*sizeof(float));

    // write dummyImg  to inputImg

    cudaMemcpy(inputBuffer, inputImg, inputSize*sizeof(float), cudaMemcpyHostToDevice);

    float* outputPred = (float*) malloc(outputSize*sizeof(float));
    cudaMemcpy(outputBuffer, outputPred, outputSize*sizeof(float), cudaMemcpyHostToDevice);

    void* buffers[2];
    buffers[_inputIndex] = inputBuffer;
    buffers[_outputIndex] = outputBuffer;

    bool status = _context->enqueueV2(buffers, _stream, nullptr);

    cudaMemcpy(outputPred, outputBuffer, outputSize*sizeof(float), cudaMemcpyDeviceToHost);

    free(outputPred);
    free(inputImg);

    cudaFree(outputBuffer);
    cudaFree(inputBuffer);

}

My goal is to have multiple contexts that do parallel inference on one GPU.

@awarebayes
Copy link

I created a network with multiple profiles therefore it can be created more than one context. If I run the following code the first time everything is fine but for the second time I get the error: [E] [TRT] 1: [resources.cpp::nvinfer1::ScopedCudaStream::ScopedCudaStream::447] Error Code 1: Cuda Runtime (an illegal memory access was encountered) The error occurs at the line: cudaMemcpy(outputPred, outputBuffer, outputSize*sizeof(float), cudaMemcpyDeviceToHost);

This is the code:

nvinfer1::IRuntime* runtime = nvinfer1::createInferRuntime(sample::gLogger);
engine = runtime->deserializeCudaEngine(data.data(), data.size(), nullptr);
context = _engine->createExecutionContext();

for(int i=0; i<10; i++) {
    _context->setOptimizationProfile(i);
    _context->setBindingDimensions(2*i, nvinfer1::Dims4(4,224,224,3));
    _inputIndex = _engine->getBindingIndex("input_layer:0");
    _outputIndex = _engine->getBindingIndex("Identity:0");
    _input_height = _engine->getBindingDimensions(_inputIndex).d[1]; 
    _input_width = _engine->getBindingDimensions(_inputIndex).d[2];
    _input_channels = _engine->getBindingDimensions(_inputIndex).d[3];

    cudaError_t cudaErr = cudaStreamCreate(&_stream);
    int input_batch_size = 1;
    int inputSize = input_batch_size * _input_height * _input_width * _input_channels;
    int outputSize = input_batch_size * _engine->getBindingDimensions(_outputIndex).d[1];
    
    void* inputBuffer;
    cudaMalloc(&inputBuffer, inputSize*sizeof(float));
    void* outputBuffer;
    cudaMalloc(&outputBuffer, outputSize*sizeof(float)) ;

    cv::Mat dummyImg = cv::Mat::ones(224, 224, CV_8UC3);

    int volChl = _input_channels * _input_width;
    int volImg = _input_channels * _input_height * _input_width;
    float* inputImg = (float*) malloc(inputSize*sizeof(float));

    // write dummyImg  to inputImg

    cudaMemcpy(inputBuffer, inputImg, inputSize*sizeof(float), cudaMemcpyHostToDevice);

    float* outputPred = (float*) malloc(outputSize*sizeof(float));
    cudaMemcpy(outputBuffer, outputPred, outputSize*sizeof(float), cudaMemcpyHostToDevice);

    void* buffers[2];
    buffers[_inputIndex] = inputBuffer;
    buffers[_outputIndex] = outputBuffer;

    bool status = _context->enqueueV2(buffers, _stream, nullptr);

    cudaMemcpy(outputPred, outputBuffer, outputSize*sizeof(float), cudaMemcpyDeviceToHost);

    free(outputPred);
    free(inputImg);

    cudaFree(outputBuffer);
    cudaFree(inputBuffer);

}

My goal is to have multiple contexts that do parallel inference on one GPU.

He is using CudaMemcpy which is device blocking, use CudaMemcpyAsync instead

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
documentation triaged Issue has been triaged by maintainers
Projects
None yet
Development

No branches or pull requests

4 participants