Skip to content

Commit

Permalink
Stop exposing internal contiguous TV storage (#3827)
Browse files Browse the repository at this point in the history
Remove from TensorVector the AsTensorList method
and constructor from external shared_ptr<TensorList>

Both allowed to observe the internal state of TensorVector
from the outside, breaking the encapsulation.

Adjust few places that relied on the changes of internal
state to be externally visible:
* Initializing the data graph in workspaces.
  The TV -> TL conversion is done vie Mixed stage.
* ArgumentInputs that are produced as TensorList
  need to be resynced by ShareData instead.

This reverts some changes from #2165:
  - CPU stage cannot be used as direct outputs
    due to the TensorList/TensorVector mismatch,
    we can only share data downwards, but we
    cannot share and preemptivelly expect the allocation
    to be mirrored.
  - CPU-only stage still uses Mixed stage, but just
    with MakeContigous constrained to CPU outputs.

Workspace initialization now takes the CPU_ONLY_DEVICE_ID
into consideration and does not set the stream (resulting
in has_stream() being false, which in turn keeps the
AccessOrder in Mixed ops as HostOrder only).

Memory is set to non pinned when CPU_ONLY_DEVICE_ID is
detected.

Eager mode optimizations with contiguous
TensorVector -> TensorList was disabled - it waits
for the rework of TensorVector replacing TensorList.

An escape hatch to access the shared_ptr of the
allocation was ported to TensorVector from TensorList,
to allow for ExternalSource to pass data without copy.

Signed-off-by: Krzysztof Lecki <klecki@nvidia.com>
  • Loading branch information
klecki authored Jul 14, 2022
1 parent 85a7001 commit 5c816bf
Show file tree
Hide file tree
Showing 14 changed files with 257 additions and 203 deletions.
27 changes: 0 additions & 27 deletions dali/pipeline/data/tensor_vector.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,20 +32,6 @@ TensorVector<Backend>::TensorVector(int batch_size)
resize_tensors(batch_size);
}


template <typename Backend>
TensorVector<Backend>::TensorVector(std::shared_ptr<TensorList<Backend>> tl)
: views_count_(0), curr_num_tensors_(0), tl_(std::move(tl)) {
assert(tl_ && "Construction with null TensorList is illegal");
pinned_ = tl_->is_pinned();
type_ = tl_->type_info();
sample_dim_ = tl_->shape().sample_dim();
state_ = State::contiguous;
resize_tensors(tl_->num_samples());
UpdateViews();
}


