Skip to content

Commit

Permalink
Fix compilation memory error (rust-lang#433)
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Jan 26, 2022
1 parent 8dfb09f commit 265646f
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 15 deletions.
2 changes: 1 addition & 1 deletion enzyme/Enzyme/AdjointGenerator.h
Original file line number Diff line number Diff line change
Expand Up @@ -5076,7 +5076,7 @@ class AdjointGenerator

#if LLVM_VERSION_MAJOR > 7
Value *d_reqp = Builder2.CreateLoad(
impi,
PointerType::getUnqual(impi),
Builder2.CreatePointerCast(
d_req, PointerType::getUnqual(PointerType::getUnqual(impi))));
#else
Expand Down
24 changes: 16 additions & 8 deletions enzyme/Enzyme/GradientUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1068,8 +1068,10 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM,
reverseBlockToPrimal[blocks[i]] = fwd;
IRBuilder<> B(blocks[i]);

unwrap_cache[blocks[i]] = unwrap_cache[oldB];
lookup_cache[blocks[i]] = lookup_cache[oldB];
for (auto pair : unwrap_cache[oldB])
unwrap_cache[blocks[i]].insert(pair);
for (auto pair : lookup_cache[oldB])
lookup_cache[blocks[i]].insert(pair);
auto PB = *done[std::make_pair(valparent, predBlocks[i])].begin();

if (auto inst = dyn_cast<Instruction>(
Expand Down Expand Up @@ -1151,8 +1153,10 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM,
unwrap_cache[bret][idx.first][idx.second] = toret;
}
unwrappedLoads[toret] = val;
unwrap_cache[bret] = unwrap_cache[oldB];
lookup_cache[bret] = lookup_cache[oldB];
for (auto pair : unwrap_cache[oldB])
unwrap_cache[bret].insert(pair);
for (auto pair : lookup_cache[oldB])
lookup_cache[bret].insert(pair);
return toret;
}
}
Expand Down Expand Up @@ -1253,8 +1257,10 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM,
}
IRBuilder<> B(blocks[i]);

unwrap_cache[blocks[i]] = unwrap_cache[oldB];
lookup_cache[blocks[i]] = lookup_cache[oldB];
for (auto pair : unwrap_cache[oldB])
unwrap_cache[blocks[i]].insert(pair);
for (auto pair : lookup_cache[oldB])
lookup_cache[blocks[i]].insert(pair);

if (auto inst =
dyn_cast<Instruction>(phi->getIncomingValueForBlock(PB))) {
Expand Down Expand Up @@ -1390,8 +1396,10 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM,
if (permitCache) {
unwrap_cache[bret][idx.first][idx.second] = toret;
}
unwrap_cache[bret] = unwrap_cache[oldB];
lookup_cache[bret] = lookup_cache[oldB];
for (auto pair : unwrap_cache[oldB])
unwrap_cache[bret].insert(pair);
for (auto pair : lookup_cache[oldB])
lookup_cache[bret].insert(pair);
unwrappedLoads[toret] = val;
return toret;
}
Expand Down
10 changes: 6 additions & 4 deletions enzyme/Enzyme/GradientUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -536,9 +536,9 @@ class GradientUtils : public CacheUtility {
Value *tape;

std::map<BasicBlock *,
std::map<Value *, std::map<BasicBlock *, WeakTrackingVH>>>
ValueMap<Value *, std::map<BasicBlock *, WeakTrackingVH>>>
unwrap_cache;
std::map<BasicBlock *, std::map<Value *, WeakTrackingVH>> lookup_cache;
std::map<BasicBlock *, ValueMap<Value *, WeakTrackingVH>> lookup_cache;

public:
BasicBlock *addReverseBlock(BasicBlock *currentBlock, Twine name,
Expand All @@ -557,8 +557,10 @@ class GradientUtils : public CacheUtility {
vec.push_back(rev);
reverseBlockToPrimal[rev] = found->second;
if (forkCache) {
unwrap_cache[rev] = unwrap_cache[currentBlock];
lookup_cache[rev] = lookup_cache[currentBlock];
for (auto pair : unwrap_cache[currentBlock])
unwrap_cache[rev].insert(pair);
for (auto pair : lookup_cache[currentBlock])
lookup_cache[rev].insert(pair);
}
return rev;
}
Expand Down
4 changes: 2 additions & 2 deletions enzyme/test/Enzyme/ReverseMode/initializemany.ll
Original file line number Diff line number Diff line change
Expand Up @@ -196,8 +196,8 @@ attributes #4 = { nounwind }
; CHECK-NEXT: store double 0.000000e+00, double* %[[bc]], align 8
; CHECK-NEXT: %[[added]] = fadd fast double %"x'de.0", %[[load]]
; CHECK-NEXT: tail call void @free(i8* nonnull %[[metaload]])
; CHECK-NEXT: %_unwrap8 = getelementptr inbounds i8*, i8** %1, i64 %"iv'ac.0"
; CHECK-NEXT: %call_unwrap = load i8*, i8** %_unwrap8, align 8, !invariant.group !11
; CHECK-NEXT: %[[_unwrap8:.+]] = getelementptr inbounds i8*, i8** %1, i64 %"iv'ac.0"
; CHECK-NEXT: %call_unwrap = load i8*, i8** %[[_unwrap8]], align 8, !invariant.group !
; CHECK-NEXT: tail call void @free(i8* %call_unwrap)
; CHECK-NEXT: %[[lcmp:.+]] = icmp eq i64 %[[sub]], 0
; CHECK-NEXT: br i1 %[[lcmp]], label %invertentry, label %invertfor.body
Expand Down

0 comments on commit 265646f

Please sign in to comment.