diff --git a/paddle/fluid/platform/profiler/chrometracing_logger.cc b/paddle/fluid/platform/profiler/chrometracing_logger.cc index f728a820bd73c..72d343692df73 100644 --- a/paddle/fluid/platform/profiler/chrometracing_logger.cc +++ b/paddle/fluid/platform/profiler/chrometracing_logger.cc @@ -27,7 +27,7 @@ limitations under the License. */ namespace paddle { namespace platform { -static const char* kSchemaVersion = "1.0.0"; +static const char* kSchemaVersion = "1.0.1"; static const char* kDefaultFilename = "pid_%s_time_%s.paddle_trace.json"; static uint32_t span_indx = 0; @@ -37,14 +37,6 @@ static std::string DefaultFileName() { GetStringFormatLocalTime().c_str()); } -const char* ChromeTracingLogger::categary_name_[] = { - "Operator", "Dataloader", "ProfileStep", - "CudaRuntime", "Kernel", "Memcpy", - "Memset", "UserDefined", "OperatorInner", - "Forward", "Backward", "Optimization", - "Communication", "PythonOp", "PythonUserDefined", - "MluRuntime"}; - void ChromeTracingLogger::OpenFile() { output_file_stream_.open(filename_, std::ofstream::out | std::ofstream::trunc); @@ -116,10 +108,41 @@ void ChromeTracingLogger::LogNodeTrees(const NodeTrees& node_trees) { (*devicenode)->LogMe(this); } } + for (auto memnode = (*hostnode)->GetMemTraceEventNodes().begin(); + memnode != (*hostnode)->GetMemTraceEventNodes().end(); ++memnode) { + (*memnode)->LogMe(this); + } } } } +void ChromeTracingLogger::LogMemTraceEventNode( + const MemTraceEventNode& mem_node) { + if (!output_file_stream_) { + return; + } + output_file_stream_ << string_format( + std::string( + R"JSON( + { + "name": "[memory]", "pid": %lld, "tid": "%lld", + "ts": %lld, + "ph": "i", "cat": "%s", + "args": { + "place": "%s", + "addr": "%llu", + "current_allocated": %llu, + "current_reserved": %llu, + "increase_bytes": %lld + } + }, + )JSON"), + mem_node.ProcessId(), mem_node.ThreadId(), mem_node.TimeStampNs(), + StringTracerMemEventType(mem_node.Type()), mem_node.Place().c_str(), + mem_node.Addr(), mem_node.CurrentAllocated(), mem_node.CurrentReserved(), + mem_node.IncreaseBytes()); +} + void ChromeTracingLogger::LogHostTraceEventNode( const HostTraceEventNode& host_node) { if (!output_file_stream_) { @@ -132,6 +155,16 @@ void ChromeTracingLogger::LogHostTraceEventNode( } else { dur_display = string_format(std::string("%.3f us"), dur * 1000); } + std::map>> input_shapes; + std::map> input_dtypes; + std::string callstack; + OperatorSupplementEventNode* op_supplement_node = + host_node.GetOperatorSupplementEventNode(); + if (op_supplement_node != nullptr) { + input_shapes = op_supplement_node->InputShapes(); + input_dtypes = op_supplement_node->Dtypes(); + callstack = op_supplement_node->CallStack(); + } switch (host_node.Type()) { case TracerEventType::ProfileStep: case TracerEventType::Forward: @@ -159,10 +192,48 @@ void ChromeTracingLogger::LogHostTraceEventNode( host_node.Name().c_str(), dur_display.c_str(), host_node.ProcessId(), host_node.ThreadId(), nsToUs(host_node.StartNs()), nsToUsFloat(host_node.Duration()), - categary_name_[static_cast(host_node.Type())], + StringTracerEventType(host_node.Type()), nsToUsFloat(host_node.StartNs(), start_time_), nsToUsFloat(host_node.EndNs(), start_time_)); break; + + case TracerEventType::Operator: + + output_file_stream_ << string_format( + std::string( + R"JSON( + { + "name": "%s[%s]", "pid": %lld, "tid": "%lld(C++)", + "ts": %lld, "dur": %.3f, + "ph": "X", "cat": "%s", + "cname": "thread_state_runnable", + "args": { + "start_time": "%.3f us", + "end_time": "%.3f us", + "input_shapes": %s, + "input_dtypes": %s, + "callstack": "%s" + } + }, + )JSON"), + host_node.Name().c_str(), dur_display.c_str(), host_node.ProcessId(), + host_node.ThreadId(), nsToUs(host_node.StartNs()), + nsToUsFloat(host_node.Duration()), + StringTracerEventType(host_node.Type()), + nsToUsFloat(host_node.StartNs(), start_time_), + nsToUsFloat(host_node.EndNs(), start_time_), + json_dict(input_shapes).c_str(), json_dict(input_dtypes).c_str(), + callstack.c_str()); + break; + case TracerEventType::CudaRuntime: + case TracerEventType::Kernel: + case TracerEventType::Memcpy: + case TracerEventType::Memset: + case TracerEventType::UserDefined: + case TracerEventType::OperatorInner: + case TracerEventType::Communication: + case TracerEventType::MluRuntime: + case TracerEventType::NumTypes: default: output_file_stream_ << string_format( std::string( @@ -181,7 +252,7 @@ void ChromeTracingLogger::LogHostTraceEventNode( host_node.Name().c_str(), dur_display.c_str(), host_node.ProcessId(), host_node.ThreadId(), nsToUs(host_node.StartNs()), nsToUsFloat(host_node.Duration()), - categary_name_[static_cast(host_node.Type())], + StringTracerEventType(host_node.Type()), nsToUsFloat(host_node.StartNs(), start_time_), nsToUsFloat(host_node.EndNs(), start_time_)); break; @@ -220,8 +291,7 @@ void ChromeTracingLogger::LogRuntimeTraceEventNode( runtime_node.Name().c_str(), dur_display.c_str(), runtime_node.ProcessId(), runtime_node.ThreadId(), nsToUs(runtime_node.StartNs()), nsToUsFloat(runtime_node.Duration()), - categary_name_[static_cast(runtime_node.Type())], - runtime_node.CorrelationId(), + StringTracerEventType(runtime_node.Type()), runtime_node.CorrelationId(), nsToUsFloat(runtime_node.StartNs(), start_time_), nsToUsFloat(runtime_node.EndNs(), start_time_)); pid_tid_set_.insert({runtime_node.ProcessId(), runtime_node.ThreadId()}); @@ -347,7 +417,7 @@ void ChromeTracingLogger::HandleTypeKernel( device_node.Name().c_str(), dur_display.c_str(), device_node.DeviceId(), device_node.StreamId(), nsToUs(device_node.StartNs()), nsToUsFloat(device_node.Duration()), - categary_name_[static_cast(device_node.Type())], + StringTracerEventType(device_node.Type()), nsToUsFloat(device_node.StartNs(), start_time_), nsToUsFloat(device_node.EndNs(), start_time_), device_node.DeviceId(), device_node.ContextId(), device_node.StreamId(), @@ -391,7 +461,7 @@ void ChromeTracingLogger::HandleTypeMemcpy( device_node.Name().c_str(), dur_display.c_str(), device_node.DeviceId(), device_node.StreamId(), nsToUs(device_node.StartNs()), nsToUsFloat(device_node.Duration()), - categary_name_[static_cast(device_node.Type())], + StringTracerEventType(device_node.Type()), nsToUsFloat(device_node.StartNs(), start_time_), nsToUsFloat(device_node.EndNs(), start_time_), device_node.StreamId(), device_node.CorrelationId(), memcpy_info.num_bytes, memory_bandwidth); @@ -427,7 +497,7 @@ void ChromeTracingLogger::HandleTypeMemset( device_node.Name().c_str(), dur_display.c_str(), device_node.DeviceId(), device_node.StreamId(), nsToUs(device_node.StartNs()), nsToUsFloat(device_node.Duration()), - categary_name_[static_cast(device_node.Type())], + StringTracerEventType(device_node.Type()), nsToUsFloat(device_node.StartNs(), start_time_), nsToUsFloat(device_node.EndNs(), start_time_), device_node.DeviceId(), device_node.ContextId(), device_node.StreamId(), diff --git a/paddle/fluid/platform/profiler/chrometracing_logger.h b/paddle/fluid/platform/profiler/chrometracing_logger.h index 12d98d1ef0c63..3cbf9ccf6a0cc 100644 --- a/paddle/fluid/platform/profiler/chrometracing_logger.h +++ b/paddle/fluid/platform/profiler/chrometracing_logger.h @@ -37,6 +37,7 @@ class ChromeTracingLogger : public BaseLogger { void LogRuntimeTraceEventNode(const CudaRuntimeTraceEventNode&) override; void LogNodeTrees(const NodeTrees&) override; void LogMetaInfo(const std::unordered_map); + void LogMemTraceEventNode(const MemTraceEventNode&) override; private: void OpenFile(); diff --git a/paddle/fluid/platform/profiler/dump/test_serialization_logger.cc b/paddle/fluid/platform/profiler/dump/test_serialization_logger.cc index 5253ecc505dbb..002071de0d1ef 100644 --- a/paddle/fluid/platform/profiler/dump/test_serialization_logger.cc +++ b/paddle/fluid/platform/profiler/dump/test_serialization_logger.cc @@ -27,7 +27,9 @@ using paddle::platform::HostTraceEventNode; using paddle::platform::KernelEventInfo; using paddle::platform::MemcpyEventInfo; using paddle::platform::MemsetEventInfo; +using paddle::platform::MemTraceEvent; using paddle::platform::NodeTrees; +using paddle::platform::OperatorSupplementEvent; using paddle::platform::ProfilerResult; using paddle::platform::RuntimeTraceEvent; using paddle::platform::SerializationLogger; @@ -37,6 +39,8 @@ TEST(SerializationLoggerTest, dump_case0) { std::list host_events; std::list runtime_events; std::list device_events; + std::list mem_events; + std::list op_supplement_events; host_events.push_back(HostTraceEvent(std::string("dataloader#1"), TracerEventType::Dataloader, 1000, 10000, 10, 10)); @@ -72,7 +76,8 @@ TEST(SerializationLoggerTest, dump_case0) { DeviceTraceEvent(std::string("memset1"), TracerEventType::Memset, 66000, 69000, 0, 10, 11, 5, MemsetEventInfo())); SerializationLogger logger("test_serialization_logger_case0.pb"); - NodeTrees tree(host_events, runtime_events, device_events); + NodeTrees tree(host_events, runtime_events, device_events, mem_events, + op_supplement_events); std::map> nodes = tree.Traverse(true); EXPECT_EQ(nodes[10].size(), 4u); @@ -101,6 +106,8 @@ TEST(SerializationLoggerTest, dump_case1) { std::list host_events; std::list runtime_events; std::list device_events; + std::list mem_events; + std::list op_supplement_events; runtime_events.push_back(RuntimeTraceEvent(std::string("cudalaunch1"), 15000, 17000, 10, 10, 1, 0)); runtime_events.push_back(RuntimeTraceEvent(std::string("cudalaunch2"), 25000, @@ -127,7 +134,8 @@ TEST(SerializationLoggerTest, dump_case1) { DeviceTraceEvent(std::string("memset1"), TracerEventType::Memset, 66000, 69000, 0, 10, 11, 5, MemsetEventInfo())); SerializationLogger logger("test_serialization_logger_case1.pb"); - NodeTrees tree(host_events, runtime_events, device_events); + NodeTrees tree(host_events, runtime_events, device_events, mem_events, + op_supplement_events); std::map> nodes = tree.Traverse(true); EXPECT_EQ(nodes[10].size(), 1u); diff --git a/paddle/fluid/platform/profiler/event_node.cc b/paddle/fluid/platform/profiler/event_node.cc index e1af63ad8909c..ca555d9c7b928 100644 --- a/paddle/fluid/platform/profiler/event_node.cc +++ b/paddle/fluid/platform/profiler/event_node.cc @@ -18,6 +18,8 @@ limitations under the License. */ #include #include +#include "paddle/fluid/platform/profiler/utils.h" + namespace paddle { namespace platform { @@ -50,8 +52,10 @@ NodeTrees::~NodeTrees() { void NodeTrees::BuildTrees( const std::vector& host_event_nodes, - std::vector& runtime_event_nodes, - const std::vector& device_event_nodes) { + const std::vector& runtime_event_nodes, + const std::vector& device_event_nodes, + const std::vector& mem_event_nodes, + const std::vector& op_supplement_events) { // separate Host Event Nodes into different threads std::map> thread2host_event_nodes; // used to store HostTraceEventNodes per thread @@ -59,6 +63,15 @@ void NodeTrees::BuildTrees( thread2runtime_event_nodes; // used to store CudaRuntimeTraceEventNode // per // thread + std::map> + thread2mem_event_nodes; // used to store MemTraceEventNode + // per + // thread + std::map> + thread2op_supplement_event_nodes; // used to store + // OperatorSupplementEventNode + // per + // thread std::map correlation_id2runtime_event_node; // used to store the relation between // correlation id and runtime node @@ -85,6 +98,15 @@ void NodeTrees::BuildTrees( "no corresponding cuda runtime events")); dst_iter->second->AddDeviceTraceEventNode(*it); } + // construct thread2mem_event_nodes + for (auto it = mem_event_nodes.begin(); it != mem_event_nodes.end(); ++it) { + thread2mem_event_nodes[(*it)->ThreadId()].push_back(*it); + } + // construct thread2op_supplement_event_nodes + for (auto it = op_supplement_events.begin(); it != op_supplement_events.end(); + ++it) { + thread2op_supplement_event_nodes[(*it)->ThreadId()].push_back(*it); + } // sort host event nodes and runtime event nodes according to start_ns and // end_ns // the smaller start_ns is, the further ahead position is. @@ -119,6 +141,29 @@ void NodeTrees::BuildTrees( return false; }); } + // sort mem event nodes and operator supplement event nodes + for (auto it = thread2mem_event_nodes.begin(); + it != thread2mem_event_nodes.end(); ++it) { + std::sort(it->second.begin(), it->second.end(), + [](MemTraceEventNode* node1, MemTraceEventNode* node2) { + if (node1->TimeStampNs() <= node2->TimeStampNs()) { + return true; + } + return false; + }); + } + + for (auto it = thread2op_supplement_event_nodes.begin(); + it != thread2op_supplement_event_nodes.end(); ++it) { + std::sort(it->second.begin(), it->second.end(), + [](OperatorSupplementEventNode* node1, + OperatorSupplementEventNode* node2) { + if (node1->TimeStampNs() <= node2->TimeStampNs()) { + return true; + } + return false; + }); + } // construct trees std::set thread_set; @@ -131,16 +176,27 @@ void NodeTrees::BuildTrees( it != thread2runtime_event_nodes.end(); ++it) { thread_set.insert(it->first); } + for (auto it = thread2mem_event_nodes.begin(); + it != thread2mem_event_nodes.end(); ++it) { + thread_set.insert(it->first); + } + for (auto it = thread2op_supplement_event_nodes.begin(); + it != thread2op_supplement_event_nodes.end(); ++it) { + thread_set.insert(it->first); + } for (auto it = thread_set.begin(); it != thread_set.end(); ++it) { thread_event_trees_map_[*it] = BuildTreeRelationship( - thread2host_event_nodes[*it], thread2runtime_event_nodes[*it]); + thread2host_event_nodes[*it], thread2runtime_event_nodes[*it], + thread2mem_event_nodes[*it], thread2op_supplement_event_nodes[*it]); } } HostTraceEventNode* NodeTrees::BuildTreeRelationship( std::vector host_event_nodes, - std::vector runtime_event_nodes) { + std::vector runtime_event_nodes, + std::vector mem_event_nodes, + std::vector op_supplement_events) { // a stack used for analyse relationship auto node_stack = std::vector(); // root node, top level @@ -226,6 +282,99 @@ HostTraceEventNode* NodeTrees::BuildTreeRelationship( } node_stack.pop_back(); } + + // build relationship between host event node and mem event node + // First, post-order traverse the tree. Then, insert the memory and op + // supplement node into correct host nodes. + auto stack = std::stack(); + auto flag_stack = std::stack(); + auto post_order_nodes = std::vector(); + stack.push(root_node); + flag_stack.push(0); + while (!stack.empty()) { + auto current_node = stack.top(); + stack.pop(); + auto flag = flag_stack.top(); + flag_stack.pop(); + if (flag == 0) { + stack.push(current_node); + flag_stack.push(1); + for (auto child = current_node->GetChildren().rbegin(); + child != current_node->GetChildren().rend(); ++child) { + stack.push(*child); + flag_stack.push(0); + } + } else { + post_order_nodes.push_back(current_node); + } + } + + for (auto it = post_order_nodes.begin(); it < post_order_nodes.end(); ++it) { + bool hasenter = false; + std::vector::iterator firstposition; + std::vector::iterator lastposition = + mem_event_nodes.end(); + for (auto mem_it = mem_event_nodes.begin(); mem_it < mem_event_nodes.end(); + ++mem_it) { + if ((*mem_it)->TimeStampNs() >= (*it)->StartNs() && + (*mem_it)->TimeStampNs() <= (*it)->EndNs()) { + (*it)->AddMemNode(*mem_it); + if (!hasenter) { + firstposition = mem_it; + hasenter = true; + } + } else { + if ((*mem_it)->TimeStampNs() > (*it)->EndNs()) { + lastposition = mem_it; + break; + } + } + } + if (hasenter) { + mem_event_nodes.erase(firstposition, lastposition); + } + } + + // build relationship between host event node and op supplement node + for (auto it = post_order_nodes.begin(); it < post_order_nodes.end(); ++it) { + int op_supplement_count = 0; + bool hasenter = false; + std::vector::iterator firstposition; + std::vector::iterator lastposition = + op_supplement_events.end(); + for (auto op_supplement_it = op_supplement_events.begin(); + op_supplement_it < op_supplement_events.end(); ++op_supplement_it) { + if ((*op_supplement_it)->TimeStampNs() >= (*it)->StartNs() && + (*op_supplement_it)->TimeStampNs() <= (*it)->EndNs()) { + if (!hasenter) { + firstposition = op_supplement_it; + hasenter = true; + } + (*it)->SetOperatorSupplementNode(*op_supplement_it); + PADDLE_ENFORCE_EQ((*it)->Type(), TracerEventType::Operator, + platform::errors::PreconditionNotMet( + "Operator supplement events should be embraced " + "by event of type TracerEventType::Operator, " + "but got type TracerEventType::%s", + StringTracerEventType((*it)->Type()))); + op_supplement_count += 1; + } else { + if ((*op_supplement_it)->TimeStampNs() > (*it)->EndNs()) { + PADDLE_ENFORCE_LE(op_supplement_count, 1, + platform::errors::PreconditionNotMet( + "One event of TracerEventType::Operator has no " + "more than 1 op supplement event, but got %d.", + op_supplement_count)); + lastposition = op_supplement_it; + break; + } + } + } + if (hasenter) { + op_supplement_events.erase(firstposition, lastposition); + } + } + return root_node; } @@ -263,8 +412,8 @@ std::map> NodeTrees::Traverse( auto current_node = stack.top(); stack.pop(); thread2host_event_nodes[thread_id].push_back(current_node); - for (auto child = current_node->GetChildren().begin(); - child != current_node->GetChildren().end(); ++child) { + for (auto child = current_node->GetChildren().rbegin(); + child != current_node->GetChildren().rend(); ++child) { stack.push(*child); } } @@ -278,7 +427,10 @@ void NodeTrees::LogMe(BaseLogger* logger) { logger->LogNodeTrees(*this); } void NodeTrees::HandleTrees( std::function host_event_node_handle, std::function runtime_event_node_handle, - std::function device_event_node_handle) { + std::function device_event_node_handle, + std::function mem_event_node_handle, + std::function + op_supplement_node_handle) { // using different user-defined function to handle different nodes const std::map> thread2host_event_nodes = Traverse(true); @@ -300,6 +452,15 @@ void NodeTrees::HandleTrees( device_event_node_handle(*devicenode); } } + for (auto memeventnode = (*hostnode)->GetMemTraceEventNodes().begin(); + memeventnode != (*hostnode)->GetMemTraceEventNodes().end(); + ++memeventnode) { + mem_event_node_handle(*memeventnode); + } + if ((*hostnode)->GetOperatorSupplementEventNode()) { + op_supplement_node_handle( + (*hostnode)->GetOperatorSupplementEventNode()); + } } } } diff --git a/paddle/fluid/platform/profiler/event_node.h b/paddle/fluid/platform/profiler/event_node.h index 3e589b0be2e04..acd5a03109f72 100644 --- a/paddle/fluid/platform/profiler/event_node.h +++ b/paddle/fluid/platform/profiler/event_node.h @@ -21,12 +21,67 @@ limitations under the License. */ #include #include "paddle/fluid/platform/enforce.h" +#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/profiler/output_logger.h" #include "paddle/fluid/platform/profiler/trace_event.h" namespace paddle { namespace platform { +class MemTraceEventNode { + public: + // constructor + explicit MemTraceEventNode(const MemTraceEvent& mem_event) + : mem_event_(mem_event) {} + + // destructor + ~MemTraceEventNode(); + + // getter + TracerMemEventType Type() const { return mem_event_.type; } + uint64_t Addr() const { return mem_event_.addr; } + uint64_t TimeStampNs() const { return mem_event_.timestamp_ns; } + uint64_t ProcessId() const { return mem_event_.process_id; } + uint64_t ThreadId() const { return mem_event_.thread_id; } + int64_t IncreaseBytes() const { return mem_event_.increase_bytes; } + std::string Place() const { return mem_event_.place; } + uint64_t CurrentAllocated() const { return mem_event_.current_allocated; } + uint64_t CurrentReserved() const { return mem_event_.current_reserved; } + + // member function + void LogMe(BaseLogger* logger) { logger->LogMemTraceEventNode(*this); } + + private: + // data + MemTraceEvent mem_event_; +}; + +class OperatorSupplementEventNode { + public: + // constructor + explicit OperatorSupplementEventNode( + const OperatorSupplementEvent& op_supplement_event) + : op_supplement_event_(op_supplement_event) {} + // destructor + ~OperatorSupplementEventNode() {} + // getter + std::string Name() const { return op_supplement_event_.op_type; } + uint64_t TimeStampNs() const { return op_supplement_event_.timestamp_ns; } + std::map>>& InputShapes() { + return op_supplement_event_.input_shapes; + } + std::map>& Dtypes() { + return op_supplement_event_.dtypes; + } + std::string CallStack() { return op_supplement_event_.callstack; } + uint64_t ProcessId() const { return op_supplement_event_.process_id; } + uint64_t ThreadId() const { return op_supplement_event_.thread_id; } + + private: + // data + OperatorSupplementEvent op_supplement_event_; +}; + class DeviceTraceEventNode { public: // constructor @@ -139,6 +194,10 @@ class HostTraceEventNode { void AddCudaRuntimeNode(CudaRuntimeTraceEventNode* node) { runtime_node_ptrs_.push_back(node); } + void AddMemNode(MemTraceEventNode* node) { mem_node_ptrs_.push_back(node); } + void SetOperatorSupplementNode(OperatorSupplementEventNode* node) { + op_supplement_node_ptr_ = node; + } const std::vector& GetChildren() const { return children_; } @@ -146,6 +205,14 @@ class HostTraceEventNode { const { return runtime_node_ptrs_; } + const std::vector& GetMemTraceEventNodes() const { + return mem_node_ptrs_; + } + + OperatorSupplementEventNode* GetOperatorSupplementEventNode() const { + return op_supplement_node_ptr_; + } + void LogMe(BaseLogger* logger) { logger->LogHostTraceEventNode(*this); } private: @@ -155,6 +222,9 @@ class HostTraceEventNode { std::vector runtime_node_ptrs_; // host events called by this std::vector children_; + // memory events happened in this event period + std::vector mem_node_ptrs_; + OperatorSupplementEventNode* op_supplement_node_ptr_ = nullptr; }; class NodeTrees { @@ -162,10 +232,14 @@ class NodeTrees { // constructor NodeTrees(const std::list& host_events, const std::list& runtime_events, - const std::list& device_events) { + const std::list& device_events, + const std::list& mem_events, + const std::list& op_supplement_events) { std::vector host_event_nodes; std::vector runtime_event_nodes; std::vector device_event_nodes; + std::vector mem_event_nodes; + std::vector op_supplement_event_nodes; // encapsulate event into nodes for (auto it = host_events.begin(); it != host_events.end(); ++it) { host_event_nodes.push_back(new HostTraceEventNode(*it)); @@ -176,8 +250,16 @@ class NodeTrees { for (auto it = device_events.begin(); it != device_events.end(); ++it) { device_event_nodes.push_back(new DeviceTraceEventNode(*it)); } + for (auto it = mem_events.begin(); it != mem_events.end(); ++it) { + mem_event_nodes.push_back(new MemTraceEventNode(*it)); + } + for (auto it = op_supplement_events.begin(); + it != op_supplement_events.end(); ++it) { + op_supplement_event_nodes.push_back(new OperatorSupplementEventNode(*it)); + } // build tree - BuildTrees(host_event_nodes, runtime_event_nodes, device_event_nodes); + BuildTrees(host_event_nodes, runtime_event_nodes, device_event_nodes, + mem_event_nodes, op_supplement_event_nodes); } explicit NodeTrees( @@ -190,7 +272,9 @@ class NodeTrees { void LogMe(BaseLogger* logger); void HandleTrees(std::function, std::function, - std::function); + std::function, + std::function, + std::function); const std::map& GetNodeTrees() const { return thread_event_trees_map_; } @@ -199,11 +283,15 @@ class NodeTrees { private: std::map thread_event_trees_map_; void BuildTrees(const std::vector&, - std::vector&, - const std::vector&); + const std::vector&, + const std::vector&, + const std::vector&, + const std::vector&); HostTraceEventNode* BuildTreeRelationship( std::vector host_event_nodes, - std::vector runtime_event_nodes); + std::vector runtime_event_nodes, + std::vector mem_event_nodes, + std::vector op_supplement_event_nodes); }; } // namespace platform diff --git a/paddle/fluid/platform/profiler/output_logger.h b/paddle/fluid/platform/profiler/output_logger.h index 05a68cf2a4a8d..47429eafa64ef 100644 --- a/paddle/fluid/platform/profiler/output_logger.h +++ b/paddle/fluid/platform/profiler/output_logger.h @@ -24,6 +24,7 @@ class DeviceTraceEventNode; // forward declaration class HostTraceEventNode; // forward declaration class CudaRuntimeTraceEventNode; // forward declaration class NodeTrees; // forward declaration +class MemTraceEventNode; // forward declaration class BaseLogger { public: @@ -33,6 +34,7 @@ class BaseLogger { virtual void LogHostTraceEventNode(const HostTraceEventNode&) {} virtual void LogRuntimeTraceEventNode(const CudaRuntimeTraceEventNode&) {} virtual void LogNodeTrees(const NodeTrees&) {} + virtual void LogMemTraceEventNode(const MemTraceEventNode&) {} }; } // namespace platform diff --git a/paddle/fluid/platform/profiler/profiler.cc b/paddle/fluid/platform/profiler/profiler.cc index 8bcf856c01ab6..8e9d8bef605e6 100644 --- a/paddle/fluid/platform/profiler/profiler.cc +++ b/paddle/fluid/platform/profiler/profiler.cc @@ -101,9 +101,10 @@ std::unique_ptr Profiler::Stop() { tracer.Get().StopTracing(); tracer.Get().CollectTraceData(&collector); } - std::unique_ptr tree(new NodeTrees(collector.HostEvents(), - collector.RuntimeEvents(), - collector.DeviceEvents())); + std::unique_ptr tree( + new NodeTrees(collector.HostEvents(), collector.RuntimeEvents(), + collector.DeviceEvents(), collector.MemEvents(), + collector.OperatorSupplementEvents())); cpu_utilization_.RecordEndTimeInfo(); ExtraInfo extrainfo; extrainfo.AddExtraInfo(std::string("System Cpu Utilization"), diff --git a/paddle/fluid/platform/profiler/test_event_node.cc b/paddle/fluid/platform/profiler/test_event_node.cc index 23ad917b57d0e..b70034633ae66 100644 --- a/paddle/fluid/platform/profiler/test_event_node.cc +++ b/paddle/fluid/platform/profiler/test_event_node.cc @@ -25,13 +25,20 @@ using paddle::platform::HostTraceEventNode; using paddle::platform::KernelEventInfo; using paddle::platform::MemcpyEventInfo; using paddle::platform::MemsetEventInfo; +using paddle::platform::MemTraceEvent; +using paddle::platform::MemTraceEventNode; using paddle::platform::NodeTrees; +using paddle::platform::OperatorSupplementEvent; +using paddle::platform::OperatorSupplementEventNode; using paddle::platform::RuntimeTraceEvent; using paddle::platform::TracerEventType; +using paddle::platform::TracerMemEventType; TEST(NodeTreesTest, LogMe_case0) { std::list host_events; std::list runtime_events; std::list device_events; + std::list mem_events; + std::list op_supplement_events; host_events.push_back(HostTraceEvent(std::string("dataloader#1"), TracerEventType::Dataloader, 1000, 10000, 10, 10)); @@ -41,6 +48,19 @@ TEST(NodeTreesTest, LogMe_case0) { std::string("op2"), TracerEventType::Operator, 21000, 30000, 10, 10)); host_events.push_back(HostTraceEvent( std::string("op3"), TracerEventType::Operator, 31000, 40000, 10, 11)); + mem_events.push_back(MemTraceEvent(11500, 0x1000, + TracerMemEventType::Allocate, 10, 10, 50, + "GPU:0", 50, 50)); + mem_events.push_back(MemTraceEvent(11900, 0x1000, TracerMemEventType::Free, + 10, 10, -50, "GPU:0", 0, 50)); + std::map>> input_shapes; + std::map> dtypes; + input_shapes[std::string("X")].push_back(std::vector{1, 2, 3}); + input_shapes[std::string("X")].push_back(std::vector{4, 5, 6, 7}); + dtypes[std::string("X")].push_back(std::string("int8")); + dtypes[std::string("X")].push_back(std::string("float32")); + op_supplement_events.push_back(OperatorSupplementEvent( + 11600, "op1", input_shapes, dtypes, "op1()", 10, 10)); runtime_events.push_back(RuntimeTraceEvent(std::string("cudalaunch1"), 15000, 17000, 10, 10, 1, 0)); runtime_events.push_back(RuntimeTraceEvent(std::string("cudalaunch2"), 25000, @@ -67,7 +87,8 @@ TEST(NodeTreesTest, LogMe_case0) { DeviceTraceEvent(std::string("memset1"), TracerEventType::Memset, 66000, 69000, 0, 10, 11, 5, MemsetEventInfo())); ChromeTracingLogger logger("test_nodetrees_logme_case0.json"); - NodeTrees tree(host_events, runtime_events, device_events); + NodeTrees tree(host_events, runtime_events, device_events, mem_events, + op_supplement_events); std::map> nodes = tree.Traverse(true); EXPECT_EQ(nodes[10].size(), 4u); @@ -81,6 +102,8 @@ TEST(NodeTreesTest, LogMe_case0) { if ((*it)->Name() == "op1") { EXPECT_EQ((*it)->GetChildren().size(), 0u); EXPECT_EQ((*it)->GetRuntimeTraceEventNodes().size(), 2u); + EXPECT_EQ((*it)->GetMemTraceEventNodes().size(), 2u); + EXPECT_NE((*it)->GetOperatorSupplementEventNode(), nullptr); } } for (auto it = thread2_nodes.begin(); it != thread2_nodes.end(); it++) { @@ -90,12 +113,15 @@ TEST(NodeTreesTest, LogMe_case0) { } } tree.LogMe(&logger); + logger.LogMetaInfo(std::unordered_map()); } TEST(NodeTreesTest, LogMe_case1) { std::list host_events; std::list runtime_events; std::list device_events; + std::list mem_events; + std::list op_supplement_events; runtime_events.push_back(RuntimeTraceEvent(std::string("cudalaunch1"), 15000, 17000, 10, 10, 1, 0)); runtime_events.push_back(RuntimeTraceEvent(std::string("cudalaunch2"), 25000, @@ -122,7 +148,8 @@ TEST(NodeTreesTest, LogMe_case1) { DeviceTraceEvent(std::string("memset1"), TracerEventType::Memset, 66000, 69000, 0, 10, 11, 5, MemsetEventInfo())); ChromeTracingLogger logger("test_nodetrees_logme_case1.json"); - NodeTrees tree(host_events, runtime_events, device_events); + NodeTrees tree(host_events, runtime_events, device_events, mem_events, + op_supplement_events); std::map> nodes = tree.Traverse(true); EXPECT_EQ(nodes[10].size(), 1u); @@ -141,18 +168,29 @@ TEST(NodeTreesTest, LogMe_case1) { } } tree.LogMe(&logger); + logger.LogMetaInfo(std::unordered_map()); } TEST(NodeTreesTest, HandleTrees_case0) { std::list host_events; std::list runtime_events; std::list device_events; + std::list mem_events; + std::list op_supplement_events; host_events.push_back(HostTraceEvent( std::string("op1"), TracerEventType::Operator, 10000, 100000, 10, 10)); host_events.push_back(HostTraceEvent( std::string("op2"), TracerEventType::Operator, 30000, 70000, 10, 10)); host_events.push_back(HostTraceEvent( std::string("op3"), TracerEventType::Operator, 2000, 120000, 10, 11)); + mem_events.push_back(MemTraceEvent(11500, 0x1000, + TracerMemEventType::Allocate, 10, 10, 50, + "GPU:0", 50, 50)); + mem_events.push_back(MemTraceEvent(11900, 0x1000, TracerMemEventType::Free, + 10, 10, -50, "GPU:0", 0, 50)); + op_supplement_events.push_back(OperatorSupplementEvent( + 11600, "op1", std::map>>(), + std::map>(), "op1()", 10, 10)); runtime_events.push_back(RuntimeTraceEvent(std::string("cudalaunch1"), 15000, 25000, 10, 10, 1, 0)); runtime_events.push_back(RuntimeTraceEvent(std::string("cudalaunch2"), 35000, @@ -169,7 +207,8 @@ TEST(NodeTreesTest, HandleTrees_case0) { DeviceTraceEvent(std::string("kernel3"), TracerEventType::Kernel, 60000, 75000, 0, 10, 11, 3, KernelEventInfo())); ChromeTracingLogger logger("test_nodetrees_handletrees_case0.json"); - NodeTrees tree(host_events, runtime_events, device_events); + NodeTrees tree(host_events, runtime_events, device_events, mem_events, + op_supplement_events); std::map> nodes = tree.Traverse(true); EXPECT_EQ(nodes[10].size(), 3u); @@ -199,6 +238,12 @@ TEST(NodeTreesTest, HandleTrees_case0) { }); std::function device_event_node_handle( [&](DeviceTraceEventNode* a) { logger.LogDeviceTraceEventNode(*a); }); + std::function mem_event_node_handle( + [&](MemTraceEventNode* a) { logger.LogMemTraceEventNode(*a); }); + std::function + op_supplement_event_node_handle([&](OperatorSupplementEventNode* a) {}); tree.HandleTrees(host_event_node_handle, runtime_event_node_handle, - device_event_node_handle); + device_event_node_handle, mem_event_node_handle, + op_supplement_event_node_handle); + logger.LogMetaInfo(std::unordered_map()); } diff --git a/paddle/fluid/platform/profiler/trace_event.h b/paddle/fluid/platform/profiler/trace_event.h index 6d398a26eda10..bfa000e2683de 100644 --- a/paddle/fluid/platform/profiler/trace_event.h +++ b/paddle/fluid/platform/profiler/trace_event.h @@ -14,7 +14,9 @@ limitations under the License. */ #pragma once +#include #include +#include namespace paddle { namespace platform { @@ -56,6 +58,15 @@ enum class TracerEventType { NumTypes }; +enum class TracerMemEventType { + // Used to mark memory allocation + Allocate = 0, + // Used to mark memory free + Free = 1, + // A flag to denote the number of current types + NumTypes +}; + struct KernelEventInfo { // The X-dimension block size for the kernel. uint32_t block_x; @@ -118,6 +129,36 @@ struct MemsetEventInfo { uint32_t value; }; +struct OperatorSupplementEvent { + OperatorSupplementEvent() = default; + OperatorSupplementEvent( + uint64_t timestamp_ns, const std::string& op_type, + const std::map>>& + input_shapes, + const std::map>& dtypes, + const std::string& callstack, uint64_t process_id, uint64_t thread_id) + : timestamp_ns(timestamp_ns), + op_type(op_type), + input_shapes(input_shapes), + dtypes(dtypes), + callstack(callstack), + process_id(process_id), + thread_id(thread_id) {} + // timestamp of the record + uint64_t timestamp_ns; + // op type name + std::string op_type; + // input shapes + std::map>> input_shapes; + std::map> dtypes; + // call stack + std::string callstack; + // process id of the record + uint64_t process_id; + // thread id of the record + uint64_t thread_id; +}; + struct HostTraceEvent { HostTraceEvent() = default; HostTraceEvent(const std::string& name, TracerEventType type, @@ -242,5 +283,42 @@ struct DeviceTraceEvent { }; }; +struct MemTraceEvent { + MemTraceEvent() = default; + MemTraceEvent(uint64_t timestamp_ns, uint64_t addr, TracerMemEventType type, + uint64_t process_id, uint64_t thread_id, int64_t increase_bytes, + const std::string& place, uint64_t current_allocated, + uint64_t current_reserved) + : timestamp_ns(timestamp_ns), + addr(addr), + type(type), + process_id(process_id), + thread_id(thread_id), + increase_bytes(increase_bytes), + place(place), + current_allocated(current_allocated), + current_reserved(current_reserved) {} + + // timestamp of the record + uint64_t timestamp_ns; + // memory addr of allocation or free + uint64_t addr; + // memory manipulation type + TracerMemEventType type; + // process id of the record + uint64_t process_id; + // thread id of the record + uint64_t thread_id; + // increase bytes after this manipulation, allocation for sign +, free for + // sign - + int64_t increase_bytes; + // place + std::string place; + // current total allocated memory + uint64_t current_allocated; + // current total reserved memory + uint64_t current_reserved; +}; + } // namespace platform } // namespace paddle diff --git a/paddle/fluid/platform/profiler/trace_event_collector.h b/paddle/fluid/platform/profiler/trace_event_collector.h index d1593bc1bfcd7..7c2ea9e16c423 100644 --- a/paddle/fluid/platform/profiler/trace_event_collector.h +++ b/paddle/fluid/platform/profiler/trace_event_collector.h @@ -39,6 +39,12 @@ class TraceEventCollector { thread_names_[tid] = name; } + void AddMemEvent(MemTraceEvent&& event) { mem_events_.push_back(event); } + + void AddOperatorSupplementEvent(OperatorSupplementEvent&& event) { + op_supplement_events_.push_back(event); + } + const std::list& HostEvents() const { return host_events_; } const std::list& RuntimeEvents() const { @@ -49,6 +55,12 @@ class TraceEventCollector { return device_events_; } + const std::list& MemEvents() const { return mem_events_; } + + const std::list& OperatorSupplementEvents() const { + return op_supplement_events_; + } + const std::unordered_map& ThreadNames() const { return thread_names_; } @@ -58,6 +70,8 @@ class TraceEventCollector { host_events_.clear(); runtime_events_.clear(); device_events_.clear(); + mem_events_.clear(); + op_supplement_events_.clear(); } private: @@ -65,6 +79,8 @@ class TraceEventCollector { std::list host_events_; std::list runtime_events_; std::list device_events_; + std::list mem_events_; + std::list op_supplement_events_; }; } // namespace platform diff --git a/paddle/fluid/platform/profiler/utils.cc b/paddle/fluid/platform/profiler/utils.cc index de314d298c90e..1f8e113fdd914 100644 --- a/paddle/fluid/platform/profiler/utils.cc +++ b/paddle/fluid/platform/profiler/utils.cc @@ -14,6 +14,7 @@ limitations under the License. */ #include "paddle/fluid/platform/profiler/utils.h" +#include #include #include "glog/logging.h" @@ -21,6 +22,26 @@ limitations under the License. */ namespace paddle { namespace platform { + +template <> +std::string json_vector( + const std::vector type_vector) { + std::ostringstream res_stream; + auto count = type_vector.size(); + res_stream << "["; + for (auto it = type_vector.begin(); it != type_vector.end(); it++) { + if (count > 1) { + res_stream << "\"" << (*it) << "\"" + << ","; + } else { + res_stream << "\"" << (*it) << "\""; + } + count--; + } + res_stream << "]"; + return res_stream.str(); +} + #ifdef PADDLE_WITH_CUPTI float CalculateEstOccupancy(uint32_t DeviceId, uint16_t RegistersPerThread, int32_t StaticSharedMemory, @@ -61,5 +82,21 @@ float CalculateEstOccupancy(uint32_t DeviceId, uint16_t RegistersPerThread, } #endif +const char* StringTracerMemEventType(TracerMemEventType type) { + static const char* categary_name_[] = {"Allocate", "Free"}; + return categary_name_[static_cast(type)]; +} + +const char* StringTracerEventType(TracerEventType type) { + static const char* categary_name_[] = { + "Operator", "Dataloader", "ProfileStep", + "CudaRuntime", "Kernel", "Memcpy", + "Memset", "UserDefined", "OperatorInner", + "Forward", "Backward", "Optimization", + "Communication", "PythonOp", "PythonUserDefined", + "MluRuntime"}; + return categary_name_[static_cast(type)]; +} + } // namespace platform } // namespace paddle diff --git a/paddle/fluid/platform/profiler/utils.h b/paddle/fluid/platform/profiler/utils.h index 433fd0b825a11..5f7c420789f80 100644 --- a/paddle/fluid/platform/profiler/utils.h +++ b/paddle/fluid/platform/profiler/utils.h @@ -14,11 +14,15 @@ limitations under the License. */ #pragma once #include +#include +#include #include +#include #include "paddle/fluid/platform/dynload/cupti.h" #include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/os_info.h" +#include "paddle/fluid/platform/profiler/trace_event.h" namespace paddle { namespace platform { @@ -36,6 +40,64 @@ std::string string_format(const std::string& format, Args... args) { return std::string(buf.get(), size - 1); // exclude the '\0' } +template +std::string json_vector(const std::vector type_vector) { + std::ostringstream res_stream; + auto count = type_vector.size(); + res_stream << "["; + for (auto it = type_vector.begin(); it != type_vector.end(); it++) { + if (count > 1) { + res_stream << (*it) << ","; + } else { + res_stream << (*it); + } + count--; + } + res_stream << "]"; + return res_stream.str(); +} + +template +std::string json_vector( + const std::vector> shape_vector) { + std::ostringstream res_stream; + auto count = shape_vector.size(); + res_stream << "["; + for (auto it = shape_vector.begin(); it != shape_vector.end(); it++) { + if (count > 1) { + res_stream << json_vector(*it) << ","; + } else { + res_stream << json_vector(*it); + } + count--; + } + res_stream << "]"; + return res_stream.str(); +} + +template <> +std::string json_vector( + const std::vector type_vector); + +template +std::string json_dict(const std::map> data_map) { + std::ostringstream res_stream; + auto count = data_map.size(); + res_stream << "{"; + for (auto it = data_map.begin(); it != data_map.end(); it++) { + if (count > 1) { + res_stream << "\"" << it->first << "\"" + << ":" << json_vector(it->second) << ","; + } else { + res_stream << "\"" << it->first << "\"" + << ":" << json_vector(it->second); + } + count--; + } + res_stream << "}"; + return res_stream.str(); +} + static std::string GetStringFormatLocalTime() { std::time_t rawtime; std::tm* timeinfo; @@ -50,6 +112,10 @@ static int64_t nsToUs(uint64_t end_ns, uint64_t start_ns = 0) { return (end_ns - start_ns) / 1000; } +const char* StringTracerMemEventType(TracerMemEventType type); + +const char* StringTracerEventType(TracerEventType type); + static float nsToUsFloat(uint64_t end_ns, uint64_t start_ns = 0) { return static_cast(end_ns - start_ns) / 1000; } @@ -63,5 +129,6 @@ float CalculateEstOccupancy(uint32_t deviceId, uint16_t registersPerThread, int32_t dynamicSharedMemory, int32_t blockX, int32_t blockY, int32_t blockZ, float blocksPerSm); #endif + } // namespace platform } // namespace paddle