template <typename Backend>
TensorVector<Backend>::TensorVector(TensorVector<Backend> &&other) noexcept {
state_ = other.state_;
Expand Down Expand Up @@ -607,19 +593,6 @@ void TensorVector<Backend>::UpdateViews() {
}


template <typename Backend>
std::shared_ptr<TensorList<Backend>> TensorVector<Backend>::AsTensorList(bool check_contiguity) {
DALI_ENFORCE(IsContiguous() || !check_contiguity,
"Cannot cast non continuous TensorVector to TensorList.");
// Update the metadata when we are exposing the TensorList to the outside, as it might have been
// kept in the individual tensors
for (int idx = 0; idx < curr_num_tensors_; idx++) {
tl_->SetMeta(idx, tensors_[idx]->GetMeta());
}
return tl_;
}


template <typename Backend>
void TensorVector<Backend>::resize_tensors(int new_size) {
if (static_cast<size_t>(new_size) > tensors_.size()) {
Expand Down
28 changes: 24 additions & 4 deletions dali/pipeline/data/tensor_vector.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,6 @@ class DLL_PUBLIC TensorVector {
*/
explicit TensorVector(int batch_size);

explicit TensorVector(std::shared_ptr<TensorList<Backend>> tl);

TensorVector(const TensorVector &) = delete;
TensorVector &operator=(const TensorVector &) = delete;

Expand Down Expand Up @@ -330,8 +328,6 @@ class DLL_PUBLIC TensorVector {

void UpdateViews();

shared_ptr<TensorList<Backend>> AsTensorList(bool check_contiguity = true);

private:
enum class State { contiguous, noncontiguous };

Expand Down Expand Up @@ -395,6 +391,30 @@ class DLL_PUBLIC TensorVector {
// with different template types
template <typename InBackend>
friend class TensorVector;

/** @defgroup AccessorFunctions Fallback for accessing pointers owning the samples
* Fallback access to contiguous data or samples of the batch. It should not be used for regular
* processing, intended mostly for batches that were made sure to be contiguous (mainly
* for pipeline outputs).
* @{
*/

/**
* @brief Return the shared pointer, that we can use to correctly share the ownership of sample
* with.
* Sample 0 is aliased with the whole buffer, if it is contiguous.
*/
friend shared_ptr<void> unsafe_sample_owner(TensorVector<Backend> &batch, int sample_idx) {
// create new aliasing pointer to current data allocation, so we share the use count
// and the deleter correctly.
if (batch.IsContiguous()) {
return {unsafe_sample_owner(*batch.tl_, 0), batch.raw_mutable_tensor(sample_idx)};
} else {
return batch.tensors_[sample_idx]->get_data_ptr();
}
}

/** @} */ // end of AccessorFunctions
};

} // namespace dali
Expand Down
67 changes: 38 additions & 29 deletions dali/pipeline/executor/executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -65,14 +65,23 @@ void Executor<WorkspacePolicy, QueuePolicy>::SyncDevice() {
template <typename WorkspacePolicy, typename QueuePolicy>
void Executor<WorkspacePolicy, QueuePolicy>::RunCPUImpl() {
PreRun();

const char placement_error[] =
"Cannot run a pipeline with Mixed/GPU ops in CPU-only mode. Please provide "
"valid device id or change the operators' device.";
if (device_id_ < 0) {
DALI_ENFORCE(device_id_ == CPU_ONLY_DEVICE_ID,
"Wrong device_id provided, it should be >= 0, "
"or equal to CPU_ONLY_DEVICE_ID.");
DALI_ENFORCE(graph_->NumOp(OpType::GPU) == 0 && graph_->NumOp(OpType::MIXED) == 0,
"Cannot run a pipeline with Mixed/GPU ops in CPU-only mode. Please provide "
"valid device id or change the operators' device.");
DALI_ENFORCE(graph_->NumOp(OpType::GPU) == 0, placement_error);

for (int i = 0; i < graph_->NumOp(OpType::MIXED) && !exec_error_; ++i) {
const OpNode &op_node = graph_->Node(OpType::MIXED, i);
DALI_ENFORCE(op_node.spec.GetSchema().name() == "MakeContiguous", placement_error);
for (auto tensor_id : op_node.children_tensors) {
const TensorNode &tensor_node = graph_->Tensor(tensor_id);
DALI_ENFORCE(tensor_node.producer.storage_device == StorageDevice::CPU, placement_error);
}
}
}

DomainTimeRange tr("[DALI][Executor] RunCPU");
Expand Down Expand Up @@ -126,20 +135,11 @@ void Executor<WorkspacePolicy, QueuePolicy>::RunMixedImpl() {
return;
}

// short path for pure CPU pipeline
if (device_id_ == CPU_ONLY_DEVICE_ID) {
if (callback_) {
callback_();
}
// We do not release, but handle to used outputs
QueuePolicy::ReleaseIdxs(OpType::MIXED, mixed_idxs, mixed_op_stream_);
return;
}

// Enforce our assumed dependency between consecutive
// iterations of a stage of the pipeline.

CUDA_CALL(cudaEventSynchronize(mixed_stage_event_));
if (device_id_ != CPU_ONLY_DEVICE_ID)
CUDA_CALL(cudaEventSynchronize(mixed_stage_event_));

auto batch_size = batch_sizes_mixed_.front();
batch_sizes_mixed_.pop();
Expand All @@ -155,30 +155,39 @@ void Executor<WorkspacePolicy, QueuePolicy>::RunMixedImpl() {
RunHelper(op_node, ws);
FillStats(mixed_memory_stats_, ws, "MIXED_" + op_node.instance_name,
mixed_memory_stats_mutex_);
if (ws.has_stream() && ws.has_event()) {
CUDA_CALL(cudaEventRecord(ws.event(), ws.stream()));
if (device_id_ != CPU_ONLY_DEVICE_ID) {
if (ws.has_stream() && ws.has_event()) {
CUDA_CALL(cudaEventRecord(ws.event(), ws.stream()));
}
CUDA_CALL(cudaGetLastError());
}
CUDA_CALL(cudaGetLastError());
} catch (std::exception &e) {
HandleError("Mixed", op_node, e.what());
} catch (...) {
HandleError();
}
}

if (callback_) {
// Record event that will allow to call the callback after whole run of this pipeline is
// finished.
CUDA_CALL(cudaEventRecord(mixed_callback_events_[mixed_idxs[OpType::MIXED]], mixed_op_stream_));
}
if (device_id_ != CPU_ONLY_DEVICE_ID) {
if (callback_) {
// Record event that will allow to call the callback after whole run of this pipeline is
// finished.
CUDA_CALL(
cudaEventRecord(mixed_callback_events_[mixed_idxs[OpType::MIXED]], mixed_op_stream_));
}

if (!mixed_output_events_.empty()) {
int queue_id = mixed_idxs[OpType::MIXED];
CUDA_CALL(cudaEventRecord(mixed_output_events_.GetEvent(queue_id), mixed_op_stream_));
}
if (!mixed_output_events_.empty()) {
int queue_id = mixed_idxs[OpType::MIXED];
CUDA_CALL(cudaEventRecord(mixed_output_events_.GetEvent(queue_id), mixed_op_stream_));
}

// We know that this is the proper stream, we do not need to look it up in any workspace
CUDA_CALL(cudaEventRecord(mixed_stage_event_, mixed_op_stream_));
// We know that this is the proper stream, we do not need to look it up in any workspace
CUDA_CALL(cudaEventRecord(mixed_stage_event_, mixed_op_stream_));
} else {
if (callback_) {
callback_();
}
}

// Pass the work to the gpu stage
QueuePolicy::ReleaseIdxs(OpType::MIXED, mixed_idxs, mixed_op_stream_);
Expand Down
29 changes: 26 additions & 3 deletions dali/pipeline/executor/executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -465,7 +465,7 @@ void Executor<WorkspacePolicy, QueuePolicy>::Build(OpGraph *graph, vector<string
// workspaces so that nothing has to be altered
// during execution (this is necessary for
// asynchronous executors that can overlap work issue)
ws_policy_.InitializeWorkspaceStore(*graph_, tensor_to_store_queue_, &thread_pool_,
ws_policy_.InitializeWorkspaceStore(*graph_, device_id_, tensor_to_store_queue_, &thread_pool_,
mixed_op_stream_, gpu_op_stream_, mixed_op_events_,
queue_sizes_);

Expand Down Expand Up @@ -509,12 +509,12 @@ void Executor<WorkspacePolicy, QueuePolicy>::ShareOutputs(DeviceWorkspace *ws) {
auto storage_dev = out_tensor.producer.storage_device;
VALUE_SWITCH(storage_dev, storage_dev_static, (StorageDevice::GPU, StorageDevice::CPU),
(
VALUE_SWITCH(op_type, op_type_static, (OpType::CPU, OpType::MIXED, OpType::GPU),
VALUE_SWITCH(op_type, op_type_static, (OpType::MIXED, OpType::GPU),
(
auto &queue = get_queue<op_type_static, storage_dev_static>(
tensor_to_store_queue_[out_tensor_id]);
auto stage_output_idx = output_idx[op_type_static];
ws->AddOutput(PresentAsTensorList(queue[stage_output_idx]));
ws->AddOutput(queue[stage_output_idx]);
), DALI_FAIL("Invalid op type")); // NOLINT(whitespace/parens)
), DALI_FAIL("Invalid storage device")); // NOLINT(whitespace/parens)
}
Expand Down Expand Up @@ -638,6 +638,29 @@ std::vector<int> Executor<WorkspacePolicy, QueuePolicy>::GetTensorQueueSizes(con
template <typename WorkspacePolicy, typename QueuePolicy>
void Executor<WorkspacePolicy, QueuePolicy>::PrepinData(
std::vector<tensor_data_store_queue_t> &tensor_to_store_queue, const OpGraph &graph) {
// No pinning when working in CPU only mode
if (device_id_ == CPU_ONLY_DEVICE_ID) {
for (int tid = 0; tid < graph.NumTensor(); tid++) {
// Only CPU storage device in CPU_ONLY mode
auto &cpu_cpu_queue =
get_queue<OpType::CPU, StorageDevice::CPU>(tensor_to_store_queue_[tid]);
auto &mixed_cpu_queue =
get_queue<OpType::MIXED, StorageDevice::CPU>(tensor_to_store_queue_[tid]);
auto &gpu_cpu_queue =
get_queue<OpType::GPU, StorageDevice::CPU>(tensor_to_store_queue_[tid]);

for (auto &t : cpu_cpu_queue) {
t->set_pinned(false);
}
for (auto &t : mixed_cpu_queue) {
t->set_pinned(false);
}
for (auto &t : gpu_cpu_queue) {
t->set_pinned(false);
}
}
return;
}
// We only pin what we need:
// The inputs of mixed ops are potentially used for H2D copies...
for (int i = 0; i < graph.NumOp(OpType::MIXED); i++) {
Expand Down
22 changes: 11 additions & 11 deletions dali/pipeline/executor/executor_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ TYPED_TEST(ExecutorTest, TestPruneBasicGraph) {

graph.AddOp(this->PrepareSpec(
OpSpec("MakeContiguous")
.AddArg("device", "cpu")
.AddArg("device", "mixed")
.AddInput("data3", "cpu")
.AddOutput("data3_cont", "cpu")), "");

Expand All @@ -127,8 +127,8 @@ TYPED_TEST(ExecutorTest, TestPruneBasicGraph) {
// Validate the graph - op 3 should
// have been pruned as its outputs
// are unused.
ASSERT_EQ(graph.NumOp(OpType::CPU), 3);
ASSERT_EQ(graph.NumOp(OpType::MIXED), 0);
ASSERT_EQ(graph.NumOp(OpType::CPU), 2);
ASSERT_EQ(graph.NumOp(OpType::MIXED), 1);
ASSERT_EQ(graph.NumOp(OpType::GPU), 0);

// Validate the source op
Expand Down Expand Up @@ -176,7 +176,7 @@ TYPED_TEST(ExecutorTest, TestPruneMultiple) {

graph.AddOp(this->PrepareSpec(
OpSpec("MakeContiguous")
.AddArg("device", "cpu")
.AddArg("device", "mixed")
.AddInput("data1", "cpu")
.AddOutput("data1_cont", "cpu")), "");

Expand Down Expand Up @@ -207,8 +207,8 @@ TYPED_TEST(ExecutorTest, TestPruneMultiple) {
// Validate the graph - op 2&3 should
// have been pruned.
// Op 4 should not be pruned
ASSERT_EQ(graph.NumOp(OpType::CPU), 3);
ASSERT_EQ(graph.NumOp(OpType::MIXED), 0);
ASSERT_EQ(graph.NumOp(OpType::CPU), 2);
ASSERT_EQ(graph.NumOp(OpType::MIXED), 1);
ASSERT_EQ(graph.NumOp(OpType::GPU), 0);

// Validate the source op
Expand Down Expand Up @@ -248,7 +248,7 @@ TYPED_TEST(ExecutorTest, TestPruneRecursive) {

graph.AddOp(this->PrepareSpec(
OpSpec("MakeContiguous")
.AddArg("device", "cpu")
.AddArg("device", "mixed")
.AddInput("data1", "cpu")
.AddOutput("data1_cont", "cpu")), "");

Expand All @@ -271,8 +271,8 @@ TYPED_TEST(ExecutorTest, TestPruneRecursive) {

// Validate the graph - op 2&3 should
// have been pruned
ASSERT_EQ(graph.NumOp(OpType::CPU), 2);
ASSERT_EQ(graph.NumOp(OpType::MIXED), 0);
ASSERT_EQ(graph.NumOp(OpType::CPU), 1);
ASSERT_EQ(graph.NumOp(OpType::MIXED), 1);
ASSERT_EQ(graph.NumOp(OpType::GPU), 0);

// Validate the source op
Expand Down Expand Up @@ -406,7 +406,7 @@ TYPED_TEST(ExecutorTest, TestRunBasicGraph) {

graph.AddOp(this->PrepareSpec(
OpSpec("MakeContiguous")
.AddArg("device", "cpu")
.AddArg("device", "mixed")
.AddInput("images", "cpu")
.AddOutput("final_images", "cpu")), "");

Expand Down Expand Up @@ -452,7 +452,7 @@ TYPED_TEST(ExecutorTest, TestRunBasicGraphWithCB) {

graph.AddOp(this->PrepareSpec(
OpSpec("MakeContiguous")
.AddArg("device", "cpu")
.AddArg("device", "mixed")
.AddInput("images", "cpu")
.AddOutput("final_images", "cpu")), "");

Expand Down
Loading

0 comments on commit 5c816bf

Please sign in to comment.