diff --git a/source/val/validate_memory.cpp b/source/val/validate_memory.cpp index 5b25eeb3c7..f5bdd076d4 100644 --- a/source/val/validate_memory.cpp +++ b/source/val/validate_memory.cpp @@ -1276,14 +1276,13 @@ spv_result_t ValidateCopyMemory(ValidationState_t& _, const Instruction* inst) { spv_result_t ValidateAccessChain(ValidationState_t& _, const Instruction* inst) { - std::string instr_name = - "Op" + std::string(spvOpcodeString(static_cast(inst->opcode()))); + const spv::Op opcode = inst->opcode(); // The result type must be OpTypePointer. auto result_type = _.FindDef(inst->type_id()); if (spv::Op::OpTypePointer != result_type->opcode()) { return _.diag(SPV_ERROR_INVALID_ID, inst) - << "The Result Type of " << instr_name << " " + << "The Result Type of Op" << spvOpcodeString(opcode) << " " << _.getIdName(inst->id()) << " must be OpTypePointer. Found Op" << spvOpcodeString(static_cast(result_type->opcode())) << "."; @@ -1301,8 +1300,8 @@ spv_result_t ValidateAccessChain(ValidationState_t& _, const auto base_type = _.FindDef(base->type_id()); if (!base_type || spv::Op::OpTypePointer != base_type->opcode()) { return _.diag(SPV_ERROR_INVALID_ID, inst) - << "The Base " << _.getIdName(base_id) << " in " << instr_name - << " instruction must be a pointer."; + << "The Base " << _.getIdName(base_id) << " in Op" + << spvOpcodeString(opcode) << " instruction must be a pointer."; } // The result pointer storage class and base pointer storage class must match. @@ -1312,8 +1311,8 @@ spv_result_t ValidateAccessChain(ValidationState_t& _, if (result_type_storage_class != base_type_storage_class) { return _.diag(SPV_ERROR_INVALID_ID, inst) << "The result pointer storage class and base " - "pointer storage class in " - << instr_name << " do not match."; + "pointer storage class in Op" + << spvOpcodeString(opcode) << " do not match."; } // The type pointed to by OpTypePointer (word 3) must be a composite type. @@ -1323,8 +1322,8 @@ spv_result_t ValidateAccessChain(ValidationState_t& _, // The number of indexes passed to OpAccessChain may not exceed 255 // The instruction includes 4 words + N words (for N indexes) size_t num_indexes = inst->words().size() - 4; - if (inst->opcode() == spv::Op::OpPtrAccessChain || - inst->opcode() == spv::Op::OpInBoundsPtrAccessChain) { + if (opcode == spv::Op::OpPtrAccessChain || + opcode == spv::Op::OpInBoundsPtrAccessChain) { // In pointer access chains, the element operand is required, but not // counted as an index. --num_indexes; @@ -1333,8 +1332,9 @@ spv_result_t ValidateAccessChain(ValidationState_t& _, _.options()->universal_limits_.max_access_chain_indexes; if (num_indexes > num_indexes_limit) { return _.diag(SPV_ERROR_INVALID_ID, inst) - << "The number of indexes in " << instr_name << " may not exceed " - << num_indexes_limit << ". Found " << num_indexes << " indexes."; + << "The number of indexes in Op" << spvOpcodeString(opcode) + << " may not exceed " << num_indexes_limit << ". Found " + << num_indexes << " indexes."; } // Indexes walk the type hierarchy to the desired depth, potentially down to // scalar granularity. The first index in Indexes will select the top-level @@ -1344,8 +1344,8 @@ spv_result_t ValidateAccessChain(ValidationState_t& _, // on. Once any non-composite type is reached, there must be no remaining // (unused) indexes. auto starting_index = 4; - if (inst->opcode() == spv::Op::OpPtrAccessChain || - inst->opcode() == spv::Op::OpInBoundsPtrAccessChain) { + if (opcode == spv::Op::OpPtrAccessChain || + opcode == spv::Op::OpInBoundsPtrAccessChain) { ++starting_index; } for (size_t i = starting_index; i < inst->words().size(); ++i) { @@ -1356,7 +1356,7 @@ spv_result_t ValidateAccessChain(ValidationState_t& _, auto index_type = _.FindDef(cur_word_instr->type_id()); if (!index_type || spv::Op::OpTypeInt != index_type->opcode()) { return _.diag(SPV_ERROR_INVALID_ID, inst) - << "Indexes passed to " << instr_name + << "Indexes passed to Op" << spvOpcodeString(opcode) << " must be of type integer."; } switch (type_pointee->opcode()) { @@ -1364,19 +1364,38 @@ spv_result_t ValidateAccessChain(ValidationState_t& _, case spv::Op::OpTypeVector: case spv::Op::OpTypeCooperativeMatrixNV: case spv::Op::OpTypeCooperativeMatrixKHR: - case spv::Op::OpTypeArray: case spv::Op::OpTypeRuntimeArray: { // In OpTypeMatrix, OpTypeVector, spv::Op::OpTypeCooperativeMatrixNV, // OpTypeArray, and OpTypeRuntimeArray, word 2 is the Element Type. type_pointee = _.FindDef(type_pointee->word(2)); break; } + case spv::Op::OpTypeArray: { + // If using a descriptor array, make sure the outer array is not OOB + if (spv::Op::OpConstant == cur_word_instr->opcode()) { + const uint32_t cur_index = cur_word_instr->word(3); + auto array_length_id = _.FindDef(type_pointee->word(3)); + if (spv::Op::OpConstant == array_length_id->opcode()) { + const uint32_t array_length = array_length_id->word(3); + if (cur_index >= array_length) { + return _.diag(SPV_ERROR_INVALID_ID, cur_word_instr) + << "Index is out of bounds: Op" << spvOpcodeString(opcode) + << " can not find index " << cur_index + << " into the array " + << _.getIdName(type_pointee->id()) << " of length " + << array_length << "."; + } + } + } + type_pointee = _.FindDef(type_pointee->word(2)); + break; + } case spv::Op::OpTypeStruct: { // In case of structures, there is an additional constraint on the // index: the index must be an OpConstant. if (spv::Op::OpConstant != cur_word_instr->opcode()) { return _.diag(SPV_ERROR_INVALID_ID, cur_word_instr) - << "The passed to " << instr_name + << "The passed to Op" << spvOpcodeString(opcode) << " to index into a " "structure must be an OpConstant."; } @@ -1392,7 +1411,7 @@ spv_result_t ValidateAccessChain(ValidationState_t& _, static_cast(type_pointee->words().size() - 2); if (cur_index >= num_struct_members) { return _.diag(SPV_ERROR_INVALID_ID, cur_word_instr) - << "Index is out of bounds: " << instr_name + << "Index is out of bounds: Op" << spvOpcodeString(opcode) << " can not find index " << cur_index << " into the structure " << _.getIdName(type_pointee->id()) << ". This structure has " @@ -1407,7 +1426,7 @@ spv_result_t ValidateAccessChain(ValidationState_t& _, default: { // Give an error. reached non-composite type while indexes still remain. return _.diag(SPV_ERROR_INVALID_ID, inst) - << instr_name + << "Op" << spvOpcodeString(opcode) << " reached non-composite type while indexes " "still remain to be traversed."; } @@ -1417,7 +1436,7 @@ spv_result_t ValidateAccessChain(ValidationState_t& _, // The type being pointed to should be the same as the result type. if (type_pointee->id() != result_type_pointee->id()) { return _.diag(SPV_ERROR_INVALID_ID, inst) - << instr_name << " result type (Op" + << "Op" << spvOpcodeString(opcode) << " result type (Op" << spvOpcodeString( static_cast(result_type_pointee->opcode())) << ") does not match the type that results from indexing into the " diff --git a/test/val/val_memory_test.cpp b/test/val/val_memory_test.cpp index 8d0a94d2b0..bae0d43f53 100644 --- a/test/val/val_memory_test.cpp +++ b/test/val/val_memory_test.cpp @@ -5115,6 +5115,79 @@ TEST_F(ValidateMemory, VulkanPtrAccessChainWorkgroupNoArrayStrideSuccess) { EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_VULKAN_1_2)); } +TEST_F(ValidateMemory, StorageBufferArrayBadIndexVulkan) { + const std::string spirv = R"( +OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint GLCompute %main "main" %var +OpExecutionMode %main LocalSize 1 1 1 +OpMemberDecorate %storage_buffer_b 0 Offset 0 +OpDecorate %storage_buffer_b Block +OpDecorate %var DescriptorSet 0 +OpDecorate %var Binding 1 +%void = OpTypeVoid +%func = OpTypeFunction %void +%int = OpTypeInt 32 1 +%storage_buffer_b = OpTypeStruct %int +%uint = OpTypeInt 32 0 +%uint_3 = OpConstant %uint 3 +%array = OpTypeArray %storage_buffer_b %uint_3 +%var_ptr = OpTypePointer StorageBuffer %array +%var = OpVariable %var_ptr StorageBuffer +%int_3 = OpConstant %int 3 +%int_0 = OpConstant %int 0 +%sb_ptr = OpTypePointer StorageBuffer %int +%main = OpFunction %void None %func +%label = OpLabel +%ac = OpAccessChain %sb_ptr %var %int_3 %int_0 +OpStore %ac %int_3 +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(spirv, SPV_ENV_VULKAN_1_2); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions(SPV_ENV_VULKAN_1_2)); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Index is out of bounds: OpAccessChain can not find index 3 " + "into the array '9[%_arr__struct_3_uint_3]' of length 3")); +} + +TEST_F(ValidateMemory, SamplerArrayBadIndexVulkan) { + const std::string spirv = R"( +OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint GLCompute %main "main" %sampler_array +OpExecutionMode %main LocalSize 1 1 1 +OpDecorate %sampler_array DescriptorSet 2 +OpDecorate %sampler_array Binding 1 +%void = OpTypeVoid +%func = OpTypeFunction %void +%sampler = OpTypeSampler +%uint = OpTypeInt 32 0 +%uint_2 = OpConstant %uint 2 +%array = OpTypeArray %sampler %uint_2 +%var_ptr = OpTypePointer UniformConstant %array +%sampler_array = OpVariable %var_ptr UniformConstant +%int = OpTypeInt 32 1 +%int_2 = OpConstant %int 2 +%uc_ptr = OpTypePointer UniformConstant %sampler +%main = OpFunction %void None %func +%label = OpLabel +%ac = OpAccessChain %uc_ptr %sampler_array %int_2 +%load = OpLoad %sampler %ac +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(spirv, SPV_ENV_VULKAN_1_2); + EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions(SPV_ENV_VULKAN_1_2)); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("Index is out of bounds: OpAccessChain can not find index 2 " + "into the array '8[%_arr_5_uint_2]' of length 2")); +} + } // namespace } // namespace val } // namespace spvtools