diff --git a/source/opt/remove_duplicates_pass.cpp b/source/opt/remove_duplicates_pass.cpp index 0a54d76ea9..753f232457 100644 --- a/source/opt/remove_duplicates_pass.cpp +++ b/source/opt/remove_duplicates_pass.cpp @@ -105,6 +105,22 @@ bool RemoveDuplicatesPass::RemoveDuplicateTypes( return modified; } + std::unordered_set types_ids_of_resources; + for (auto& decoration_inst : ir_context->annotations()) { + if (decoration_inst.opcode() != SpvOpDecorate) { + continue; + } + + if (decoration_inst.GetSingleWordInOperand(1) != SpvDecorationBinding) { + continue; + } + + uint32_t var_id = decoration_inst.GetSingleWordInOperand(0); + ir::Instruction* var_inst = ir_context->get_def_use_mgr()->GetDef(var_id); + uint32_t type_id = var_inst->type_id(); + AddStructuresToSet(type_id, ir_context, &types_ids_of_resources); + } + std::vector visited_types; std::vector to_delete; for (auto* i = &*ir_context->types_values_begin(); i; i = i->NextNode()) { @@ -116,12 +132,26 @@ bool RemoveDuplicatesPass::RemoveDuplicateTypes( // Is the current type equal to one of the types we have aready visited? SpvId id_to_keep = 0u; + bool i_is_resource_type = types_ids_of_resources.count(i->result_id()) != 0; + // TODO(dneto0): Use a trie to avoid quadratic behaviour? Extract the // ResultIdTrie from unify_const_pass.cpp for this. - for (auto j : visited_types) { + for (auto& j : visited_types) { + if (!j) { + continue; + } + if (AreTypesEqual(*i, *j, ir_context)) { - id_to_keep = j->result_id(); - break; + if (!i_is_resource_type) { + id_to_keep = j->result_id(); + break; + } else if (!types_ids_of_resources.count(j->result_id())) { + ir_context->KillNamesAndDecorates(j->result_id()); + ir_context->ReplaceAllUsesWith(j->result_id(), i->result_id()); + modified = true; + to_delete.emplace_back(j); + j = nullptr; + } } } @@ -142,7 +172,7 @@ bool RemoveDuplicatesPass::RemoveDuplicateTypes( } return modified; -} +} // namespace opt // TODO(pierremoreau): Duplicate decoration groups should be removed. For // example, in @@ -202,5 +232,31 @@ bool RemoveDuplicatesPass::AreTypesEqual(const Instruction& inst1, return false; } +void RemoveDuplicatesPass::AddStructuresToSet( + uint32_t type_id, ir::IRContext* ctx, + std::unordered_set* set_of_ids) const { + ir::Instruction* type_inst = ctx->get_def_use_mgr()->GetDef(type_id); + switch (type_inst->opcode()) { + case SpvOpTypeStruct: + set_of_ids->insert(type_id); + for (uint32_t i = 0; i < type_inst->NumInOperands(); ++i) { + AddStructuresToSet(type_inst->GetSingleWordInOperand(i), ctx, + set_of_ids); + } + break; + case SpvOpTypeArray: + case SpvOpTypeRuntimeArray: + set_of_ids->insert(type_id); + AddStructuresToSet(type_inst->GetSingleWordInOperand(0), ctx, set_of_ids); + break; + case SpvOpTypePointer: + set_of_ids->insert(type_id); + AddStructuresToSet(type_inst->GetSingleWordInOperand(1), ctx, set_of_ids); + break; + default: + break; + } +} + } // namespace opt } // namespace spvtools diff --git a/source/opt/remove_duplicates_pass.h b/source/opt/remove_duplicates_pass.h index d766f6733b..c55d4cf590 100644 --- a/source/opt/remove_duplicates_pass.h +++ b/source/opt/remove_duplicates_pass.h @@ -59,6 +59,11 @@ class RemoveDuplicatesPass : public Pass { // // Returns true if the module was modified, false otherwise. bool RemoveDuplicateDecorations(ir::IRContext* ir_context) const; + + // Adds |type_id| to |set_of_ids| if |type_id| is the id of a a structure. + // Adds any structures that are subtypes of |type_id| to |set_of_ids|. + void AddStructuresToSet(uint32_t type_id, ir::IRContext* ctx, + std::unordered_set* set_of_ids) const; }; } // namespace opt diff --git a/test/opt/pass_remove_duplicates_test.cpp b/test/opt/pass_remove_duplicates_test.cpp index d269daa41f..d944be1b9d 100644 --- a/test/opt/pass_remove_duplicates_test.cpp +++ b/test/opt/pass_remove_duplicates_test.cpp @@ -25,8 +25,8 @@ namespace { -using spvtools::ir::IRContext; using spvtools::ir::Instruction; +using spvtools::ir::IRContext; using spvtools::opt::PassManager; using spvtools::opt::RemoveDuplicatesPass; @@ -129,8 +129,8 @@ OpCapability Linkage OpMemoryModel Logical GLSL450 )"; - EXPECT_THAT(RunPass(spirv), after); - EXPECT_THAT(GetErrorMessage(), ""); + EXPECT_EQ(RunPass(spirv), after); + EXPECT_EQ(GetErrorMessage(), ""); } TEST_F(RemoveDuplicatesTest, DuplicateExtInstImports) { @@ -149,8 +149,8 @@ OpCapability Linkage OpMemoryModel Logical GLSL450 )"; - EXPECT_THAT(RunPass(spirv), after); - EXPECT_THAT(GetErrorMessage(), ""); + EXPECT_EQ(RunPass(spirv), after); + EXPECT_EQ(GetErrorMessage(), ""); } TEST_F(RemoveDuplicatesTest, DuplicateTypes) { @@ -169,8 +169,8 @@ OpMemoryModel Logical GLSL450 %3 = OpTypeStruct %1 %1 )"; - EXPECT_THAT(RunPass(spirv), after); - EXPECT_THAT(GetErrorMessage(), ""); + EXPECT_EQ(RunPass(spirv), after); + EXPECT_EQ(GetErrorMessage(), ""); } TEST_F(RemoveDuplicatesTest, SameTypeDifferentMemberDecoration) { @@ -192,8 +192,8 @@ OpDecorate %1 GLSLPacked %3 = OpTypeStruct %2 %2 )"; - EXPECT_THAT(RunPass(spirv), after); - EXPECT_THAT(GetErrorMessage(), ""); + EXPECT_EQ(RunPass(spirv), after); + EXPECT_EQ(GetErrorMessage(), ""); } TEST_F(RemoveDuplicatesTest, SameTypeAndMemberDecoration) { @@ -215,8 +215,8 @@ OpDecorate %1 GLSLPacked %1 = OpTypeStruct %3 %3 )"; - EXPECT_THAT(RunPass(spirv), after); - EXPECT_THAT(GetErrorMessage(), ""); + EXPECT_EQ(RunPass(spirv), after); + EXPECT_EQ(GetErrorMessage(), ""); } TEST_F(RemoveDuplicatesTest, SameTypeAndDifferentName) { @@ -238,8 +238,8 @@ OpName %1 "Type1" %1 = OpTypeStruct %3 %3 )"; - EXPECT_THAT(RunPass(spirv), after); - EXPECT_THAT(GetErrorMessage(), ""); + EXPECT_EQ(RunPass(spirv), after); + EXPECT_EQ(GetErrorMessage(), ""); } // Check that #1033 has been fixed. @@ -268,8 +268,8 @@ OpGroupDecorate %3 %1 %2 %3 = OpVariable %4 Uniform )"; - EXPECT_THAT(RunPass(spirv), after); - EXPECT_THAT(GetErrorMessage(), ""); + EXPECT_EQ(RunPass(spirv), after); + EXPECT_EQ(GetErrorMessage(), ""); } TEST_F(RemoveDuplicatesTest, DifferentDecorationGroup) { @@ -303,8 +303,194 @@ OpGroupDecorate %2 %4 %4 = OpVariable %5 Uniform )"; - EXPECT_THAT(RunPass(spirv), after); - EXPECT_THAT(GetErrorMessage(), ""); + EXPECT_EQ(RunPass(spirv), after); + EXPECT_EQ(GetErrorMessage(), ""); } +TEST_F(RemoveDuplicatesTest, DontMergeNestedResourceTypes) { + const std::string spirv = R"(OpCapability Shader +OpMemoryModel Logical GLSL450 +OpSource HLSL 600 +OpName %1 "PositionAdjust" +OpMemberName %1 0 "XAdjust" +OpName %2 "NormalAdjust" +OpMemberName %2 0 "XDir" +OpMemberName %3 0 "AdjustXYZ" +OpMemberName %3 1 "AdjustDir" +OpName %4 "Constants" +OpMemberDecorate %1 0 Offset 0 +OpMemberDecorate %2 0 Offset 0 +OpMemberDecorate %3 0 Offset 0 +OpMemberDecorate %3 1 Offset 16 +OpDecorate %3 Block +OpDecorate %4 DescriptorSet 0 +OpDecorate %4 Binding 0 +%5 = OpTypeFloat 32 +%6 = OpTypeVector %5 3 +%1 = OpTypeStruct %6 +%2 = OpTypeStruct %6 +%3 = OpTypeStruct %1 %2 +%7 = OpTypePointer Uniform %3 +%4 = OpVariable %7 Uniform +)"; + + EXPECT_EQ(RunPass(spirv), spirv); + EXPECT_EQ(GetErrorMessage(), ""); +} + +TEST_F(RemoveDuplicatesTest, DontMergeResourceTypes) { + const std::string spirv = R"(OpCapability Shader +OpMemoryModel Logical GLSL450 +OpSource HLSL 600 +OpName %1 "PositionAdjust" +OpMemberName %1 0 "XAdjust" +OpName %2 "NormalAdjust" +OpMemberName %2 0 "XDir" +OpName %3 "Constants" +OpMemberDecorate %1 0 Offset 0 +OpMemberDecorate %2 0 Offset 0 +OpDecorate %3 DescriptorSet 0 +OpDecorate %3 Binding 0 +OpDecorate %4 DescriptorSet 1 +OpDecorate %4 Binding 0 +%5 = OpTypeFloat 32 +%6 = OpTypeVector %5 3 +%1 = OpTypeStruct %6 +%2 = OpTypeStruct %6 +%7 = OpTypePointer Uniform %1 +%8 = OpTypePointer Uniform %2 +%3 = OpVariable %7 Uniform +%4 = OpVariable %8 Uniform +)"; + + EXPECT_EQ(RunPass(spirv), spirv); + EXPECT_EQ(GetErrorMessage(), ""); +} + +TEST_F(RemoveDuplicatesTest, DontMergeResourceTypesContainingArray) { + const std::string spirv = R"(OpCapability Shader +OpMemoryModel Logical GLSL450 +OpSource HLSL 600 +OpName %1 "PositionAdjust" +OpMemberName %1 0 "XAdjust" +OpName %2 "NormalAdjust" +OpMemberName %2 0 "XDir" +OpName %3 "Constants" +OpMemberDecorate %1 0 Offset 0 +OpMemberDecorate %2 0 Offset 0 +OpDecorate %3 DescriptorSet 0 +OpDecorate %3 Binding 0 +OpDecorate %4 DescriptorSet 1 +OpDecorate %4 Binding 0 +%5 = OpTypeFloat 32 +%6 = OpTypeVector %5 3 +%1 = OpTypeStruct %6 +%2 = OpTypeStruct %6 +%7 = OpTypeInt 32 0 +%8 = OpConstant %7 4 +%9 = OpTypeArray %1 %8 +%10 = OpTypeArray %2 %8 +%11 = OpTypePointer Uniform %9 +%12 = OpTypePointer Uniform %10 +%3 = OpVariable %11 Uniform +%4 = OpVariable %12 Uniform +)"; + + EXPECT_EQ(RunPass(spirv), spirv); + EXPECT_EQ(GetErrorMessage(), ""); +} + +// Test that we merge the type of a resource with a type that is not the type +// a resource. The resource type appears first in this case. We must keep +// the resource type. +TEST_F(RemoveDuplicatesTest, MergeResourceTypeWithNonresourceType1) { + const std::string spirv = R"(OpCapability Shader +OpMemoryModel Logical GLSL450 +OpSource HLSL 600 +OpName %1 "PositionAdjust" +OpMemberName %1 0 "XAdjust" +OpName %2 "NormalAdjust" +OpMemberName %2 0 "XDir" +OpName %3 "Constants" +OpMemberDecorate %1 0 Offset 0 +OpMemberDecorate %2 0 Offset 0 +OpDecorate %3 DescriptorSet 0 +OpDecorate %3 Binding 0 +%4 = OpTypeFloat 32 +%5 = OpTypeVector %4 3 +%1 = OpTypeStruct %5 +%2 = OpTypeStruct %5 +%6 = OpTypePointer Uniform %1 +%7 = OpTypePointer Uniform %2 +%3 = OpVariable %6 Uniform +%8 = OpVariable %7 Uniform +)"; + + const std::string result = R"(OpCapability Shader +OpMemoryModel Logical GLSL450 +OpSource HLSL 600 +OpName %1 "PositionAdjust" +OpMemberName %1 0 "XAdjust" +OpName %3 "Constants" +OpMemberDecorate %1 0 Offset 0 +OpDecorate %3 DescriptorSet 0 +OpDecorate %3 Binding 0 +%4 = OpTypeFloat 32 +%5 = OpTypeVector %4 3 +%1 = OpTypeStruct %5 +%6 = OpTypePointer Uniform %1 +%3 = OpVariable %6 Uniform +%8 = OpVariable %6 Uniform +)"; + + EXPECT_EQ(RunPass(spirv), result); + EXPECT_EQ(GetErrorMessage(), ""); +} + +// Test that we merge the type of a resource with a type that is not the type +// a resource. The resource type appears second in this case. We must keep +// the resource type. +TEST_F(RemoveDuplicatesTest, MergeResourceTypeWithNonresourceType2) { + const std::string spirv = R"(OpCapability Shader +OpMemoryModel Logical GLSL450 +OpSource HLSL 600 +OpName %1 "PositionAdjust" +OpMemberName %1 0 "XAdjust" +OpName %2 "NormalAdjust" +OpMemberName %2 0 "XDir" +OpName %3 "Constants" +OpMemberDecorate %1 0 Offset 0 +OpMemberDecorate %2 0 Offset 0 +OpDecorate %3 DescriptorSet 0 +OpDecorate %3 Binding 0 +%4 = OpTypeFloat 32 +%5 = OpTypeVector %4 3 +%1 = OpTypeStruct %5 +%2 = OpTypeStruct %5 +%6 = OpTypePointer Uniform %1 +%7 = OpTypePointer Uniform %2 +%8 = OpVariable %6 Uniform +%3 = OpVariable %7 Uniform +)"; + + const std::string result = R"(OpCapability Shader +OpMemoryModel Logical GLSL450 +OpSource HLSL 600 +OpName %2 "NormalAdjust" +OpMemberName %2 0 "XDir" +OpName %3 "Constants" +OpMemberDecorate %2 0 Offset 0 +OpDecorate %3 DescriptorSet 0 +OpDecorate %3 Binding 0 +%4 = OpTypeFloat 32 +%5 = OpTypeVector %4 3 +%2 = OpTypeStruct %5 +%7 = OpTypePointer Uniform %2 +%8 = OpVariable %7 Uniform +%3 = OpVariable %7 Uniform +)"; + + EXPECT_EQ(RunPass(spirv), result); + EXPECT_EQ(GetErrorMessage(), ""); +} } // namespace