New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Assertion error in record_branches #5
Comments
Problem seems to be this odd definition, which is called by |
Those two definitions predate the new broadcast interface and feel like they might be inconsistent with its spirit (although I can’t quite see what the right replacement is). I’m going to be slightly obnoxious and tag @Sacha0 and @mbauman, because they’ll probably know pretty quickly if the definitions are wrong. |
You have the right of it I think :). While semantically correct, those definitions predate the broadcast interface overhaul by a few months, and likely should be expressed differently now. Best! |
f(x) = 0.5*(x'*(H*x)) does not fix the issue on latest master julia> f(x) = 0.5*(x'*(H*x))
f (generic function with 1 method)
julia> fp = Zygote.gradient(f,x)
ERROR: MethodError: no method matching exprtype(::Core.Compiler.IRCode, ::String)
Closest candidates are:
exprtype(::Core.Compiler.IRCode, ::Expr) at /local/home/fredrikb/.julia/dev/Zygote/src/tools/ir.jl:71
exprtype(::Core.Compiler.IRCode, ::QuoteNode) at /local/home/fredrikb/.julia/dev/Zygote/src/tools/ir.jl:68
exprtype(::Core.Compiler.IRCode, ::GlobalRef) at /local/home/fredrikb/.julia/dev/Zygote/src/tools/ir.jl:67
...
Stacktrace:
[1] _broadcast_getindex_evalf at ./broadcast.jl:574 [inlined]
[2] _broadcast_getindex at ./broadcast.jl:547 [inlined]
[3] getindex at ./broadcast.jl:507 [inlined]
[4] copyto_nonleaf!(::Array{DataType,1}, ::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1},Tuple{Base.OneTo{Int64}},typeof(Zygote.exprtype),Tuple{Base.RefValue{Core.Compiler.IRCode},Base.Broadcast.Extruded{Array{Any,1},Tuple{Bool},Tuple{Int64}}}}, ::Base.OneTo{Int64}, ::Int64, ::Int64) at ./broadcast.jl:899
[5] copy(::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1},Tuple{Base.OneTo{Int64}},typeof(Zygote.exprtype),Tuple{Base.RefValue{Core.Compiler.IRCode},Array{Any,1}}}) at ./broadcast.jl:762
[6] materialize(::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1},Nothing,typeof(Zygote.exprtype),Tuple{Base.RefValue{Core.Compiler.IRCode},Array{Any,1}}}) at ./broadcast.jl:724
[7] record!(::Core.Compiler.IRCode) at /local/home/fredrikb/.julia/dev/Zygote/src/compiler/reverse.jl:132
[8] #Primal#46(::Int64, ::Type, ::Core.Compiler.IRCode) at /local/home/fredrikb/.julia/dev/Zygote/src/compiler/reverse.jl:177
[9] Type at ./none:0 [inlined]
[10] #Adjoint#72 at /local/home/fredrikb/.julia/dev/Zygote/src/compiler/reverse.jl:375 [inlined]
[11] (::getfield(Core, Symbol("#kw#Type")))(::NamedTuple{(:varargs,),Tuple{Int64}}, ::Type{Zygote.Adjoint}, ::Core.Compiler.IRCode) at ./none:0
[12] _lookup_grad(::Type) at /local/home/fredrikb/.julia/dev/Zygote/src/compiler/emit.jl:121
[13] #s18#627 at /local/home/fredrikb/.julia/dev/Zygote/src/compiler/interface2.jl:19 [inlined]
[14] #s18#627(::Any, ::Any, ::Any) at ./none:0
[15] (::Core.GeneratedFunctionStub)(::Any, ::Vararg{Any,N} where N) at ./boot.jl:506
[16] f at ./REPL[9]:1 [inlined]
[17] (::Zygote.J{Tuple{typeof(f),Array{Float64,1}},Tuple{typeof(f),Array{Float64,1},getfield(Zygote, Symbol("##1010#back2#569")){getfield(Zygote, Symbol("##567#568")){Transpose{Float64,Array{Float64,1}},Array{Float64,1}}},getfield(Zygote, Symbol("##1010#back2#569")){getfield(Zygote, Symbol("##567#568")){Array{Float64,2},Array{Float64,1}}},getfield(Zygote, Symbol("##1016#back2#572")){getfield(Zygote, Symbol("##570#571"))},Zygote.J{Tuple{typeof(hvcat),Tuple{Int64,Int64,Int64},Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64},Tuple{typeof(hvcat)}},getfield(Zygote, Symbol("##143#back2#115")){typeof(identity)}}})(::Int64) at /local/home/fredrikb/.julia/dev/Zygote/src/compiler/interface2.jl:0
[18] (::getfield(Zygote, Symbol("##73#74")){Zygote.J{Tuple{typeof(f),Array{Float64,1}},Tuple{typeof(f),Array{Float64,1},getfield(Zygote, Symbol("##1010#back2#569")){getfield(Zygote, Symbol("##567#568")){Transpose{Float64,Array{Float64,1}},Array{Float64,1}}},getfield(Zygote, Symbol("##1010#back2#569")){getfield(Zygote, Symbol("##567#568")){Array{Float64,2},Array{Float64,1}}},getfield(Zygote, Symbol("##1016#back2#572")){getfield(Zygote, Symbol("##570#571"))},Zygote.J{Tuple{typeof(hvcat),Tuple{Int64,Int64,Int64},Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64},Tuple{typeof(hvcat)}},getfield(Zygote, Symbol("##143#back2#115")){typeof(identity)}}}})(::Int64) at /local/home/fredrikb/.julia/dev/Zygote/src/compiler/interface.jl:28
[19] gradient(::Function, ::Array{Float64,1}) at /local/home/fredrikb/.julia/dev/Zygote/src/compiler/interface.jl:34 |
Code in my previous post now works on latest master. But the extended example below still produce the same error using Zygote, LinearAlgebra
x = randn(3) # Input
v = randn(3) # Vector
H = randn(3,3); H = H+H' # Hessian
f(x) = 0.5*(x'*(H*x))
hvp = H*v # True Hessian vector product
gg = H*x # True gradient
ggvp = gg'v # True gradient vector product
@assert f'(x) ≈ gg # Works on latest master
gvp(x) = f'(x)'v # Gradient vector product
@assert gvp(x) ≈ ggvp # Works
Hvp(x) = gvp'(x) # This contains nested differentiation
@assert Hvp(x) ≈ hvp # Error
MethodError: no method matching exprtype(::Core.Compiler.IRCode, ::String)
Closest candidates are:
exprtype(::Core.Compiler.IRCode, !Matched::Expr) at /local/home/fredrikb/.julia/dev/Zygote/src/tools/ir.jl:54
exprtype(::Core.Compiler.IRCode, !Matched::QuoteNode) at /local/home/fredrikb/.julia/dev/Zygote/src/tools/ir.jl:51
exprtype(::Core.Compiler.IRCode, !Matched::GlobalRef) at /local/home/fredrikb/.julia/dev/Zygote/src/tools/ir.jl:50
...
_broadcast_getindex_evalf at broadcast.jl:574 [inlined]
_broadcast_getindex at broadcast.jl:547 [inlined]
getindex at broadcast.jl:507 [inlined]
copyto_nonleaf!(::Array{DataType,1}, ::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1},Tuple{Base.OneTo{Int64}},typeof(Zygote.exprtype),Tuple{Base.RefValue{Core.Compiler.IRCode},Base.Broadcast.Extruded{Array{Any,1},Tuple{Bool},Tuple{Int64}}}}, ::Base.OneTo{Int64}, ::Int64, ::Int64) at broadcast.jl:899
copy(::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1},Tuple{Base.OneTo{Int64}},typeof(Zygote.exprtype),Tuple{Base.RefValue{Core.Compiler.IRCode},Array{Any,1}}}) at broadcast.jl:762
materialize(::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1},Nothing,typeof(Zygote.exprtype),Tuple{Base.RefValue{Core.Compiler.IRCode},Array{Any,1}}}) at broadcast.jl:724
... |
Fixed the original issue: julia> derivative(x -> sum(0.5x'), [1, 2, 3])
3-element Array{Float64,1}:
0.5
0.5
0.5 Nested differentiation is still ropey, but we'll figure that out separately. |
Right now Zygote inserts stacks whenever it needs to use an ssa value not defined in the first basic block. This is of course unnecessary. The condition for needing stacks is that the basic block that defines it is self-reachable (i.e. in a loop). Otherwise, we can simply insert phi nodes to thread the desired SSA value through to the exit block (we don't need to do anything in the adjoint, since the reversal of the CFG ensures dominance). Removing stacks allows for both more efficient code generation and enables higher order auto-diff (since we use control flow in Zygote, but can't handle differentiating code that contains stacks). The headline example is something like the following: ``` function foo(b, x) if b sin(x) else cos(x) end end ``` Then looking at `@code_typed derivative(x->foo(true, x), 1.0)`, we get: Before: ``` CodeInfo( 1 ── %1 = $(Expr(:foreigncall, :(:jl_alloc_array_1d), Array{Int8,1}, svec(Any, Int64), :(:ccall), 2, Array{Int8,1}, 0, 0))::Array{Int8,1} │ %2 = $(Expr(:foreigncall, :(:jl_alloc_array_1d), Array{Any,1}, svec(Any, Int64), :(:ccall), 2, Array{Any,1}, 0, 0))::Array{Any,1} │ %3 = $(Expr(:foreigncall, :(:jl_alloc_array_1d), Array{Any,1}, svec(Any, Int64), :(:ccall), 2, Array{Any,1}, 0, 0))::Array{Any,1} │ %4 = Base.sin::typeof(sin) │ invoke %4(_3::Float64)::Float64 │ %6 = %new(##334#335{Float64}, x)::##334#335{Float64} │ %7 = %new(##758#back#336{##334#335{Float64}}, %6)::##758#back#336{##334#335{Float64}} [snip] 23 ─ %52 = invoke %47(1::Int8)::Tuple{Nothing,Nothing,Any} │ %53 = Base.getfield(%52, 3, true)::Any └─── goto #24 24 ─ return %53 ) => Any ``` After: ``` CodeInfo( 1 ─ %1 = Base.sin::typeof(sin) │ invoke %1(_3::Float64)::Float64 │ %3 = Core.Intrinsics.not_int(true)::Bool └── goto #3 if not %3 2 ─ invoke Zygote.notnothing(nothing::Nothing)::Union{} └── $(Expr(:unreachable))::Union{} 3 ┄ %7 = invoke Zygote.cos(_3::Float64)::Float64 │ %8 = Base.mul_float(1.0, %7)::Float64 └── goto #4 4 ─ goto #5 5 ─ goto #6 6 ─ goto #7 7 ─ return %8 ) => Float64 ``` Which is essentially perfect (there's a bit of junk left over, but LLVM can take care of that. The only thing that doesn't get removed is the useless invocation of `sin`, but that's a separate and known issue).
Right now Zygote inserts stacks whenever it needs to use an ssa value not defined in the first basic block. This is of course unnecessary. The condition for needing stacks is that the basic block that defines it is self-reachable (i.e. in a loop). Otherwise, we can simply insert phi nodes to thread the desired SSA value through to the exit block (we don't need to do anything in the adjoint, since the reversal of the CFG ensures dominance). Removing stacks allows for both more efficient code generation and enables higher order auto-diff (since we use control flow in Zygote, but can't handle differentiating code that contains stacks). The headline example is something like the following: ``` function foo(b, x) if b sin(x) else cos(x) end end ``` Then looking at `@code_typed derivative(x->foo(true, x), 1.0)`, we get: Before: ``` CodeInfo( 1 ── %1 = $(Expr(:foreigncall, :(:jl_alloc_array_1d), Array{Int8,1}, svec(Any, Int64), :(:ccall), 2, Array{Int8,1}, 0, 0))::Array{Int8,1} │ %2 = $(Expr(:foreigncall, :(:jl_alloc_array_1d), Array{Any,1}, svec(Any, Int64), :(:ccall), 2, Array{Any,1}, 0, 0))::Array{Any,1} │ %3 = $(Expr(:foreigncall, :(:jl_alloc_array_1d), Array{Any,1}, svec(Any, Int64), :(:ccall), 2, Array{Any,1}, 0, 0))::Array{Any,1} │ %4 = Base.sin::typeof(sin) │ invoke %4(_3::Float64)::Float64 │ %6 = %new(##334#335{Float64}, x)::##334#335{Float64} │ %7 = %new(##758#back#336{##334#335{Float64}}, %6)::##758#back#336{##334#335{Float64}} [snip] 23 ─ %52 = invoke %47(1::Int8)::Tuple{Nothing,Nothing,Any} │ %53 = Base.getfield(%52, 3, true)::Any └─── goto #24 24 ─ return %53 ) => Any ``` After: ``` CodeInfo( 1 ─ %1 = Base.sin::typeof(sin) │ invoke %1(_3::Float64)::Float64 │ %3 = Core.Intrinsics.not_int(true)::Bool └── goto #3 if not %3 2 ─ invoke Zygote.notnothing(nothing::Nothing)::Union{} └── $(Expr(:unreachable))::Union{} 3 ┄ %7 = invoke Zygote.cos(_3::Float64)::Float64 │ %8 = Base.mul_float(1.0, %7)::Float64 └── goto #4 4 ─ goto #5 5 ─ goto #6 6 ─ goto #7 7 ─ return %8 ) => Float64 ``` Which is essentially perfect (there's a bit of junk left over, but LLVM can take care of that. The only thing that doesn't get removed is the useless invocation of `sin`, but that's a separate and known issue).
Julia v0.7 / Zygote v0.1
The text was updated successfully, but these errors were encountered: