Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update of #36 and #35 #93

Merged
merged 20 commits into from
Sep 20, 2023
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "AbstractDifferentiation"
uuid = "c29ec348-61ec-40c8-8164-b8c60e9d9f3d"
authors = ["Mohamed Tarek <mohamed82008@gmail.com> and contributors"]
version = "0.5.2"
version = "0.6.0-DEV"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ This operation goes by a few names. Refer to the [ChainRules documentation](http

The following functions can be used to request the pullback operator/function with or without the function value. In order to request the pullback function `pb_f` of a function `f` at the inputs `xs`, you can use either of:
- `pb_f = AD.pullback_function(ab::AD.AbstractBackend, f, xs...)`: returns the pullback function `pb_f` of the function `f` at the inputs `xs`. `pb_f` is a function that accepts the co-tangents `vs` as input which is a tuple of length equal to the number of outputs of `f`. If `f` has a single output, `pb_f` can also accept a single input instead of a 1-tuple.
- `value_and_pb_f = AD.value_and_pullback_function(ab::AD.AbstractBackend, f, xs...)`: returns a function `value_and_pb_f` which accepts the co-tangent `vs` as input which is a tuple of length equal to the number of outputs of `f`. If `f` has a single output, `value_and_pb_f` can accept a single input instead of a 1-tuple. `value_and_pb_f` returns a 2-tuple, namely the value `f(xs...)` and output of the pullback operator.
- `value_and_pb_f = AD.value_and_pullback_function(ab::AD.AbstractBackend, f, xs...)`: computes the function value `v = f(xs...)` and returns a 2-tuple containing the value `v` and a function `pb_f` that accepts the co-tangent `vs` as input, which is a tuple of length equal to the number of outputs of `f`. If `f` has a single output, `pb_f` can accept a single input instead of a 1-tuple.

### Lazy operators

Expand Down
16 changes: 11 additions & 5 deletions ext/AbstractDifferentiationChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,17 @@ module AbstractDifferentiationChainRulesCoreExt
import AbstractDifferentiation as AD
using ChainRulesCore: ChainRulesCore

AD.@primitive function pullback_function(ba::AD.ReverseRuleConfigBackend, f, xs...)
_, back = ChainRulesCore.rrule_via_ad(AD.ruleconfig(ba), f, xs...)
pullback(vs) = Base.tail(back(vs))
pullback(vs::Tuple{Any}) = Base.tail(back(first(vs)))
return pullback
AD.@primitive function value_and_pullback_function(ba::AD.ReverseRuleConfigBackend, f, xs...)
value, back = ChainRulesCore.rrule_via_ad(AD.ruleconfig(ba), f, xs...)
function rrule_pullback(vs)
_vs = if vs isa Tuple && !(value isa Tuple)
only(vs)
else
vs
end
return Base.tail(back(_vs))
end
return value, rrule_pullback
end

end # module
13 changes: 13 additions & 0 deletions ext/AbstractDifferentiationFiniteDifferencesExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@ AD.@primitive function jacobian(ba::AD.FiniteDifferencesBackend, f, xs...)
return FiniteDifferences.jacobian(ba.method, f, xs...)
end

function AD.gradient(ba::AD.FiniteDifferencesBackend, f, xs...)
return FiniteDifferences.grad(ba.method, f, xs...)
end

function AD.pushforward_function(ba::AD.FiniteDifferencesBackend, f, xs...)
return function pushforward(vs)
ws = FiniteDifferences.jvp(ba.method, f, tuple.(xs, vs)...)
Expand All @@ -32,4 +36,13 @@ function AD.pullback_function(ba::AD.FiniteDifferencesBackend, f, xs...)
end
end

# Ensure consistency with `value_and_pullback` function
function AD.value_and_pullback_function(ba::AD.FiniteDifferencesBackend, f, xs...)
value = f(xs...)
function fd_pullback(vs)
return FiniteDifferences.j′vp(ba.method, f, vs, xs...)
end
return value, fd_pullback
end

end # module
16 changes: 9 additions & 7 deletions ext/AbstractDifferentiationTrackerExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,18 @@ AD.primal_value(x::Tracker.TrackedReal) = Tracker.data(x)
AD.primal_value(x::Tracker.TrackedArray) = Tracker.data(x)
AD.primal_value(x::AbstractArray{<:Tracker.TrackedReal}) = Tracker.data.(x)

AD.@primitive function pullback_function(ba::AD.TrackerBackend, f, xs...)
value, back = Tracker.forward(f, xs...)
function pullback(ws)
if ws isa Tuple && !(value isa Tuple)
map(Tracker.data, back(only(ws)))
AD.@primitive function value_and_pullback_function(ba::AD.TrackerBackend, f, xs...)
_value, back = Tracker.forward(f, xs...)
value = map(Tracker.data, _value)
function tracker_pullback(ws)
_ws = if ws isa Tuple && !(value isa Tuple)
only(ws)
else
map(Tracker.data, back(ws))
ws
end
return map(Tracker.data, back(_ws))
end
return pullback
return value, tracker_pullback
end

function AD.derivative(::AD.TrackerBackend, f, xs::Number...)
Expand Down
62 changes: 20 additions & 42 deletions src/AbstractDifferentiation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@
function value_and_hessian(ab::AbstractBackend, f, x)
if x isa Tuple
# only support computation of Hessian for functions with single input argument
x = only(x)

Check warning on line 89 in src/AbstractDifferentiation.jl

View check run for this annotation

Codecov / codecov/patch

src/AbstractDifferentiation.jl#L89

Added line #L89 was not covered by tests
end

local value
Expand All @@ -108,7 +108,7 @@
function value_and_hessian(ab::HigherOrderBackend, f, x)
if x isa Tuple
# only support computation of Hessian for functions with single input argument
x = only(x)

Check warning on line 111 in src/AbstractDifferentiation.jl

View check run for this annotation

Codecov / codecov/patch

src/AbstractDifferentiation.jl#L111

Added line #L111 was not covered by tests
end
local value
primalcalled = false
Expand All @@ -125,7 +125,7 @@
function value_gradient_and_hessian(ab::AbstractBackend, f, x)
if x isa Tuple
# only support computation of Hessian for functions with single input argument
x = only(x)

Check warning on line 128 in src/AbstractDifferentiation.jl

View check run for this annotation

Codecov / codecov/patch

src/AbstractDifferentiation.jl#L128

Added line #L128 was not covered by tests
end
local value
primalcalled = false
Expand All @@ -142,7 +142,7 @@
function value_gradient_and_hessian(ab::HigherOrderBackend, f, x)
if x isa Tuple
# only support computation of Hessian for functions with single input argument
x = only(x)

Check warning on line 145 in src/AbstractDifferentiation.jl

View check run for this annotation

Codecov / codecov/patch

src/AbstractDifferentiation.jl#L145

Added line #L145 was not covered by tests
end
local value
primalcalled = false
Expand All @@ -169,7 +169,7 @@
newxs = xs .+ ds .* xds
return f(newxs...)
else
newx = only(xs) + ds * only(xds)

Check warning on line 172 in src/AbstractDifferentiation.jl

View check run for this annotation

Codecov / codecov/patch

src/AbstractDifferentiation.jl#L172

Added line #L172 was not covered by tests
return f(newx)
end
end, _zero.(xs, ds)...)
Expand Down Expand Up @@ -213,7 +213,7 @@

