New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add adjoint for pairs(::NamedTuple)
and Pair
#452
Conversation
julia> @nograd pairs
julia> Zygote.gradient(()->sum(neural_ode(dudt,x,tspan,Tsit5(),save_everystep=false,save_start=false,sensealg=BacksolveAdjoint())),p)
Grads(...) |
Thanks! FWIW, a minimal reproducer is julia> foo(;kw...) = 1
foo (generic function with 1 method)
julia> gradient(() -> foo(a=1,b=2.0))
ERROR: Can't differentiate gc_preserve_end expression (This would be good to have as a test case.) This works if all kwargs are of the same type; it would be great to understand why that is, to make sure we've got the whole issue fixed. |
pairs(::NamedTuple)
pairs(::NamedTuple)
and Pair
CI failure looks unrelated. @MikeInnes could you review again? |
The patch looks good to me. I'd still like to understand what when wrong in detail before merging, but I can look into that myself if it's not already clear. bors try |
tryBuild succeeded |
@MikeInnes what should we do here? |
bump |
related to #669 , but of the tests in this PR @test (x->10*pairs((a=x, b=2))[1])'(100) === 10
@test (x->10*pairs((a=x, b=2))[2])'(100) === 0
foo(;kw...) = 1
@test gradient(() -> foo(a=1,b=2.0)) === ()
@test (x->10*(x => 2)[1])'(100) === 10
@test (x->10*(x => 2)[2])'(100) === 0 only the 3rd passes on current master, so we should merge this as well |
bors r+ |
Build succeeded: |
@@ -115,3 +115,11 @@ end | |||
end | |||
return pairs(t), pairs_namedtuple | |||
end | |||
|
|||
@adjoint function Base.getfield(p::Pair, i::Int) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How to write the adjoint method for LinearAlgebra.diagm of Pair{Int64,Array{Float64,1}}
Fix SciML/DiffEqFlux.jl#105