Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Assertion error in record_branches #5

Closed
baggepinnen opened this issue Aug 18, 2018 · 6 comments
Closed

Assertion error in record_branches #5

baggepinnen opened this issue Aug 18, 2018 · 6 comments

Comments

@baggepinnen
Copy link
Contributor

Julia v0.7 / Zygote v0.1

using Zygote
x      = randn(3)             # Input
v      = randn(3)             # Vector
H      = randn(3,3); H = H+H' # Hessian
f(x)   = 0.5*x'*(H*x)     # i'H*i function to take hessian of
fp  = Zygote.gradient(f,x)
fp(x)

ERROR: AssertionError: length(preds) <= 2
Stacktrace:
 [1] record_branches!(::Core.Compiler.IRCode) at /home/fredrikb/.julia/packages/Zygote/g8WMA/src/compiler/reverse.jl:76
 [2] #grad_ir#50(::Nothing, ::Function, ::Core.Compiler.IRCode) at /home/fredrikb/.julia/packages/Zygote/g8WMA/src/compiler/reverse.jl:319
 [3] (::getfield(Zygote, Symbol("#kw##grad_ir")))(::NamedTuple{(:varargs,),Tuple{Nothing}}, ::typeof(Zygote.grad_ir), ::Core.Compiler.IRCode) at ./none:0
 [4] _lookup_grad(::Type) at /home/fredrikb/.julia/packages/Zygote/g8WMA/src/compiler/emit.jl:118
 [5] #s21#533 at /home/fredrikb/.julia/packages/Zygote/g8WMA/src/compiler/interface2.jl:18 [inlined]
 [6] #s21#533(::Any, ::Any, ::Any) at ./none:0
 [7] (::Core.GeneratedFunctionStub)(::Any, ::Vararg{Any,N} where N) at ./boot.jl:506
 [8] _broadcast_getindex at ./broadcast.jl:525 [inlined]
 [9] (::Zygote.J{Tuple{typeof(Base.Broadcast._broadcast_getindex),Tuple{Float64,LinearAlgebra.Adjoint{Float64,Array{Float64,1}}},Int64},Any})(::NamedTuple{(:parent,),Tuple{Array{Float64,1}}}) at /home/fredrikb/.julia/packages/Zygote/g8WMA/src/compiler/interface2.jl:0
 [10] _getindex at ./broadcast.jl:571 [inlined]
 [11] (::Zygote.J{Tuple{typeof(Base.Broadcast._getindex),Tuple{Tuple{Float64,LinearAlgebra.Adjoint{Float64,Array{Float64,1}}}},Int64},Any})(::Tuple{NamedTuple{(:parent,),Tuple{Array{Float64,1}}}}) at /home/fredrikb/.julia/packages/Zygote/g8WMA/src/compiler/interface2.jl:0
 [12] _broadcast_getindex at ./broadcast.jl:546 [inlined]
 [13] (::Zygote.J{Tuple{typeof(Base.Broadcast._broadcast_getindex),Base.Broadcast.Broadcasted{Base.Broadcast.Style{Tuple},Nothing,typeof(LinearAlgebra.quasiparenta),Tuple{Tuple{Float64,LinearAlgebra.Adjoint{Float64,Array{Float64,1}}}}},Int64},Any})(::Array{Float64,1}) at /home/fredrikb/.julia/packages/Zygote/g8WMA/src/compiler/interface2.jl:0
 [14] #17 at ./broadcast.jl:922 [inlined]
 [15] (::Zygote.J{Tuple{getfield(Base.Broadcast, Symbol("##17#18")){Base.Broadcast.Broadcasted{Base.Broadcast.Style{Tuple},Nothing,typeof(LinearAlgebra.quasiparenta),Tuple{Tuple{Float64,LinearAlgebra.Adjoint{Float64,Array{Float64,1}}}}}},Int64},Any})(::Array{Float64,1}) at /home/fredrikb/.julia/packages/Zygote/g8WMA/src/compiler/interface2.jl:0
 [16] ntuple at ./tuple.jl:157 [inlined]
 [17] (::Zygote.J{Tuple{typeof(ntuple),getfield(Base.Broadcast, Symbol("##17#18")){Base.Broadcast.Broadcasted{Base.Broadcast.Style{Tuple},Nothing,typeof(LinearAlgebra.quasiparenta),Tuple{Tuple{Float64,LinearAlgebra.Adjoint{Float64,Array{Float64,1}}}}}},Val{2}},Any})(::Tuple{Float64,Array{Float64,1}}) at /home/fredrikb/.julia/packages/Zygote/g8WMA/src/compiler/interface2.jl:0
 [18] tuplebroadcast at ./broadcast.jl:922 [inlined]
 [19] (::Zygote.J{Tuple{typeof(Base.Broadcast.tuplebroadcast),Tuple{Float64,LinearAlgebra.Adjoint{Float64,Array{Float64,1}}},Base.Broadcast.Broadcasted{Base.Broadcast.Style{Tuple},Nothing,typeof(LinearAlgebra.quasiparenta),Tuple{Tuple{Float64,LinearAlgebra.Adjoint{Float64,Array{Float64,1}}}}}},Any})(::Tuple{Float64,Array{Float64,1}}) at /home/fredrikb/.julia/packages/Zygote/g8WMA/src/compiler/interface2.jl:0
 [20] copy at ./broadcast.jl:920 [inlined]
 [21] (::Zygote.J{Tuple{typeof(copy),Base.Broadcast.Broadcasted{Base.Broadcast.Style{Tuple},Nothing,typeof(LinearAlgebra.quasiparenta),Tuple{Tuple{Float64,LinearAlgebra.Adjoint{Float64,Array{Float64,1}}}}}},Any})(::Tuple{Float64,Array{Float64,1}}) at /home/fredrikb/.julia/packages/Zygote/g8WMA/src/compiler/interface2.jl:0
 [22] materialize at ./broadcast.jl:724 [inlined]
 [23] (::Zygote.J{Tuple{typeof(Base.Broadcast.materialize),Base.Broadcast.Broadcasted{Base.Broadcast.Style{Tuple},Nothing,typeof(LinearAlgebra.quasiparenta),Tuple{Tuple{Float64,LinearAlgebra.Adjoint{Float64,Array{Float64,1}}}}}},Any})(::Tuple{Float64,Array{Float64,1}}) at /home/fredrikb/.julia/packages/Zygote/g8WMA/src/compiler/interface2.jl:0
 [24] broadcast at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v0.7/LinearAlgebra/src/adjtrans.jl:185 [inlined]
 [25] (::Zygote.J{Tuple{typeof(broadcast),typeof(*),Float64,LinearAlgebra.Adjoint{Float64,Array{Float64,1}}},Any})(::LinearAlgebra.Transpose{Float64,Array{Float64,1}}) at /home/fredrikb/.julia/packages/Zygote/g8WMA/src/compiler/interface2.jl:0
 [26] * at ./arraymath.jl:52 [inlined]
 [27] (::Zygote.J{Tuple{typeof(*),Float64,LinearAlgebra.Adjoint{Float64,Array{Float64,1}}},Any})(::LinearAlgebra.Transpose{Float64,Array{Float64,1}}) at /home/fredrikb/.julia/packages/Zygote/g8WMA/src/compiler/interface2.jl:0
 [28] * at ./operators.jl:502 [inlined]
 [29] (::Zygote.J{Tuple{typeof(*),Float64,LinearAlgebra.Adjoint{Float64,Array{Float64,1}},Array{Float64,1}},Any})(::Int64) at /home/fredrikb/.julia/packages/Zygote/g8WMA/src/compiler/interface2.jl:0
 [30] f at ./REPL[5]:1 [inlined]
 [31] (::Zygote.J{Tuple{typeof(f),Array{Float64,1}},Any})(::Int64) at /home/fredrikb/.julia/packages/Zygote/g8WMA/src/compiler/interface2.jl:0
 [32] (::getfield(Zygote, Symbol("##51#52")){Zygote.J{Tuple{typeof(f),Array{Float64,1}},Any}})(::Int64) at /home/fredrikb/.julia/packages/Zygote/g8WMA/src/compiler/interface.jl:28
 [33] gradient(::Function, ::Array{Float64,1}) at /home/fredrikb/.julia/packages/Zygote/g8WMA/src/compiler/interface.jl:34
 [34] top-level scope at none:0
