diff --git a/source/fuzz/fact_manager.cpp b/source/fuzz/fact_manager.cpp index 6dff669f25..97474ea491 100644 --- a/source/fuzz/fact_manager.cpp +++ b/source/fuzz/fact_manager.cpp @@ -475,10 +475,13 @@ class FactManager::DataSynonymAndIdEquationFacts { const protobufs::DataDescriptor& dd2); // Returns true if and only if |dd1| and |dd2| are valid data descriptors - // whose associated data have the same type (modulo integer signedness). - bool DataDescriptorsAreWellFormedAndComparable( + // whose associated data have compatible types. Two types are compatible if: + // - they are the same + // - they both are numerical or vectors of numerical components with the same + // number of components and the same bit count per component + static bool DataDescriptorsAreWellFormedAndComparable( opt::IRContext* context, const protobufs::DataDescriptor& dd1, - const protobufs::DataDescriptor& dd2) const; + const protobufs::DataDescriptor& dd2); OperationSet GetEquations(const protobufs::DataDescriptor* lhs) const; @@ -1206,7 +1209,7 @@ void FactManager::DataSynonymAndIdEquationFacts::MakeEquivalent( bool FactManager::DataSynonymAndIdEquationFacts:: DataDescriptorsAreWellFormedAndComparable( opt::IRContext* context, const protobufs::DataDescriptor& dd1, - const protobufs::DataDescriptor& dd2) const { + const protobufs::DataDescriptor& dd2) { auto end_type_id_1 = fuzzerutil::WalkCompositeTypeIndices( context, context->get_def_use_mgr()->GetDef(dd1.object())->type_id(), dd1.index()); @@ -1225,30 +1228,49 @@ bool FactManager::DataSynonymAndIdEquationFacts:: // vectors that differ only in signedness. // Get both types. - const opt::analysis::Type* type_1 = - context->get_type_mgr()->GetType(end_type_id_1); - const opt::analysis::Type* type_2 = - context->get_type_mgr()->GetType(end_type_id_2); - - // If the first type is a vector, check that the second type is a vector of - // the same width, and drill down to the vector element types. - if (type_1->AsVector()) { - if (!type_2->AsVector()) { + const auto* type_a = context->get_type_mgr()->GetType(end_type_id_1); + const auto* type_b = context->get_type_mgr()->GetType(end_type_id_2); + assert(type_a && type_b && "Data descriptors have invalid type(s)"); + + // If both types are numerical or vectors of numerical components, then they + // are compatible if they have the same number of components and the same bit + // count per component. + + if (type_a->AsVector() && type_b->AsVector()) { + const auto* vector_a = type_a->AsVector(); + const auto* vector_b = type_b->AsVector(); + + if (vector_a->element_count() != vector_b->element_count() || + vector_a->element_type()->AsBool() || + vector_b->element_type()->AsBool()) { + // The case where both vectors have boolean elements and the same number + // of components is handled by the direct equality check earlier. + // You can't have multiple identical boolean vector types. return false; } - if (type_1->AsVector()->element_count() != - type_2->AsVector()->element_count()) { - return false; - } - type_1 = type_1->AsVector()->element_type(); - type_2 = type_2->AsVector()->element_type(); + + type_a = vector_a->element_type(); + type_b = vector_b->element_type(); } - // Check that type_1 and type_2 are both integer types of the same bit-width - // (but with potentially different signedness). - auto integer_type_1 = type_1->AsInteger(); - auto integer_type_2 = type_2->AsInteger(); - return integer_type_1 && integer_type_2 && - integer_type_1->width() == integer_type_2->width(); + + auto get_bit_count_for_numeric_type = + [](const opt::analysis::Type& type) -> uint32_t { + if (const auto* integer = type.AsInteger()) { + return integer->width(); + } else if (const auto* floating = type.AsFloat()) { + return floating->width(); + } else { + assert(false && "|type| must be a numerical type"); + return 0; + } + }; + + // Checks that both |type_a| and |type_b| are either numerical or vectors of + // numerical components and have the same number of bits. + return (type_a->AsInteger() || type_a->AsFloat()) && + (type_b->AsInteger() || type_b->AsFloat()) && + (get_bit_count_for_numeric_type(*type_a) == + get_bit_count_for_numeric_type(*type_b)); } std::vector diff --git a/source/fuzz/fuzzer_pass_apply_id_synonyms.cpp b/source/fuzz/fuzzer_pass_apply_id_synonyms.cpp index 2808ad5610..bb676e8961 100644 --- a/source/fuzz/fuzzer_pass_apply_id_synonyms.cpp +++ b/source/fuzz/fuzzer_pass_apply_id_synonyms.cpp @@ -60,7 +60,7 @@ void FuzzerPassApplyIdSynonyms::Apply() { } }); - for (auto& use : uses) { + for (const auto& use : uses) { auto use_inst = use.first; auto use_index = use.second; auto block_containing_use = GetIRContext()->get_instr_block(use_inst); @@ -82,7 +82,7 @@ void FuzzerPassApplyIdSynonyms::Apply() { } std::vector synonyms_to_try; - for (auto& data_descriptor : + for (const auto* data_descriptor : GetTransformationContext()->GetFactManager()->GetSynonymsForId( id_with_known_synonyms)) { protobufs::DataDescriptor descriptor_for_this_id = @@ -91,7 +91,12 @@ void FuzzerPassApplyIdSynonyms::Apply() { // Exclude the fact that the id is synonymous with itself. continue; } - synonyms_to_try.push_back(data_descriptor); + + if (DataDescriptorsHaveCompatibleTypes( + use_inst->opcode(), use_in_operand_index, + descriptor_for_this_id, *data_descriptor)) { + synonyms_to_try.push_back(data_descriptor); + } } while (!synonyms_to_try.empty()) { auto synonym_to_try = @@ -162,5 +167,26 @@ void FuzzerPassApplyIdSynonyms::Apply() { } } +bool FuzzerPassApplyIdSynonyms::DataDescriptorsHaveCompatibleTypes( + SpvOp opcode, uint32_t use_in_operand_index, + const protobufs::DataDescriptor& dd1, + const protobufs::DataDescriptor& dd2) { + auto base_object_type_id_1 = + fuzzerutil::GetTypeId(GetIRContext(), dd1.object()); + auto base_object_type_id_2 = + fuzzerutil::GetTypeId(GetIRContext(), dd2.object()); + assert(base_object_type_id_1 && base_object_type_id_2 && + "Data descriptors are invalid"); + + auto type_id_1 = fuzzerutil::WalkCompositeTypeIndices( + GetIRContext(), base_object_type_id_1, dd1.index()); + auto type_id_2 = fuzzerutil::WalkCompositeTypeIndices( + GetIRContext(), base_object_type_id_2, dd2.index()); + assert(type_id_1 && type_id_2 && "Data descriptors have invalid types"); + + return TransformationReplaceIdWithSynonym::TypesAreCompatible( + GetIRContext(), opcode, use_in_operand_index, type_id_1, type_id_2); +} + } // namespace fuzz } // namespace spvtools diff --git a/source/fuzz/fuzzer_pass_apply_id_synonyms.h b/source/fuzz/fuzzer_pass_apply_id_synonyms.h index 1a9213db1a..2feb65c205 100644 --- a/source/fuzz/fuzzer_pass_apply_id_synonyms.h +++ b/source/fuzz/fuzzer_pass_apply_id_synonyms.h @@ -16,7 +16,6 @@ #define SOURCE_FUZZ_FUZZER_PASS_APPLY_ID_SYNONYMS_ #include "source/fuzz/fuzzer_pass.h" - #include "source/opt/ir_context.h" namespace spvtools { @@ -34,6 +33,16 @@ class FuzzerPassApplyIdSynonyms : public FuzzerPass { ~FuzzerPassApplyIdSynonyms() override; void Apply() override; + + private: + // Returns true if uses of |dd1| can be replaced with |dd2| and vice-versa + // with respect to the type. Concretely, returns true if |dd1| and |dd2| have + // the same type or both |dd1| and |dd2| are either a numerical or a vector + // type of integral components with possibly different signedness. + bool DataDescriptorsHaveCompatibleTypes(SpvOp opcode, + uint32_t use_in_operand_index, + const protobufs::DataDescriptor& dd1, + const protobufs::DataDescriptor& dd2); }; } // namespace fuzz diff --git a/source/fuzz/transformation_replace_id_with_synonym.cpp b/source/fuzz/transformation_replace_id_with_synonym.cpp index 55607e1035..fbbeab2e6a 100644 --- a/source/fuzz/transformation_replace_id_with_synonym.cpp +++ b/source/fuzz/transformation_replace_id_with_synonym.cpp @@ -66,12 +66,9 @@ bool TransformationReplaceIdWithSynonym::IsApplicable( // If the id of interest and the synonym are scalar or vector integer // constants with different signedness, their use can only be swapped if the // instruction is agnostic to the signedness of the operand. - if (type_id_of_interest != type_id_synonym && - fuzzerutil::TypesAreEqualUpToSign(ir_context, type_id_of_interest, - type_id_synonym) && - !IsAgnosticToSignednessOfOperand( - use_instruction->opcode(), - message_.id_use_descriptor().in_operand_index())) { + if (!TypesAreCompatible(ir_context, use_instruction->opcode(), + message_.id_use_descriptor().in_operand_index(), + type_id_of_interest, type_id_synonym)) { return false; } @@ -241,5 +238,17 @@ bool TransformationReplaceIdWithSynonym::IsAgnosticToSignednessOfOperand( } } +bool TransformationReplaceIdWithSynonym::TypesAreCompatible( + opt::IRContext* ir_context, SpvOp opcode, uint32_t use_in_operand_index, + uint32_t type_id_1, uint32_t type_id_2) { + assert(ir_context->get_type_mgr()->GetType(type_id_1) && + ir_context->get_type_mgr()->GetType(type_id_2) && + "Type ids are invalid"); + + return type_id_1 == type_id_2 || + (IsAgnosticToSignednessOfOperand(opcode, use_in_operand_index) && + fuzzerutil::TypesAreEqualUpToSign(ir_context, type_id_1, type_id_2)); +} + } // namespace fuzz } // namespace spvtools diff --git a/source/fuzz/transformation_replace_id_with_synonym.h b/source/fuzz/transformation_replace_id_with_synonym.h index 78878b292d..fe3fac0521 100644 --- a/source/fuzz/transformation_replace_id_with_synonym.h +++ b/source/fuzz/transformation_replace_id_with_synonym.h @@ -64,6 +64,15 @@ class TransformationReplaceIdWithSynonym : public Transformation { opt::Instruction* use_instruction, uint32_t use_in_operand_index); + // Returns true if |type_id_1| and |type_id_2| represent compatible types + // given the context of the instruction with |opcode| (i.e. we can replace + // an operand of |opcode| of the first type with an id of the second type + // and vice-versa). + static bool TypesAreCompatible(opt::IRContext* ir_context, SpvOp opcode, + uint32_t use_in_operand_index, + uint32_t type_id_1, uint32_t type_id_2); + + private: // Returns true if the instruction with opcode |opcode| does not change its // behaviour depending on the signedness of the operand at // |use_in_operand_index|. @@ -71,7 +80,6 @@ class TransformationReplaceIdWithSynonym : public Transformation { static bool IsAgnosticToSignednessOfOperand(SpvOp opcode, uint32_t use_in_operand_index); - private: protobufs::TransformationReplaceIdWithSynonym message_; }; diff --git a/test/fuzz/transformation_replace_id_with_synonym_test.cpp b/test/fuzz/transformation_replace_id_with_synonym_test.cpp index e4a3f004df..33713c2bb6 100644 --- a/test/fuzz/transformation_replace_id_with_synonym_test.cpp +++ b/test/fuzz/transformation_replace_id_with_synonym_test.cpp @@ -13,6 +13,7 @@ // limitations under the License. #include "source/fuzz/transformation_replace_id_with_synonym.h" + #include "source/fuzz/data_descriptor.h" #include "source/fuzz/id_use_descriptor.h" #include "source/fuzz/instruction_descriptor.h" @@ -1715,6 +1716,70 @@ TEST(TransformationReplaceIdWithSynonymTest, EquivalentIntegerVectorConstants) { ASSERT_TRUE(IsEqual(env, after_transformation, context.get())); } +TEST(TransformationReplaceIdWithSynonymTest, IncompatibleTypes) { + const std::string shader = R"( + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %2 "main" + OpExecutionMode %2 OriginUpperLeft + OpSource ESSL 310 + %5 = OpTypeVoid + %6 = OpTypeFunction %5 + %7 = OpTypeInt 32 1 + %8 = OpTypeInt 32 0 + %9 = OpTypeFloat 32 + %12 = OpConstant %7 1 + %13 = OpConstant %8 1 + %10 = OpConstant %9 1 + %2 = OpFunction %5 None %6 + %17 = OpLabel + %18 = OpIAdd %7 %12 %13 + %19 = OpFAdd %9 %10 %10 + OpReturn + OpFunctionEnd + )"; + + const auto env = SPV_ENV_UNIVERSAL_1_3; + const auto consumer = nullptr; + const auto context = BuildModule(env, consumer, shader, kFuzzAssembleOption); + ASSERT_TRUE(IsValid(env, context.get())); + + FactManager fact_manager; + spvtools::ValidatorOptions validator_options; + TransformationContext transformation_context(&fact_manager, + validator_options); + + auto* op_i_add = context->get_def_use_mgr()->GetDef(18); + ASSERT_TRUE(op_i_add); + + auto* op_f_add = context->get_def_use_mgr()->GetDef(19); + ASSERT_TRUE(op_f_add); + + fact_manager.AddFactDataSynonym(MakeDataDescriptor(12, {}), + MakeDataDescriptor(13, {}), context.get()); + fact_manager.AddFactDataSynonym(MakeDataDescriptor(12, {}), + MakeDataDescriptor(10, {}), context.get()); + + // Synonym differs only in signedness for OpIAdd. + ASSERT_TRUE(TransformationReplaceIdWithSynonym( + MakeIdUseDescriptorFromUse(context.get(), op_i_add, 0), 13) + .IsApplicable(context.get(), transformation_context)); + + // Synonym has wrong type for OpIAdd. + ASSERT_FALSE(TransformationReplaceIdWithSynonym( + MakeIdUseDescriptorFromUse(context.get(), op_i_add, 0), 10) + .IsApplicable(context.get(), transformation_context)); + + // Synonym has wrong type for OpFAdd. + ASSERT_FALSE(TransformationReplaceIdWithSynonym( + MakeIdUseDescriptorFromUse(context.get(), op_f_add, 0), 12) + .IsApplicable(context.get(), transformation_context)); + ASSERT_FALSE(TransformationReplaceIdWithSynonym( + MakeIdUseDescriptorFromUse(context.get(), op_f_add, 0), 13) + .IsApplicable(context.get(), transformation_context)); +} + } // namespace } // namespace fuzz } // namespace spvtools