-
Notifications
You must be signed in to change notification settings - Fork 537
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
spirv-val: Validate outter descriptor array index #5577
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1276,14 +1276,13 @@ spv_result_t ValidateCopyMemory(ValidationState_t& _, const Instruction* inst) { | |
|
||
spv_result_t ValidateAccessChain(ValidationState_t& _, | ||
const Instruction* inst) { | ||
std::string instr_name = | ||
"Op" + std::string(spvOpcodeString(static_cast<spv::Op>(inst->opcode()))); | ||
const spv::Op opcode = inst->opcode(); | ||
|
||
// The result type must be OpTypePointer. | ||
auto result_type = _.FindDef(inst->type_id()); | ||
if (spv::Op::OpTypePointer != result_type->opcode()) { | ||
return _.diag(SPV_ERROR_INVALID_ID, inst) | ||
<< "The Result Type of " << instr_name << " <id> " | ||
<< "The Result Type of Op" << spvOpcodeString(opcode) << " <id> " | ||
<< _.getIdName(inst->id()) << " must be OpTypePointer. Found Op" | ||
<< spvOpcodeString(static_cast<spv::Op>(result_type->opcode())) | ||
<< "."; | ||
|
@@ -1301,8 +1300,8 @@ spv_result_t ValidateAccessChain(ValidationState_t& _, | |
const auto base_type = _.FindDef(base->type_id()); | ||
if (!base_type || spv::Op::OpTypePointer != base_type->opcode()) { | ||
return _.diag(SPV_ERROR_INVALID_ID, inst) | ||
<< "The Base <id> " << _.getIdName(base_id) << " in " << instr_name | ||
<< " instruction must be a pointer."; | ||
<< "The Base <id> " << _.getIdName(base_id) << " in Op" | ||
<< spvOpcodeString(opcode) << " instruction must be a pointer."; | ||
} | ||
|
||
// The result pointer storage class and base pointer storage class must match. | ||
|
@@ -1312,8 +1311,8 @@ spv_result_t ValidateAccessChain(ValidationState_t& _, | |
if (result_type_storage_class != base_type_storage_class) { | ||
return _.diag(SPV_ERROR_INVALID_ID, inst) | ||
<< "The result pointer storage class and base " | ||
"pointer storage class in " | ||
<< instr_name << " do not match."; | ||
"pointer storage class in Op" | ||
<< spvOpcodeString(opcode) << " do not match."; | ||
} | ||
|
||
// The type pointed to by OpTypePointer (word 3) must be a composite type. | ||
|
@@ -1323,8 +1322,8 @@ spv_result_t ValidateAccessChain(ValidationState_t& _, | |
// The number of indexes passed to OpAccessChain may not exceed 255 | ||
// The instruction includes 4 words + N words (for N indexes) | ||
size_t num_indexes = inst->words().size() - 4; | ||
if (inst->opcode() == spv::Op::OpPtrAccessChain || | ||
inst->opcode() == spv::Op::OpInBoundsPtrAccessChain) { | ||
if (opcode == spv::Op::OpPtrAccessChain || | ||
opcode == spv::Op::OpInBoundsPtrAccessChain) { | ||
// In pointer access chains, the element operand is required, but not | ||
// counted as an index. | ||
--num_indexes; | ||
|
@@ -1333,8 +1332,9 @@ spv_result_t ValidateAccessChain(ValidationState_t& _, | |
_.options()->universal_limits_.max_access_chain_indexes; | ||
if (num_indexes > num_indexes_limit) { | ||
return _.diag(SPV_ERROR_INVALID_ID, inst) | ||
<< "The number of indexes in " << instr_name << " may not exceed " | ||
<< num_indexes_limit << ". Found " << num_indexes << " indexes."; | ||
<< "The number of indexes in Op" << spvOpcodeString(opcode) | ||
<< " may not exceed " << num_indexes_limit << ". Found " | ||
<< num_indexes << " indexes."; | ||
} | ||
// Indexes walk the type hierarchy to the desired depth, potentially down to | ||
// scalar granularity. The first index in Indexes will select the top-level | ||
|
@@ -1344,8 +1344,8 @@ spv_result_t ValidateAccessChain(ValidationState_t& _, | |
// on. Once any non-composite type is reached, there must be no remaining | ||
// (unused) indexes. | ||
auto starting_index = 4; | ||
if (inst->opcode() == spv::Op::OpPtrAccessChain || | ||
inst->opcode() == spv::Op::OpInBoundsPtrAccessChain) { | ||
if (opcode == spv::Op::OpPtrAccessChain || | ||
opcode == spv::Op::OpInBoundsPtrAccessChain) { | ||
++starting_index; | ||
} | ||
for (size_t i = starting_index; i < inst->words().size(); ++i) { | ||
|
@@ -1356,27 +1356,46 @@ spv_result_t ValidateAccessChain(ValidationState_t& _, | |
auto index_type = _.FindDef(cur_word_instr->type_id()); | ||
if (!index_type || spv::Op::OpTypeInt != index_type->opcode()) { | ||
return _.diag(SPV_ERROR_INVALID_ID, inst) | ||
<< "Indexes passed to " << instr_name | ||
<< "Indexes passed to Op" << spvOpcodeString(opcode) | ||
<< " must be of type integer."; | ||
} | ||
switch (type_pointee->opcode()) { | ||
case spv::Op::OpTypeMatrix: | ||
case spv::Op::OpTypeVector: | ||
case spv::Op::OpTypeCooperativeMatrixNV: | ||
case spv::Op::OpTypeCooperativeMatrixKHR: | ||
case spv::Op::OpTypeArray: | ||
case spv::Op::OpTypeRuntimeArray: { | ||
// In OpTypeMatrix, OpTypeVector, spv::Op::OpTypeCooperativeMatrixNV, | ||
// OpTypeArray, and OpTypeRuntimeArray, word 2 is the Element Type. | ||
type_pointee = _.FindDef(type_pointee->word(2)); | ||
break; | ||
} | ||
case spv::Op::OpTypeArray: { | ||
// If using a descriptor array, make sure the outer array is not OOB | ||
if (spv::Op::OpConstant == cur_word_instr->opcode()) { | ||
const uint32_t cur_index = cur_word_instr->word(3); | ||
auto array_length_id = _.FindDef(type_pointee->word(3)); | ||
if (spv::Op::OpConstant == array_length_id->opcode()) { | ||
const uint32_t array_length = array_length_id->word(3); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. while the I realize from the failing |
||
if (cur_index >= array_length) { | ||
return _.diag(SPV_ERROR_INVALID_ID, cur_word_instr) | ||
<< "Index is out of bounds: Op" << spvOpcodeString(opcode) | ||
<< " can not find index " << cur_index | ||
<< " into the array <id> " | ||
<< _.getIdName(type_pointee->id()) << " of length " | ||
<< array_length << "."; | ||
} | ||
} | ||
} | ||
type_pointee = _.FindDef(type_pointee->word(2)); | ||
break; | ||
} | ||
case spv::Op::OpTypeStruct: { | ||
// In case of structures, there is an additional constraint on the | ||
// index: the index must be an OpConstant. | ||
if (spv::Op::OpConstant != cur_word_instr->opcode()) { | ||
return _.diag(SPV_ERROR_INVALID_ID, cur_word_instr) | ||
<< "The <id> passed to " << instr_name | ||
<< "The <id> passed to Op" << spvOpcodeString(opcode) | ||
<< " to index into a " | ||
"structure must be an OpConstant."; | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @alan-baker working more into this, I found the line below makes it seem like it is legal to have a So this means if I have a
These are valid? %int_n1 = OpConstant %int -1
%int_5 = OpConstant %int 5
%ac = OpAccessChain %ptr %var %int_0 %int_n1
// or
%ac = OpAccessChain %sb_ptr %var %int_0 %int_5 (I realize this code should be checking 64-bit, I am adding that) |
||
|
@@ -1392,7 +1411,7 @@ spv_result_t ValidateAccessChain(ValidationState_t& _, | |
static_cast<uint32_t>(type_pointee->words().size() - 2); | ||
if (cur_index >= num_struct_members) { | ||
return _.diag(SPV_ERROR_INVALID_ID, cur_word_instr) | ||
<< "Index is out of bounds: " << instr_name | ||
<< "Index is out of bounds: Op" << spvOpcodeString(opcode) | ||
<< " can not find index " << cur_index | ||
<< " into the structure <id> " | ||
<< _.getIdName(type_pointee->id()) << ". This structure has " | ||
|
@@ -1407,7 +1426,7 @@ spv_result_t ValidateAccessChain(ValidationState_t& _, | |
default: { | ||
// Give an error. reached non-composite type while indexes still remain. | ||
return _.diag(SPV_ERROR_INVALID_ID, inst) | ||
<< instr_name | ||
<< "Op" << spvOpcodeString(opcode) | ||
<< " reached non-composite type while indexes " | ||
"still remain to be traversed."; | ||
} | ||
|
@@ -1417,7 +1436,7 @@ spv_result_t ValidateAccessChain(ValidationState_t& _, | |
// The type being pointed to should be the same as the result type. | ||
if (type_pointee->id() != result_type_pointee->id()) { | ||
return _.diag(SPV_ERROR_INVALID_ID, inst) | ||
<< instr_name << " result type (Op" | ||
<< "Op" << spvOpcodeString(opcode) << " result type (Op" | ||
<< spvOpcodeString( | ||
static_cast<spv::Op>(result_type_pointee->opcode())) | ||
<< ") does not match the type that results from indexing into the " | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
moved string logic to error message time as assume most
OpAccessChain
are valid and don't need to spend time building this string