@MikeInnes
Copy link
Member

Problem seems to be this odd definition, which is called by 0.5x'. Until we have a better fix, 0.5(x'*H*x) will work.

@jekbradbury
Copy link

Those two definitions predate the new broadcast interface and feel like they might be inconsistent with its spirit (although I can’t quite see what the right replacement is). I’m going to be slightly obnoxious and tag @Sacha0 and @mbauman, because they’ll probably know pretty quickly if the definitions are wrong.

@Sacha0
Copy link

Sacha0 commented Sep 7, 2018

Those two definitions predate the new broadcast interface and feel like they might be inconsistent with its spirit (although I can’t quite see what the right replacement is).

You have the right of it I think :). While semantically correct, those definitions predate the broadcast interface overhaul by a few months, and likely should be expressed differently now. Best!

@baggepinnen
Copy link
Contributor Author

f(x)   = 0.5*(x'*(H*x))

does not fix the issue on latest master

julia> f(x)   = 0.5*(x'*(H*x))
f (generic function with 1 method)

julia> fp  = Zygote.gradient(f,x)
ERROR: MethodError: no method matching exprtype(::Core.Compiler.IRCode, ::String)
Closest candidates are:
  exprtype(::Core.Compiler.IRCode, ::Expr) at /local/home/fredrikb/.julia/dev/Zygote/src/tools/ir.jl:71
  exprtype(::Core.Compiler.IRCode, ::QuoteNode) at /local/home/fredrikb/.julia/dev/Zygote/src/tools/ir.jl:68
  exprtype(::Core.Compiler.IRCode, ::GlobalRef) at /local/home/fredrikb/.julia/dev/Zygote/src/tools/ir.jl:67
  ...
Stacktrace:
 [1] _broadcast_getindex_evalf at ./broadcast.jl:574 [inlined]
 [2] _broadcast_getindex at ./broadcast.jl:547 [inlined]
 [3] getindex at ./broadcast.jl:507 [inlined]
 [4] copyto_nonleaf!(::Array{DataType,1}, ::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1},Tuple{Base.OneTo{Int64}},typeof(Zygote.exprtype),Tuple{Base.RefValue{Core.Compiler.IRCode},Base.Broadcast.Extruded{Array{Any,1},Tuple{Bool},Tuple{Int64}}}}, ::Base.OneTo{Int64}, ::Int64, ::Int64) at ./broadcast.jl:899
 [5] copy(::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1},Tuple{Base.OneTo{Int64}},typeof(Zygote.exprtype),Tuple{Base.RefValue{Core.Compiler.IRCode},Array{Any,1}}}) at ./broadcast.jl:762
 [6] materialize(::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1},Nothing,typeof(Zygote.exprtype),Tuple{Base.RefValue{Core.Compiler.IRCode},Array{Any,1}}}) at ./broadcast.jl:724
 [7] record!(::Core.Compiler.IRCode) at /local/home/fredrikb/.julia/dev/Zygote/src/compiler/reverse.jl:132
 [8] #Primal#46(::Int64, ::Type, ::Core.Compiler.IRCode) at /local/home/fredrikb/.julia/dev/Zygote/src/compiler/reverse.jl:177
 [9] Type at ./none:0 [inlined]
 [10] #Adjoint#72 at /local/home/fredrikb/.julia/dev/Zygote/src/compiler/reverse.jl:375 [inlined]
 [11] (::getfield(Core, Symbol("#kw#Type")))(::NamedTuple{(:varargs,),Tuple{Int64}}, ::Type{Zygote.Adjoint}, ::Core.Compiler.IRCode) at ./none:0
 [12] _lookup_grad(::Type) at /local/home/fredrikb/.julia/dev/Zygote/src/compiler/emit.jl:121
 [13] #s18#627 at /local/home/fredrikb/.julia/dev/Zygote/src/compiler/interface2.jl:19 [inlined]
 [14] #s18#627(::Any, ::Any, ::Any) at ./none:0
 [15] (::Core.GeneratedFunctionStub)(::Any, ::Vararg{Any,N} where N) at ./boot.jl:506
 [16] f at ./REPL[9]:1 [inlined]
 [17] (::Zygote.J{Tuple{typeof(f),Array{Float64,1}},Tuple{typeof(f),Array{Float64,1},getfield(Zygote, Symbol("##1010#back2#569")){getfield(Zygote, Symbol("##567#568")){Transpose{Float64,Array{Float64,1}},Array{Float64,1}}},getfield(Zygote, Symbol("##1010#back2#569")){getfield(Zygote, Symbol("##567#568")){Array{Float64,2},Array{Float64,1}}},getfield(Zygote, Symbol("##1016#back2#572")){getfield(Zygote, Symbol("##570#571"))},Zygote.J{Tuple{typeof(hvcat),Tuple{Int64,Int64,Int64},Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64},Tuple{typeof(hvcat)}},getfield(Zygote, Symbol("##143#back2#115")){typeof(identity)}}})(::Int64) at /local/home/fredrikb/.julia/dev/Zygote/src/compiler/interface2.jl:0
 [18] (::getfield(Zygote, Symbol("##73#74")){Zygote.J{Tuple{typeof(f),Array{Float64,1}},Tuple{typeof(f),Array{Float64,1},getfield(Zygote, Symbol("##1010#back2#569")){getfield(Zygote, Symbol("##567#568")){Transpose{Float64,Array{Float64,1}},Array{Float64,1}}},getfield(Zygote, Symbol("##1010#back2#569")){getfield(Zygote, Symbol("##567#568")){Array{Float64,2},Array{Float64,1}}},getfield(Zygote, Symbol("##1016#back2#572")){getfield(Zygote, Symbol("##570#571"))},Zygote.J{Tuple{typeof(hvcat),Tuple{Int64,Int64,Int64},Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64},Tuple{typeof(hvcat)}},getfield(Zygote, Symbol("##143#back2#115")){typeof(identity)}}}})(::Int64) at /local/home/fredrikb/.julia/dev/Zygote/src/compiler/interface.jl:28
 [19] gradient(::Function, ::Array{Float64,1}) at /local/home/fredrikb/.julia/dev/Zygote/src/compiler/interface.jl:34