@inline _dot(x, y) = dot(x, y)
@inline function _dot(x::AbstractVector, y::UniformScaling)
return @inbounds dot(only(x), y.λ)

Check warning on line 216 in src/AbstractDifferentiation.jl

View check run for this annotation

Codecov / codecov/patch

src/AbstractDifferentiation.jl#L216

Added line #L216 was not covered by tests
end
@inline function _dot(x::AbstractVector, y::AbstractMatrix)
@assert size(y, 2) == 1
Expand All @@ -221,48 +221,27 @@
end

function pullback_function(ab::AbstractBackend, f, xs...)
return (ws) -> begin
return gradient(lowest(ab), (xs...,) -> begin
vs = f(xs...)
if ws isa Tuple
@assert length(vs) == length(ws)
return sum(Base.splat(_dot), zip(ws, vs))
else
return _dot(vs, ws)
end
end, xs...)
end
_, pbf = value_and_pullback_function(ab, f, xs...)
return pbf
end
function value_and_pullback_function(
ab::AbstractBackend,
f,
xs...,
)
return (ws) -> begin
local value
primalcalled = false
if ab isa AbstractFiniteDifference
value = primal_value(ab, nothing, f, xs)
primalcalled = true
end
if ws === nothing
vs = f(xs...)
if !primalcalled
value = primal_value(lowest(ab), vs, f, xs)
primalcalled = true
end
return value, nothing
end
pb = pullback_function(lowest(ab), (_xs...,) -> begin
value = f(xs...)
pbf = ws -> begin
gdalle marked this conversation as resolved.
Show resolved Hide resolved
return gradient(lowest(ab), (_xs...,) -> begin
vs = f(_xs...)
if !primalcalled
value = primal_value(lowest(ab), vs, f, xs)
primalcalled = true
if ws isa Tuple
@assert length(vs) == length(ws)
gdalle marked this conversation as resolved.
Show resolved Hide resolved
return sum(Base.splat(_dot), zip(ws, vs))

Check warning on line 238 in src/AbstractDifferentiation.jl

View check run for this annotation

Codecov / codecov/patch

src/AbstractDifferentiation.jl#L237-L238

Added lines #L237 - L238 were not covered by tests
else
return _dot(vs, ws)
end
return vs
end, xs...)(ws)
return value, pb
end, xs...)
end
return value, pbf
end

