Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix handling of CUDA streams in Python frontend #2050

Merged
merged 12 commits into from
Jun 25, 2020

Conversation

mzient
Copy link
Contributor

@mzient mzient commented Jun 23, 2020

Signed-off-by: Michał Zientkiewicz mzient@gmail.com

Why we need this PR?

Pick one, remove the rest

  • It fixes a bug: gpu external source test failure under GPU load

What happened in this PR?

Fill relevant points, put NA otherwise. Replace anything inside []

  • What solution was applied:
    • When cuda_stream is None in ExternalSource and the object is a CuPy array, issue the copy on CuPy's current stream.
  • Affected modules and functionalities:
    • pipeline.py
  • Key points relevant for the review:
    • N/A
  • Validation and testing:
    • Existing tests apply
  • Documentation (including examples):
    • N/A

JIRA TASK: DALI-1474

* When cuda_stream is None in ExternalSource and the object is a CuPy array, issue the copy on CuPy's current stream.

Signed-off-by: Michał Zientkiewicz <mzient@gmail.com>
@mzient
Copy link
Contributor Author

mzient commented Jun 23, 2020

!build

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [1417005]: BUILD STARTED

@klecki
Copy link
Contributor

klecki commented Jun 23, 2020

My points:

  • we say in the doc, that if you don't provide the stream, we will use an internal one for the copy. If the user knows that memory is ready, he may want to preserve this behaviour.
  • this is fix for CuPy test by adding a bit of CuPy-specific workaround. There are other sources of cuda memory we can get here that will still cause the bug.
  • I think we should add a mention in the doc, that if the user doesn't provide the stream the memory should be ready.

Maybe we can make this configurable and by default extract the stream from whatever library we can recognize and know how to do it. Otherwise allow user to turn it off?

This problem can also be fixed by using the cupy API to sync before passing memory to DALI in the test.

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [1417005]: BUILD PASSED

Signed-off-by: Michał Zientkiewicz <mzient@gmail.com>
@mzient mzient requested a review from JanuszL June 23, 2020 17:17
@mzient mzient changed the title Fix CuPy external source. Fix handling of CUDA streams in Python frontend Jun 23, 2020
@mzient
Copy link
Contributor Author

mzient commented Jun 23, 2020

!build

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [1417491]: BUILD STARTED

@@ -62,9 +62,15 @@ def feed_ndarray(dali_tensor, arr, cuda_stream = None):
# Get CTypes void pointer to the underlying memory held by arr
ptr = ctypes.c_void_p()
mx.base._LIB.MXNDArrayGetData(arr.handle, ctypes.byref(ptr))

if hasattr(cuda_stream, "cuda_stream"): # torch
Copy link
Contributor

@JanuszL JanuszL Jun 23, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Extract to common utility

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [1417491]: BUILD FAILED

@@ -57,10 +57,14 @@ def feed_ndarray(dali_tensor, arr, cuda_stream = None):
assert dali_tensor.shape() == list(arr.size()), \
("Shapes do not match: DALI tensor has size {0}"
", but PyTorch Tensor has size {1}".format(dali_tensor.shape(), list(arr.size())))
#turn raw int to a c void pointer
if cuda_stream is torch.cuda.Stream:
Copy link
Contributor

@JanuszL JanuszL Jun 23, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if cuda_stream is torch.cuda.Stream:
if isinstance(cuda_stream, torch.cuda.streams.Stream):

Signed-off-by: Janusz Lisiecki <jlisiecki@nvidia.com>
@JanuszL
Copy link
Contributor

JanuszL commented Jun 24, 2020

!build

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [1419648]: BUILD STARTED

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [1419648]: BUILD FAILED

Signed-off-by: Michal Zientkiewicz <michalz@nvidia.com>
Signed-off-by: Michal Zientkiewicz <michalz@nvidia.com>
Signed-off-by: Michal Zientkiewicz <michalz@nvidia.com>
@mzient
Copy link
Contributor Author

mzient commented Jun 24, 2020

!build

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [1420074]: BUILD STARTED

Signed-off-by: Michal Zientkiewicz <michalz@nvidia.com>
Signed-off-by: Michal Zientkiewicz <michalz@nvidia.com>
Copy link
Contributor

@klecki klecki left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mostly wording and the repetition of c_void_p usage.

and all work is properly queued). If no stream is provided feeding input blocks until the
provided memory is copied to the internal buffer
and all work is properly queued). If no stream is provided, DALI will use a default, with
best-effort approach at correctness (see ``cuda_stream`` argument documentation for details).
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, we're ditching the idea of blocking copy? I see it will still happen as you didn't change the internals.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's sort of orthogonal - we'll still block if using internal stream, because the kind of bug it protects against is even harder to detect. Maybe we should make it explicitly configurable. Let's discuss it on dev.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we need that in python level. In C++ it is there. I would still mention here that if no stream is provided it will block.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's mentioned in the documentation of cuda_stream parameter.

@@ -70,13 +70,16 @@ def feed_ndarray(dali_tensor, ptr, cuda_stream = None):
Tensor from which to copy
`ptr` : LoDTensor data pointer
Destination of the copy
`cuda_stream` : Any value that can be casted to cudaStream_t
`cuda_stream` : Any value that can be caste to cudaStream_t
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
`cuda_stream` : Any value that can be caste to cudaStream_t
`cuda_stream` : Any value that can be cast to cudaStream_t

But maybe we should call it representing cudaStream_t? Accessing some attributes is not exactly casting, right?