@baggepinnen
Copy link
Contributor Author

baggepinnen commented Sep 22, 2018

Code in my previous post now works on latest master. But the extended example below still produce the same error

using Zygote, LinearAlgebra
x        = randn(3)             # Input
v        = randn(3)             # Vector
H        = randn(3,3); H = H+H' # Hessian
f(x)     = 0.5*(x'*(H*x))

hvp    = H*v                  # True Hessian vector product
gg     = H*x                # True gradient
ggvp   = gg'v                 # True gradient vector product

@assert f'(x)  gg # Works on latest master
gvp(x) = f'(x)'v # Gradient vector product
@assert gvp(x)  ggvp # Works
Hvp(x) = gvp'(x) # This contains nested differentiation
@assert Hvp(x)  hvp # Error

MethodError: no method matching exprtype(::Core.Compiler.IRCode, ::String)
Closest candidates are:
  exprtype(::Core.Compiler.IRCode, !Matched::Expr) at /local/home/fredrikb/.julia/dev/Zygote/src/tools/ir.jl:54
  exprtype(::Core.Compiler.IRCode, !Matched::QuoteNode) at /local/home/fredrikb/.julia/dev/Zygote/src/tools/ir.jl:51
  exprtype(::Core.Compiler.IRCode, !Matched::GlobalRef) at /local/home/fredrikb/.julia/dev/Zygote/src/tools/ir.jl:50
  ...
_broadcast_getindex_evalf at broadcast.jl:574 [inlined]
_broadcast_getindex at broadcast.jl:547 [inlined]
getindex at broadcast.jl:507 [inlined]
copyto_nonleaf!(::Array{DataType,1}, ::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1},Tuple{Base.OneTo{Int64}},typeof(Zygote.exprtype),Tuple{Base.RefValue{Core.Compiler.IRCode},Base.Broadcast.Extruded{Array{Any,1},Tuple{Bool},Tuple{Int64}}}}, ::Base.OneTo{Int64}, ::Int64, ::Int64) at broadcast.jl:899
copy(::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1},Tuple{Base.OneTo{Int64}},typeof(Zygote.exprtype),Tuple{Base.RefValue{Core.Compiler.IRCode},Array{Any,1}}}) at broadcast.jl:762
materialize(::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1},Nothing,typeof(Zygote.exprtype),Tuple{Base.RefValue{Core.Compiler.IRCode},Array{Any,1}}}) at broadcast.jl:724
...

