Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prevent clobbering of outputs before non-blocking copy_to_external finishes. #3953

Merged
merged 11 commits into from
Jun 8, 2022
56 changes: 30 additions & 26 deletions dali/c_api/c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -545,57 +545,61 @@ void daliOutputCopy(daliPipelineHandle *pipe_handle, void *dst, int output_idx,
dali::DomainTimeRange tr("[DALI][C API] daliOutputCopy", dali::DomainTimeRange::kGreen);

bool is_pinned = flags & DALI_ext_pinned;
bool sync = flags & DALI_ext_force_sync;
bool host_sync = flags & DALI_ext_force_sync;
bool use_copy_kernel = flags & DALI_use_copy_kernel;
auto dst_mem_kind = GetMemKind(dst_type, is_pinned);

dali::DeviceWorkspace *ws = reinterpret_cast<dali::DeviceWorkspace *>(pipe_handle->ws);
assert(ws != nullptr);

auto &type_info = dali::TypeTable::GetTypeInfo(dali::DALIDataType::DALI_UINT8);
AccessOrder wait_order = AccessOrder::host();
AccessOrder copy_order = AccessOrder::host();

if (ws->OutputIsType<CPUBackend>(output_idx)) {
AccessOrder order = is_pinned ? AccessOrder(stream) : AccessOrder::host();
CopyToExternal(dst, dst_mem_kind, ws->Output<CPUBackend>(output_idx),
order, use_copy_kernel);
if (!is_pinned) {
sync = false;
}
copy_order = is_pinned ? AccessOrder(stream) : AccessOrder::host();
auto &src = ws->Output<CPUBackend>(output_idx);
CopyToExternal(dst, dst_mem_kind, src, copy_order, use_copy_kernel);
if (!host_sync)
wait_order = src.order(); // if the copy order is host, then wait will be no-op
} else {
CopyToExternal(dst, dst_mem_kind, ws->Output<GPUBackend>(output_idx),
stream, use_copy_kernel);
}
if (sync) {
CUDA_CALL(cudaStreamSynchronize(stream));
auto &src = ws->Output<GPUBackend>(output_idx);
copy_order = stream;
CopyToExternal(dst, dst_mem_kind, src, copy_order, use_copy_kernel);
if (!host_sync)
wait_order = src.order();
}
wait_order.wait(copy_order);
}

void daliOutputCopySamples(daliPipelineHandle *pipe_handle, void **dsts, int output_idx,
device_type_t dst_type, cudaStream_t stream, unsigned int flags) {
dali::DomainTimeRange tr("[DALI][C API] daliOutputCopySamples", dali::DomainTimeRange::kGreen);

bool is_pinned = flags & DALI_ext_pinned;
bool sync = flags & DALI_ext_force_sync;
bool host_sync = flags & DALI_ext_force_sync;
bool use_copy_kernel = flags & DALI_use_copy_kernel;
auto dst_mem_kind = GetMemKind(dst_type, is_pinned);

dali::DeviceWorkspace *ws = reinterpret_cast<dali::DeviceWorkspace *>(pipe_handle->ws);
assert(ws != nullptr);

auto &type_info = dali::TypeTable::GetTypeInfo(dali::DALIDataType::DALI_UINT8);
AccessOrder wait_order = AccessOrder::host();
AccessOrder copy_order = AccessOrder::host();

if (ws->OutputIsType<CPUBackend>(output_idx)) {
AccessOrder order = is_pinned ? AccessOrder(stream) : AccessOrder::host();
if (!is_pinned) {
sync = false;
}
CopyToExternal(dsts, dst_mem_kind, ws->Output<CPUBackend>(output_idx),
order, use_copy_kernel);
copy_order = is_pinned ? AccessOrder(stream) : AccessOrder::host();
auto & src = ws->Output<CPUBackend>(output_idx);
CopyToExternal(dsts, dst_mem_kind, src, copy_order, use_copy_kernel);
if (!host_sync)
wait_order = src.order(); // if the copy order is host, then wait will be no-op
} else {
CopyToExternal(dsts, dst_mem_kind, ws->Output<GPUBackend>(output_idx),
stream, use_copy_kernel);
}
if (sync) {
CUDA_CALL(cudaStreamSynchronize(stream));
auto &src = ws->Output<GPUBackend>(output_idx);
copy_order = stream;
CopyToExternal(dsts, dst_mem_kind, src, copy_order, use_copy_kernel);
if (!host_sync)
wait_order = src.order();
}
wait_order.wait(copy_order);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can extract

if (!host_sync)
    wait_order = src.order();

outside of the if/else

Copy link
Contributor Author

@mzient mzient Jun 8, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I cannot - src is a local variable initialized inside if/else and has a different type in these branches.
I can, however, remove the duplicate asssignment, which I've just noticed here.

}


