diff --git a/test/defaults.jl b/test/defaults.jl index 28eca3c..5664e03 100644 --- a/test/defaults.jl +++ b/test/defaults.jl @@ -24,7 +24,8 @@ FDMBackend2() = FDMBackend2(central_fdm(5, 1)) const fdm_backend2 = FDMBackend2() AD.@primitive function pushforward_function(ab::FDMBackend2, f, xs...) return function (vs) - FDM.jvp(ab.alg, f, tuple.(xs, vs)...) + ws = FDM.jvp(ab.alg, f, tuple.(xs, vs)...) + return length(xs) == 1 ? (ws,) : ws end end