Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prevent clobbering of outputs before non-blocking copy_to_external finishes. #3953

Merged
merged 11 commits into from
Jun 8, 2022

Conversation

mzient
Copy link
Contributor

@mzient mzient commented Jun 2, 2022

Category:

Bug fix
Tests

Description:

This PR fixes the issue with copy_to_external / daliCopyOutput where not requiring host synchronization introduced a race condition - once ReleaseOutputs was called, the output buffer that is still being copied could be clobbered by the next iteration of the pipeline.
This PR adds a device/device synchronization between the stream associated with the tensor being copied (usually that's the GPU stage's stream) and the user stream. With this change, any work submitted on the GPU stream after copy_to_external exits will be scheduled after the copy.
There are extensive tests with PyTorch CUDA streams and in C API that triggered the issue and a test-driven approach was used to make the issue go away without altering the tests.

Additional information:

The tests are not 100% reliable, as is always the case with race condition - i.e. there can be false-negatives if the issue reappears.

Affected modules and functionalities:

bakcend_impl - copy_to_external
C API
C API tests

Key points relevant for the review:

N/A

Tests:

test_copy_to_external_torch.py - all tests (it's a new file)
c_api_test.cu - likewise
existing tests - many of these, should check for regressions - these include
c_api_test.cc - tests with daliCopyOutput, daliCopyOutputSamples in their name
framework tests - they use copy_to_external to populate framework tensors

  • Existing tests apply
  • New tests added
    • Python tests
    • GTests
    • Benchmark
    • Other
  • N/A

Checklist

Documentation

  • Existing documentation applies
  • Documentation updated
    • Docstring
    • Doxygen
    • RST
    • Jupyter
    • Other
  • N/A

DALI team only

Requirements

  • Implements new requirements
  • Affects existing requirements
  • N/A

REQ IDs: N/A

JIRA TASK: DALI-2467

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [4994090]: BUILD STARTED

@mzient mzient changed the title Async feed nd array Make feed_ndarray non-blocking Jun 2, 2022
@dali-automaton
Copy link
Collaborator

CI MESSAGE: [4994090]: BUILD PASSED

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [5004719]: BUILD STARTED

@mzient mzient changed the title Make feed_ndarray non-blocking Prevent clobbering of outputs before non-blocking copy_to_external finishes. Jun 3, 2022
@dali-automaton
Copy link
Collaborator

CI MESSAGE: [5004719]: BUILD PASSED

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [5032484]: BUILD STARTED

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [5032484]: BUILD PASSED

* The copy will be scheduled on the provided `cuda_stream` or, if left out, on an internal DALI
* stream.
* If a non-blocking copy is requested, the function will synchronize the source buffer's
* associated access order with the provided stream; otherwie, the function will wait until the
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
* associated access order with the provided stream; otherwie, the function will wait until the
* associated access order with the provided stream; otherwise, the function will wait until the

* associated access order with the provided stream; otherwie, the function will wait until the
* copy completes.
*
* @tparam SourceObject a data store on GPUBackend (Tensor, TensorList, TensorVector)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
* @tparam SourceObject a data store on GPUBackend (Tensor, TensorList, TensorVector)
* @tparam SourceObject a data store on GPUBackend (Tensor, TensorList, TensorVector)

Comment on lines +34 to +43
to_torch_type = {
types.DALIDataType.FLOAT : torch.float32,
types.DALIDataType.FLOAT64 : torch.float64,
types.DALIDataType.FLOAT16 : torch.float16,
types.DALIDataType.UINT8 : torch.uint8,
types.DALIDataType.INT8 : torch.int8,
types.DALIDataType.INT16 : torch.int16,
types.DALIDataType.INT32 : torch.int32,
types.DALIDataType.INT64 : torch.int64
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use
from nvidia.dali.plugin.pytorch import to_torch_type
?

Comment on lines +45 to +80
def feed_ndarray(tensor_or_tl, arr, cuda_stream=None, non_blocking=False):
"""
Copy contents of DALI tensor to PyTorch's Tensor.

Parameters
----------
`tensor_or_tl` : TensorGPU or TensorListGPU
`arr` : torch.Tensor
Destination of the copy
`cuda_stream` : torch.cuda.Stream, cudaStream_t or any value that can be cast to cudaStream_t.
CUDA stream to be used for the copy
(if not provided, an internal user stream will be selected)
In most cases, using pytorch's current stream is expected (for example,
if we are copying to a tensor allocated with torch.zeros(...))
"""
dali_type = to_torch_type[tensor_or_tl.dtype]
if isinstance(tensor_or_tl, TensorListGPU):
dali_tensor = tensor_or_tl.as_tensor()
else:
dali_tensor = tensor_or_tl


assert dali_type == arr.dtype, ("The element type of DALI Tensor/TensorList"
" doesn't match the element type of the target PyTorch Tensor:"
"{} vs {}".format(dali_type, arr.dtype))

assert dali_tensor.shape() == list(arr.size()), \
("Shapes do not match: DALI tensor has size {0}"
", but PyTorch Tensor has size {1}".format(dali_tensor.shape(), list(arr.size())))
cuda_stream = types._raw_cuda_stream(cuda_stream)

# turn raw int to a c void pointer
c_type_pointer = ctypes.c_void_p(arr.data_ptr())
stream = None if cuda_stream is None else ctypes.c_void_p(cuda_stream)
tensor_or_tl.copy_to_external(c_type_pointer, stream, non_blocking)
return arr
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally we would remove this and extend the feed_ndarray in the plugin to support the non_blocking argument. This can be handle as a separate task

pipe.release_outputs()
# if no appropriate synchronization is done, the array is likely
# clobbered with the results from the second iteration
assert check(arr, ref)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
assert check(arr, ref)
assert torch.equal(arr, ref)

would read better, IMHO


// This loop is tuned so that if the output buffer is recycled before the asynchronous copy
// finishes, the buffer is clobbered and an error is detected.
// (michalz) Verified on my desktop. The changes in c_api that came with this test
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This makes sense in the context of this PR, but will have no meaning some time in the future. Do we keep such comment?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, this is a test and a delicate one, too. It was very hard to get a repro, so this comment is a word of caution for whoever is touching this. I could remove the sentence about my desktop, but I'd add a repro on how to break the code to trigger a failure.

}
wait_order.wait(copy_order);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can extract

if (!host_sync)
    wait_order = src.order();

outside of the if/else

Copy link
Contributor Author

@mzient mzient Jun 8, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I cannot - src is a local variable initialized inside if/else and has a different type in these branches.
I can, however, remove the duplicate asssignment, which I've just noticed here.

mzient and others added 11 commits June 8, 2022 11:52
Signed-off-by: Michał Zientkiewicz <mzient@gmail.com>
Signed-off-by: Michal Zientkiewicz <michalz@nvidia.com>
Signed-off-by: Michal Zientkiewicz <michalz@nvidia.com>
Signed-off-by: Michał Zientkiewicz <mzient@gmail.com>
Signed-off-by: Michał Zientkiewicz <mzient@gmail.com>
Signed-off-by: Michał Zientkiewicz <mzient@gmail.com>
Signed-off-by: Michał Zientkiewicz <mzient@gmail.com>
Signed-off-by: Michał Zientkiewicz <mzient@gmail.com>
Signed-off-by: Michal Zientkiewicz <michalz@nvidia.com>
Signed-off-by: Michał Zientkiewicz <mzient@gmail.com>
Signed-off-by: Michal Zientkiewicz <michalz@nvidia.com>
@dali-automaton
Copy link
Collaborator

CI MESSAGE: [5042627]: BUILD STARTED

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [5042627]: BUILD PASSED

@mzient mzient merged commit 80fce13 into NVIDIA:main Jun 8, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants