diff --git a/Android.mk b/Android.mk index afa04039337..3b3b319e67e 100644 --- a/Android.mk +++ b/Android.mk @@ -288,6 +288,7 @@ $(eval $(call gen_spvtools_vendor_tables,$(SPVTOOLS_OUT_PATH),spv-amd-shader-bal $(eval $(call gen_spvtools_vendor_tables,$(SPVTOOLS_OUT_PATH),spv-amd-shader-explicit-vertex-parameter,"")) $(eval $(call gen_spvtools_vendor_tables,$(SPVTOOLS_OUT_PATH),spv-amd-shader-trinary-minmax,"")) $(eval $(call gen_spvtools_vendor_tables,$(SPVTOOLS_OUT_PATH),nonsemantic.clspvreflection,"")) +$(eval $(call gen_spvtools_vendor_tables,$(SPVTOOLS_OUT_PATH),nonsemantic.vkspreflection,"")) define gen_spvtools_enum_string_mapping $(call generate-file-dir,$(1)/extension_enum.inc.inc) diff --git a/BUILD.bazel b/BUILD.bazel index b83fd5aeadc..c7b03c5ac40 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -58,6 +58,8 @@ generate_vendor_tables(extension = "debuginfo") generate_vendor_tables(extension = "nonsemantic.clspvreflection") +generate_vendor_tables(extension = "nonsemantic.vkspreflection") + generate_vendor_tables( extension = "opencl.debuginfo.100", operand_kind_prefix = "CLDEBUG100_", @@ -147,6 +149,7 @@ cc_library( ":gen_vendor_tables_debuginfo", ":gen_vendor_tables_nonsemantic_clspvreflection", ":gen_vendor_tables_nonsemantic_shader_debuginfo_100", + ":gen_vendor_tables_nonsemantic_vkspreflection", ":gen_vendor_tables_opencl_debuginfo_100", ":gen_vendor_tables_spv_amd_gcn_shader", ":gen_vendor_tables_spv_amd_shader_ballot", diff --git a/BUILD.gn b/BUILD.gn index 9ff36d83db1..fb860bafa1d 100644 --- a/BUILD.gn +++ b/BUILD.gn @@ -331,6 +331,10 @@ spvtools_vendor_tables = [ "nonsemantic.shader.debuginfo.100", "SHDEBUG100_", ], + [ + "nonsemantic.vkspreflection", + "...nil...", + ], ] foreach(table_def, spvtools_vendor_tables) { diff --git a/include/spirv-tools/libspirv.h b/include/spirv-tools/libspirv.h index b70f084a8e4..0de38758bcd 100644 --- a/include/spirv-tools/libspirv.h +++ b/include/spirv-tools/libspirv.h @@ -327,6 +327,7 @@ typedef enum spv_ext_inst_type_t { SPV_EXT_INST_TYPE_OPENCL_DEBUGINFO_100, SPV_EXT_INST_TYPE_NONSEMANTIC_CLSPVREFLECTION, SPV_EXT_INST_TYPE_NONSEMANTIC_SHADER_DEBUGINFO_100, + SPV_EXT_INST_TYPE_NONSEMANTIC_VKSPREFLECTION, // Multiple distinct extended instruction set types could return this // value, if they are prefixed with NonSemantic. and are otherwise diff --git a/include/spirv-tools/optimizer.hpp b/include/spirv-tools/optimizer.hpp index 53ebc59f005..263a73a9311 100644 --- a/include/spirv-tools/optimizer.hpp +++ b/include/spirv-tools/optimizer.hpp @@ -1000,6 +1000,133 @@ Optimizer::PassToken CreateSwitchDescriptorSetPass(uint32_t ds_from, // OpBeginInterlockInvocationEXT and one OpEndInterlockInvocationEXT, in that // order. Optimizer::PassToken CreateInvocationInterlockPlacementPass(); + +struct vksp_push_constant { + uint32_t offset; + uint32_t size; + uint32_t stageFlags; + const char* pValues; +}; + +#define VKSP_DESCRIPTOR_TYPE_STORAGE_BUFFER_COUNTER_BITS (0xf0000000) +#define VKSP_DESCRIPTOR_TYPE_STORAGE_BUFFER_COUNTER_MASK \ + (~VKSP_DESCRIPTOR_TYPE_STORAGE_BUFFER_COUNTER_BITS) +#define VKSP_DESCRIPTOR_TYPE_STORAGE_BUFFER_COUNTER \ + (VK_DESCRIPTOR_TYPE_STORAGE_BUFFER | \ + VKSP_DESCRIPTOR_TYPE_STORAGE_BUFFER_COUNTER_BITS) +struct vksp_descriptor_set { + uint32_t ds; + uint32_t binding; + uint32_t type; + union { + struct { + uint32_t flags; + uint32_t queueFamilyIndexCount; + uint32_t sharingMode; + uint32_t size; + uint32_t usage; + uint32_t range; + uint32_t offset; + uint32_t memorySize; + uint32_t memoryType; + uint32_t bindOffset; + } buffer; + struct { + uint32_t imageLayout; + uint32_t imageFlags; + uint32_t imageType; + uint32_t format; + uint32_t width; + uint32_t height; + uint32_t depth; + uint32_t mipLevels; + uint32_t arrayLayers; + uint32_t samples; + uint32_t tiling; + uint32_t usage; + uint32_t sharingMode; + uint32_t queueFamilyIndexCount; + uint32_t initialLayout; + uint32_t aspectMask; + uint32_t baseMipLevel; + uint32_t levelCount; + uint32_t baseArrayLayer; + uint32_t layerCount; + uint32_t viewFlags; + uint32_t viewType; + uint32_t viewFormat; + uint32_t component_a; + uint32_t component_b; + uint32_t component_g; + uint32_t component_r; + uint32_t memorySize; + uint32_t memoryType; + uint32_t bindOffset; + } image; + struct { + uint32_t flags; + uint32_t magFilter; + uint32_t minFilter; + uint32_t mipmapMode; + uint32_t addressModeU; + uint32_t addressModeV; + uint32_t addressModeW; + union { + float fMipLodBias; + uint32_t uMipLodBias; + }; + uint32_t anisotropyEnable; + union { + float fMaxAnisotropy; + uint32_t uMaxAnisotropy; + }; + uint32_t compareEnable; + uint32_t compareOp; + union { + float fMinLod; + uint32_t uMinLod; + }; + union { + float fMaxLod; + uint32_t uMaxLod; + }; + uint32_t borderColor; + uint32_t unnormalizedCoordinates; + } sampler; + }; +}; + +struct vksp_configuration { + const char* enabledExtensionNames; + uint32_t specializationInfoDataSize; + const char* specializationInfoData; + const char* shaderName; + const char* entryPoint; + uint32_t groupCountX; + uint32_t groupCountY; + uint32_t groupCountZ; +}; + +struct vksp_specialization_map_entry { + uint32_t constantID; + uint32_t offset; + uint32_t size; +}; + +struct vksp_counter { + uint32_t index; + const char* name; +}; + +Optimizer::PassToken CreateInsertVkspReflectInfoPass( + std::vector* pc, std::vector* ds, + std::vector* me, vksp_configuration* config); + +Optimizer::PassToken CreateExtractVkspReflectInfoPass( + std::vector* pc, std::vector* ds, + std::vector* me, + std::vector* counters, vksp_configuration* config); + } // namespace spvtools #endif // INCLUDE_SPIRV_TOOLS_OPTIMIZER_HPP_ diff --git a/source/CMakeLists.txt b/source/CMakeLists.txt index f4ee3c84cf6..d0454c6c706 100644 --- a/source/CMakeLists.txt +++ b/source/CMakeLists.txt @@ -156,6 +156,7 @@ spvtools_vendor_tables("debuginfo" "debuginfo" "") spvtools_vendor_tables("opencl.debuginfo.100" "cldi100" "CLDEBUG100_") spvtools_vendor_tables("nonsemantic.shader.debuginfo.100" "shdi100" "SHDEBUG100_") spvtools_vendor_tables("nonsemantic.clspvreflection" "clspvreflection" "") +spvtools_vendor_tables("nonsemantic.vkspreflection" "vkspreflection" "") spvtools_extinst_lang_headers("DebugInfo" ${DEBUGINFO_GRAMMAR_JSON_FILE}) spvtools_extinst_lang_headers("OpenCLDebugInfo100" ${CLDEBUGINFO100_GRAMMAR_JSON_FILE}) spvtools_extinst_lang_headers("NonSemanticShaderDebugInfo100" ${VKDEBUGINFO100_GRAMMAR_JSON_FILE}) diff --git a/source/ext_inst.cpp b/source/ext_inst.cpp index 4e2795453f4..9a5ba84e466 100644 --- a/source/ext_inst.cpp +++ b/source/ext_inst.cpp @@ -30,6 +30,7 @@ #include "glsl.std.450.insts.inc" #include "nonsemantic.clspvreflection.insts.inc" #include "nonsemantic.shader.debuginfo.100.insts.inc" +#include "nonsemantic.vkspreflection.insts.inc" #include "opencl.debuginfo.100.insts.inc" #include "opencl.std.insts.inc" @@ -62,6 +63,9 @@ static const spv_ext_inst_group_t kGroups_1_0[] = { {SPV_EXT_INST_TYPE_NONSEMANTIC_CLSPVREFLECTION, ARRAY_SIZE(nonsemantic_clspvreflection_entries), nonsemantic_clspvreflection_entries}, + {SPV_EXT_INST_TYPE_NONSEMANTIC_VKSPREFLECTION, + ARRAY_SIZE(nonsemantic_vkspreflection_entries), + nonsemantic_vkspreflection_entries}, }; static const spv_ext_inst_table_t kTable_1_0 = {ARRAY_SIZE(kGroups_1_0), @@ -138,6 +142,9 @@ spv_ext_inst_type_t spvExtInstImportTypeGet(const char* name) { if (!strncmp("NonSemantic.ClspvReflection.", name, 28)) { return SPV_EXT_INST_TYPE_NONSEMANTIC_CLSPVREFLECTION; } + if (!strncmp("NonSemantic.VkspReflection.", name, 27)) { + return SPV_EXT_INST_TYPE_NONSEMANTIC_VKSPREFLECTION; + } // ensure to add any known non-semantic extended instruction sets // above this point, and update spvExtInstIsNonSemantic() if (!strncmp("NonSemantic.", name, 12)) { @@ -149,7 +156,8 @@ spv_ext_inst_type_t spvExtInstImportTypeGet(const char* name) { bool spvExtInstIsNonSemantic(const spv_ext_inst_type_t type) { if (type == SPV_EXT_INST_TYPE_NONSEMANTIC_UNKNOWN || type == SPV_EXT_INST_TYPE_NONSEMANTIC_SHADER_DEBUGINFO_100 || - type == SPV_EXT_INST_TYPE_NONSEMANTIC_CLSPVREFLECTION) { + type == SPV_EXT_INST_TYPE_NONSEMANTIC_CLSPVREFLECTION || + type == SPV_EXT_INST_TYPE_NONSEMANTIC_VKSPREFLECTION) { return true; } return false; diff --git a/source/opt/CMakeLists.txt b/source/opt/CMakeLists.txt index 6ebbfbf0053..f183d47b154 100644 --- a/source/opt/CMakeLists.txt +++ b/source/opt/CMakeLists.txt @@ -133,6 +133,7 @@ set(SPIRV_TOOLS_OPT_SOURCES vector_dce.h workaround1209.h wrap_opkill.h + vksp_passes.h fix_func_call_arguments.cpp aggressive_dead_code_elim_pass.cpp @@ -250,6 +251,7 @@ set(SPIRV_TOOLS_OPT_SOURCES vector_dce.cpp workaround1209.cpp wrap_opkill.cpp + vksp_passes.cpp ) if(MSVC AND (NOT ("${CMAKE_CXX_COMPILER_ID}" MATCHES "Clang"))) diff --git a/source/opt/constants.cpp b/source/opt/constants.cpp index a487a45b885..6eebbb572a9 100644 --- a/source/opt/constants.cpp +++ b/source/opt/constants.cpp @@ -498,7 +498,7 @@ const Constant* ConstantManager::GetIntConst(uint64_t val, int32_t bitWidth, int32_t num_of_bit_to_ignore = 64 - bitWidth; val = static_cast(val << num_of_bit_to_ignore) >> num_of_bit_to_ignore; - } else { + } else if (bitWidth < 64) { // Clear the upper bit that are not used. uint64_t mask = ((1ull << bitWidth) - 1); val &= mask; @@ -511,7 +511,7 @@ const Constant* ConstantManager::GetIntConst(uint64_t val, int32_t bitWidth, // If the value is more than 32-bit, we need to split the operands into two // 32-bit integers. return GetConstant( - int_type, {static_cast(val >> 32), static_cast(val)}); + int_type, {static_cast(val), static_cast(val >> 32)}); } uint32_t ConstantManager::GetUIntConstId(uint32_t val) { diff --git a/source/opt/optimizer.cpp b/source/opt/optimizer.cpp index d865cf1d47d..4d825f0645e 100644 --- a/source/opt/optimizer.cpp +++ b/source/opt/optimizer.cpp @@ -1122,6 +1122,23 @@ Optimizer::PassToken CreateInvocationInterlockPlacementPass() { return MakeUnique( MakeUnique()); } + +Optimizer::PassToken CreateInsertVkspReflectInfoPass( + std::vector* pc, std::vector* ds, + std::vector* me, + vksp_configuration* config) { + return MakeUnique( + MakeUnique(pc, ds, me, config)); +} +Optimizer::PassToken CreateExtractVkspReflectInfoPass( + std::vector* pc, std::vector* ds, + std::vector* me, + std::vector* counters, vksp_configuration* config) { + return MakeUnique( + MakeUnique(pc, ds, me, counters, + config)); +} + } // namespace spvtools extern "C" { diff --git a/source/opt/passes.h b/source/opt/passes.h index 305f5782792..c5b743235cf 100644 --- a/source/opt/passes.h +++ b/source/opt/passes.h @@ -90,5 +90,6 @@ #include "source/opt/vector_dce.h" #include "source/opt/workaround1209.h" #include "source/opt/wrap_opkill.h" +#include "source/opt/vksp_passes.h" #endif // SOURCE_OPT_PASSES_H_ diff --git a/source/opt/vksp_passes.cpp b/source/opt/vksp_passes.cpp new file mode 100644 index 00000000000..baf01992c19 --- /dev/null +++ b/source/opt/vksp_passes.cpp @@ -0,0 +1,770 @@ +// Copyright (c) 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/opt/vksp_passes.h" + +#include "source/opt/instruction.h" +#include "source/opt/types.h" +#include "spirv-tools/libspirv.h" +#include "spirv-tools/optimizer.hpp" +#include "spirv/unified1/NonSemanticVkspReflection.h" +#include "spirv/unified1/spirv.hpp11" +#include "vulkan/vulkan.h" + +namespace spvtools { +namespace opt { + +Pass::Status InsertVkspReflectInfoPass::Process() { + auto module = context()->module(); + + std::vector ext_words = + spvtools::utils::MakeVector("NonSemantic.VkspReflection.1"); + auto ExtInstId = context()->TakeNextId(); + auto ExtInst = + new Instruction(context(), spv::Op::OpExtInstImport, 0u, ExtInstId, + {{SPV_OPERAND_TYPE_LITERAL_STRING, ext_words}}); + module->AddExtInstImport(std::unique_ptr(ExtInst)); + + uint32_t void_ty_id = context()->get_type_mgr()->GetVoidTypeId(); + + std::vector enabledExtensions = + spvtools::utils::MakeVector(config_->enabledExtensionNames); + std::vector pData = + spvtools::utils::MakeVector(config_->specializationInfoData); + std::vector shaderName = + spvtools::utils::MakeVector(config_->shaderName); + std::vector entryPoint = + spvtools::utils::MakeVector(config_->entryPoint); + auto ConfigId = context()->TakeNextId(); + auto ConfigInst = new Instruction( + context(), spv::Op::OpExtInst, void_ty_id, ConfigId, + { + {SPV_OPERAND_TYPE_ID, {ExtInstId}}, + {SPV_OPERAND_TYPE_EXTENSION_INSTRUCTION_NUMBER, + {NonSemanticVkspReflectionConfiguration}}, + {SPV_OPERAND_TYPE_LITERAL_STRING, enabledExtensions}, + {SPV_OPERAND_TYPE_LITERAL_INTEGER, + {config_->specializationInfoDataSize}}, + {SPV_OPERAND_TYPE_LITERAL_STRING, pData}, + {SPV_OPERAND_TYPE_LITERAL_STRING, shaderName}, + {SPV_OPERAND_TYPE_LITERAL_STRING, entryPoint}, + {SPV_OPERAND_TYPE_LITERAL_INTEGER, {config_->groupCountX}}, + {SPV_OPERAND_TYPE_LITERAL_INTEGER, {config_->groupCountY}}, + {SPV_OPERAND_TYPE_LITERAL_INTEGER, {config_->groupCountZ}}, + }); + module->AddExtInstDebugInfo(std::unique_ptr(ConfigInst)); + + for (auto& pc : *pc_) { + std::vector pValues = spvtools::utils::MakeVector(pc.pValues); + auto PcInstId = context()->TakeNextId(); + auto PcInst = + new Instruction(context(), spv::Op::OpExtInst, void_ty_id, PcInstId, + { + {SPV_OPERAND_TYPE_ID, {ExtInstId}}, + {SPV_OPERAND_TYPE_EXTENSION_INSTRUCTION_NUMBER, + {NonSemanticVkspReflectionPushConstants}}, + {SPV_OPERAND_TYPE_LITERAL_INTEGER, {pc.offset}}, + {SPV_OPERAND_TYPE_LITERAL_INTEGER, {pc.size}}, + {SPV_OPERAND_TYPE_LITERAL_STRING, pValues}, + {SPV_OPERAND_TYPE_LITERAL_INTEGER, {pc.stageFlags}}, + }); + module->AddExtInstDebugInfo(std::unique_ptr(PcInst)); + } + + for (auto& ds : *ds_) { + switch (ds.type) { + case VK_DESCRIPTOR_TYPE_STORAGE_BUFFER: + case VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER: { + auto DsInstId = context()->TakeNextId(); + auto DstInst = new Instruction( + context(), spv::Op::OpExtInst, void_ty_id, DsInstId, + { + {SPV_OPERAND_TYPE_ID, {ExtInstId}}, + {SPV_OPERAND_TYPE_EXTENSION_INSTRUCTION_NUMBER, + {NonSemanticVkspReflectionDescriptorSetBuffer}}, + {SPV_OPERAND_TYPE_LITERAL_INTEGER, {ds.ds}}, + {SPV_OPERAND_TYPE_LITERAL_INTEGER, {ds.binding}}, + {SPV_OPERAND_TYPE_LITERAL_INTEGER, {ds.type}}, + {SPV_OPERAND_TYPE_LITERAL_INTEGER, {ds.buffer.flags}}, + {SPV_OPERAND_TYPE_LITERAL_INTEGER, + {ds.buffer.queueFamilyIndexCount}}, + {SPV_OPERAND_TYPE_LITERAL_INTEGER, {ds.buffer.sharingMode}}, + {SPV_OPERAND_TYPE_LITERAL_INTEGER, {ds.buffer.size}}, + {SPV_OPERAND_TYPE_LITERAL_INTEGER, {ds.buffer.usage}}, + {SPV_OPERAND_TYPE_LITERAL_INTEGER, {ds.buffer.range}}, + {SPV_OPERAND_TYPE_LITERAL_INTEGER, {ds.buffer.offset}}, + {SPV_OPERAND_TYPE_LITERAL_INTEGER, {ds.buffer.memorySize}}, + {SPV_OPERAND_TYPE_LITERAL_INTEGER, {ds.buffer.memoryType}}, + {SPV_OPERAND_TYPE_LITERAL_INTEGER, {ds.buffer.bindOffset}}, + }); + module->AddExtInstDebugInfo(std::unique_ptr(DstInst)); + } break; + case VK_DESCRIPTOR_TYPE_SAMPLED_IMAGE: + case VK_DESCRIPTOR_TYPE_STORAGE_IMAGE: { + auto DsInstId = context()->TakeNextId(); + auto DstInst = new Instruction( + context(), spv::Op::OpExtInst, void_ty_id, DsInstId, + { + {SPV_OPERAND_TYPE_ID, {ExtInstId}}, + {SPV_OPERAND_TYPE_EXTENSION_INSTRUCTION_NUMBER, + {NonSemanticVkspReflectionDescriptorSetImage}}, + {SPV_OPERAND_TYPE_LITERAL_INTEGER, {ds.ds}}, + {SPV_OPERAND_TYPE_LITERAL_INTEGER, {ds.binding}}, + {SPV_OPERAND_TYPE_LITERAL_INTEGER, {ds.type}}, + {SPV_OPERAND_TYPE_LITERAL_INTEGER, {ds.image.imageLayout}}, + {SPV_OPERAND_TYPE_LITERAL_INTEGER, {ds.image.imageFlags}}, + {SPV_OPERAND_TYPE_LITERAL_INTEGER, {ds.image.imageType}}, + {SPV_OPERAND_TYPE_LITERAL_INTEGER, {ds.image.format}}, + {SPV_OPERAND_TYPE_LITERAL_INTEGER, {ds.image.width}}, + {SPV_OPERAND_TYPE_LITERAL_INTEGER, {ds.image.height}}, + {SPV_OPERAND_TYPE_LITERAL_INTEGER, {ds.image.depth}}, + {SPV_OPERAND_TYPE_LITERAL_INTEGER, {ds.image.mipLevels}}, + {SPV_OPERAND_TYPE_LITERAL_INTEGER, {ds.image.arrayLayers}}, + {SPV_OPERAND_TYPE_LITERAL_INTEGER, {ds.image.samples}}, + {SPV_OPERAND_TYPE_LITERAL_INTEGER, {ds.image.tiling}}, + {SPV_OPERAND_TYPE_LITERAL_INTEGER, {ds.image.usage}}, + {SPV_OPERAND_TYPE_LITERAL_INTEGER, {ds.image.sharingMode}}, + {SPV_OPERAND_TYPE_LITERAL_INTEGER, + {ds.image.queueFamilyIndexCount}}, + {SPV_OPERAND_TYPE_LITERAL_INTEGER, {ds.image.initialLayout}}, + {SPV_OPERAND_TYPE_LITERAL_INTEGER, {ds.image.aspectMask}}, + {SPV_OPERAND_TYPE_LITERAL_INTEGER, {ds.image.baseMipLevel}}, + {SPV_OPERAND_TYPE_LITERAL_INTEGER, {ds.image.levelCount}}, + {SPV_OPERAND_TYPE_LITERAL_INTEGER, {ds.image.baseArrayLayer}}, + {SPV_OPERAND_TYPE_LITERAL_INTEGER, {ds.image.layerCount}}, + {SPV_OPERAND_TYPE_LITERAL_INTEGER, {ds.image.viewFlags}}, + {SPV_OPERAND_TYPE_LITERAL_INTEGER, {ds.image.viewType}}, + {SPV_OPERAND_TYPE_LITERAL_INTEGER, {ds.image.viewFormat}}, + {SPV_OPERAND_TYPE_LITERAL_INTEGER, {ds.image.component_a}}, + {SPV_OPERAND_TYPE_LITERAL_INTEGER, {ds.image.component_b}}, + {SPV_OPERAND_TYPE_LITERAL_INTEGER, {ds.image.component_g}}, + {SPV_OPERAND_TYPE_LITERAL_INTEGER, {ds.image.component_r}}, + {SPV_OPERAND_TYPE_LITERAL_INTEGER, {ds.image.memorySize}}, + {SPV_OPERAND_TYPE_LITERAL_INTEGER, {ds.image.memoryType}}, + {SPV_OPERAND_TYPE_LITERAL_INTEGER, {ds.image.bindOffset}}, + }); + module->AddExtInstDebugInfo(std::unique_ptr(DstInst)); + } break; + case VK_DESCRIPTOR_TYPE_SAMPLER: { + auto DsInstId = context()->TakeNextId(); + auto DstInst = new Instruction( + context(), spv::Op::OpExtInst, void_ty_id, DsInstId, + { + {SPV_OPERAND_TYPE_ID, {ExtInstId}}, + {SPV_OPERAND_TYPE_EXTENSION_INSTRUCTION_NUMBER, + {NonSemanticVkspReflectionDescriptorSetSampler}}, + {SPV_OPERAND_TYPE_LITERAL_INTEGER, {ds.ds}}, + {SPV_OPERAND_TYPE_LITERAL_INTEGER, {ds.binding}}, + {SPV_OPERAND_TYPE_LITERAL_INTEGER, {ds.type}}, + {SPV_OPERAND_TYPE_LITERAL_INTEGER, {ds.sampler.flags}}, + {SPV_OPERAND_TYPE_LITERAL_INTEGER, {ds.sampler.magFilter}}, + {SPV_OPERAND_TYPE_LITERAL_INTEGER, {ds.sampler.minFilter}}, + {SPV_OPERAND_TYPE_LITERAL_INTEGER, {ds.sampler.mipmapMode}}, + {SPV_OPERAND_TYPE_LITERAL_INTEGER, {ds.sampler.addressModeU}}, + {SPV_OPERAND_TYPE_LITERAL_INTEGER, {ds.sampler.addressModeV}}, + {SPV_OPERAND_TYPE_LITERAL_INTEGER, {ds.sampler.addressModeW}}, + {SPV_OPERAND_TYPE_LITERAL_FLOAT, {ds.sampler.uMipLodBias}}, + {SPV_OPERAND_TYPE_LITERAL_INTEGER, + {ds.sampler.anisotropyEnable}}, + {SPV_OPERAND_TYPE_LITERAL_FLOAT, {ds.sampler.uMaxAnisotropy}}, + {SPV_OPERAND_TYPE_LITERAL_INTEGER, {ds.sampler.compareEnable}}, + {SPV_OPERAND_TYPE_LITERAL_INTEGER, {ds.sampler.compareOp}}, + {SPV_OPERAND_TYPE_LITERAL_FLOAT, {ds.sampler.uMinLod}}, + {SPV_OPERAND_TYPE_LITERAL_FLOAT, {ds.sampler.uMaxLod}}, + {SPV_OPERAND_TYPE_LITERAL_INTEGER, {ds.sampler.borderColor}}, + {SPV_OPERAND_TYPE_LITERAL_INTEGER, + {ds.sampler.unnormalizedCoordinates}}, + }); + module->AddExtInstDebugInfo(std::unique_ptr(DstInst)); + } break; + default: + break; + } + } + + for (auto& me : *me_) { + auto MapEntryId = context()->TakeNextId(); + auto MapEntryInst = + new Instruction(context(), spv::Op::OpExtInst, void_ty_id, MapEntryId, + { + {SPV_OPERAND_TYPE_ID, {ExtInstId}}, + {SPV_OPERAND_TYPE_EXTENSION_INSTRUCTION_NUMBER, + {NonSemanticVkspReflectionSpecializationMapEntry}}, + {SPV_OPERAND_TYPE_LITERAL_INTEGER, {me.constantID}}, + {SPV_OPERAND_TYPE_LITERAL_INTEGER, {me.offset}}, + {SPV_OPERAND_TYPE_LITERAL_INTEGER, {me.size}}, + }); + module->AddExtInstDebugInfo(std::unique_ptr(MapEntryInst)); + } + + return Status::SuccessWithChange; +} + +int32_t ExtractVkspReflectInfoPass::UpdateMaxBinding(uint32_t ds, + uint32_t binding, + int32_t max_binding) { + if (ds != 0) { + return max_binding; + } else { + return std::max(max_binding, (int32_t)binding); + } +} + +void ExtractVkspReflectInfoPass::ParseInstruction( + Instruction* inst, uint32_t ext_inst_id, + std::map& id_to_descriptor_set, + std::map& id_to_binding, + int32_t& descriptor_set_0_max_binding, + std::vector& start_counters, + std::vector& stop_counters) { + uint32_t op_id = 2; + if (inst->opcode() == spv::Op::OpDecorate) { + spv::Decoration decoration = (spv::Decoration)inst->GetOperand(1).words[0]; + if (decoration == spv::Decoration::DescriptorSet) { + auto id = inst->GetOperand(0).AsId(); + auto ds = inst->GetOperand(2).words[0]; + id_to_descriptor_set[id] = ds; + if (ds == 0 && id_to_binding.count(id) > 0) { + descriptor_set_0_max_binding = UpdateMaxBinding( + ds, id_to_binding[id], descriptor_set_0_max_binding); + } + } else if (decoration == spv::Decoration::Binding) { + auto id = inst->GetOperand(0).AsId(); + auto binding = inst->GetOperand(2).words[0]; + id_to_binding[id] = binding; + if (id_to_descriptor_set.count(id) > 0) { + descriptor_set_0_max_binding = UpdateMaxBinding( + id_to_descriptor_set[id], binding, descriptor_set_0_max_binding); + } + } + return; + } else if (inst->opcode() != spv::Op::OpExtInst || + ext_inst_id != inst->GetOperand(op_id++).AsId()) { + return; + } + + auto vksp_inst = inst->GetOperand(op_id++).words[0]; + switch (vksp_inst) { + case NonSemanticVkspReflectionConfiguration: + config_->enabledExtensionNames = + strdup(inst->GetOperand(op_id++).AsString().c_str()); + config_->specializationInfoDataSize = inst->GetOperand(op_id++).words[0]; + config_->specializationInfoData = + strdup(inst->GetOperand(op_id++).AsString().c_str()); + config_->shaderName = + strdup(inst->GetOperand(op_id++).AsString().c_str()); + config_->entryPoint = + strdup(inst->GetOperand(op_id++).AsString().c_str()); + config_->groupCountX = inst->GetOperand(op_id++).words[0]; + config_->groupCountY = inst->GetOperand(op_id++).words[0]; + config_->groupCountZ = inst->GetOperand(op_id++).words[0]; + break; + case NonSemanticVkspReflectionDescriptorSetBuffer: { + vksp_descriptor_set ds; + ds.ds = inst->GetOperand(op_id++).words[0]; + ds.binding = inst->GetOperand(op_id++).words[0]; + ds.type = inst->GetOperand(op_id++).words[0]; + ds.buffer.flags = inst->GetOperand(op_id++).words[0]; + ds.buffer.queueFamilyIndexCount = inst->GetOperand(op_id++).words[0]; + ds.buffer.sharingMode = inst->GetOperand(op_id++).words[0]; + ds.buffer.size = inst->GetOperand(op_id++).words[0]; + ds.buffer.usage = inst->GetOperand(op_id++).words[0]; + ds.buffer.range = inst->GetOperand(op_id++).words[0]; + ds.buffer.offset = inst->GetOperand(op_id++).words[0]; + ds.buffer.memorySize = inst->GetOperand(op_id++).words[0]; + ds.buffer.memoryType = inst->GetOperand(op_id++).words[0]; + ds.buffer.bindOffset = inst->GetOperand(op_id++).words[0]; + ds_->push_back(ds); + descriptor_set_0_max_binding = + UpdateMaxBinding(ds.ds, ds.binding, descriptor_set_0_max_binding); + } break; + case NonSemanticVkspReflectionDescriptorSetImage: { + vksp_descriptor_set ds; + ds.ds = inst->GetOperand(op_id++).words[0]; + ds.binding = inst->GetOperand(op_id++).words[0]; + ds.type = inst->GetOperand(op_id++).words[0]; + ds.image.imageLayout = inst->GetOperand(op_id++).words[0]; + ds.image.imageFlags = inst->GetOperand(op_id++).words[0]; + ds.image.imageType = inst->GetOperand(op_id++).words[0]; + ds.image.format = inst->GetOperand(op_id++).words[0]; + ds.image.width = inst->GetOperand(op_id++).words[0]; + ds.image.height = inst->GetOperand(op_id++).words[0]; + ds.image.depth = inst->GetOperand(op_id++).words[0]; + ds.image.mipLevels = inst->GetOperand(op_id++).words[0]; + ds.image.arrayLayers = inst->GetOperand(op_id++).words[0]; + ds.image.samples = inst->GetOperand(op_id++).words[0]; + ds.image.tiling = inst->GetOperand(op_id++).words[0]; + ds.image.usage = inst->GetOperand(op_id++).words[0]; + ds.image.sharingMode = inst->GetOperand(op_id++).words[0]; + ds.image.queueFamilyIndexCount = inst->GetOperand(op_id++).words[0]; + ds.image.initialLayout = inst->GetOperand(op_id++).words[0]; + ds.image.aspectMask = inst->GetOperand(op_id++).words[0]; + ds.image.baseMipLevel = inst->GetOperand(op_id++).words[0]; + ds.image.levelCount = inst->GetOperand(op_id++).words[0]; + ds.image.baseArrayLayer = inst->GetOperand(op_id++).words[0]; + ds.image.layerCount = inst->GetOperand(op_id++).words[0]; + ds.image.viewFlags = inst->GetOperand(op_id++).words[0]; + ds.image.viewType = inst->GetOperand(op_id++).words[0]; + ds.image.viewFormat = inst->GetOperand(op_id++).words[0]; + ds.image.component_a = inst->GetOperand(op_id++).words[0]; + ds.image.component_b = inst->GetOperand(op_id++).words[0]; + ds.image.component_g = inst->GetOperand(op_id++).words[0]; + ds.image.component_r = inst->GetOperand(op_id++).words[0]; + ds.image.memorySize = inst->GetOperand(op_id++).words[0]; + ds.image.memoryType = inst->GetOperand(op_id++).words[0]; + ds.image.bindOffset = inst->GetOperand(op_id++).words[0]; + ds_->push_back(ds); + descriptor_set_0_max_binding = + UpdateMaxBinding(ds.ds, ds.binding, descriptor_set_0_max_binding); + } break; + case NonSemanticVkspReflectionDescriptorSetSampler: { + vksp_descriptor_set ds; + ds.ds = inst->GetOperand(op_id++).words[0]; + ds.binding = inst->GetOperand(op_id++).words[0]; + ds.type = inst->GetOperand(op_id++).words[0]; + ds.sampler.flags = inst->GetOperand(op_id++).words[0]; + ds.sampler.magFilter = inst->GetOperand(op_id++).words[0]; + ds.sampler.minFilter = inst->GetOperand(op_id++).words[0]; + ds.sampler.mipmapMode = inst->GetOperand(op_id++).words[0]; + ds.sampler.addressModeU = inst->GetOperand(op_id++).words[0]; + ds.sampler.addressModeV = inst->GetOperand(op_id++).words[0]; + ds.sampler.addressModeW = inst->GetOperand(op_id++).words[0]; + ds.sampler.uMipLodBias = inst->GetOperand(op_id++).words[0]; + ds.sampler.anisotropyEnable = inst->GetOperand(op_id++).words[0]; + ds.sampler.uMaxAnisotropy = inst->GetOperand(op_id++).words[0]; + ds.sampler.compareEnable = inst->GetOperand(op_id++).words[0]; + ds.sampler.compareOp = inst->GetOperand(op_id++).words[0]; + ds.sampler.uMinLod = inst->GetOperand(op_id++).words[0]; + ds.sampler.uMaxLod = inst->GetOperand(op_id++).words[0]; + ds.sampler.borderColor = inst->GetOperand(op_id++).words[0]; + ds.sampler.unnormalizedCoordinates = inst->GetOperand(op_id++).words[0]; + ds_->push_back(ds); + descriptor_set_0_max_binding = + UpdateMaxBinding(ds.ds, ds.binding, descriptor_set_0_max_binding); + } break; + case NonSemanticVkspReflectionPushConstants: + vksp_push_constant pc; + pc.offset = inst->GetOperand(op_id++).words[0]; + pc.size = inst->GetOperand(op_id++).words[0]; + pc.pValues = strdup(inst->GetOperand(op_id++).AsString().c_str()); + pc.stageFlags = inst->GetOperand(op_id++).words[0]; + pc_->push_back(pc); + break; + case NonSemanticVkspReflectionSpecializationMapEntry: + vksp_specialization_map_entry me; + me.constantID = inst->GetOperand(op_id++).words[0]; + me.offset = inst->GetOperand(op_id++).words[0]; + me.size = inst->GetOperand(op_id++).words[0]; + me_->push_back(me); + break; + case NonSemanticVkspReflectionStartCounter: + start_counters.push_back(inst); + break; + case NonSemanticVkspReflectionStopCounter: + stop_counters.push_back(inst); + break; + default: + break; + } +} + +void ExtractVkspReflectInfoPass::CreateVariables( + uint32_t u64_arr_ty_id, uint32_t u64_arr_st_ty_id, + uint32_t local_counters_ty_id, uint32_t counters_ty_id, + uint32_t global_counters_ds, uint32_t global_counters_binding, + uint32_t& global_counters_id, uint32_t& local_counters_id) { + auto module = context()->module(); + + auto decorate_arr_inst = new Instruction( + context(), spv::Op::OpDecorate, 0, 0, + {{SPV_OPERAND_TYPE_ID, {u64_arr_ty_id}}, + {SPV_OPERAND_TYPE_DECORATION, {(uint32_t)spv::Decoration::ArrayStride}}, + {SPV_OPERAND_TYPE_LITERAL_INTEGER, {8}}}); + module->AddAnnotationInst(std::unique_ptr(decorate_arr_inst)); + + auto decorate_member_offset_inst = new Instruction( + context(), spv::Op::OpMemberDecorate, 0, 0, + {{SPV_OPERAND_TYPE_ID, {u64_arr_st_ty_id}}, + {SPV_OPERAND_TYPE_LITERAL_INTEGER, {0}}, + {SPV_OPERAND_TYPE_DECORATION, {(uint32_t)spv::Decoration::Offset}}, + {SPV_OPERAND_TYPE_LITERAL_INTEGER, {0}}}); + module->AddAnnotationInst( + std::unique_ptr(decorate_member_offset_inst)); + + auto decorate_arr_st_inst = new Instruction( + context(), spv::Op::OpDecorate, 0, 0, + {{SPV_OPERAND_TYPE_ID, {u64_arr_st_ty_id}}, + {SPV_OPERAND_TYPE_DECORATION, {(uint32_t)spv::Decoration::Block}}}); + module->AddAnnotationInst(std::unique_ptr(decorate_arr_st_inst)); + + local_counters_id = context()->TakeNextId(); + auto local_counters_inst = new Instruction( + context(), spv::Op::OpVariable, local_counters_ty_id, local_counters_id, + {{SPV_OPERAND_TYPE_LITERAL_INTEGER, + {(uint32_t)spv::StorageClass::Private}}}); + module->AddGlobalValue(std::unique_ptr(local_counters_inst)); + + global_counters_id = context()->TakeNextId(); + auto global_counters_inst = new Instruction( + context(), spv::Op::OpVariable, counters_ty_id, global_counters_id, + {{SPV_OPERAND_TYPE_LITERAL_INTEGER, + {(uint32_t)spv::StorageClass::StorageBuffer}}}); + module->AddGlobalValue(std::unique_ptr(global_counters_inst)); + + auto counters_descriptor_set_inst = new Instruction( + context(), spv::Op::OpDecorate, 0, 0, + {{SPV_OPERAND_TYPE_ID, {global_counters_inst->result_id()}}, + {SPV_OPERAND_TYPE_DECORATION, + {(uint32_t)spv::Decoration::DescriptorSet}}, + {SPV_OPERAND_TYPE_LITERAL_INTEGER, {global_counters_ds}}}); + module->AddAnnotationInst( + std::unique_ptr(counters_descriptor_set_inst)); + + auto counters_binding_inst = new Instruction( + context(), spv::Op::OpDecorate, 0, 0, + {{SPV_OPERAND_TYPE_ID, {global_counters_inst->result_id()}}, + {SPV_OPERAND_TYPE_DECORATION, {(uint32_t)spv::Decoration::Binding}}, + {SPV_OPERAND_TYPE_LITERAL_INTEGER, {global_counters_binding}}}); + module->AddAnnotationInst( + std::unique_ptr(counters_binding_inst)); +} + +void ExtractVkspReflectInfoPass::CreatePrologue( + Instruction* entry_point_inst, uint32_t u64_private_ptr_ty_id, + uint32_t u64_ty_id, uint32_t subgroup_scope_id, uint32_t global_counters_id, + uint32_t local_counters_id, std::vector& start_counters, + Function*& function, uint32_t& read_clock_id) { + auto* cst_mgr = context()->get_constant_mgr(); + entry_point_inst->AddOperand({SPV_OPERAND_TYPE_ID, {global_counters_id}}); + entry_point_inst->AddOperand({SPV_OPERAND_TYPE_ID, {local_counters_id}}); + + auto function_id = entry_point_inst->GetOperand(1).AsId(); + function = context()->GetFunction(function_id); + + auto& function_first_inst = *function->entry()->begin(); + + auto u64_cst0_id = + cst_mgr->GetDefiningInstruction(cst_mgr->GetIntConst(0, 64, 0)) + ->result_id(); + + for (unsigned i = 0; i < start_counters.size(); i++) { + auto get_id = context()->TakeNextId(); + auto gep_inst = new Instruction( + context(), spv::Op::OpAccessChain, u64_private_ptr_ty_id, get_id, + {{SPV_OPERAND_TYPE_ID, {local_counters_id}}, + {SPV_OPERAND_TYPE_ID, {cst_mgr->GetUIntConstId(i)}}}); + gep_inst->InsertBefore(&function_first_inst); + + auto store_inst = new Instruction(context(), spv::Op::OpStore, 0, 0, + {{SPV_OPERAND_TYPE_ID, {get_id}}, + {SPV_OPERAND_TYPE_ID, {u64_cst0_id}}}); + store_inst->InsertAfter(gep_inst); + } + + read_clock_id = context()->TakeNextId(); + auto read_clock_inst = new Instruction( + context(), spv::Op::OpReadClockKHR, u64_ty_id, read_clock_id, + {{SPV_OPERAND_TYPE_SCOPE_ID, {subgroup_scope_id}}}); + read_clock_inst->InsertBefore(&function_first_inst); +} + +void ExtractVkspReflectInfoPass::CreateEpilogue( + Instruction* return_inst, uint32_t read_clock_id, uint32_t u64_ty_id, + uint32_t u64_ptr_ty_id, uint32_t u64_private_ptr_ty_id, + uint32_t subgroup_scope_id, uint32_t device_scope_id, + uint32_t acq_rel_mem_sem_id, uint32_t global_counters_id, + uint32_t local_counters_id, std::vector& start_counters) { + auto* cst_mgr = context()->get_constant_mgr(); + + auto read_clock_end_id = context()->TakeNextId(); + auto read_clock_end_inst = new Instruction( + context(), spv::Op::OpReadClockKHR, u64_ty_id, read_clock_end_id, + {{SPV_OPERAND_TYPE_SCOPE_ID, {subgroup_scope_id}}}); + read_clock_end_inst->InsertBefore(return_inst); + + auto substraction_id = context()->TakeNextId(); + auto substraction_inst = + new Instruction(context(), spv::Op::OpISub, u64_ty_id, substraction_id, + {{SPV_OPERAND_TYPE_ID, {read_clock_end_id}}, + {SPV_OPERAND_TYPE_ID, {read_clock_id}}}); + substraction_inst->InsertAfter(read_clock_end_inst); + + auto gep_invocations_id = context()->TakeNextId(); + auto gep_invocations_inst = new Instruction( + context(), spv::Op::OpAccessChain, u64_ptr_ty_id, gep_invocations_id, + {{SPV_OPERAND_TYPE_ID, {global_counters_id}}, + {SPV_OPERAND_TYPE_ID, {cst_mgr->GetUIntConstId(0)}}, + {SPV_OPERAND_TYPE_ID, {cst_mgr->GetUIntConstId(0)}}}); + gep_invocations_inst->InsertAfter(substraction_inst); + + auto atomic_incr_id = context()->TakeNextId(); + auto atomic_incr_inst = new Instruction( + context(), spv::Op::OpAtomicIIncrement, u64_ty_id, atomic_incr_id, + {{SPV_OPERAND_TYPE_ID, {gep_invocations_id}}, + {SPV_OPERAND_TYPE_SCOPE_ID, {device_scope_id}}, + {SPV_OPERAND_TYPE_MEMORY_SEMANTICS_ID, {acq_rel_mem_sem_id}}}); + atomic_incr_inst->InsertAfter(gep_invocations_inst); + + auto gep_entrypoint_counter_id = context()->TakeNextId(); + auto gep_entrypoint_counter_inst = + new Instruction(context(), spv::Op::OpAccessChain, u64_ptr_ty_id, + gep_entrypoint_counter_id, + {{SPV_OPERAND_TYPE_ID, {global_counters_id}}, + {SPV_OPERAND_TYPE_ID, {cst_mgr->GetUIntConstId(0)}}, + {SPV_OPERAND_TYPE_ID, {cst_mgr->GetUIntConstId(1)}}}); + gep_entrypoint_counter_inst->InsertAfter(atomic_incr_inst); + + auto atomic_add_id = context()->TakeNextId(); + auto atomic_add_inst = new Instruction( + context(), spv::Op::OpAtomicIAdd, u64_ty_id, atomic_add_id, + {{SPV_OPERAND_TYPE_ID, {gep_entrypoint_counter_id}}, + {SPV_OPERAND_TYPE_SCOPE_ID, {device_scope_id}}, + {SPV_OPERAND_TYPE_MEMORY_SEMANTICS_ID, {acq_rel_mem_sem_id}}, + {SPV_OPERAND_TYPE_ID, {substraction_id}}}); + atomic_add_inst->InsertAfter(gep_entrypoint_counter_inst); + + for (unsigned i = 0; i < start_counters.size(); i++) { + auto gep_id = context()->TakeNextId(); + auto gep_inst = new Instruction( + context(), spv::Op::OpAccessChain, u64_private_ptr_ty_id, gep_id, + {{SPV_OPERAND_TYPE_ID, {local_counters_id}}, + {SPV_OPERAND_TYPE_ID, {cst_mgr->GetUIntConstId(i)}}}); + gep_inst->InsertAfter(atomic_add_inst); + + auto load_id = context()->TakeNextId(); + auto load_inst = + new Instruction(context(), spv::Op::OpLoad, u64_ty_id, load_id, + {{SPV_OPERAND_TYPE_ID, {gep_id}}}); + load_inst->InsertAfter(gep_inst); + + auto gep_atomic_id = context()->TakeNextId(); + auto gep_atomic_inst = new Instruction( + context(), spv::Op::OpAccessChain, u64_ptr_ty_id, gep_atomic_id, + {{SPV_OPERAND_TYPE_ID, {global_counters_id}}, + {SPV_OPERAND_TYPE_ID, {cst_mgr->GetUIntConstId(0)}}, + {SPV_OPERAND_TYPE_ID, {cst_mgr->GetUIntConstId(2 + i)}}}); + gep_atomic_inst->InsertAfter(load_inst); + + atomic_add_id = context()->TakeNextId(); + atomic_add_inst = new Instruction( + context(), spv::Op::OpAtomicIAdd, u64_ty_id, atomic_add_id, + {{SPV_OPERAND_TYPE_ID, {gep_atomic_id}}, + {SPV_OPERAND_TYPE_SCOPE_ID, {device_scope_id}}, + {SPV_OPERAND_TYPE_MEMORY_SEMANTICS_ID, {acq_rel_mem_sem_id}}, + {SPV_OPERAND_TYPE_ID, {load_id}}}); + atomic_add_inst->InsertAfter(gep_atomic_inst); + } +} + +void ExtractVkspReflectInfoPass::CreateCounters( + uint32_t u64_ty_id, uint32_t u64_private_ptr_ty_id, + uint32_t subgroup_scope_id, std::vector& start_counters, + std::vector& stop_counters, uint32_t local_counters_id) { + auto* cst_mgr = context()->get_constant_mgr(); + std::map> start_counters_id_map; + uint32_t next_counter_id = 2; + + for (auto* inst : start_counters) { + const char* counter_name = strdup(inst->GetOperand(4).AsString().c_str()); + + auto read_clock_id = context()->TakeNextId(); + auto read_clock_inst = new Instruction( + context(), spv::Op::OpReadClockKHR, u64_ty_id, read_clock_id, + {{SPV_OPERAND_TYPE_SCOPE_ID, {subgroup_scope_id}}}); + read_clock_inst->InsertBefore(inst); + + counters_->push_back({next_counter_id, counter_name}); + start_counters_id_map[inst->result_id()] = + std::make_pair(read_clock_id, next_counter_id); + next_counter_id++; + } + + for (auto* inst : stop_counters) { + auto read_clock_ext_inst_id = inst->GetOperand(4).AsId(); + if (start_counters_id_map.count(read_clock_ext_inst_id) == 0) { + continue; + } + auto pair = start_counters_id_map[read_clock_ext_inst_id]; + auto read_clock_id = pair.first; + auto counters_var_index = pair.second; + + auto read_clock_end_id = context()->TakeNextId(); + auto read_clock_end_inst = new Instruction( + context(), spv::Op::OpReadClockKHR, u64_ty_id, read_clock_end_id, + {{SPV_OPERAND_TYPE_SCOPE_ID, {subgroup_scope_id}}}); + read_clock_end_inst->InsertAfter(inst); + + auto substraction_id = context()->TakeNextId(); + auto substraction_inst = + new Instruction(context(), spv::Op::OpISub, u64_ty_id, substraction_id, + {{SPV_OPERAND_TYPE_ID, {read_clock_end_id}}, + {SPV_OPERAND_TYPE_ID, {read_clock_id}}}); + substraction_inst->InsertAfter(read_clock_end_inst); + + auto gep_id = context()->TakeNextId(); + auto gep_inst = new Instruction( + context(), spv::Op::OpAccessChain, u64_private_ptr_ty_id, gep_id, + {{SPV_OPERAND_TYPE_ID, {local_counters_id}}, + {SPV_OPERAND_TYPE_ID, + {cst_mgr->GetUIntConstId(counters_var_index - 2)}}}); + gep_inst->InsertAfter(substraction_inst); + + auto load_id = context()->TakeNextId(); + auto load_inst = + new Instruction(context(), spv::Op::OpLoad, u64_ty_id, load_id, + {{SPV_OPERAND_TYPE_ID, {gep_id}}}); + load_inst->InsertAfter(gep_inst); + + auto add_id = context()->TakeNextId(); + auto add_inst = + new Instruction(context(), spv::Op::OpIAdd, u64_ty_id, add_id, + {{SPV_OPERAND_TYPE_ID, {load_id}}, + {SPV_OPERAND_TYPE_ID, {substraction_id}}}); + add_inst->InsertAfter(load_inst); + + auto store_inst = new Instruction( + context(), spv::Op::OpStore, 0, 0, + {{SPV_OPERAND_TYPE_ID, {gep_id}}, {SPV_OPERAND_TYPE_ID, {add_id}}}); + store_inst->InsertAfter(add_inst); + } +} + +Pass::Status ExtractVkspReflectInfoPass::Process() { + auto module = context()->module(); + uint32_t ext_inst_id = + module->GetExtInstImportId("NonSemantic.VkspReflection.1"); + int32_t descriptor_set_0_max_binding = -1; + std::map id_to_descriptor_set; + std::map id_to_binding; + std::vector start_counters; + std::vector stop_counters; + + module->ForEachInst([this, ext_inst_id, &id_to_descriptor_set, &id_to_binding, + &descriptor_set_0_max_binding, &start_counters, + &stop_counters](Instruction* inst) { + ParseInstruction(inst, ext_inst_id, id_to_descriptor_set, id_to_binding, + descriptor_set_0_max_binding, start_counters, + stop_counters); + }); + + context()->AddExtension("SPV_KHR_shader_clock"); + context()->AddExtension("SPV_KHR_storage_buffer_storage_class"); + context()->AddCapability(spv::Capability::ShaderClockKHR); + context()->AddCapability(spv::Capability::Int64); + context()->AddCapability(spv::Capability::Int64Atomics); + + uint32_t global_counters_ds = 0; + uint32_t global_counters_binding = descriptor_set_0_max_binding + 1; + auto counters_size = + (uint32_t)(sizeof(uint64_t) * + (2 + start_counters + .size())); // 2 for the number of invocations and the + // time of the whole entry point + ds_->push_back( + {global_counters_ds, + global_counters_binding, + (uint32_t)VKSP_DESCRIPTOR_TYPE_STORAGE_BUFFER_COUNTER, + {.buffer = {0, 0, VK_SHARING_MODE_EXCLUSIVE, counters_size, + VK_BUFFER_USAGE_TRANSFER_DST_BIT | + VK_BUFFER_USAGE_STORAGE_BUFFER_BIT, + counters_size, 0, counters_size, UINT32_MAX, 0}}}); + + auto* cst_mgr = context()->get_constant_mgr(); + auto* type_mgr = context()->get_type_mgr(); + + auto u64_ty = type_mgr->GetIntType(64, 0); + analysis::RuntimeArray run_arr(u64_ty); + auto u64_run_arr_ty = type_mgr->GetRegisteredType(&run_arr); + analysis::Struct st({u64_run_arr_ty}); + auto u64_run_arr_st_ty = type_mgr->GetRegisteredType(&st); + analysis::Pointer u64_run_arr_st_ty_ptr(u64_run_arr_st_ty, + spv::StorageClass::StorageBuffer); + auto u64_run_arr_st_ptr_ty = + type_mgr->GetRegisteredType(&u64_run_arr_st_ty_ptr); + + analysis::Pointer u64_ty_ptr(u64_ty, spv::StorageClass::StorageBuffer); + auto u64_ptr_ty = type_mgr->GetRegisteredType(&u64_ty_ptr); + + analysis::Array arr( + u64_ty, analysis::Array::LengthInfo{ + cst_mgr->GetUIntConstId((uint32_t)start_counters.size()), + {0, (uint32_t)start_counters.size()}}); + auto u64_arr_ty = type_mgr->GetRegisteredType(&arr); + analysis::Pointer u64_arr_ty_ptr(u64_arr_ty, spv::StorageClass::Private); + auto u64_arr_ptr_ty = type_mgr->GetRegisteredType(&u64_arr_ty_ptr); + analysis::Pointer u64_ty_ptr_private(u64_ty, spv::StorageClass::Private); + auto u64_private_ptr_ty = type_mgr->GetRegisteredType(&u64_ty_ptr_private); + + auto local_counters_ty_id = type_mgr->GetId(u64_arr_ptr_ty); + auto counters_ty_id = type_mgr->GetId(u64_run_arr_st_ptr_ty); + auto u64_ty_id = type_mgr->GetId(u64_ty); + auto u64_ptr_ty_id = type_mgr->GetId(u64_ptr_ty); + auto u64_arr_ty_id = type_mgr->GetId(u64_run_arr_ty); + auto u64_arr_st_ty_id = type_mgr->GetId(u64_run_arr_st_ty); + auto u64_private_ptr_ty_id = type_mgr->GetId(u64_private_ptr_ty); + + auto subgroup_scope_id = + cst_mgr->GetUIntConstId((uint32_t)spv::Scope::Subgroup); + auto device_scope_id = cst_mgr->GetUIntConstId((uint32_t)spv::Scope::Device); + auto acq_rel_mem_sem_id = cst_mgr->GetUIntConstId( + (uint32_t)spv::MemorySemanticsMask::AcquireRelease); + + uint32_t global_counters_id; + uint32_t local_counters_id; + CreateVariables(u64_arr_ty_id, u64_arr_st_ty_id, local_counters_ty_id, + counters_ty_id, global_counters_ds, global_counters_binding, + global_counters_id, local_counters_id); + + bool found = false; + for (auto& entry_point_inst : module->entry_points()) { + auto function_name = entry_point_inst.GetOperand(2).AsString(); + if (function_name != std::string(config_->entryPoint)) { + continue; + } + found = true; + + uint32_t read_clock_id; + Function* function; + CreatePrologue(&entry_point_inst, u64_private_ptr_ty_id, u64_ty_id, + subgroup_scope_id, global_counters_id, local_counters_id, + start_counters, function, read_clock_id); + + function->ForEachInst([this, read_clock_id, u64_ty_id, u64_ptr_ty_id, + u64_private_ptr_ty_id, subgroup_scope_id, + device_scope_id, acq_rel_mem_sem_id, + global_counters_id, local_counters_id, + &start_counters](Instruction* inst) { + if (inst->opcode() != spv::Op::OpReturn) { + return; + } + CreateEpilogue(inst, read_clock_id, u64_ty_id, u64_ptr_ty_id, + u64_private_ptr_ty_id, subgroup_scope_id, device_scope_id, + acq_rel_mem_sem_id, global_counters_id, local_counters_id, + start_counters); + }); + + break; + } + if (!found) { + return Status::Failure; + } + + CreateCounters(u64_ty_id, u64_private_ptr_ty_id, subgroup_scope_id, + start_counters, stop_counters, local_counters_id); + + return Status::SuccessWithChange; +} + +} // namespace opt +} // namespace spvtools diff --git a/source/opt/vksp_passes.h b/source/opt/vksp_passes.h new file mode 100644 index 00000000000..cd9fdfd376b --- /dev/null +++ b/source/opt/vksp_passes.h @@ -0,0 +1,101 @@ +// Copyright (c) 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_OPT_VKSP_PASSES_H_ +#define SOURCE_OPT_VKSP_PASSES_H_ + +#include "source/opt/module.h" +#include "source/opt/pass.h" +#include "spirv-tools/optimizer.hpp" + +namespace spvtools { +namespace opt { + +class InsertVkspReflectInfoPass : public Pass { + public: + InsertVkspReflectInfoPass(std::vector* pc, + std::vector* ds, + std::vector* me, + vksp_configuration* config) + : pc_(pc), ds_(ds), me_(me), config_(config) {} + const char* name() const override { return "insert-vksp-reflect-info"; } + Status Process() override; + + private: + std::vector* pc_; + std::vector* ds_; + std::vector* me_; + vksp_configuration* config_; +}; + +class ExtractVkspReflectInfoPass : public Pass { + public: + ExtractVkspReflectInfoPass(std::vector* pc, + std::vector* ds, + std::vector* me, + std::vector* counters, + vksp_configuration* config) + : pc_(pc), ds_(ds), me_(me), counters_(counters), config_(config) {} + const char* name() const override { return "extract-vksp-reflect-info"; } + Status Process() override; + + private: + int32_t UpdateMaxBinding(uint32_t ds, uint32_t binding, int32_t max_binding); + + void ParseInstruction(Instruction* inst, uint32_t ext_inst_id, + std::map& id_to_descriptor_set, + std::map& id_to_binding, + int32_t& descriptor_set_0_max_binding, + std::vector& start_counters, + std::vector& stop_counters); + + void CreateVariables(uint32_t u64_arr_ty_id, uint32_t u64_arr_st_ty_id, + uint32_t local_counters_ty_id, uint32_t counters_ty_id, + uint32_t global_counters_ds, + uint32_t global_counters_binding, + uint32_t& global_counters_id, + uint32_t& local_counters_id); + + void CreatePrologue(Instruction* entry_point_inst, + uint32_t u64_private_ptr_ty_id, uint32_t u64_ty_id, + uint32_t subgroup_scope_id, uint32_t global_counters_id, + uint32_t local_counters_id, + std::vector& start_counters, + Function*& function, uint32_t& read_clock_id); + + void CreateEpilogue(Instruction* inst, uint32_t read_clock_id, + uint32_t u64_ty_id, uint32_t u64_ptr_ty_id, + uint32_t u64_private_ptr_ty_id, + uint32_t subgroup_scope_id, uint32_t device_scope_id, + uint32_t acq_rel_mem_sem_id, uint32_t global_counters_id, + uint32_t local_counters_id, + std::vector& start_counters); + + void CreateCounters(uint32_t u64_ty_id, uint32_t u64_private_ptr_ty_id, + uint32_t subgroup_scope_id, + std::vector& start_counters, + std::vector& stop_counters, + uint32_t local_counters_id); + + std::vector* pc_; + std::vector* ds_; + std::vector* me_; + std::vector* counters_; + vksp_configuration* config_; +}; + +} // namespace opt +} // namespace spvtools + +#endif // SOURCE_OPT_VKSP_REFLECT_INFO_PASS_H_ diff --git a/source/table.h b/source/table.h index 8097f13f776..f1e971bf83b 100644 --- a/source/table.h +++ b/source/table.h @@ -74,7 +74,7 @@ typedef struct spv_ext_inst_desc_t { const uint32_t ext_inst; const uint32_t numCapabilities; const spv::Capability* capabilities; - const spv_operand_type_t operandTypes[16]; // TODO: Smaller/larger? + const spv_operand_type_t operandTypes[64]; // TODO: Smaller/larger? } spv_ext_inst_desc_t; typedef struct spv_ext_inst_group_t {