Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Custom gradient macro proof of concept #123

Merged
merged 8 commits into from
Apr 12, 2020

Conversation

mohamed82008
Copy link
Member

@mohamed82008 mohamed82008 commented Mar 24, 2020

This PR implements a proof of concept for an @grad macro that enables users to define custom gradients for some functions in a user-friendly syntax, identical to Tracker's syntax. Please let me know if I made some obvious mistakes, this is my first contribution to ReverseDiff :)

Here is a usage example that is possible with this PR:

julia> using ReverseDiff, LinearAlgebra
[ Info: Precompiling ReverseDiff [37e2e3b7-166d-5795-8a7a-e32c996b4267]

julia> f(x) = dot(x, x)
f (generic function with 1 method)

julia> f(x::ReverseDiff.TrackedVector) = ReverseDiff.track(f, x)
f (generic function with 2 methods)

julia> ReverseDiff.@grad function f(x)
           println("The custom gradient has been used.")
           xv = ReverseDiff.value(x)
           return dot(xv, xv), ∇ -> (∇ * 2 * xv,)
       end

julia> x = rand(3);

julia> ReverseDiff.gradient(x -> dot(x, x), x)
3-element Array{Float64,1}:
 1.5585250070434835
 1.8206637829549495
 0.32560088933194864

julia> ReverseDiff.gradient(f, x)
The custom gradient has been used.
3-element Array{Float64,1}:
 1.5585250070434835
 1.8206637829549495
 0.32560088933194864

@mohamed82008
Copy link
Member Author

mohamed82008 commented Mar 24, 2020

I will add some tests once the general approach is approved.

Edit: added tests

@AStupidBear
Copy link

@mohamed82008 I want to do turing_inference (DiffEqBayes.jl) with ReverseDiff backend because it's considered the fastest. I need to define custom gradient for concrete_solve

function DiffEqBase.concrete_solve(prob::DiffEqBase.DEProblem, alg::DiffEqBase.DEAlgorithm, u0, p::ReverseDiff.TrackedArray, args...; kwargs...)
    ReverseDiff.track(concrete_solve, prob, alg, u0, p, args...; kwargs...)
end

function DiffEqBase.concrete_solve(prob::DiffEqBase.DEProblem, alg::DiffEqBase.DEAlgorithm, u0, p::Array{<:ReverseDiff.TrackedReal}, args...; kwargs...)
    # need to defined ReverseDiff.collect just like Tracker.collect
    ReverseDiff.track(concrete_solve, prob, alg, ReverseDiff.collect(u0), ReverseDiff.collect(u0), args...; kwargs...)
end

ReverseDiff.@grad function concrete_solve(prob, alg, u0, p::ReverseDiff.TrackedArray, args...; kwargs...)
    y, back = DiffEqBase._concrete_solve_adjoint(prob, alg, nothing, ReverseDiff.value(u0), ReverseDiff.value(p), args...; kwargs...)
    return y, back
    isnothing(save_idxs) && return y, back
    y[save_idxs, :], function (Δ)
        Δ′ = zero(ReverseDiff.value(y))
        Δ′[save_idxs, :] .= ReverseDiff.value(Δ)
        return back(Δ′)
    end
end

However, this PR still cannot handle keyword arguments.

@mohamed82008
Copy link
Member Author

@AStupidBear ReverseDiff doesn't seem very actively developed so I have the macro in TuringLang/DistributionsAD.jl#58 hopefully temporarily. I will try to add kwarg support before merging.

src/macros.jl Outdated Show resolved Hide resolved
@mohamed82008
Copy link
Member Author

I believe this is ready for a second review.

@ChrisRackauckas
Copy link
Member

Let's get a check from @oxinabox as well.

@mohamed82008
Copy link
Member Author

@oxinabox can I merge this?

@mohamed82008
Copy link
Member Author

Ping @oxinabox

@mohamed82008
Copy link
Member Author

Given the lack of response and given that the tests are passing, I will merge now and release. This is an independent feature so if there are bugs, I can fix them in patch releases later. I need this for other developments in TuringLang, so sorry for my lack of patience.

@mohamed82008 mohamed82008 merged commit a7441ce into JuliaDiff:master Apr 12, 2020
@ChrisRackauckas
Copy link
Member

Sounds good to me. No worries!

@mohamed82008 mohamed82008 deleted the mt/custom_grad_syntax branch April 13, 2020 14:00
@mohamed82008 mohamed82008 restored the mt/custom_grad_syntax branch April 13, 2020 14:00
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants