-
Notifications
You must be signed in to change notification settings - Fork 610
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Dump operator stats #2039
Dump operator stats #2039
Changes from all commits
8ab6cfc
416e5a7
22fd3bc
7f75031
ce82218
41bd1b1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -21,6 +21,8 @@ | |
#include <string> | ||
#include <utility> | ||
#include <vector> | ||
#include <unordered_map> | ||
#include <mutex> | ||
|
||
#include "dali/core/common.h" | ||
#include "dali/core/error_handling.h" | ||
|
@@ -38,15 +40,26 @@ | |
#include "dali/pipeline/workspace/host_workspace.h" | ||
#include "dali/pipeline/workspace/mixed_workspace.h" | ||
#include "dali/pipeline/workspace/workspace_data_factory.h" | ||
#include "dali/pipeline/data/backend.h" | ||
|
||
namespace dali { | ||
|
||
struct DLL_PUBLIC ExecutorMeta { | ||
size_t real_size; | ||
size_t max_real_size; | ||
size_t reserved; | ||
size_t max_reserved; | ||
}; | ||
using ExecutorMetaMap = std::unordered_map<std::string, std::vector<ExecutorMeta>>; | ||
|
||
namespace detail { | ||
// This is stream callback used on GPU stream to indicate that GPU work for this | ||
// pipeline run is finished | ||
static void gpu_finished_callback(cudaStream_t stream, cudaError_t status, void *userData); | ||
|
||
// helper function to concatenate ExecutorMetaMap maps | ||
static void AppendToMap(ExecutorMetaMap &ret, ExecutorMetaMap &in_stats, std::mutex &mutex); | ||
|
||
} // namespace detail | ||
|
||
class DLL_PUBLIC ExecutorBase { | ||
|
@@ -62,6 +75,8 @@ class DLL_PUBLIC ExecutorBase { | |
DLL_PUBLIC virtual void ShareOutputs(DeviceWorkspace *ws) = 0; | ||
DLL_PUBLIC virtual void ReleaseOutputs() = 0; | ||
DLL_PUBLIC virtual void SetCompletionCallback(ExecutorCallback cb) = 0; | ||
DLL_PUBLIC virtual void EnableMemoryStats(bool enable_memory_stats = false) = 0; | ||
DLL_PUBLIC virtual ExecutorMetaMap GetExecutorMeta() = 0; | ||
|
||
protected: | ||
// virtual to allow the TestPruneWholeGraph test in gcc | ||
|
@@ -94,13 +109,17 @@ class DLL_PUBLIC Executor : public ExecutorBase, public WorkspacePolicy, public | |
exec_error_(false), | ||
queue_sizes_(prefetch_queue_depth), | ||
mixed_op_stream_(0), | ||
gpu_op_stream_(0) { | ||
gpu_op_stream_(0), | ||
enable_memory_stats_(false) { | ||
DALI_ENFORCE(batch_size_ > 0, "Batch size must be greater than 0."); | ||
DALI_ENFORCE(device_id >= 0, "Device id must be non-negative."); | ||
|
||
stage_queue_depths_ = QueuePolicy::GetQueueSizes(prefetch_queue_depth); | ||
} | ||
|
||
DLL_PUBLIC void EnableMemoryStats(bool enable_memory_stats = false) override { | ||
enable_memory_stats_ = enable_memory_stats; | ||
} | ||
DLL_PUBLIC void Build(OpGraph *graph, vector<string> output_names) override; | ||
DLL_PUBLIC void Init() override {} | ||
DLL_PUBLIC void RunCPU() override; | ||
|
@@ -110,6 +129,7 @@ class DLL_PUBLIC Executor : public ExecutorBase, public WorkspacePolicy, public | |
DLL_PUBLIC void ShareOutputs(DeviceWorkspace *ws) override; | ||
DLL_PUBLIC void ReleaseOutputs() override; | ||
DLL_PUBLIC void SetCompletionCallback(ExecutorCallback cb) override; | ||
DLL_PUBLIC ExecutorMetaMap GetExecutorMeta() override; | ||
|
||
DLL_PUBLIC void ShutdownQueue() { | ||
QueuePolicy::SignalStop(); | ||
|
@@ -118,6 +138,75 @@ class DLL_PUBLIC Executor : public ExecutorBase, public WorkspacePolicy, public | |
DISABLE_COPY_MOVE_ASSIGN(Executor); | ||
|
||
protected: | ||
template<typename T> | ||
inline void GetMaxSizesCont(T &in, size_t &max_out_size, size_t &max_reserved_size) { | ||
auto out_size = in.nbytes(); | ||
auto reserved_size = in.capacity(); | ||
max_out_size = std::max<size_t>(std::ceil((out_size * 1.0) / in.ntensor()), max_out_size); | ||
max_reserved_size = std::max<size_t>(std::ceil((reserved_size * 1.0) / in.ntensor()), | ||
max_reserved_size); | ||
} | ||
|
||
template<typename T> | ||
inline void GetMaxSizesNonCont(T &in, size_t &max_out_size, size_t &max_reserved_size) { | ||
for (size_t j = 0; j < in.ntensor(); ++j) { | ||
max_out_size = std::max(in[j].nbytes(), max_out_size); | ||
max_reserved_size = std::max(in[j].capacity(), max_reserved_size); | ||
} | ||
} | ||
|
||
template<typename backend> | ||
inline void GetMaxSizes(TensorList<backend> &in, size_t &max_out_size, | ||
size_t &max_reserved_size) { | ||
GetMaxSizesCont(in, max_out_size, max_reserved_size); | ||
} | ||
|
||
template<typename backend> | ||
inline void GetMaxSizes(TensorVector<backend> &in, size_t &max_out_size, | ||
size_t &max_reserved_size) { | ||
if (in.IsContiguous()) { | ||
GetMaxSizesCont(in, max_out_size, max_reserved_size); | ||
} else { | ||
GetMaxSizesNonCont(in, max_out_size, max_reserved_size); | ||
} | ||
} | ||
|
||
template <typename W> | ||
inline void FillStats(ExecutorMetaMap &memory_stats, W ws, std::string op_name, | ||
std::mutex &write_mutex) { | ||
if (enable_memory_stats_) { | ||
size_t out_size = 0; | ||
size_t max_out_size = 0; | ||
size_t reserved_size = 0; | ||
size_t max_reserved_size = 0; | ||
std::lock_guard<std::mutex> lck(write_mutex); | ||
auto &stats = memory_stats[op_name]; | ||
stats.resize(ws.NumOutput(), {0, 0}); | ||
|
||
for (int i = 0; i < ws.NumOutput(); ++i) { | ||
out_size = 0; | ||
max_out_size = 0; | ||
reserved_size = 0; | ||
max_reserved_size = 0; | ||
if (ws.template OutputIsType<CPUBackend>(i)) { | ||
auto &out = ws.template OutputRef<CPUBackend>(i); | ||
out_size = out.nbytes(); | ||
reserved_size = out.capacity(); | ||
GetMaxSizes(out, max_out_size, max_reserved_size); | ||
} else { | ||
auto &out = ws.template OutputRef<GPUBackend>(i); | ||
out_size = out.nbytes(); | ||
reserved_size = out.capacity(); | ||
GetMaxSizes(out, max_out_size, max_reserved_size); | ||
} | ||
stats[i].real_size = std::max(out_size, stats[i].real_size); | ||
stats[i].max_real_size = std::max(max_out_size, stats[i].max_real_size); | ||
stats[i].reserved = std::max(reserved_size, stats[i].reserved); | ||
stats[i].max_reserved = std::max(max_reserved_size, stats[i].max_reserved); | ||
} | ||
} | ||
} | ||
|
||
void HandleError(const char *message = "Unknown exception") { | ||
exec_error_ = true; | ||
ShutdownQueue(); | ||
|
@@ -220,6 +309,12 @@ class DLL_PUBLIC Executor : public ExecutorBase, public WorkspacePolicy, public | |
// in some edge cases where there are no operators | ||
std::vector<cudaEvent_t> mixed_callback_events_; | ||
|
||
std::atomic<bool> enable_memory_stats_; | ||
ExecutorMetaMap cpu_memory_stats_, mixed_memory_stats_, gpu_memory_stats_; | ||
std::mutex cpu_memory_stats_mutex_; | ||
std::mutex mixed_memory_stats_mutex_; | ||
std::mutex gpu_memory_stats_mutex_; | ||
|
||
private: | ||
template <typename Workspace> | ||
void RunHelper(OpNode &op_node, Workspace &ws) { | ||
|
@@ -266,6 +361,15 @@ void Executor<WorkspacePolicy, QueuePolicy>::SetCompletionCallback(ExecutorCallb | |
} | ||
} | ||
|
||
template <typename WorkspacePolicy, typename QueuePolicy> | ||
ExecutorMetaMap Executor<WorkspacePolicy, QueuePolicy>::GetExecutorMeta() { | ||
ExecutorMetaMap ret; | ||
detail::AppendToMap(ret, cpu_memory_stats_, cpu_memory_stats_mutex_); | ||
detail::AppendToMap(ret, mixed_memory_stats_, mixed_memory_stats_mutex_); | ||
detail::AppendToMap(ret, gpu_memory_stats_, gpu_memory_stats_mutex_); | ||
return ret; | ||
} | ||
|
||
template <typename WorkspacePolicy, typename QueuePolicy> | ||
void Executor<WorkspacePolicy, QueuePolicy>::Build(OpGraph *graph, vector<string> output_names) { | ||
DALI_ENFORCE(graph != nullptr, "Input graph is nullptr."); | ||
|
@@ -347,6 +451,7 @@ void Executor<WorkspacePolicy, QueuePolicy>::RunCPU() { | |
|
||
try { | ||
RunHelper(op_node, ws); | ||
FillStats(cpu_memory_stats_, ws, "CPU_" + op_node.instance_name, cpu_memory_stats_mutex_); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Instance names should be unique, what's the rationale for the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If you have an operator for CPU and GPU then it is hard to tell which instance is placed where. Name is unique but a bit mangled and not always self explanatory to the user. |
||
} catch (std::exception &e) { | ||
HandleError(e.what()); | ||
} catch (...) { | ||
|
@@ -381,6 +486,8 @@ void Executor<WorkspacePolicy, QueuePolicy>::RunMixed() { | |
TimeRange tr("[Executor] Run Mixed op " + op_node.instance_name, | ||
TimeRange::kOrange); | ||
RunHelper(op_node, ws); | ||
FillStats(mixed_memory_stats_, ws, "MIXED_" + op_node.instance_name, | ||
mixed_memory_stats_mutex_); | ||
if (ws.has_stream() && ws.has_event()) { | ||
CUDA_CALL(cudaEventRecord(ws.event(), ws.stream())); | ||
} | ||
|
@@ -438,6 +545,7 @@ void Executor<WorkspacePolicy, QueuePolicy>::RunGPU() { | |
TimeRange tr("[Executor] Run GPU op " + op_node.instance_name, | ||
TimeRange::knvGreen); | ||
RunHelper(op_node, ws); | ||
FillStats(gpu_memory_stats_, ws, "GPU_" + op_node.instance_name, gpu_memory_stats_mutex_); | ||
if (ws.has_event()) { | ||
CUDA_CALL(cudaEventRecord(ws.event(), ws.stream())); | ||
} | ||
|
@@ -724,15 +832,20 @@ void Executor<WorkspacePolicy, QueuePolicy>::SetupOutputQueuesForGraph() { | |
|
||
using SimpleExecutor = Executor<AOT_WS_Policy<UniformQueuePolicy>, UniformQueuePolicy>; | ||
|
||
|
||
namespace detail { | ||
|
||
void gpu_finished_callback(cudaStream_t stream, cudaError_t status, void *userData) { | ||
auto callback = static_cast<ExecutorBase::ExecutorCallback*>(userData); | ||
(*callback)(); | ||
} | ||
|
||
} // namespace detail | ||
void AppendToMap(ExecutorMetaMap &ret, ExecutorMetaMap &in_stats, std::mutex &mutex) { | ||
const std::lock_guard<std::mutex> lock(mutex); | ||
ret.insert(in_stats.begin(), in_stats.end()); | ||
} | ||
|
||
} // namespace detail | ||
|
||
} // namespace dali | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have some reservations regarding this. In the case of contiguous TensorVector, the Tensors that you're accessing with operator[] here are kind of views into the backing TensorList. So the
capacity
will probably match exactly thenbytes
here, but the total capacity of the TensorVector might be bigger thannsamples * max(number of bytes)
(I suspect).This can be probably checked with a test that sets the TensorVector to contiguous mode, resizes it to some big shapes and than sets all the shapes to be for example half the initial size.
Can you check this? I think the best solution would be to take max of this reserved size and the average reserved size you calculate for the TensorList in function above.