@MikeInnes
Copy link
Member

MikeInnes commented Nov 5, 2018

Fixed the original issue:

julia> derivative(x -> sum(0.5x'), [1, 2, 3])
3-element Array{Float64,1}:
 0.5
 0.5
 0.5

Nested differentiation is still ropey, but we'll figure that out separately.

Keno added a commit that referenced this issue Feb 24, 2019
Right now Zygote inserts stacks whenever it needs to use an ssa value
not defined in the first basic block. This is of course unnecessary.
The condition for needing stacks is that the basic block that defines
it is self-reachable (i.e. in a loop). Otherwise, we can simply insert
phi nodes to thread the desired SSA value through to the exit block
(we don't need to do anything in the adjoint, since the reversal of
the CFG ensures dominance). Removing stacks allows for both more
efficient code generation and enables higher order auto-diff (since
we use control flow in Zygote, but can't handle differentiating code
that contains stacks). The headline example is something like the following:

```
function foo(b, x)
    if b
        sin(x)
    else
        cos(x)
    end
end
```

Then looking at `@code_typed derivative(x->foo(true, x), 1.0)`, we get:

Before:
```
CodeInfo(
1 ── %1  = $(Expr(:foreigncall, :(:jl_alloc_array_1d), Array{Int8,1}, svec(Any, Int64), :(:ccall), 2, Array{Int8,1}, 0, 0))::Array{Int8,1}
│    %2  = $(Expr(:foreigncall, :(:jl_alloc_array_1d), Array{Any,1}, svec(Any, Int64), :(:ccall), 2, Array{Any,1}, 0, 0))::Array{Any,1}
│    %3  = $(Expr(:foreigncall, :(:jl_alloc_array_1d), Array{Any,1}, svec(Any, Int64), :(:ccall), 2, Array{Any,1}, 0, 0))::Array{Any,1}
│    %4  = Base.sin::typeof(sin)
│          invoke %4(_3::Float64)::Float64
│    %6  = %new(##334#335{Float64}, x)::##334#335{Float64}
│    %7  = %new(##758#back#336{##334#335{Float64}}, %6)::##758#back#336{##334#335{Float64}}
[snip]
23 ─ %52 = invoke %47(1::Int8)::Tuple{Nothing,Nothing,Any}
│    %53 = Base.getfield(%52, 3, true)::Any
└───       goto #24
24 ─       return %53
) => Any
```

After:
```
CodeInfo(
1 ─ %1 = Base.sin::typeof(sin)
│        invoke %1(_3::Float64)::Float64
│   %3 = Core.Intrinsics.not_int(true)::Bool
└──      goto #3 if not %3
2 ─      invoke Zygote.notnothing(nothing::Nothing)::Union{}
└──      $(Expr(:unreachable))::Union{}
3 ┄ %7 = invoke Zygote.cos(_3::Float64)::Float64
│   %8 = Base.mul_float(1.0, %7)::Float64
└──      goto #4
4 ─      goto #5
5 ─      goto #6
6 ─      goto #7
7 ─      return %8
) => Float64
```

Which is essentially perfect (there's a bit of junk left over, but LLVM
can take care of that. The only thing that doesn't get removed is the
useless invocation of `sin`, but that's a separate and known issue).
Keno added a commit that referenced this issue Mar 6, 2019
Right now Zygote inserts stacks whenever it needs to use an ssa value
not defined in the first basic block. This is of course unnecessary.
The condition for needing stacks is that the basic block that defines
it is self-reachable (i.e. in a loop). Otherwise, we can simply insert
phi nodes to thread the desired SSA value through to the exit block
(we don't need to do anything in the adjoint, since the reversal of
the CFG ensures dominance). Removing stacks allows for both more
efficient code generation and enables higher order auto-diff (since
we use control flow in Zygote, but can't handle differentiating code
that contains stacks). The headline example is something like the following:

```
function foo(b, x)
    if b
        sin(x)
    else
        cos(x)
    end
end
```

Then looking at `@code_typed derivative(x->foo(true, x), 1.0)`, we get:

Before:
```
CodeInfo(
1 ── %1  = $(Expr(:foreigncall, :(:jl_alloc_array_1d), Array{Int8,1}, svec(Any, Int64), :(:ccall), 2, Array{Int8,1}, 0, 0))::Array{Int8,1}
│    %2  = $(Expr(:foreigncall, :(:jl_alloc_array_1d), Array{Any,1}, svec(Any, Int64), :(:ccall), 2, Array{Any,1}, 0, 0))::Array{Any,1}
│    %3  = $(Expr(:foreigncall, :(:jl_alloc_array_1d), Array{Any,1}, svec(Any, Int64), :(:ccall), 2, Array{Any,1}, 0, 0))::Array{Any,1}
│    %4  = Base.sin::typeof(sin)
│          invoke %4(_3::Float64)::Float64
│    %6  = %new(##334#335{Float64}, x)::##334#335{Float64}
│    %7  = %new(##758#back#336{##334#335{Float64}}, %6)::##758#back#336{##334#335{Float64}}
[snip]
23 ─ %52 = invoke %47(1::Int8)::Tuple{Nothing,Nothing,Any}
│    %53 = Base.getfield(%52, 3, true)::Any
└───       goto #24
24 ─       return %53
) => Any
```

After:
```
CodeInfo(
1 ─ %1 = Base.sin::typeof(sin)
│        invoke %1(_3::Float64)::Float64
│   %3 = Core.Intrinsics.not_int(true)::Bool
└──      goto #3 if not %3
2 ─      invoke Zygote.notnothing(nothing::Nothing)::Union{}
└──      $(Expr(:unreachable))::Union{}
3 ┄ %7 = invoke Zygote.cos(_3::Float64)::Float64
│   %8 = Base.mul_float(1.0, %7)::Float64
└──      goto #4
4 ─      goto #5
5 ─      goto #6
6 ─      goto #7
7 ─      return %8
) => Float64
```

Which is essentially perfect (there's a bit of junk left over, but LLVM
can take care of that. The only thing that doesn't get removed is the
useless invocation of `sin`, but that's a separate and known issue).
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants