Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cuda error in TensorRT5 when using DALI preprocessing #408

Closed
herbiezhao opened this issue Jan 7, 2019 · 17 comments
Closed

cuda error in TensorRT5 when using DALI preprocessing #408

herbiezhao opened this issue Jan 7, 2019 · 17 comments
Labels
bug Something isn't working

Comments

@herbiezhao
Copy link

I use DALI to preprocess my data, then input pipe_out of DALI to tensorRT inference.I got cuda error:
Cuda error in file src/implicit_gemm.cu at line 1214: invalid resource handle
[TensorRT] ERROR: cuda/customWinogradConvActLayer.cpp (310) - Cuda Error in execute: 33
[TensorRT] ERROR: cuda/customWinogradConvActLayer.cpp (310) - Cuda Error in execute: 33

Partial code:
pipe_out = pipe.run()
pre_input, labels = pipe_out
pre_input_cpu = pre_input.asCPU()
pre_input_tensor = pre_input_cpu.as_tensor()
pre_input_ = np.array(pre_input_tensor)
input_ = np.array(pre_input_.ravel())
np.copyto(pagelocked_buffer, input_)
[output] = common.do_inference(context, bindings=bindings, inputs=inputs, outputs=outputs, stream=stream)

Error happens in code "[output] = common.do_inference(context, bindings=bindings, inputs=inputs, outputs=outputs, stream=stream)"

I got correct output in DALI and I got correct output in TensorRT5 without DALI.

@JanuszL JanuszL added the bug Something isn't working label Jan 7, 2019
@JanuszL
Copy link
Contributor

JanuszL commented Jan 7, 2019

