Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ERROR: throw not supported #11

Open
chriselrod opened this issue Sep 6, 2018 · 3 comments
Open

ERROR: throw not supported #11

chriselrod opened this issue Sep 6, 2018 · 3 comments

Comments

@chriselrod
Copy link

chriselrod commented Sep 6, 2018

julia> using StaticArrays

julia> import Zygote: gradient

julia> const S = (@SMatrix randn(6,4)) |> x -> x' * x;

julia> g(x) = 0.5 * x' * S * x
g (generic function with 1 method)

julia> x = @SVector randn(4);

julia> g(x)
1.8359824145349797

julia> gradient(g, x)
ERROR: `throw` not supported
Stacktrace:
 [1] error(::String) at ./error.jl:33
 [2] merge_returns(::Core.Compiler.IRCode) at /home/chris/.julia/packages/Zygote/zd432/src/compiler/reverse.jl:15
 [3] #Primal#46(::Nothing, ::Type, ::Core.Compiler.IRCode) at /home/chris/.julia/packages/Zygote/zd432/src/compiler/reverse.jl:176
 [4] Type at ./none:0 [inlined]
 [5] #Adjoint#72 at /home/chris/.julia/packages/Zygote/zd432/src/compiler/reverse.jl:375 [inlined]
 [6] (::getfield(Core, Symbol("#kw#Type")))(::NamedTuple{(:varargs,),Tuple{Nothing}}, ::Type{Zygote.Adjoint}, ::Core.Compiler.IRCode) at ./none:0
 [7] _lookup_grad(::Type) at /home/chris/.julia/packages/Zygote/zd432/src/compiler/emit.jl:116
 [8] #s18#615 at /home/chris/.julia/packages/Zygote/zd432/src/compiler/interface2.jl:19 [inlined]
 [9] #s18#615(::Any, ::Any, ::Any) at ./none:0
 [10] (::Core.GeneratedFunctionStub)(::Any, ::Vararg{Any,N} where N) at ./boot.jl:506
 [11] _broadcast_getindex at ./broadcast.jl:525 [inlined]
 [12] _getindex at ./broadcast.jl:571 [inlined]
 [13] (::Zygote.J{Tuple{typeof(Base.Broadcast._getindex),Tuple{Tuple{Float64,Float64}},Int64},Tuple{typeof(Base.Broadcast._getindex),Tuple{Tuple{Float64,Float64}},Int64,getfield(Zygote, Symbol("##148#back2#115")){typeof(identity)},Zygote.J{Tuple{typeof(Base.Broadcast._broadcast_getindex),Tuple{Float64,Float64},Int64},Tuple{typeof(Base.Broadcast._broadcast_getindex),Tuple{Float64,Float64},Int64,getfield(Zygote, Symbol("##154#back2#120")){getfield(Zygote, Symbol("##116#118")){2,Int64}},Zygote.J{Tuple{typeof(getindex),Int64,Int64},Tuple{typeof(getindex)}}}},getfield(Zygote, Symbol("##154#back2#120")){getfield(Zygote, Symbol("##116#118")){1,Int64}}}})(::Tuple{Float64}) at /home/chris/.julia/packages/Zygote/zd432/src/compiler/interface2.jl:0
 [14] _broadcast_getindex at ./broadcast.jl:546 [inlined]
 [15] (::Zygote.J{Tuple{typeof(Base.Broadcast._broadcast_getindex),Base.Broadcast.Broadcasted{Base.Broadcast.Style{Tuple},Nothing,typeof(adjoint),Tuple{Tuple{Float64,Float64}}},Int64},Any})(::Float64) at /home/chris/.julia/packages/Zygote/zd432/src/compiler/interface2.jl:0
 [16] #17 at ./broadcast.jl:922 [inlined]
 [17] (::Zygote.J{Tuple{getfield(Base.Broadcast, Symbol("##17#18")){Base.Broadcast.Broadcasted{Base.Broadcast.Style{Tuple},Nothing,typeof(adjoint),Tuple{Tuple{Float64,Float64}}}},Int64},Any})(::Float64) at /home/chris/.julia/packages/Zygote/zd432/src/compiler/interface2.jl:0
 [18] ntuple at ./tuple.jl:157 [inlined]
 [19] (::Zygote.J{Tuple{typeof(ntuple),getfield(Base.Broadcast, Symbol("##17#18")){Base.Broadcast.Broadcasted{Base.Broadcast.Style{Tuple},Nothing,typeof(adjoint),Tuple{Tuple{Float64,Float64}}}},Val{2}},Tuple{typeof(ntuple),getfield(Base.Broadcast, Symbol("##17#18")){Base.Broadcast.Broadcasted{Base.Broadcast.Style{Tuple},Nothing,typeof(adjoint),Tuple{Tuple{Float64,Float64}}}},Val{2},getfield(Zygote, Symbol("##148#back2#115")){typeof(identity)},Zygote.J{Tuple{getfield(Base.Broadcast, Symbol("##17#18")){Base.Broadcast.Broadcasted{Base.Broadcast.Style{Tuple},Nothing,typeof(adjoint),Tuple{Tuple{Float64,Float64}}}},Int64},Any},Zygote.J{Tuple{getfield(Base.Broadcast, Symbol("##17#18")){Base.Broadcast.Broadcasted{Base.Broadcast.Style{Tuple},Nothing,typeof(adjoint),Tuple{Tuple{Float64,Float64}}}},Int64},Any}}})(::Tuple{Float64,Float64}) at /home/chris/.julia/packages/Zygote/zd432/src/compiler/interface2.jl:0
 [20] tuplebroadcast at ./broadcast.jl:922 [inlined]
 [21] (::Zygote.J{Tuple{typeof(Base.Broadcast.tuplebroadcast),Tuple{Float64,Float64},Base.Broadcast.Broadcasted{Base.Broadcast.Style{Tuple},Nothing,typeof(adjoint),Tuple{Tuple{Float64,Float64}}}},Any})(::Tuple{Float64,Float64}) at /home/chris/.julia/packages/Zygote/zd432/src/compiler/interface2.jl:0
 [22] copy at ./broadcast.jl:920 [inlined]
 [23] (::Zygote.J{Tuple{typeof(copy),Base.Broadcast.Broadcasted{Base.Broadcast.Style{Tuple},Nothing,typeof(adjoint),Tuple{Tuple{Float64,Float64}}}},Any})(::Tuple{Float64,Float64}) at /home/chris/.julia/packages/Zygote/zd432/src/compiler/interface2.jl:0
 [24] materialize at ./broadcast.jl:724 [inlined]
 [25] (::Zygote.J{Tuple{typeof(Base.Broadcast.materialize),Base.Broadcast.Broadcasted{Base.Broadcast.Style{Tuple},Nothing,typeof(adjoint),Tuple{Tuple{Float64,Float64}}}},Tuple{typeof(Base.Broadcast.materialize),Base.Broadcast.Broadcasted{Base.Broadcast.Style{Tuple},Nothing,typeof(adjoint),Tuple{Tuple{Float64,Float64}}},Zygote.J{Tuple{typeof(copy),Base.Broadcast.Broadcasted{Base.Broadcast.Style{Tuple},Nothing,typeof(adjoint),Tuple{Tuple{Float64,Float64}}}},Any},Zygote.J{Tuple{typeof(Base.Broadcast.instantiate),Base.Broadcast.Broadcasted{Base.Broadcast.Style{Tuple},Nothing,typeof(adjoint),Tuple{Tuple{Float64,Float64}}}},Tuple{typeof(Base.Broadcast.instantiate),Base.Broadcast.Broadcasted{Base.Broadcast.Style{Tuple},Nothing,typeof(adjoint),Tuple{Tuple{Float64,Float64}}}}}}})(::Tuple{Float64,Float64}) at /home/chris/.julia/packages/Zygote/zd432/src/compiler/interface2.jl:0
 [26] #5 at /home/chris/Documents/prog/zygote/usr/share/julia/stdlib/v1.0/LinearAlgebra/src/adjtrans.jl:185 [inlined]
 [27] (::Zygote.J{Tuple{getfield(LinearAlgebra, Symbol("##5#6")){typeof(*)},Float64,Float64},Any})(::Float64) at /home/chris/.julia/packages/Zygote/zd432/src/compiler/interface2.jl:0
 [28] macro expansion at /home/chris/.julia/packages/StaticArrays/Ze5H3/src/broadcast.jl:133 [inlined]
 [29] _broadcast at /home/chris/.julia/packages/StaticArrays/Ze5H3/src/broadcast.jl:94 [inlined]
 [30] (::Zygote.J{Tuple{typeof(StaticArrays._broadcast),getfield(LinearAlgebra, Symbol("##5#6")){typeof(*)},Size{(4,)},Tuple{Size{()},Size{(4,)}},Float64,SArray{Tuple{4},Float64,1,4}},Any})(::SArray{Tuple{4},Float64,1,4}) at /home/chris/.julia/packages/Zygote/zd432/src/compiler/interface2.jl:0
 [31] (::getfield(Zygote, Symbol("##126#127")){Tuple{Tuple{Nothing,Nothing,Tuple{Nothing,Nothing}},Tuple{Nothing,Nothing}},Zygote.J{Tuple{typeof(StaticArrays._broadcast),getfield(LinearAlgebra, Symbol("##5#6")){typeof(*)},Size{(4,)},Tuple{Size{()},Size{(4,)}},Float64,SArray{Tuple{4},Float64,1,4}},Any}})(::SArray{Tuple{4},Float64,1,4}) at /home/chris/.julia/packages/Zygote/zd432/src/lib/lib.jl:156
 [32] (::getfield(Zygote, Symbol("##166#back2#128")){getfield(Zygote, Symbol("##126#127")){Tuple{Tuple{Nothing,Nothing,Tuple{Nothing,Nothing}},Tuple{Nothing,Nothing}},Zygote.J{Tuple{typeof(StaticArrays._broadcast),getfield(LinearAlgebra, Symbol("##5#6")){typeof(*)},Size{(4,)},Tuple{Size{()},Size{(4,)}},Float64,SArray{Tuple{4},Float64,1,4}},Any}}})(::SArray{Tuple{4},Float64,1,4}) at /home/chris/.julia/packages/Zygote/zd432/src/lib/lib.jl:36
 [33] copy at /home/chris/.julia/packages/StaticArrays/Ze5H3/src/broadcast.jl:24 [inlined]
 [34] (::Zygote.J{Tuple{typeof(copy),Base.Broadcast.Broadcasted{StaticArrays.StaticArrayStyle{1},Tuple{Base.OneTo{Int64}},getfield(LinearAlgebra, Symbol("##5#6")){typeof(*)},Tuple{Float64,SArray{Tuple{4},Float64,1,4}}}},Any})(::SArray{Tuple{4},Float64,1,4}) at /home/chris/.julia/packages/Zygote/zd432/src/compiler/interface2.jl:0
 [35] materialize at ./broadcast.jl:724 [inlined]
 [36] broadcast at ./broadcast.jl:702 [inlined]
 [37] (::Zygote.J{Tuple{typeof(broadcast),getfield(LinearAlgebra, Symbol("##5#6")){typeof(*)},Float64,SArray{Tuple{4},Float64,1,4}},Tuple{typeof(broadcast),getfield(LinearAlgebra, Symbol("##5#6")){typeof(*)},Tuple{Float64,SArray{Tuple{4},Float64,1,4}},Zygote.J{Tuple{typeof(Base.Broadcast.materialize),Base.Broadcast.Broadcasted{StaticArrays.StaticArrayStyle{1},Nothing,getfield(LinearAlgebra, Symbol("##5#6")){typeof(*)},Tuple{Float64,SArray{Tuple{4},Float64,1,4}}}},Tuple{typeof(Base.Broadcast.materialize),Base.Broadcast.Broadcasted{StaticArrays.StaticArrayStyle{1},Nothing,getfield(LinearAlgebra, Symbol("##5#6")){typeof(*)},Tuple{Float64,SArray{Tuple{4},Float64,1,4}}},Zygote.J{Tuple{typeof(copy),Base.Broadcast.Broadcasted{StaticArrays.StaticArrayStyle{1},Tuple{Base.OneTo{Int64}},getfield(LinearAlgebra, Symbol("##5#6")){typeof(*)},Tuple{Float64,SArray{Tuple{4},Float64,1,4}}}},Any},Zygote.J{Tuple{typeof(Base.Broadcast.instantiate),Base.Broadcast.Broadcasted{StaticArrays.StaticArrayStyle{1},Nothing,getfield(LinearAlgebra, Symbol("##5#6")){typeof(*)},Tuple{Float64,SArray{Tuple{4},Float64,1,4}}}},Tuple{typeof(Base.Broadcast.instantiate)}}}},getfield(Zygote, Symbol("##166#back2#128")){getfield(Zygote, Symbol("##126#127")){Tuple{Tuple{Nothing},Tuple{Nothing,Nothing}},getfield(Zygote, Symbol("##1051#back2#594")){getfield(Zygote, Symbol("##592#593"))}}},getfield(Zygote, Symbol("##148#back2#115")){typeof(identity)}}})(::SArray{Tuple{4},Float64,1,4}) at /home/chris/.julia/packages/Zygote/zd432/src/compiler/interface2.jl:0
 [38] (::getfield(Zygote, Symbol("##126#127")){Tuple{Tuple{Nothing},Tuple{Nothing,Nothing}},Zygote.J{Tuple{typeof(broadcast),getfield(LinearAlgebra, Symbol("##5#6")){typeof(*)},Float64,SArray{Tuple{4},Float64,1,4}},Tuple{typeof(broadcast),getfield(LinearAlgebra, Symbol("##5#6")){typeof(*)},Tuple{Float64,SArray{Tuple{4},Float64,1,4}},Zygote.J{Tuple{typeof(Base.Broadcast.materialize),Base.Broadcast.Broadcasted{StaticArrays.StaticArrayStyle{1},Nothing,getfield(LinearAlgebra, Symbol("##5#6")){typeof(*)},Tuple{Float64,SArray{Tuple{4},Float64,1,4}}}},Tuple{typeof(Base.Broadcast.materialize),Base.Broadcast.Broadcasted{StaticArrays.StaticArrayStyle{1},Nothing,getfield(LinearAlgebra, Symbol("##5#6")){typeof(*)},Tuple{Float64,SArray{Tuple{4},Float64,1,4}}},Zygote.J{Tuple{typeof(copy),Base.Broadcast.Broadcasted{StaticArrays.StaticArrayStyle{1},Tuple{Base.OneTo{Int64}},getfield(LinearAlgebra, Symbol("##5#6")){typeof(*)},Tuple{Float64,SArray{Tuple{4},Float64,1,4}}}},Any},Zygote.J{Tuple{typeof(Base.Broadcast.instantiate),Base.Broadcast.Broadcasted{StaticArrays.StaticArrayStyle{1},Nothing,getfield(LinearAlgebra, Symbol("##5#6")){typeof(*)},Tuple{Float64,SArray{Tuple{4},Float64,1,4}}}},Tuple{typeof(Base.Broadcast.instantiate)}}}},getfield(Zygote, Symbol("##166#back2#128")){getfield(Zygote, Symbol("##126#127")){Tuple{Tuple{Nothing},Tuple{Nothing,Nothing}},getfield(Zygote, Symbol("##1051#back2#594")){getfield(Zygote, Symbol("##592#593"))}}},getfield(Zygote, Symbol("##148#back2#115")){typeof(identity)}}}})(::SArray{Tuple{4},Float64,1,4}) at /home/chris/.julia/packages/Zygote/zd432/src/lib/lib.jl:156
 [39] #166#back2 at /home/chris/.julia/packages/Zygote/zd432/src/lib/lib.jl:36 [inlined]
 [40] broadcast at /home/chris/Documents/prog/zygote/usr/share/julia/stdlib/v1.0/LinearAlgebra/src/adjtrans.jl:185 [inlined]
 [41] (::Zygote.J{Tuple{typeof(broadcast),typeof(*),Float64,LinearAlgebra.Adjoint{Float64,SArray{Tuple{4},Float64,1,4}}},Tuple{typeof(broadcast),typeof(*),Tuple{Float64,LinearAlgebra.Adjoint{Float64,SArray{Tuple{4},Float64,1,4}}},getfield(Zygote, Symbol("##1027#back2#575")){getfield(Zygote, Symbol("##573#574"))},getfield(Zygote, Symbol("##166#back2#128")){getfield(Zygote, Symbol("##126#127")){Tuple{Tuple{Nothing},Tuple{Nothing,Nothing}},Zygote.J{Tuple{typeof(broadcast),getfield(LinearAlgebra, Symbol("##5#6")){typeof(*)},Float64,SArray{Tuple{4},Float64,1,4}},Tuple{typeof(broadcast),getfield(LinearAlgebra, Symbol("##5#6")){typeof(*)},Tuple{Float64,SArray{Tuple{4},Float64,1,4}},Zygote.J{Tuple{typeof(Base.Broadcast.materialize),Base.Broadcast.Broadcasted{StaticArrays.StaticArrayStyle{1},Nothing,getfield(LinearAlgebra, Symbol("##5#6")){typeof(*)},Tuple{Float64,SArray{Tuple{4},Float64,1,4}}}},Tuple{typeof(Base.Broadcast.materialize),Base.Broadcast.Broadcasted{StaticArrays.StaticArrayStyle{1},Nothing,getfield(LinearAlgebra, Symbol("##5#6")){typeof(*)},Tuple{Float64,SArray{Tuple{4},Float64,1,4}}},Zygote.J{Tuple{typeof(copy),Base.Broadcast.Broadcasted{StaticArrays.StaticArrayStyle{1},Tuple{Base.OneTo{Int64}},getfield(LinearAlgebra, Symbol("##5#6")){typeof(*)},Tuple{Float64,SArray{Tuple{4},Float64,1,4}}}},Any},Zygote.J{Tuple{typeof(Base.Broadcast.instantiate),Base.Broadcast.Broadcasted{StaticArrays.StaticArrayStyle{1},Nothing,getfield(LinearAlgebra, Symbol("##5#6")){typeof(*)},Tuple{Float64,SArray{Tuple{4},Float64,1,4}}}},Tuple{typeof(Base.Broadcast.instantiate)}}}},getfield(Zygote, Symbol("##166#back2#128")){getfield(Zygote, Symbol("##126#127")){Tuple{Tuple{Nothing},Tuple{Nothing,Nothing}},getfield(Zygote, Symbol("##1051#back2#594")){getfield(Zygote, Symbol("##592#593"))}}},getfield(Zygote, Symbol("##148#back2#115")){typeof(identity)}}}}},Zygote.J{Tuple{typeof(Base.Broadcast.materialize),Base.Broadcast.Broadcasted{Base.Broadcast.Style{Tuple},Nothing,typeof(LinearAlgebra.quasiparenta),Tuple{Tuple{Float64,LinearAlgebra.Adjoint{Float64,SArray{Tuple{4},Float64,1,4}}}}}},Tuple{typeof(Base.Broadcast.materialize),Base.Broadcast.Broadcasted{Base.Broadcast.Style{Tuple},Nothing,typeof(LinearAlgebra.quasiparenta),Tuple{Tuple{Float64,LinearAlgebra.Adjoint{Float64,SArray{Tuple{4},Float64,1,4}}}}},Zygote.J{Tuple{typeof(copy),Base.Broadcast.Broadcasted{Base.Broadcast.Style{Tuple},Nothing,typeof(LinearAlgebra.quasiparenta),Tuple{Tuple{Float64,LinearAlgebra.Adjoint{Float64,SArray{Tuple{4},Float64,1,4}}}}}},Any},Zygote.J{Tuple{typeof(Base.Broadcast.instantiate),Base.Broadcast.Broadcasted{Base.Broadcast.Style{Tuple},Nothing,typeof(LinearAlgebra.quasiparenta),Tuple{Tuple{Float64,LinearAlgebra.Adjoint{Float64,SArray{Tuple{4},Float64,1,4}}}}}},Tuple{typeof(Base.Broadcast.instantiate),Base.Broadcast.Broadcasted{Base.Broadcast.Style{Tuple},Nothing,typeof(LinearAlgebra.quasiparenta),Tuple{Tuple{Float64,LinearAlgebra.Adjoint{Float64,SArray{Tuple{4},Float64,1,4}}}}}}}}},getfield(Zygote, Symbol("##1051#back2#594")){getfield(Zygote, Symbol("##592#593"))},getfield(Zygote, Symbol("##148#back2#115")){typeof(identity)},getfield(Zygote, Symbol("##188#back2#145")){Zygote.Jnew{getfield(LinearAlgebra, Symbol("##5#6")){typeof(*)},Nothing}}}})(::LinearAlgebra.Transpose{Float64,SArray{Tuple{4},Float64,1,4}}) at /home/chris/.julia/packages/Zygote/zd432/src/compiler/interface2.jl:0
 [42] * at ./arraymath.jl:52 [inlined]
 [43] * at ./operators.jl:502 [inlined]
 [44] (::Zygote.J{Tuple{typeof(*),Float64,LinearAlgebra.Adjoint{Float64,SArray{Tuple{4},Float64,1,4}},SArray{Tuple{4,4},Float64,2,16},SArray{Tuple{4},Float64,1,4}},Tuple{typeof(*),Float64,LinearAlgebra.Adjoint{Float64,SArray{Tuple{4},Float64,1,4}},SArray{Tuple{4,4},Float64,2,16},Tuple{SArray{Tuple{4},Float64,1,4}},getfield(Zygote, Symbol("##166#back2#128")){getfield(Zygote, Symbol("##126#127")){Tuple{Tuple{Nothing,Nothing},Tuple{Nothing}},Zygote.J{Tuple{typeof(Base.afoldl),typeof(*),LinearAlgebra.Adjoint{Float64,SArray{Tuple{4},Float64,1,4}},SArray{Tuple{4},Float64,1,4}},Tuple{typeof(Base.afoldl),typeof(*),LinearAlgebra.Adjoint{Float64,SArray{Tuple{4},Float64,1,4}},SArray{Tuple{4},Float64,1,4},getfield(Zygote, Symbol("##1015#back2#569")){getfield(Zygote, Symbol("##567#568")){LinearAlgebra.Adjoint{Float64,SArray{Tuple{4},Float64,1,4}},SArray{Tuple{4},Float64,1,4}}}}}}},getfield(Zygote, Symbol("##148#back2#115")){typeof(identity)},getfield(Zygote, Symbol("##1015#back2#569")){getfield(Zygote, Symbol("##567#568")){LinearAlgebra.Adjoint{Float64,SArray{Tuple{4},Float64,1,4}},SArray{Tuple{4,4},Float64,2,16}}},Zygote.J{Tuple{typeof(*),Float64,LinearAlgebra.Adjoint{Float64,SArray{Tuple{4},Float64,1,4}}},Tuple{typeof(*),Float64,LinearAlgebra.Adjoint{Float64,SArray{Tuple{4},Float64,1,4}},Zygote.J{Tuple{typeof(broadcast),typeof(*),Float64,LinearAlgebra.Adjoint{Float64,SArray{Tuple{4},Float64,1,4}}},Tuple{typeof(broadcast),typeof(*),Tuple{Float64,LinearAlgebra.Adjoint{Float64,SArray{Tuple{4},Float64,1,4}}},getfield(Zygote, Symbol("##1027#back2#575")){getfield(Zygote, Symbol("##573#574"))},getfield(Zygote, Symbol("##166#back2#128")){getfield(Zygote, Symbol("##126#127")){Tuple{Tuple{Nothing},Tuple{Nothing,Nothing}},Zygote.J{Tuple{typeof(broadcast),getfield(LinearAlgebra, Symbol("##5#6")){typeof(*)},Float64,SArray{Tuple{4},Float64,1,4}},Tuple{typeof(broadcast),getfield(LinearAlgebra, Symbol("##5#6")){typeof(*)},Tuple{Float64,SArray{Tuple{4},Float64,1,4}},Zygote.J{Tuple{typeof(Base.Broadcast.materialize),Base.Broadcast.Broadcasted{StaticArrays.StaticArrayStyle{1},Nothing,getfield(LinearAlgebra, Symbol("##5#6")){typeof(*)},Tuple{Float64,SArray{Tuple{4},Float64,1,4}}}},Tuple{typeof(Base.Broadcast.materialize),Base.Broadcast.Broadcasted{StaticArrays.StaticArrayStyle{1},Nothing,getfield(LinearAlgebra, Symbol("##5#6")){typeof(*)},Tuple{Float64,SArray{Tuple{4},Float64,1,4}}},Zygote.J{Tuple{typeof(copy),Base.Broadcast.Broadcasted{StaticArrays.StaticArrayStyle{1},Tuple{Base.OneTo{Int64}},getfield(LinearAlgebra, Symbol("##5#6")){typeof(*)},Tuple{Float64,SArray{Tuple{4},Float64,1,4}}}},Any},Zygote.J{Tuple{typeof(Base.Broadcast.instantiate),Base.Broadcast.Broadcasted{StaticArrays.StaticArrayStyle{1},Nothing,getfield(LinearAlgebra, Symbol("##5#6")){typeof(*)},Tuple{Float64,SArray{Tuple{4},Float64,1,4}}}},Tuple{typeof(Base.Broadcast.instantiate)}}}},getfield(Zygote, Symbol("##166#back2#128")){getfield(Zygote, Symbol("##126#127")){Tuple{Tuple{Nothing},Tuple{Nothing,Nothing}},getfield(Zygote, Symbol("##1051#back2#594")){getfield(Zygote, Symbol("##592#593"))}}},getfield(Zygote, Symbol("##148#back2#115")){typeof(identity)}}}}},Zygote.J{Tuple{typeof(Base.Broadcast.materialize),Base.Broadcast.Broadcasted{Base.Broadcast.Style{Tuple},Nothing,typeof(LinearAlgebra.quasiparenta),Tuple{Tuple{Float64,LinearAlgebra.Adjoint{Float64,SArray{Tuple{4},Float64,1,4}}}}}},Tuple{typeof(Base.Broadcast.materialize),Base.Broadcast.Broadcasted{Base.Broadcast.Style{Tuple},Nothing,typeof(LinearAlgebra.quasiparenta),Tuple{Tuple{Float64,LinearAlgebra.Adjoint{Float64,SArray{Tuple{4},Float64,1,4}}}}},Zygote.J{Tuple{typeof(copy),Base.Broadcast.Broadcasted{Base.Broadcast.Style{Tuple},Nothing,typeof(LinearAlgebra.quasiparenta),Tuple{Tuple{Float64,LinearAlgebra.Adjoint{Float64,SArray{Tuple{4},Float64,1,4}}}}}},Any},Zygote.J{Tuple{typeof(Base.Broadcast.instantiate),Base.Broadcast.Broadcasted{Base.Broadcast.Style{Tuple},Nothing,typeof(LinearAlgebra.quasiparenta),Tuple{Tuple{Float64,LinearAlgebra.Adjoint{Float64,SArray{Tuple{4},Float64,1,4}}}}}},Tuple{typeof(Base.Broadcast.instantiate),Base.Broadcast.Broadcasted{Base.Broadcast.Style{Tuple},Nothing,typeof(LinearAlgebra.quasiparenta),Tuple{Tuple{Float64,LinearAlgebra.Adjoint{Float64,SArray{Tuple{4},Float64,1,4}}}}}}}}},getfield(Zygote, Symbol("##1051#back2#594")){getfield(Zygote, Symbol("##592#593"))},getfield(Zygote, Symbol("##148#back2#115")){typeof(identity)},getfield(Zygote, Symbol("##188#back2#145")){Zygote.Jnew{getfield(LinearAlgebra, Symbol("##5#6")){typeof(*)},Nothing}}}}}}}})(::Int64) at /home/chris/.julia/packages/Zygote/zd432/src/compiler/interface2.jl:0
 [45] g at ./REPL[4]:1 [inlined]
 [46] (::Zygote.J{Tuple{typeof(g),SArray{Tuple{4},Float64,1,4}},Any})(::Int64) at /home/chris/.julia/packages/Zygote/zd432/src/compiler/interface2.jl:0
 [47] (::getfield(Zygote, Symbol("##73#74")){Zygote.J{Tuple{typeof(g),SArray{Tuple{4},Float64,1,4}},Any}})(::Int64) at /home/chris/.julia/packages/Zygote/zd432/src/compiler/interface.jl:28
 [48] gradient(::Function, ::SArray{Tuple{4},Float64,1,4}) at /home/chris/.julia/packages/Zygote/zd432/src/compiler/interface.jl:34 [49] top-level scope at none:0