struct LazyDerivative{B, F, X}
Expand Down Expand Up @@ -494,8 +473,8 @@
name = fdef[:name]
if name == :pushforward_function
return define_pushforward_function_and_friends(fdef) |> esc
elseif name == :pullback_function
return define_pullback_function_and_friends(fdef) |> esc
elseif name == :value_and_pullback_function
return define_value_and_pullback_function_and_friends(fdef) |> esc
elseif name == :jacobian
return define_jacobian_and_friends(fdef) |> esc
elseif name == :primal_value
Expand Down Expand Up @@ -541,30 +520,29 @@
return funcs
end

function define_pullback_function_and_friends(fdef)
fdef[:name] = :($(AbstractDifferentiation).pullback_function)
function define_value_and_pullback_function_and_friends(fdef)
fdef[:name] = :($(AbstractDifferentiation).value_and_pullback_function)
args = fdef[:args]
funcs = quote
$(ExprTools.combinedef(fdef))
function $(AbstractDifferentiation).jacobian($(args...),)
value_and_pbf = $(value_and_pullback_function)($(args...),)
value, _ = value_and_pbf(nothing)
value, pbf = $(value_and_pullback_function)($(args...),)
identity_like = $(identity_matrix_like)(value)
if eltype(identity_like) <: Tuple{Vararg{AbstractMatrix}}
return map(identity_like) do identity_like_i
return mapreduce(vcat, $(_eachcol).(identity_like_i)...) do (cols...)
value_and_pbf(cols)[2]'
pbf(cols)'
end
end
elseif eltype(identity_like) <: AbstractMatrix
# needed for Hessian computation:
# value is a (grad,). Then, identity_like is a (matrix,).
# cols loops over columns of the matrix
return vcat.(mapslices(identity_like[1], dims=1) do cols
adjoint.(value_and_pbf((cols,))[2])
adjoint.(pbf((cols,)))
end ...)
else
return adjoint.(value_and_pbf(identity_like)[2])
return adjoint.(pbf(identity_like))
end
end
end
Expand All @@ -580,7 +558,7 @@
end

function define_primal_value(fdef)
fdef[:name] = :($(AbstractDifferentiation).primal_value)

Check warning on line 561 in src/AbstractDifferentiation.jl

View check run for this annotation

Codecov / codecov/patch

src/AbstractDifferentiation.jl#L561

Added line #L561 was not covered by tests
return ExprTools.combinedef(fdef)
end

Expand Down Expand Up @@ -629,14 +607,14 @@
include("../ext/AbstractDifferentiationChainRulesCoreExt.jl")
end
@static if !EXTENSIONS_SUPPORTED
function __init__()
@require DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5" begin
@require ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" include("../ext/AbstractDifferentiationForwardDiffExt.jl")
@require ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" include("../ext/AbstractDifferentiationReverseDiffExt.jl")

Check warning on line 613 in src/AbstractDifferentiation.jl

View check run for this annotation

Codecov / codecov/patch

src/AbstractDifferentiation.jl#L610-L613

Added lines #L610 - L613 were not covered by tests
end
@require FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" include("../ext/AbstractDifferentiationFiniteDifferencesExt.jl")
@require Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" include("../ext/AbstractDifferentiationTrackerExt.jl")
@require Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" include("../ext/AbstractDifferentiationZygoteExt.jl")

Check warning on line 617 in src/AbstractDifferentiation.jl

View check run for this annotation

Codecov / codecov/patch

src/AbstractDifferentiation.jl#L615-L617

Added lines #L615 - L617 were not covered by tests
end
end

Expand Down
29 changes: 13 additions & 16 deletions test/defaults.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,14 @@ struct FDMBackend3{A} <: AD.AbstractFiniteDifference
end
FDMBackend3() = FDMBackend3(central_fdm(5, 1))
const fdm_backend3 = FDMBackend3()
AD.@primitive function pullback_function(ab::FDMBackend3, f, xs...)
return function (vs)
AD.@primitive function value_and_pullback_function(ab::FDMBackend3, f, xs...)
value = f(xs...)
function fd3_pullback(vs)
# Supports only single output
if vs isa AbstractVector
return FDM.j′vp(ab.alg, f, vs, xs...)
else
return FDM.j′vp(ab.alg, f, only(vs), xs...)
end
_vs = vs isa AbstractVector ? vs : only(vs)
return FDM.j′vp(ab.alg, f, _vs, xs...)
end
return value, fd3_pullback
end
##

Expand Down Expand Up @@ -90,16 +89,14 @@ AD.primal_value(::ForwardDiffBackend2, ::Any, f, xs) = ForwardDiff.value.(f(xs..
## Zygote
struct ZygoteBackend1 <: AD.AbstractReverseMode end
const zygote_backend1 = ZygoteBackend1()
AD.@primitive function pullback_function(ab::ZygoteBackend1, f, xs...)
return function (vs)
# Supports only single output
_, back = Zygote.pullback(f, xs...)
if vs isa AbstractVector
back(vs)
else
back(only(vs))
end
AD.@primitive function value_and_pullback_function(ab::ZygoteBackend1, f, xs...)
# Supports only single output
value, back = Zygote.pullback(f, xs...)
function zygote_pullback(vs)
_vs = vs isa AbstractVector ? vs : only(vs)
return back(_vs)
end
return value, zygote_pullback
end

@testset "defaults" begin
Expand Down
13 changes: 13 additions & 0 deletions test/ruleconfig.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using AbstractDifferentiation
using ChainRulesCore
using Test
using Zygote

Expand Down Expand Up @@ -52,4 +53,16 @@ using Zygote
end
@test AD.jacobian(ad, f, [1, 2, 3], 3) == ([6.0 0.0 0.0; 0.0 6.0 0.0; 0.0 0.0 6.0], [2.0, 4.0, 6.0])
end

# issue #57
@testset "primal computation in rrule" begin
function myfunc(x)
@info "This should not be logged if I have an rrule"
x
end
ChainRulesCore.rrule(::typeof(myfunc), x) = (x, (y -> (NoTangent(), y)))

@test_logs Zygote.gradient(myfunc, 1) # nothing is logged
@test_logs AD.derivative(AD.ZygoteBackend(), myfunc, 1) # nothing is logged
end
end
9 changes: 6 additions & 3 deletions test/test_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,8 @@ function test_j′vp(backend; multiple_inputs=true, rng=Random.GLOBAL_RNG, test_
w = rand(rng, length(fjac(xvec, yvec)))
if multiple_inputs
pb1 = AD.pullback_function(backend, fjac, xvec, yvec)(w)
valvec, pb2 = AD.value_and_pullback_function(backend, fjac, xvec, yvec)(w)
valvec, pbf2 = AD.value_and_pullback_function(backend, fjac, xvec, yvec)
pb2 = pbf2(w)

if test_types
@test valvec isa Vector{Float64}
Expand All @@ -263,8 +264,10 @@ function test_j′vp(backend; multiple_inputs=true, rng=Random.GLOBAL_RNG, test_
@test yvec == yvec2
end

valvec1, pb1 = AD.value_and_pullback_function(backend, x -> fjac(x, yvec), xvec)(w)
valvec2, pb2 = AD.value_and_pullback_function(backend, y -> fjac(xvec, y), yvec)(w)
valvec1, pbf1 = AD.value_and_pullback_function(backend, x -> fjac(x, yvec), xvec)
pb1 = pbf1(w)
valvec2, pbf2 = AD.value_and_pullback_function(backend, y -> fjac(xvec, y), yvec)
pb2 = pbf2(w)
if test_types
@test valvec1 isa Vector{Float64}
@test valvec2 isa Vector{Float64}
Expand Down
Loading