-
Notifications
You must be signed in to change notification settings - Fork 609
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Prevent clobbering of outputs before non-blocking copy_to_external finishes. #3953
Merged
Merged
Changes from all commits
Commits
Show all changes
11 commits
Select commit
Hold shift + click to select a range
8786798
[WIP]
mzient 9972ecf
[WIP]
mzient ebd81cc
[WIP]
mzient 12c9d03
Trying to repro in C API.
mzient fde1396
Working in both Python and C API.
mzient 0f0aad5
Improved docs. Added test launch script.
mzient c0a8984
Minor refactoring. Improved tests.
mzient 375e09a
Revert useless change in c_api_test.cc
mzient 95b4ead
Make copy_to_external in framework iterators non-blocking.
mzient 1c45b0b
Removed comment accidentally copied from another file.
mzient 30d4b9f
Fix review issues.
mzient File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,170 @@ | ||
// Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#include <gtest/gtest.h> | ||
#include <memory> | ||
#include <string> | ||
#include <utility> | ||
#include <vector> | ||
|
||
#include "dali/c_api.h" | ||
#include "dali/core/cuda_stream_pool.h" | ||
#include "dali/core/dev_buffer.h" | ||
#include "dali/pipeline/pipeline.h" | ||
|
||
|
||
namespace dali { | ||
|
||
|
||
namespace { | ||
|
||
__global__ void hog(float *f, size_t n) { | ||
size_t i = blockIdx.x * blockDim.x + threadIdx.x; | ||
if (i < n) { | ||
for (int k = 0; k < 100; k++) | ||
f[i] = sqrt(f[i]); | ||
} | ||
} | ||
|
||
struct GPUHog { | ||
void init() { | ||
if (!mem) | ||
mem = mm::alloc_raw_unique<float, mm::memory_kind::device>(size); | ||
CUDA_CALL(cudaMemset(mem.get(), 1, size * sizeof(float))); | ||
} | ||
|
||
void run(cudaStream_t stream, int count = 1) { | ||
for (int i = 0; i < count; i++) { | ||
hog<<<div_ceil(size, 512), 512, 0, stream>>>(mem.get(), size); | ||
} | ||
CUDA_CALL(cudaGetLastError()); | ||
} | ||
|
||
mm::uptr<float> mem; | ||
size_t size = 16<<20; | ||
}; | ||
|
||
} // namespace | ||
|
||
enum class Method { | ||
Contiguous, Samples | ||
}; | ||
|
||
void TestCopyOutput(Method method) { | ||
int batch_size = 2; | ||
dali::Pipeline pipe(batch_size, 4, 0); | ||
std::string es_cpu_name = "pipe_in"; | ||
pipe.AddExternalInput(es_cpu_name, "cpu"); | ||
|
||
std::string cont_name = "pipe_out"; | ||
pipe.AddOperator(OpSpec("MakeContiguous") | ||
.AddArg("device", "mixed") | ||
.AddArg("name", cont_name) | ||
.AddInput(es_cpu_name, "cpu") | ||
.AddOutput(cont_name, "gpu"), cont_name); | ||
std::vector<std::pair<std::string, std::string>> outputs = {{"pipe_out", "gpu"}}; | ||
|
||
GPUHog hog; | ||
hog.init(); | ||
|
||
std::vector<std::vector<int>> in_data; | ||
int sample_size = 400000; | ||
TensorListShape<> shape = uniform_list_shape(8, {sample_size}); | ||
in_data.resize(2); | ||
for (int i = 0; i < 2; i++) { | ||
in_data[i].resize(batch_size*sample_size); | ||
for (int j = 0; j < batch_size*sample_size; j++) | ||
in_data[i][j] = i + j; | ||
} | ||
pipe.SetOutputDescs(outputs); | ||
|
||
std::vector<int> out_cpu(batch_size*sample_size); | ||
DeviceBuffer<int> out_gpu; | ||
out_gpu.resize(batch_size*sample_size); | ||
CUDAStreamLease stream = CUDAStreamPool::instance().Get(); | ||
|
||
CUDA_CALL(cudaMemsetAsync(out_gpu.data(), -1, out_gpu.size() * sizeof(int), stream)); | ||
std::string ser = pipe.SerializeToProtobuf(); | ||
hog.run(stream, 50); | ||
|
||
|
||
// This loop is tuned so that if the output buffer is recycled before the asynchronous copy | ||
// finishes, the buffer is clobbered and an error is detected. | ||
// In order to trigger a failure, remove the `wait_order.wait` at the end of | ||
// daliOutputCopy / daliOutputCopySamples | ||
for (int attempt = 0; attempt < 20; attempt++) { | ||
daliPipelineHandle handle; | ||
|
||
// create a new instance of the pipeline | ||
daliDeserializeDefault(&handle, ser.c_str(), ser.size()); | ||
|
||
// feed the data & run - this is the iteration from which we want to see the data | ||
daliSetExternalInput(&handle, "pipe_in", CPU, in_data[0].data(), ::DALI_INT32, | ||
shape.shapes.data(), 1, nullptr, 0); | ||
daliRun(&handle); | ||
|
||
// schedule an extra iteration | ||
daliSetExternalInput(&handle, "pipe_in", CPU, in_data[1].data(), ::DALI_INT32, | ||
shape.shapes.data(), 1, nullptr, 0); | ||
daliRun(&handle); | ||
// ...and prepare for one more | ||
daliSetExternalInput(&handle, "pipe_in", CPU, in_data[1].data(), ::DALI_INT32, | ||
shape.shapes.data(), 1, nullptr, 0); | ||
|
||
// get the outputs - this contains some synchronization, so it comes before dispatching the hog | ||
daliShareOutput(&handle); | ||
// hog the GPU on the stream on which we'll copy the output | ||
hog.run(stream, 10); | ||
|
||
// copy the output on our stream, without waiting on host | ||
if (method == Method::Contiguous) { | ||
daliOutputCopy(&handle, out_gpu.data(), 0, GPU, stream, 0); | ||
} else if (method == Method::Samples) { | ||
void *dsts[batch_size]; | ||
for (int i = 0; i < batch_size; i++) | ||
dsts[i] = out_gpu.data() + i*sample_size; | ||
daliOutputCopySamples(&handle, dsts, 0, GPU, stream, 0); | ||
} | ||
|
||
// release the buffer - it can be immediately recycled (in appropriate stream order) | ||
daliOutputRelease(&handle); | ||
daliRun(&handle); | ||
|
||
// now, copy the buffer to host... | ||
CUDA_CALL(cudaMemcpyAsync(out_cpu.data(), out_gpu.data(), batch_size*sample_size*sizeof(int), | ||
cudaMemcpyDeviceToHost, stream)); | ||
|
||
// ...and verify the contents | ||
for (int i = 0; i < batch_size * sample_size; i++) { | ||
// check for race condition... | ||
ASSERT_NE(out_cpu[i], in_data[1][i]) | ||
<< " synchronization failed - data clobbered by next iteration " << i; | ||
|
||
// ...and for any other corruption | ||
ASSERT_EQ(out_cpu[i], in_data[0][i]) << " data corrupted at index " << i; | ||
} | ||
|
||
daliDeletePipeline(&handle); | ||
} | ||
} | ||
|
||
TEST(CApiTest, daliOutputCopy_Async) { | ||
TestCopyOutput(Method::Contiguous); | ||
} | ||
|
||
TEST(CApiTest, daliOutputCopySamples_Async) { | ||
TestCopyOutput(Method::Samples); | ||
} | ||
|
||
} // namespace dali |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -283,6 +283,45 @@ void ExposeTensorLayout(py::module &m) { | |||||
// Placeholder enum for defining __call__ on dtype member of Tensor (to be deprecated). | ||||||
enum DALIDataTypePlaceholder {}; | ||||||
|
||||||
/** | ||||||
* @brief Copies the contents of the source DALI batch to an external buffer | ||||||
* | ||||||
* The function schedules a copy of the contents of src to the target destination buffer. | ||||||
* The copy will be scheduled on the provided `cuda_stream` or, if left out, on an internal DALI | ||||||
* stream. | ||||||
* If a non-blocking copy is requested, the function will synchronize the source buffer's | ||||||
* associated access order with the provided stream; otherwise, the function will wait until the | ||||||
* copy completes. | ||||||
* | ||||||
* @tparam SourceObject a data store on GPUBackend (Tensor, TensorList, TensorVector) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
* @param src Source batch | ||||||
* @param dst_ptr Destination pointer, wrapped in a C void_ptr Python type | ||||||
* @param cuda_stream CUDA stream, wrapped in a C void_ptr type | ||||||
* @param non_blocking whether the function should wait on host for the copy to complete | ||||||
* @param use_copy_kernel if true, the copy will be done using a kernel instead of cudaMemcpyAsync | ||||||
*/ | ||||||
template <typename SourceObject> | ||||||
void CopyToExternalImplGPU(SourceObject &src, | ||||||
py::object dst_ptr, py::object cuda_stream, | ||||||
bool non_blocking, bool use_copy_kernel) { | ||||||
CUDAStreamLease lease; | ||||||
AccessOrder copy_order; | ||||||
AccessOrder wait_order = non_blocking ? src.order() : AccessOrder::host(); | ||||||
int device = src.device_id(); | ||||||
if (!cuda_stream.is_none()) { | ||||||
cudaStream_t stream = static_cast<cudaStream_t>(ctypes_void_ptr(cuda_stream)); | ||||||
copy_order = AccessOrder(stream, device); | ||||||
} else { | ||||||
lease = CUDAStreamPool::instance().Get(device); | ||||||
copy_order = AccessOrder(lease, device); | ||||||
} | ||||||
|
||||||
void *ptr = ctypes_void_ptr(dst_ptr); | ||||||
CopyToExternal<mm::memory_kind::device>(ptr, src, copy_order, use_copy_kernel); | ||||||
|
||||||
wait_order.wait(copy_order); | ||||||
} | ||||||
|
||||||
/** | ||||||
* Pipeline output descriptor. | ||||||
*/ | ||||||
|
@@ -541,14 +580,7 @@ void ExposeTensor(py::module &m) { | |||||
.def("copy_to_external", | ||||||
[](Tensor<GPUBackend> &t, py::object p, py::object cuda_stream, | ||||||
bool non_blocking, bool use_copy_kernel) { | ||||||
void *ptr = ctypes_void_ptr(p); | ||||||
cudaStream_t stream = cuda_stream.is_none() | ||||||
? UserStream::Get()->GetStream(t) | ||||||
: static_cast<cudaStream_t>(ctypes_void_ptr(cuda_stream)); | ||||||
CopyToExternal<mm::memory_kind::device>(ptr, t, stream, use_copy_kernel); | ||||||
if (!non_blocking) { | ||||||
CUDA_CALL(cudaStreamSynchronize(stream)); | ||||||
} | ||||||
CopyToExternalImplGPU(t, p, cuda_stream, non_blocking, use_copy_kernel); | ||||||
}, | ||||||
"ptr"_a, | ||||||
"cuda_stream"_a = py::none(), | ||||||
|
@@ -1037,14 +1069,7 @@ void ExposeTensorList(py::module &m) { | |||||
.def("copy_to_external", | ||||||
[](TensorList<GPUBackend> &t, py::object p, py::object cuda_stream, | ||||||
bool non_blocking, bool use_copy_kernel) { | ||||||
void *ptr = ctypes_void_ptr(p); | ||||||
cudaStream_t stream = cuda_stream.is_none() | ||||||
? UserStream::Get()->GetStream(t) | ||||||
: static_cast<cudaStream_t>(ctypes_void_ptr(cuda_stream)); | ||||||
CopyToExternal<mm::memory_kind::device>(ptr, t, stream, use_copy_kernel); | ||||||
if (!non_blocking) { | ||||||
CUDA_CALL(cudaStreamSynchronize(stream)); | ||||||
} | ||||||
CopyToExternalImplGPU(t, p, cuda_stream, non_blocking, use_copy_kernel); | ||||||
}, | ||||||
"ptr"_a, | ||||||
"cuda_stream"_a = py::none(), | ||||||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can extract
outside of the if/else
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I cannot -
src
is a local variable initialized inside if/else and has a different type in these branches.I can, however, remove the duplicate asssignment, which I've just noticed here.