Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dump operator stats #2039

Merged
merged 6 commits into from
Jun 23, 2020
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 34 additions & 1 deletion dali/c_api/c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,8 @@ void daliCreatePipeline(daliPipelineHandle *pipe_handle,
int separated_execution,
int prefetch_queue_depth,
int cpu_prefetch_queue_depth,
int gpu_prefetch_queue_depth) {
int gpu_prefetch_queue_depth,
int get_memory_stats) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nitpick: get_memory_stats gives me the feeling that is a function name. Consider naming it something like "memory_stats_enabled'

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

bool se = separated_execution != 0;
auto pipeline = std::make_unique<dali::Pipeline>(std::string(serialized_pipeline, length),
batch_size, num_threads, device_id, true,
Expand All @@ -118,6 +119,7 @@ void daliCreatePipeline(daliPipelineHandle *pipe_handle,
if (se) {
pipeline->SetQueueSizes(cpu_prefetch_queue_depth, gpu_prefetch_queue_depth);
}
pipeline->EnableOperatorOutputMemoryStatistics(get_memory_stats);
pipeline->Build();
auto ws = std::make_unique<dali::DeviceWorkspace>();
auto stream = dali::CUDAStream::Create(true);
Expand Down Expand Up @@ -441,3 +443,34 @@ void daliGetReaderMetadata(daliPipelineHandle* pipe_handle, const char *reader_n
meta->pad_last_batch = returned_meta.pad_last_batch;
meta->stick_to_shard = returned_meta.stick_to_shard;
}

void daliGetExecutorMetadata(daliPipelineHandle* pipe_handle, daliExecutorMetadata **operator_meta,
size_t *operator_meta_num) {
dali::Pipeline* pipeline = reinterpret_cast<dali::Pipeline*>(pipe_handle->pipe);
auto returned_meta = pipeline->GetExecutorMeta();
*operator_meta_num = returned_meta.size();
*operator_meta = static_cast<daliExecutorMetadata*>(malloc(sizeof(daliExecutorMetadata) *
returned_meta.size()));

int i = 0;
for (const auto &stat : returned_meta) {
auto op_name_size = stat.first.size();
(*operator_meta)[i].operator_name = static_cast<char*>(malloc(sizeof(char) *
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
(*operator_meta)[i].operator_name = static_cast<char*>(malloc(sizeof(char) *
auto &op_meta = (*operator_meta)[i];
op_meta.operator_name = static_cast<char*>(malloc(sizeof(char) *

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

(op_name_size + 1)));
stat.first.copy((*operator_meta)[i].operator_name, op_name_size);
(*operator_meta)[i].operator_name[op_name_size] = '\0';

auto num_outputs = stat.second.size();
(*operator_meta)[i].out_num = num_outputs;
(*operator_meta)[i].real_size = static_cast<size_t*>(malloc(sizeof(size_t) * num_outputs));
(*operator_meta)[i].reserved = static_cast<size_t*>(malloc(sizeof(size_t) * num_outputs));

for (size_t j = 0; j < num_outputs; ++j) {
const auto &entry = stat.second[j];
(*operator_meta)[i].real_size[j] = entry.real_size;
(*operator_meta)[i].reserved[j] = entry.reserved;
}
++i;
}
}

36 changes: 31 additions & 5 deletions dali/c_api/c_api_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ TYPED_TEST(CApiTest, FileReaderPipe) {
daliPipelineHandle handle;
daliCreatePipeline(&handle, serialized.c_str(), serialized.size(), batch_size, num_thread,
device_id, false, prefetch_queue_depth, prefetch_queue_depth,
prefetch_queue_depth);
prefetch_queue_depth, false);
daliPrefetchUniform(&handle, prefetch_queue_depth);

dali::DeviceWorkspace ws;
Expand Down Expand Up @@ -223,7 +223,7 @@ TYPED_TEST(CApiTest, ExternalSourceSingleAllocPipe) {
daliPipelineHandle handle;
daliCreatePipeline(&handle, serialized.c_str(), serialized.size(), batch_size, num_thread,
device_id, false, prefetch_queue_depth, prefetch_queue_depth,
prefetch_queue_depth);
prefetch_queue_depth, false);

for (int i = 0; i < prefetch_queue_depth; i++) {
SequentialFill(view<uint8_t>(input_cpu), 42 * i);
Expand Down Expand Up @@ -280,7 +280,7 @@ TYPED_TEST(CApiTest, ExternalSourceMultipleAllocPipe) {
daliPipelineHandle handle;
daliCreatePipeline(&handle, serialized.c_str(), serialized.size(), batch_size, num_thread,
device_id, false, prefetch_queue_depth, prefetch_queue_depth,
prefetch_queue_depth);
prefetch_queue_depth, false);

for (int i = 0; i < prefetch_queue_depth; i++) {
SequentialFill(view<uint8_t>(input_cpu), 42 * i);
Expand Down Expand Up @@ -340,7 +340,7 @@ TYPED_TEST(CApiTest, ExternalSourceSingleAllocDifferentBackendsTest) {
daliPipelineHandle handle;
daliCreatePipeline(&handle, serialized.c_str(), serialized.size(), batch_size, num_thread,
device_id, false, prefetch_queue_depth, prefetch_queue_depth,
prefetch_queue_depth);
prefetch_queue_depth, false);

for (int i = 0; i < prefetch_queue_depth; i++) {
SequentialFill(view<uint8_t>(input_cpu), 42 * i);
Expand Down Expand Up @@ -404,7 +404,7 @@ TYPED_TEST(CApiTest, ExternalSourceMultipleAllocDifferentBackendsTest) {
daliPipelineHandle handle;
daliCreatePipeline(&handle, serialized.c_str(), serialized.size(), batch_size, num_thread,
device_id, false, prefetch_queue_depth, prefetch_queue_depth,
prefetch_queue_depth);
prefetch_queue_depth, false);

for (int i = 0; i < prefetch_queue_depth; i++) {
SequentialFill(view<uint8_t>(input_cpu), 42 * i);
Expand Down Expand Up @@ -445,4 +445,30 @@ TYPED_TEST(CApiTest, ExternalSourceMultipleAllocDifferentBackendsTest) {
ComparePipelinesOutputs<OpBackend>(handle, *pipe_ptr);
}

TYPED_TEST(CApiTest, TestExecutorMeta) {
auto pipe_ptr = GetTestPipeline<TypeParam>(true, this->output_device_);
auto serialized = pipe_ptr->SerializeToProtobuf();

pipe_ptr.reset();
daliPipelineHandle handle;
daliCreatePipeline(&handle, serialized.c_str(), serialized.size(), batch_size, num_thread,
device_id, false, prefetch_queue_depth, prefetch_queue_depth,
prefetch_queue_depth, true);

daliRun(&handle);
daliOutput(&handle);

size_t N;
daliExecutorMetadata *meta;
daliGetExecutorMetadata(&handle, &meta, &N);
EXPECT_EQ(N, 4);
for (size_t i = 0; i < N; ++i) {
free(meta[i].operator_name);
free(meta[i].real_size);
free(meta[i].reserved);
}
free(meta);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe it would be beneficial to add function for freeing such metadata to C api?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

CUDA_CALL(cudaDeviceSynchronize());
}

} // namespace dali
16 changes: 16 additions & 0 deletions dali/pipeline/data/tensor_vector.h
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,22 @@ class TensorVector {
return tensors_.size();
}

size_t nbytes() const noexcept {
size_t total_nbytes = 0;
for (const auto &t : tensors_) {
total_nbytes += t->nbytes();
}
return total_nbytes;
}

size_t capacity() const noexcept {
size_t total_capacity = 0;
for (const auto &t : tensors_) {
total_capacity += t->capacity();
}
return total_capacity;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Those functions need to be handle similarly to, for example shape(), so

if (state_ == State::contiguous) {
  return tl->nbytes();
} 
size_t total_nbytes = 0;
for (const auto &t : tensors_) {
  total_nbytes += t->nbytes();
}
return total_nbytes;

as you can have them backed by tensor list or vector of tensors.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

}

TensorListShape<> shape() const {
if (state_ == State::contiguous) {
return tl_->shape();
Expand Down
66 changes: 66 additions & 0 deletions dali/pipeline/executor/executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
#include <string>
#include <utility>
#include <vector>
#include <unordered_map>
#include <mutex>

#include "dali/core/common.h"
#include "dali/core/error_handling.h"
Expand All @@ -38,9 +40,15 @@
#include "dali/pipeline/workspace/host_workspace.h"
#include "dali/pipeline/workspace/mixed_workspace.h"
#include "dali/pipeline/workspace/workspace_data_factory.h"
#include "dali/pipeline/data/backend.h"

namespace dali {

struct DLL_PUBLIC ExecutorMeta {
size_t real_size;
size_t reserved;
};
using ExecutorMetaMap = std::unordered_map<std::string, std::vector<ExecutorMeta>>;

namespace detail {
// This is stream callback used on GPU stream to indicate that GPU work for this
Expand All @@ -62,6 +70,8 @@ class DLL_PUBLIC ExecutorBase {
DLL_PUBLIC virtual void ShareOutputs(DeviceWorkspace *ws) = 0;
DLL_PUBLIC virtual void ReleaseOutputs() = 0;
DLL_PUBLIC virtual void SetCompletionCallback(ExecutorCallback cb) = 0;
DLL_PUBLIC virtual void EnableMemoryStats(bool get_memory_stats = false) = 0;
DLL_PUBLIC virtual ExecutorMetaMap GetExecutorMeta() = 0;

protected:
// virtual to allow the TestPruneWholeGraph test in gcc
Expand Down Expand Up @@ -101,6 +111,9 @@ class DLL_PUBLIC Executor : public ExecutorBase, public WorkspacePolicy, public
stage_queue_depths_ = QueuePolicy::GetQueueSizes(prefetch_queue_depth);
}

DLL_PUBLIC void EnableMemoryStats(bool get_memory_stats = false) override {
get_memory_stats_ = get_memory_stats;
}
DLL_PUBLIC void Build(OpGraph *graph, vector<string> output_names) override;
DLL_PUBLIC void Init() override {}
DLL_PUBLIC void RunCPU() override;
Expand All @@ -110,6 +123,7 @@ class DLL_PUBLIC Executor : public ExecutorBase, public WorkspacePolicy, public
DLL_PUBLIC void ShareOutputs(DeviceWorkspace *ws) override;
DLL_PUBLIC void ReleaseOutputs() override;
DLL_PUBLIC void SetCompletionCallback(ExecutorCallback cb) override;
DLL_PUBLIC ExecutorMetaMap GetExecutorMeta() override;

DLL_PUBLIC void ShutdownQueue() {
QueuePolicy::SignalStop();
Expand All @@ -118,6 +132,31 @@ class DLL_PUBLIC Executor : public ExecutorBase, public WorkspacePolicy, public
DISABLE_COPY_MOVE_ASSIGN(Executor);

protected:
template <typename W>
inline void FillStats(ExecutorMetaMap &memory_stats, W ws, std::string op_name,
std::mutex &write_mutex) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nitpick: please add a space here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

if (get_memory_stats_) {
size_t out_size = 0;
size_t reserved_size = 0;
std::lock_guard<std::mutex> lck(write_mutex);
auto &stats = memory_stats[op_name];
stats.resize(ws.NumOutput(), {0, 0});
for (int i = 0; i < ws.NumOutput(); ++i) {
out_size = 0;
reserved_size = 0;
if (ws.template OutputIsType<CPUBackend>(i)) {
out_size = ws.template OutputRef<CPUBackend>(i).nbytes();
reserved_size = ws.template OutputRef<CPUBackend>(i).capacity();
} else {
out_size = ws.template OutputRef<GPUBackend>(i).nbytes();
reserved_size = ws.template OutputRef<GPUBackend>(i).capacity();
}
stats[i].real_size = std::max(out_size, stats[i].real_size);
stats[i].reserved = std::max(reserved_size, stats[i].reserved);
}
}
}

void HandleError(const char *message = "Unknown exception") {
exec_error_ = true;
ShutdownQueue();
Expand Down Expand Up @@ -220,6 +259,12 @@ class DLL_PUBLIC Executor : public ExecutorBase, public WorkspacePolicy, public
// in some edge cases where there are no operators
std::vector<cudaEvent_t> mixed_callback_events_;

std::atomic<bool> get_memory_stats_ = ATOMIC_VAR_INIT(false);;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Double ;.
Do we really need the init, isn't constructor enough? Maybe with C API we need.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

ExecutorMetaMap cpu_memory_stats_, mixed_memory_stats_, gpu_memory_stats_;
std::mutex cpu_memory_stats_mutex_;
std::mutex mixed_memory_stats_mutex_;
std::mutex gpu_memory_stats_mutex_;

private:
template <typename Workspace>
void RunHelper(OpNode &op_node, Workspace &ws) {
Expand Down Expand Up @@ -266,6 +311,23 @@ void Executor<WorkspacePolicy, QueuePolicy>::SetCompletionCallback(ExecutorCallb
}
}

template<typename map>
void AppendToMap(map &ret, map &in_stats, std::mutex &mutex) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
template<typename map>
void AppendToMap(map &ret, map &in_stats, std::mutex &mutex) {
void AppendToMap(ExecutorMetaMap &ret, const ExecutorMetaMap &in_stats, std::mutex &mutex) {

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

const std::lock_guard<std::mutex> lock(mutex);
for (auto const& stats : in_stats) {
ret.emplace(stats);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't

Suggested change
for (auto const& stats : in_stats) {
ret.emplace(stats);
}
ret.insert(stats.begin(), stats.end());

also work?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

}

template <typename WorkspacePolicy, typename QueuePolicy>
ExecutorMetaMap Executor<WorkspacePolicy, QueuePolicy>::GetExecutorMeta() {
ExecutorMetaMap ret;
AppendToMap(ret, cpu_memory_stats_, cpu_memory_stats_mutex_);
AppendToMap(ret, mixed_memory_stats_, mixed_memory_stats_mutex_);
AppendToMap(ret, gpu_memory_stats_, gpu_memory_stats_mutex_);
return ret;
}

template <typename WorkspacePolicy, typename QueuePolicy>
void Executor<WorkspacePolicy, QueuePolicy>::Build(OpGraph *graph, vector<string> output_names) {
DALI_ENFORCE(graph != nullptr, "Input graph is nullptr.");
Expand Down Expand Up @@ -347,6 +409,7 @@ void Executor<WorkspacePolicy, QueuePolicy>::RunCPU() {

try {
RunHelper(op_node, ws);
FillStats(cpu_memory_stats_, ws, "CPU_" + op_node.instance_name, cpu_memory_stats_mutex_);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instance names should be unique, what's the rationale for the CPU_ suffixes etc?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you have an operator for CPU and GPU then it is hard to tell which instance is placed where. Name is unique but a bit mangled and not always self explanatory to the user.

} catch (std::exception &e) {
HandleError(e.what());
} catch (...) {
Expand Down Expand Up @@ -381,6 +444,8 @@ void Executor<WorkspacePolicy, QueuePolicy>::RunMixed() {
TimeRange tr("[Executor] Run Mixed op " + op_node.instance_name,
TimeRange::kOrange);
RunHelper(op_node, ws);
FillStats(mixed_memory_stats_, ws, "MIXED_" + op_node.instance_name,
mixed_memory_stats_mutex_);
if (ws.has_stream() && ws.has_event()) {
CUDA_CALL(cudaEventRecord(ws.event(), ws.stream()));
}
Expand Down Expand Up @@ -438,6 +503,7 @@ void Executor<WorkspacePolicy, QueuePolicy>::RunGPU() {
TimeRange tr("[Executor] Run GPU op " + op_node.instance_name,
TimeRange::knvGreen);
RunHelper(op_node, ws);
FillStats(gpu_memory_stats_, ws, "GPU_" + op_node.instance_name, gpu_memory_stats_mutex_);
if (ws.has_event()) {
CUDA_CALL(cudaEventRecord(ws.event(), ws.stream()));
}
Expand Down
5 changes: 3 additions & 2 deletions dali/pipeline/pipeline.cc
Original file line number Diff line number Diff line change
Expand Up @@ -167,11 +167,11 @@ int Pipeline::AddOperator(const OpSpec &spec, const std::string& inst_name) {
}

int Pipeline::AddOperator(const OpSpec &spec, int logical_id) {
return AddOperator(spec, "<no name>", logical_id);
return AddOperator(spec, make_string("<no name>_", logical_id), logical_id);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, now it looks like they are unique :D

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's from the C++ API POV, in Python we use:
self._name = '__' + type(op).__name__ + "_" + str(self._counter.id)

If you are already accessing this, maybe we can unify and set something similar here based on the Operator name instead of <no name>?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

}

int Pipeline::AddOperator(const OpSpec &spec) {
return AddOperator(spec, "<no name>", GetNextLogicalId());
return AddOperator(spec, GetNextLogicalId());
}


Expand Down Expand Up @@ -431,6 +431,7 @@ void Pipeline::Build(vector<std::pair<string, string>> output_names) {
executor_ = GetExecutor(pipelined_execution_, separated_execution_, async_execution_, batch_size_,
num_threads_, device_id_, bytes_per_sample_hint_, set_affinity_,
max_num_stream_, default_cuda_stream_priority_, prefetch_queue_depth_);
executor_->EnableMemoryStats(get_memory_stats_);
executor_->Init();

// Creating the graph
Expand Down
25 changes: 24 additions & 1 deletion dali/pipeline/pipeline.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,6 @@ class DLL_PUBLIC Pipeline {
default_cuda_stream_priority, QueueSizes{prefetch_queue_depth});
}


DLL_PUBLIC Pipeline(const string &serialized_pipe, int batch_size = -1, int num_threads = -1,
int device_id = -1, bool pipelined_execution = true,
int prefetch_queue_depth = 2, bool async_execution = true,
Expand Down Expand Up @@ -310,6 +309,29 @@ class DLL_PUBLIC Pipeline {
async_execution_ = async_execution;
}

/**
* @brief Set if the DALI pipeline should operator output buffer statistics
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Something went wrong in this @brief. Maybe the description from @param would be enough.

*
* @param get_memory_stats If DALI should print operator output buffer statistics.
* Usefull for `bytes_per_sample_hint` operator parameter.
*/
DLL_PUBLIC void EnableOperatorOutputMemoryStatistics(bool get_memory_stats = true) {
get_memory_stats_ = get_memory_stats;
if (executor_) {
JanuszL marked this conversation as resolved.
Show resolved Hide resolved
executor_->EnableMemoryStats(get_memory_stats_);
}
}

/**
* @brief Obtains the executor statistics
*/
DLL_PUBLIC ExecutorMetaMap GetExecutorMeta() {
if (executor_) {
return executor_->GetExecutorMeta();
} else {
return {};
}
}

/**
* @brief Set queue sizes for Pipeline using Separated Queues
Expand Down Expand Up @@ -492,6 +514,7 @@ class DLL_PUBLIC Pipeline {
int next_logical_id_ = 0;
int next_internal_logical_id_ = -1;
QueueSizes prefetch_queue_depth_;
bool get_memory_stats_ = false;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
bool get_memory_stats_ = false;
bool memory_stats_enabled_ = false;

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done


std::vector<int64_t> seed_;
int original_seed_;
Expand Down
Loading