From 7fee77252382a6bdd0a6ead7c37fa1136eb14653 Mon Sep 17 00:00:00 2001 From: William Moses Date: Sun, 23 Jun 2024 17:08:05 -0400 Subject: [PATCH] Fix shadow return usage (#1939) --- enzyme/Enzyme/EnzymeLogic.cpp | 27 +++---- enzyme/Enzyme/EnzymeLogic.h | 6 +- enzyme/test/Enzyme/ReverseMode/shadowret.ll | 80 +++++++++++++++++++++ 3 files changed, 98 insertions(+), 15 deletions(-) create mode 100644 enzyme/test/Enzyme/ReverseMode/shadowret.ll diff --git a/enzyme/Enzyme/EnzymeLogic.cpp b/enzyme/Enzyme/EnzymeLogic.cpp index 348b25f8143..b8a8c87d8e9 100644 --- a/enzyme/Enzyme/EnzymeLogic.cpp +++ b/enzyme/Enzyme/EnzymeLogic.cpp @@ -2017,7 +2017,7 @@ const AugmentedReturn &EnzymeLogic::CreateAugmentedPrimal( return insert_or_assign( AugmentedCachedFunctions, tup, AugmentedReturn(newFunc, nullptr, {}, returnMapping, {}, {}, - constant_args)) + constant_args, shadowReturnUsed)) ->second; } if (context.req) { @@ -2028,7 +2028,7 @@ const AugmentedReturn &EnzymeLogic::CreateAugmentedPrimal( return insert_or_assign( AugmentedCachedFunctions, tup, AugmentedReturn(newFunc, nullptr, {}, returnMapping, {}, {}, - constant_args)) + constant_args, shadowReturnUsed)) ->second; } llvm::errs() << "mod: " << *todiff->getParent() << "\n"; @@ -2149,7 +2149,8 @@ const AugmentedReturn &EnzymeLogic::CreateAugmentedPrimal( AugmentedCachedFunctions, tup, AugmentedReturn(NewF, aug.tapeType, aug.tapeIndices, aug.returns, aug.overwritten_args_map, - aug.can_modref_map, next_constant_args)) + aug.can_modref_map, next_constant_args, + shadowReturnUsed)) ->second; } @@ -2213,7 +2214,7 @@ const AugmentedReturn &EnzymeLogic::CreateAugmentedPrimal( return insert_or_assign( AugmentedCachedFunctions, tup, AugmentedReturn(foundcalled, nullptr, {}, returnMapping, {}, - {}, constant_args)) + {}, constant_args, shadowReturnUsed)) ->second; } @@ -2271,7 +2272,7 @@ const AugmentedReturn &EnzymeLogic::CreateAugmentedPrimal( return insert_or_assign( AugmentedCachedFunctions, tup, AugmentedReturn(foundcalled, nullptr, {}, returnMapping, {}, - {}, constant_args)) + {}, constant_args, shadowReturnUsed)) ->second; } if (ST->getNumElements() == 2 && @@ -2282,7 +2283,7 @@ const AugmentedReturn &EnzymeLogic::CreateAugmentedPrimal( return insert_or_assign( AugmentedCachedFunctions, tup, AugmentedReturn(foundcalled, nullptr, {}, returnMapping, {}, - {}, constant_args)) + {}, constant_args, shadowReturnUsed)) ->second; } if (ST->getNumElements() == 2) { @@ -2336,7 +2337,7 @@ const AugmentedReturn &EnzymeLogic::CreateAugmentedPrimal( return insert_or_assign( AugmentedCachedFunctions, tup, AugmentedReturn(foundcalled, nullptr, {}, returnMapping, {}, - {}, constant_args)) + {}, constant_args, shadowReturnUsed)) ->second; } } @@ -2348,7 +2349,7 @@ const AugmentedReturn &EnzymeLogic::CreateAugmentedPrimal( return insert_or_assign( AugmentedCachedFunctions, tup, AugmentedReturn(foundcalled, nullptr, {}, returnMapping, {}, {}, - constant_args)) + constant_args, shadowReturnUsed)) ->second; // dyn_cast(st->getElementType(0))); } @@ -2393,7 +2394,7 @@ const AugmentedReturn &EnzymeLogic::CreateAugmentedPrimal( return insert_or_assign( AugmentedCachedFunctions, tup, AugmentedReturn(newFunc, nullptr, {}, returnMapping, {}, {}, - constant_args)) + constant_args, shadowReturnUsed)) ->second; } if (context.req) { @@ -2404,7 +2405,7 @@ const AugmentedReturn &EnzymeLogic::CreateAugmentedPrimal( return insert_or_assign( AugmentedCachedFunctions, tup, AugmentedReturn(newFunc, nullptr, {}, returnMapping, {}, {}, - constant_args)) + constant_args, shadowReturnUsed)) ->second; } llvm::errs() << "mod: " << *todiff->getParent() << "\n"; @@ -2455,7 +2456,7 @@ const AugmentedReturn &EnzymeLogic::CreateAugmentedPrimal( insert_or_assign(AugmentedCachedFunctions, tup, AugmentedReturn(gutils->newFunc, nullptr, {}, returnMapping, overwritten_args_map, can_modref_map, - constant_args)); + constant_args, shadowReturnUsed)); auto getIndex = [&](Instruction *I, CacheType u, IRBuilder<> &B) -> unsigned { return gutils->getIndex( @@ -4135,8 +4136,8 @@ Function *EnzymeLogic::CreatePrimalAndGradient( DiffeGradientUtils *gutils = DiffeGradientUtils::CreateFromClone( *this, key.mode, key.width, key.todiff, TLI, TA, oldTypeInfo, key.retType, - key.shadowReturnUsed, diffeReturnArg, key.constant_args, retVal, - key.additionalType, omp); + augmenteddata ? augmenteddata->shadowReturnUsed : key.shadowReturnUsed, + diffeReturnArg, key.constant_args, retVal, key.additionalType, omp); gutils->AtomicAdd = key.AtomicAdd; gutils->FreeMemory = key.freeMemory; diff --git a/enzyme/Enzyme/EnzymeLogic.h b/enzyme/Enzyme/EnzymeLogic.h index ae23fb74a78..09de0c67c99 100644 --- a/enzyme/Enzyme/EnzymeLogic.h +++ b/enzyme/Enzyme/EnzymeLogic.h @@ -121,6 +121,8 @@ class AugmentedReturn { const std::vector constant_args; + bool shadowReturnUsed; + bool isComplete; AugmentedReturn( @@ -129,11 +131,11 @@ class AugmentedReturn { std::map returns, std::map> overwritten_args_map, std::map can_modref_map, - const std::vector &constant_args) + const std::vector &constant_args, bool shadowReturnUsed) : fn(fn), tapeType(tapeType), tapeIndices(tapeIndices), returns(returns), overwritten_args_map(overwritten_args_map), can_modref_map(can_modref_map), constant_args(constant_args), - isComplete(false) {} + shadowReturnUsed(shadowReturnUsed), isComplete(false) {} }; /// \p todiff is the function to differentiate diff --git a/enzyme/test/Enzyme/ReverseMode/shadowret.ll b/enzyme/test/Enzyme/ReverseMode/shadowret.ll new file mode 100644 index 00000000000..2ed324af0ec --- /dev/null +++ b/enzyme/test/Enzyme/ReverseMode/shadowret.ll @@ -0,0 +1,80 @@ +; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme-preopt=false -enzyme -mem2reg -early-cse -instsimplify -simplifycfg -S | FileCheck %s; fi +; RUN: %opt < %s %newLoadEnzyme -enzyme-preopt=false -passes="enzyme,function(mem2reg,early-cse,instsimplify,%simplifycfg)" -S | FileCheck %s + +define internal fastcc nonnull double* @julia_mygetindex_2341(double* %a12) { +top: + %a13 = load double, double* %a12, align 8 + %a22 = call noalias nonnull double* @malloc(i64 24) + br label %L8 + +L8: ; preds = %idxend2, %idxend + %a21 = phi i64 [ 0, %top ], [ %a24, %L8 ] + %a23 = getelementptr inbounds double, double* %a22, i64 %a21 + store double %a13, double* %a23, align 8 + %a24 = add nuw nsw i64 %a21, 1 + %.not5 = icmp ne i64 1, %a24 + br i1 %.not5, label %L8, label %L28 + +L28: ; preds = %idxend2 + ret double* %a22 +} + +define internal fastcc nonnull double* @mydiag(double* %a12) { +top: + %a4 = call fastcc nonnull double* @julia_mygetindex_2341(double* %a12) + ret double* %a4 +} + +declare void @__enzyme_autodiff(...) + +define void @caller(double* %x, double* %dx) { +entry: + call void (...) @__enzyme_autodiff(double (double*)* @f, metadata !"enzyme_dup", double* %x, double* %dx) + ret void +} + +define double @f(double* %a13) { +top: + %b13 = call fastcc nonnull double* @mydiag(double* %a13) + %a14 = load double, double* %b13, align 8 + ret double %a14 +} + +declare double* @malloc(i64) local_unnamed_addr + +; CHECK: define internal fastcc void @diffejulia_mygetindex_2341(double* %a12, double* %"a12'", { double*, double* } %tapeArg) +; CHECK-NEXT: top: +; CHECK-NEXT: %"a22'mi" = extractvalue { double*, double* } %tapeArg, 0 +; CHECK-NEXT: %a22 = extractvalue { double*, double* } %tapeArg, 1 +; CHECK-NEXT: br label %L8 + +; CHECK: L8: ; preds = %L8, %top +; CHECK-NEXT: %iv = phi i64 [ %iv.next, %L8 ], [ 0, %top ] +; CHECK-NEXT: %iv.next = add nuw nsw i64 %iv, 1 +; CHECK-NEXT: %.not5 = icmp ne i64 %iv.next, 1 +; CHECK-NEXT: br i1 %.not5, label %L8, label %invertL8 + +; CHECK: inverttop: ; preds = %invertL8 +; CHECK-NEXT: %[[a0:.+]] = bitcast double* %"a22'mi" to i8* +; CHECK-NEXT: call void @free(i8* nonnull %[[a0]]) +; CHECK-NEXT: %[[a1:.+]] = bitcast double* %a22 to i8* +; CHECK-NEXT: call void @free(i8* %[[a1]]) +; CHECK-NEXT: %[[a2:.+]] = load double, double* %"a12'", align 8, !alias.scope !5, !noalias !8 +; CHECK-NEXT: %[[a3:.+]] = fadd fast double %[[a2]], %[[a5:.+]] +; CHECK-NEXT: store double %[[a3]], double* %"a12'", align 8, !alias.scope !5, !noalias !8 +; CHECK-NEXT: ret void + +; CHECK: invertL8: ; preds = %L8, %incinvertL8 +; CHECK-NEXT: %"a13'de.0" = phi double [ %[[a5]], %incinvertL8 ], [ 0.000000e+00, %L8 ] +; CHECK-NEXT: %"iv'ac.0" = phi i64 [ %[[a7:.+]], %incinvertL8 ], [ 0, %L8 ] +; CHECK-NEXT: %"a23'ipg_unwrap" = getelementptr inbounds double, double* %"a22'mi", i64 %"iv'ac.0" +; CHECK-NEXT: %[[a4:.+]] = load double, double* %"a23'ipg_unwrap", align 8 +; CHECK-NEXT: store double 0.000000e+00, double* %"a23'ipg_unwrap", align 8 +; CHECK-NEXT: %[[a5]] = fadd fast double %"a13'de.0", %[[a4]] +; CHECK-NEXT: %[[a6:.+]] = icmp eq i64 %"iv'ac.0", 0 +; CHECK-NEXT: br i1 %[[a6]], label %inverttop, label %incinvertL8 + +; CHECK: incinvertL8: ; preds = %invertL8 +; CHECK-NEXT: %[[a7]] = add nsw i64 %"iv'ac.0", -1 +; CHECK-NEXT: br label %invertL8 +; CHECK-NEXT: }