From 8dc1541f789defb35e31e1347165823685a8dd50 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Thu, 24 Sep 2020 12:51:41 -0400 Subject: [PATCH] Cleanup and document Type Analysis code --- enzyme/Enzyme/ActiveVariable.cpp | 28 +- enzyme/Enzyme/AdjointGenerator.h | 14 +- enzyme/Enzyme/Enzyme.cpp | 6 +- enzyme/Enzyme/EnzymeLogic.cpp | 36 +- enzyme/Enzyme/GradientUtils.cpp | 4 +- enzyme/Enzyme/GradientUtils.h | 2 +- enzyme/Enzyme/TypeAnalysis/BaseType.h | 7 +- enzyme/Enzyme/TypeAnalysis/ConcreteType.h | 533 +++++++++--------- enzyme/Enzyme/TypeAnalysis/TBAA.h | 149 +++-- enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp | 245 ++++---- enzyme/Enzyme/TypeAnalysis/TypeAnalysis.h | 85 ++- .../TypeAnalysis/TypeAnalysisPrinter.cpp | 19 +- enzyme/Enzyme/TypeAnalysis/TypeTree.h | 339 +++++------ enzyme/Enzyme/Utils.h | 21 +- enzyme/test/Enzyme/unusedalloc.ll | 39 ++ 15 files changed, 788 insertions(+), 739 deletions(-) create mode 100644 enzyme/test/Enzyme/unusedalloc.ll diff --git a/enzyme/Enzyme/ActiveVariable.cpp b/enzyme/Enzyme/ActiveVariable.cpp index a70ade8f2736..c569368e2df3 100644 --- a/enzyme/Enzyme/ActiveVariable.cpp +++ b/enzyme/Enzyme/ActiveVariable.cpp @@ -67,6 +67,10 @@ cl::opt emptyfnconst("enzyme_emptyfnconst", cl::init(false), cl::Hidden, #include #include + +constexpr uint8_t UP = 1; +constexpr uint8_t DOWN = 2; + bool isFunctionArgumentConstant(TypeResults &TR, CallInst *CI, Value *val, SmallPtrSetImpl &constants, SmallPtrSetImpl &nonconstant, @@ -180,11 +184,11 @@ bool isFunctionArgumentConstant(TypeResults &TR, CallInst *CI, Value *val, FnTypeInfo nextTypeInfo(F); int argnum = 0; for (auto &arg : F->args()) { - nextTypeInfo.first.insert(std::pair( + nextTypeInfo.Arguments.insert(std::pair( &arg, TR.query(CI->getArgOperand(argnum)))); ++argnum; } - nextTypeInfo.second = TR.query(CI); + nextTypeInfo.Return = TR.query(CI); TypeResults TR2 = TR.analysis.analyzeFunction(nextTypeInfo); for (unsigned i = 0; i < CI->getNumArgOperands(); ++i) { @@ -352,9 +356,7 @@ bool isconstantM(TypeResults &TR, Instruction *inst, SmallPtrSetImpl &retvals, AAResults &AA, uint8_t directions) { assert(inst); - assert(TR.info.function == inst->getParent()->getParent()); - constexpr uint8_t UP = 1; - constexpr uint8_t DOWN = 2; + assert(TR.info.Function == inst->getParent()->getParent()); // assert(directions >= 0); assert(directions <= 3); if (isa(inst)) @@ -473,7 +475,7 @@ bool isconstantM(TypeResults &TR, Instruction *inst, auto q = TR.query(storeinst->getPointerOperand()).Data0(); for (int i = -1; i < (int)storeSize; ++i) { auto dt = q[{i}]; - if (dt.isIntegral() || dt.typeEnum == BaseType::Anything) { + if (dt.isIntegral() || dt == BaseType::Anything) { anIntegral = true; } else if (dt.isKnown()) { allIntegral = false; @@ -964,13 +966,11 @@ bool isconstantValueM(TypeResults &TR, Value *val, uint8_t directions) { assert(val); if (auto inst = dyn_cast(val)) { - assert(TR.info.function == inst->getParent()->getParent()); + assert(TR.info.Function == inst->getParent()->getParent()); } if (auto arg = dyn_cast(val)) { - assert(TR.info.function == arg->getParent()); + assert(TR.info.Function == arg->getParent()); } - // constexpr uint8_t UP = 1; - constexpr uint8_t DOWN = 2; // assert(directions >= 0); assert(directions <= 3); @@ -1017,8 +1017,8 @@ bool isconstantValueM(TypeResults &TR, Value *val, assert(0 && "must've put arguments in constant/nonconstant"); } - //! This value is certainly an integer (and only and integer, not a pointer or - //! float). Therefore its value is constant + // This value is certainly an integer (and only and integer, not a pointer or + // float). Therefore its value is constant if (TR.intType(val, /*errIfNotFound*/ false).isIntegral()) { if (printconst) llvm::errs() << " Value const as integral " << (int)directions << " " @@ -1028,8 +1028,8 @@ bool isconstantValueM(TypeResults &TR, Value *val, return true; } - //! This value is certainly a pointer to an integer (and only and integer, not - //! a pointer or float). Therefore its value is constant + // This value is certainly a pointer to an integer (and only and integer, not + // a pointer or float). Therefore its value is constant // TODO use typeInfo for more aggressive activity analysis if (val->getType()->isPointerTy() && cast(val->getType())->isIntOrIntVectorTy() && diff --git a/enzyme/Enzyme/AdjointGenerator.h b/enzyme/Enzyme/AdjointGenerator.h index 24ade74b8d6b..bc177ee129c9 100644 --- a/enzyme/Enzyme/AdjointGenerator.h +++ b/enzyme/Enzyme/AdjointGenerator.h @@ -74,7 +74,7 @@ class DerivativeMaker unnecessaryInstructions(unnecessaryInstructions), unnecessaryStores(unnecessaryStores), dretAlloca(dretAlloca) { - assert(TR.info.function == gutils->oldFunc); + assert(TR.info.Function == gutils->oldFunc); for (auto &pair : TR.analysis.analyzedFunctions.find(TR.info)->second.analysis) { if (auto in = dyn_cast(pair.first)) { @@ -1129,9 +1129,9 @@ class DerivativeMaker auto dt = vd[{-1}]; for (size_t i = start; i < size; ++i) { - bool legal = true; - dt.legalMergeIn(vd[{(int)i}], /*pointerIntSame*/ true, legal); - if (!legal) { + bool Legal = true; + dt.checkedOrIn(vd[{(int)i}], /*PointerIntSame*/ true, Legal); + if (!Legal) { nextStart = i; break; } @@ -1912,15 +1912,15 @@ class DerivativeMaker std::map> intseen; for (auto &arg : called->args()) { - nextTypeInfo.first.insert(std::pair( + nextTypeInfo.Arguments.insert(std::pair( &arg, TR.query(orig->getArgOperand(argnum)))); - nextTypeInfo.knownValues.insert( + nextTypeInfo.KnownValues.insert( std::pair>( &arg, TR.knownIntegralValues(orig->getArgOperand(argnum)))); ++argnum; } - nextTypeInfo.second = TR.query(orig); + nextTypeInfo.Return = TR.query(orig); } // llvm::Optional, unsigned>> diff --git a/enzyme/Enzyme/Enzyme.cpp b/enzyme/Enzyme/Enzyme.cpp index 5422cf6dcf3b..645fca8ca19f 100644 --- a/enzyme/Enzyme/Enzyme.cpp +++ b/enzyme/Enzyme/Enzyme.cpp @@ -210,7 +210,7 @@ void HandleAutoDiff(CallInst *CI, TargetLibraryInfo &TLI, AAResults &AA) { std::map volatile_args; FnTypeInfo type_args(cast(fn)); - for (auto &a : type_args.function->args()) { + for (auto &a : type_args.Function->args()) { volatile_args[&a] = false; TypeTree dt; if (a.getType()->isFPOrFPVectorTy()) { @@ -223,10 +223,10 @@ void HandleAutoDiff(CallInst *CI, TargetLibraryInfo &TLI, AAResults &AA) { dt = TypeTree(ConcreteType(BaseType::Pointer)).Only(-1); } } - type_args.first.insert(std::pair(&a, dt.Only(-1))); + type_args.Arguments.insert(std::pair(&a, dt.Only(-1))); // TODO note that here we do NOT propagate constants in type info (and // should consider whether we should) - type_args.knownValues.insert( + type_args.KnownValues.insert( std::pair>(&a, {})); } diff --git a/enzyme/Enzyme/EnzymeLogic.cpp b/enzyme/Enzyme/EnzymeLogic.cpp index f6daa7ee8bf0..a1b822507c2a 100644 --- a/enzyme/Enzyme/EnzymeLogic.cpp +++ b/enzyme/Enzyme/EnzymeLogic.cpp @@ -943,7 +943,7 @@ CreateAugmentedPrimal(Function *todiff, DIFFE_TYPE retType, !todiff->getReturnType()->isVoidTy()); FnTypeInfo oldTypeInfo = oldTypeInfo_; - for (auto &pair : oldTypeInfo.knownValues) { + for (auto &pair : oldTypeInfo.KnownValues) { if (pair.second.size() != 0) { bool recursiveUse = false; for (auto user : pair.first->users()) { @@ -1082,23 +1082,23 @@ CreateAugmentedPrimal(Function *todiff, DIFFE_TYPE retType, for (; toarg != todiff->arg_end(); ++toarg, ++olarg) { { - auto fd = oldTypeInfo.first.find(toarg); - assert(fd != oldTypeInfo.first.end()); - typeInfo.first.insert( + auto fd = oldTypeInfo.Arguments.find(toarg); + assert(fd != oldTypeInfo.Arguments.end()); + typeInfo.Arguments.insert( std::pair(olarg, fd->second)); } { - auto cfd = oldTypeInfo.knownValues.find(toarg); - assert(cfd != oldTypeInfo.knownValues.end()); - typeInfo.knownValues.insert( + auto cfd = oldTypeInfo.KnownValues.find(toarg); + assert(cfd != oldTypeInfo.KnownValues.end()); + typeInfo.KnownValues.insert( std::pair>(olarg, cfd->second)); } } - typeInfo.second = oldTypeInfo.second; + typeInfo.Return = oldTypeInfo.Return; } TypeResults TR = TA.analyzeFunction(typeInfo); - assert(TR.info.function == gutils->oldFunc); + assert(TR.info.Function == gutils->oldFunc); gutils->forceActiveDetection(AA, TR); gutils->forceAugmentedReturns(TR, guaranteedUnreachable); @@ -1831,7 +1831,7 @@ Function *CreatePrimalAndGradient( const AugmentedReturn *augmenteddata) { FnTypeInfo oldTypeInfo = oldTypeInfo_; - for (auto &pair : oldTypeInfo.knownValues) { + for (auto &pair : oldTypeInfo.KnownValues) { if (pair.second.size() != 0) { bool recursiveUse = false; for (auto user : pair.first->users()) { @@ -2045,24 +2045,24 @@ Function *CreatePrimalAndGradient( for (; toarg != todiff->arg_end(); ++toarg, ++olarg) { { - auto fd = oldTypeInfo.first.find(toarg); - assert(fd != oldTypeInfo.first.end()); - typeInfo.first.insert( + auto fd = oldTypeInfo.Arguments.find(toarg); + assert(fd != oldTypeInfo.Arguments.end()); + typeInfo.Arguments.insert( std::pair(olarg, fd->second)); } { - auto cfd = oldTypeInfo.knownValues.find(toarg); - assert(cfd != oldTypeInfo.knownValues.end()); - typeInfo.knownValues.insert( + auto cfd = oldTypeInfo.KnownValues.find(toarg); + assert(cfd != oldTypeInfo.KnownValues.end()); + typeInfo.KnownValues.insert( std::pair>(olarg, cfd->second)); } } - typeInfo.second = oldTypeInfo.second; + typeInfo.Return = oldTypeInfo.Return; } TypeResults TR = TA.analyzeFunction(typeInfo); - assert(TR.info.function == gutils->oldFunc); + assert(TR.info.Function == gutils->oldFunc); gutils->forceActiveDetection(AA, TR); gutils->forceAugmentedReturns(TR, guaranteedUnreachable); diff --git a/enzyme/Enzyme/GradientUtils.cpp b/enzyme/Enzyme/GradientUtils.cpp index e81ee8c9c1a6..a17aa725574f 100644 --- a/enzyme/Enzyme/GradientUtils.cpp +++ b/enzyme/Enzyme/GradientUtils.cpp @@ -510,8 +510,8 @@ Value *GradientUtils::invertPointerM(Value *oval, IRBuilder<> &BuilderM) { std::vector types; for (auto &a : fn->args()) { uncacheable_args[&a] = !a.getType()->isFPOrFPVectorTy(); - type_args.first.insert(std::pair(&a, {})); - type_args.knownValues.insert( + type_args.Arguments.insert(std::pair(&a, {})); + type_args.KnownValues.insert( std::pair>(&a, {})); DIFFE_TYPE typ; if (a.getType()->isFPOrFPVectorTy()) { diff --git a/enzyme/Enzyme/GradientUtils.h b/enzyme/Enzyme/GradientUtils.h index ed7382054a8e..63ef2d5df6c3 100644 --- a/enzyme/Enzyme/GradientUtils.h +++ b/enzyme/Enzyme/GradientUtils.h @@ -1170,7 +1170,7 @@ class GradientUtils { void forceAugmentedReturns( TypeResults &TR, const SmallPtrSetImpl &guaranteedUnreachable) { - assert(TR.info.function == oldFunc); + assert(TR.info.Function == oldFunc); for (BasicBlock &oBB : *oldFunc) { // Don't create derivatives for code that results in termination diff --git a/enzyme/Enzyme/TypeAnalysis/BaseType.h b/enzyme/Enzyme/TypeAnalysis/BaseType.h index 2bea0c378655..5ba2e86dfcc2 100644 --- a/enzyme/Enzyme/TypeAnalysis/BaseType.h +++ b/enzyme/Enzyme/TypeAnalysis/BaseType.h @@ -27,19 +27,21 @@ #include #include "llvm/Support/ErrorHandling.h" +/// Categories of potential types enum class BaseType { - // integral type + // integral type which doesn't represent a pointer Integer, // floating point Float, // pointer Pointer, - // can be anything of users choosing [usually result of a constant] + // can be anything of users choosing [usually result of a constant such as 0] Anything, // insufficient information Unknown }; +/// Convert Basetype to string static inline std::string to_string(BaseType t) { switch (t) { case BaseType::Integer: @@ -56,6 +58,7 @@ static inline std::string to_string(BaseType t) { llvm_unreachable("unknown inttype"); } +/// Convert string to BaseType static inline BaseType parseBaseType(std::string str) { if (str == "Integer") return BaseType::Integer; diff --git a/enzyme/Enzyme/TypeAnalysis/ConcreteType.h b/enzyme/Enzyme/TypeAnalysis/ConcreteType.h index 6fe21ae4239a..19052ef4289a 100644 --- a/enzyme/Enzyme/TypeAnalysis/ConcreteType.h +++ b/enzyme/Enzyme/TypeAnalysis/ConcreteType.h @@ -1,4 +1,4 @@ -//===- ConcreteType.h - Underlying type used in Type Analysis ------------===// +//===- ConcreteType.h - Underlying SubType used in Type Analysis ------------===// // // Enzyme Project // @@ -18,8 +18,8 @@ //===----------------------------------------------------------------------===// // // This file contains the implementation of an a class representing all potential -// end types used in Type Analysis. This ``ConcreteType`` contains an the type -// category ``BaseType`` as well as the type of float, if relevant. This also +// end SubTypes used in Type Analysis. This ``ConcreteType`` contains an the SubType +// category ``BaseType`` as well as the SubType of float, if relevant. This also // contains several helper utility functions. // //===----------------------------------------------------------------------===// @@ -34,239 +34,306 @@ #include "BaseType.h" +/// Concrete SubType of a given value. Consists of a category `BaseType` and the +/// particular floating point value, if relevant. class ConcreteType { public: - llvm::Type *type; - BaseType typeEnum; - - ConcreteType(const ConcreteType&) = default; - ConcreteType(ConcreteType&&) = default; - ConcreteType(llvm::Type *type) : type(type), typeEnum(BaseType::Float) { - assert(type != nullptr); - assert(!llvm::isa(type)); - if (!type->isFloatingPointTy()) { - llvm::errs() << " passing in non FP type: " << *type << "\n"; + /// Category of underlying type + BaseType SubTypeEnum; + /// Floating point type, if relevant, otherwise nullptr + llvm::Type *SubType; + + /// Construct a ConcreteType from an existing FloatingPoint Type + ConcreteType(llvm::Type *SubType) : SubTypeEnum(BaseType::Float), SubType(SubType) { + assert(SubType != nullptr); + assert(!llvm::isa(SubType)); + if (!SubType->isFloatingPointTy()) { + llvm::errs() << " passing in non FP SubType: " << *SubType << "\n"; } - assert(type->isFloatingPointTy()); + assert(SubType->isFloatingPointTy()); } - ConcreteType(BaseType typeEnum) : type(nullptr), typeEnum(typeEnum) { - assert(typeEnum != BaseType::Float); + /// Construct a non-floating Concrete type from a BaseType + ConcreteType(BaseType SubTypeEnum) : SubTypeEnum(SubTypeEnum), SubType(nullptr) { + assert(SubTypeEnum != BaseType::Float); } - ConcreteType(std::string str, llvm::LLVMContext &C) { - auto fd = str.find('@'); - if (fd != std::string::npos) { - typeEnum = BaseType::Float; - assert(str.substr(0, fd) == "Float"); - auto subt = str.substr(fd + 1); - if (subt == "half") { - type = llvm::Type::getHalfTy(C); - } else if (subt == "float") { - type = llvm::Type::getFloatTy(C); - } else if (subt == "double") { - type = llvm::Type::getDoubleTy(C); - } else if (subt == "fp80") { - type = llvm::Type::getX86_FP80Ty(C); - } else if (subt == "fp128") { - type = llvm::Type::getFP128Ty(C); - } else if (subt == "ppc128") { - type = llvm::Type::getPPC_FP128Ty(C); + /// Construct a ConcreteType from a string + /// A Concrete Type's string representation is given by the string of the enum + /// If it is a floating point it is given by Float@ + ConcreteType(std::string Str, llvm::LLVMContext &C) { + auto Sep = Str.find('@'); + if (Sep != std::string::npos) { + SubTypeEnum = BaseType::Float; + assert(Str.substr(0, Sep) == "Float"); + auto SubName = Str.substr(Sep + 1); + if (SubName == "half") { + SubType = llvm::Type::getHalfTy(C); + } else if (SubName == "float") { + SubType = llvm::Type::getFloatTy(C); + } else if (SubName == "double") { + SubType = llvm::Type::getDoubleTy(C); + } else if (SubName == "fp80") { + SubType = llvm::Type::getX86_FP80Ty(C); + } else if (SubName == "fp128") { + SubType = llvm::Type::getFP128Ty(C); + } else if (SubName == "ppc128") { + SubType = llvm::Type::getPPC_FP128Ty(C); } else { - llvm_unreachable("unknown data type"); + llvm_unreachable("unknown data SubType"); } } else { - type = nullptr; - typeEnum = parseBaseType(str); + SubType = nullptr; + SubTypeEnum = parseBaseType(Str); } } - bool isIntegral() const { - return typeEnum == BaseType::Integer || typeEnum == BaseType::Anything; + /// Convert the ConcreteType to a string + std::string str() const { + std::string Result = to_string(SubTypeEnum); + if (SubTypeEnum == BaseType::Float) { + if (SubType->isHalfTy()) { + Result += "@half"; + } else if (SubType->isFloatTy()) { + Result += "@float"; + } else if (SubType->isDoubleTy()) { + Result += "@double"; + } else if (SubType->isX86_FP80Ty()) { + Result += "@fp80"; + } else if (SubType->isFP128Ty()) { + Result += "@fp128"; + } else if (SubType->isPPC_FP128Ty()) { + Result += "@ppc128"; + } else { + llvm_unreachable("unknown data SubType"); + } + } + return Result; } - bool isKnown() const { return typeEnum != BaseType::Unknown; } + /// Whether this ConcreteType has information (is not unknown) + bool isKnown() const { return SubTypeEnum != BaseType::Unknown; } + + /// Whether this ConcreteType can be used as an integer (SubTypeEnum is Integer or Anything) + bool isIntegral() const { + return SubTypeEnum == BaseType::Integer || SubTypeEnum == BaseType::Anything; + } + /// Whether this ConcreteType could be a pointer (SubTypeEnum is unknown or a pointer) bool isPossiblePointer() const { - return !isKnown() || typeEnum == BaseType::Pointer; + return !isKnown() || SubTypeEnum == BaseType::Pointer; } + /// Whether this ConcreteType could be a float (SubTypeEnum is unknown or a float) bool isPossibleFloat() const { - return !isKnown() || typeEnum == BaseType::Float; + return !isKnown() || SubTypeEnum == BaseType::Float; } - llvm::Type *isFloat() const { return type; } + /// Return the floating point type, if this is a float + llvm::Type *isFloat() const { return SubType; } - bool operator==(const BaseType dt) const { return typeEnum == dt; } + /// Return if this is known to be the BaseType BT + /// This cannot be called with BaseType::Float as it lacks information + bool operator==(const BaseType BT) const { + if (BT == BaseType::Float) { + assert(0 && "Cannot do comparision between ConcreteType and BaseType::Float"); + llvm_unreachable("Cannot do comparision between ConcreteType and BaseType::Float"); + } + return SubTypeEnum == BT; + } - bool operator!=(const BaseType dt) const { return typeEnum != dt; } + /// Return if this is known not to be the BaseType BT + /// This cannot be called with BaseType::Float as it lacks information + bool operator!=(const BaseType BT) const { + if (BT == BaseType::Float) { + assert(0 && "Cannot do comparision between ConcreteType and BaseType::Float"); + llvm_unreachable("Cannot do comparision between ConcreteType and BaseType::Float"); + } + return SubTypeEnum != BT; + } - bool operator==(const ConcreteType dt) const { - return type == dt.type && typeEnum == dt.typeEnum; + /// Return if this is known to be the ConcreteType CT + bool operator==(const ConcreteType CT) const { + return SubType == CT.SubType && SubTypeEnum == CT.SubTypeEnum; } - bool operator!=(const ConcreteType dt) const { return !(*this == dt); } + /// Return if this is known not to be the ConcreteType CT + bool operator!=(const ConcreteType CT) const { return !(*this == CT); } - bool operator=(const BaseType bt) { - assert(bt != BaseType::Float); + /// Set this to the given ConcreteType, returning true if + /// this ConcreteType has changed + bool operator=(const ConcreteType CT) { bool changed = false; - if (typeEnum != bt) + if (SubTypeEnum != CT.SubTypeEnum) changed = true; - typeEnum = bt; - if (type != nullptr) + SubTypeEnum = CT.SubTypeEnum; + if (SubType != CT.SubType) changed = true; - type = nullptr; + SubType = CT.SubType; return changed; } - // returns whether changed - bool operator=(const ConcreteType dt) { - bool changed = false; - if (typeEnum != dt.typeEnum) - changed = true; - typeEnum = dt.typeEnum; - if (type != dt.type) - changed = true; - type = dt.type; - return changed; - } - bool operator=(ConcreteType&& dt) { - bool changed = false; - if (typeEnum != dt.typeEnum) - changed = true; - typeEnum = dt.typeEnum; - if (type != dt.type) - changed = true; - type = dt.type; - return changed; + /// Set this to the given BaseType, returning true if + /// this ConcreteType has changed + bool operator=(const BaseType BT) { + assert(BT != BaseType::Float); + return ConcreteType::operator=(ConcreteType(BT)); } - // returns whether changed - bool legalMergeIn(const ConcreteType dt, bool pointerIntSame, bool &legal) { - if (typeEnum == BaseType::Anything) { + /// Set this to the logical or of itself and CT, returning whether this value changed + /// Setting `PointerIntSame` considers pointers and integers as equivalent + /// If this is an illegal operation, `LegalOr` will be set to false + bool checkedOrIn(const ConcreteType CT, bool PointerIntSame, bool &LegalOr) { + LegalOr = true; + if (SubTypeEnum == BaseType::Anything) { return false; } - if (dt.typeEnum == BaseType::Anything) { - return *this = dt; + if (CT.SubTypeEnum == BaseType::Anything) { + return *this = CT; } - if (typeEnum == BaseType::Unknown) { - return *this = dt; + if (SubTypeEnum == BaseType::Unknown) { + return *this = CT; } - if (dt.typeEnum == BaseType::Unknown) { + if (CT.SubTypeEnum == BaseType::Unknown) { return false; } - if (dt.typeEnum != typeEnum) { - if (pointerIntSame) { - if ((typeEnum == BaseType::Pointer && dt.typeEnum == BaseType::Integer) || - (typeEnum == BaseType::Integer && dt.typeEnum == BaseType::Pointer)) { + if (CT.SubTypeEnum != SubTypeEnum) { + if (PointerIntSame) { + if ((SubTypeEnum == BaseType::Pointer && CT.SubTypeEnum == BaseType::Integer) || + (SubTypeEnum == BaseType::Integer && CT.SubTypeEnum == BaseType::Pointer)) { return false; } } - legal = false; + LegalOr = false; return false; } - assert(dt.typeEnum == typeEnum); - if (dt.type != type) { - legal = false; + assert(CT.SubTypeEnum == SubTypeEnum); + if (CT.SubType != SubType) { + LegalOr = false; return false; } - assert(dt.type == type); + assert(CT.SubType == SubType); return false; } - // returns whether changed - bool mergeIn(const ConcreteType dt, bool pointerIntSame) { - bool legal = true; - bool res = legalMergeIn(dt, pointerIntSame, legal); - if (!legal) { - llvm::errs() << "me: " << str() << " right: " << dt.str() << "\n"; + /// Set this to the logical or of itself and CT, returning whether this value changed + /// Setting `PointerIntSame` considers pointers and integers as equivalent + /// This function will error if doing an illegal Operation + bool orIn(const ConcreteType CT, bool PointerIntSame) { + bool Legal = true; + bool Result = checkedOrIn(CT, PointerIntSame, Legal); + if (!Legal) { + llvm::errs() << "Illegal orIn: " << str() << " right: " << CT.str() << " PointerIntSame=" << PointerIntSame << "\n"; + assert(0 && "Performed illegal ConcreteType::orIn"); + llvm_unreachable("Performed illegal ConcreteType::orIn"); } - assert(legal); - return res; + return Result; } - // returns whether changed - bool operator|=(const ConcreteType dt) { - return mergeIn(dt, /*pointerIntSame*/ false); + /// Set this to the logical or of itself and CT, returning whether this value changed + /// This assumes that pointers and integers are distinct + /// This function will error if doing an illegal Operation + bool operator|=(const ConcreteType CT) { + return orIn(CT, /*pointerIntSame*/ false); } - bool pointerIntMerge(const ConcreteType dt, llvm::BinaryOperator::BinaryOps op) { - bool changed = false; + /// Set this to the logical and of itself and CT, returning whether this value changed + /// If this and CT are incompatible, the result will be BaseType::Unknown + bool andIn(const ConcreteType CT) { + if (SubTypeEnum == BaseType::Anything) { + return *this = CT; + } + if (CT.SubTypeEnum == BaseType::Anything) { + return false; + } + if (SubTypeEnum == BaseType::Unknown) { + return false; + } + if (CT.SubTypeEnum == BaseType::Unknown) { + return *this = CT; + } + + if (CT.SubTypeEnum != SubTypeEnum) { + return *this = BaseType::Unknown; + } + if (CT.SubType != SubType) { + return *this = BaseType::Unknown; + } + return false; + } + + /// Set this to the logical and of itself and CT, returning whether this value changed + /// If this and CT are incompatible, the result will be BaseType::Unknown + bool operator&=(const ConcreteType CT) { + return andIn(CT); + } + + /// Set this to the logical `binop` of itself and RHS, using the Binop Op, + /// returning true if this was changed. + /// This function will error on an invalid type combination + bool binopIn(const ConcreteType RHS, llvm::BinaryOperator::BinaryOps Op) { + bool Changed = false; using namespace llvm; - if (typeEnum == BaseType::Anything && dt.typeEnum == BaseType::Anything) { - return changed; + // Anything op Anyhting => Anything + if (SubTypeEnum == BaseType::Anything && RHS.SubTypeEnum == BaseType::Anything) { + return Changed; } - if (op == BinaryOperator::And && - (((typeEnum == BaseType::Anything || typeEnum == BaseType::Integer) && - dt.isFloat()) || - (isFloat() && (dt.typeEnum == BaseType::Anything || - dt.typeEnum == BaseType::Integer)))) { - typeEnum = BaseType::Unknown; - type = nullptr; - changed = true; - return changed; + // Constant & float => Unknown + if (Op == BinaryOperator::And && + (((SubTypeEnum == BaseType::Anything || SubTypeEnum == BaseType::Integer) && + RHS.isFloat()) || + (isFloat() && (RHS.SubTypeEnum == BaseType::Anything || + RHS.SubTypeEnum == BaseType::Integer)))) { + SubTypeEnum = BaseType::Unknown; + SubType = nullptr; + Changed = true; + return Changed; } - if ((typeEnum == BaseType::Unknown && dt.typeEnum == BaseType::Anything) || - (typeEnum == BaseType::Anything && dt.typeEnum == BaseType::Unknown)) { - if (typeEnum != BaseType::Unknown) { - typeEnum = BaseType::Unknown; - changed = true; + // Unknown op Anything => Unknown + if ((SubTypeEnum == BaseType::Unknown && RHS.SubTypeEnum == BaseType::Anything) || + (SubTypeEnum == BaseType::Anything && RHS.SubTypeEnum == BaseType::Unknown)) { + if (SubTypeEnum != BaseType::Unknown) { + SubTypeEnum = BaseType::Unknown; + Changed = true; } - return changed; + return Changed; } - if ((typeEnum == BaseType::Integer && dt.typeEnum == BaseType::Integer) || - (typeEnum == BaseType::Unknown && dt.typeEnum == BaseType::Integer) || - (typeEnum == BaseType::Integer && dt.typeEnum == BaseType::Unknown) || - (typeEnum == BaseType::Anything && dt.typeEnum == BaseType::Integer) || - (typeEnum == BaseType::Integer && dt.typeEnum == BaseType::Anything)) { - switch (op) { - case BinaryOperator::Add: - case BinaryOperator::Sub: - // if one of these is unknown we cannot deduce the result - // e.g. pointer + int = pointer and int + int = int - if (typeEnum == BaseType::Unknown || dt.typeEnum == BaseType::Unknown) { - if (typeEnum != BaseType::Unknown) { - typeEnum = BaseType::Unknown; - changed = true; - } - return changed; - } + // Integer op Integer => Integer + if (SubTypeEnum == BaseType::Integer && RHS.SubTypeEnum == BaseType::Integer) { + return Changed; + } - case BinaryOperator::Mul: - case BinaryOperator::UDiv: - case BinaryOperator::SDiv: - case BinaryOperator::URem: - case BinaryOperator::SRem: - case BinaryOperator::And: - case BinaryOperator::Or: - case BinaryOperator::Xor: - case BinaryOperator::Shl: - case BinaryOperator::AShr: - case BinaryOperator::LShr: - //! Anything << 16 ==> Anything - if (typeEnum == BaseType::Anything) { - break; - } - if (typeEnum != BaseType::Integer) { - typeEnum = BaseType::Integer; - changed = true; - } - break; - default: - llvm_unreachable("unknown binary operator"); + // Integer op Anything => Anything + if ((SubTypeEnum == BaseType::Anything && RHS.SubTypeEnum == BaseType::Integer) || + (SubTypeEnum == BaseType::Integer && RHS.SubTypeEnum == BaseType::Anything)) { + if (SubTypeEnum != BaseType::Anything) { + SubTypeEnum = BaseType::Anything; + Changed = true; + } + return Changed; + } + + // Integer op Unknown => Unknown + // e.g. pointer + int = pointer and int + int = int + if ((SubTypeEnum == BaseType::Unknown && RHS.SubTypeEnum == BaseType::Integer) || + (SubTypeEnum == BaseType::Integer && RHS.SubTypeEnum == BaseType::Unknown)) { + if (SubTypeEnum != BaseType::Unknown) { + SubTypeEnum = BaseType::Unknown; + Changed = true; } - return changed; + return Changed; } - if (typeEnum == BaseType::Pointer && dt.typeEnum == BaseType::Pointer) { - switch (op) { + // Pointer op Pointer => {Integer, Illegal} + if (SubTypeEnum == BaseType::Pointer && RHS.SubTypeEnum == BaseType::Pointer) { + switch (Op) { case BinaryOperator::Sub: - typeEnum = BaseType::Integer; - changed = true; + SubTypeEnum = BaseType::Integer; + Changed = true; break; case BinaryOperator::Add: case BinaryOperator::Mul: @@ -285,42 +352,43 @@ class ConcreteType { default: llvm_unreachable("unknown binary operator"); } - return changed; + return Changed; } - if ((typeEnum == BaseType::Integer && dt.typeEnum == BaseType::Pointer) || - (typeEnum == BaseType::Pointer && dt.typeEnum == BaseType::Integer) || - (typeEnum == BaseType::Integer && dt.typeEnum == BaseType::Pointer) || - (typeEnum == BaseType::Pointer && dt.typeEnum == BaseType::Unknown) || - (typeEnum == BaseType::Unknown && dt.typeEnum == BaseType::Pointer) || - (typeEnum == BaseType::Pointer && dt.typeEnum == BaseType::Anything) || - (typeEnum == BaseType::Anything && dt.typeEnum == BaseType::Pointer)) { + // Pointer op ? => {Pointer, Unknown} + if ((SubTypeEnum == BaseType::Integer && RHS.SubTypeEnum == BaseType::Pointer) || + (SubTypeEnum == BaseType::Pointer && RHS.SubTypeEnum == BaseType::Integer) || + (SubTypeEnum == BaseType::Integer && RHS.SubTypeEnum == BaseType::Pointer) || + (SubTypeEnum == BaseType::Pointer && RHS.SubTypeEnum == BaseType::Unknown) || + (SubTypeEnum == BaseType::Unknown && RHS.SubTypeEnum == BaseType::Pointer) || + (SubTypeEnum == BaseType::Pointer && RHS.SubTypeEnum == BaseType::Anything) || + (SubTypeEnum == BaseType::Anything && RHS.SubTypeEnum == BaseType::Pointer)) { - switch (op) { + switch (Op) { case BinaryOperator::Sub: - if (typeEnum == BaseType::Anything || dt.typeEnum == BaseType::Anything) { - if (typeEnum != BaseType::Unknown) { - typeEnum = BaseType::Unknown; - changed = true; + if (SubTypeEnum == BaseType::Anything || RHS.SubTypeEnum == BaseType::Anything) { + if (SubTypeEnum != BaseType::Unknown) { + SubTypeEnum = BaseType::Unknown; + Changed = true; } break; } case BinaryOperator::Add: case BinaryOperator::Mul: - if (typeEnum != BaseType::Pointer) { - typeEnum = BaseType::Pointer; - changed = true; + if (SubTypeEnum != BaseType::Pointer) { + SubTypeEnum = BaseType::Pointer; + Changed = true; } break; case BinaryOperator::UDiv: case BinaryOperator::SDiv: case BinaryOperator::URem: case BinaryOperator::SRem: - if (dt.typeEnum == BaseType::Pointer) { + if (RHS.SubTypeEnum == BaseType::Pointer) { llvm_unreachable("cannot divide integer by pointer"); - } else if (typeEnum != BaseType::Unknown) { - typeEnum = BaseType::Unknown; - changed = true; + } else if (SubTypeEnum != BaseType::Unknown) { + SubTypeEnum = BaseType::Unknown; + Changed = true; } break; case BinaryOperator::And: @@ -329,106 +397,33 @@ class ConcreteType { case BinaryOperator::Shl: case BinaryOperator::AShr: case BinaryOperator::LShr: - if (typeEnum != BaseType::Unknown) { - typeEnum = BaseType::Unknown; - changed = true; + if (SubTypeEnum != BaseType::Unknown) { + SubTypeEnum = BaseType::Unknown; + Changed = true; } break; default: llvm_unreachable("unknown binary operator"); } - return changed; - } - - if (dt.typeEnum == BaseType::Integer) { - switch (op) { - case BinaryOperator::Shl: - case BinaryOperator::AShr: - case BinaryOperator::LShr: - if (typeEnum != BaseType::Unknown) { - typeEnum = BaseType::Unknown; - changed = true; - return changed; - } - break; - default: - break; - } + return Changed; } - llvm::errs() << "self: " << str() << " other: " << dt.str() << " op: " << op + llvm::errs() << "self: " << str() << " RHS: " << RHS.str() << " Op: " << Op << "\n"; - llvm_unreachable("unknown case"); - } - - bool andIn(const ConcreteType dt, bool assertIfIllegal = true) { - if (typeEnum == BaseType::Anything) { - return *this = dt; - } - if (dt.typeEnum == BaseType::Anything) { - return false; - } - if (typeEnum == BaseType::Unknown) { - return false; - } - if (dt.typeEnum == BaseType::Unknown) { - return *this = dt; - } - - if (dt.typeEnum != typeEnum) { - if (!assertIfIllegal) { - return *this = BaseType::Unknown; - } - llvm::errs() << "&= typeEnum: " << to_string(typeEnum) - << " dt.typeEnum.str(): " << to_string(dt.typeEnum) << "\n"; - return *this = BaseType::Unknown; - } - assert(dt.typeEnum == typeEnum); - if (dt.type != type) { - if (!assertIfIllegal) { - return *this = BaseType::Unknown; - } - llvm::errs() << "type: " << *type << " dt.type: " << *dt.type << "\n"; - } - assert(dt.type == type); - return false; - } - - // returns whether changed - bool operator&=(const ConcreteType dt) { - return andIn(dt, /*assertIfIllegal*/ true); + llvm_unreachable("Unknown ConcreteType::binopIn"); } + /// Compare concrete types for use in map's bool operator<(const ConcreteType dt) const { - if (typeEnum == dt.typeEnum) { - return type < dt.type; + if (SubTypeEnum == dt.SubTypeEnum) { + return SubType < dt.SubType; } else { - return typeEnum < dt.typeEnum; - } - } - std::string str() const { - std::string res = to_string(typeEnum); - if (typeEnum == BaseType::Float) { - if (type->isHalfTy()) { - res += "@half"; - } else if (type->isFloatTy()) { - res += "@float"; - } else if (type->isDoubleTy()) { - res += "@double"; - } else if (type->isX86_FP80Ty()) { - res += "@fp80"; - } else if (type->isFP128Ty()) { - res += "@fp128"; - } else if (type->isPPC_FP128Ty()) { - res += "@ppc128"; - } else { - llvm_unreachable("unknown data type"); - } + return SubTypeEnum < dt.SubTypeEnum; } - return res; } }; +// Convert ConcreteType to string static inline std::string to_string(const ConcreteType dt) { return dt.str(); } #endif diff --git a/enzyme/Enzyme/TypeAnalysis/TBAA.h b/enzyme/Enzyme/TypeAnalysis/TBAA.h index 62ec25dd88f1..2be51edebe97 100644 --- a/enzyme/Enzyme/TypeAnalysis/TBAA.h +++ b/enzyme/Enzyme/TypeAnalysis/TBAA.h @@ -354,49 +354,52 @@ getAccessNameTBAA(Instruction *Inst, const std::set &legalnames) { return ""; } - -//! The following is new -extern llvm::cl::opt printtype; - -static inline ConcreteType getTypeFromTBAAString(std::string typeNameStringRef, - Instruction *inst) { - if (typeNameStringRef == "long long" || typeNameStringRef == "long" || - typeNameStringRef == "int" || typeNameStringRef == "bool" || - typeNameStringRef == "jtbaa_arraysize" || - typeNameStringRef == "jtbaa_arraylen") { - if (printtype) { - llvm::errs() << "known tbaa " << *inst << " " << typeNameStringRef +//! The following is not taken from LLVM + +/// Flag to print Type Analysis results as they are derived +extern llvm::cl::opt PrintType; + +/// Derive the ConcreteType corresponding to the string TypeName +/// The Instruction I denotes the context in which this was found +static inline ConcreteType getTypeFromTBAAString(std::string TypeName, + Instruction &I) { + if (TypeName == "long long" || TypeName == "long" || + TypeName == "int" || TypeName == "bool" || + TypeName == "jtbaa_arraysize" || + TypeName == "jtbaa_arraylen") { + if (PrintType) { + llvm::errs() << "known tbaa " << I << " " << TypeName << "\n"; } return ConcreteType(BaseType::Integer); - } else if (typeNameStringRef == "any pointer" || - typeNameStringRef == "vtable pointer" || - typeNameStringRef == "jtbaa_arrayptr" || - typeNameStringRef == "jtbaa_tag") { - if (printtype) { - llvm::errs() << "known tbaa " << *inst << " " << typeNameStringRef + } else if (TypeName == "any pointer" || + TypeName == "vtable pointer" || + TypeName == "jtbaa_arrayptr" || + TypeName == "jtbaa_tag") { + if (PrintType) { + llvm::errs() << "known tbaa " << I << " " << TypeName << "\n"; } return ConcreteType(BaseType::Pointer); - } else if (typeNameStringRef == "float") { - if (printtype) - llvm::errs() << "known tbaa " << *inst << " " << typeNameStringRef + } else if (TypeName == "float") { + if (PrintType) + llvm::errs() << "known tbaa " << I << " " << TypeName << "\n"; - return Type::getFloatTy(inst->getContext()); - } else if (typeNameStringRef == "double") { - if (printtype) - llvm::errs() << "known tbaa " << *inst << " " << typeNameStringRef + return Type::getFloatTy(I.getContext()); + } else if (TypeName == "double") { + if (PrintType) + llvm::errs() << "known tbaa " << I << " " << TypeName << "\n"; - return Type::getDoubleTy(inst->getContext()); - } else if (typeNameStringRef == "jtbaa_arraybuf") { - if (printtype) - llvm::errs() << "known tbaa " << *inst << " " << typeNameStringRef + return Type::getDoubleTy(I.getContext()); + } else if (TypeName == "jtbaa_arraybuf") { + if (PrintType) + llvm::errs() << "known tbaa " << I << " " << TypeName << "\n"; - if (isa(inst)) { - if (inst->getType()->isFPOrFPVectorTy()) { - return inst->getType()->getScalarType(); + if (isa(&I)) { + if (I.getType()->isFPOrFPVectorTy()) { + return I.getType()->getScalarType(); } - if (inst->getType()->isIntOrIntVectorTy()) { + if (I.getType()->isIntOrIntVectorTy()) { return BaseType::Integer; } } @@ -404,44 +407,41 @@ static inline ConcreteType getTypeFromTBAAString(std::string typeNameStringRef, return ConcreteType(BaseType::Unknown); } +/// Given a TBAA access node return the corresponding TypeTree +/// This includes recursively parsing the access nodes, with +/// corresponding offsets in the result static inline TypeTree parseTBAA(TBAAStructTypeNode AccessType, - Instruction *inst, - const llvm::DataLayout &dl) { - // llvm::errs() << "AT: " << *AccessType.getNode() << "\n"; + Instruction &I, + const llvm::DataLayout &DL) { if (auto *Id = dyn_cast(AccessType.getId())) { - // llvm::errs() << "cur access type: " << Id->getString() << "\n"; - auto dt = getTypeFromTBAAString(Id->getString().str(), inst); - if (dt.isKnown()) { - return TypeTree(dt).Only(-1); + auto CT = getTypeFromTBAAString(Id->getString().str(), I); + if (CT.isKnown()) { + return TypeTree(CT).Only(-1); } } - // llvm::errs() << "numfields: " << AccessType.getNumFields() << "\n"; - - TypeTree dat(BaseType::Pointer); - for (unsigned i = 0; i < AccessType.getNumFields(); ++i) { - auto at = AccessType.getFieldType(i); - auto start = AccessType.getFieldOffset(i); - // llvm::errs() << " f at i: " << i << " at: " << start << " fd: " << - // *at.getNode() << "\n"; - auto vd = parseTBAA(at, inst, dl); - // llvm::errs() << " _^ for f found " << vd.str() << "\n"; - dat |= vd.ShiftIndices(dl, /*init offset*/ 0, /*max size*/ -1, - /*addOffset*/ start); + TypeTree Result(BaseType::Pointer); + for (unsigned i = 0, size = AccessType.getNumFields(); i < size; ++i) { + auto SubAccess = AccessType.getFieldType(i); + auto Offset = AccessType.getFieldOffset(i); + auto SubResult = parseTBAA(SubAccess, I, DL); + Result |= SubResult.ShiftIndices(DL, /*init offset*/ 0, /*max size*/ -1, + /*addOffset*/ Offset); } - return dat; + return Result; } -// Modified from MDNode::isTBAAVtableAccess() -static inline TypeTree parseTBAA(const MDNode *M, Instruction *inst, - const llvm::DataLayout &dl) { +/// Given a TBAA metadata node return the corresponding TypeTree +/// Modified from MDNode::isTBAAVtableAccess() +static inline TypeTree parseTBAA(const MDNode *M, Instruction &I, + const llvm::DataLayout &DL) { if (!isStructPathTBAA(M)) { if (M->getNumOperands() < 1) return TypeTree(); if (const MDString *Tag1 = dyn_cast(M->getOperand(0))) { - return TypeTree(getTypeFromTBAAString(Tag1->getString().str(), inst)) + return TypeTree(getTypeFromTBAAString(Tag1->getString().str(), I)) .Only(0); } return TypeTree(); @@ -450,36 +450,35 @@ static inline TypeTree parseTBAA(const MDNode *M, Instruction *inst, // For struct-path aware TBAA, we use the access type of the tag. TBAAStructTagNode Tag(M); TBAAStructTypeNode AccessType(Tag.getAccessType()); - return parseTBAA(AccessType, inst, dl); + return parseTBAA(AccessType, I, DL); } -static inline TypeTree parseTBAA(Instruction *Inst, - const llvm::DataLayout &dl) { - TypeTree dat; - if (const MDNode *M = Inst->getMetadata(LLVMContext::MD_tbaa_struct)) { - for (unsigned i = 0; i < M->getNumOperands(); i += 3) { +/// Given an Instruction, return a TypeTree representing any +/// types that can be derived from TBAA metadata attached +static inline TypeTree parseTBAA(Instruction &I, + const llvm::DataLayout &DL) { + TypeTree Result; + if (const MDNode *M = I.getMetadata(LLVMContext::MD_tbaa_struct)) { + for (unsigned i = 0, size = M->getNumOperands(); i < size; i += 3) { if (const MDNode *M2 = dyn_cast(M->getOperand(i + 2))) { - auto vd = parseTBAA(M2, Inst, dl); - auto start = cast( + auto SubResult = parseTBAA(M2, I, DL); + auto Start = cast( cast(M->getOperand(i))->getValue()) ->getLimitedValue(); - auto len = + auto Len = cast( cast(M->getOperand(i + 1))->getValue()) ->getLimitedValue(); - // llvm::errs() << "inst: " << *Inst << " vd " << vd.str() << " len: " - // << len << " start: " << start << "\n"; - dat |= vd.ShiftIndices(dl, /*init offset*/ 0, /*max size*/ len, - /*add offset*/ start); + Result |= SubResult.ShiftIndices(DL, /*init offset*/ 0, /*max size*/ Len, + /*add offset*/ Start); } } } - if (const MDNode *M = Inst->getMetadata(LLVMContext::MD_tbaa)) { - dat |= parseTBAA(M, Inst, dl); + if (const MDNode *M = I.getMetadata(LLVMContext::MD_tbaa)) { + Result |= parseTBAA(M, I, DL); } - dat |= TypeTree(BaseType::Pointer); - // llvm::errs() << "overall parsed: " << *Inst << " " << dat.str() << "\n"; - return dat; + Result |= TypeTree(BaseType::Pointer); + return Result; } #endif diff --git a/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp b/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp index 3f6ed2dd4fd0..09143f52193c 100644 --- a/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp +++ b/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp @@ -46,19 +46,17 @@ #include "TBAA.h" -llvm::cl::opt printtype("enzyme_printtype", cl::init(false), cl::Hidden, +llvm::cl::opt PrintType("enzyme_PrintType", cl::init(false), cl::Hidden, cl::desc("Print type detection algorithm")); TypeAnalyzer::TypeAnalyzer(const FnTypeInfo &fn, TypeAnalysis &TA) - : intseen(), fntypeinfo(fn), interprocedural(TA), DT(*fn.function) { - // assert(fntypeinfo.knownValues.size() == - // fntypeinfo.function->getFunctionType()->getNumParams()); - for (BasicBlock &BB : *fntypeinfo.function) { + : intseen(), fntypeinfo(fn), interprocedural(TA), DT(*fn.Function) { + for (BasicBlock &BB : *fntypeinfo.Function) { for (auto &inst : BB) { workList.push_back(&inst); } } - for (BasicBlock &BB : *fntypeinfo.function) { + for (BasicBlock &BB : *fntypeinfo.Function) { for (auto &inst : BB) { for (auto &op : inst.operands()) { addToWorkList(op); @@ -67,9 +65,10 @@ TypeAnalyzer::TypeAnalyzer(const FnTypeInfo &fn, TypeAnalysis &TA) } } +/// Given a constant value, deduce any type information applicable TypeTree getConstantAnalysis(Constant *val, const FnTypeInfo &nfti, TypeAnalysis &TA) { - auto &dl = nfti.function->getParent()->getDataLayout(); + auto &dl = nfti.Function->getParent()->getDataLayout(); // Undefined value is an anything everywhere if (isa(val) || isa(val)) { return TypeTree(BaseType::Anything).Only(-1); @@ -91,11 +90,11 @@ TypeTree getConstantAnalysis(Constant *val, const FnTypeInfo &nfti, TypeTree res; int off = 0; for (unsigned i = 0; i < ca->getNumOperands(); ++i) { - assert(nfti.function); + assert(nfti.Function); auto op = ca->getOperand(i); // TODO check this for i1 constant aggregates packing/etc auto size = - (nfti.function->getParent()->getDataLayout().getTypeSizeInBits( + (nfti.Function->getParent()->getDataLayout().getTypeSizeInBits( op->getType()) + 7) / 8; @@ -111,11 +110,11 @@ TypeTree getConstantAnalysis(Constant *val, const FnTypeInfo &nfti, TypeTree res; int off = 0; for (unsigned i = 0; i < ca->getNumElements(); ++i) { - assert(nfti.function); + assert(nfti.Function); auto op = ca->getElementAsConstant(0); // TODO check this for i1 constant aggregates packing/etc auto size = - (nfti.function->getParent()->getDataLayout().getTypeSizeInBits( + (nfti.Function->getParent()->getDataLayout().getTypeSizeInBits( op->getType()) + 7) / 8; @@ -151,7 +150,7 @@ TypeTree getConstantAnalysis(Constant *val, const FnTypeInfo &nfti, TypeTree vd; auto ae = ce->getAsInstruction(); - ae->insertBefore(nfti.function->getEntryBlock().getTerminator()); + ae->insertBefore(nfti.Function->getEntryBlock().getTerminator()); { TypeAnalyzer tmp(nfti, TA); @@ -199,21 +198,21 @@ TypeTree TypeAnalyzer::getAnalysis(Value *val) { } if (auto inst = dyn_cast(val)) { - if (inst->getParent()->getParent() != fntypeinfo.function) { - llvm::errs() << " function: " << *fntypeinfo.function << "\n"; + if (inst->getParent()->getParent() != fntypeinfo.Function) { + llvm::errs() << " function: " << *fntypeinfo.Function << "\n"; llvm::errs() << " instParent: " << *inst->getParent()->getParent() << "\n"; llvm::errs() << " inst: " << *inst << "\n"; } - assert(inst->getParent()->getParent() == fntypeinfo.function); + assert(inst->getParent()->getParent() == fntypeinfo.Function); } if (auto arg = dyn_cast(val)) { - if (arg->getParent() != fntypeinfo.function) { - llvm::errs() << " function: " << *fntypeinfo.function << "\n"; + if (arg->getParent() != fntypeinfo.Function) { + llvm::errs() << " function: " << *fntypeinfo.Function << "\n"; llvm::errs() << " argParent: " << *arg->getParent() << "\n"; llvm::errs() << " arg: " << *arg << "\n"; } - assert(arg->getParent() == fntypeinfo.function); + assert(arg->getParent() == fntypeinfo.Function); } if (isa(val) || isa(val)) @@ -235,22 +234,20 @@ void TypeAnalyzer::updateAnalysis(Value *val, BaseType data, Value *origin) { void TypeAnalyzer::addToWorkList(Value *val) { if (!isa(val) && !isa(val)) return; - // llvm::errs() << " - adding to work list: " << *val << "\n"; if (std::find(workList.begin(), workList.end(), val) != workList.end()) return; if (auto inst = dyn_cast(val)) { - if (fntypeinfo.function != inst->getParent()->getParent()) { - llvm::errs() << "function: " << *fntypeinfo.function << "\n"; + if (fntypeinfo.Function != inst->getParent()->getParent()) { + llvm::errs() << "function: " << *fntypeinfo.Function << "\n"; llvm::errs() << "instf: " << *inst->getParent()->getParent() << "\n"; llvm::errs() << "inst: " << *inst << "\n"; } - assert(fntypeinfo.function == inst->getParent()->getParent()); + assert(fntypeinfo.Function == inst->getParent()->getParent()); } if (auto arg = dyn_cast(val)) - assert(fntypeinfo.function == arg->getParent()); + assert(fntypeinfo.Function == arg->getParent()); - // llvm::errs() << " - - true add : " << *val << "\n"; workList.push_back(val); } @@ -259,7 +256,7 @@ void TypeAnalyzer::updateAnalysis(Value *val, TypeTree data, Value *origin) { return; } - if (printtype) { + if (PrintType) { llvm::errs() << "updating analysis of val: " << *val << " current: " << analysis[val].str() << " new " << data.str(); @@ -269,15 +266,15 @@ void TypeAnalyzer::updateAnalysis(Value *val, TypeTree data, Value *origin) { } if (auto inst = dyn_cast(val)) { - if (fntypeinfo.function != inst->getParent()->getParent()) { - llvm::errs() << "function: " << *fntypeinfo.function << "\n"; + if (fntypeinfo.Function != inst->getParent()->getParent()) { + llvm::errs() << "function: " << *fntypeinfo.Function << "\n"; llvm::errs() << "instf: " << *inst->getParent()->getParent() << "\n"; llvm::errs() << "inst: " << *inst << "\n"; } - assert(fntypeinfo.function == inst->getParent()->getParent()); + assert(fntypeinfo.Function == inst->getParent()->getParent()); } if (auto arg = dyn_cast(val)) - assert(fntypeinfo.function == arg->getParent()); + assert(fntypeinfo.Function == arg->getParent()); if (isa(val) && data[{}] == BaseType::Integer) { llvm::errs() << "illegal gep update\n"; @@ -302,7 +299,7 @@ void TypeAnalyzer::updateAnalysis(Value *val, TypeTree data, Value *origin) { if (use != origin) { if (auto inst = dyn_cast(use)) { - if (fntypeinfo.function != inst->getParent()->getParent()) { + if (fntypeinfo.Function != inst->getParent()->getParent()) { continue; } } @@ -322,22 +319,22 @@ void TypeAnalyzer::updateAnalysis(Value *val, TypeTree data, Value *origin) { } void TypeAnalyzer::prepareArgs() { - for (auto &pair : fntypeinfo.first) { - assert(pair.first->getParent() == fntypeinfo.function); + for (auto &pair : fntypeinfo.Arguments) { + assert(pair.first->getParent() == fntypeinfo.Function); updateAnalysis(pair.first, pair.second, nullptr); } - for (auto &arg : fntypeinfo.function->args()) { + for (auto &arg : fntypeinfo.Function->args()) { // Get type and other information about argument updateAnalysis(&arg, getAnalysis(&arg), &arg); } // Propagate return value type information - for (auto &BB : *fntypeinfo.function) { + for (auto &BB : *fntypeinfo.Function) { for (auto &inst : BB) { if (auto ri = dyn_cast(&inst)) { if (auto rv = ri->getReturnValue()) { - updateAnalysis(rv, fntypeinfo.second, nullptr); + updateAnalysis(rv, fntypeinfo.Return, nullptr); } } } @@ -345,12 +342,12 @@ void TypeAnalyzer::prepareArgs() { } void TypeAnalyzer::considerTBAA() { - auto &dl = fntypeinfo.function->getParent()->getDataLayout(); + auto &dl = fntypeinfo.Function->getParent()->getDataLayout(); - for (auto &BB : *fntypeinfo.function) { + for (auto &BB : *fntypeinfo.Function) { for (auto &inst : BB) { - auto vdptr = parseTBAA(&inst, dl); + auto vdptr = parseTBAA(inst, dl); if (!vdptr.isKnownPastPointer()) continue; @@ -477,7 +474,6 @@ bool hasAnyUse(TypeAnalyzer &TAZ, } unknownUse = true; - // llvm::errs() << "unknown use : " << *use << " of v: " << *v << "\n"; continue; } @@ -611,7 +607,7 @@ bool TypeAnalyzer::runUnusedChecks() { std::map anyseen; std::map intseen; - for (BasicBlock &BB : *fntypeinfo.function) { + for (BasicBlock &BB : *fntypeinfo.Function) { for (auto &inst : BB) { auto analysis = getAnalysis(&inst); if (analysis[{0}] != BaseType::Unknown) @@ -730,10 +726,6 @@ void TypeAnalyzer::visitLoadInst(LoadInst &I) { .ShiftIndices(dl, /*start*/ 0, loadSize, /*addOffset*/ 0) .PurgeAnything(); ptr |= TypeTree(BaseType::Pointer); - // llvm::errs() << "LI: " << I << " prev i0: " << - // getAnalysis(I.getOperand(0)).str() << " ptr only-1:" << ptr.Only(-1).str() - // << "\n"; llvm::errs() << " + " << " prev i: " << getAnalysis(&I).str() <<" - // ga lu:" << getAnalysis(I.getOperand(0)).Lookup(loadSize).str() << "\n"; updateAnalysis(I.getOperand(0), ptr.Only(-1), &I); updateAnalysis(&I, getAnalysis(I.getOperand(0)).Lookup(loadSize, dl), &I); } @@ -749,11 +741,6 @@ void TypeAnalyzer::visitStoreInst(StoreInst &I) { .PurgeAnything(); ptr |= purged; - // llvm::errs() << "considering si: " << I << "\n"; - // llvm::errs() << " prevanalysis: " << - // getAnalysis(I.getPointerOperand()).str() << "\n"; llvm::errs() << " new: " - // << ptr.str() << "\n"; - updateAnalysis(I.getPointerOperand(), ptr.Only(-1), &I); updateAnalysis( I.getValueOperand(), @@ -784,7 +771,7 @@ std::set> getSet(const std::vector> &todo, } void TypeAnalyzer::visitGetElementPtrInst(GetElementPtrInst &gep) { - auto &dl = fntypeinfo.function->getParent()->getDataLayout(); + auto &dl = fntypeinfo.Function->getParent()->getDataLayout(); auto pointerAnalysis = getAnalysis(gep.getPointerOperand()); updateAnalysis(&gep, pointerAnalysis.KeepMinusOne(), &gep); @@ -887,7 +874,7 @@ void TypeAnalyzer::visitPHINode(PHINode &phi) { auto consider = [&](TypeTree &&newData, Value *v) { if (set) { - vd.andIn(newData, /*assertIfIllegal*/ false); + vd &= newData; } else { set = true; vd = newData; @@ -947,9 +934,8 @@ void TypeAnalyzer::visitPHINode(PHINode &phi) { TypeTree vd2 = isa(bo->getOperand(1)) ? getAnalysis(bo->getOperand(1)).Data0() : vd.Data0(); - vd1.pointerIntMerge(vd2, bo->getOpcode()); - vd.andIn(vd1.Only(bo->getType()->isIntegerTy() ? -1 : 0), - /*assertIfIllegal*/ false); + vd1.binopIn(vd2, bo->getOpcode()); + vd &= vd1.Only(bo->getType()->isIntegerTy() ? -1 : 0); } updateAnalysis(&phi, vd, &phi); @@ -1030,7 +1016,7 @@ void TypeAnalyzer::visitBitCastInst(BitCastInst &I) { &I, getAnalysis(I.getOperand(0)) .Data0() - .KeepForCast(fntypeinfo.function->getParent()->getDataLayout(), et2, + .KeepForCast(fntypeinfo.Function->getParent()->getDataLayout(), et2, et1) .Only(-1), &I); @@ -1038,7 +1024,7 @@ void TypeAnalyzer::visitBitCastInst(BitCastInst &I) { I.getOperand(0), getAnalysis(&I) .Data0() - .KeepForCast(fntypeinfo.function->getParent()->getDataLayout(), et1, + .KeepForCast(fntypeinfo.Function->getParent()->getDataLayout(), et1, et2) .Only(-1), &I); @@ -1050,7 +1036,7 @@ void TypeAnalyzer::visitSelectInst(SelectInst &I) { updateAnalysis(I.getFalseValue(), getAnalysis(&I), &I); TypeTree vd = getAnalysis(I.getTrueValue()); - vd.andIn(getAnalysis(I.getFalseValue()), /*assertIfIllegal*/ false); + vd &= getAnalysis(I.getFalseValue()); updateAnalysis(&I, vd, &I); } @@ -1080,13 +1066,13 @@ void TypeAnalyzer::visitShuffleVectorInst(ShuffleVectorInst &I) { updateAnalysis(I.getOperand(1), getAnalysis(&I), &I); TypeTree vd = getAnalysis(I.getOperand(0)); - vd.andIn(getAnalysis(I.getOperand(1)), /*assertIfIllegal*/ false); + vd &= getAnalysis(I.getOperand(1)); updateAnalysis(&I, vd, &I); } void TypeAnalyzer::visitExtractValueInst(ExtractValueInst &I) { - auto &dl = fntypeinfo.function->getParent()->getDataLayout(); + auto &dl = fntypeinfo.Function->getParent()->getDataLayout(); std::vector vec; vec.push_back(ConstantInt::get(Type::getInt64Ty(I.getContext()), 0)); for (auto ind : I.indices()) { @@ -1118,7 +1104,7 @@ void TypeAnalyzer::visitExtractValueInst(ExtractValueInst &I) { } void TypeAnalyzer::visitInsertValueInst(InsertValueInst &I) { - auto &dl = fntypeinfo.function->getParent()->getDataLayout(); + auto &dl = fntypeinfo.Function->getParent()->getDataLayout(); std::vector vec; vec.push_back(ConstantInt::get(Type::getInt64Ty(I.getContext()), 0)); for (auto ind : I.indices()) { @@ -1181,41 +1167,34 @@ void TypeAnalyzer::visitBinaryOperator(BinaryOperator &I) { updateAnalysis(I.getOperand(1), TypeTree(dt).Only(-1), &I); updateAnalysis(&I, TypeTree(dt).Only(-1), &I); } else { - auto analysis = getAnalysis(&I).Data0(); + auto AnalysisLHS = getAnalysis(I.getOperand(0)).Data0(); + auto AnalysisRHS = getAnalysis(I.getOperand(1)).Data0(); + auto AnalysisRet = getAnalysis(&I).Data0(); + switch (I.getOpcode()) { case BinaryOperator::Sub: - // TODO propagate this info // ptr - ptr => int and int - int => int; thus int = a - b says only that // these are equal ptr - int => ptr and int - ptr => ptr; thus - analysis = ConcreteType(BaseType::Unknown); + // howerver we do not want to propagate underlying ptr types since it's legal to subtract unrelated pointer + if (AnalysisRet[{}] == BaseType::Integer) { + updateAnalysis(I.getOperand(0), TypeTree(AnalysisRHS[{}]).Only(-1), &I); + updateAnalysis(I.getOperand(1), TypeTree(AnalysisLHS[{}]).Only(-1), &I); + } break; case BinaryOperator::Add: case BinaryOperator::Mul: // if a + b or a * b == int, then a and b must be ints - analysis = analysis.JustInt(); + updateAnalysis(I.getOperand(0), TypeTree(AnalysisRet.JustInt()[{}]).Only(-1), &I); + updateAnalysis(I.getOperand(1), TypeTree(AnalysisRet.JustInt()[{}]).Only(-1), &I); break; - case BinaryOperator::UDiv: - case BinaryOperator::SDiv: - case BinaryOperator::URem: - case BinaryOperator::SRem: - case BinaryOperator::And: - case BinaryOperator::Or: - case BinaryOperator::Xor: - case BinaryOperator::Shl: - case BinaryOperator::AShr: - case BinaryOperator::LShr: - analysis = ConcreteType(BaseType::Unknown); - break; default: - llvm_unreachable("unknown binary operator"); + break; } - updateAnalysis(I.getOperand(0), analysis.Only(-1), &I); - updateAnalysis(I.getOperand(1), analysis.Only(-1), &I); - TypeTree vd = getAnalysis(I.getOperand(0)).Data0(); - vd.pointerIntMerge(getAnalysis(I.getOperand(1)).Data0(), I.getOpcode()); + TypeTree vd = AnalysisLHS; + vd.binopIn(AnalysisRHS, I.getOpcode()); if (I.getOpcode() == BinaryOperator::And) { for (int i = 0; i < 2; ++i) { @@ -1329,7 +1308,7 @@ void TypeAnalyzer::visitIntrinsicInst(llvm::IntrinsicInst &I) { auto analysis = getAnalysis(&I).Data0(); BinaryOperator::BinaryOps opcode; - + // TODO update to use better rules in regular binop switch (I.getIntrinsicID()) { case Intrinsic::ssub_with_overflow: case Intrinsic::usub_with_overflow: { @@ -1363,7 +1342,7 @@ void TypeAnalyzer::visitIntrinsicInst(llvm::IntrinsicInst &I) { updateAnalysis(I.getOperand(1), analysis.Only(-1), &I); TypeTree vd = getAnalysis(I.getOperand(0)).Data0(); - vd.pointerIntMerge(getAnalysis(I.getOperand(1)).Data0(), opcode); + vd.binopIn(getAnalysis(I.getOperand(1)).Data0(), opcode); TypeTree overall = vd.Only(0); @@ -1573,8 +1552,8 @@ void analyzeFuncTypes(RT (*fn)(Args...), CallInst &call, TypeAnalyzer &TA) { } void TypeAnalyzer::visitCallInst(CallInst &call) { - assert(fntypeinfo.knownValues.size() == - fntypeinfo.function->getFunctionType()->getNumParams()); + assert(fntypeinfo.KnownValues.size() == + fntypeinfo.Function->getFunctionType()->getNumParams()); #if LLVM_VERSION_MAJOR >= 11 if (auto iasm = dyn_cast(call.getCalledOperand())) { @@ -1753,7 +1732,7 @@ void TypeAnalyzer::visitCallInst(CallInst &call) { TypeTree TypeAnalyzer::getReturnAnalysis() { bool set = false; TypeTree vd; - for (BasicBlock &BB : *fntypeinfo.function) { + for (BasicBlock &BB : *fntypeinfo.Function) { for (auto &inst : BB) { if (auto ri = dyn_cast(&inst)) { if (auto rv = ri->getReturnValue()) { @@ -1762,7 +1741,7 @@ TypeTree TypeAnalyzer::getReturnAnalysis() { vd = getAnalysis(rv); continue; } - vd.andIn(getAnalysis(rv), /*assertIfIllegal*/ false); + vd &= getAnalysis(rv); } } } @@ -1777,19 +1756,19 @@ std::set FnTypeInfo::knownIntegralValues( return {constant->getSExtValue()}; } - assert(knownValues.size() == function->getFunctionType()->getNumParams()); + assert(KnownValues.size() == Function->getFunctionType()->getNumParams()); if (auto arg = dyn_cast(val)) { - auto found = knownValues.find(arg); - if (found == knownValues.end()) { - for (const auto &pair : knownValues) { - llvm::errs() << " knownValues[" << *pair.first << "] - " + auto found = KnownValues.find(arg); + if (found == KnownValues.end()) { + for (const auto &pair : KnownValues) { + llvm::errs() << " KnownValues[" << *pair.first << "] - " << pair.first->getParent()->getName() << "\n"; } llvm::errs() << " arg: " << *arg << " - " << arg->getParent()->getName() << "\n"; } - assert(found != knownValues.end()); + assert(found != KnownValues.end()); return found->second; } @@ -1906,24 +1885,24 @@ std::set FnTypeInfo::knownIntegralValues( } void TypeAnalyzer::visitIPOCall(CallInst &call, Function &fn) { - assert(fntypeinfo.knownValues.size() == - fntypeinfo.function->getFunctionType()->getNumParams()); + assert(fntypeinfo.KnownValues.size() == + fntypeinfo.Function->getFunctionType()->getNumParams()); FnTypeInfo typeInfo(&fn); int argnum = 0; for (auto &arg : fn.args()) { auto dt = getAnalysis(call.getArgOperand(argnum)); - typeInfo.first.insert(std::pair(&arg, dt)); - typeInfo.knownValues.insert(std::pair>( + typeInfo.Arguments.insert(std::pair(&arg, dt)); + typeInfo.KnownValues.insert(std::pair>( &arg, fntypeinfo.knownIntegralValues(call.getArgOperand(argnum), DT, intseen))); ++argnum; } - typeInfo.second = getAnalysis(&call); + typeInfo.Return = getAnalysis(&call); - if (printtype) + if (PrintType) llvm::errs() << " starting IPO of " << call << "\n"; auto a = fn.arg_begin(); @@ -1942,12 +1921,12 @@ TypeResults TypeAnalysis::analyzeFunction(const FnTypeInfo &fn) { auto found = analyzedFunctions.find(fn); if (found != analyzedFunctions.end()) { auto &analysis = found->second; - if (analysis.fntypeinfo.function != fn.function) { - llvm::errs() << " queryFunc: " << *fn.function << "\n"; - llvm::errs() << " analysisFunc: " << *analysis.fntypeinfo.function + if (analysis.fntypeinfo.Function != fn.Function) { + llvm::errs() << " queryFunc: " << *fn.Function << "\n"; + llvm::errs() << " analysisFunc: " << *analysis.fntypeinfo.Function << "\n"; } - assert(analysis.fntypeinfo.function == fn.function); + assert(analysis.fntypeinfo.Function == fn.Function); return TypeResults(*this, fn); } @@ -1955,38 +1934,38 @@ TypeResults TypeAnalysis::analyzeFunction(const FnTypeInfo &fn) { auto res = analyzedFunctions.emplace(fn, TypeAnalyzer(fn, *this)); auto &analysis = res.first->second; - if (printtype) { - llvm::errs() << "analyzing function " << fn.function->getName() << "\n"; - for (auto &pair : fn.first) { + if (PrintType) { + llvm::errs() << "analyzing function " << fn.Function->getName() << "\n"; + for (auto &pair : fn.Arguments) { llvm::errs() << " + knowndata: " << *pair.first << " : " << pair.second.str(); - auto found = fn.knownValues.find(pair.first); - if (found != fn.knownValues.end()) { + auto found = fn.KnownValues.find(pair.first); + if (found != fn.KnownValues.end()) { llvm::errs() << " - " << to_string(found->second); } llvm::errs() << "\n"; } - llvm::errs() << " + retdata: " << fn.second.str() << "\n"; + llvm::errs() << " + retdata: " << fn.Return.str() << "\n"; } analysis.prepareArgs(); analysis.considerTBAA(); analysis.run(); - if (analysis.fntypeinfo.function != fn.function) { - llvm::errs() << " queryFunc: " << *fn.function << "\n"; - llvm::errs() << " analysisFunc: " << *analysis.fntypeinfo.function << "\n"; + if (analysis.fntypeinfo.Function != fn.Function) { + llvm::errs() << " queryFunc: " << *fn.Function << "\n"; + llvm::errs() << " analysisFunc: " << *analysis.fntypeinfo.Function << "\n"; } - assert(analysis.fntypeinfo.function == fn.function); + assert(analysis.fntypeinfo.Function == fn.Function); { auto &analysis = analyzedFunctions.find(fn)->second; - if (analysis.fntypeinfo.function != fn.function) { - llvm::errs() << " queryFunc: " << *fn.function << "\n"; - llvm::errs() << " analysisFunc: " << *analysis.fntypeinfo.function + if (analysis.fntypeinfo.Function != fn.Function) { + llvm::errs() << " queryFunc: " << *fn.Function << "\n"; + llvm::errs() << " analysisFunc: " << *analysis.fntypeinfo.Function << "\n"; } - assert(analysis.fntypeinfo.function == fn.function); + assert(analysis.fntypeinfo.Function == fn.Function); } return TypeResults(*this, fn); @@ -2012,11 +1991,11 @@ TypeTree TypeAnalysis::query(Value *val, const FnTypeInfo &fn) { analyzeFunction(fn); auto &found = analyzedFunctions.find(fn)->second; - if (func && found.fntypeinfo.function != func) { + if (func && found.fntypeinfo.Function != func) { llvm::errs() << " queryFunc: " << *func << "\n"; - llvm::errs() << " foundFunc: " << *found.fntypeinfo.function << "\n"; + llvm::errs() << " foundFunc: " << *found.fntypeinfo.Function << "\n"; } - assert(!func || found.fntypeinfo.function == func); + assert(!func || found.fntypeinfo.Function == func); return found.getAnalysis(val); } @@ -2027,7 +2006,7 @@ ConcreteType TypeAnalysis::intType(Value *val, const FnTypeInfo &fn, auto q = query(val, fn).Data0(); auto dt = q[{}]; // dump(); - if (errIfNotFound && (!dt.isKnown() || dt.typeEnum == BaseType::Anything)) { + if (errIfNotFound && (!dt.isKnown() || dt == BaseType::Anything)) { if (auto inst = dyn_cast(val)) { llvm::errs() << *inst->getParent()->getParent()->getParent() << "\n"; llvm::errs() << *inst->getParent()->getParent() << "\n"; @@ -2050,12 +2029,12 @@ ConcreteType TypeAnalysis::firstPointer(size_t num, Value *val, assert(val->getType()->isPointerTy()); auto q = query(val, fn).Data0(); auto dt = q[{0}]; - dt.mergeIn(q[{-1}], pointerIntSame); + dt.orIn(q[{-1}], pointerIntSame); for (size_t i = 1; i < num; ++i) { - dt.mergeIn(q[{(int)i}], pointerIntSame); + dt.orIn(q[{(int)i}], pointerIntSame); } - if (errIfNotFound && (!dt.isKnown() || dt.typeEnum == BaseType::Anything)) { + if (errIfNotFound && (!dt.isKnown() || dt == BaseType::Anything)) { auto &res = analyzedFunctions.find(fn)->second; if (auto inst = dyn_cast(val)) { llvm::errs() << *inst->getParent()->getParent() << "\n"; @@ -2097,25 +2076,25 @@ TypeResults::TypeResults(TypeAnalysis &analysis, const FnTypeInfo &fn) : analysis(analysis), info(fn) {} FnTypeInfo TypeResults::getAnalyzedTypeInfo() { - FnTypeInfo res(info.function); - for (auto &arg : info.function->args()) { - res.first.insert( + FnTypeInfo res(info.Function); + for (auto &arg : info.Function->args()) { + res.Arguments.insert( std::pair(&arg, analysis.query(&arg, info))); } - res.second = getReturnAnalysis(); - res.knownValues = info.knownValues; + res.Return = getReturnAnalysis(); + res.KnownValues = info.KnownValues; return res; } TypeTree TypeResults::query(Value *val) { if (auto inst = dyn_cast(val)) { - assert(inst->getParent()->getParent() == info.function); + assert(inst->getParent()->getParent() == info.Function); } if (auto arg = dyn_cast(val)) { - assert(arg->getParent() == info.function); + assert(arg->getParent() == info.Function); } - for (auto &pair : info.first) { - assert(pair.first->getParent() == info.function); + for (auto &pair : info.Arguments) { + assert(pair.first->getParent() == info.Function); } return analysis.query(val, info); } diff --git a/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.h b/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.h index 8fdce9f12a30..b4fabb2df95c 100644 --- a/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.h +++ b/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.h @@ -39,21 +39,27 @@ #include "TypeTree.h" -class FnTypeInfo { -public: - llvm::Function *function; - FnTypeInfo(llvm::Function *fn) : function(fn) {} +/// Struct containing all contextual type information for a +/// particular function call +struct FnTypeInfo { + /// Function being analyzed + llvm::Function *Function; + + FnTypeInfo(llvm::Function *fn) : Function(fn) {} FnTypeInfo(const FnTypeInfo &) = default; FnTypeInfo &operator=(FnTypeInfo &) = default; FnTypeInfo &operator=(FnTypeInfo &&) = default; - // arguments:type - std::map first; - // return type - TypeTree second; - // the specific constant of an argument, if it is constant - std::map> knownValues; + /// Types of arguments + std::map Arguments; + + /// Type of return + TypeTree Return; + + /// The specific constant(s) known to represented by an argument, if constant + std::map> KnownValues; + /// The set of known values val will take std::set knownIntegralValues(llvm::Value *val, const llvm::DominatorTree &DT, std::map> &intseen) const; @@ -62,25 +68,27 @@ class FnTypeInfo { static inline bool operator<(const FnTypeInfo &lhs, const FnTypeInfo &rhs) { - if (lhs.function < rhs.function) + if (lhs.Function < rhs.Function) return true; - if (rhs.function < lhs.function) + if (rhs.Function < lhs.Function) return false; - if (lhs.first < rhs.first) + if (lhs.Arguments < rhs.Arguments) return true; - if (rhs.first < lhs.first) + if (rhs.Arguments < lhs.Arguments) return false; - if (lhs.second < rhs.second) + if (lhs.Return < rhs.Return) return true; - if (rhs.second < lhs.second) + if (rhs.Return < lhs.Return) return false; - return lhs.knownValues < rhs.knownValues; + return lhs.KnownValues < rhs.KnownValues; } class TypeAnalyzer; class TypeAnalysis; +/// A holder class representing the results of running TypeAnalysis +/// on a given function class TypeResults { public: TypeAnalysis &analysis; @@ -90,34 +98,50 @@ class TypeResults { TypeResults(TypeAnalysis &analysis, const FnTypeInfo &fn); ConcreteType intType(llvm::Value *val, bool errIfNotFound = true); - //! Returns whether in the first num bytes there is pointer, int, float, or - //! none If pointerIntSame is set to true, then consider either as the same - //! (and thus mergable) + /// Returns whether in the first num bytes there is pointer, int, float, or + /// none If pointerIntSame is set to true, then consider either as the same + /// (and thus mergable) ConcreteType firstPointer(size_t num, llvm::Value *val, bool errIfNotFound = true, bool pointerIntSame = false); + /// The TypeTree of a particular Value TypeTree query(llvm::Value *val); + + /// The TypeInfo calling convention FnTypeInfo getAnalyzedTypeInfo(); + + /// The Type of the return TypeTree getReturnAnalysis(); + + /// Prints all known information void dump(); + + ///The set of values val will take on during this program std::set knownIntegralValues(llvm::Value *val) const; }; +/// Helper class that computes the fixed-point type results of a given function class TypeAnalyzer : public llvm::InstVisitor { public: - // List of value's which should be re-analyzed now with new information + /// List of value's which should be re-analyzed now with new information std::deque workList; private: + /// Tell TypeAnalyzer to reanalyze this value void addToWorkList(llvm::Value *val); + + /// Map of Value to known integer constants that it will take on std::map> intseen; public: - // Calling context + /// Calling context const FnTypeInfo fntypeinfo; + /// Calling TypeAnalysis to be used in the case of calls to other + /// functions TypeAnalysis &interprocedural; + /// Intermediate conservative, but correct Type analysis results std::map analysis; llvm::DominatorTree DT; @@ -126,16 +150,22 @@ class TypeAnalyzer : public llvm::InstVisitor { TypeTree getAnalysis(llvm::Value *val); + /// Add additional information to the Type info of val, readding it to the + /// work queue as necessary void updateAnalysis(llvm::Value *val, BaseType data, llvm::Value *origin); void updateAnalysis(llvm::Value *val, ConcreteType data, llvm::Value *origin); void updateAnalysis(llvm::Value *val, TypeTree data, llvm::Value *origin); + /// Analyze type info given by the arguments, possibly adding to work queue void prepareArgs(); + /// Analyze type info given by the TBAA, possibly adding to work queue void considerTBAA(); + /// Run the interprocedural type analysis starting from this function void run(); + /// Set any unused values to a particular type bool runUnusedChecks(); void visitValue(llvm::Value &val); @@ -207,19 +237,30 @@ class TypeAnalyzer : public llvm::InstVisitor { //TODO handle fneg on LLVM 10+ }; +/// Full interprocedural TypeAnalysis class TypeAnalysis { public: + /// Map of possible query states to TypeAnalyzer intermediate results std::map analyzedFunctions; + /// Analyze a particular function, returning the results TypeResults analyzeFunction(const FnTypeInfo &fn); + /// Get the TypeTree of a given value from a given function context TypeTree query(llvm::Value *val, const FnTypeInfo &fn); + /// Get the underlying data type of value val given a particular context + /// If the type is not known err if errIfNotFound ConcreteType intType(llvm::Value *val, const FnTypeInfo &fn, bool errIfNotFound = true); + + /// Get the underlying data type of first num bytes of val given a particular context + /// If the type is not known err if errIfNotFound. Consider ints and pointers + /// the same if pointerIntSame. ConcreteType firstPointer(size_t num, llvm::Value *val, const FnTypeInfo &fn, bool errIfNotFound = true, bool pointerIntSame = false); + /// Get the TyeTree of the returned value of a given function and context inline TypeTree getReturnAnalysis(const FnTypeInfo &fn) { analyzeFunction(fn); return analyzedFunctions.find(fn)->second.getReturnAnalysis(); diff --git a/enzyme/Enzyme/TypeAnalysis/TypeAnalysisPrinter.cpp b/enzyme/Enzyme/TypeAnalysis/TypeAnalysisPrinter.cpp index cf6733efa137..f5d45c15aa8a 100644 --- a/enzyme/Enzyme/TypeAnalysis/TypeAnalysisPrinter.cpp +++ b/enzyme/Enzyme/TypeAnalysis/TypeAnalysisPrinter.cpp @@ -53,7 +53,8 @@ using namespace llvm; #endif #define DEBUG_TYPE "type-analysis-results" -llvm::cl::opt functionToAnalyzeTypes("type-analysis-func", cl::init(""), cl::Hidden, +/// Function TypeAnalysis will be starting its run from +llvm::cl::opt FunctionToAnalyze("type-analysis-func", cl::init(""), cl::Hidden, cl::desc("Which function to analyze/print")); namespace { @@ -67,11 +68,11 @@ class TypeAnalysisPrinter : public FunctionPass { } bool runOnFunction(Function &F) override { - if (F.getName() != functionToAnalyzeTypes) + if (F.getName() != FunctionToAnalyze) return /*changed*/ false; FnTypeInfo type_args(&F); - for (auto &a : type_args.function->args()) { + for (auto &a : type_args.Function->args()) { TypeTree dt; if (a.getType()->isFPOrFPVectorTy()) { dt = ConcreteType(a.getType()->getScalarType()); @@ -83,10 +84,10 @@ class TypeAnalysisPrinter : public FunctionPass { dt = TypeTree(ConcreteType(BaseType::Pointer)).Only({-1}); } } - type_args.first.insert(std::pair(&a, dt.Only(-1))); + type_args.Arguments.insert(std::pair(&a, dt.Only(-1))); // TODO note that here we do NOT propagate constants in type info (and // should consider whether we should) - type_args.knownValues.insert( + type_args.KnownValues.insert( std::pair>(&a, {})); } @@ -95,15 +96,15 @@ class TypeAnalysisPrinter : public FunctionPass { for (Function &f : *F.getParent()) { for (auto &analysis : TA.analyzedFunctions) { - if (analysis.first.function != &f) + if (analysis.first.Function != &f) continue; auto &ta = analysis.second; - llvm::outs() << f.getName() << " - " << analysis.first.second.str() + llvm::outs() << f.getName() << " - " << analysis.first.Return.str() << " |"; for (auto &a : f.args()) { - llvm::outs() << analysis.first.first.find(&a)->second.str() << ":" - << to_string(analysis.first.knownValues.find(&a)->second) + llvm::outs() << analysis.first.Arguments.find(&a)->second.str() << ":" + << to_string(analysis.first.KnownValues.find(&a)->second) << " "; } llvm::outs() << "\n"; diff --git a/enzyme/Enzyme/TypeAnalysis/TypeTree.h b/enzyme/Enzyme/TypeAnalysis/TypeTree.h index 8bbbc63189bb..47ea401626cf 100644 --- a/enzyme/Enzyme/TypeAnalysis/TypeTree.h +++ b/enzyme/Enzyme/TypeAnalysis/TypeTree.h @@ -40,6 +40,7 @@ #include "BaseType.h" #include "ConcreteType.h" +/// Helper function to print a vector of ints to a string static inline std::string to_string(const std::vector x) { std::string out = "["; for (unsigned i = 0; i < x.size(); ++i) { @@ -57,56 +58,63 @@ typedef std::shared_ptr TypeResult; typedef std::map, ConcreteType> ConcreteTypeMapType; typedef std::map, const TypeResult> TypeTreeMapType; +/// Class representing the underlying types of values as +/// sequences of offsets to a ConcreteType class TypeTree : public std::enable_shared_from_this { private: // mapping of known indices to type if one exists ConcreteTypeMapType mapping; - // mapping of known indices to type if one exists - // TypeTreeMapType recur_mapping; +public: - static std::map, TypeResult> - cache; + TypeTree() {} + TypeTree(ConcreteType dat) { + if (dat != ConcreteType(BaseType::Unknown)) { + insert({}, dat); + } + } -public: - ConcreteType operator[](const std::vector v) const { - auto found = mapping.find(v); - if (found != mapping.end()) { - return found->second; + /// Lookup the underlying ConcreteType at a given offset sequence + /// or Unknown if none exists + ConcreteType operator[](const std::vector Seq) const { + auto Found = mapping.find(Seq); + if (Found != mapping.end()) { + return Found->second; } for (const auto &pair : mapping) { - if (pair.first.size() != v.size()) + if (pair.first.size() != Seq.size()) continue; - bool match = true; - for (unsigned i = 0; i < pair.first.size(); ++i) { + bool Match = true; + for (unsigned i = 0, size = pair.first.size(); i < size; ++i) { if (pair.first[i] == -1) continue; - if (pair.first[i] != v[i]) { - match = false; + if (pair.first[i] != Seq[i]) { + Match = false; break; } } - if (!match) + if (!Match) continue; return pair.second; } return BaseType::Unknown; } - void erase(const std::vector v) { mapping.erase(v); } + /// Remove a given offset sequence + void erase(const std::vector Seq) { mapping.erase(Seq); } - void insert(const std::vector v, ConcreteType d, + void insert(const std::vector Seq, ConcreteType CT, bool intsAreLegalSubPointer = false) { - if (v.size() > 0) { + if (Seq.size() > 0) { // check pointer abilities from before { - std::vector tmp(v.begin(), v.end() - 1); + std::vector tmp(Seq.begin(), Seq.end() - 1); auto found = mapping.find(tmp); if (found != mapping.end()) { if (!(found->second == BaseType::Pointer || found->second == BaseType::Anything)) { - llvm::errs() << "FAILED dt: " << str() - << " adding v: " << to_string(v) << ": " << d.str() + llvm::errs() << "FAILED CT: " << str() + << " adding Seq: " << to_string(Seq) << ": " << CT.str() << "\n"; } assert(found->second == BaseType::Pointer || @@ -116,53 +124,53 @@ class TypeTree : public std::enable_shared_from_this { // don't insert if there's an existing ending -1 { - std::vector tmp(v.begin(), v.end() - 1); + std::vector tmp(Seq.begin(), Seq.end() - 1); tmp.push_back(-1); auto found = mapping.find(tmp); if (found != mapping.end()) { - if (found->second != d) { - if (d == BaseType::Anything) { - found->second = d; + if (found->second != CT) { + if (CT == BaseType::Anything) { + found->second = CT; } else { llvm::errs() << "FAILED dt: " << str() - << " adding v: " << to_string(v) << ": " << d.str() + << " adding v: " << to_string(Seq) << ": " << CT.str() << "\n"; } } - assert(found->second == d); + assert(found->second == CT); return; } } // don't insert if there's an existing starting -1 { - std::vector tmp(v.begin(), v.end()); + std::vector tmp(Seq.begin(), Seq.end()); tmp[0] = -1; auto found = mapping.find(tmp); if (found != mapping.end()) { - if (found->second != d) { - if (d == BaseType::Anything) { - found->second = d; + if (found->second != CT) { + if (CT == BaseType::Anything) { + found->second = CT; } else { llvm::errs() << "FAILED dt: " << str() - << " adding v: " << to_string(v) << ": " << d.str() + << " adding v: " << to_string(Seq) << ": " << CT.str() << "\n"; } } - assert(found->second == d); + assert(found->second == CT); return; } } // if this is a ending -1, remove other -1's - if (v.back() == -1) { + if (Seq.back() == -1) { std::set> toremove; for (const auto &pair : mapping) { - if (pair.first.size() == v.size()) { + if (pair.first.size() == Seq.size()) { bool matches = true; for (unsigned i = 0; i < pair.first.size() - 1; ++i) { - if (pair.first[i] != v[i]) { + if (pair.first[i] != Seq[i]) { matches = false; break; } @@ -171,15 +179,15 @@ class TypeTree : public std::enable_shared_from_this { continue; if (intsAreLegalSubPointer && - pair.second.typeEnum == BaseType::Integer && - d.typeEnum == BaseType::Pointer) { + pair.second == BaseType::Integer && + CT == BaseType::Pointer) { } else { - if (pair.second != d) { + if (pair.second != CT) { llvm::errs() << "inserting into : " << str() << " with " - << to_string(v) << " of " << d.str() << "\n"; + << to_string(Seq) << " of " << CT.str() << "\n"; } - assert(pair.second == d); + assert(pair.second == CT); } toremove.insert(pair.first); } @@ -191,20 +199,20 @@ class TypeTree : public std::enable_shared_from_this { } // if this is a starting -1, remove other -1's - if (v[0] == -1) { + if (Seq[0] == -1) { std::set> toremove; for (const auto &pair : mapping) { - if (pair.first.size() == v.size()) { + if (pair.first.size() == Seq.size()) { bool matches = true; for (unsigned i = 1; i < pair.first.size(); ++i) { - if (pair.first[i] != v[i]) { + if (pair.first[i] != Seq[i]) { matches = false; break; } } if (!matches) continue; - assert(pair.second == d); + assert(pair.second == CT); toremove.insert(pair.first); } } @@ -214,31 +222,25 @@ class TypeTree : public std::enable_shared_from_this { } } } - if (v.size() > 6) { + if (Seq.size() > 6) { llvm::errs() << "not handling more than 6 pointer lookups deep dt:" - << str() << " adding v: " << to_string(v) << ": " << d.str() + << str() << " adding v: " << to_string(Seq) << ": " << CT.str() << "\n"; return; } - for (auto a : v) { - if (a > 1000) { - // llvm::errs() << "not handling more than 1000B offset pointer dt:" << - // str() << " adding v: " << to_string(v) << ": " << d.str() << "\n"; + for (auto Off : Seq) { + if (Off > 1000) { + // TODO perhaps issue warning for too large an offset return; } } - mapping.insert(std::pair, ConcreteType>(v, d)); + mapping.insert(std::pair, ConcreteType>(Seq, CT)); } + /// How this TypeTree compares with another bool operator<(const TypeTree &vd) const { return mapping < vd.mapping; } - TypeTree() {} - TypeTree(ConcreteType dat) { - if (dat != ConcreteType(BaseType::Unknown)) { - insert({}, dat); - } - } - + /// Whether this TypeTree contains any information bool isKnown() { for (auto &pair : mapping) { // we should assert here as we shouldn't keep any unknown maps for @@ -248,6 +250,7 @@ class TypeTree : public std::enable_shared_from_this { return mapping.size() != 0; } + /// Whether this TypeTree knows any non-pointer information bool isKnownPastPointer() { for (auto &pair : mapping) { // we should assert here as we shouldn't keep any unknown maps for @@ -262,12 +265,11 @@ class TypeTree : public std::enable_shared_from_this { return false; } - static TypeTree Unknown() { return TypeTree(); } - + /// Select only the Integer ConcreteTypes TypeTree JustInt() const { TypeTree vd; for (auto &pair : mapping) { - if (pair.second.typeEnum == BaseType::Integer) { + if (pair.second == BaseType::Integer) { vd.insert(pair.first, pair.second); } } @@ -282,48 +284,27 @@ class TypeTree : public std::enable_shared_from_this { TypeTree KeepForCast(const llvm::DataLayout &dl, llvm::Type *from, llvm::Type *to) const; - static std::vector appendIndex(int off, const std::vector &first) { - std::vector out; - out.push_back(off); - for (auto a : first) - out.push_back(a); - return out; + /// Helper function to prepend an offset + static std::vector prependIndex(int Off, const std::vector &Array) { + std::vector Result; + Result.push_back(Off); + for (auto Val : Array) + Result.push_back(Val); + return Result; } - TypeTree Only(int off) const { - TypeTree dat; - + /// Prepend an offset to all mappings + TypeTree Only(int Off) const { + TypeTree Result; for (const auto &pair : mapping) { - dat.insert(appendIndex(off, pair.first), pair.second); - // if (pair.first.size() > 0) { - // dat.insert(indices, ConcreteType(BaseType::Pointer)); - //} + Result.insert(prependIndex(Off, pair.first), pair.second); } - - return dat; - } - - static bool lookupIndices(std::vector &first, int idx, - const std::vector &second) { - if (second.size() == 0) - return false; - - assert(first.size() == 0); - - if (idx == -1) { - } else if (second[0] == -1) { - } else if (idx != second[0]) { - return false; - } - - for (size_t i = 1; i < second.size(); ++i) { - first.push_back(second[i]); - } - return true; + return Result; } + /// Peel off the outermost index at offset 0 TypeTree Data0() const { - TypeTree dat; + TypeTree Result; for (const auto &pair : mapping) { assert(pair.first.size() != 0); @@ -332,45 +313,60 @@ class TypeTree : public std::enable_shared_from_this { std::vector next; for (size_t i = 1; i < pair.first.size(); ++i) next.push_back(pair.first[i]); - TypeTree dat2; - dat2.insert(next, pair.second); - dat |= dat2; + // We do insertion like this to force an error + // on the |= operation if there is an incompatible + // merge. The insert operation does not error. + TypeTree SubResult; + SubResult.insert(next, pair.second); + Result |= SubResult; } } - return dat; + return Result; } + /// Remove any mappings in the range [start, end) or [len, inf) + /// This function has special handling for -1's TypeTree Clear(size_t start, size_t end, size_t len) const { - TypeTree dat; + TypeTree Result; + + // Note that below do insertion with the |= operator + // to force an error if there is an incompatible + // merge. The insert operation does not error. for (const auto &pair : mapping) { assert(pair.first.size() != 0); if (pair.first[0] == -1) { - TypeTree dat2; + // For "all index" calculations, explicitly + // add mappings for regions in range + TypeTree SubResult; auto next = pair.first; for (size_t i = 0; i < start; ++i) { next[0] = i; - dat2.insert(next, pair.second); + SubResult.insert(next, pair.second); } for (size_t i = end; i < len; ++i) { next[0] = i; - dat2.insert(next, pair.second); + SubResult.insert(next, pair.second); } - dat |= dat2; + Result |= SubResult; } else if ((size_t)pair.first[0] > start && (size_t)pair.first[0] >= end && (size_t)pair.first[0] < len) { - TypeTree dat2; - dat2.insert(pair.first, pair.second); - dat |= dat2; + // Otherwise simply check that the given offset is in range + + TypeTree SubResult; + SubResult.insert(pair.first, pair.second); + Result |= SubResult; } } // TODO canonicalize this - return dat; + return Result; } + /// Select all submappings whose first index is in range [0, len) and remove + /// the first index. This is the inverse of the `Only` operation TypeTree Lookup(size_t len, const llvm::DataLayout &dl) const { // Map of indices[1:] => ( End => possible Index[0] ) @@ -403,7 +399,7 @@ class TypeTree : public std::enable_shared_from_this { staging[next][pair.second].insert(pair.first[1]); } - TypeTree dat; + TypeTree Result; for (auto &pair : staging) { auto &pnext = pair.first; for (auto &pair2 : pair.second) { @@ -426,7 +422,7 @@ class TypeTree : public std::enable_shared_from_this { llvm::errs() << *flt << "\n"; assert(0 && "unhandled float type"); } - } else if (dt.typeEnum == BaseType::Pointer) { + } else if (dt == BaseType::Pointer) { chunk = dl.getPointerSizeInBits() / 8; } @@ -445,19 +441,21 @@ class TypeTree : public std::enable_shared_from_this { next.push_back(v); if (legalCombine) { - dat.insert(next, dt, /*intsAreLegalPointerSub*/ true); + Result.insert(next, dt, /*intsAreLegalPointerSub*/ true); } else { for (auto e : set) { next[0] = e; - dat.insert(next, dt); + Result.insert(next, dt); } } } } - return dat; + return Result; } + /// Given that this tree represents something of at most size len, canonicalize + /// this, creating -1's where possible TypeTree CanonicalizeValue(size_t len, const llvm::DataLayout &dl) const { // Map of indices[1:] => ( End => possible Index[0] ) @@ -497,7 +495,7 @@ class TypeTree : public std::enable_shared_from_this { llvm::errs() << *flt << "\n"; assert(0 && "unhandled float type"); } - } else if (dt.typeEnum == BaseType::Pointer) { + } else if (dt == BaseType::Pointer) { chunk = dl.getPointerSizeInBits() / 8; } @@ -529,6 +527,7 @@ class TypeTree : public std::enable_shared_from_this { return dat; } + /// Keep only pointers (or anything's) to a repeated value (represented by -1) TypeTree KeepMinusOne() const { TypeTree dat; @@ -557,23 +556,25 @@ class TypeTree : public std::enable_shared_from_this { return dat; } - //! Replace offsets in [offset, offset+maxSize] with [addOffset, addOffset + - //! maxSize] + /// Replace mappings in the range in [offset, offset+maxSize] with those in + // [addOffset, addOffset + maxSize]. In other worse, select all mappings in + // [offset, offset+maxSize] then add `addOffset` TypeTree ShiftIndices(const llvm::DataLayout &dl, int offset, int maxSize, size_t addOffset = 0) const { - TypeTree dat; + TypeTree Result; for (const auto &pair : mapping) { if (pair.first.size() == 0) { if (pair.second == BaseType::Pointer || pair.second == BaseType::Anything) { - dat.insert(pair.first, pair.second); + Result.insert(pair.first, pair.second); continue; } llvm::errs() << "could not unmerge " << str() << "\n"; + assert(0 && "ShiftIndices called on a nonpointer/anything"); + llvm_unreachable("ShiftIndices called on a nonpointer/anything"); } - assert(pair.first.size() > 0); std::vector next(pair.first); @@ -606,8 +607,6 @@ class TypeTree : public std::enable_shared_from_this { } TypeTree dat2; - // llvm::errs() << "next: " << to_string(next) << " indices: " << - // to_string(indices) << " pair.first: " << to_string(pair.first) << "\n"; if (next[0] == -1 && maxSize != -1) { size_t chunk = 1; auto op = operator[]({pair.first[0]}); @@ -622,7 +621,7 @@ class TypeTree : public std::enable_shared_from_this { llvm::errs() << *flt << "\n"; assert(0 && "unhandled float type"); } - } else if (op.typeEnum == BaseType::Pointer) { + } else if (op == BaseType::Pointer) { chunk = dl.getPointerSizeInBits() / 8; } @@ -633,13 +632,13 @@ class TypeTree : public std::enable_shared_from_this { } else { dat2.insert(next, pair.second); } - dat |= dat2; + Result |= dat2; } - return dat; + return Result; } - // Removes any anything types + /// Keep only mappings where the type is not an `Anything` TypeTree PurgeAnything() const { TypeTree dat; for (const auto &pair : mapping) { @@ -650,7 +649,7 @@ class TypeTree : public std::enable_shared_from_this { return dat; } - // TODO note that this keeps -1's + /// Select mappings in range [0, max), preserving -1's TypeTree AtMost(size_t max) const { assert(max > 0); TypeTree dat; @@ -663,70 +662,66 @@ class TypeTree : public std::enable_shared_from_this { return dat; } - static TypeTree Argument(ConcreteType type, llvm::Value *v) { - if (v->getType()->isIntOrIntVectorTy()) - return TypeTree(type); - return TypeTree(type).Only(0); - } - - bool operator==(const TypeTree &v) const { return mapping == v.mapping; } + /// Chceck equality of two TypeTrees + bool operator==(const TypeTree &RHS) const { return mapping == RHS.mapping; } - // Return if changed - bool operator=(const TypeTree &v) { - if (*this == v) + /// Set this to another TypeTree, returning if this was changed + bool operator=(const TypeTree &RHS) { + if (*this == RHS) return false; mapping.clear(); - for(const auto& elems : v.mapping) { + for(const auto& elems : RHS.mapping) { mapping.emplace(elems); } - //mapping = v.mapping; return true; } - bool mergeIn(const TypeTree &v, bool pointerIntSame) { - //! Todo detect recursive merge + /// Set this to the logical or of itself and RHS, returning whether this value changed + /// Setting `PointerIntSame` considers pointers and integers as equivalent + /// This function will error if doing an illegal Operation + bool orIn(const TypeTree &RHS, bool PointerIntSame) { + // TODO detect recursive merge and simplify bool changed = false; - if (v[{-1}] != BaseType::Unknown) { + if (RHS[{-1}] != BaseType::Unknown) { for (auto &pair : mapping) { if (pair.first.size() == 1 && pair.first[0] != -1) { - pair.second.mergeIn(v[{-1}], pointerIntSame); + pair.second.orIn(RHS[{-1}], PointerIntSame); // if (pair.second == ) // NOTE DELETE the non -1 } } } - for (auto &pair : v.mapping) { + for (auto &pair : RHS.mapping) { assert(pair.second != BaseType::Unknown); - ConcreteType dt = operator[](pair.first); - // llvm::errs() << "merging @ " << to_string(pair.first) << " old:" << - // dt.str() << " new:" << pair.second.str() << "\n"; - changed |= (dt.mergeIn(pair.second, pointerIntSame)); - insert(pair.first, dt); + ConcreteType CT = operator[](pair.first); + changed |= (CT.orIn(pair.second, PointerIntSame)); + insert(pair.first, CT); } return changed; } - bool operator|=(const TypeTree &v) { - return mergeIn(v, /*pointerIntSame*/ false); - } - - bool operator&=(const TypeTree &v) { - return andIn(v, /*assertIfIllegal*/ true); + /// Set this to the logical or of itself and RHS, returning whether this value changed + /// This assumes that pointers and integers are distinct + /// This function will error if doing an illegal Operation + bool operator|=(const TypeTree &RHS) { + return orIn(RHS, /*PointerIntSame*/ false); } - bool andIn(const TypeTree &v, bool assertIfIllegal = true) { + /// Set this to the logical and of itself and RHS, returning whether this value changed + /// If this and RHS are incompatible at an index, the result will be BaseType::Unknown + bool andIn(const TypeTree &RHS) { bool changed = false; std::vector> keystodelete; for (auto &pair : mapping) { ConcreteType other = BaseType::Unknown; - auto fd = v.mapping.find(pair.first); - if (fd != v.mapping.end()) { + auto fd = RHS.mapping.find(pair.first); + if (fd != RHS.mapping.end()) { other = fd->second; } - changed = (pair.second.andIn(other, assertIfIllegal)); + changed = (pair.second &= other); if (pair.second == BaseType::Unknown) { keystodelete.push_back(pair.first); } @@ -739,21 +734,30 @@ class TypeTree : public std::enable_shared_from_this { return changed; } - bool pointerIntMerge(const TypeTree &v, llvm::BinaryOperator::BinaryOps op) { + /// Set this to the logical and of itself and RHS, returning whether this value changed + /// If this and RHS are incompatible at an index, the result will be BaseType::Unknown + bool operator&=(const TypeTree &RHS) { + return andIn(RHS); + } + + /// Set this to the logical `binop` of itself and RHS, using the Binop Op, + /// returning true if this was changed. + /// This function will error on an invalid type combination + bool binopIn(const TypeTree &RHS, llvm::BinaryOperator::BinaryOps Op) { bool changed = false; auto found = mapping.find({}); if (found != mapping.end()) { - changed |= (found->second.pointerIntMerge(v[{}], op)); + changed |= (found->second.binopIn(RHS[{}], Op)); if (found->second == BaseType::Unknown) { mapping.erase(std::vector({})); } - } else if (v.mapping.find({}) != v.mapping.end()) { - ConcreteType dt(BaseType::Unknown); - dt.pointerIntMerge(v[{}], op); - if (dt != BaseType::Unknown) { + } else if (RHS.mapping.find({}) != RHS.mapping.end()) { + ConcreteType CT(BaseType::Unknown); + CT.binopIn(RHS[{}], Op); + if (CT != BaseType::Unknown) { changed = true; - mapping.emplace(std::vector({}), dt); + mapping.emplace(std::vector({}), CT); } } @@ -772,6 +776,7 @@ class TypeTree : public std::enable_shared_from_this { return changed; } + /// Returns a string representation of this TypeTree std::string str() const { std::string out = "{"; bool first = true; diff --git a/enzyme/Enzyme/Utils.h b/enzyme/Enzyme/Utils.h index 0788e6485b39..8a1af5b3ff75 100644 --- a/enzyme/Enzyme/Utils.h +++ b/enzyme/Enzyme/Utils.h @@ -358,15 +358,15 @@ static inline bool isCertainPrintMallocOrFree(llvm::Function *called) { return false; } -//! Create function for type that performs the derivative memcpy on floating -//! point memory +/// Create function for type that performs the derivative memcpy on floating +/// point memory llvm::Function *getOrInsertDifferentialFloatMemcpy(llvm::Module &M, llvm::PointerType *T, unsigned dstalign, unsigned srcalign); -//! Create function for type that performs the derivative memmove on floating -//! point memory +/// Create function for type that performs the derivative memmove on floating +/// point memory llvm::Function *getOrInsertDifferentialFloatMemmove(llvm::Module &M, llvm::PointerType *T, unsigned dstalign, @@ -375,7 +375,6 @@ llvm::Function *getOrInsertDifferentialFloatMemmove(llvm::Module &M, template static inline typename std::map::iterator insert_or_assign(std::map &map, K& key, V &&val) { - // map.insert_or_assign(key, val); auto found = map.find(key); if (found != map.end()) { map.erase(found); @@ -386,7 +385,6 @@ insert_or_assign(std::map &map, K& key, V &&val) { template static inline typename std::map::iterator insert_or_assign2(std::map &map, K key, V val) { - // map.insert_or_assign(key, val); auto found = map.find(key); if (found != map.end()) { map.erase(found); @@ -401,10 +399,8 @@ insert_or_assign2(std::map &map, K key, V val) { static inline void allFollowersOf(llvm::Instruction *inst, std::function f) { - // llvm::errs() << "all followers of: " << *inst << "\n"; for (auto uinst = inst->getNextNode(); uinst != nullptr; uinst = uinst->getNextNode()) { - // llvm::errs() << " + bb1: " << *uinst << "\n"; if (f(uinst)) return; } @@ -436,10 +432,8 @@ static inline void allPredecessorsOf(llvm::Instruction *inst, std::function f) { - // llvm::errs() << "all followers of: " << *inst << "\n"; for (auto uinst = inst->getPrevNode(); uinst != nullptr; uinst = uinst->getPrevNode()) { - // llvm::errs() << " + bb1: " << *uinst << "\n"; if (f(uinst)) return; } @@ -476,7 +470,6 @@ allInstructionsBetween(llvm::LoopInfo &LI, llvm::Instruction *inst1, std::function f) { for (auto uinst = inst1->getNextNode(); uinst != nullptr; uinst = uinst->getNextNode()) { - // llvm::errs() << " + bb1: " << *uinst << "\n"; if (f(uinst)) return; if (uinst == inst2) @@ -488,11 +481,6 @@ allInstructionsBetween(llvm::LoopInfo &LI, llvm::Instruction *inst1, llvm::Loop *l1 = LI.getLoopFor(inst1->getParent()); while (l1 && !l1->contains(inst2->getParent())) l1 = l1->getParentLoop(); - /* - llvm::errs() << " l1: " << l1; - if (l1) llvm::errs() << " " << *l1; - llvm::errs() << "\n"; - */ // Do all instructions from inst1 up to first instance of inst2's start block { @@ -508,7 +496,6 @@ allInstructionsBetween(llvm::LoopInfo &LI, llvm::Instruction *inst1, continue; done.insert(BB); - // llvm::errs() << " block: " << BB->getName() << "\n"; for (auto &ni : *BB) { instructions.insert(&ni); } diff --git a/enzyme/test/Enzyme/unusedalloc.ll b/enzyme/test/Enzyme/unusedalloc.ll new file mode 100644 index 000000000000..df37adf835cd --- /dev/null +++ b/enzyme/test/Enzyme/unusedalloc.ll @@ -0,0 +1,39 @@ +; RUN: %opt < %s %loadEnzyme -enzyme -enzyme_preopt=false -mem2reg -instsimplify -simplifycfg -S | FileCheck %s +; XFAIL: * +; TODO, this currently fails because as Enzyme we don't run DSE + +declare noalias i8* @malloc(i64) + +define double @sub(double %x, i64 %y) { +entry: + %malloccall = tail call i8* @malloc(i64 8) + %bc = bitcast i8* %malloccall to i64* + store i64 %y, i64* %bc, align 8 + ret double %x +} + +define double @caller(double %x) { +entry: + %call = tail call double @sub(double %x, i64 0) + ret double %call +} + +define dso_local double @dcaller(double %x) { +entry: + %0 = tail call double (double (double)*, ...) @__enzyme_autodiff(double (double)* nonnull @caller, double %x) + ret double %0 +} + +declare double @__enzyme_autodiff(double (double)*, ...) + +; CHECK: define internal {{(dso_local )?}}{ double } @diffeadd4(double %x, double %[[differet:.+]]) +; CHECK-NEXT: entry: +; CHECK-NEXT: %0 = call { double } @diffeadd2(double %x, double %[[differet]]) +; CHECK-NEXT: ret { double } %0 +; CHECK-NEXT: } + +; CHECK: define internal {{(dso_local )?}}{ double } @diffeadd2(double %x, double %[[differet:.+]]) +; CHECK-NEXT: entry: +; CHECK-NEXT: %[[result2:.+]] = insertvalue { double } undef, double %[[differet]], 0 +; CHECK-NEXT: ret { double } %[[result2]] +; CHECK-NEXT: }