From e40ab55196e548d18d2992df61173c307d7f6741 Mon Sep 17 00:00:00 2001 From: Ashkan Aliabadi Date: Tue, 13 Oct 2020 16:02:24 -0700 Subject: [PATCH] Add fence. ghstack-source-id: 7c0b55debf2d4e4db4d0335e7bbc4f087e8d7365 Pull Request resolved: https://github.com/pytorch/pytorch/pull/45148 --- aten/src/ATen/native/vulkan/api/Command.cpp | 4 +- aten/src/ATen/native/vulkan/api/Command.h | 2 +- aten/src/ATen/native/vulkan/api/Resource.cpp | 120 ++++++++++++++++--- aten/src/ATen/native/vulkan/api/Resource.h | 60 +++++++++- 4 files changed, 163 insertions(+), 23 deletions(-) diff --git a/aten/src/ATen/native/vulkan/api/Command.cpp b/aten/src/ATen/native/vulkan/api/Command.cpp index 61a6ecd02c..4461240ba1 100644 --- a/aten/src/ATen/native/vulkan/api/Command.cpp +++ b/aten/src/ATen/native/vulkan/api/Command.cpp @@ -257,7 +257,7 @@ void Command::Buffer::dispatch( void Command::Buffer::submit( const VkQueue queue, - const VkFence fence) { + const Resource::Fence fence) { TORCH_INTERNAL_ASSERT_DEBUG_ONLY( command_buffer_, "This command buffer is in an invalid state! " @@ -279,7 +279,7 @@ void Command::Buffer::submit( nullptr, }; - VK_CHECK(vkQueueSubmit(queue, 1u, &submit_info, fence)); + VK_CHECK(vkQueueSubmit(queue, 1u, &submit_info, fence.handle())); } Command::Pool::Pool(const GPU& gpu) diff --git a/aten/src/ATen/native/vulkan/api/Command.h b/aten/src/ATen/native/vulkan/api/Command.h index a3c646c15d..aaec2df259 100644 --- a/aten/src/ATen/native/vulkan/api/Command.h +++ b/aten/src/ATen/native/vulkan/api/Command.h @@ -32,7 +32,7 @@ struct Command final { void bind(const Descriptor::Set& set); void copy(Resource::Buffer::Object source, Resource::Buffer::Object destination); void dispatch(const Shader::WorkGroup& work_group); - void submit(VkQueue queue, VkFence fence); + void submit(VkQueue queue, Resource::Fence fence = {}); private: VkCommandBuffer command_buffer_; diff --git a/aten/src/ATen/native/vulkan/api/Resource.cpp b/aten/src/ATen/native/vulkan/api/Resource.cpp index 206433c1fa..436288645e 100644 --- a/aten/src/ATen/native/vulkan/api/Resource.cpp +++ b/aten/src/ATen/native/vulkan/api/Resource.cpp @@ -177,6 +177,31 @@ Resource::Image::Sampler::Factory::operator()( }; } +void Resource::Fence::wait(const uint64_t timeout_nanoseconds) { + const VkFence fence = handle(/* add_to_waitlist = */ false); + + const auto waitlist_itr = std::find( + pool->fence_.waitlist.cbegin(), + pool->fence_.waitlist.cend(), + fence); + + if (pool->fence_.waitlist.cend() != waitlist_itr) { + VK_CHECK(vkWaitForFences( + pool->device_, + 1u, + &fence, + VK_TRUE, + timeout_nanoseconds)); + + VK_CHECK(vkResetFences( + pool->device_, + 1u, + &fence)); + + pool->fence_.waitlist.erase(waitlist_itr); + } +} + Resource::Pool::Pool(const GPU& gpu) : device_(gpu.device), allocator_( @@ -185,17 +210,30 @@ Resource::Pool::Pool(const GPU& gpu) gpu.adapter->handle, device_), vmaDestroyAllocator), - sampler_(gpu) { - buffers_.reserve(Configuration::kReserve); - images_.reserve(Configuration::kReserve); + buffer_{}, + image_{ + .sampler = Image::Sampler{gpu}, + }, + fence_{} { + buffer_.pool.reserve(Configuration::kReserve); + image_.pool.reserve(Configuration::kReserve); + fence_.pool.reserve(Configuration::kReserve); +} + +Resource::Pool::~Pool() { + try { + purge(); + } + catch (...) { + } } Resource::Pool::Pool(Pool&& pool) : device_(std::move(pool.device_)), allocator_(std::move(pool.allocator_)), - buffers_(std::move(pool.buffers_)), - images_(std::move(pool.images_)), - sampler_(std::move(pool.sampler_)) { + buffer_(std::move(pool.buffer_)), + image_(std::move(pool.image_)), + fence_(std::move(pool.fence_)) { pool.device_ = VK_NULL_HANDLE; } @@ -203,9 +241,9 @@ Resource::Pool& Resource::Pool::operator=(Pool&& pool) { if (&pool != this) { device_ = std::move(pool.device_); allocator_ = std::move(pool.allocator_); - buffers_ = std::move(pool.buffers_); - images_ = std::move(pool.images_); - sampler_ = std::move(pool.sampler_); + buffer_ = std::move(pool.buffer_); + image_ = std::move(pool.image_); + fence_ = std::move(pool.fence_); pool.device_ = VK_NULL_HANDLE; }; @@ -249,7 +287,7 @@ Resource::Buffer Resource::Pool::buffer( TORCH_CHECK(buffer, "Invalid Vulkan buffer!"); TORCH_CHECK(allocation, "Invalid VMA allocation!"); - buffers_.emplace_back( + buffer_.pool.emplace_back( Buffer{ Buffer::Object{ buffer, @@ -263,7 +301,7 @@ Resource::Buffer Resource::Pool::buffer( }, &release_buffer); - return buffers_.back().get(); + return buffer_.pool.back().get(); } Resource::Image Resource::Pool::image( @@ -342,13 +380,13 @@ Resource::Image Resource::Pool::image( view, "Invalid Vulkan image view!"); - images_.emplace_back( + image_.pool.emplace_back( Image{ Image::Object{ image, VK_IMAGE_LAYOUT_UNDEFINED, view, - sampler_.cache.retrieve(descriptor.sampler), + image_.sampler.cache.retrieve(descriptor.sampler), }, Memory{ allocator_.get(), @@ -357,7 +395,40 @@ Resource::Image Resource::Pool::image( }, &release_image); - return images_.back().get(); + return image_.pool.back().get(); +} + +Resource::Fence Resource::Pool::fence() { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + device_ && allocator_, + "This resource pool is in an invalid state! ", + "Potential reason: This resource pool is moved from."); + + if (fence_.pool.size() == fence_.in_use) { + const VkFenceCreateInfo fence_create_info{ + VK_STRUCTURE_TYPE_FENCE_CREATE_INFO, + nullptr, + 0u, + }; + + VkFence fence{}; + VK_CHECK(vkCreateFence( + device_, + &fence_create_info, + nullptr, + &fence)); + + TORCH_CHECK( + fence, + "Invalid Vulkan fence!"); + + fence_.pool.emplace_back(fence, VK_DELETER(Fence)(device_)); + } + + return Fence{ + this, + fence_.in_use++, + }; } void Resource::Pool::purge() { @@ -366,8 +437,25 @@ void Resource::Pool::purge() { "This resource pool is in an invalid state! ", "Potential reason: This resource pool is moved from."); - images_.clear(); - buffers_.clear(); + if (!fence_.waitlist.empty()) { + VK_CHECK(vkWaitForFences( + device_, + fence_.waitlist.size(), + fence_.waitlist.data(), + VK_TRUE, + UINT64_MAX)); + + VK_CHECK(vkResetFences( + device_, + fence_.waitlist.size(), + fence_.waitlist.data())); + + fence_.waitlist.clear(); + } + + fence_.in_use = 0u; + image_.pool.clear(); + buffer_.pool.clear(); } } // namespace api diff --git a/aten/src/ATen/native/vulkan/api/Resource.h b/aten/src/ATen/native/vulkan/api/Resource.h index 04b73bd0ec..2f65b7fbe6 100644 --- a/aten/src/ATen/native/vulkan/api/Resource.h +++ b/aten/src/ATen/native/vulkan/api/Resource.h @@ -11,6 +11,8 @@ namespace vulkan { namespace api { struct Resource final { + class Pool; + // // Memory // @@ -211,6 +213,19 @@ struct Resource final { operator bool() const; }; + // + // Fence + // + + struct Fence final { + Pool* pool; + size_t id; + + operator bool() const; + VkFence handle(bool add_to_waitlist = true) const; + void wait(uint64_t timeout_nanoseconds = UINT64_MAX); + }; + // // Pool // @@ -222,12 +237,16 @@ struct Resource final { Pool& operator=(const Pool&) = delete; Pool(Pool&&); Pool& operator=(Pool&&); - ~Pool() = default; + ~Pool(); Buffer buffer(const Buffer::Descriptor& descriptor); Image image(const Image::Descriptor& descriptor); + Fence fence(); void purge(); + private: + friend struct Fence; + private: struct Configuration final { static constexpr uint32_t kReserve = 256u; @@ -235,9 +254,21 @@ struct Resource final { VkDevice device_; Handle allocator_; - std::vector> buffers_; - std::vector> images_; - Image::Sampler sampler_; + + struct { + std::vector> pool; + } buffer_; + + struct { + std::vector> pool; + Image::Sampler sampler; + } image_; + + struct { + std::vector> pool; + mutable std::vector waitlist; + size_t in_use; + } fence_; } pool; explicit Resource(const GPU& gpu) @@ -319,6 +350,27 @@ inline Resource::Image::operator bool() const { return object; } +inline Resource::Fence::operator bool() const { + return pool; +} + +inline VkFence Resource::Fence::handle(const bool add_to_waitlist) const { + if (!pool) { + return VK_NULL_HANDLE; + } + + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + id < pool->fence_.pool.size(), + "Invalid Vulkan fence!"); + + const VkFence fence = pool->fence_.pool[id].get(); + if (add_to_waitlist) { + pool->fence_.waitlist.push_back(fence); + } + + return fence; +} + } // namespace api } // namespace vulkan } // namespace native