diff --git a/src/Enzyme.jl b/src/Enzyme.jl index 7cf978ff9a..e263e4d716 100644 --- a/src/Enzyme.jl +++ b/src/Enzyme.jl @@ -496,12 +496,11 @@ function overload_autodiff( func2.operation = MLIR.API.MlirOperation(C_NULL) if reverse - resv = if EnzymeCore.needs_primal(CMode) - result + if EnzymeCore.needs_primal(CMode) + return ((restup...,), result) else - nothing + return ((restup...,),) end - return ((restup...,), resv) else if EnzymeCore.needs_primal(CMode) if CMode <: ForwardMode && !(A <: Const) diff --git a/test/autodiff.jl b/test/autodiff.jl index 2c3f8a0fac..10daa32ee8 100644 --- a/test/autodiff.jl +++ b/test/autodiff.jl @@ -407,3 +407,12 @@ end @test results_fd[2].y ≈ results_enz[2].y @test results_fd[3] ≈ results_enz[3] end + +@testset "Correct return tuple" begin + # issue 1875 + x = ones(2) + xr = Reactant.to_rarray(x) + res = autodiff(Reverse, sum, Duplicated(x, zero(x))) + res_reactant = @jit autodiff(Reverse, sum, Duplicated(xr, zero(xr))) + @test length(res) == length(res_reactant) +end