diff --git a/src/runtime/metal/metal_device_api.mm b/src/runtime/metal/metal_device_api.mm index 37fb9dc347d4..42dd249630ff 100644 --- a/src/runtime/metal/metal_device_api.mm +++ b/src/runtime/metal/metal_device_api.mm @@ -21,6 +21,7 @@ * \file metal_device_api.mm */ #include +#include #include #include "metal_common.h" @@ -366,6 +367,42 @@ int GetWarpSize(id dev) { MetalWorkspace::Global()->ReinitializeDefaultStreams(); }); +class MetalTimerNode : public TimerNode { + public: + MetalTimerNode() {} + explicit MetalTimerNode(Device dev) : dev_(dev) { + mtl_dev_ = MetalWorkspace::Global()->GetDevice(dev_); + } + + virtual void Start() { + [mtl_dev_ sampleTimestamps:&start_cpu_time_ gpuTimestamp:&start_gpu_time_]; + } + virtual void Stop() { + auto ws = MetalWorkspace::Global(); + ws->StreamSync(dev_, ws->GetCurrentStream(dev_)); + [mtl_dev_ sampleTimestamps:&stop_cpu_time_ gpuTimestamp:&stop_gpu_time_]; + } + virtual int64_t SyncAndGetElapsedNanos() { return stop_gpu_time_ - start_gpu_time_; } + + static constexpr const char* _type_key = "MetalTimerNode"; + TVM_DECLARE_FINAL_OBJECT_INFO(MetalTimerNode, TimerNode); + + private: + Device dev_; + id mtl_dev_; + + MTLTimestamp start_cpu_time_; + MTLTimestamp start_gpu_time_; + MTLTimestamp stop_cpu_time_; + MTLTimestamp stop_gpu_time_; +}; + +TVM_REGISTER_OBJECT_TYPE(MetalTimerNode); + +TVM_REGISTER_GLOBAL("profiling.timer.metal").set_body_typed([](Device dev) { + return Timer(make_object(dev)); +}); + } // namespace metal } // namespace runtime } // namespace tvm