julia> versioninfo() # latest commit from here https://github.com/JuliaLang/julia/tree/mji/zygote
Julia Version 1.0.0
Commit ece21e1ae5 (2018-08-29 14:15 UTC)
Platform Info:
  OS: Linux (x86_64-pc-linux-gnu)
  CPU: AMD Ryzen Threadripper 1950X 16-Core Processor
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-6.0.0 (ORCJIT, znver1)

I am on the latest commit of this branch of Julia, and on Zygote master.

@MikeInnes
Copy link
Member

I expect this has an identical root cause as #5, though the specific error has come out differently. The same bandaid works – 0.5(x'*S*x).

julia> derivative(g, @SVector randn(4))
4-element SArray{Tuple{4},Float64,1,4}:
 -8.43300310272326
 28.03595720287608
 -4.489528942268124
 -6.965151312812216

As an aside, it's very pleasing that custom array types already work nicely like this; it's not something I had tested up to now. StaticArrays are an interesting case, because for things like this they should be very fast.

@chriselrod
Copy link
Author

Feel free to close this issue in favor of the old one. This works for me.
Thanks for the great work here - I'm excitedly following the development of this and Capstan.

@chriselrod
Copy link
Author

Right now, this code is type stable, but much slower than ForwardDiff (or the analytic gradient):

