Skip to content

Commit

Permalink
Strengthen any type checks (#1864)
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed May 7, 2024
1 parent 7d09d5e commit 2e164a0
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 10 deletions.
4 changes: 2 additions & 2 deletions enzyme/Enzyme/EnzymeLogic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3330,8 +3330,8 @@ void createInvertedTerminator(DiffeGradientUtils *gutils,
auto PNtype = PNtypeT[{-1}];

// TODO remove explicit type check and only use PNtype
if (PNtype == BaseType::Anything || PNtype == BaseType::Pointer ||
PNtype == BaseType::Integer || orig->getType()->isPointerTy())
if (!gutils->TR.anyFloat(orig, /*anythingIsFloat*/ false) ||
orig->getType()->isPointerTy())
continue;

Type *PNfloatType = PNtype.isFloat();
Expand Down
50 changes: 43 additions & 7 deletions enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5888,29 +5888,60 @@ TypeTree TypeResults::query(Value *val) const {
return analyzer->getAnalysis(val);
}

bool TypeResults::anyFloat(Value *val) const {
// Returns last non-padding/alignment location of the corresponding subtype T.
size_t skippedBytes(SmallSet<size_t, 8> &offs, Type *T, const DataLayout &DL,
size_t offset = 0) {
auto ST = dyn_cast<StructType>(T);
if (!ST)
return (DL.getTypeSizeInBits(T) + 7) / 8;

auto SL = DL.getStructLayout(ST);
size_t prevOff = 0;
for (size_t idx = 0; idx < ST->getNumElements(); idx++) {
auto off = SL->getElementOffset(idx);
if (off > prevOff)
for (size_t i = prevOff; i < off; i++)
offs.insert(offset + i);
size_t subSize = skippedBytes(offs, ST->getElementType(idx), DL, prevOff);
prevOff = off + subSize;
}
return prevOff;
}

bool TypeResults::anyFloat(Value *val, bool anythingIsFloat) const {
assert(val);
assert(val->getType());
auto q = query(val);
auto dt = q[{-1}];
if (!anythingIsFloat && dt == BaseType::Anything)
return false;
if (dt != BaseType::Anything && dt != BaseType::Unknown)
return dt.isFloat();

size_t ObjSize = 1;
if (val->getType()->isTokenTy())
return false;
auto &dl = analyzer->fntypeinfo.Function->getParent()->getDataLayout();
if (val->getType()->isSized())
ObjSize = (dl.getTypeSizeInBits(val->getType()) + 7) / 8;
SmallSet<size_t, 8> offs;
size_t ObjSize = skippedBytes(offs, val->getType(), dl);

for (size_t i = 0; i < ObjSize;) {
dt = q[{(int)i}];
if (dt == BaseType::Integer) {
i++;
continue;
}
if (!anythingIsFloat && dt == BaseType::Integer) {
i++;
continue;
}
if (dt == BaseType::Pointer) {
i += dl.getPointerSize(0);
continue;
}
if (offs.count(i)) {
i++;
continue;
}
return true;
}
return false;
Expand All @@ -5923,11 +5954,12 @@ bool TypeResults::anyPointer(Value *val) const {
auto dt = q[{-1}];
if (dt != BaseType::Anything && dt != BaseType::Unknown)
return dt == BaseType::Pointer;
if (val->getType()->isTokenTy())
return false;

size_t ObjSize = 1;
auto &dl = analyzer->fntypeinfo.Function->getParent()->getDataLayout();
if (val->getType()->isSized())
ObjSize = (dl.getTypeSizeInBits(val->getType()) + 7) / 8;
SmallSet<size_t, 8> offs;
size_t ObjSize = skippedBytes(offs, val->getType(), dl);

for (size_t i = 0; i < ObjSize;) {
dt = q[{(int)i}];
Expand All @@ -5939,6 +5971,10 @@ bool TypeResults::anyPointer(Value *val) const {
i += (dl.getTypeSizeInBits(FT) + 7) / 8;
continue;
}
if (offs.count(i)) {
i++;
continue;
}
return true;
}
return false;
Expand Down
4 changes: 3 additions & 1 deletion enzyme/Enzyme/TypeAnalysis/TypeAnalysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,9 @@ class TypeResults {
/// Whether any part of the top level register can contain a float
/// e.g. { i64, float } can contain a float, but { i64, i8* } would not.
// Of course, here we compute with type analysis rather than llvm type
bool anyFloat(llvm::Value *val) const;
// The flag `anythingIsFloat` specifies whether an anything should
// be considered a float.
bool anyFloat(llvm::Value *val, bool anythingIsFloat = true) const;

/// Whether any part of the top level register can contain a pointer
/// e.g. { i64, i8* } can contain a pointer, but { i64, float } would not.
Expand Down

0 comments on commit 2e164a0

Please sign in to comment.