Expand Down
170 changes: 170 additions & 0 deletions dali/c_api/c_api_test.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
// Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <gtest/gtest.h>
#include <memory>
#include <string>
#include <utility>
#include <vector>

#include "dali/c_api.h"
#include "dali/core/cuda_stream_pool.h"
#include "dali/core/dev_buffer.h"
#include "dali/pipeline/pipeline.h"


namespace dali {


namespace {

__global__ void hog(float *f, size_t n) {
size_t i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < n) {
for (int k = 0; k < 100; k++)
f[i] = sqrt(f[i]);
}
}

struct GPUHog {
void init() {
if (!mem)
mem = mm::alloc_raw_unique<float, mm::memory_kind::device>(size);
CUDA_CALL(cudaMemset(mem.get(), 1, size * sizeof(float)));
}

void run(cudaStream_t stream, int count = 1) {
for (int i = 0; i < count; i++) {
hog<<<div_ceil(size, 512), 512, 0, stream>>>(mem.get(), size);
}
CUDA_CALL(cudaGetLastError());
}

mm::uptr<float> mem;
size_t size = 16<<20;
};

} // namespace

enum class Method {
Contiguous, Samples
};

void TestCopyOutput(Method method) {
int batch_size = 2;
dali::Pipeline pipe(batch_size, 4, 0);
std::string es_cpu_name = "pipe_in";
pipe.AddExternalInput(es_cpu_name, "cpu");

std::string cont_name = "pipe_out";
pipe.AddOperator(OpSpec("MakeContiguous")
.AddArg("device", "mixed")
.AddArg("name", cont_name)
.AddInput(es_cpu_name, "cpu")
.AddOutput(cont_name, "gpu"), cont_name);
std::vector<std::pair<std::string, std::string>> outputs = {{"pipe_out", "gpu"}};

GPUHog hog;
hog.init();

std::vector<std::vector<int>> in_data;
int sample_size = 400000;
TensorListShape<> shape = uniform_list_shape(8, {sample_size});
in_data.resize(2);
for (int i = 0; i < 2; i++) {
in_data[i].resize(batch_size*sample_size);
for (int j = 0; j < batch_size*sample_size; j++)
in_data[i][j] = i + j;
}
pipe.SetOutputDescs(outputs);

std::vector<int> out_cpu(batch_size*sample_size);
DeviceBuffer<int> out_gpu;
out_gpu.resize(batch_size*sample_size);
CUDAStreamLease stream = CUDAStreamPool::instance().Get();

CUDA_CALL(cudaMemsetAsync(out_gpu.data(), -1, out_gpu.size() * sizeof(int), stream));
std::string ser = pipe.SerializeToProtobuf();
hog.run(stream, 50);


// This loop is tuned so that if the output buffer is recycled before the asynchronous copy
// finishes, the buffer is clobbered and an error is detected.
// In order to trigger a failure, remove the `wait_order.wait` at the end of
// daliOutputCopy / daliOutputCopySamples
for (int attempt = 0; attempt < 20; attempt++) {
daliPipelineHandle handle;

// create a new instance of the pipeline
daliDeserializeDefault(&handle, ser.c_str(), ser.size());

// feed the data & run - this is the iteration from which we want to see the data
daliSetExternalInput(&handle, "pipe_in", CPU, in_data[0].data(), ::DALI_INT32,
shape.shapes.data(), 1, nullptr, 0);
daliRun(&handle);

// schedule an extra iteration
daliSetExternalInput(&handle, "pipe_in", CPU, in_data[1].data(), ::DALI_INT32,
shape.shapes.data(), 1, nullptr, 0);
daliRun(&handle);
// ...and prepare for one more
daliSetExternalInput(&handle, "pipe_in", CPU, in_data[1].data(), ::DALI_INT32,
shape.shapes.data(), 1, nullptr, 0);

// get the outputs - this contains some synchronization, so it comes before dispatching the hog
daliShareOutput(&handle);
// hog the GPU on the stream on which we'll copy the output
hog.run(stream, 10);

// copy the output on our stream, without waiting on host
if (method == Method::Contiguous) {
daliOutputCopy(&handle, out_gpu.data(), 0, GPU, stream, 0);
} else if (method == Method::Samples) {
void *dsts[batch_size];
for (int i = 0; i < batch_size; i++)
dsts[i] = out_gpu.data() + i*sample_size;
daliOutputCopySamples(&handle, dsts, 0, GPU, stream, 0);
}

// release the buffer - it can be immediately recycled (in appropriate stream order)
daliOutputRelease(&handle);
daliRun(&handle);

// now, copy the buffer to host...
CUDA_CALL(cudaMemcpyAsync(out_cpu.data(), out_gpu.data(), batch_size*sample_size*sizeof(int),
cudaMemcpyDeviceToHost, stream));

// ...and verify the contents
for (int i = 0; i < batch_size * sample_size; i++) {
// check for race condition...
ASSERT_NE(out_cpu[i], in_data[1][i])
<< " synchronization failed - data clobbered by next iteration " << i;

// ...and for any other corruption
ASSERT_EQ(out_cpu[i], in_data[0][i]) << " data corrupted at index " << i;
}

daliDeletePipeline(&handle);
}
}