Hi,
@herbiezhao could you provide some simple, yet full reproduction script for that (following points from https://github.com/NVIDIA/DALI#reporting-problems-asking-questions)?
It will help us know we are chasing exactly the same problem you have encountered.
Br,
Janusz

@JanuszL
Copy link
Contributor

JanuszL commented Jan 7, 2019

Tracked as DALI-472.

@herbiezhao
Copy link
Author

My environment is:
4GPUs, P100
CUDA9.0 TensorRT5
You need test data and uff file to run sample_bug.py, you can produce this uff by yourself. The model is resnet_v1_50 in slim. If I set device="cpu" in DALI, there is no problem. I think it is because of gpu resource confliction.

[sample_bug.py]

This sample uses a UFF MNIST model to create a TensorRT Inference Engine

import numpy as np
import tensorrt as trt
import sys, os
import common
import argparse
from nvidia.dali.pipeline import Pipeline
import pycuda.driver as cuda
import pycuda.autoinit

parser = argparse.ArgumentParser()
parser.add_argument('--data', help='test data')
args = parser.parse_args()

You can set the logger severity higher to suppress messages (or lower to display more messages).

TRT_LOGGER = trt.Logger(trt.Logger.INFO)

import nvidia.dali.ops as ops
import nvidia.dali.types as types

batch_size = 1

class SimplePipeline(Pipeline):
def init(self, batch_size, num_threads, device_id, image_dir):
super(SimplePipeline, self).init(batch_size, num_threads, device_id, seed = 12)
self.input = ops.FileReader(file_root = image_dir)
# instead of path to file directory file with pairs image_name image_label_value can be provided
# self.input = ops.FileReader(file_root = image_dir, file_list = image_dir + '/file_list.txt')
self.decode = ops.HostDecoder(output_type = types.RGB)
self.resize = ops.Resize(device="gpu", resize_shorter=256.)
self.cmnp = ops.CropMirrorNormalize(device="gpu",
output_dtype=types.FLOAT,
output_layout = types.NCHW,
crop = (224, 224),
image_type = types.RGB,
mean = [123.68, 116.78, 103.94],
std = [1., 1., 1.])
self.uniform = ops.Uniform(range = (0.5, 0.5))

def define_graph(self):
    jpegs, labels = self.input()
    images = self.decode(jpegs)
    #tensor = TensorCPU(images)
    #size = tensor.shape()
    #print(size)
    resize = self.resize(images.gpu())
    output = self.cmnp(resize,crop_pos_x = self.uniform(),
                       crop_pos_y = self.uniform())
    return (output, labels)

class ModelData(object):
MODEL_FILE = os.path.join(os.path.dirname(file), "models/resnet_v1_50.uff")
INPUT_NAME ="input"
INPUT_SHAPE = (3, 224, 224)
OUTPUT_NAME = "resnet_v1_50/SpatialSqueeze"

def build_engine(model_file):
# For more information on TRT basics, refer to the introductory samples.
print("build engine begin\n")
with trt.Builder(TRT_LOGGER) as builder, builder.create_network() as network, trt.UffParser() as parser:
builder.max_batch_size = 1
builder.max_workspace_size = common.GiB(1)
# Parse the Uff Network
parser.register_input(ModelData.INPUT_NAME, ModelData.INPUT_SHAPE, trt.UffInputOrder.NHWC)
parser.register_output(ModelData.OUTPUT_NAME)
parser.parse(model_file, network)
# Build and return an engine.
return builder.build_cuda_engine(network)

def main():
#data_path = common.find_sample_data(description="Runs an MNIST network using a UFF model file", subfolder="mnist")
model_file = ModelData.MODEL_FILE
with build_engine(model_file) as engine:
# Build an engine, allocate buffers and create a stream.
# For more information on buffer allocation, refer to the introductory samples.
inputs, outputs, bindings, stream = common.allocate_buffers(engine)
pagelocked_buffer=inputs[0].host
with engine.create_execution_context() as context:
pipe = SimplePipeline(batch_size, 1, 0, args.data)
pipe.build()

        pipe_out = pipe.run()
        pre_input, labels = pipe_out
        #print(pre_input.asCPU().at(0).shape)
        pre_input_cpu = pre_input.asCPU()
        pre_input_tensor = pre_input_cpu.as_tensor()
        pre_input_ = np.array(pre_input_tensor)
        
        input_ = np.array(pre_input_.ravel())
        np.copyto(pagelocked_buffer, input_)
        
        [output] = common.do_inference(context, bindings=bindings, inputs=inputs, outputs=outputs, stream=stream)
        #print(output)

if name == 'main':
main()

[common.py]
import os
import argparse
import numpy as np
import pycuda.driver as cuda
import tensorrt as trt

try:
# Sometimes python2 does not understand FileNotFoundError
FileNotFoundError
except NameError:
FileNotFoundError = IOError

def GiB(val):
return val * 1 << 30

def find_sample_data(description="Runs a TensorRT Python sample", subfolder="", find_files=[]):
'''
Parses sample arguments.
Args:
description (str): Description of the sample.
subfolder (str): The subfolder containing data relevant to this sample
find_files (str): A list of filenames to find. Each filename will be replaced with an absolute path.
Returns:
str: Path of data directory.
Raises:
FileNotFoundError
'''
kDEFAULT_DATA_ROOT = os.path.abspath("/usr/src/tensorrt/data")

# Standard command-line arguments for all samples.
parser = argparse.ArgumentParser(description=description)
parser.add_argument("-d", "--datadir", help="Location of the TensorRT sample data directory.")
args, unknown_args = parser.parse_known_args()

# If data directory is not specified, use the default.
data_root = args.datadir if args.datadir else kDEFAULT_DATA_ROOT
# If the subfolder exists, append it to the path, otherwise use the provided path as-is.
subfolder_path = os.path.join(data_root, subfolder)
if not os.path.exists(subfolder_path):
    print("WARNING: " + subfolder_path + " does not exist. Using " + data_root + " instead.")
data_path = subfolder_path if os.path.exists(subfolder_path) else data_root

# Make sure data directory exists.
if not (os.path.exists(data_path)):
    raise FileNotFoundError(data_path + " does not exist. Please provide the correct data path with the -d option.")

# Find all requested files.
for index, f in enumerate(find_files):
    find_files[index] = os.path.abspath(os.path.join(data_path, f))
    if not os.path.exists(find_files[index]):
        raise FileNotFoundError(find_files[index] + " does not exist. Please provide the correct data path with the -d option.")
if find_files:
    return data_path, find_files
else:
    return data_path

Simple helper data class that's a little nicer to use than a 2-tuple.

class HostDeviceMem(object):
def init(self, host_mem, device_mem):
self.host = host_mem
self.device = device_mem

def __str__(self):
    return "Host:\n" + str(self.host) + "\nDevice:\n" + str(self.device)

def __repr__(self):
    return self.__str__()

Allocates all buffers required for an engine, i.e. host/device inputs/outputs.

def allocate_buffers(engine):
inputs = []
outputs = []
bindings = []
stream = cuda.Stream()
for binding in engine:
size = trt.volume(engine.get_binding_shape(binding)) * engine.max_batch_size
dtype = trt.nptype(engine.get_binding_dtype(binding))
# Allocate host and device buffers
host_mem = cuda.pagelocked_empty(size, dtype)
device_mem = cuda.mem_alloc(host_mem.nbytes)
# Append the device buffer to device bindings.
bindings.append(int(device_mem))
# Append to the appropriate list.
if engine.binding_is_input(binding):
inputs.append(HostDeviceMem(host_mem, device_mem))
else:
outputs.append(HostDeviceMem(host_mem, device_mem))
return inputs, outputs, bindings, stream

This function is generalized for multiple inputs/outputs.

inputs and outputs are expected to be lists of HostDeviceMem objects.

def do_inference(context, bindings, inputs, outputs, stream, batch_size=1):
# Transfer input data to the GPU.
[cuda.memcpy_htod_async(inp.device, inp.host, stream) for inp in inputs]
# Run inference.
context.execute_async(batch_size=batch_size, bindings=bindings, stream_handle=stream.handle)
# Transfer predictions back from the GPU.
[cuda.memcpy_dtoh_async(out.host, out.device, stream) for out in outputs]
# Synchronize the stream
stream.synchronize()
# Return only the host outputs.
return [out.host for out in outputs]

@JanuszL
Copy link
Contributor

JanuszL commented Jan 8, 2019

Hi,
Thanks for the extensive description. I will reach TRT guys and we will try to reproduce this.
Br,
Janusz

@herbiezhao
Copy link
Author

Hi
Have you reproduced this issue?

@JanuszL
Copy link
Contributor

JanuszL commented Jan 18, 2019

Hi,
We have and TRT team is looking into that.

@herbiezhao
Copy link
Author

Any further information, pls let me know, thanks.

@JanuszL
Copy link
Contributor

JanuszL commented Jan 18, 2019

@herbiezhao I think that repro you provided is sufficient for us. We got the same error so it is up to us in Nvidia to debug it.
Anyway, thanks for all the details and script you have provided. It allowed us to easily start working on that.

@herbiezhao
Copy link
Author

I got the same error, when I used TF-GPU + TRT.
TF + TRT is no problem.

@JanuszL
Copy link
Contributor

JanuszL commented Jan 18, 2019

Even without DALI?

@fkaster-nvidia
Copy link

Hi @herbiezhao - I am looking at the problem from the TRT side. Thank you for all your help so far. If you have a reproduction script that does not use DALI (but uses TF-GPU) it would be very helpful to narrow down the cause.

@herbiezhao
Copy link
Author

@JanuszL @fkaster-nvidia

from random import randint
from PIL import Image
import numpy as np
import pycuda.driver as cuda
import pycuda.autoinit
import tensorrt as trt
import sys, os
import common
from preprocessing import preprocessing_factory
from nets import nets_factory
import tensorflow as tf
import time

You can set the logger severity higher to suppress messages (or lower to display more messages).

TRT_LOGGER = trt.Logger(trt.Logger.WARNING)

class ModelData(object):
MODEL_FILE = os.path.join(os.path.dirname(file), "model.uff")
INPUT_NAME ="input"
INPUT_SHAPE = (3, 224, 224)
OUTPUT_NAME = "resnet_v1_50/SpatialSqueeze"

def build_engine(model_file):
# For more information on TRT basics, refer to the introductory samples.
print("build engine begin\n")
with trt.Builder(TRT_LOGGER) as builder, builder.create_network() as network, trt.UffParser() as parser:
builder.max_batch_size = 1
#builder.int8_mode = True
#builder.fp16_mode = True
builder.max_workspace_size = common.GiB(1)
# Parse the Uff Network
parser.register_input(ModelData.INPUT_NAME, ModelData.INPUT_SHAPE, trt.UffInputOrder.NHWC)
parser.register_output(ModelData.OUTPUT_NAME)
parser.parse(model_file, network)
# Build and return an engine.
return builder.build_cuda_engine(network)
def main():
#data_path = common.find_sample_data(description="Runs an MNIST network using a UFF model file", subfolder="mnist")
model_file = ModelData.MODEL_FILE

g = tf.Graph()
sess = tf.Session(graph=g)
with g.as_default():
    image = tf.placeholder(tf.uint8, shape=[None, None, 3])
    model_name = "resnet_v1_50"
    network_fn = nets_factory.get_network_fn(model_name, 8, is_training=False)
    image_size = network_fn.default_image_size
    image_preprocessing_fn = preprocessing_factory.get_preprocessing(model_name,is_training=False)
    arg_scope = nets_factory.arg_scopes_map[model_name]()
    pre_image = image_preprocessing_fn(image, image_size, image_size)
    input = tf.expand_dims(pre_image, 0)

    with build_engine(model_file) as engine:
        # Build an engine, allocate buffers and create a stream.
        # For more information on buffer allocation, refer to the introductory samples.
        inputs, outputs, bindings, stream = common.allocate_buffers(engine)
        pagelocked_buffer=inputs[0].host
        with engine.create_execution_context() as context:
            image_path = os.path.join("test.bmp")
            img = Image.open(image_path)
            img = np.array(img)

            input_ = sess.run(pre_image, feed_dict={image: img})
            input_ = np.array(input_.ravel())
            np.copyto(pagelocked_buffer, input_)
            [output] = common.do_inference(context, bindings=bindings, inputs=inputs, outputs=outputs, stream=stream)

if name == 'main':
main()

@herbiezhao
Copy link
Author

@fkaster-nvidia
Could you give a successful sample for using tensorflow-gpu and tensorrt at the same time?

@JanuszL
Copy link
Contributor

JanuszL commented May 8, 2019

Hi,
For TF-GPU + TRT example please check attached file.
inference_no_dali.txt
For the DALI bug we have some workaround:

  1. Isolate TRT+PyCUDA in a separate CUDA context
  2. Move DALI's pipeline initialization so that it happens before any TRT/PyCUDA calls
    But we will try to find some proper solution.

@JanuszL
Copy link
Contributor

JanuszL commented May 13, 2019

More permanent solution as WIP - #882

@JanuszL JanuszL added this to the Release_0.11.0 milestone May 29, 2019
@JanuszL
Copy link
Contributor

JanuszL commented May 29, 2019

DALI workaround for pycuda context management merged in #882, will be available in the next nightly build or in >= 0.11.0

@JanuszL
Copy link
Contributor

JanuszL commented Jul 2, 2019

Please retest with 0.11 and reopen if it still doesn't work.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants