Skip to content

Commit

Permalink
Update of #36 and #35 (#93)
Browse files Browse the repository at this point in the history
* Define value_and_pullback_function as returning value

* Update definition of Jacobian

* Update tests

* Update API definition

* Increment minor version number

* Make value_and_oullback_function the primitive

* Define pullback_function in terms of value_and_pullback_function

* Define value_and_pullback_function in tests

* Increment version number

* Use value_and_pb_function in CRC and Tracker extensions

* Fix bug in Tracker backend

* More updates

* Fix end

* Handle `nothing`

* Fix test failures

* Use named functions

---------

Co-authored-by: Seth Axen <seth.axen@gmail.com>
Co-authored-by: Guillaume Dalle <22795598+gdalle@users.noreply.github.com>
  • Loading branch information
3 people committed Sep 20, 2023
1 parent 00181f8 commit 3c18e86
Show file tree
Hide file tree
Showing 9 changed files with 92 additions and 73 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "AbstractDifferentiation"
uuid = "c29ec348-61ec-40c8-8164-b8c60e9d9f3d"
authors = ["Mohamed Tarek <mohamed82008@gmail.com> and contributors"]
version = "0.5.3"
version = "0.6.0-DEV"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ This operation goes by a few names. Refer to the [ChainRules documentation](http

The following functions can be used to request the pullback operator/function with or without the function value. In order to request the pullback function `pb_f` of a function `f` at the inputs `xs`, you can use either of:
- `pb_f = AD.pullback_function(ab::AD.AbstractBackend, f, xs...)`: returns the pullback function `pb_f` of the function `f` at the inputs `xs`. `pb_f` is a function that accepts the co-tangents `vs` as input which is a tuple of length equal to the number of outputs of `f`. If `f` has a single output, `pb_f` can also accept a single input instead of a 1-tuple.
- `value_and_pb_f = AD.value_and_pullback_function(ab::AD.AbstractBackend, f, xs...)`: returns a function `value_and_pb_f` which accepts the co-tangent `vs` as input which is a tuple of length equal to the number of outputs of `f`. If `f` has a single output, `value_and_pb_f` can accept a single input instead of a 1-tuple. `value_and_pb_f` returns a 2-tuple, namely the value `f(xs...)` and output of the pullback operator.
- `value_and_pb_f = AD.value_and_pullback_function(ab::AD.AbstractBackend, f, xs...)`: computes the function value `v = f(xs...)` and returns a 2-tuple containing the value `v` and a function `pb_f` that accepts the co-tangent `vs` as input, which is a tuple of length equal to the number of outputs of `f`. If `f` has a single output, `pb_f` can accept a single input instead of a 1-tuple.

### Lazy operators

Expand Down
16 changes: 11 additions & 5 deletions ext/AbstractDifferentiationChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,17 @@ module AbstractDifferentiationChainRulesCoreExt
import AbstractDifferentiation as AD
using ChainRulesCore: ChainRulesCore

AD.@primitive function pullback_function(ba::AD.ReverseRuleConfigBackend, f, xs...)
_, back = ChainRulesCore.rrule_via_ad(AD.ruleconfig(ba), f, xs...)
pullback(vs) = Base.tail(back(vs))
pullback(vs::Tuple{Any}) = Base.tail(back(first(vs)))
return pullback
AD.@primitive function value_and_pullback_function(ba::AD.ReverseRuleConfigBackend, f, xs...)
value, back = ChainRulesCore.rrule_via_ad(AD.ruleconfig(ba), f, xs...)
function rrule_pullback(vs)
_vs = if vs isa Tuple && !(value isa Tuple)
only(vs)
else
vs
end
return Base.tail(back(_vs))
end
return value, rrule_pullback
end

end # module
13 changes: 13 additions & 0 deletions ext/AbstractDifferentiationFiniteDifferencesExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@ function AD.jacobian(ba::AD.FiniteDifferencesBackend, f, xs...)
return FiniteDifferences.jacobian(ba.method, f, xs...)
end

function AD.gradient(ba::AD.FiniteDifferencesBackend, f, xs...)
return FiniteDifferences.grad(ba.method, f, xs...)
end

function AD.pushforward_function(ba::AD.FiniteDifferencesBackend, f, xs...)
return function pushforward(vs)
ws = FiniteDifferences.jvp(ba.method, f, tuple.(xs, vs)...)
Expand All @@ -32,6 +36,15 @@ function AD.pullback_function(ba::AD.FiniteDifferencesBackend, f, xs...)
end
end

# Ensure consistency with `value_and_pullback` function
function AD.value_and_pullback_function(ba::AD.FiniteDifferencesBackend, f, xs...)
value = f(xs...)
function fd_pullback(vs)
return FiniteDifferences.j′vp(ba.method, f, vs, xs...)
end
return value, fd_pullback
end

# Better performance: issue #87
function AD.derivative(ba::AD.FiniteDifferencesBackend, f::TF, x::Real) where {TF<:Function}
return (ba.method(f, x),)
Expand Down
16 changes: 9 additions & 7 deletions ext/AbstractDifferentiationTrackerExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,18 @@ AD.primal_value(x::Tracker.TrackedReal) = Tracker.data(x)
AD.primal_value(x::Tracker.TrackedArray) = Tracker.data(x)
AD.primal_value(x::AbstractArray{<:Tracker.TrackedReal}) = Tracker.data.(x)

AD.@primitive function pullback_function(ba::AD.TrackerBackend, f, xs...)
value, back = Tracker.forward(f, xs...)
function pullback(ws)
if ws isa Tuple && !(value isa Tuple)
map(Tracker.data, back(only(ws)))
AD.@primitive function value_and_pullback_function(ba::AD.TrackerBackend, f, xs...)
_value, back = Tracker.forward(f, xs...)
value = map(Tracker.data, _value)
function tracker_pullback(ws)
_ws = if ws isa Tuple && !(value isa Tuple)
only(ws)
else
map(Tracker.data, back(ws))
ws
end
return map(Tracker.data, back(_ws))
end
return pullback
return value, tracker_pullback
end

function AD.derivative(::AD.TrackerBackend, f, xs::Number...)
Expand Down
65 changes: 25 additions & 40 deletions src/AbstractDifferentiation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -221,48 +221,28 @@ end
end

function pullback_function(ab::AbstractBackend, f, xs...)
return (ws) -> begin
return gradient(lowest(ab), (xs...,) -> begin
vs = f(xs...)
if ws isa Tuple
@assert length(vs) == length(ws)
return sum(Base.splat(_dot), zip(ws, vs))
else
return _dot(vs, ws)
end
end, xs...)
end
_, pbf = value_and_pullback_function(ab, f, xs...)
return pbf
end
function value_and_pullback_function(
ab::AbstractBackend,
f,
xs...,
)
return (ws) -> begin
local value
primalcalled = false
if ab isa AbstractFiniteDifference
value = primal_value(ab, nothing, f, xs)
primalcalled = true
end
if ws === nothing
vs = f(xs...)
if !primalcalled
value = primal_value(lowest(ab), vs, f, xs)
primalcalled = true
end
return value, nothing
end
pb = pullback_function(lowest(ab), (_xs...,) -> begin
value = f(xs...)
function pullback_function(ws)
function pullback_gradient_function(_xs...)
vs = f(_xs...)
if !primalcalled
value = primal_value(lowest(ab), vs, f, xs)
primalcalled = true
if ws isa Tuple
@assert length(vs) == length(ws)
return sum(Base.splat(_dot), zip(ws, vs))
else
return _dot(vs, ws)
end
return vs
end, xs...)(ws)
return value, pb
end
return gradient(lowest(ab), pullback_gradient_function, xs...)
end
return value, pullback_function
end

struct LazyDerivative{B, F, X}
Expand Down Expand Up @@ -494,6 +474,12 @@ macro primitive(expr)
name = fdef[:name]
if name == :pushforward_function
return define_pushforward_function_and_friends(fdef) |> esc
elseif name == :value_and_pullback_function
return define_value_and_pullback_function_and_friends(fdef) |> esc
elseif name == :jacobian
return define_jacobian_and_friends(fdef) |> esc
elseif name == :primal_value
return define_primal_value(fdef) |> esc
elseif name == :pullback_function
return define_pullback_function_and_friends(fdef) |> esc
else
Expand Down Expand Up @@ -537,30 +523,29 @@ function define_pushforward_function_and_friends(fdef)
return funcs
end

function define_pullback_function_and_friends(fdef)
fdef[:name] = :($(AbstractDifferentiation).pullback_function)
function define_value_and_pullback_function_and_friends(fdef)
fdef[:name] = :($(AbstractDifferentiation).value_and_pullback_function)
args = fdef[:args]
funcs = quote
$(ExprTools.combinedef(fdef))
function $(AbstractDifferentiation).jacobian($(args...),)
value_and_pbf = $(value_and_pullback_function)($(args...),)
value, _ = value_and_pbf(nothing)
value, pbf = $(value_and_pullback_function)($(args...),)
identity_like = $(identity_matrix_like)(value)
if eltype(identity_like) <: Tuple{Vararg{AbstractMatrix}}
return map(identity_like) do identity_like_i
return mapreduce(vcat, $(_eachcol).(identity_like_i)...) do (cols...)
value_and_pbf(cols)[2]'
pbf(cols)'
end
end
elseif eltype(identity_like) <: AbstractMatrix
# needed for Hessian computation:
# value is a (grad,). Then, identity_like is a (matrix,).
# cols loops over columns of the matrix
return vcat.(mapslices(identity_like[1], dims=1) do cols
adjoint.(value_and_pbf((cols,))[2])
adjoint.(pbf((cols,)))
end ...)
else
return adjoint.(value_and_pbf(identity_like)[2])
return adjoint.(pbf(identity_like))
end
end
end
Expand Down
29 changes: 13 additions & 16 deletions test/defaults.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,14 @@ struct FDMBackend3{A} <: AD.AbstractFiniteDifference
end
FDMBackend3() = FDMBackend3(central_fdm(5, 1))
const fdm_backend3 = FDMBackend3()
AD.@primitive function pullback_function(ab::FDMBackend3, f, xs...)
return function (vs)
AD.@primitive function value_and_pullback_function(ab::FDMBackend3, f, xs...)
value = f(xs...)
function fd3_pullback(vs)
# Supports only single output
if vs isa AbstractVector
return FDM.j′vp(ab.alg, f, vs, xs...)
else
return FDM.j′vp(ab.alg, f, only(vs), xs...)
end
_vs = vs isa AbstractVector ? vs : only(vs)
return FDM.j′vp(ab.alg, f, _vs, xs...)
end
return value, fd3_pullback
end
##

Expand Down Expand Up @@ -90,16 +89,14 @@ AD.primal_value(::ForwardDiffBackend2, ::Any, f, xs) = ForwardDiff.value.(f(xs..
## Zygote
struct ZygoteBackend1 <: AD.AbstractReverseMode end
const zygote_backend1 = ZygoteBackend1()
AD.@primitive function pullback_function(ab::ZygoteBackend1, f, xs...)
return function (vs)
# Supports only single output
_, back = Zygote.pullback(f, xs...)
if vs isa AbstractVector
back(vs)
else
back(only(vs))
end
AD.@primitive function value_and_pullback_function(ab::ZygoteBackend1, f, xs...)
# Supports only single output
value, back = Zygote.pullback(f, xs...)
function zygote_pullback(vs)
_vs = vs isa AbstractVector ? vs : only(vs)
return back(_vs)
end
return value, zygote_pullback
end

@testset "defaults" begin
Expand Down
13 changes: 13 additions & 0 deletions test/ruleconfig.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using AbstractDifferentiation
using ChainRulesCore
using Test
using Zygote

Expand Down Expand Up @@ -52,4 +53,16 @@ using Zygote
end
@test AD.jacobian(ad, f, [1, 2, 3], 3) == ([6.0 0.0 0.0; 0.0 6.0 0.0; 0.0 0.0 6.0], [2.0, 4.0, 6.0])
end

# issue #57
@testset "primal computation in rrule" begin
function myfunc(x)
@info "This should not be logged if I have an rrule"
x
end
ChainRulesCore.rrule(::typeof(myfunc), x) = (x, (y -> (NoTangent(), y)))

@test_logs Zygote.gradient(myfunc, 1) # nothing is logged
@test_logs AD.derivative(AD.ZygoteBackend(), myfunc, 1) # nothing is logged
end
end
9 changes: 6 additions & 3 deletions test/test_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,8 @@ function test_j′vp(backend; multiple_inputs=true, rng=Random.GLOBAL_RNG, test_
w = rand(rng, length(fjac(xvec, yvec)))
if multiple_inputs
pb1 = AD.pullback_function(backend, fjac, xvec, yvec)(w)
valvec, pb2 = AD.value_and_pullback_function(backend, fjac, xvec, yvec)(w)
valvec, pbf2 = AD.value_and_pullback_function(backend, fjac, xvec, yvec)
pb2 = pbf2(w)

if test_types
@test valvec isa Vector{Float64}
Expand All @@ -263,8 +264,10 @@ function test_j′vp(backend; multiple_inputs=true, rng=Random.GLOBAL_RNG, test_
@test yvec == yvec2
end

valvec1, pb1 = AD.value_and_pullback_function(backend, x -> fjac(x, yvec), xvec)(w)
valvec2, pb2 = AD.value_and_pullback_function(backend, y -> fjac(xvec, y), yvec)(w)
valvec1, pbf1 = AD.value_and_pullback_function(backend, x -> fjac(x, yvec), xvec)
pb1 = pbf1(w)
valvec2, pbf2 = AD.value_and_pullback_function(backend, y -> fjac(xvec, y), yvec)
pb2 = pbf2(w)
if test_types
@test valvec1 isa Vector{Float64}
@test valvec2 isa Vector{Float64}
Expand Down

0 comments on commit 3c18e86

Please sign in to comment.