-
-
Notifications
You must be signed in to change notification settings - Fork 216
Closed
Description
This is more or less a copy of FluxML/Flux.jl#129, adopted to Zygote in order to track the progress of HVPs
using Zygote: gradient
inp = randn(3) # Input
v = randn(3) # Vector
H = randn(3,3); H = H+H' # Hessian
f(inp,H) = 0.5*inp'H*inp # i'H*i function to take hessian of
hvp = H*v # True Hessian vector product
gg = H*inp # True gradient
ggvp = gg'v # True gradient vector product
Zygote.refresh()
g(x) = gradient(x->f(x,H),x)[1] # gradient function
gvp(x) = g(x)⋅v
@assert gvp(inp) ≈ ggvp # Correct until here == gg'v
dgvp(x) = gradient(gvp, x)
dgvp(inp) # ErrorsCan't differentiate foreigncall expression
error(::String) at error.jl:33
get at abstractdict.jl:595 [inlined]
(::typeof(∂(get)))(::Nothing) at interface2.jl:0
in at abstractdict.jl:665 [inlined]
(::typeof(∂(in)))(::Nothing) at interface2.jl:0
haskey at abstractdict.jl:17 [inlined]
(::typeof(∂(haskey)))(::Nothing) at interface2.jl:0
macro expansion at lib.jl:37 [inlined]
accum_param at lib.jl:35 [inlined]
(::typeof(∂(Zygote.accum_param)))(::Nothing) at interface2.jl:0
#137 at lib.jl:58 [inlined]
(::typeof(∂(λ)))(::Nothing) at interface2.jl:0
#209#back at grad.jl:41 [inlined]
(::typeof(∂(λ)))(::Nothing) at interface2.jl:0
#19 at hvp.jl:10 [inlined]
(::typeof(∂(λ)))(::Tuple{Nothing,Array{Float64,1}}) at interface2.jl:0
#73 at interface.jl:38 [inlined]
(::typeof(∂(λ)))(::Tuple{Array{Float64,1}}) at interface2.jl:0
gradient at interface.jl:44 [inlined]
(::typeof(∂(Zygote.gradient)))(::Tuple{Array{Float64,1}}) at interface2.jl:0
g at hvp.jl:10 [inlined]
(::typeof(∂(g)))(::Array{Float64,1}) at interface2.jl:0
gvp at hvp.jl:11 [inlined]
(::typeof(∂(gvp)))(::Int8) at interface2.jl:0
(::getfield(Zygote, Symbol("##73#74")){typeof(∂(gvp))})(::Int8) at interface.jl:38
gradient(::Function, ::Array{Float64,1}) at interface.jl:44
dgvp(::Array{Float64,1}) at hvp.jl:13
top-level scope at none:0
Metadata
Metadata
Assignees
Labels
No labels