Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Runtime][Vulkan] Add RGP support to TVM for vulkan device #10953

Merged
merged 1 commit into from
Apr 13, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 53 additions & 0 deletions src/runtime/vulkan/vulkan_amdrgp.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

#include "vulkan_device.h"

namespace tvm {
namespace runtime {
namespace vulkan {

VulkanStreamProfiler::VulkanStreamProfiler(const VulkanDevice* device)
: device_(device), curr_state_(READY), available_(device->UseDebugUtilsLabel()) {}

void AmdRgpProfiler::capture() {
if (!available_) {
return;
}

// Trigger RGP capture by using dummy present and switch state from READY to RUNNING
if (curr_state_ == READY) {
VkDebugUtilsLabelEXT frame_end_label = {
VK_STRUCTURE_TYPE_DEBUG_UTILS_LABEL_EXT, NULL, "AmdFrameEnd", {0.0f, 0.0f, 0.0f, 0.0f}};
device_->queue_insert_debug_utils_label_functions->vkQueueInsertDebugUtilsLabelEXT(
device_->Queue(), &frame_end_label);

VkDebugUtilsLabelEXT frame_begin_label = {
VK_STRUCTURE_TYPE_DEBUG_UTILS_LABEL_EXT, NULL, "AmdFrameBegin", {0.0f, 0.0f, 0.0f, 0.0f}};
device_->queue_insert_debug_utils_label_functions->vkQueueInsertDebugUtilsLabelEXT(
device_->Queue(), &frame_begin_label);

// Set state as RUNNING
curr_state_ = RUNNING;
}
}

} // namespace vulkan
} // namespace runtime
} // namespace tvm
63 changes: 63 additions & 0 deletions src/runtime/vulkan/vulkan_amdrgp.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

#ifndef TVM_RUNTIME_VULKAN_VULKAN_AMDRGP_H_
#define TVM_RUNTIME_VULKAN_VULKAN_AMDRGP_H_

namespace tvm {
namespace runtime {
namespace vulkan {

class VulkanDevice;

class VulkanStreamProfiler {
public:
enum state { READY = 0, RUNNING, RESET };

explicit VulkanStreamProfiler(const VulkanDevice* device);

virtual ~VulkanStreamProfiler() {}

virtual void reset() { curr_state_ = RESET; }

virtual void ready() {
if (curr_state_ == RESET) {
curr_state_ = READY;
}
}

virtual void capture() = 0;

protected:
const VulkanDevice* device_;
state curr_state_;
bool available_;
};

class AmdRgpProfiler : public VulkanStreamProfiler {
public:
explicit AmdRgpProfiler(const VulkanDevice* device) : VulkanStreamProfiler(device) {}

void capture();
};

} // namespace vulkan
} // namespace runtime
} // namespace tvm
#endif // TVM_RUNTIME_VULKAN_VULKAN_AMDRGP_H_
13 changes: 13 additions & 0 deletions src/runtime/vulkan/vulkan_device.cc
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,12 @@ VulkanGetBufferMemoryRequirements2Functions::VulkanGetBufferMemoryRequirements2F
vkGetDeviceProcAddr(device, "vkGetBufferMemoryRequirements2KHR"));
}

VulkanQueueInsertDebugUtilsLabelFunctions::VulkanQueueInsertDebugUtilsLabelFunctions(
VkInstance instance) {
vkQueueInsertDebugUtilsLabelEXT = (PFN_vkQueueInsertDebugUtilsLabelEXT)ICHECK_NOTNULL(
vkGetInstanceProcAddr(instance, "vkQueueInsertDebugUtilsLabelEXT"));
}

VulkanDevice::VulkanDevice(const VulkanInstance& instance, VkPhysicalDevice phy_device)
: physical_device_(phy_device) {
queue_family_index = SelectComputeQueueFamily();
Expand Down Expand Up @@ -325,6 +331,11 @@ VulkanDevice::VulkanDevice(const VulkanInstance& instance, VkPhysicalDevice phy_
get_buffer_memory_requirements_2_functions =
std::make_unique<VulkanGetBufferMemoryRequirements2Functions>(device_);
}

if (instance.HasExtension("VK_EXT_debug_utils")) {
queue_insert_debug_utils_label_functions =
std::make_unique<VulkanQueueInsertDebugUtilsLabelFunctions>(instance);
}
}

