Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Callable struct #144

Closed
wants to merge 2 commits into from
Closed

Conversation

yiyuezhuo
Copy link

Fix #135

Though master branch (without this PR) local unit test will fail at

running LinAlgTests...
  testing Array -> Number functions: `sum`...
  testing Array -> Number functions: `det`...
  testing Array -> Number functions: `mean`...
  testing Array -> Number functions: `#18`...
  testing Array -> Number functions: `#20`...
ERROR: LoadError: LoadError: MethodError: *(::LinearAlgebra.Adjoint{ReverseDiff.TrackedReal{Float64,Float64,ReverseDiff.TrackedArray{Float64,Float64,1,Array{Float64,1},Array{Float64,1}}},ReverseDiff.TrackedArray{Float64,Float64,1,Array{Float64,1},Array{Float64,1}}}, ::ReverseDiff.TrackedArray{Float64,Float64,1,Array{Float64,1},Array{Float64,1}}) is ambiguous. Candidates:
  *(u::LinearAlgebra.Adjoint{var"#s826",var"#s8261"} where var"#s8261"<:(AbstractArray{T,1} where T) where var"#s826"<:Number, v::AbstractArray{var"#s825",1} where var"#s825"<:Number) in LinearAlgebra at /home/yiyuezhuo/Downloads/julia-1.5.0/share/julia/stdlib/v1.5/LinearAlgebra/src/adjtrans.jl:283
  *(A::LinearAlgebra.Adjoint{var"#s92",var"#s91"} where var"#s91"<:(ReverseDiff.TrackedArray{T,D,1,VA,DA}where DA where VA) where var"#s92", B::ReverseDiff.TrackedArray{T,D,1,VA,DA} where DA where VA) where {T<:Real, D} in ReverseDiff at /home/yiyuezhuo/.julia/dev/ReverseDiff/src/derivatives/linalg/arithmetic.jl:254
Possible fix, define
  *(::LinearAlgebra.Adjoint{var"#s92",var"#s91"} where var"#s91"<:(ReverseDiff.TrackedArray{T,D,1,VA,DA} where DA where VA) where var"#s92"<:Number, ::ReverseDiff.TrackedArray{T,D,1,VA,DA} where DA where VA) where {T<:Real, D}

This PR doesn't try to solve this problem, so CI will fail I guess. But my local unit test at least failed at same location as master branch while added unit test CallableStructTests.jl will pass:

using ReverseDiff, Test

# https://github.com/JuliaDiff/ReverseDiff.jl/issues/135

struct Over{T} den::T end

(o::Over)(x) = x ./ o.den

(o::Over)(x::ReverseDiff.TrackedArray) = ReverseDiff.track(o, x)

ReverseDiff.@grad function (o::Over)(x)
    # abused gradient :/ but we can leverage it to test whether it come from custom grad or "normal over".
    ReverseDiff.value(x) ./ o.den, Δ ->.* o.den,) 
end


struct Over2 end

(o::Over2)(x) = x ./ 2

(o::Over2)(x::ReverseDiff.TrackedArray) = ReverseDiff.track(o, x)

ReverseDiff.@grad function (o::Over2)(x)
    ReverseDiff.value(x) ./ 2, Δ ->.* 2,)
end

o3 = Over(3.)
o2 = Over2()

g3 = ReverseDiff.gradient([2., 1., 3.]) do x
    sum(o3(x))
end

g2 = ReverseDiff.gradient([2., 1., 3.]) do x
    sum(o2(x))
end

@test g3 == [3., 3., 3.]
@test g2 == [2., 2., 2.]

@yiyuezhuo
Copy link
Author

Test passed.. So I think my local failed test is due to Julia 1.5

@mcabbott
Copy link
Member

The test failure on Julia 1.5 looks identical with the tagged version of this package. Needs fixing but it's an orthogonal concern to this PR.

running LinAlgTests...
  testing Array -> Number functions: `sum`...
  testing Array -> Number functions: `det`...
  testing Array -> Number functions: `mean`...
  testing Array -> Number functions: `#18`...
  testing Array -> Number functions: `#20`...
