From 680302c9997a1c5db2a527609338ccfe647216c5 Mon Sep 17 00:00:00 2001 From: Michael Abbott Date: Wed, 27 Jan 2021 12:23:39 +0100 Subject: [PATCH 01/17] hessian tweaks, committed first to help git out --- src/lib/utils.jl | 46 ++++++++++++++++++++++++++++++++++------------ 1 file changed, 34 insertions(+), 12 deletions(-) diff --git a/src/lib/utils.jl b/src/lib/utils.jl index e51ff2d3c..23ca64749 100644 --- a/src/lib/utils.jl +++ b/src/lib/utils.jl @@ -31,11 +31,11 @@ ignore(f) = f() Tell Zygote to ignore an expression. Equivalent to `ignore() do (...) end`. Example: -```julia-repl -julia> f(x) = (y = Zygote.@ignore x; x * y); +```julia-repl +julia> f(x) = (y = Zygote.@ignore x; x * y); julia> f'(1) 1 -``` +``` """ macro ignore(ex) return :(Zygote.ignore() do @@ -102,16 +102,38 @@ end """ hessian(f, x) -Construct the Hessian of `f`, where `x` is a real or real array and `f(x)` is -a real. +Construct the Hessian `∂²f/∂x∂x`, where `x` is a real number or an array, +and `f(x)` is a real number. - julia> hessian(((a, b),) -> a*b, [2, 3]) - 2×2 Array{Int64,2}: - 0 1 - 1 0 +Uses forward over reverse, ForwardDiff over Zygote, by default: `hessian_dual(f, x)`. + +# Examples + +```jldoctest +julia> Zygote.hessian(x -> x[1]*x[2], randn(2)) +2×2 Array{Float64,2}: + 0.0 1.0 + 1.0 0.0 + +julia> Zygote.hessian(x -> sum(x.^3), [1 2; 3 4]) # uses linear indexing of x +4×4 Array{$Int,2}: + 6 0 0 0 + 0 18 0 0 + 0 0 12 0 + 0 0 0 24 + +julia> Zygote.hessian(sin, pi/2) +-1.0 +``` """ hessian(f, x::AbstractArray) = forward_jacobian(x -> gradient(f, x)[1], x)[2] +hessian(f, x) = hessian_dual(f, x) + +hessian_dual(f, x::AbstractArray) = forward_jacobian(x -> gradient(f, x)[1], x)[2] + +hessian_dual(f, x::Number) = ForwardDiff.derivative(x -> gradient(f, x)[1], x) + """ isderiving() isderiving(x) @@ -122,13 +144,13 @@ Check whether the current function call is happening while taking the derivative julia> function f(x) @show isderiving() end - + f (generic function with 1 method) - + julia> f(3) isderiving() = false false - + julia> gradient(f, 4) isderiving() = true (nothing,) From 93a4c5bfc350f2212f8ff8c93c9a597600c7ead3 Mon Sep 17 00:00:00 2001 From: Michael Abbott Date: Wed, 27 Jan 2021 12:58:18 +0100 Subject: [PATCH 02/17] new jacobian implementation, and tests --- src/Zygote.jl | 2 +- src/lib/utils.jl | 154 +++++++++++++++++++++++++++++++++++++++++------ test/runtests.jl | 5 +- test/utils.jl | 30 +++++++++ 4 files changed, 172 insertions(+), 19 deletions(-) create mode 100644 test/utils.jl diff --git a/src/Zygote.jl b/src/Zygote.jl index a7b42cf2a..4b0401ada 100644 --- a/src/Zygote.jl +++ b/src/Zygote.jl @@ -12,7 +12,7 @@ using MacroTools, Requires using MacroTools: @forward import Distributed: pmap, CachingPool, workers -export Params, gradient, pullback, pushforward, @code_adjoint +export Params, gradient, jacobian, pullback, pushforward, @code_adjoint include("tools/idset.jl") include("tools/buffer.jl") diff --git a/src/lib/utils.jl b/src/lib/utils.jl index 23ca64749..8b9a2535e 100644 --- a/src/lib/utils.jl +++ b/src/lib/utils.jl @@ -99,6 +99,32 @@ macro showgrad(x) end) end +""" + isderiving() + isderiving(x) + +Check whether the current function call is happening while taking the derivative. + + + julia> function f(x) + @show isderiving() + end + + f (generic function with 1 method) + + julia> f(3) + isderiving() = false + false + + julia> gradient(f, 4) + isderiving() = true + (nothing,) +""" +isderiving() = false +isderiving(x) = false +@adjoint isderiving() = true, _ -> nothing +@adjoint isderiving(x) = true, x -> (nothing,) + """ hessian(f, x) @@ -134,28 +160,122 @@ hessian_dual(f, x::AbstractArray) = forward_jacobian(x -> gradient(f, x)[1], x)[ hessian_dual(f, x::Number) = ForwardDiff.derivative(x -> gradient(f, x)[1], x) + """ - isderiving() - isderiving(x) + jacobian(f, args...) -Check whether the current function call is happening while taking the derivative. +For each array `a ∈ args` this returns a matrix with `Ja[k,i] = ∂y[k]/∂a[i]` +where `y = f(args...)` is usually a vector. +For scalar `x::Number ∈ args`, the result `Jx[k,1] = ∂y[k]/∂x` is a vector, +while for scalar `y` all results have just one row. +For any other argument type, no result is produced, even if [`gradient`](@ref) would work. - julia> function f(x) - @show isderiving() - end +This reverse-mode Jacobian needs to evaluate the pullback once for each element of `y`. +This is usually only efficient when `length(y)` is small compared to `length(a)`, +otherwise forward mode is likely to be better. - f (generic function with 1 method) +# Examples - julia> f(3) - isderiving() = false - false +```jldoctest +julia> jacobian(a -> 100*a[1:3].^2, 1:7)[1] # first index (rows) is output +3×7 Array{$Int,2}: + 200 0 0 0 0 0 0 + 0 400 0 0 0 0 0 + 0 0 600 0 0 0 0 - julia> gradient(f, 4) - isderiving() = true - (nothing,) +julia> jacobian((a,x) -> a.^2 .* x, [1,2,3], 1) # scalar argument has vector jacobian +([2 0 0; 0 4 0; 0 0 6], [1, 4, 9]) + +julia> jacobian((a,d) -> prod(a, dims=d), [1 2; 3 4; 5 6], 2) +([2 0 … 0 0; 0 4 … 3 0; 0 0 … 0 5], [0, 0, 0]) +``` + +!!! Warning: for arguments of any type except `Number` & `AbstractArray`, the result is `nothing`. + +```jldoctest +julia> jacobian((a,s) -> a.^length(s), [1,2,3], "str") +([3 0 0; 0 12 0; 0 0 27], nothing) + +julia> jacobian((a,t) -> sum(a .* t[1]) + t[2], [1,2,3], (4,5)) +([4 4 4], nothing) + +julia> gradient((a,t) -> sum(a .* t[1]) + t[2], [1,2,3], (4,5)) # gradient undersands the tuple +([4, 4, 4], (6, 1)) + ``` """ -isderiving() = false -isderiving(x) = false -@adjoint isderiving() = true, _ -> nothing -@adjoint isderiving(x) = true, x -> (nothing,) +function jacobian(f, args...) + y, back = pullback(_jvec∘f, args...) + out = map(args) do x + T = promote_type(eltype(x), eltype(y)) + dx = x isa AbstractArray ? similar(x, T, length(y), length(x)) : + x isa Number ? similar(y, T, length(y)) : + nothing + end + delta = fill!(similar(y), 0) + for k in LinearIndices(y) + delta[k] = 1 + grads = back(delta) + for (dx, grad) in zip(out, grads) + dx isa AbstractArray || continue + _gradcopy!(view(dx,k,:), grad) + end + delta[k] = 0 + end + out +end + +_jvec(x::AbstractArray) = vec(x) +_jvec(x::Number) = vcat(x) +_jvec(x) = throw(ArgumentError("jacobian expected a function which returns an array, or a scalar, got $(typeof(x))")) + +_gradcopy!(dst::AbstractArray, src::AbstractArray{<:Number}) = copyto!(dst, src) +_gradcopy!(dst::AbstractArray, src::Number) = copyto!(dst, src) +_gradcopy!(dst::AbstractArray, ::Nothing) = dst .= false +_gradcopy!(dst::AbstractArray, src::AbstractArray) = copyto!(dst, g isa Number ? g : 0 for g in src) # e.g. Union{Nothing,Float64} + +""" + jacobian(loss, ::Params) + +Like `gradient` with implicit parameters, this method takes a zero-argument function +and returns an `IdDict`-like object, now containing the Jacobian for each parameter. + +# Examples +```jldoctest +julia> xs = [1 2; 3 4]; ys = [5,7,9]; + +julia> Jxy = jacobian(() -> ys[1:2] .+ sum(xs.^2), Params([xs, ys])) +Grads(...) + +julia> Jxy[ys] +2×3 Array{$Int,2}: + 1 0 0 + 0 1 0 + +julia> Jxy[xs] +2×4 Array{$Int,2}: + 2 6 4 8 + 2 6 4 8 +``` +""" +function jacobian(f, pars::Params) + y, back = pullback(_jvec∘f, pars) + out = IdDict() + for p in pars + T = Base.promote_type(eltype(p), eltype(y)) + J = similar(y, T, length(y), length(p)) + out[p] = J + end + delta = fill!(similar(y), 0) + for k in LinearIndices(y) + delta[k] = 1 + grads = back(delta) + for p in pars + out[p] isa AbstractArray || continue + _gradcopy!(view(out[p],k,:), grads[p]) + end + delta[k] = 0 + end + Grads(out, pars) +end + diff --git a/test/runtests.jl b/test/runtests.jl index 5e7ca4e62..c6737ac73 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -6,11 +6,14 @@ using CUDA: has_cuda include("interface.jl") end - @testset "Tools" begin include("tools.jl") end +@testset "Utils" begin + include("utils.jl") +end + @testset "lib/number" begin include("lib/number.jl") end diff --git a/test/utils.jl b/test/utils.jl new file mode 100644 index 000000000..9883fea0d --- /dev/null +++ b/test/utils.jl @@ -0,0 +1,30 @@ +using LinearAlgebra + +@testset "jacobian(f, x, y)" begin + j1 = jacobian((a,x) -> a.^2 .* x, [1,2,3], 1) + @test j1[1] ≈ Diagonal([2,4,6]) + @test j1[2] ≈ [1, 4, 9] + @test j1[2] isa Vector + + j2 = jacobian((a,t) -> sum(a .* t[1]) + t[2], [1,2,3], (4,5)) + @test j2[1] == [4 4 4] + @test j2[1] isa Matrix + @test j2[2] === nothing # input other than Number, Array is ignored + + j3 = jacobian((a,d) -> prod(a, dims=d), [1 2; 3 4], 1) + @test j3[1] ≈ [3 1 0 0; 0 0 4 2] + @test j3[2] ≈ [0, 0] # pullback is always Nothing, but array already allocated + + j4 = jacobian([1,2,-3,4,-5]) do xs + map(x -> x>0 ? x^3 : 0, xs) # pullback gives Nothing for some elements x + end + @test j4[1] ≈ Diagonal([3,12,0,48,0]) +end + +@testset "jacobian(loss, ::Params)" begin + xs = [1 2; 3 4] + ys = [5,7,9]; + Jxy = jacobian(() -> ys[1:2] .+ sum(xs.^2), Params([xs, ys])) + @test Jxy[ys] ≈ [1 0 0; 0 1 0] + @test Jxy[xs] ≈ [2 6 4 8; 2 6 4 8] +end From a6f0ef0f1330dbb96128bb169e01a9354e3895e8 Mon Sep 17 00:00:00 2001 From: Michael Abbott Date: Wed, 27 Jan 2021 12:59:13 +0100 Subject: [PATCH 03/17] hessian_reverse, and tests --- src/lib/utils.jl | 13 +++++++++++++ test/utils.jl | 18 ++++++++++++++++++ 2 files changed, 31 insertions(+) diff --git a/src/lib/utils.jl b/src/lib/utils.jl index 8b9a2535e..2315c6a7f 100644 --- a/src/lib/utils.jl +++ b/src/lib/utils.jl @@ -132,6 +132,7 @@ Construct the Hessian `∂²f/∂x∂x`, where `x` is a real number or an array, and `f(x)` is a real number. Uses forward over reverse, ForwardDiff over Zygote, by default: `hessian_dual(f, x)`. +See [`hessian_reverse`](@ref) for an all-Zygote version. # Examples @@ -160,6 +161,16 @@ hessian_dual(f, x::AbstractArray) = forward_jacobian(x -> gradient(f, x)[1], x)[ hessian_dual(f, x::Number) = ForwardDiff.derivative(x -> gradient(f, x)[1], x) +""" + hessian_reverse(f, x) + +This should be equivalent to [`hessian(f, x)`](@ref hessian), +but implemented using reverse over reverse mode, all Zygote. +(This is usually much slower, and more likely to find errors.) +""" +hessian_reverse(f, x::AbstractArray) = jacobian(x -> gradient(f, x)[1], x)[1] + +hessian_reverse(f, x::Number) = gradient(x -> gradient(f, x)[1], x)[1] """ jacobian(f, args...) @@ -175,6 +186,8 @@ This reverse-mode Jacobian needs to evaluate the pullback once for each element This is usually only efficient when `length(y)` is small compared to `length(a)`, otherwise forward mode is likely to be better. +See also [`hessian`](@ref), [`hessian_reverse`](@ref). + # Examples ```jldoctest diff --git a/test/utils.jl b/test/utils.jl index 9883fea0d..4d2488e71 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -1,4 +1,22 @@ using LinearAlgebra +using Zygote: hessian_dual, hessian_reverse + +@testset "hessian: $hess" for hess in [hessian_dual, hessian_reverse] + + if hess == hessian_dual + @test hess(x -> x[1]*x[2], randn(2)) ≈ [0 1; 1 0] + @test hess(((x,y),) -> x*y, randn(2)) ≈ [0 1; 1 0] # original docstring version + else + @test_broken hess(x -> x[1]*x[2], randn(2)) ≈ [0 1; 1 0] # can't differentiate ∇getindex + @test_broken hess(((x,y),) -> x*y, randn(2)) ≈ [0 1; 1 0] + end + @test hess(x -> sum(x.^3), [1 2; 3 4]) ≈ Diagonal([6, 18, 12, 24]) + @test hess(sin, pi/2) ≈ -1 + + @test_throws Exception hess(sin, im*pi) + @test_throws Exception hess(x -> x+im, pi) + @test_throws Exception hess(identity, randn(2)) +end @testset "jacobian(f, x, y)" begin j1 = jacobian((a,x) -> a.^2 .* x, [1,2,3], 1) From c1687cc89b0f49e92e344d85966cf35aa2430e90 Mon Sep 17 00:00:00 2001 From: Michael Abbott Date: Wed, 27 Jan 2021 13:52:49 +0100 Subject: [PATCH 04/17] brief docstring & xref for gradient --- src/compiler/interface.jl | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/src/compiler/interface.jl b/src/compiler/interface.jl index 38f7f5d7d..4535d946e 100644 --- a/src/compiler/interface.jl +++ b/src/compiler/interface.jl @@ -42,8 +42,17 @@ end sensitivity(y::Number) = one(y) sensitivity(y::Complex) = error("Output is complex, so the gradient is not defined.") +sensitivity(y::AbstractArray) = error("output an array, so the gradient is not defined. Perhaps you wanted jacobian.") sensitivity(y) = error("Output should be scalar; gradients are not defined for output $(repr(y))") +""" + gradient(f, args...) + +Returns a tuple containing `∂f/∂x` for each argument `x`, +the derivative (for scalar x) or the gradient. + +`f(x)` must be a real number, see [`jacobian`](@ref) for array output. +""" function gradient(f, args...) y, back = pullback(f, args...) return back(sensitivity(y)) From 418a74f8f829cfaedb97df67a4a19ad7388a2db1 Mon Sep 17 00:00:00 2001 From: Michael Abbott Date: Wed, 27 Jan 2021 14:43:51 +0100 Subject: [PATCH 05/17] a few more tests --- src/lib/utils.jl | 3 ++- test/utils.jl | 19 +++++++++++++++++-- 2 files changed, 19 insertions(+), 3 deletions(-) diff --git a/src/lib/utils.jl b/src/lib/utils.jl index 2315c6a7f..909b071c6 100644 --- a/src/lib/utils.jl +++ b/src/lib/utils.jl @@ -239,8 +239,9 @@ function jacobian(f, args...) end _jvec(x::AbstractArray) = vec(x) -_jvec(x::Number) = vcat(x) +_jvec(x::Number) = _jvec(vcat(x)) _jvec(x) = throw(ArgumentError("jacobian expected a function which returns an array, or a scalar, got $(typeof(x))")) +_jvec(x::AbstractArray{<:Complex}) = throw(ArgumentError("jacobian does not accept complex output")) _gradcopy!(dst::AbstractArray, src::AbstractArray{<:Number}) = copyto!(dst, src) _gradcopy!(dst::AbstractArray, src::Number) = copyto!(dst, src) diff --git a/test/utils.jl b/test/utils.jl index 4d2488e71..d09fc2dc2 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -18,13 +18,15 @@ using Zygote: hessian_dual, hessian_reverse @test_throws Exception hess(identity, randn(2)) end -@testset "jacobian(f, x, y)" begin +@testset "jacobian(f, args...)" begin + @test jacobian(identity, [1,2])[1] == [1 0; 0 1] + j1 = jacobian((a,x) -> a.^2 .* x, [1,2,3], 1) @test j1[1] ≈ Diagonal([2,4,6]) @test j1[2] ≈ [1, 4, 9] @test j1[2] isa Vector - j2 = jacobian((a,t) -> sum(a .* t[1]) + t[2], [1,2,3], (4,5)) + j2 = jacobian((a,t) -> sum(a .* t[1]) + t[2], [1,2,3], (4,5)) # scalar output is OK @test j2[1] == [4 4 4] @test j2[1] isa Matrix @test j2[2] === nothing # input other than Number, Array is ignored @@ -37,6 +39,19 @@ end map(x -> x>0 ? x^3 : 0, xs) # pullback gives Nothing for some elements x end @test j4[1] ≈ Diagonal([3,12,0,48,0]) + + j5 = jacobian((x,y) -> hcat(x[1], y), fill(pi), exp(1)) # zero-array + @test j5[1] isa Matrix + @test vec(j5[1]) == [1, 0] + + @test_throws ArgumentError jacobian(identity, [1,2,3+im]) + @test_throws ArgumentError jacobian(sum, [1,2,3+im]) # scalar, complex + + f6(x,y) = abs2.(x .* y) + g6 = gradient(first∘f6, [1+im, 2], 3+4im) + j6 = jacobian((x,y) -> abs2.(x .* y), [1+im, 2], 3+4im) + @test j6[1][1,:] ≈ g6[1] + @test j6[2][1] ≈ g6[2] end @testset "jacobian(loss, ::Params)" begin From 3f9e6c5e49c317ec6bbc11d7ee72d19e51f6a3db Mon Sep 17 00:00:00 2001 From: Michael Abbott Date: Wed, 27 Jan 2021 14:44:09 +0100 Subject: [PATCH 06/17] add to docs, Utilities page --- docs/src/utils.md | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/docs/src/utils.md b/docs/src/utils.md index af275c0dd..bafe5a27a 100644 --- a/docs/src/utils.md +++ b/docs/src/utils.md @@ -1,6 +1,14 @@ # Utilities -Zygote provides a set of helpful utilities. These are all "user-level" tools – +Zygote's gradients can be used to construct a Jacobian (by repeated evaluation) +or a Hessian (by taking a second derivative). + +```@docs +Zygote.jacobian +Zygote.hessian +``` + +Zygote also provides a set of helpful utilities. These are all "user-level" tools – in other words you could have written them easily yourself, but they live in Zygote for convenience. @@ -8,7 +16,6 @@ Zygote for convenience. Zygote.@showgrad Zygote.hook Zygote.dropgrad -Zygote.hessian Zygote.Buffer Zygote.forwarddiff Zygote.ignore From 62b64976f8f3284355d439c41a33e897af805afe Mon Sep 17 00:00:00 2001 From: Michael Abbott Date: Wed, 27 Jan 2021 17:17:48 +0100 Subject: [PATCH 07/17] docstring fix an indent bug --- src/lib/utils.jl | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/lib/utils.jl b/src/lib/utils.jl index 909b071c6..6080ed9e1 100644 --- a/src/lib/utils.jl +++ b/src/lib/utils.jl @@ -204,7 +204,8 @@ julia> jacobian((a,d) -> prod(a, dims=d), [1 2; 3 4; 5 6], 2) ([2 0 … 0 0; 0 4 … 3 0; 0 0 … 0 5], [0, 0, 0]) ``` -!!! Warning: for arguments of any type except `Number` & `AbstractArray`, the result is `nothing`. +!!! warning + For arguments of any type except `Number` & `AbstractArray`, the result is `nothing`. ```jldoctest julia> jacobian((a,s) -> a.^length(s), [1,2,3], "str") @@ -215,7 +216,7 @@ julia> jacobian((a,t) -> sum(a .* t[1]) + t[2], [1,2,3], (4,5)) julia> gradient((a,t) -> sum(a .* t[1]) + t[2], [1,2,3], (4,5)) # gradient undersands the tuple ([4, 4, 4], (6, 1)) - ``` +``` """ function jacobian(f, args...) y, back = pullback(_jvec∘f, args...) From 1d0d9dddf0f1ba7bc33257b41e4082d39525b5df Mon Sep 17 00:00:00 2001 From: Michael Abbott Date: Wed, 27 Jan 2021 18:01:30 +0100 Subject: [PATCH 08/17] change delta approach, add explicit CuArray test --- src/lib/utils.jl | 12 ++++-------- test/cuda.jl | 14 ++++++++++++++ 2 files changed, 18 insertions(+), 8 deletions(-) diff --git a/src/lib/utils.jl b/src/lib/utils.jl index 6080ed9e1..8250fac80 100644 --- a/src/lib/utils.jl +++ b/src/lib/utils.jl @@ -226,15 +226,13 @@ function jacobian(f, args...) x isa Number ? similar(y, T, length(y)) : nothing end - delta = fill!(similar(y), 0) + delta = Diagonal(fill!(similar(y), 1)) for k in LinearIndices(y) - delta[k] = 1 - grads = back(delta) + grads = back(delta[:,k]) for (dx, grad) in zip(out, grads) dx isa AbstractArray || continue _gradcopy!(view(dx,k,:), grad) end - delta[k] = 0 end out end @@ -281,15 +279,13 @@ function jacobian(f, pars::Params) J = similar(y, T, length(y), length(p)) out[p] = J end - delta = fill!(similar(y), 0) + delta = Diagonal(fill!(similar(y), 1)) for k in LinearIndices(y) - delta[k] = 1 - grads = back(delta) + grads = back(delta[:,k]) for p in pars out[p] isa AbstractArray || continue _gradcopy!(view(out[p],k,:), grads[p]) end - delta[k] = 0 end Grads(out, pars) end diff --git a/test/cuda.jl b/test/cuda.jl index 2820a776b..43538840b 100644 --- a/test/cuda.jl +++ b/test/cuda.jl @@ -15,3 +15,17 @@ end log_grada = cu(Float32[1.0, 0.5, 0.33333334, 0.25, 0.2, 0.16666667, 0.14285715, 0.125, 0.11111111]) @test gradient(x -> w(x) |> sum, a) == (log_grada,) end + +@testset "jacobian" begin + v1 = cu(collect(1:3f0)) + + res1 = jacobian(x -> x .* x', 1:3f0)[1] + j1 = jacobian(x -> x .* x', v1)[1] + @test j1 isa CuArray + @test j1 ≈ cu(res1) + + res2 = jacobian(x -> x ./ sum(x), 1:3f0)[1] + j2 = jacobian(() -> v1 ./ sum(v1), Params([v1])) + @test j2[v1] isa CuArray + @test j2[v1] ≈ cu(res2) +end From f6812e71ef7123bf98c7a2565cc6ccb750b42ff7 Mon Sep 17 00:00:00 2001 From: Michael Abbott Date: Wed, 27 Jan 2021 18:28:03 +0100 Subject: [PATCH 09/17] Diagonal -> diagm, sparse -> dense? --- src/lib/utils.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/lib/utils.jl b/src/lib/utils.jl index 8250fac80..a034c03aa 100644 --- a/src/lib/utils.jl +++ b/src/lib/utils.jl @@ -226,7 +226,7 @@ function jacobian(f, args...) x isa Number ? similar(y, T, length(y)) : nothing end - delta = Diagonal(fill!(similar(y), 1)) + delta = diagm(fill!(similar(y), 1)) for k in LinearIndices(y) grads = back(delta[:,k]) for (dx, grad) in zip(out, grads) @@ -279,7 +279,7 @@ function jacobian(f, pars::Params) J = similar(y, T, length(y), length(p)) out[p] = J end - delta = Diagonal(fill!(similar(y), 1)) + delta = diagm(fill!(similar(y), 1)) for k in LinearIndices(y) grads = back(delta[:,k]) for p in pars From c4d75337fe687148f709dfe74e72317558bc4ae4 Mon Sep 17 00:00:00 2001 From: Michael Abbott Date: Wed, 27 Jan 2021 19:50:19 +0100 Subject: [PATCH 10/17] another idea, and move Cu tests earlier --- src/lib/utils.jl | 7 +++++-- test/runtests.jl | 16 ++++++++-------- 2 files changed, 13 insertions(+), 10 deletions(-) diff --git a/src/lib/utils.jl b/src/lib/utils.jl index a034c03aa..6fe4fc589 100644 --- a/src/lib/utils.jl +++ b/src/lib/utils.jl @@ -226,7 +226,9 @@ function jacobian(f, args...) x isa Number ? similar(y, T, length(y)) : nothing end - delta = diagm(fill!(similar(y), 1)) + # delta = diagm(fill!(similar(y), 1)) + delta = fill!(similar(y, length(y), length(y)), 0) + delta[LinearAlgebra.diagind(delta)] .= 1 for k in LinearIndices(y) grads = back(delta[:,k]) for (dx, grad) in zip(out, grads) @@ -279,7 +281,8 @@ function jacobian(f, pars::Params) J = similar(y, T, length(y), length(p)) out[p] = J end - delta = diagm(fill!(similar(y), 1)) + delta = fill!(similar(y, length(y), length(y)), 0) + delta[LinearAlgebra.diagind(delta)] .= 1 for k in LinearIndices(y) grads = back(delta[:,k]) for p in pars diff --git a/test/runtests.jl b/test/runtests.jl index c6737ac73..6fbc75341 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -2,6 +2,14 @@ using Zygote, Test using Zygote: gradient using CUDA: has_cuda +if has_cuda() + @testset "CUDA tests" begin + include("cuda.jl") + end +else + @warn "CUDA not found - Skipping CUDA Tests" +end + @testset "Interface" begin include("interface.jl") end @@ -45,11 +53,3 @@ end @testset "Compiler" begin include("compiler.jl") end - -if has_cuda() - @testset "CUDA tests" begin - include("cuda.jl") - end -else - @warn "CUDA not found - Skipping CUDA Tests" -end From a347029fc9a47aee2968bd9c237aad7185a152b9 Mon Sep 17 00:00:00 2001 From: Michael Abbott Date: Wed, 27 Jan 2021 23:08:29 +0100 Subject: [PATCH 11/17] docstring tweaks --- src/lib/utils.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/lib/utils.jl b/src/lib/utils.jl index 6fe4fc589..1695bee50 100644 --- a/src/lib/utils.jl +++ b/src/lib/utils.jl @@ -131,12 +131,12 @@ isderiving(x) = false Construct the Hessian `∂²f/∂x∂x`, where `x` is a real number or an array, and `f(x)` is a real number. -Uses forward over reverse, ForwardDiff over Zygote, by default: `hessian_dual(f, x)`. +Uses forward over reverse, ForwardDiff over Zygote, calling `hessian_dual(f, x)`. See [`hessian_reverse`](@ref) for an all-Zygote version. # Examples -```jldoctest +```jldoctest; setup=:(using Zygote) julia> Zygote.hessian(x -> x[1]*x[2], randn(2)) 2×2 Array{Float64,2}: 0.0 1.0 @@ -246,7 +246,7 @@ _jvec(x::AbstractArray{<:Complex}) = throw(ArgumentError("jacobian does not acce _gradcopy!(dst::AbstractArray, src::AbstractArray{<:Number}) = copyto!(dst, src) _gradcopy!(dst::AbstractArray, src::Number) = copyto!(dst, src) -_gradcopy!(dst::AbstractArray, ::Nothing) = dst .= false +_gradcopy!(dst::AbstractArray, src::Nothing) = dst .= 0 _gradcopy!(dst::AbstractArray, src::AbstractArray) = copyto!(dst, g isa Number ? g : 0 for g in src) # e.g. Union{Nothing,Float64} """ From 65b561d12eda343be2391c4f17e75ba3550d730f Mon Sep 17 00:00:00 2001 From: Michael Abbott Date: Wed, 27 Jan 2021 23:12:58 +0100 Subject: [PATCH 12/17] move identity matrix creation to a function --- src/lib/utils.jl | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/src/lib/utils.jl b/src/lib/utils.jl index 1695bee50..e1847d12b 100644 --- a/src/lib/utils.jl +++ b/src/lib/utils.jl @@ -226,9 +226,7 @@ function jacobian(f, args...) x isa Number ? similar(y, T, length(y)) : nothing end - # delta = diagm(fill!(similar(y), 1)) - delta = fill!(similar(y, length(y), length(y)), 0) - delta[LinearAlgebra.diagind(delta)] .= 1 + delta = _eyelike(y) for k in LinearIndices(y) grads = back(delta[:,k]) for (dx, grad) in zip(out, grads) @@ -244,6 +242,13 @@ _jvec(x::Number) = _jvec(vcat(x)) _jvec(x) = throw(ArgumentError("jacobian expected a function which returns an array, or a scalar, got $(typeof(x))")) _jvec(x::AbstractArray{<:Complex}) = throw(ArgumentError("jacobian does not accept complex output")) +_eyelike(y::Vector) = Matrix{eltype(y)}(I, length(y), length(y)) +function _eyelike(y::AbstractVector) # version which works on GPU + out = fill!(similar(y, length(y), length(y)), 0) + out[LinearAlgebra.diagind(out)] .= 1 + out +end + _gradcopy!(dst::AbstractArray, src::AbstractArray{<:Number}) = copyto!(dst, src) _gradcopy!(dst::AbstractArray, src::Number) = copyto!(dst, src) _gradcopy!(dst::AbstractArray, src::Nothing) = dst .= 0 @@ -281,8 +286,7 @@ function jacobian(f, pars::Params) J = similar(y, T, length(y), length(p)) out[p] = J end - delta = fill!(similar(y, length(y), length(y)), 0) - delta[LinearAlgebra.diagind(delta)] .= 1 + delta = _eyelike(y) for k in LinearIndices(y) grads = back(delta[:,k]) for p in pars @@ -292,4 +296,3 @@ function jacobian(f, pars::Params) end Grads(out, pars) end - From 100021352a5d69243729af3c8d7cb88316703a03 Mon Sep 17 00:00:00 2001 From: Michael Abbott Date: Fri, 29 Jan 2021 10:22:33 +0100 Subject: [PATCH 13/17] doc tweaks --- src/lib/utils.jl | 26 +++++++++++++++----------- 1 file changed, 15 insertions(+), 11 deletions(-) diff --git a/src/lib/utils.jl b/src/lib/utils.jl index e1847d12b..023200f9c 100644 --- a/src/lib/utils.jl +++ b/src/lib/utils.jl @@ -128,11 +128,13 @@ isderiving(x) = false """ hessian(f, x) -Construct the Hessian `∂²f/∂x∂x`, where `x` is a real number or an array, -and `f(x)` is a real number. +Construct the Hessian `∂²f/∂x²`, where `x` is a real number or an array, +and `f(x)` is a real number. When `x` is an array, the result is a matrix +`H[i,j] = ∂²f/∂x[i]∂x[j]`, using linear indexing `x[i]` even if the argument +is higher-dimensional. -Uses forward over reverse, ForwardDiff over Zygote, calling `hessian_dual(f, x)`. -See [`hessian_reverse`](@ref) for an all-Zygote version. +This uses forward over reverse, ForwardDiff over Zygote, calling `hessian_dual(f, x)`. +See [`hessian_reverse`](@ref) for an all-Zygote alternative. # Examples @@ -173,24 +175,26 @@ hessian_reverse(f, x::AbstractArray) = jacobian(x -> gradient(f, x)[1], x)[1] hessian_reverse(f, x::Number) = gradient(x -> gradient(f, x)[1], x)[1] """ - jacobian(f, args...) + jacobian(f, args...) -> Tuple For each array `a ∈ args` this returns a matrix with `Ja[k,i] = ∂y[k]/∂a[i]` where `y = f(args...)` is usually a vector. +Arrays of higher dimension are treated like `vec(a)`, or `vec(y)` for output. + For scalar `x::Number ∈ args`, the result `Jx[k,1] = ∂y[k]/∂x` is a vector, while for scalar `y` all results have just one row. -For any other argument type, no result is produced, even if [`gradient`](@ref) would work. +With any other argument type, no result is produced, even if [`gradient`](@ref) would work. This reverse-mode Jacobian needs to evaluate the pullback once for each element of `y`. -This is usually only efficient when `length(y)` is small compared to `length(a)`, +Doing so is usually only efficient when `length(y)` is small compared to `length(a)`, otherwise forward mode is likely to be better. See also [`hessian`](@ref), [`hessian_reverse`](@ref). # Examples -```jldoctest +```jldoctest; setup=:(using Zygote) julia> jacobian(a -> 100*a[1:3].^2, 1:7)[1] # first index (rows) is output 3×7 Array{$Int,2}: 200 0 0 0 0 0 0 @@ -207,7 +211,7 @@ julia> jacobian((a,d) -> prod(a, dims=d), [1 2; 3 4; 5 6], 2) !!! warning For arguments of any type except `Number` & `AbstractArray`, the result is `nothing`. -```jldoctest +```jldoctest; setup=:(using Zygote) julia> jacobian((a,s) -> a.^length(s), [1,2,3], "str") ([3 0 0; 0 12 0; 0 0 27], nothing) @@ -257,11 +261,11 @@ _gradcopy!(dst::AbstractArray, src::AbstractArray) = copyto!(dst, g isa Number ? """ jacobian(loss, ::Params) -Like `gradient` with implicit parameters, this method takes a zero-argument function +Like [`gradient`](@ref) with implicit parameters, this method takes a zero-argument function and returns an `IdDict`-like object, now containing the Jacobian for each parameter. # Examples -```jldoctest +```jldoctest; setup=:(using Zygote) julia> xs = [1 2; 3 4]; ys = [5,7,9]; julia> Jxy = jacobian(() -> ys[1:2] .+ sum(xs.^2), Params([xs, ys])) From 6a0023cc48867865cbbff7d661ddc934a5f6b9f5 Mon Sep 17 00:00:00 2001 From: Michael Abbott Date: Fri, 29 Jan 2021 10:37:59 +0100 Subject: [PATCH 14/17] move to grad.jl, no other changes --- src/lib/grad.jl | 175 ++++++++++++++++++++++++++++++++++++++++++++++ src/lib/utils.jl | 177 +---------------------------------------------- 2 files changed, 176 insertions(+), 176 deletions(-) diff --git a/src/lib/grad.jl b/src/lib/grad.jl index a2b2cf6d7..87ea7211f 100644 --- a/src/lib/grad.jl +++ b/src/lib/grad.jl @@ -38,3 +38,178 @@ function Zygote._pullback(ctx::Zygote.AContext, ::typeof(checkpointed), f, xs... end return y, pullback_checkpointed end + +""" + hessian(f, x) + +Construct the Hessian `∂²f/∂x²`, where `x` is a real number or an array, +and `f(x)` is a real number. When `x` is an array, the result is a matrix +`H[i,j] = ∂²f/∂x[i]∂x[j]`, using linear indexing `x[i]` even if the argument +is higher-dimensional. + +This uses forward over reverse, ForwardDiff over Zygote, calling `hessian_dual(f, x)`. +See [`hessian_reverse`](@ref) for an all-Zygote alternative. + +# Examples + +```jldoctest; setup=:(using Zygote) +julia> Zygote.hessian(x -> x[1]*x[2], randn(2)) +2×2 Array{Float64,2}: + 0.0 1.0 + 1.0 0.0 + +julia> Zygote.hessian(x -> sum(x.^3), [1 2; 3 4]) # uses linear indexing of x +4×4 Array{$Int,2}: + 6 0 0 0 + 0 18 0 0 + 0 0 12 0 + 0 0 0 24 + +julia> Zygote.hessian(sin, pi/2) +-1.0 +``` +""" +hessian(f, x) = hessian_dual(f, x) + +hessian_dual(f, x::AbstractArray) = forward_jacobian(x -> gradient(f, x)[1], x)[2] + +hessian_dual(f, x::Number) = ForwardDiff.derivative(x -> gradient(f, x)[1], x) + +""" + hessian_reverse(f, x) + +This should be equivalent to [`hessian(f, x)`](@ref hessian), +but implemented using reverse over reverse mode, all Zygote. +(This is usually much slower, and more likely to find errors.) +""" +hessian_reverse(f, x::AbstractArray) = jacobian(x -> gradient(f, x)[1], x)[1] + +hessian_reverse(f, x::Number) = gradient(x -> gradient(f, x)[1], x)[1] + + +""" + jacobian(f, args...) -> Tuple + +For each array `a ∈ args` this returns a matrix with `Ja[k,i] = ∂y[k]/∂a[i]` +where `y = f(args...)` is usually a vector. +Arrays of higher dimension are treated like `vec(a)`, or `vec(y)` for output. + +For scalar `x::Number ∈ args`, the result `Jx[k,1] = ∂y[k]/∂x` is a vector, +while for scalar `y` all results have just one row. + +With any other argument type, no result is produced, even if [`gradient`](@ref) would work. + +This reverse-mode Jacobian needs to evaluate the pullback once for each element of `y`. +Doing so is usually only efficient when `length(y)` is small compared to `length(a)`, +otherwise forward mode is likely to be better. + +See also [`hessian`](@ref), [`hessian_reverse`](@ref). + +# Examples + +```jldoctest; setup=:(using Zygote) +julia> jacobian(a -> 100*a[1:3].^2, 1:7)[1] # first index (rows) is output +3×7 Array{$Int,2}: + 200 0 0 0 0 0 0 + 0 400 0 0 0 0 0 + 0 0 600 0 0 0 0 + +julia> jacobian((a,x) -> a.^2 .* x, [1,2,3], 1) # scalar argument has vector jacobian +([2 0 0; 0 4 0; 0 0 6], [1, 4, 9]) + +julia> jacobian((a,d) -> prod(a, dims=d), [1 2; 3 4; 5 6], 2) +([2 0 … 0 0; 0 4 … 3 0; 0 0 … 0 5], [0, 0, 0]) +``` + +!!! warning + For arguments of any type except `Number` & `AbstractArray`, the result is `nothing`. + +```jldoctest; setup=:(using Zygote) +julia> jacobian((a,s) -> a.^length(s), [1,2,3], "str") +([3 0 0; 0 12 0; 0 0 27], nothing) + +julia> jacobian((a,t) -> sum(a .* t[1]) + t[2], [1,2,3], (4,5)) +([4 4 4], nothing) + +julia> gradient((a,t) -> sum(a .* t[1]) + t[2], [1,2,3], (4,5)) # gradient undersands the tuple +([4, 4, 4], (6, 1)) +``` +""" +function jacobian(f, args...) + y, back = pullback(_jvec∘f, args...) + out = map(args) do x + T = promote_type(eltype(x), eltype(y)) + dx = x isa AbstractArray ? similar(x, T, length(y), length(x)) : + x isa Number ? similar(y, T, length(y)) : + nothing + end + delta = _eyelike(y) + for k in LinearIndices(y) + grads = back(delta[:,k]) + for (dx, grad) in zip(out, grads) + dx isa AbstractArray || continue + _gradcopy!(view(dx,k,:), grad) + end + end + out +end + +_jvec(x::AbstractArray) = vec(x) +_jvec(x::Number) = _jvec(vcat(x)) +_jvec(x) = throw(ArgumentError("jacobian expected a function which returns an array, or a scalar, got $(typeof(x))")) +_jvec(x::AbstractArray{<:Complex}) = throw(ArgumentError("jacobian does not accept complex output")) + +_eyelike(y::Vector) = Matrix{eltype(y)}(I, length(y), length(y)) +function _eyelike(y::AbstractVector) # version which works on GPU + out = fill!(similar(y, length(y), length(y)), 0) + out[LinearAlgebra.diagind(out)] .= 1 + out +end + +_gradcopy!(dst::AbstractArray, src::AbstractArray{<:Number}) = copyto!(dst, src) +_gradcopy!(dst::AbstractArray, src::Number) = copyto!(dst, src) +_gradcopy!(dst::AbstractArray, src::Nothing) = dst .= 0 +_gradcopy!(dst::AbstractArray, src::AbstractArray) = copyto!(dst, g isa Number ? g : 0 for g in src) # e.g. Union{Nothing,Float64} + +""" + jacobian(loss, ::Params) + +Like [`gradient`](@ref) with implicit parameters, this method takes a zero-argument function +and returns an `IdDict`-like object, now containing the Jacobian for each parameter. + +# Examples +```jldoctest; setup=:(using Zygote) +julia> xs = [1 2; 3 4]; ys = [5,7,9]; + +julia> Jxy = jacobian(() -> ys[1:2] .+ sum(xs.^2), Params([xs, ys])) +Grads(...) + +julia> Jxy[ys] +2×3 Array{$Int,2}: + 1 0 0 + 0 1 0 + +julia> Jxy[xs] +2×4 Array{$Int,2}: + 2 6 4 8 + 2 6 4 8 +``` +""" +function jacobian(f, pars::Params) + y, back = pullback(_jvec∘f, pars) + out = IdDict() + for p in pars + T = Base.promote_type(eltype(p), eltype(y)) + J = similar(y, T, length(y), length(p)) + out[p] = J + end + delta = _eyelike(y) + for k in LinearIndices(y) + grads = back(delta[:,k]) + for p in pars + out[p] isa AbstractArray || continue + _gradcopy!(view(out[p],k,:), grads[p]) + end + end + Grads(out, pars) +end diff --git a/src/lib/utils.jl b/src/lib/utils.jl index 023200f9c..86e6fff8c 100644 --- a/src/lib/utils.jl +++ b/src/lib/utils.jl @@ -122,181 +122,6 @@ Check whether the current function call is happening while taking the derivative """ isderiving() = false isderiving(x) = false + @adjoint isderiving() = true, _ -> nothing @adjoint isderiving(x) = true, x -> (nothing,) - -""" - hessian(f, x) - -Construct the Hessian `∂²f/∂x²`, where `x` is a real number or an array, -and `f(x)` is a real number. When `x` is an array, the result is a matrix -`H[i,j] = ∂²f/∂x[i]∂x[j]`, using linear indexing `x[i]` even if the argument -is higher-dimensional. - -This uses forward over reverse, ForwardDiff over Zygote, calling `hessian_dual(f, x)`. -See [`hessian_reverse`](@ref) for an all-Zygote alternative. - -# Examples - -```jldoctest; setup=:(using Zygote) -julia> Zygote.hessian(x -> x[1]*x[2], randn(2)) -2×2 Array{Float64,2}: - 0.0 1.0 - 1.0 0.0 - -julia> Zygote.hessian(x -> sum(x.^3), [1 2; 3 4]) # uses linear indexing of x -4×4 Array{$Int,2}: - 6 0 0 0 - 0 18 0 0 - 0 0 12 0 - 0 0 0 24 - -julia> Zygote.hessian(sin, pi/2) --1.0 -``` -""" -hessian(f, x::AbstractArray) = forward_jacobian(x -> gradient(f, x)[1], x)[2] - -hessian(f, x) = hessian_dual(f, x) - -hessian_dual(f, x::AbstractArray) = forward_jacobian(x -> gradient(f, x)[1], x)[2] - -hessian_dual(f, x::Number) = ForwardDiff.derivative(x -> gradient(f, x)[1], x) - -""" - hessian_reverse(f, x) - -This should be equivalent to [`hessian(f, x)`](@ref hessian), -but implemented using reverse over reverse mode, all Zygote. -(This is usually much slower, and more likely to find errors.) -""" -hessian_reverse(f, x::AbstractArray) = jacobian(x -> gradient(f, x)[1], x)[1] - -hessian_reverse(f, x::Number) = gradient(x -> gradient(f, x)[1], x)[1] - -""" - jacobian(f, args...) -> Tuple - -For each array `a ∈ args` this returns a matrix with `Ja[k,i] = ∂y[k]/∂a[i]` -where `y = f(args...)` is usually a vector. -Arrays of higher dimension are treated like `vec(a)`, or `vec(y)` for output. - -For scalar `x::Number ∈ args`, the result `Jx[k,1] = ∂y[k]/∂x` is a vector, -while for scalar `y` all results have just one row. - -With any other argument type, no result is produced, even if [`gradient`](@ref) would work. - -This reverse-mode Jacobian needs to evaluate the pullback once for each element of `y`. -Doing so is usually only efficient when `length(y)` is small compared to `length(a)`, -otherwise forward mode is likely to be better. - -See also [`hessian`](@ref), [`hessian_reverse`](@ref). - -# Examples - -```jldoctest; setup=:(using Zygote) -julia> jacobian(a -> 100*a[1:3].^2, 1:7)[1] # first index (rows) is output -3×7 Array{$Int,2}: - 200 0 0 0 0 0 0 - 0 400 0 0 0 0 0 - 0 0 600 0 0 0 0 - -julia> jacobian((a,x) -> a.^2 .* x, [1,2,3], 1) # scalar argument has vector jacobian -([2 0 0; 0 4 0; 0 0 6], [1, 4, 9]) - -julia> jacobian((a,d) -> prod(a, dims=d), [1 2; 3 4; 5 6], 2) -([2 0 … 0 0; 0 4 … 3 0; 0 0 … 0 5], [0, 0, 0]) -``` - -!!! warning - For arguments of any type except `Number` & `AbstractArray`, the result is `nothing`. - -```jldoctest; setup=:(using Zygote) -julia> jacobian((a,s) -> a.^length(s), [1,2,3], "str") -([3 0 0; 0 12 0; 0 0 27], nothing) - -julia> jacobian((a,t) -> sum(a .* t[1]) + t[2], [1,2,3], (4,5)) -([4 4 4], nothing) - -julia> gradient((a,t) -> sum(a .* t[1]) + t[2], [1,2,3], (4,5)) # gradient undersands the tuple -([4, 4, 4], (6, 1)) -``` -""" -function jacobian(f, args...) - y, back = pullback(_jvec∘f, args...) - out = map(args) do x - T = promote_type(eltype(x), eltype(y)) - dx = x isa AbstractArray ? similar(x, T, length(y), length(x)) : - x isa Number ? similar(y, T, length(y)) : - nothing - end - delta = _eyelike(y) - for k in LinearIndices(y) - grads = back(delta[:,k]) - for (dx, grad) in zip(out, grads) - dx isa AbstractArray || continue - _gradcopy!(view(dx,k,:), grad) - end - end - out -end - -_jvec(x::AbstractArray) = vec(x) -_jvec(x::Number) = _jvec(vcat(x)) -_jvec(x) = throw(ArgumentError("jacobian expected a function which returns an array, or a scalar, got $(typeof(x))")) -_jvec(x::AbstractArray{<:Complex}) = throw(ArgumentError("jacobian does not accept complex output")) - -_eyelike(y::Vector) = Matrix{eltype(y)}(I, length(y), length(y)) -function _eyelike(y::AbstractVector) # version which works on GPU - out = fill!(similar(y, length(y), length(y)), 0) - out[LinearAlgebra.diagind(out)] .= 1 - out -end - -_gradcopy!(dst::AbstractArray, src::AbstractArray{<:Number}) = copyto!(dst, src) -_gradcopy!(dst::AbstractArray, src::Number) = copyto!(dst, src) -_gradcopy!(dst::AbstractArray, src::Nothing) = dst .= 0 -_gradcopy!(dst::AbstractArray, src::AbstractArray) = copyto!(dst, g isa Number ? g : 0 for g in src) # e.g. Union{Nothing,Float64} - -""" - jacobian(loss, ::Params) - -Like [`gradient`](@ref) with implicit parameters, this method takes a zero-argument function -and returns an `IdDict`-like object, now containing the Jacobian for each parameter. - -# Examples -```jldoctest; setup=:(using Zygote) -julia> xs = [1 2; 3 4]; ys = [5,7,9]; - -julia> Jxy = jacobian(() -> ys[1:2] .+ sum(xs.^2), Params([xs, ys])) -Grads(...) - -julia> Jxy[ys] -2×3 Array{$Int,2}: - 1 0 0 - 0 1 0 - -julia> Jxy[xs] -2×4 Array{$Int,2}: - 2 6 4 8 - 2 6 4 8 -``` -""" -function jacobian(f, pars::Params) - y, back = pullback(_jvec∘f, pars) - out = IdDict() - for p in pars - T = Base.promote_type(eltype(p), eltype(y)) - J = similar(y, T, length(y), length(p)) - out[p] = J - end - delta = _eyelike(y) - for k in LinearIndices(y) - grads = back(delta[:,k]) - for p in pars - out[p] isa AbstractArray || continue - _gradcopy!(view(out[p],k,:), grads[p]) - end - end - Grads(out, pars) -end From 909cec5bdf6a7608daa433f50b503633c2e246e7 Mon Sep 17 00:00:00 2001 From: Michael Abbott Date: Sat, 30 Jan 2021 12:59:52 +0100 Subject: [PATCH 15/17] tweak --- src/lib/grad.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lib/grad.jl b/src/lib/grad.jl index 87ea7211f..c9910e2a7 100644 --- a/src/lib/grad.jl +++ b/src/lib/grad.jl @@ -94,7 +94,7 @@ For each array `a ∈ args` this returns a matrix with `Ja[k,i] = ∂y[k]/∂a[i where `y = f(args...)` is usually a vector. Arrays of higher dimension are treated like `vec(a)`, or `vec(y)` for output. -For scalar `x::Number ∈ args`, the result `Jx[k,1] = ∂y[k]/∂x` is a vector, +For scalar `x::Number ∈ args`, the result is a vector `Jx[k] = ∂y[k]/∂x`, while for scalar `y` all results have just one row. With any other argument type, no result is produced, even if [`gradient`](@ref) would work. From bc187985e27449f95bfc8014f5133450e30c40be Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Mon, 1 Feb 2021 10:39:19 +0100 Subject: [PATCH 16/17] Update src/compiler/interface.jl Co-authored-by: Carlo Lucibello --- src/compiler/interface.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/compiler/interface.jl b/src/compiler/interface.jl index 4535d946e..58a09dd9e 100644 --- a/src/compiler/interface.jl +++ b/src/compiler/interface.jl @@ -51,7 +51,7 @@ sensitivity(y) = error("Output should be scalar; gradients are not defined for o Returns a tuple containing `∂f/∂x` for each argument `x`, the derivative (for scalar x) or the gradient. -`f(x)` must be a real number, see [`jacobian`](@ref) for array output. +`f(args...)` must be a real number, see [`jacobian`](@ref) for array output. """ function gradient(f, args...) y, back = pullback(f, args...) From b1d868e67d058f333362e213c6cb843905b29b30 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Mon, 1 Feb 2021 11:30:33 +0100 Subject: [PATCH 17/17] Apply suggestions from code review Co-authored-by: Carlo Lucibello --- src/Zygote.jl | 2 +- src/lib/grad.jl | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/Zygote.jl b/src/Zygote.jl index 4b0401ada..fa0af463d 100644 --- a/src/Zygote.jl +++ b/src/Zygote.jl @@ -12,7 +12,7 @@ using MacroTools, Requires using MacroTools: @forward import Distributed: pmap, CachingPool, workers -export Params, gradient, jacobian, pullback, pushforward, @code_adjoint +export Params, gradient, jacobian, hessian, pullback, pushforward, @code_adjoint include("tools/idset.jl") include("tools/buffer.jl") diff --git a/src/lib/grad.jl b/src/lib/grad.jl index c9910e2a7..3999e51d7 100644 --- a/src/lib/grad.jl +++ b/src/lib/grad.jl @@ -53,19 +53,19 @@ See [`hessian_reverse`](@ref) for an all-Zygote alternative. # Examples ```jldoctest; setup=:(using Zygote) -julia> Zygote.hessian(x -> x[1]*x[2], randn(2)) +julia> hessian(x -> x[1]*x[2], randn(2)) 2×2 Array{Float64,2}: 0.0 1.0 1.0 0.0 -julia> Zygote.hessian(x -> sum(x.^3), [1 2; 3 4]) # uses linear indexing of x +julia> hessian(x -> sum(x.^3), [1 2; 3 4]) # uses linear indexing of x 4×4 Array{$Int,2}: 6 0 0 0 0 18 0 0 0 0 12 0 0 0 0 24 -julia> Zygote.hessian(sin, pi/2) +julia> hessian(sin, pi/2) -1.0 ``` """