From 265646fe0fc4ac0f3add2216fcd171735610ea5f Mon Sep 17 00:00:00 2001 From: William Moses Date: Wed, 26 Jan 2022 01:13:19 -0500 Subject: [PATCH] Fix compilation memory error (#433) --- enzyme/Enzyme/AdjointGenerator.h | 2 +- enzyme/Enzyme/GradientUtils.cpp | 24 ++++++++++++------- enzyme/Enzyme/GradientUtils.h | 10 ++++---- .../test/Enzyme/ReverseMode/initializemany.ll | 4 ++-- 4 files changed, 25 insertions(+), 15 deletions(-) diff --git a/enzyme/Enzyme/AdjointGenerator.h b/enzyme/Enzyme/AdjointGenerator.h index e4c9fd06fe35c..150c6c3ca20e9 100644 --- a/enzyme/Enzyme/AdjointGenerator.h +++ b/enzyme/Enzyme/AdjointGenerator.h @@ -5076,7 +5076,7 @@ class AdjointGenerator #if LLVM_VERSION_MAJOR > 7 Value *d_reqp = Builder2.CreateLoad( - impi, + PointerType::getUnqual(impi), Builder2.CreatePointerCast( d_req, PointerType::getUnqual(PointerType::getUnqual(impi)))); #else diff --git a/enzyme/Enzyme/GradientUtils.cpp b/enzyme/Enzyme/GradientUtils.cpp index 1328c32473e6a..d9d25204c8b95 100644 --- a/enzyme/Enzyme/GradientUtils.cpp +++ b/enzyme/Enzyme/GradientUtils.cpp @@ -1068,8 +1068,10 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM, reverseBlockToPrimal[blocks[i]] = fwd; IRBuilder<> B(blocks[i]); - unwrap_cache[blocks[i]] = unwrap_cache[oldB]; - lookup_cache[blocks[i]] = lookup_cache[oldB]; + for (auto pair : unwrap_cache[oldB]) + unwrap_cache[blocks[i]].insert(pair); + for (auto pair : lookup_cache[oldB]) + lookup_cache[blocks[i]].insert(pair); auto PB = *done[std::make_pair(valparent, predBlocks[i])].begin(); if (auto inst = dyn_cast( @@ -1151,8 +1153,10 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM, unwrap_cache[bret][idx.first][idx.second] = toret; } unwrappedLoads[toret] = val; - unwrap_cache[bret] = unwrap_cache[oldB]; - lookup_cache[bret] = lookup_cache[oldB]; + for (auto pair : unwrap_cache[oldB]) + unwrap_cache[bret].insert(pair); + for (auto pair : lookup_cache[oldB]) + lookup_cache[bret].insert(pair); return toret; } } @@ -1253,8 +1257,10 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM, } IRBuilder<> B(blocks[i]); - unwrap_cache[blocks[i]] = unwrap_cache[oldB]; - lookup_cache[blocks[i]] = lookup_cache[oldB]; + for (auto pair : unwrap_cache[oldB]) + unwrap_cache[blocks[i]].insert(pair); + for (auto pair : lookup_cache[oldB]) + lookup_cache[blocks[i]].insert(pair); if (auto inst = dyn_cast(phi->getIncomingValueForBlock(PB))) { @@ -1390,8 +1396,10 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM, if (permitCache) { unwrap_cache[bret][idx.first][idx.second] = toret; } - unwrap_cache[bret] = unwrap_cache[oldB]; - lookup_cache[bret] = lookup_cache[oldB]; + for (auto pair : unwrap_cache[oldB]) + unwrap_cache[bret].insert(pair); + for (auto pair : lookup_cache[oldB]) + lookup_cache[bret].insert(pair); unwrappedLoads[toret] = val; return toret; } diff --git a/enzyme/Enzyme/GradientUtils.h b/enzyme/Enzyme/GradientUtils.h index f14fb9ea2f6ff..1882f39163ebe 100644 --- a/enzyme/Enzyme/GradientUtils.h +++ b/enzyme/Enzyme/GradientUtils.h @@ -536,9 +536,9 @@ class GradientUtils : public CacheUtility { Value *tape; std::map>> + ValueMap>> unwrap_cache; - std::map> lookup_cache; + std::map> lookup_cache; public: BasicBlock *addReverseBlock(BasicBlock *currentBlock, Twine name, @@ -557,8 +557,10 @@ class GradientUtils : public CacheUtility { vec.push_back(rev); reverseBlockToPrimal[rev] = found->second; if (forkCache) { - unwrap_cache[rev] = unwrap_cache[currentBlock]; - lookup_cache[rev] = lookup_cache[currentBlock]; + for (auto pair : unwrap_cache[currentBlock]) + unwrap_cache[rev].insert(pair); + for (auto pair : lookup_cache[currentBlock]) + lookup_cache[rev].insert(pair); } return rev; } diff --git a/enzyme/test/Enzyme/ReverseMode/initializemany.ll b/enzyme/test/Enzyme/ReverseMode/initializemany.ll index 5f1932a7c44f5..43c77a2443618 100644 --- a/enzyme/test/Enzyme/ReverseMode/initializemany.ll +++ b/enzyme/test/Enzyme/ReverseMode/initializemany.ll @@ -196,8 +196,8 @@ attributes #4 = { nounwind } ; CHECK-NEXT: store double 0.000000e+00, double* %[[bc]], align 8 ; CHECK-NEXT: %[[added]] = fadd fast double %"x'de.0", %[[load]] ; CHECK-NEXT: tail call void @free(i8* nonnull %[[metaload]]) -; CHECK-NEXT: %_unwrap8 = getelementptr inbounds i8*, i8** %1, i64 %"iv'ac.0" -; CHECK-NEXT: %call_unwrap = load i8*, i8** %_unwrap8, align 8, !invariant.group !11 +; CHECK-NEXT: %[[_unwrap8:.+]] = getelementptr inbounds i8*, i8** %1, i64 %"iv'ac.0" +; CHECK-NEXT: %call_unwrap = load i8*, i8** %[[_unwrap8]], align 8, !invariant.group ! ; CHECK-NEXT: tail call void @free(i8* %call_unwrap) ; CHECK-NEXT: %[[lcmp:.+]] = icmp eq i64 %[[sub]], 0 ; CHECK-NEXT: br i1 %[[lcmp]], label %invertentry, label %invertfor.body