julia> using StaticArrays, BenchmarkTools, ForwardDiff

julia> import Zygote: gradient #always re-precompiles?
[ Info: Precompiling Zygote [e88e6eb3-aa80-5325-afca-941959d7151f]
[ Info: Precompiling IRTools [7869d1d1-7146-5819-86e3-90919afe41df]

julia> const S = (@SMatrix randn(6,4)) |> x -> x' * x;

julia> g(x) = 0.5 * (x' * S * x);

julia> x = @SVector randn(4);

julia> g(x)
12.515760061124144

julia> x' * S
1×4 LinearAlgebra.Adjoint{Float64,SArray{Tuple{4},Float64,1,4}}:
 -1.82664  -13.4078  -5.39434  1.61724

julia> gradient(g, x)
([-1.82664, -13.4078, -5.39434, 1.61724],)

julia> ForwardDiff.gradient(g, x)'
1×4 LinearAlgebra.Adjoint{Float64,SArray{Tuple{4},Float64,1,4}}:
 -1.82664  -13.4078  -5.39434  1.61724

julia> @benchmark gradient(g, $x)
BenchmarkTools.Trial: 
  memory estimate:  1.64 KiB
  allocs estimate:  22
  --------------
  minimum time:     1.327 μs (0.00% GC)
  median time:      1.374 μs (0.00% GC)
  mean time:        2.254 μs (35.01% GC)
  maximum time:     5.420 ms (99.91% GC)
  --------------
  samples:          10000
  evals/sample:     10

julia> @benchmark ForwardDiff.gradient(g, $x)
BenchmarkTools.Trial: 
  memory estimate:  0 bytes
  allocs estimate:  0
  --------------
  minimum time:     47.753 ns (0.00% GC)
  median time:      48.698 ns (0.00% GC)
  mean time:        48.696 ns (0.00% GC)
  maximum time:     70.348 ns (0.00% GC)
  --------------
  samples:          10000
  evals/sample:     988

julia> @benchmark $x' * S
BenchmarkTools.Trial: 
  memory estimate:  0 bytes
  allocs estimate:  0
  --------------
  minimum time:     4.029 ns (0.00% GC)
  median time:      4.034 ns (0.00% GC)
  mean time:        4.209 ns (0.00% GC)
  maximum time:     18.588 ns (0.00% GC)
  --------------
  samples:          10000
  evals/sample:     1000

julia> @code_warntype gradient(g, x)
Body::Tuple{SArray{Tuple{4},Float64,1,4}}
32 1 ── %1   = (getfield)(args, 1)::SArray{Tuple{4},Float64,1,4}                  │                 
   │    %2   = Zygote.nothing::Nothing                                            │╻╷                forward
   │    %3   = %new(Zygote.Context, %2)::Zygote.Context                           ││┃││               _forward
   │    %4   = %new(LinearAlgebra.Adjoint{Float64,SArray{Tuple{4},Float64,1,4}}, %1)::LinearAlgebra.Adjoint{Float64,SArray{Tuple{4},Float64,1,4}}%5   = Main.S::SArray{Tuple{4,4},Float64,2,16}                            ││││              
   │    %6   = invoke Zygote._forward(%3::Zygote.Context, Main.:*::typeof(*), %4::LinearAlgebra.Adjoint{Float64,SArray{Tuple{4},Float64,1,4}}, $(QuoteNode([4.22863 0.218509 -1.17347 1.24761; 0.218509 8.47895 2.104 0.21028; -1.17347 2.104 3.31319 -1.31538; 1.24761 0.21028 -1.31538 3.23605]))::SArray{Tuple{4,4},Float64,2,16}, %1::SArray{Tuple{4},Float64,1,4})::Tuple{Float64,Zygote.J{Tuple{typeof(*),LinearAlgebra.Adjoint{Float64,SArray{Tuple{4},Float64,1,4}},SArray{Tuple{4,4},Float64,2,16},SArray{Tuple{4},Float64,1,4}},Tuple{typeof(*),LinearAlgebra.Adjoint{Float64,SArray{Tuple{4},Float64,1,4}},SArray{Tuple{4,4},Float64,2,16},SArray{Tuple{4},Float64,1,4},Tuple{},getfield(Zygote, Symbol("##165#back2#123")){getfield(Zygote, Symbol("##121#122")){Tuple{Tuple{Nothing,Nothing},Tuple{}},Zygote.J{Tuple{typeof(Base.afoldl),typeof(*),Float64},Tuple{typeof(Base.afoldl),typeof(*),Float64}}}},getfield(Zygote, Symbol("##147#back2#110")){typeof(identity)},getfield(Zygote, Symbol("##1014#back2#564")){getfield(Zygote, Symbol("##562#563")){LinearAlgebra.Adjoint{Float64,SArray{Tuple{4},Float64,1,4}},SArray{Tuple{4},Float64,1,4}}},getfield(Zygote, Symbol("##1014#back2#564")){getfield(Zygote, Symbol("##562#563")){LinearAlgebra.Adjoint{Float64,SArray{Tuple{4},Float64,1,4}},SArray{Tuple{4,4},Float64,2,16}}}}}}
   │    %7   = (Base.getfield)(%6, 1, true)::Float64                              ││││╻                 getindex
   │    %8   = (Base.getfield)(%6, 2, true)::Zygote.J{Tuple{typeof(*),LinearAlgebra.Adjoint{Float64,SArray{Tuple{4},Float64,1,4}},SArray{Tuple{4,4},Float64,2,16},SArray{Tuple{4},Float64,1,4}},Tuple{typeof(*),LinearAlgebra.Adjoint{Float64,SArray{Tuple{4},Float64,1,4}},SArray{Tuple{4,4},Float64,2,16},SArray{Tuple{4},Float64,1,4},Tuple{},getfield(Zygote, Symbol("##165#back2#123")){getfield(Zygote, Symbol("##121#122")){Tuple{Tuple{Nothing,Nothing},Tuple{}},Zygote.J{Tuple{typeof(Base.afoldl),typeof(*),Float64},Tuple{typeof(Base.afoldl),typeof(*),Float64}}}},getfield(Zygote, Symbol("##147#back2#110")){typeof(identity)},getfield(Zygote, Symbol("##1014#back2#564")){getfield(Zygote, Symbol("##562#563")){LinearAlgebra.Adjoint{Float64,SArray{Tuple{4},Float64,1,4}},SArray{Tuple{4},Float64,1,4}}},getfield(Zygote, Symbol("##1014#back2#564")){getfield(Zygote, Symbol("##562#563")){LinearAlgebra.Adjoint{Float64,SArray{Tuple{4},Float64,1,4}},SArray{Tuple{4,4},Float64,2,16}}}}}
   └───        (Base.mul_float)(0.5, %7)                                          │││││╻╷                macro expansion
34 2 ┄─        (Base.mul_float)(1.0, %7)                                          │╻╷╷╷╷╷            #66%11  = (Base.mul_float)(1.0, 0.5)::Float64                                ││╻                 g
   │    %12  = (Base.getfield)(%8, :t)::Tuple{typeof(*),LinearAlgebra.Adjoint{Float64,SArray{Tuple{4},Float64,1,4}},SArray{Tuple{4,4},Float64,2,16},SArray{Tuple{4},Float64,1,4},Tuple{},getfield(Zygote, Symbol("##165#back2#123")){getfield(Zygote, Symbol("##121#122")){Tuple{Tuple{Nothing,Nothing},Tuple{}},Zygote.J{Tuple{typeof(Base.afoldl),typeof(*),Float64},Tuple{typeof(Base.afoldl),typeof(*),Float64}}}},getfield(Zygote, Symbol("##147#back2#110")){typeof(identity)},getfield(Zygote, Symbol("##1014#back2#564")){getfield(Zygote, Symbol("##562#563")){LinearAlgebra.Adjoint{Float64,SArray{Tuple{4},Float64,1,4}},SArray{Tuple{4},Float64,1,4}}},getfield(Zygote, Symbol("##1014#back2#564")){getfield(Zygote, Symbol("##562#563")){LinearAlgebra.Adjoint{Float64,SArray{Tuple{4},Float64,1,4}},SArray{Tuple{4,4},Float64,2,16}}}}
   │    %13  = (Base.getfield)(%12, 8, true)::getfield(Zygote, Symbol("##1014#back2#564")){getfield(Zygote, Symbol("##562#563")){LinearAlgebra.Adjoint{Float64,SArray{Tuple{4},Float64,1,4}},SArray{Tuple{4},Float64,1,4}}}
   │    %14  = (Base.getfield)(%12, 9, true)::getfield(Zygote, Symbol("##1014#back2#564")){getfield(Zygote, Symbol("##562#563")){LinearAlgebra.Adjoint{Float64,SArray{Tuple{4},Float64,1,4}},SArray{Tuple{4,4},Float64,2,16}}}
   │    %15  = π (%11, Float64)                                                   │││││╻                 #121%16  = (Core.getfield)(%13, Symbol("#1013#back"))::getfield(Zygote, Symbol("##562#563")){LinearAlgebra.Adjoint{Float64,SArray{Tuple{4},Float64,1,4}},SArray{Tuple{4},Float64,1,4}}
   │    %17  = (Core.getfield)(%16, :b)::SArray{Tuple{4},Float64,1,4}             │││││╻                 #562
   │           (Base.ifelse)(false, 0, 4)                                         ││││││╻╷╷╷╷╷╷╷╷╷╷╷      *%19  = (Base.getfield)(%17, :data)::NTuple{4,Float64}                     │││││││╻╷╷╷              broadcast
   │    %20  = (Base.getfield)(%19, 1, false)::Float64                            ││││││││╻                 broadcast
   │    %21  = (Base.mul_float)(%15, %20)::Float64                                │││││││││╻                 materialize
   │    %22  = (Base.getfield)(%17, :data)::NTuple{4,Float64}                     ││││││││││╻                 copy
   │    %23  = (Base.getfield)(%22, 2, false)::Float64                            │││││││││││╻                 _broadcast
   │    %24  = (Base.mul_float)(%15, %23)::Float64                                ││││││││││││╻                 macro expansion
   │    %25  = (Base.getfield)(%17, :data)::NTuple{4,Float64}                     │││││││││││││╻                 getindex
   │    %26  = (Base.getfield)(%25, 3, false)::Float64                            ││││││││││││││╻                 getindex
   │    %27  = (Base.mul_float)(%15, %26)::Float64                                ││││││││││││││╻                 *%28  = (Base.getfield)(%17, :data)::NTuple{4,Float64}                     ││││││││││││││╻                 getproperty
   │    %29  = (Base.getfield)(%28, 4, false)::Float64                            ││││││││││││││╻                 getindex
   │    %30  = (Base.mul_float)(%15, %29)::Float64                                ││││││││││││││╻                 *%31  = (StaticArrays.tuple)(%21, %24, %27, %30)::NTuple{4,Float64}        │││││││││││││     
   │    %32  = %new(SArray{Tuple{4},Float64,1,4}, %31)::SArray{Tuple{4},Float64,1,4}│││││││││││╻                 Type
   └───        goto #4                                                            │││││││││││││     
   3 ──        $(Expr(:unreachable))                                              │││││││││││││     
   4 ┄─        goto #5                                                            │││││││││││       
   5 ──        goto #6                                                            ││││││││││        
   6 ──        goto #7                                                            │││││││││         
   7 ── %38  = %new(LinearAlgebra.Transpose{Float64,SArray{Tuple{4},Float64,1,4}}, %32)::LinearAlgebra.Transpose{Float64,SArray{Tuple{4},Float64,1,4}}
   └───        goto #8                                                            ││││││││          
   8 ──        goto #9                                                            │││││││           
   9 ── %41  = (Core.getfield)(%16, :a)::LinearAlgebra.Adjoint{Float64,SArray{Tuple{4},Float64,1,4}}%42  = (Base.getfield)(%41, :parent)::SArray{Tuple{4},Float64,1,4}        │││││││╻                 getproperty
   │           (Base.ifelse)(false, 0, 4)                                         │││││││╻╷╷╷╷╷╷╷╷╷        broadcast
   │    %44  = (Base.getfield)(%42, :data)::NTuple{4,Float64}                     ││││││││╻╷╷╷              materialize
   │    %45  = (Base.getfield)(%44, 1, false)::Float64                            │││││││││╻                 copy
   │    %46  = (Base.mul_float)(%45, %15)::Float64                                ││││││││││╻                 _broadcast
   │    %47  = (Base.getfield)(%42, :data)::NTuple{4,Float64}                     │││││││││││╻                 macro expansion
   │    %48  = (Base.getfield)(%47, 2, false)::Float64                            ││││││││││││╻                 getindex
   │    %49  = (Base.mul_float)(%48, %15)::Float64                                ││││││││││││╻                 *%50  = (Base.getfield)(%42, :data)::NTuple{4,Float64}                     │││││││││││││╻                 getproperty
   │    %51  = (Base.getfield)(%50, 3, false)::Float64                            │││││││││││││╻                 getindex
   │    %52  = (Base.mul_float)(%51, %15)::Float64                                ││││││││││││╻                 *%53  = (Base.getfield)(%42, :data)::NTuple{4,Float64}                     │││││││││││││╻                 getproperty
   │    %54  = (Base.getfield)(%53, 4, false)::Float64                            │││││││││││││╻                 getindex
   │    %55  = (Base.mul_float)(%54, %15)::Float64                                ││││││││││││╻                 *
   └───        goto #11                                                           ││││││││││││      
   10$(Expr(:unreachable))                                              ││││││││││││      
   11 ┄        goto #12                                                           ││││││││││        
   12 ─        goto #13                                                           │││││││││         
   13 ─        goto #14                                                           ││││││││          
   14 ─        goto #15                                                           │││││││           
   15 ─        goto #16                                                           ││││││            
   16 ─        goto #17                                                           │││││             
   17%64  = (Core.getfield)(%14, Symbol("#1013#back"))::getfield(Zygote, Symbol("##562#563")){LinearAlgebra.Adjoint{Float64,SArray{Tuple{4},Float64,1,4}},SArray{Tuple{4,4},Float64,2,16}}
   │    %65  = invoke %64(%38::LinearAlgebra.Transpose{Float64,SArray{Tuple{4},Float64,1,4}})::Tuple{LinearAlgebra.Transpose{Float64,SArray{Tuple{4},Float64,1,4}},SArray{Tuple{4,4},Float64,2,16}}
   │    %66  = (getfield)(%65, 1)::LinearAlgebra.Transpose{Float64,SArray{Tuple{4},Float64,1,4}}         _gradtuple
   │    %67  = (getfield)(%65, 2)::SArray{Tuple{4,4},Float64,2,16}                ││││││            
   └───        goto #18                                                           ││││              
   18 ─        invoke Zygote.accum_param(%3::Zygote.Context, %5::SArray{Tuple{4,4},Float64,2,16}, %67::SArray{Tuple{4,4},Float64,2,16})
   │    %70  = (Base.getfield)(%66, :parent)::SArray{Tuple{4},Float64,1,4}        ││││╻╷╷               #568
   │           (Base.ifelse)(false, 0, 4)                                         ││││╻╷╷╷╷╷╷╷╷         materialize
   │           (Base.ifelse)(false, 0, 4)                                         │││││╻╷╷╷╷╷            instantiate
   │    %73  = (4 === 4)::Bool                                                    ││││││╻╷╷╷╷             combine_axes
   │    %74  = (Base.and_int)(true, %73)::Bool                                    │││││││╻                 broadcast_shape
   └───        goto #20 if not %74                                                ││││││││┃││               _bcs
   19 ─        goto #21                                                           │││││││││┃│                _bcs1
   20 ─        goto #21                                                           ││││││││││┃                 _bcsm
   21%78  = φ (#19 => %74, #20 => false)::Bool                                 ││││││││││        
   └───        goto #23 if not %78                                                ││││││││││        
   22 ─        goto #29                                                           ││││││││││        
   23%81  = (4 === 4)::Bool                                                    │││││││││││╻╷                ==%82  = (Base.and_int)(true, %81)::Bool                                    ││││││││││││╻                 &
   └───        goto #25 if not %82                                                │││││││││││       
   24 ─        goto #26                                                           │││││││││││       
   25 ─        goto #26                                                           │││││││││││       
   26%86  = φ (#24 => %82, #25 => false)::Bool                                 ││││││││││        
   └───        goto #28 if not %86                                                ││││││││││        
   27 ─        goto #29                                                           ││││││││││        
   28%89  = %new(Base.DimensionMismatch, "arrays could not be broadcast to a common size")::DimensionMismatchpe
   │           (Base.Broadcast.throw)(%89)                                        ││││││││││        
   └───        $(Expr(:unreachable))                                              ││││││││││        
   29 ┄        goto #30                                                           │││││││││         
   30 ─        goto #31                                                           ││││││││          
   31 ─        goto #32                                                           │││││││           
   32 ─        goto #33                                                           │││││╻                 instantiate
   33%96  = (Base.getfield)(%70, :data)::NTuple{4,Float64}                     ││││││╻╷╷╷              _broadcast
   │    %97  = (Base.getfield)(%96, 1, false)::Float64                            │││││││╻                 macro expansion
   │    %98  = (Base.add_float)(%46, %97)::Float64                                ││││││││╻                 accum
   │    %99  = (Base.getfield)(%70, :data)::NTuple{4,Float64}                     │││││││││╻                 getproperty
   │    %100 = (Base.getfield)(%99, 2, false)::Float64                            │││││││││╻                 getindex
   │    %101 = (Base.add_float)(%49, %100)::Float64                               │││││││││╻                 +%102 = (Base.getfield)(%70, :data)::NTuple{4,Float64}                     │││││││││╻                 getproperty
   │    %103 = (Base.getfield)(%102, 3, false)::Float64                           │││││││││╻                 getindex
   │    %104 = (Base.add_float)(%52, %103)::Float64                               │││││││││╻                 +%105 = (Base.getfield)(%70, :data)::NTuple{4,Float64}                     │││││││││╻                 getproperty
   │    %106 = (Base.getfield)(%105, 4, false)::Float64                           │││││││││╻                 getindex
   │    %107 = (Base.add_float)(%55, %106)::Float64                               │││││││││╻                 +%108 = (StaticArrays.tuple)(%98, %101, %104, %107)::NTuple{4,Float64}     ││││││││          
   │    %109 = %new(SArray{Tuple{4},Float64,1,4}, %108)::SArray{Tuple{4},Float64,1,4}│││││╻                 Type
   └───        goto #35                                                           ││││││││          
   34$(Expr(:unreachable))                                              ││││││││          
   35 ┄        goto #36                                                           ││││││            
   36 ─        goto #37                                                           │││││             
   37 ─        goto #38                                                           ││││              
   38 ─        goto #39                                                           │││               
   39%116 = (Core.tuple)(%109)::Tuple{SArray{Tuple{4},Float64,1,4}}            │││╻                 tail
   └───        goto #40                                                           ││                
   40return %11641 ─        goto #2    

I also do not see any change from Zygote.refresh():

julia> using Zygote

julia> Zygote.refresh()

julia> gradient(g, x)
([-1.82664, -13.4078, -5.39434, 1.61724],)

julia> @benchmark gradient(g, $x)
BenchmarkTools.Trial: 
  memory estimate:  1.64 KiB
  allocs estimate:  22
  --------------
  minimum time:     1.346 μs (0.00% GC)
  median time:      1.393 μs (0.00% GC)
  mean time:        2.312 μs (35.93% GC)
  maximum time:     5.704 ms (99.92% GC)
  --------------
  samples:          10000
  evals/sample:     10

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants