diff --git a/Project.toml b/Project.toml index e9373264..a00e73c2 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "FiniteDifferences" uuid = "26cc04aa-876d-5657-8c51-4c34ba976000" -version = "0.10.2" +version = "0.10.3" [deps] LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" diff --git a/src/grad.jl b/src/grad.jl index 461f0c11..673aeb34 100644 --- a/src/grad.jl +++ b/src/grad.jl @@ -59,7 +59,7 @@ function jvp(fdm, f, (x, ẋ)::Tuple{Any, Any}) end function jvp(fdm, f, xẋs::Tuple{Any, Any}...) x, ẋ = collect(zip(xẋs...)) - return jvp(fdm, xs->f(xs...)[1], (x, ẋ)) + return jvp(fdm, xs->f(xs...), (x, ẋ)) end """ diff --git a/test/grad.jl b/test/grad.jl index 44097ceb..1bdaf953 100644 --- a/test/grad.jl +++ b/test/grad.jl @@ -4,14 +4,36 @@ using FiniteDifferences: grad, jacobian, _jvp, jvp, j′vp, _j′vp, to_vec @testset "jvp(::$T)" for T in (Float64,) rng, N, M, fdm = MersenneTwister(123456), 2, 3, central_fdm(5, 1) - x, y = randn(rng, T, N), randn(rng, T, M) - ẋ, ẏ = randn(rng, T, N), randn(rng, T, M) - xy, ẋẏ = vcat(x, y), vcat(ẋ, ẏ) - ż_manual = _jvp(fdm, (xy)->sum(sin, xy), xy, ẋẏ)[1] - ż_auto = jvp(fdm, x->sum(sin, x[1]) + sum(sin, x[2]), ((x, y), (ẋ, ẏ))) - ż_multi = jvp(fdm, (x, y)->sum(sin, x) + sum(sin, y), (x, ẋ), (y, ẏ)) - @test ż_manual ≈ ż_auto - @test ż_manual ≈ ż_multi + @testset "scalar output" begin + x, y = randn(rng, T, N), randn(rng, T, M) + ẋ, ẏ = randn(rng, T, N), randn(rng, T, M) + xy, ẋẏ = vcat(x, y), vcat(ẋ, ẏ) + ż_manual = _jvp(fdm, (xy)->sum(sin, xy), xy, ẋẏ)[1] + ż_auto = jvp(fdm, x->sum(sin, x[1]) + sum(sin, x[2]), ((x, y), (ẋ, ẏ))) + ż_multi = jvp(fdm, (x, y)->sum(sin, x) + sum(sin, y), (x, ẋ), (y, ẏ)) + @test ż_manual ≈ ż_auto + @test ż_manual ≈ ż_multi + end + @testset "vector output" begin + x, y = randn(rng, T, N), randn(rng, T, N) + ẋ, ẏ = randn(rng, T, N), randn(rng, T, N) + ż_manual = @. cos(x) * ẋ + cos(y) * ẏ + ż_auto = jvp(fdm, x->sin.(x[1]) .+ sin.(x[2]), ((x, y), (ẋ, ẏ))) + ż_multi = jvp(fdm, (x, y)->sin.(x) .+ sin.(y), (x, ẋ), (y, ẏ)) + @test ż_manual ≈ ż_auto + @test ż_manual ≈ ż_multi + end + @testset "tuple output" begin + x, y = randn(rng, T, N), randn(rng, T, N) + ẋ, ẏ = randn(rng, T, N), randn(rng, T, N) + ż_manual = (cos.(x) .* ẋ, cos.(y) .* ẏ) + ż_auto = jvp(fdm, x->(sin.(x[1]), sin.(x[2])), ((x, y), (ẋ, ẏ))) + ż_multi = jvp(fdm, (x, y)->(sin.(x), sin.(y)), (x, ẋ), (y, ẏ)) + @test ż_auto isa Tuple + @test ż_multi isa Tuple + @test collect(ż_manual) ≈ collect(ż_auto) + @test collect(ż_manual) ≈ collect(ż_multi) + end end @testset "grad(::$T)" for T in (Float64,)