Skip to content

Commit

Permalink
Add image sampler.
Browse files Browse the repository at this point in the history
ghstack-source-id: 90ce47103d33d5c3eb647b183d8e556ff1172e10
Pull Request resolved: pytorch/pytorch#45037
  • Loading branch information
Ashkan Aliabadi committed Oct 13, 2020
1 parent 53b79cd commit 48fdf9d
Show file tree
Hide file tree
Showing 6 changed files with 228 additions and 62 deletions.
13 changes: 6 additions & 7 deletions aten/src/ATen/native/vulkan/api/Command.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ void Command::Buffer::barrier(
barrier.layout.dst,
VK_QUEUE_FAMILY_IGNORED,
VK_QUEUE_FAMILY_IGNORED,
barrier.handle,
barrier.object.handle,
VkImageSubresourceRange{
VK_IMAGE_ASPECT_COLOR_BIT,
0u,
Expand Down Expand Up @@ -212,9 +212,8 @@ void Command::Buffer::bind(
}

void Command::Buffer::copy(
const VkBuffer source,
const VkBuffer destination,
const size_t size) {
const Resource::Buffer::Object source,
const Resource::Buffer::Object destination) {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
command_buffer_,
"This command buffer is in an invalid state! "
Expand All @@ -231,13 +230,13 @@ void Command::Buffer::copy(
const VkBufferCopy buffer_copy{
0u,
0u,
size,
std::min(source.range, destination.range),
};

vkCmdCopyBuffer(
command_buffer_,
source,
destination,
source.handle,
destination.handle,
1u,
&buffer_copy);
}
Expand Down
3 changes: 2 additions & 1 deletion aten/src/ATen/native/vulkan/api/Command.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include <ATen/native/vulkan/api/Common.h>
#include <ATen/native/vulkan/api/Descriptor.h>
#include <ATen/native/vulkan/api/Pipeline.h>
#include <ATen/native/vulkan/api/Resource.h>
#include <ATen/native/vulkan/api/Shader.h>

namespace at {
Expand All @@ -29,7 +30,7 @@ struct Command final {
void barrier(const Pipeline::Barrier& barrier);
void bind(const Pipeline::Object& pipeline);
void bind(const Descriptor::Set& set);
void copy(VkBuffer source, VkBuffer destination, size_t size);
void copy(Resource::Buffer::Object source, Resource::Buffer::Object destination);
void dispatch(const Shader::WorkGroup& work_group);
void submit(VkQueue queue, VkFence fence);

Expand Down
14 changes: 7 additions & 7 deletions aten/src/ATen/native/vulkan/api/Descriptor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ void Descriptor::Set::update(const Item& item) {
Descriptor::Set& Descriptor::Set::bind(
const uint32_t binding,
const VkDescriptorType type,
const Resource::Buffer& buffer) {
const Resource::Buffer::Object& buffer) {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
device_,
"This descriptor set is in an invalid state! "
Expand All @@ -169,8 +169,8 @@ Descriptor::Set& Descriptor::Set::bind(
{
.buffer = {
buffer.handle,
0u, // buffer.offset,
0u, // buffer.range,
buffer.offset,
buffer.range,
},
},
});
Expand All @@ -181,7 +181,7 @@ Descriptor::Set& Descriptor::Set::bind(
Descriptor::Set& Descriptor::Set::bind(
const uint32_t binding,
const VkDescriptorType type,
const Resource::Image& image) {
const Resource::Image::Object& image) {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
device_,
"This descriptor set is in an invalid state! "
Expand All @@ -192,9 +192,9 @@ Descriptor::Set& Descriptor::Set::bind(
type,
{
.image = {
// image.sampler,
// image.view,
// image.layout
image.sampler,
image.view,
image.layout
},
},
});
Expand Down
4 changes: 2 additions & 2 deletions aten/src/ATen/native/vulkan/api/Descriptor.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,12 +68,12 @@ struct Descriptor final {
Set& bind(
uint32_t binding,
VkDescriptorType type,
const Resource::Buffer& buffer);
const Resource::Buffer::Object& buffer);

Set& bind(
uint32_t binding,
VkDescriptorType type,
const Resource::Image& image);
const Resource::Image::Object& image);

VkDescriptorSet handle() const;

Expand Down
83 changes: 72 additions & 11 deletions aten/src/ATen/native/vulkan/api/Resource.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,21 +63,23 @@ void release_buffer(const Resource::Buffer& buffer) {
// Safe to pass null as buffer or allocation.
vmaDestroyBuffer(
buffer.memory.allocator,
buffer.handle,
buffer.object.handle,
buffer.memory.allocation);
}

void release_image(const Resource::Image& image) {
if (VK_NULL_HANDLE != image.view) {
// Sampler is an immutable object. Its lifetime is managed through the cache.

if (VK_NULL_HANDLE != image.object.view) {
VmaAllocatorInfo allocator_info{};
vmaGetAllocatorInfo(image.memory.allocator, &allocator_info);
vkDestroyImageView(allocator_info.device, image.view, nullptr);
vkDestroyImageView(allocator_info.device, image.object.view, nullptr);
}

// Safe to pass null as image or allocation.
vmaDestroyImage(
image.memory.allocator,
image.handle,
image.object.handle,
image.memory.allocation);
}

Expand Down Expand Up @@ -127,23 +129,73 @@ void Resource::Memory::Scope::operator()(const void* const data) const {
}
}

Resource::Image::Sampler::Factory::Factory(const GPU& gpu)
: device_(gpu.device) {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
device_,
"Invalid Vulkan device!");
}

typename Resource::Image::Sampler::Factory::Handle
Resource::Image::Sampler::Factory::operator()(
const Descriptor& descriptor) const {
const VkSamplerCreateInfo sampler_create_info{
VK_STRUCTURE_TYPE_SAMPLER_CREATE_INFO,
nullptr,
0u,
descriptor.filter,
descriptor.filter,
descriptor.mipmap_mode,
descriptor.address_mode,
descriptor.address_mode,
descriptor.address_mode,
0.0f,
VK_FALSE,
0.0f,
VK_FALSE,
VK_COMPARE_OP_NEVER,
0.0f,
0.0f,
descriptor.border,
VK_FALSE,
};

VkSampler sampler{};
VK_CHECK(vkCreateSampler(
device_,
&sampler_create_info,
nullptr,
&sampler));

TORCH_CHECK(
sampler,
"Invalid Vulkan image sampler!");

return Handle{
sampler,
Deleter(device_),
};
}

Resource::Pool::Pool(const GPU& gpu)
: device_(gpu.device),
allocator_(
create_allocator(
gpu.adapter->runtime->instance(),
gpu.adapter->handle,
device_),
vmaDestroyAllocator) {
buffers_.reserve(Configuration::kReserve);
images_.reserve(Configuration::kReserve);
vmaDestroyAllocator),
sampler_(gpu) {
buffers_.reserve(Configuration::kReserve);
images_.reserve(Configuration::kReserve);
}

Resource::Pool::Pool(Pool&& pool)
: device_(std::move(pool.device_)),
allocator_(std::move(pool.allocator_)),
buffers_(std::move(pool.buffers_)),
images_(std::move(pool.images_)) {
images_(std::move(pool.images_)),
sampler_(std::move(pool.sampler_)) {
pool.device_ = VK_NULL_HANDLE;
}

Expand All @@ -153,6 +205,7 @@ Resource::Pool& Resource::Pool::operator=(Pool&& pool) {
allocator_ = std::move(pool.allocator_);
buffers_ = std::move(pool.buffers_);
images_ = std::move(pool.images_);
sampler_ = std::move(pool.sampler_);

pool.device_ = VK_NULL_HANDLE;
};
Expand Down Expand Up @@ -198,7 +251,11 @@ Resource::Buffer Resource::Pool::buffer(

buffers_.emplace_back(
Buffer{
buffer,
Buffer::Object{
buffer,
0u,
descriptor.size,
},
Memory{
allocator_.get(),
allocation,
Expand Down Expand Up @@ -287,8 +344,12 @@ Resource::Image Resource::Pool::image(

images_.emplace_back(
Image{
image,
view,
Image::Object{
image,
VK_IMAGE_LAYOUT_UNDEFINED,
view,
sampler_.cache.retrieve(descriptor.sampler),
},
Memory{
allocator_.get(),
allocation,
Expand Down

0 comments on commit 48fdf9d

Please sign in to comment.