VulkanDevice::~VulkanDevice() {
Expand Down Expand Up @@ -363,6 +374,8 @@ void VulkanDevice::do_swap(VulkanDevice&& other) {
std::swap(descriptor_template_khr_functions, other.descriptor_template_khr_functions);
std::swap(get_buffer_memory_requirements_2_functions,
other.get_buffer_memory_requirements_2_functions);
std::swap(queue_insert_debug_utils_label_functions,
other.queue_insert_debug_utils_label_functions);
std::swap(compute_mtype_index, other.compute_mtype_index);
std::swap(queue, other.queue);
std::swap(queue_family_index, other.queue_family_index);
Expand Down
12 changes: 12 additions & 0 deletions src/runtime/vulkan/vulkan_device.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,12 @@ struct VulkanGetBufferMemoryRequirements2Functions {
PFN_vkGetBufferMemoryRequirements2KHR vkGetBufferMemoryRequirements2KHR{nullptr};
};

struct VulkanQueueInsertDebugUtilsLabelFunctions {
explicit VulkanQueueInsertDebugUtilsLabelFunctions(VkInstance instance);

PFN_vkQueueInsertDebugUtilsLabelEXT vkQueueInsertDebugUtilsLabelEXT{nullptr};
};

/*!
* \brief Stores the capabilities/limits queried from the physical device.
*
Expand Down Expand Up @@ -212,6 +218,8 @@ class VulkanDevice {
std::unique_ptr<VulkanDescriptorTemplateKHRFunctions> descriptor_template_khr_functions{nullptr};
std::unique_ptr<VulkanGetBufferMemoryRequirements2Functions>
get_buffer_memory_requirements_2_functions{nullptr};
std::unique_ptr<VulkanQueueInsertDebugUtilsLabelFunctions>
queue_insert_debug_utils_label_functions{nullptr};
// Memory type index for compute
uint32_t compute_mtype_index{0};

Expand All @@ -220,6 +228,10 @@ class VulkanDevice {

bool UseImmediate() const { return descriptor_template_khr_functions != nullptr; }

bool UseDebugUtilsLabel() const { return queue_insert_debug_utils_label_functions != nullptr; }

VkQueue Queue() const { return queue; }

private:
/*! \brief Helper function for move assignment/construction
*
Expand Down
3 changes: 3 additions & 0 deletions src/runtime/vulkan/vulkan_device_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,7 @@ void VulkanDeviceAPI::CopyDataFromTo(const void* from, size_t from_offset, void*
&copy_info);
});
stream.Synchronize();
stream.ProfilerReset();
if (!device.coherent_staging) {
VkMappedMemoryRange mrange;
mrange.sType = VK_STRUCTURE_TYPE_MAPPED_MEMORY_RANGE;
Expand Down Expand Up @@ -413,6 +414,8 @@ void VulkanDeviceAPI::CopyDataFromTo(const void* from, size_t from_offset, void*
vkCmdCopyBuffer(state->cmd_buffer_, staging_buffer.vk_buf.buffer, to_buf->buffer, 1,
&copy_info);
});

stream.ProfilerReady();
// TODO(tulloch): should we instead make the staging buffer a property of the
// Stream? This would allow us to elide synchronizations here.
stream.Synchronize();
Expand Down
7 changes: 7 additions & 0 deletions src/runtime/vulkan/vulkan_instance.cc
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,13 @@ VulkanInstance::VulkanInstance() {
std::vector<const char*> required_extensions{};
std::vector<const char*> optional_extensions{"VK_KHR_get_physical_device_properties2"};

// Check if RGP support is needed. If needed, enable VK_EXT_debug_utils extension for
// inserting debug labels into the queue.
if (support::BoolEnvironmentVar("TVM_USE_AMD_RGP")) {
LOG(INFO) << "Push VK_EXT_debug_utils";
required_extensions.push_back("VK_EXT_debug_utils");
}

uint32_t inst_extension_prop_count;
VULKAN_CALL(
vkEnumerateInstanceExtensionProperties(nullptr, &inst_extension_prop_count, nullptr));
Expand Down
13 changes: 13 additions & 0 deletions src/runtime/vulkan/vulkan_stream.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

#include "vulkan_stream.h"

#include "../../support/utils.h"
#include "vulkan_device.h"

namespace tvm {
Expand Down Expand Up @@ -55,11 +56,19 @@ VulkanStream::VulkanStream(const VulkanDevice* device)
cb_begin.flags = VK_COMMAND_BUFFER_USAGE_ONE_TIME_SUBMIT_BIT;
cb_begin.pInheritanceInfo = 0;
VULKAN_CALL(vkBeginCommandBuffer(state_->cmd_buffer_, &cb_begin));

if (support::BoolEnvironmentVar("TVM_USE_AMD_RGP")) {
profiler_ = new AmdRgpProfiler(device_);
}
}

VulkanStream::~VulkanStream() {
vkDestroyFence(*device_, state_->fence_, nullptr);
vkDestroyCommandPool(*device_, cmd_pool_, nullptr);

if (profiler_) {
delete (profiler_);
}
}

void VulkanStream::Launch(const std::function<void(VulkanStreamState*)>& kernel) {
Expand Down Expand Up @@ -132,6 +141,10 @@ void VulkanStream::Synchronize() {
cb_submit.signalSemaphoreCount = 0;
cb_submit.pSignalSemaphores = nullptr;

if (profiler_) {
profiler_->capture();
}

device_->QueueSubmit(cb_submit, state_->fence_);

uint64_t timeout = 1UL << 30UL;
Expand Down
16 changes: 16 additions & 0 deletions src/runtime/vulkan/vulkan_stream.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include <unordered_map>
#include <vector>

#include "vulkan_amdrgp.h"
#include "vulkan_common.h"

namespace tvm {
Expand Down Expand Up @@ -99,6 +100,20 @@ class VulkanStream {
const std::function<void(VulkanStreamState*)>& deferred_kernel,
const VulkanStreamToken& deferred_token);

// reset profiler state
void ProfilerReset() {
if (profiler_) {
profiler_->reset();
}
}

// set profiler to READY state after reset
void ProfilerReady() {
if (profiler_) {
profiler_->ready();
}
}

// Synchronize the current stream `state_` with respect to the host.
void Synchronize();

Expand All @@ -110,6 +125,7 @@ class VulkanStream {
std::unordered_map<VkDescriptorSet, std::vector<VulkanStreamToken>> deferred_tokens_;
std::vector<std::function<void(VulkanStreamState*)>> deferred_kernels_;
VkCommandPool cmd_pool_;
VulkanStreamProfiler* profiler_ = nullptr;
};

} // namespace vulkan
Expand Down
18 changes: 18 additions & 0 deletions src/runtime/vulkan/vulkan_wrapped_func.cc
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,15 @@ void VulkanWrappedFunc::operator()(TVMArgs args, TVMRetValue* rv,
vkCmdPipelineBarrier(state->cmd_buffer_, VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT,
VK_PIPELINE_STAGE_TRANSFER_BIT | VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, 0,
1, &barrier_info, 0, nullptr, 0, nullptr);

if (device.UseDebugUtilsLabel()) {
VkDebugUtilsLabelEXT dispatch_label = {VK_STRUCTURE_TYPE_DEBUG_UTILS_LABEL_EXT,
NULL,
func_name_.c_str(),
{0.0f, 0.0f, 0.0f, 0.0f}};
device.queue_insert_debug_utils_label_functions->vkQueueInsertDebugUtilsLabelEXT(
device.Queue(), &dispatch_label);
}
});
return;
}
Expand Down Expand Up @@ -164,6 +173,15 @@ void VulkanWrappedFunc::operator()(TVMArgs args, TVMRetValue* rv,
deferred_token.buffers_[i] = descriptor_buffers[i].buffer;
}
device.ThreadLocalStream().LaunchDeferred(deferred_initializer, deferred_kernel, deferred_token);

if (device.UseDebugUtilsLabel()) {
VkDebugUtilsLabelEXT dispatch_label = {VK_STRUCTURE_TYPE_DEBUG_UTILS_LABEL_EXT,
NULL,
func_name_.c_str(),
{0.0f, 0.0f, 0.0f, 0.0f}};
device.queue_insert_debug_utils_label_functions->vkQueueInsertDebugUtilsLabelEXT(
device.Queue(), &dispatch_label);
}
masahi marked this conversation as resolved.
Show resolved Hide resolved
}

VulkanModuleNode::~VulkanModuleNode() {
Expand Down