@@ -48,7 +48,7 @@ def feed_ndarray(dali_tensor, arr, cuda_stream = None):
Tensor from which to copy
`arr` : mxnet.nd.NDArray
Destination of the copy
`cuda_stream` : Any value that can be casted to cudaStream_t
`cuda_stream` : Any value that can be cast to cudaStream_t
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As mentioned elsewhere, maybe:

Suggested change
`cuda_stream` : Any value that can be cast to cudaStream_t
`cuda_stream` : Any value that can be cast or represents cudaStream_t

Comment on lines +464 to +468
The array(s) may be one of:
* NumPy ndarray (CPU)
* MXNet ndarray (CPU)
* PyTorch tensor (CPU or GPU)
* CuPy array (GPU)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't we handle anything with [cuda] array interface?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And Python Buffer Protocol (Numpy is just an example of it). I would add this info to the ExternalSource docs as well.

Comment on lines +496 to +502
infer_stream = False
if cuda_stream is None:
infer_stream = True
if cuda_stream == -1:
cuda_stream = None
else:
cuda_stream = types._raw_cuda_stream(cuda_stream)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we maybe use another name for cuda_stream that is passed further as the argument to this function and the one we use later have a bit different meanings with the None and -1.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe something like stream_ptr, and make the _raw_cuda_stream already return a c_void_p values?
You're packing it by hand in every invocation place.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or maybe it would be better to have some boolean for SetExternalTensorInput indicating if the stream should be generated internally?

Copy link
Contributor Author

@mzient mzient Jun 24, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unpacking it as a raw pointer may cause hard errors in case of the following or similar scenario:

stream = torch.cuda.Stream()
fn.external_source(src, cuda_stream = stream)
# stream reference is forgotten

It's a bug, I agree. But if we unwrap immediately, we'll lose the reference to the stream and will be destroyed - our stream pointer will be invalid and can even be recycled by the driver upon next stream creation, where we'd coincidentally have a different stream, but still invalid. We convert a python-level logic error to a potentially disastrous hard error in native code. I don't want our users to debug THAT kind of erros.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or maybe it would be better to have some boolean for SetExternalTensorInput indicating if the stream should be generated internally?

I think None is fine.

pipe = Pipeline(1, 3, 0)

def gen_batch():
nonlocal t0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you need non-local for t0 and not increment?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because I need the same tensor to change - simply returning a new one in PyTorch resulted in synchronization and the error could not be reproduced even when streams were wrong.

Signed-off-by: Michal Zientkiewicz <michalz@nvidia.com>
@klecki
Copy link
Contributor

klecki commented Jun 24, 2020

!build

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [1420366]: BUILD STARTED

@JanuszL JanuszL self-requested a review June 24, 2020 14:34
provided GPU memory content only using provided stream (DALI schedules
a copy on it and all work is properly queued). If no stream is provided
feed_input blocks until the provided memory is copied to the internal buffer
"""Pass a mutlidimensional array (or a list thereof) to an output of ExternalSource.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would copy/paste more info from ExternalSource docs.

c_type_pointer = ctypes.c_void_p(ptr)
if isinstance(dali_tensor, (TensorGPU, TensorListGPU)):
dali_tensor.copy_to_external(c_type_pointer, cuda_stream)
dali_tensor.copy_to_external(c_type_pointer, ctypes.c_void_p(cuda_stream))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now you will get nullptr instead of None. This breaks the logic in:

    .def("copy_to_external",
        [](Tensor<GPUBackend> &t, py::object p, py::object cuda_stream, bool non_blocking) {
          void *ptr = ctypes_void_ptr(p);
          cudaStream_t stream = cuda_stream.is_none()
                ? UserStream::Get()->GetStream(t)
                : static_cast<cudaStream_t>(ctypes_void_ptr(cuda_stream));

c_type_pointer = ctypes.c_void_p(arr.data_ptr())
if isinstance(dali_tensor, (TensorGPU, TensorListGPU)):
dali_tensor.copy_to_external(c_type_pointer, cuda_stream)
dali_tensor.copy_to_external(c_type_pointer, ctypes.c_void_p(cuda_stream))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As above

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [1420366]: BUILD PASSED

Signed-off-by: Michal Zientkiewicz <michalz@nvidia.com>
@mzient
Copy link
Contributor Author

mzient commented Jun 24, 2020

!build

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [1420580]: BUILD STARTED

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [1420580]: BUILD PASSED

Reject masked tensors.
Fix documentation formatting issues.

Signed-off-by: Michal Zientkiewicz <michalz@nvidia.com>
@dali-automaton
Copy link
Collaborator

CI MESSAGE: [1421224]: BUILD STARTED

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [1421224]: BUILD PASSED

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [1421798]: BUILD STARTED

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [1421798]: BUILD PASSED

Signed-off-by: Michal Zientkiewicz <michalz@nvidia.com>
@dali-automaton
Copy link
Collaborator

CI MESSAGE: [1422380]: BUILD STARTED

Comment on lines +152 to +156
for (int i = strides.size() - 1; i >= 0; i--) {
DALI_ENFORCE(strides[i] == stride_from_shape,
make_string("Strided data not supported. Dimension ", i, " has stride ", strides[i],
" whereas densely packed data of this shape would have a stride ", stride_from_shape));
stride_from_shape *= shape[i];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We use this check in a couple of places now. Can you extract it to the function?

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [1422380]: BUILD PASSED

@mzient mzient merged commit d9a5f03 into NVIDIA:master Jun 25, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants