-
-
Notifications
You must be signed in to change notification settings - Fork 213
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Crashing on functions involving array operations #24
Comments
No misunderstanding, these are definitely bugs on Zygote's end, at least insofar as they should give more helpful errors. Your first example does, now: julia> f1'(1)
ERROR: Mutating arrays is not supported You'll have to write that differently until we add support for mutation. The reason the loop example fails in a worse way is that although it has the same behaviour, in nonetheless emits more complex code (containing a loop), and Zygote has to deal with that, because it runs before any other optimisations are applied. |
Just a heads up, this example now works on julia> f1'(1)
1.0 |
Closing for now as I think this is giving the right error (we can track mutation support in #61). |
Right now Zygote inserts stacks whenever it needs to use an ssa value not defined in the first basic block. This is of course unnecessary. The condition for needing stacks is that the basic block that defines it is self-reachable (i.e. in a loop). Otherwise, we can simply insert phi nodes to thread the desired SSA value through to the exit block (we don't need to do anything in the adjoint, since the reversal of the CFG ensures dominance). Removing stacks allows for both more efficient code generation and enables higher order auto-diff (since we use control flow in Zygote, but can't handle differentiating code that contains stacks). The headline example is something like the following: ``` function foo(b, x) if b sin(x) else cos(x) end end ``` Then looking at `@code_typed derivative(x->foo(true, x), 1.0)`, we get: Before: ``` CodeInfo( 1 ── %1 = $(Expr(:foreigncall, :(:jl_alloc_array_1d), Array{Int8,1}, svec(Any, Int64), :(:ccall), 2, Array{Int8,1}, 0, 0))::Array{Int8,1} │ %2 = $(Expr(:foreigncall, :(:jl_alloc_array_1d), Array{Any,1}, svec(Any, Int64), :(:ccall), 2, Array{Any,1}, 0, 0))::Array{Any,1} │ %3 = $(Expr(:foreigncall, :(:jl_alloc_array_1d), Array{Any,1}, svec(Any, Int64), :(:ccall), 2, Array{Any,1}, 0, 0))::Array{Any,1} │ %4 = Base.sin::typeof(sin) │ invoke %4(_3::Float64)::Float64 │ %6 = %new(##334#335{Float64}, x)::##334#335{Float64} │ %7 = %new(##758#back#336{##334#335{Float64}}, %6)::##758#back#336{##334#335{Float64}} [snip] 23 ─ %52 = invoke %47(1::Int8)::Tuple{Nothing,Nothing,Any} │ %53 = Base.getfield(%52, 3, true)::Any └─── goto #24 24 ─ return %53 ) => Any ``` After: ``` CodeInfo( 1 ─ %1 = Base.sin::typeof(sin) │ invoke %1(_3::Float64)::Float64 │ %3 = Core.Intrinsics.not_int(true)::Bool └── goto #3 if not %3 2 ─ invoke Zygote.notnothing(nothing::Nothing)::Union{} └── $(Expr(:unreachable))::Union{} 3 ┄ %7 = invoke Zygote.cos(_3::Float64)::Float64 │ %8 = Base.mul_float(1.0, %7)::Float64 └── goto #4 4 ─ goto #5 5 ─ goto #6 6 ─ goto #7 7 ─ return %8 ) => Float64 ``` Which is essentially perfect (there's a bit of junk left over, but LLVM can take care of that. The only thing that doesn't get removed is the useless invocation of `sin`, but that's a separate and known issue).
Right now Zygote inserts stacks whenever it needs to use an ssa value not defined in the first basic block. This is of course unnecessary. The condition for needing stacks is that the basic block that defines it is self-reachable (i.e. in a loop). Otherwise, we can simply insert phi nodes to thread the desired SSA value through to the exit block (we don't need to do anything in the adjoint, since the reversal of the CFG ensures dominance). Removing stacks allows for both more efficient code generation and enables higher order auto-diff (since we use control flow in Zygote, but can't handle differentiating code that contains stacks). The headline example is something like the following: ``` function foo(b, x) if b sin(x) else cos(x) end end ``` Then looking at `@code_typed derivative(x->foo(true, x), 1.0)`, we get: Before: ``` CodeInfo( 1 ── %1 = $(Expr(:foreigncall, :(:jl_alloc_array_1d), Array{Int8,1}, svec(Any, Int64), :(:ccall), 2, Array{Int8,1}, 0, 0))::Array{Int8,1} │ %2 = $(Expr(:foreigncall, :(:jl_alloc_array_1d), Array{Any,1}, svec(Any, Int64), :(:ccall), 2, Array{Any,1}, 0, 0))::Array{Any,1} │ %3 = $(Expr(:foreigncall, :(:jl_alloc_array_1d), Array{Any,1}, svec(Any, Int64), :(:ccall), 2, Array{Any,1}, 0, 0))::Array{Any,1} │ %4 = Base.sin::typeof(sin) │ invoke %4(_3::Float64)::Float64 │ %6 = %new(##334#335{Float64}, x)::##334#335{Float64} │ %7 = %new(##758#back#336{##334#335{Float64}}, %6)::##758#back#336{##334#335{Float64}} [snip] 23 ─ %52 = invoke %47(1::Int8)::Tuple{Nothing,Nothing,Any} │ %53 = Base.getfield(%52, 3, true)::Any └─── goto #24 24 ─ return %53 ) => Any ``` After: ``` CodeInfo( 1 ─ %1 = Base.sin::typeof(sin) │ invoke %1(_3::Float64)::Float64 │ %3 = Core.Intrinsics.not_int(true)::Bool └── goto #3 if not %3 2 ─ invoke Zygote.notnothing(nothing::Nothing)::Union{} └── $(Expr(:unreachable))::Union{} 3 ┄ %7 = invoke Zygote.cos(_3::Float64)::Float64 │ %8 = Base.mul_float(1.0, %7)::Float64 └── goto #4 4 ─ goto #5 5 ─ goto #6 6 ─ goto #7 7 ─ return %8 ) => Float64 ``` Which is essentially perfect (there's a bit of junk left over, but LLVM can take care of that. The only thing that doesn't get removed is the useless invocation of `sin`, but that's a separate and known issue).
This is probably due to a known unimplemented feature, but I'm having problems getting gradients with respect to functions involving arrays. Sometimes f'(p) works but derivative(f, p) does not. I want to use the latter because I need gradients with respect to parameter vectors, not scalars as in these minimal examples.
EXAMPLE 1:
The next example with a 'for' loop actually segfaults.
EXAMPLE 2:
Clearly f'(p) and derivative(f, p) work differently, although I would have expected the former to be syntactic sugar for the latter. This is true even for simple working examples:
f(x) = x^2
givesf'(1.5) == 2.9999999999973244
whilederivative(f, 1.5) == 3.0
.Also, these problems seem to have to do with array operations, because it works if I write out the array operations by hand as variables:
Am I misunderstanding how derivative() is to be used?
The text was updated successfully, but these errors were encountered: