Skip to content

Commit

Permalink
Merge #890
Browse files Browse the repository at this point in the history
890: Add `jacobian`, at last? r=CarloLucibello a=mcabbott

This adds a Jacobian function. 

Compared to #747 this one:
* has tests
* accepts multiple arguments
* shouldn't fail if `back` returns `nothing`
* inserts `vec` a few more places
* has a method which accepts implicit `Params`, and returns `Grads`. (This was the only part I actually needed in real life.)
* now works on the GPU too!

Compared to #414 this one:
* always inserts `vec`, never makes higher-dimensional arrays
* runs on current Zygote
* has tests.

Compared to #235 this one:
* doesn't try to provide numerical jacobian
* doesn't alter testing infrastructure
* doesn't provide `jacobian!`, nor any code for structured matrices.

This does not address #564's concerns about functions which return a tuple of arrays. Only functions returning an array (or a scalar) are permitted. Similar considerations might give sensible jacobians when the argument of a function is a tuple, or some other struct, but for now these are handled by putting up a giant warning sign.

Nothing in the file `utils.jl` seems to have any tests at all. So I also added tests for `hessian`. 

And, while I was there, made `hessian` actually accept a real number like its docstring promises. (Hence this closes #891.) And, made a version that is reverse-over-reverse, using this `jacobian`, which works less well of course but may as well exist to test things. (See for example #865.) Ideally there would be a pure-Zygote version using its own forward mode, but I didn't write that.

Fixes #51, fixes #98, fixes #413. Closes #747.

Co-authored-by: Michael Abbott <me@escbook>
Co-authored-by: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
  • Loading branch information
3 people committed Feb 2, 2021
2 parents 1007950 + b1d868e commit f41d803
Show file tree
Hide file tree
Showing 8 changed files with 290 additions and 31 deletions.
11 changes: 9 additions & 2 deletions docs/src/utils.md
Original file line number Diff line number Diff line change
@@ -1,14 +1,21 @@
# Utilities

Zygote provides a set of helpful utilities. These are all "user-level" tools –
Zygote's gradients can be used to construct a Jacobian (by repeated evaluation)
or a Hessian (by taking a second derivative).

```@docs
Zygote.jacobian
Zygote.hessian
```

Zygote also provides a set of helpful utilities. These are all "user-level" tools –
in other words you could have written them easily yourself, but they live in
Zygote for convenience.

```@docs
Zygote.@showgrad
Zygote.hook
Zygote.dropgrad
Zygote.hessian
Zygote.Buffer
Zygote.forwarddiff
Zygote.ignore
Expand Down
2 changes: 1 addition & 1 deletion src/Zygote.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ using MacroTools, Requires
using MacroTools: @forward

import Distributed: pmap, CachingPool, workers
export Params, gradient, pullback, pushforward, @code_adjoint
export Params, gradient, jacobian, hessian, pullback, pushforward, @code_adjoint

include("tools/idset.jl")
include("tools/buffer.jl")
Expand Down
9 changes: 9 additions & 0 deletions src/compiler/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,17 @@ end

sensitivity(y::Number) = one(y)
sensitivity(y::Complex) = error("Output is complex, so the gradient is not defined.")
sensitivity(y::AbstractArray) = error("output an array, so the gradient is not defined. Perhaps you wanted jacobian.")
sensitivity(y) = error("Output should be scalar; gradients are not defined for output $(repr(y))")

"""
gradient(f, args...)
Returns a tuple containing `∂f/∂x` for each argument `x`,
the derivative (for scalar x) or the gradient.
`f(args...)` must be a real number, see [`jacobian`](@ref) for array output.
"""
function gradient(f, args...)
y, back = pullback(f, args...)
return back(sensitivity(y))
Expand Down
175 changes: 175 additions & 0 deletions src/lib/grad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,178 @@ function Zygote._pullback(ctx::Zygote.AContext, ::typeof(checkpointed), f, xs...
end
return y, pullback_checkpointed
end

"""
hessian(f, x)
Construct the Hessian `∂²f/∂x²`, where `x` is a real number or an array,
and `f(x)` is a real number. When `x` is an array, the result is a matrix
`H[i,j] = ∂²f/∂x[i]∂x[j]`, using linear indexing `x[i]` even if the argument
is higher-dimensional.
This uses forward over reverse, ForwardDiff over Zygote, calling `hessian_dual(f, x)`.
See [`hessian_reverse`](@ref) for an all-Zygote alternative.
# Examples
```jldoctest; setup=:(using Zygote)
julia> hessian(x -> x[1]*x[2], randn(2))
2×2 Array{Float64,2}:
0.0 1.0
1.0 0.0
julia> hessian(x -> sum(x.^3), [1 2; 3 4]) # uses linear indexing of x
4×4 Array{$Int,2}:
6 0 0 0
0 18 0 0
0 0 12 0
0 0 0 24
julia> hessian(sin, pi/2)
-1.0
```
"""
hessian(f, x) = hessian_dual(f, x)

hessian_dual(f, x::AbstractArray) = forward_jacobian(x -> gradient(f, x)[1], x)[2]

hessian_dual(f, x::Number) = ForwardDiff.derivative(x -> gradient(f, x)[1], x)

"""
hessian_reverse(f, x)
This should be equivalent to [`hessian(f, x)`](@ref hessian),
but implemented using reverse over reverse mode, all Zygote.
(This is usually much slower, and more likely to find errors.)
"""
hessian_reverse(f, x::AbstractArray) = jacobian(x -> gradient(f, x)[1], x)[1]

hessian_reverse(f, x::Number) = gradient(x -> gradient(f, x)[1], x)[1]


"""
jacobian(f, args...) -> Tuple
For each array `a ∈ args` this returns a matrix with `Ja[k,i] = ∂y[k]/∂a[i]`
where `y = f(args...)` is usually a vector.
Arrays of higher dimension are treated like `vec(a)`, or `vec(y)` for output.
For scalar `x::Number ∈ args`, the result is a vector `Jx[k] = ∂y[k]/∂x`,
while for scalar `y` all results have just one row.
With any other argument type, no result is produced, even if [`gradient`](@ref) would work.
This reverse-mode Jacobian needs to evaluate the pullback once for each element of `y`.
Doing so is usually only efficient when `length(y)` is small compared to `length(a)`,
otherwise forward mode is likely to be better.
See also [`hessian`](@ref), [`hessian_reverse`](@ref).
# Examples
```jldoctest; setup=:(using Zygote)
julia> jacobian(a -> 100*a[1:3].^2, 1:7)[1] # first index (rows) is output
3×7 Array{$Int,2}:
200 0 0 0 0 0 0
0 400 0 0 0 0 0
0 0 600 0 0 0 0
julia> jacobian((a,x) -> a.^2 .* x, [1,2,3], 1) # scalar argument has vector jacobian
([2 0 0; 0 4 0; 0 0 6], [1, 4, 9])
julia> jacobian((a,d) -> prod(a, dims=d), [1 2; 3 4; 5 6], 2)
([2 0 … 0 0; 0 4 … 3 0; 0 0 … 0 5], [0, 0, 0])
```
!!! warning
For arguments of any type except `Number` & `AbstractArray`, the result is `nothing`.
```jldoctest; setup=:(using Zygote)
julia> jacobian((a,s) -> a.^length(s), [1,2,3], "str")
([3 0 0; 0 12 0; 0 0 27], nothing)
julia> jacobian((a,t) -> sum(a .* t[1]) + t[2], [1,2,3], (4,5))
([4 4 4], nothing)
julia> gradient((a,t) -> sum(a .* t[1]) + t[2], [1,2,3], (4,5)) # gradient undersands the tuple
([4, 4, 4], (6, 1))
```
"""
function jacobian(f, args...)
y, back = pullback(_jvecf, args...)
out = map(args) do x
T = promote_type(eltype(x), eltype(y))
dx = x isa AbstractArray ? similar(x, T, length(y), length(x)) :
x isa Number ? similar(y, T, length(y)) :
nothing
end
delta = _eyelike(y)
for k in LinearIndices(y)
grads = back(delta[:,k])
for (dx, grad) in zip(out, grads)
dx isa AbstractArray || continue
_gradcopy!(view(dx,k,:), grad)
end
end
out
end

_jvec(x::AbstractArray) = vec(x)
_jvec(x::Number) = _jvec(vcat(x))
_jvec(x) = throw(ArgumentError("jacobian expected a function which returns an array, or a scalar, got $(typeof(x))"))
_jvec(x::AbstractArray{<:Complex}) = throw(ArgumentError("jacobian does not accept complex output"))

_eyelike(y::Vector) = Matrix{eltype(y)}(I, length(y), length(y))
function _eyelike(y::AbstractVector) # version which works on GPU
out = fill!(similar(y, length(y), length(y)), 0)
out[LinearAlgebra.diagind(out)] .= 1
out
end

_gradcopy!(dst::AbstractArray, src::AbstractArray{<:Number}) = copyto!(dst, src)
_gradcopy!(dst::AbstractArray, src::Number) = copyto!(dst, src)
_gradcopy!(dst::AbstractArray, src::Nothing) = dst .= 0
_gradcopy!(dst::AbstractArray, src::AbstractArray) = copyto!(dst, g isa Number ? g : 0 for g in src) # e.g. Union{Nothing,Float64}

"""
jacobian(loss, ::Params)
Like [`gradient`](@ref) with implicit parameters, this method takes a zero-argument function
and returns an `IdDict`-like object, now containing the Jacobian for each parameter.
# Examples
```jldoctest; setup=:(using Zygote)
julia> xs = [1 2; 3 4]; ys = [5,7,9];
julia> Jxy = jacobian(() -> ys[1:2] .+ sum(xs.^2), Params([xs, ys]))
Grads(...)
julia> Jxy[ys]
2×3 Array{$Int,2}:
1 0 0
0 1 0
julia> Jxy[xs]
2×4 Array{$Int,2}:
2 6 4 8
2 6 4 8
```
"""
function jacobian(f, pars::Params)
y, back = pullback(_jvecf, pars)
out = IdDict()
for p in pars
T = Base.promote_type(eltype(p), eltype(y))
J = similar(y, T, length(y), length(p))
out[p] = J
end
delta = _eyelike(y)
for k in LinearIndices(y)
grads = back(delta[:,k])
for p in pars
out[p] isa AbstractArray || continue
_gradcopy!(view(out[p],k,:), grads[p])
end
end
Grads(out, pars)
end
26 changes: 7 additions & 19 deletions src/lib/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,11 @@ ignore(f) = f()
Tell Zygote to ignore an expression. Equivalent to `ignore() do (...) end`.
Example:
```julia-repl
julia> f(x) = (y = Zygote.@ignore x; x * y);
```julia-repl
julia> f(x) = (y = Zygote.@ignore x; x * y);
julia> f'(1)
1
```
```
"""
macro ignore(ex)
return :(Zygote.ignore() do
Expand Down Expand Up @@ -99,19 +99,6 @@ macro showgrad(x)
end)
end

"""
hessian(f, x)
Construct the Hessian of `f`, where `x` is a real or real array and `f(x)` is
a real.
julia> hessian(((a, b),) -> a*b, [2, 3])
2×2 Array{Int64,2}:
0 1
1 0
"""
hessian(f, x::AbstractArray) = forward_jacobian(x -> gradient(f, x)[1], x)[2]

"""
isderiving()
isderiving(x)
Expand All @@ -122,18 +109,19 @@ Check whether the current function call is happening while taking the derivative
julia> function f(x)
@show isderiving()
end
f (generic function with 1 method)
julia> f(3)
isderiving() = false
false
julia> gradient(f, 4)
isderiving() = true
(nothing,)
"""
isderiving() = false
isderiving(x) = false

@adjoint isderiving() = true, _ -> nothing
@adjoint isderiving(x) = true, x -> (nothing,)
14 changes: 14 additions & 0 deletions test/cuda.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,17 @@ end
log_grada = cu(Float32[1.0, 0.5, 0.33333334, 0.25, 0.2, 0.16666667, 0.14285715, 0.125, 0.11111111])
@test gradient(x -> w(x) |> sum, a) == (log_grada,)
end

@testset "jacobian" begin
v1 = cu(collect(1:3f0))

res1 = jacobian(x -> x .* x', 1:3f0)[1]
j1 = jacobian(x -> x .* x', v1)[1]
@test j1 isa CuArray
@test j1 cu(res1)

res2 = jacobian(x -> x ./ sum(x), 1:3f0)[1]
j2 = jacobian(() -> v1 ./ sum(v1), Params([v1]))
@test j2[v1] isa CuArray
@test j2[v1] cu(res2)
end
21 changes: 12 additions & 9 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,26 @@ using Zygote, Test
using Zygote: gradient
using CUDA: has_cuda

if has_cuda()
@testset "CUDA tests" begin
include("cuda.jl")
end
else
@warn "CUDA not found - Skipping CUDA Tests"
end

@testset "Interface" begin
include("interface.jl")
end


@testset "Tools" begin
include("tools.jl")
end

@testset "Utils" begin
include("utils.jl")
end

@testset "lib/number" begin
include("lib/number.jl")
end
Expand Down Expand Up @@ -42,11 +53,3 @@ end
@testset "Compiler" begin
include("compiler.jl")
end

if has_cuda()
@testset "CUDA tests" begin
include("cuda.jl")
end
else
@warn "CUDA not found - Skipping CUDA Tests"
end

0 comments on commit f41d803

Please sign in to comment.