Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Access Flags for DescriptorBinding #224

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions common/output_stream.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -696,6 +696,28 @@ std::string ToStringDecorationFlags(
return sstream.str();
}

std::string ToStringAccessFlags(SpvReflectAccessFlags access_flags) {
if (access_flags == SPV_REFLECT_ACCESS_NONE) {
return "NONE";
}

#define PRINT_AND_CLEAR_ACCESS_FLAG(stream, flags, bit) \
if (((flags) & (SPV_REFLECT_ACCESS_##bit)) == (SPV_REFLECT_ACCESS_##bit)) { \
stream << #bit << " "; \
flags ^= SPV_REFLECT_ACCESS_##bit; \
}
std::stringstream sstream;
PRINT_AND_CLEAR_ACCESS_FLAG(sstream, access_flags, READ);
PRINT_AND_CLEAR_ACCESS_FLAG(sstream, access_flags, WRITE);
PRINT_AND_CLEAR_ACCESS_FLAG(sstream, access_flags, ATOMIC);
#undef PRINT_AND_CLEAR_ACCESS_FLAG
if (access_flags != 0) {
// Unhandled SpvReflectAccessFlags bit
sstream << "???";
}
return sstream.str();
}

std::string ToStringFormat(SpvReflectFormat fmt) {
switch (fmt) {
case SPV_REFLECT_FORMAT_UNDEFINED:
Expand Down Expand Up @@ -1921,6 +1943,9 @@ void SpvReflectToYaml::WriteDescriptorBinding(

// uint32_t accessed;
os << t1 << "accessed: " << db.accessed << std::endl;
// SpvReflectAccessFlags access_flags;
os << t1 << "access_flags: " << AsHexString(db.access_flags) << " # "
<< ToStringAccessFlags(db.access_flags) << std::endl;

// uint32_t uav_counter_id;
os << t1 << "uav_counter_id: " << db.uav_counter_id << std::endl;
Expand Down
177 changes: 165 additions & 12 deletions spirv_reflect.c
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,22 @@ typedef struct SpvReflectPrvString {
} SpvReflectPrvString;
// clang-format on

// clang-format off
// There are a limit set of instructions that can touch an OpVariable,
// these are represented here with how it was accessed
// Examples:
// OpImageRead -> OpLoad -> OpVariable
// OpImageWrite -> OpLoad -> OpVariable
// OpStore -> OpAccessChain -> OpVariable
// OpAtomicIAdd -> OpAccessChain -> OpVariable
// OpAtomicLoad -> OpImageTexelPointer -> OpVariable
typedef struct SpvReflectPrvAccessedVariable {
uint32_t result_id;
uint32_t variable_ptr;
SpvReflectAccessFlags access_flags;
} SpvReflectPrvAccessedVariable;
// clang-format on

// clang-format off
typedef struct SpvReflectPrvFunction {
uint32_t id;
Expand All @@ -185,6 +201,8 @@ typedef struct SpvReflectPrvFunction {
struct SpvReflectPrvFunction** callee_ptrs;
uint32_t accessed_ptr_count;
uint32_t* accessed_ptrs;
uint32_t accessed_variable_ptr_count;
SpvReflectPrvAccessedVariable* accessed_variable_ptrs;
} SpvReflectPrvFunction;
// clang-format on

Expand Down Expand Up @@ -691,6 +709,7 @@ static void DestroyParser(SpvReflectPrvParser* p_parser)
SafeFree(p_parser->functions[i].callees);
SafeFree(p_parser->functions[i].callee_ptrs);
SafeFree(p_parser->functions[i].accessed_ptrs);
SafeFree(p_parser->functions[i].accessed_variable_ptrs);
}

// Free access chains
Expand Down Expand Up @@ -1166,6 +1185,7 @@ static SpvReflectResult ParseFunction(
p_func->callee_count = 0;
p_func->accessed_ptr_count = 0;

// First get count to know how much to allocate
for (size_t i = first_label_index; i < p_parser->node_count; ++i) {
SpvReflectPrvNode* p_node = &(p_parser->nodes[i]);
if (p_node->op == SpvOpFunctionEnd) {
Expand Down Expand Up @@ -1213,10 +1233,17 @@ static SpvReflectResult ParseFunction(
if (IsNull(p_func->accessed_ptrs)) {
return SPV_REFLECT_RESULT_ERROR_ALLOC_FAILED;
}

p_func->accessed_variable_ptrs = (SpvReflectPrvAccessedVariable*)calloc(
p_func->accessed_ptr_count, sizeof(*(p_func->accessed_variable_ptrs)));
if (IsNull(p_func->accessed_variable_ptrs)) {
return SPV_REFLECT_RESULT_ERROR_ALLOC_FAILED;
}
}

p_func->callee_count = 0;
p_func->accessed_ptr_count = 0;
// Now have allocation, fill in values
for (size_t i = first_label_index; i < p_parser->node_count; ++i) {
SpvReflectPrvNode* p_node = &(p_parser->nodes[i]);
if (p_node->op == SpvOpFunctionEnd) {
Expand All @@ -1227,26 +1254,74 @@ static SpvReflectResult ParseFunction(
CHECKED_READU32(p_parser, p_node->word_offset + 3,
p_func->callees[p_func->callee_count]);
(++p_func->callee_count);
}
break;
case SpvOpLoad:
} break;
case SpvOpAccessChain:
case SpvOpInBoundsAccessChain:
case SpvOpPtrAccessChain:
case SpvOpArrayLength:
case SpvOpGenericPtrMemSemantics:
case SpvOpInBoundsPtrAccessChain:
case SpvOpInBoundsPtrAccessChain: {
CHECKED_READU32(p_parser, p_node->word_offset + 3,
spencer-lunarg marked this conversation as resolved.
Show resolved Hide resolved
p_func->accessed_ptrs[p_func->accessed_ptr_count]);

CHECKED_READU32(
p_parser, p_node->word_offset + 3,
p_func->accessed_variable_ptrs[p_func->accessed_ptr_count]
.variable_ptr);
spencer-lunarg marked this conversation as resolved.
Show resolved Hide resolved
// Need to track Result ID as not sure there has been any memory access
// through here yet
CHECKED_READU32(
p_parser, p_node->word_offset + 2,
p_func->accessed_variable_ptrs[p_func->accessed_ptr_count]
.result_id);
(++p_func->accessed_ptr_count);
} break;
case SpvOpLoad: {
CHECKED_READU32(p_parser, p_node->word_offset + 3,
p_func->accessed_ptrs[p_func->accessed_ptr_count]);

CHECKED_READU32(
p_parser, p_node->word_offset + 3,
p_func->accessed_variable_ptrs[p_func->accessed_ptr_count]
.variable_ptr);
CHECKED_READU32(
p_parser, p_node->word_offset + 2,
p_func->accessed_variable_ptrs[p_func->accessed_ptr_count]
.result_id);
p_func->accessed_variable_ptrs[p_func->accessed_ptr_count]
.access_flags = SPV_REFLECT_ACCESS_READ;
(++p_func->accessed_ptr_count);
} break;
case SpvOpImageTexelPointer:
{
CHECKED_READU32(p_parser, p_node->word_offset + 3,
p_func->accessed_ptrs[p_func->accessed_ptr_count]);

CHECKED_READU32(
p_parser, p_node->word_offset + 3,
p_func->accessed_variable_ptrs[p_func->accessed_ptr_count]
.variable_ptr);
CHECKED_READU32(
p_parser, p_node->word_offset + 2,
p_func->accessed_variable_ptrs[p_func->accessed_ptr_count]
.result_id);
p_func->accessed_variable_ptrs[p_func->accessed_ptr_count]
.access_flags = SPV_REFLECT_ACCESS_ATOMIC |
SPV_REFLECT_ACCESS_READ | SPV_REFLECT_ACCESS_WRITE;
(++p_func->accessed_ptr_count);
}
break;
case SpvOpStore:
{
CHECKED_READU32(p_parser, p_node->word_offset + 2,
p_func->accessed_ptrs[p_func->accessed_ptr_count]);

CHECKED_READU32(
p_parser, p_node->word_offset + 2,
p_func->accessed_variable_ptrs[p_func->accessed_ptr_count]
.variable_ptr);
p_func->accessed_variable_ptrs[p_func->accessed_ptr_count]
.access_flags = SPV_REFLECT_ACCESS_WRITE;
(++p_func->accessed_ptr_count);
}
break;
Expand All @@ -1265,6 +1340,66 @@ static SpvReflectResult ParseFunction(
}
}

// Apply the SpvReflectAccessFlags to all things touching an OpVariable
for (size_t i = first_label_index; i < p_parser->node_count; ++i) {
SpvReflectPrvNode* p_node = &(p_parser->nodes[i]);
if (p_node->op == SpvOpFunctionEnd) {
break;
}
// These are memory accesses instruction
uint32_t memory_access_ptr = 0;
SpvReflectAccessFlags memory_access_type = SPV_REFLECT_ACCESS_NONE;
switch (p_node->op) {
case SpvOpLoad: {
CHECKED_READU32(p_parser, p_node->word_offset + 3, memory_access_ptr);
memory_access_type = SPV_REFLECT_ACCESS_READ;
} break;
case SpvOpImageWrite:
case SpvOpStore: {
CHECKED_READU32(p_parser, p_node->word_offset + 1, memory_access_ptr);
memory_access_type = SPV_REFLECT_ACCESS_WRITE;
} break;
case SpvOpImageTexelPointer:
case SpvOpAtomicLoad:
case SpvOpAtomicExchange:
case SpvOpAtomicCompareExchange:
case SpvOpAtomicIIncrement:
case SpvOpAtomicIDecrement:
case SpvOpAtomicIAdd:
case SpvOpAtomicISub:
case SpvOpAtomicSMin:
case SpvOpAtomicUMin:
case SpvOpAtomicSMax:
case SpvOpAtomicUMax:
case SpvOpAtomicAnd:
case SpvOpAtomicOr:
case SpvOpAtomicXor:
case SpvOpAtomicFMinEXT:
case SpvOpAtomicFMaxEXT:
case SpvOpAtomicFAddEXT: {
CHECKED_READU32(p_parser, p_node->word_offset + 3, memory_access_ptr);
memory_access_type = SPV_REFLECT_ACCESS_ATOMIC |
SPV_REFLECT_ACCESS_READ | SPV_REFLECT_ACCESS_WRITE;
} break;
case SpvOpAtomicStore: {
CHECKED_READU32(p_parser, p_node->word_offset + 1, memory_access_ptr);
memory_access_type = SPV_REFLECT_ACCESS_ATOMIC |
SPV_REFLECT_ACCESS_READ | SPV_REFLECT_ACCESS_WRITE;
} break;
default:
break;
}

if (memory_access_ptr == 0) {
continue;
}
for (uint32_t k = 0; k < p_func->accessed_ptr_count; k++) {
if (p_func->accessed_variable_ptrs[k].result_id == memory_access_ptr) {
p_func->accessed_variable_ptrs[k].access_flags |= memory_access_type;
}
}
}

if (p_func->callee_count > 0) {
qsort(p_func->callees, p_func->callee_count,
sizeof(*(p_func->callees)), SortCompareUint32);
Expand All @@ -1276,6 +1411,7 @@ static SpvReflectResult ParseFunction(
qsort(p_func->accessed_ptrs, p_func->accessed_ptr_count,
sizeof(*(p_func->accessed_ptrs)), SortCompareUint32);
}
p_func->accessed_variable_ptr_count = p_func->accessed_ptr_count;
p_func->accessed_ptr_count = (uint32_t)DedupSortedUint32(p_func->accessed_ptrs,
p_func->accessed_ptr_count);

Expand Down Expand Up @@ -3323,24 +3459,35 @@ static SpvReflectResult ParseStaticallyUsedResources(
called_function_count = DedupSortedUint32(p_called_functions, called_function_count);

uint32_t used_variable_count = 0;
uint32_t used_acessed_count = 0;
for (size_t i = 0, j = 0; i < called_function_count; ++i) {
// No need to bounds check j because a missing ID issue would have been
// found during TraverseCallGraph
while (p_parser->functions[j].id != p_called_functions[i]) {
++j;
}
used_variable_count += p_parser->functions[j].accessed_ptr_count;
used_acessed_count += p_parser->functions[j].accessed_variable_ptr_count;
}
uint32_t* used_variables = NULL;
SpvReflectPrvAccessedVariable* used_accesses = NULL;
if (used_variable_count > 0) {
used_variables = (uint32_t*)calloc(used_variable_count,
sizeof(*used_variables));
if (IsNull(used_variables)) {
SafeFree(p_called_functions);
return SPV_REFLECT_RESULT_ERROR_ALLOC_FAILED;
}

used_accesses = (SpvReflectPrvAccessedVariable*)calloc(
used_acessed_count, sizeof(SpvReflectPrvAccessedVariable));
if (IsNull(used_accesses)) {
SafeFree(p_called_functions);
return SPV_REFLECT_RESULT_ERROR_ALLOC_FAILED;
}
}
used_variable_count = 0;
used_acessed_count = 0;
for (size_t i = 0, j = 0; i < called_function_count; ++i) {
while (p_parser->functions[j].id != p_called_functions[i]) {
++j;
Expand All @@ -3350,6 +3497,12 @@ static SpvReflectResult ParseStaticallyUsedResources(
p_parser->functions[j].accessed_ptrs,
p_parser->functions[j].accessed_ptr_count * sizeof(*used_variables));
used_variable_count += p_parser->functions[j].accessed_ptr_count;

memcpy(&used_accesses[used_acessed_count],
p_parser->functions[j].accessed_variable_ptrs,
p_parser->functions[j].accessed_variable_ptr_count *
sizeof(SpvReflectPrvAccessedVariable));
used_acessed_count += p_parser->functions[j].accessed_variable_ptr_count;
}
SafeFree(p_called_functions);

Expand Down Expand Up @@ -3381,18 +3534,18 @@ static SpvReflectResult ParseStaticallyUsedResources(
&p_entry->used_push_constants,
&used_push_constant_count);

for (uint32_t j = 0; j < p_module->descriptor_binding_count; ++j) {
SpvReflectDescriptorBinding* p_binding = &p_module->descriptor_bindings[j];
bool found = SearchSortedUint32(
used_variables,
used_variable_count,
p_binding->spirv_id);
if (found) {
p_binding->accessed = 1;
for (uint32_t i = 0; i < p_module->descriptor_binding_count; ++i) {
SpvReflectDescriptorBinding* p_binding = &p_module->descriptor_bindings[i];
for (uint32_t j = 0; j < used_acessed_count; j++) {
if (used_accesses[j].variable_ptr == p_binding->spirv_id) {
p_binding->accessed = 1;
p_binding->access_flags |= used_accesses[j].access_flags;
}
}
}

SafeFree(used_variables);
SafeFree(used_accesses);
if (result0 != SPV_REFLECT_RESULT_SUCCESS) {
return result0;
}
Expand Down
17 changes: 17 additions & 0 deletions spirv_reflect.h
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,22 @@ typedef enum SpvReflectShaderStageFlagBits {

} SpvReflectShaderStageFlagBits;

/*! @enum SpvReflectAccessBits

NOTE: A variable may be "accessed" but still have SPV_REFLECT_ACCESS_NONE
Example is if there is a OpAccessChain, but then it is never used

*/
typedef enum SpvReflectAccessFlagBits {
SPV_REFLECT_ACCESS_NONE = 0x00000000,
SPV_REFLECT_ACCESS_READ = 0x00000001,
SPV_REFLECT_ACCESS_WRITE = 0x00000002,
// Atomic will always also be marked as READ and WRITE
SPV_REFLECT_ACCESS_ATOMIC = 0x00000004,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there ever a case where ATOMIC is used or checked by itself it it always inlcudes READ and WRITE? Wondering if there's a different way to express this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so I had it before as separate... #222 (comment) convinced me otherwise

} SpvReflectAccessFlagBits;

typedef uint32_t SpvReflectAccessFlags;

/*! @enum SpvReflectGenerator

*/
Expand Down Expand Up @@ -440,6 +456,7 @@ typedef struct SpvReflectDescriptorBinding {
SpvReflectBindingArrayTraits array;
uint32_t count;
uint32_t accessed;
SpvReflectAccessFlags access_flags;
uint32_t uav_counter_id;
struct SpvReflectDescriptorBinding* uav_counter_binding;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ all_descriptor_bindings:
block: *bv2 #
array: { dims_count: 0, dims: [] }
accessed: 1
access_flags: 0x00000000 # NONE
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

last thing left in PR is figure out why this is not consistent with accessed

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is

         %13 = OpAccessChain %_ptr_StorageBuffer__struct_3 %2
         %14 = OpArrayLength %uint %13 1
                     OpReturn
                     OpFunctionEn

and from looking at spec, this is not considered a read, just an access

uav_counter_id: 4294967295
uav_counter_binding:
type_description: *td2
Expand Down
2 changes: 2 additions & 0 deletions tests/cbuffer_unused/cbuffer_unused_001.spv.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3309,6 +3309,7 @@ all_descriptor_bindings:
block: *bv63 # "MyParams"
array: { dims_count: 0, dims: [] }
accessed: 1
access_flags: 0x00000001 # READ
uav_counter_id: 4294967295
uav_counter_binding:
type_description: *td63
Expand All @@ -3325,6 +3326,7 @@ all_descriptor_bindings:
block: *bv95 # "MyParams2"
array: { dims_count: 0, dims: [] }
accessed: 1
access_flags: 0x00000001 # READ
uav_counter_id: 4294967295
uav_counter_binding:
type_description: *td95
Expand Down
1 change: 1 addition & 0 deletions tests/entry_exec_mode/comp_local_size.spv.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ all_descriptor_bindings:
block: *bv1 # ""
array: { dims_count: 0, dims: [] }
accessed: 1
access_flags: 0x00000003 # READ WRITE
uav_counter_id: 4294967295
uav_counter_binding:
type_description: *td1
Expand Down
Loading