Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

spirv-val: Validate outter descriptor array index #5577

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
57 changes: 38 additions & 19 deletions source/val/validate_memory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1276,14 +1276,13 @@ spv_result_t ValidateCopyMemory(ValidationState_t& _, const Instruction* inst) {

spv_result_t ValidateAccessChain(ValidationState_t& _,
const Instruction* inst) {
std::string instr_name =
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

moved string logic to error message time as assume most OpAccessChain are valid and don't need to spend time building this string

"Op" + std::string(spvOpcodeString(static_cast<spv::Op>(inst->opcode())));
const spv::Op opcode = inst->opcode();

// The result type must be OpTypePointer.
auto result_type = _.FindDef(inst->type_id());
if (spv::Op::OpTypePointer != result_type->opcode()) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< "The Result Type of " << instr_name << " <id> "
<< "The Result Type of Op" << spvOpcodeString(opcode) << " <id> "
<< _.getIdName(inst->id()) << " must be OpTypePointer. Found Op"
<< spvOpcodeString(static_cast<spv::Op>(result_type->opcode()))
<< ".";
Expand All @@ -1301,8 +1300,8 @@ spv_result_t ValidateAccessChain(ValidationState_t& _,
const auto base_type = _.FindDef(base->type_id());
if (!base_type || spv::Op::OpTypePointer != base_type->opcode()) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< "The Base <id> " << _.getIdName(base_id) << " in " << instr_name
<< " instruction must be a pointer.";
<< "The Base <id> " << _.getIdName(base_id) << " in Op"
<< spvOpcodeString(opcode) << " instruction must be a pointer.";
}

// The result pointer storage class and base pointer storage class must match.
Expand All @@ -1312,8 +1311,8 @@ spv_result_t ValidateAccessChain(ValidationState_t& _,
if (result_type_storage_class != base_type_storage_class) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< "The result pointer storage class and base "
"pointer storage class in "
<< instr_name << " do not match.";
"pointer storage class in Op"
<< spvOpcodeString(opcode) << " do not match.";
}

// The type pointed to by OpTypePointer (word 3) must be a composite type.
Expand All @@ -1323,8 +1322,8 @@ spv_result_t ValidateAccessChain(ValidationState_t& _,
// The number of indexes passed to OpAccessChain may not exceed 255
// The instruction includes 4 words + N words (for N indexes)
size_t num_indexes = inst->words().size() - 4;
if (inst->opcode() == spv::Op::OpPtrAccessChain ||
inst->opcode() == spv::Op::OpInBoundsPtrAccessChain) {
if (opcode == spv::Op::OpPtrAccessChain ||
opcode == spv::Op::OpInBoundsPtrAccessChain) {
// In pointer access chains, the element operand is required, but not
// counted as an index.
--num_indexes;
Expand All @@ -1333,8 +1332,9 @@ spv_result_t ValidateAccessChain(ValidationState_t& _,
_.options()->universal_limits_.max_access_chain_indexes;
if (num_indexes > num_indexes_limit) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< "The number of indexes in " << instr_name << " may not exceed "
<< num_indexes_limit << ". Found " << num_indexes << " indexes.";
<< "The number of indexes in Op" << spvOpcodeString(opcode)
<< " may not exceed " << num_indexes_limit << ". Found "
<< num_indexes << " indexes.";
}
// Indexes walk the type hierarchy to the desired depth, potentially down to
// scalar granularity. The first index in Indexes will select the top-level
Expand All @@ -1344,8 +1344,8 @@ spv_result_t ValidateAccessChain(ValidationState_t& _,
// on. Once any non-composite type is reached, there must be no remaining
// (unused) indexes.
auto starting_index = 4;
if (inst->opcode() == spv::Op::OpPtrAccessChain ||
inst->opcode() == spv::Op::OpInBoundsPtrAccessChain) {
if (opcode == spv::Op::OpPtrAccessChain ||
opcode == spv::Op::OpInBoundsPtrAccessChain) {
++starting_index;
}
for (size_t i = starting_index; i < inst->words().size(); ++i) {
Expand All @@ -1356,27 +1356,46 @@ spv_result_t ValidateAccessChain(ValidationState_t& _,
auto index_type = _.FindDef(cur_word_instr->type_id());
if (!index_type || spv::Op::OpTypeInt != index_type->opcode()) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< "Indexes passed to " << instr_name
<< "Indexes passed to Op" << spvOpcodeString(opcode)
<< " must be of type integer.";
}
switch (type_pointee->opcode()) {
case spv::Op::OpTypeMatrix:
case spv::Op::OpTypeVector:
case spv::Op::OpTypeCooperativeMatrixNV:
case spv::Op::OpTypeCooperativeMatrixKHR:
case spv::Op::OpTypeArray:
case spv::Op::OpTypeRuntimeArray: {
// In OpTypeMatrix, OpTypeVector, spv::Op::OpTypeCooperativeMatrixNV,
// OpTypeArray, and OpTypeRuntimeArray, word 2 is the Element Type.
type_pointee = _.FindDef(type_pointee->word(2));
break;
}
case spv::Op::OpTypeArray: {
// If using a descriptor array, make sure the outer array is not OOB
if (spv::Op::OpConstant == cur_word_instr->opcode()) {
const uint32_t cur_index = cur_word_instr->word(3);
auto array_length_id = _.FindDef(type_pointee->word(3));
if (spv::Op::OpConstant == array_length_id->opcode()) {
const uint32_t array_length = array_length_id->word(3);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

while the OpAccessChain indexes are required to be 32-bit int scalar but can be negative

I realize from the failing spirv-opt tests is the OpTypeArray length operand can be a 64-bit length

if (cur_index >= array_length) {
return _.diag(SPV_ERROR_INVALID_ID, cur_word_instr)
<< "Index is out of bounds: Op" << spvOpcodeString(opcode)
<< " can not find index " << cur_index
<< " into the array <id> "
<< _.getIdName(type_pointee->id()) << " of length "
<< array_length << ".";
}
}
}
type_pointee = _.FindDef(type_pointee->word(2));
break;
}
case spv::Op::OpTypeStruct: {
// In case of structures, there is an additional constraint on the
// index: the index must be an OpConstant.
if (spv::Op::OpConstant != cur_word_instr->opcode()) {
return _.diag(SPV_ERROR_INVALID_ID, cur_word_instr)
<< "The <id> passed to " << instr_name
<< "The <id> passed to Op" << spvOpcodeString(opcode)
<< " to index into a "
"structure must be an OpConstant.";
}
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@alan-baker working more into this, I found the line below makes it seem like it is legal to have a -1 access chain index?

So this means if I have a

layout(set=0, binding=0) buffer foo { int x[3]; } var;

These are valid?

%int_n1 = OpConstant %int -1
%int_5 = OpConstant %int 5

%ac = OpAccessChain %ptr %var %int_0 %int_n1
// or
%ac = OpAccessChain %sb_ptr %var %int_0 %int_5

(I realize this code should be checking 64-bit, I am adding that)

Expand All @@ -1392,7 +1411,7 @@ spv_result_t ValidateAccessChain(ValidationState_t& _,
static_cast<uint32_t>(type_pointee->words().size() - 2);
if (cur_index >= num_struct_members) {
return _.diag(SPV_ERROR_INVALID_ID, cur_word_instr)
<< "Index is out of bounds: " << instr_name
<< "Index is out of bounds: Op" << spvOpcodeString(opcode)
<< " can not find index " << cur_index
<< " into the structure <id> "
<< _.getIdName(type_pointee->id()) << ". This structure has "
Expand All @@ -1407,7 +1426,7 @@ spv_result_t ValidateAccessChain(ValidationState_t& _,
default: {
// Give an error. reached non-composite type while indexes still remain.
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< instr_name
<< "Op" << spvOpcodeString(opcode)
<< " reached non-composite type while indexes "
"still remain to be traversed.";
}
Expand All @@ -1417,7 +1436,7 @@ spv_result_t ValidateAccessChain(ValidationState_t& _,
// The type being pointed to should be the same as the result type.
if (type_pointee->id() != result_type_pointee->id()) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< instr_name << " result type (Op"
<< "Op" << spvOpcodeString(opcode) << " result type (Op"
<< spvOpcodeString(
static_cast<spv::Op>(result_type_pointee->opcode()))
<< ") does not match the type that results from indexing into the "
Expand Down
73 changes: 73 additions & 0 deletions test/val/val_memory_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5115,6 +5115,79 @@ TEST_F(ValidateMemory, VulkanPtrAccessChainWorkgroupNoArrayStrideSuccess) {
EXPECT_EQ(SPV_SUCCESS, ValidateInstructions(SPV_ENV_VULKAN_1_2));
}

TEST_F(ValidateMemory, StorageBufferArrayBadIndexVulkan) {
const std::string spirv = R"(
OpCapability Shader
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %main "main" %var
OpExecutionMode %main LocalSize 1 1 1
OpMemberDecorate %storage_buffer_b 0 Offset 0
OpDecorate %storage_buffer_b Block
OpDecorate %var DescriptorSet 0
OpDecorate %var Binding 1
%void = OpTypeVoid
%func = OpTypeFunction %void
%int = OpTypeInt 32 1
%storage_buffer_b = OpTypeStruct %int
%uint = OpTypeInt 32 0
%uint_3 = OpConstant %uint 3
%array = OpTypeArray %storage_buffer_b %uint_3
%var_ptr = OpTypePointer StorageBuffer %array
%var = OpVariable %var_ptr StorageBuffer
%int_3 = OpConstant %int 3
%int_0 = OpConstant %int 0
%sb_ptr = OpTypePointer StorageBuffer %int
%main = OpFunction %void None %func
%label = OpLabel
%ac = OpAccessChain %sb_ptr %var %int_3 %int_0
OpStore %ac %int_3
OpReturn
OpFunctionEnd
)";

CompileSuccessfully(spirv, SPV_ENV_VULKAN_1_2);
EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions(SPV_ENV_VULKAN_1_2));
EXPECT_THAT(
getDiagnosticString(),
HasSubstr("Index is out of bounds: OpAccessChain can not find index 3 "
"into the array <id> '9[%_arr__struct_3_uint_3]' of length 3"));
}

TEST_F(ValidateMemory, SamplerArrayBadIndexVulkan) {
const std::string spirv = R"(
OpCapability Shader
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %main "main" %sampler_array
OpExecutionMode %main LocalSize 1 1 1
OpDecorate %sampler_array DescriptorSet 2
OpDecorate %sampler_array Binding 1
%void = OpTypeVoid
%func = OpTypeFunction %void
%sampler = OpTypeSampler
%uint = OpTypeInt 32 0
%uint_2 = OpConstant %uint 2
%array = OpTypeArray %sampler %uint_2
%var_ptr = OpTypePointer UniformConstant %array
%sampler_array = OpVariable %var_ptr UniformConstant
%int = OpTypeInt 32 1
%int_2 = OpConstant %int 2
%uc_ptr = OpTypePointer UniformConstant %sampler
%main = OpFunction %void None %func
%label = OpLabel
%ac = OpAccessChain %uc_ptr %sampler_array %int_2
%load = OpLoad %sampler %ac
OpReturn
OpFunctionEnd
)";

CompileSuccessfully(spirv, SPV_ENV_VULKAN_1_2);
EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions(SPV_ENV_VULKAN_1_2));
EXPECT_THAT(
getDiagnosticString(),
HasSubstr("Index is out of bounds: OpAccessChain can not find index 2 "
"into the array <id> '8[%_arr_5_uint_2]' of length 2"));
}

} // namespace
} // namespace val
} // namespace spvtools