ERROR: LoadError: LoadError: MethodError: *(::LinearAlgebra.Adjoint{ReverseDiff.TrackedReal{Float64,Float64,ReverseDiff.TrackedArray{Float64,Float64,1,Array{Float64,1},Array{Float64,1}}},ReverseDiff.TrackedArray{Float64,Float64,1,Array{Float64,1},Array{Float64,1}}}, ::ReverseDiff.TrackedArray{Float64,Float64,1,Array{Float64,1},Array{Float64,1}}) is ambiguous. Candidates:
  *(u::LinearAlgebra.Adjoint{var"#s826",var"#s8261"} where var"#s8261"<:(AbstractArray{T,1} where T) where var"#s826"<:Number, v::AbstractArray{var"#s825",1} where var"#s825"<:Number) in LinearAlgebra at /Applications/Julia-1.5.app/Contents/Resources/julia/share/julia/stdlib/v1.5/LinearAlgebra/src/adjtrans.jl:283
  *(A::LinearAlgebra.Adjoint{var"#s92",var"#s91"} where var"#s91"<:(ReverseDiff.TrackedArray{T,D,1,VA,DA} where DA where VA) where var"#s92", B::ReverseDiff.TrackedArray{T,D,1,VA,DA} where DA where VA) where {T<:Real, D} in ReverseDiff at /Users/me/.julia/dev/ReverseDiff/src/derivatives/linalg/arithmetic.jl:254
Possible fix, define
  *(::LinearAlgebra.Adjoint{var"#s92",var"#s91"} where var"#s91"<:(ReverseDiff.TrackedArray{T,D,1,VA,DA} where DA where VA) where var"#s92"<:Number, ::ReverseDiff.TrackedArray{T,D,1,VA,DA} where DA where VA) where {T<:Real, D}
Stacktrace:
 [1] (::Main.LinAlgTests.var"#20#26")(::ReverseDiff.TrackedArray{Float64,Float64,2,Array{Float64,2},Array{Float64,2}}) at /Users/me/.julia/dev/ReverseDiff/test/derivatives/LinAlgTests.jl:215
 [2] test_arr2num(::Main.LinAlgTests.var"#20#26", ::Array{Float64,2}, ::Array{ReverseDiff.AbstractInstruction,1}; ignore_tape_length::Bool) at /Users/me/.julia/dev/ReverseDiff/test/derivatives/LinAlgTests.jl:15
 [3] top-level scope at /Users/me/.julia/dev/ReverseDiff/test/derivatives/LinAlgTests.jl:223
 [4] include(::String) at /Applications/Julia-1.5.app/Contents/Resources/julia/lib/julia/sys.dylib:?
 [5] top-level scope at ./timing.jl:233 [inlined]
 [6] top-level scope at /Users/me/.julia/dev/ReverseDiff/test/runtests.jl:0
 [7] include(::String) at /Applications/Julia-1.5.app/Contents/Resources/julia/lib/julia/sys.dylib:?
 [8] top-level scope at none:6
in expression starting at /Users/me/.julia/dev/ReverseDiff/test/derivatives/LinAlgTests.jl:214
in expression starting at /Users/me/.julia/dev/ReverseDiff/test/runtests.jl:22
ERROR: Package ReverseDiff errored during testing

@mohamed82008
Copy link
Member

The only problem I have with this PR is that it will probably give an error when the closure has a tracked field. Can you please add a test case like that?

@mohamed82008
Copy link
Member

Hmm I don't think there is currently a correct way to implement this generically. However, we already have ReverseDiff.NotTracked which can be used to wrap a struct that we know doesn't hold any tracked variables to allow for efficient broadcasting over constant structs and TrackedArrays. I think a similar trick can be used here where custom gradients can be defined only for callable struct wrapped in NotTracked. Then it is the user's responsibility at call-site to wrap the closure with NotTracked before calling it to ensure that it is treated as an AD primitive.

Copy link
Member

@mohamed82008 mohamed82008 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please see the comments above.

@yiyuezhuo
Copy link
Author

Sorry for my delay, I'm meeting a deadline last week.

@mohamed82008 , I don't know the exact meaning of "closure has a tracked field", do you mean that a struct field is tracked?

@yiyuezhuo
Copy link
Author

Close to encourage better solutions.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Custom gradient for callable struct
3 participants