TEST(CApiTest, daliOutputCopy_Async) {
TestCopyOutput(Method::Contiguous);
}

TEST(CApiTest, daliOutputCopySamples_Async) {
TestCopyOutput(Method::Samples);
}

} // namespace dali
57 changes: 41 additions & 16 deletions dali/python/backend_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,45 @@ void ExposeTensorLayout(py::module &m) {
// Placeholder enum for defining __call__ on dtype member of Tensor (to be deprecated).
enum DALIDataTypePlaceholder {};

/**
* @brief Copies the contents of the source DALI batch to an external buffer
*
* The function schedules a copy of the contents of src to the target destination buffer.
* The copy will be scheduled on the provided `cuda_stream` or, if left out, on an internal DALI
* stream.
* If a non-blocking copy is requested, the function will synchronize the source buffer's
* associated access order with the provided stream; otherwise, the function will wait until the
* copy completes.
*
* @tparam SourceObject a data store on GPUBackend (Tensor, TensorList, TensorVector)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
* @tparam SourceObject a data store on GPUBackend (Tensor, TensorList, TensorVector)
* @tparam SourceObject a data store on GPUBackend (Tensor, TensorList, TensorVector)

* @param src Source batch
* @param dst_ptr Destination pointer, wrapped in a C void_ptr Python type
* @param cuda_stream CUDA stream, wrapped in a C void_ptr type
* @param non_blocking whether the function should wait on host for the copy to complete
* @param use_copy_kernel if true, the copy will be done using a kernel instead of cudaMemcpyAsync
*/
template <typename SourceObject>
void CopyToExternalImplGPU(SourceObject &src,
py::object dst_ptr, py::object cuda_stream,
bool non_blocking, bool use_copy_kernel) {
CUDAStreamLease lease;
AccessOrder copy_order;
AccessOrder wait_order = non_blocking ? src.order() : AccessOrder::host();
int device = src.device_id();
if (!cuda_stream.is_none()) {
cudaStream_t stream = static_cast<cudaStream_t>(ctypes_void_ptr(cuda_stream));
copy_order = AccessOrder(stream, device);
} else {
lease = CUDAStreamPool::instance().Get(device);
copy_order = AccessOrder(lease, device);
}

void *ptr = ctypes_void_ptr(dst_ptr);
CopyToExternal<mm::memory_kind::device>(ptr, src, copy_order, use_copy_kernel);

wait_order.wait(copy_order);
}

/**
* Pipeline output descriptor.
*/
Expand Down Expand Up @@ -541,14 +580,7 @@ void ExposeTensor(py::module &m) {
.def("copy_to_external",
[](Tensor<GPUBackend> &t, py::object p, py::object cuda_stream,
bool non_blocking, bool use_copy_kernel) {
void *ptr = ctypes_void_ptr(p);
cudaStream_t stream = cuda_stream.is_none()
? UserStream::Get()->GetStream(t)
: static_cast<cudaStream_t>(ctypes_void_ptr(cuda_stream));
CopyToExternal<mm::memory_kind::device>(ptr, t, stream, use_copy_kernel);
if (!non_blocking) {
CUDA_CALL(cudaStreamSynchronize(stream));
}
CopyToExternalImplGPU(t, p, cuda_stream, non_blocking, use_copy_kernel);
},
"ptr"_a,
"cuda_stream"_a = py::none(),
Expand Down Expand Up @@ -1037,14 +1069,7 @@ void ExposeTensorList(py::module &m) {
.def("copy_to_external",
[](TensorList<GPUBackend> &t, py::object p, py::object cuda_stream,
bool non_blocking, bool use_copy_kernel) {
void *ptr = ctypes_void_ptr(p);
cudaStream_t stream = cuda_stream.is_none()
? UserStream::Get()->GetStream(t)
: static_cast<cudaStream_t>(ctypes_void_ptr(cuda_stream));
CopyToExternal<mm::memory_kind::device>(ptr, t, stream, use_copy_kernel);
if (!non_blocking) {
CUDA_CALL(cudaStreamSynchronize(stream));
}
CopyToExternalImplGPU(t, p, cuda_stream, non_blocking, use_copy_kernel);
},
"ptr"_a,
"cuda_stream"_a = py::none(),
Expand Down
3 changes: 2 additions & 1 deletion dali/python/nvidia/dali/plugin/mxnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,8 @@ def feed_ndarray(dali_tensor, arr, cuda_stream=None):

# Copy data from DALI tensor to ptr
if isinstance(dali_tensor, (TensorGPU, TensorListGPU)):
dali_tensor.copy_to_external(ptr, None if cuda_stream is None else ctypes.c_void_p(cuda_stream))
stream = None if cuda_stream is None else ctypes.c_void_p(cuda_stream)
dali_tensor.copy_to_external(ptr, stream, non_blocking=True)
else:
dali_tensor.copy_to_external(ptr)

Expand Down
4 changes: 2 additions & 2 deletions dali/python/nvidia/dali/plugin/paddle.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,8 @@ def feed_ndarray(dali_tensor, ptr, cuda_stream=None):

c_type_pointer = ctypes.c_void_p(ptr)
if isinstance(dali_tensor, (TensorGPU, TensorListGPU)):
dali_tensor.copy_to_external(
c_type_pointer, None if cuda_stream is None else ctypes.c_void_p(cuda_stream))
stream = None if cuda_stream is None else ctypes.c_void_p(cuda_stream)
dali_tensor.copy_to_external(c_type_pointer, stream, non_blocking=True)
else:
dali_tensor.copy_to_external(c_type_pointer)
return ptr
Expand Down
3 changes: 2 additions & 1 deletion dali/python/nvidia/dali/plugin/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,8 @@ def feed_ndarray(dali_tensor, arr, cuda_stream=None):
# turn raw int to a c void pointer
c_type_pointer = ctypes.c_void_p(arr.data_ptr())
if isinstance(dali_tensor, (TensorGPU, TensorListGPU)):
dali_tensor.copy_to_external(c_type_pointer, None if cuda_stream is None else ctypes.c_void_p(cuda_stream))
stream = None if cuda_stream is None else ctypes.c_void_p(cuda_stream)
dali_tensor.copy_to_external(c_type_pointer, stream, non_blocking=True)
else:
dali_tensor.copy_to_external(c_type_pointer)
return arr
Expand Down
Loading