Skip to content

Commit

Permalink
MLIR fix return primal in reverse mode (#1856)
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed May 1, 2024
1 parent 0246d1a commit d5eac0f
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 6 deletions.
2 changes: 1 addition & 1 deletion enzyme/Enzyme/MLIR/Interfaces/EnzymeLogicReverse.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ FunctionOpInterface MEnzymeLogic::CreateReverseDiff(
auto buildFuncReturnOp = [&](OpBuilder &builder, Block *oBB) {
SmallVector<mlir::Value> retargs;
for (auto [arg, returnPrimal] :
llvm::zip(oBB->getArguments(), returnPrimals)) {
llvm::zip(oBB->getTerminator()->getOperands(), returnPrimals)) {
if (returnPrimal) {
retargs.push_back(gutils->getNewFromOriginal(arg));
}
Expand Down
13 changes: 8 additions & 5 deletions enzyme/test/MLIR/ReverseMode/multiret.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,13 @@ module {
// CHECK-NEXT: return %3, %1 : f64, f32
// CHECK-NEXT: }


// CHECK: func.func private @diffesquare_0(%arg0: f64, %arg1: i32, %arg2: f32, %arg3: f64, %arg4: f32) -> (f64, i32, f64, f32) {
// CHECK-NEXT: %0 = arith.mulf %arg4, %arg2 : f32
// CHECK-NEXT: %1 = arith.addf %0, %0 : f32
// CHECK-NEXT: %2 = arith.mulf %arg3, %arg0 : f64
// CHECK-NEXT: %3 = arith.addf %2, %2 : f64
// CHECK-NEXT: return %arg0, %arg1, %3, %1 : f64, i32, f64, f32
// CHECK-NEXT: %0 = arith.mulf %arg0, %arg0 : f64
// CHECK-NEXT: %1 = arith.muli %arg1, %arg1 : i32
// CHECK-NEXT: %2 = arith.mulf %arg4, %arg2 : f32
// CHECK-NEXT: %3 = arith.addf %2, %2 : f32
// CHECK-NEXT: %4 = arith.mulf %arg3, %arg0 : f64
// CHECK-NEXT: %5 = arith.addf %4, %4 : f64
// CHECK-NEXT: return %0, %1, %5, %3 : f64, i32, f64, f32
// CHECK-NEXT: }

0 comments on commit d5eac0f

Please sign in to comment.