Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cleanup and document Type Analysis code #58

Merged
merged 1 commit into from
Sep 24, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 14 additions & 14 deletions enzyme/Enzyme/ActiveVariable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,10 @@ cl::opt<bool> emptyfnconst("enzyme_emptyfnconst", cl::init(false), cl::Hidden,
#include <set>
#include <unordered_map>


constexpr uint8_t UP = 1;
constexpr uint8_t DOWN = 2;

bool isFunctionArgumentConstant(TypeResults &TR, CallInst *CI, Value *val,
SmallPtrSetImpl<Value *> &constants,
SmallPtrSetImpl<Value *> &nonconstant,
Expand Down Expand Up @@ -180,11 +184,11 @@ bool isFunctionArgumentConstant(TypeResults &TR, CallInst *CI, Value *val,
FnTypeInfo nextTypeInfo(F);
int argnum = 0;
for (auto &arg : F->args()) {
nextTypeInfo.first.insert(std::pair<Argument *, TypeTree>(
nextTypeInfo.Arguments.insert(std::pair<Argument *, TypeTree>(
&arg, TR.query(CI->getArgOperand(argnum))));
++argnum;
}
nextTypeInfo.second = TR.query(CI);
nextTypeInfo.Return = TR.query(CI);
TypeResults TR2 = TR.analysis.analyzeFunction(nextTypeInfo);

for (unsigned i = 0; i < CI->getNumArgOperands(); ++i) {
Expand Down Expand Up @@ -352,9 +356,7 @@ bool isconstantM(TypeResults &TR, Instruction *inst,
SmallPtrSetImpl<Value *> &retvals, AAResults &AA,
uint8_t directions) {
assert(inst);
assert(TR.info.function == inst->getParent()->getParent());
constexpr uint8_t UP = 1;
constexpr uint8_t DOWN = 2;
assert(TR.info.Function == inst->getParent()->getParent());
// assert(directions >= 0);
assert(directions <= 3);
if (isa<ReturnInst>(inst))
Expand Down Expand Up @@ -473,7 +475,7 @@ bool isconstantM(TypeResults &TR, Instruction *inst,
auto q = TR.query(storeinst->getPointerOperand()).Data0();
for (int i = -1; i < (int)storeSize; ++i) {
auto dt = q[{i}];
if (dt.isIntegral() || dt.typeEnum == BaseType::Anything) {
if (dt.isIntegral() || dt == BaseType::Anything) {
anIntegral = true;
} else if (dt.isKnown()) {
allIntegral = false;
Expand Down Expand Up @@ -964,13 +966,11 @@ bool isconstantValueM(TypeResults &TR, Value *val,
uint8_t directions) {
assert(val);
if (auto inst = dyn_cast<Instruction>(val)) {
assert(TR.info.function == inst->getParent()->getParent());
assert(TR.info.Function == inst->getParent()->getParent());
}
if (auto arg = dyn_cast<Argument>(val)) {
assert(TR.info.function == arg->getParent());
assert(TR.info.Function == arg->getParent());
}
// constexpr uint8_t UP = 1;
constexpr uint8_t DOWN = 2;
// assert(directions >= 0);
assert(directions <= 3);

Expand Down Expand Up @@ -1017,8 +1017,8 @@ bool isconstantValueM(TypeResults &TR, Value *val,
assert(0 && "must've put arguments in constant/nonconstant");
}

//! This value is certainly an integer (and only and integer, not a pointer or
//! float). Therefore its value is constant
// This value is certainly an integer (and only and integer, not a pointer or
// float). Therefore its value is constant
if (TR.intType(val, /*errIfNotFound*/ false).isIntegral()) {
if (printconst)
llvm::errs() << " Value const as integral " << (int)directions << " "
Expand All @@ -1028,8 +1028,8 @@ bool isconstantValueM(TypeResults &TR, Value *val,
return true;
}

//! This value is certainly a pointer to an integer (and only and integer, not
//! a pointer or float). Therefore its value is constant
// This value is certainly a pointer to an integer (and only and integer, not
// a pointer or float). Therefore its value is constant
// TODO use typeInfo for more aggressive activity analysis
if (val->getType()->isPointerTy() &&
cast<PointerType>(val->getType())->isIntOrIntVectorTy() &&
Expand Down
14 changes: 7 additions & 7 deletions enzyme/Enzyme/AdjointGenerator.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ class DerivativeMaker
unnecessaryInstructions(unnecessaryInstructions),
unnecessaryStores(unnecessaryStores), dretAlloca(dretAlloca) {

assert(TR.info.function == gutils->oldFunc);
assert(TR.info.Function == gutils->oldFunc);
for (auto &pair :
TR.analysis.analyzedFunctions.find(TR.info)->second.analysis) {
if (auto in = dyn_cast<Instruction>(pair.first)) {
Expand Down Expand Up @@ -1129,9 +1129,9 @@ class DerivativeMaker

auto dt = vd[{-1}];
for (size_t i = start; i < size; ++i) {
bool legal = true;
dt.legalMergeIn(vd[{(int)i}], /*pointerIntSame*/ true, legal);
if (!legal) {
bool Legal = true;
dt.checkedOrIn(vd[{(int)i}], /*PointerIntSame*/ true, Legal);
if (!Legal) {
nextStart = i;
break;
}
Expand Down Expand Up @@ -1912,15 +1912,15 @@ class DerivativeMaker
std::map<Value *, std::set<int64_t>> intseen;

for (auto &arg : called->args()) {
nextTypeInfo.first.insert(std::pair<Argument *, TypeTree>(
nextTypeInfo.Arguments.insert(std::pair<Argument *, TypeTree>(
&arg, TR.query(orig->getArgOperand(argnum))));
nextTypeInfo.knownValues.insert(
nextTypeInfo.KnownValues.insert(
std::pair<Argument *, std::set<int64_t>>(
&arg, TR.knownIntegralValues(orig->getArgOperand(argnum))));

++argnum;
}
nextTypeInfo.second = TR.query(orig);
nextTypeInfo.Return = TR.query(orig);
}

// llvm::Optional<std::map<std::pair<Instruction*, std::string>, unsigned>>
Expand Down
6 changes: 3 additions & 3 deletions enzyme/Enzyme/Enzyme.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ void HandleAutoDiff(CallInst *CI, TargetLibraryInfo &TLI, AAResults &AA) {

std::map<Argument *, bool> volatile_args;
FnTypeInfo type_args(cast<Function>(fn));
for (auto &a : type_args.function->args()) {
for (auto &a : type_args.Function->args()) {
volatile_args[&a] = false;
TypeTree dt;
if (a.getType()->isFPOrFPVectorTy()) {
Expand All @@ -223,10 +223,10 @@ void HandleAutoDiff(CallInst *CI, TargetLibraryInfo &TLI, AAResults &AA) {
dt = TypeTree(ConcreteType(BaseType::Pointer)).Only(-1);
}
}
type_args.first.insert(std::pair<Argument *, TypeTree>(&a, dt.Only(-1)));
type_args.Arguments.insert(std::pair<Argument *, TypeTree>(&a, dt.Only(-1)));
// TODO note that here we do NOT propagate constants in type info (and
// should consider whether we should)
type_args.knownValues.insert(
type_args.KnownValues.insert(
std::pair<Argument *, std::set<int64_t>>(&a, {}));
}

Expand Down
36 changes: 18 additions & 18 deletions enzyme/Enzyme/EnzymeLogic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -943,7 +943,7 @@ CreateAugmentedPrimal(Function *todiff, DIFFE_TYPE retType,
!todiff->getReturnType()->isVoidTy());

FnTypeInfo oldTypeInfo = oldTypeInfo_;
for (auto &pair : oldTypeInfo.knownValues) {
for (auto &pair : oldTypeInfo.KnownValues) {
if (pair.second.size() != 0) {
bool recursiveUse = false;
for (auto user : pair.first->users()) {
Expand Down Expand Up @@ -1082,23 +1082,23 @@ CreateAugmentedPrimal(Function *todiff, DIFFE_TYPE retType,
for (; toarg != todiff->arg_end(); ++toarg, ++olarg) {

{
auto fd = oldTypeInfo.first.find(toarg);
assert(fd != oldTypeInfo.first.end());
typeInfo.first.insert(
auto fd = oldTypeInfo.Arguments.find(toarg);
assert(fd != oldTypeInfo.Arguments.end());
typeInfo.Arguments.insert(
std::pair<Argument *, TypeTree>(olarg, fd->second));
}

{
auto cfd = oldTypeInfo.knownValues.find(toarg);
assert(cfd != oldTypeInfo.knownValues.end());
typeInfo.knownValues.insert(
auto cfd = oldTypeInfo.KnownValues.find(toarg);
assert(cfd != oldTypeInfo.KnownValues.end());
typeInfo.KnownValues.insert(
std::pair<Argument *, std::set<int64_t>>(olarg, cfd->second));
}
}
typeInfo.second = oldTypeInfo.second;
typeInfo.Return = oldTypeInfo.Return;
}
TypeResults TR = TA.analyzeFunction(typeInfo);
assert(TR.info.function == gutils->oldFunc);
assert(TR.info.Function == gutils->oldFunc);
gutils->forceActiveDetection(AA, TR);

gutils->forceAugmentedReturns(TR, guaranteedUnreachable);
Expand Down Expand Up @@ -1831,7 +1831,7 @@ Function *CreatePrimalAndGradient(
const AugmentedReturn *augmenteddata) {

FnTypeInfo oldTypeInfo = oldTypeInfo_;
for (auto &pair : oldTypeInfo.knownValues) {
for (auto &pair : oldTypeInfo.KnownValues) {
if (pair.second.size() != 0) {
bool recursiveUse = false;
for (auto user : pair.first->users()) {
Expand Down Expand Up @@ -2045,24 +2045,24 @@ Function *CreatePrimalAndGradient(
for (; toarg != todiff->arg_end(); ++toarg, ++olarg) {

{
auto fd = oldTypeInfo.first.find(toarg);
assert(fd != oldTypeInfo.first.end());
typeInfo.first.insert(
auto fd = oldTypeInfo.Arguments.find(toarg);
assert(fd != oldTypeInfo.Arguments.end());
typeInfo.Arguments.insert(
std::pair<Argument *, TypeTree>(olarg, fd->second));
}

{
auto cfd = oldTypeInfo.knownValues.find(toarg);
assert(cfd != oldTypeInfo.knownValues.end());
typeInfo.knownValues.insert(
auto cfd = oldTypeInfo.KnownValues.find(toarg);
assert(cfd != oldTypeInfo.KnownValues.end());
typeInfo.KnownValues.insert(
std::pair<Argument *, std::set<int64_t>>(olarg, cfd->second));
}
}
typeInfo.second = oldTypeInfo.second;
typeInfo.Return = oldTypeInfo.Return;
}

TypeResults TR = TA.analyzeFunction(typeInfo);
assert(TR.info.function == gutils->oldFunc);
assert(TR.info.Function == gutils->oldFunc);

gutils->forceActiveDetection(AA, TR);
gutils->forceAugmentedReturns(TR, guaranteedUnreachable);
Expand Down
4 changes: 2 additions & 2 deletions enzyme/Enzyme/GradientUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -510,8 +510,8 @@ Value *GradientUtils::invertPointerM(Value *oval, IRBuilder<> &BuilderM) {
std::vector<DIFFE_TYPE> types;
for (auto &a : fn->args()) {
uncacheable_args[&a] = !a.getType()->isFPOrFPVectorTy();
type_args.first.insert(std::pair<Argument *, TypeTree>(&a, {}));
type_args.knownValues.insert(
type_args.Arguments.insert(std::pair<Argument *, TypeTree>(&a, {}));
type_args.KnownValues.insert(
std::pair<Argument *, std::set<int64_t>>(&a, {}));
DIFFE_TYPE typ;
if (a.getType()->isFPOrFPVectorTy()) {
Expand Down
2 changes: 1 addition & 1 deletion enzyme/Enzyme/GradientUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -1170,7 +1170,7 @@ class GradientUtils {
void forceAugmentedReturns(
TypeResults &TR,
const SmallPtrSetImpl<BasicBlock *> &guaranteedUnreachable) {
assert(TR.info.function == oldFunc);
assert(TR.info.Function == oldFunc);

for (BasicBlock &oBB : *oldFunc) {
// Don't create derivatives for code that results in termination
Expand Down
7 changes: 5 additions & 2 deletions enzyme/Enzyme/TypeAnalysis/BaseType.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,19 +27,21 @@
#include <string>
#include "llvm/Support/ErrorHandling.h"

/// Categories of potential types
enum class BaseType {
// integral type
// integral type which doesn't represent a pointer
Integer,
// floating point
Float,
// pointer
Pointer,
// can be anything of users choosing [usually result of a constant]
// can be anything of users choosing [usually result of a constant such as 0]
Anything,
// insufficient information
Unknown
};

/// Convert Basetype to string
static inline std::string to_string(BaseType t) {
switch (t) {
case BaseType::Integer:
Expand All @@ -56,6 +58,7 @@ static inline std::string to_string(BaseType t) {
llvm_unreachable("unknown inttype");
}

/// Convert string to BaseType
static inline BaseType parseBaseType(std::string str) {
if (str == "Integer")
return BaseType::Integer;
Expand Down
Loading