Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
WIP: Don't use stacks for simple control flow
Right now Zygote inserts stacks whenever it needs to use an ssa value not defined in the first basic block. This is of course unnecessary. The condition for needing stacks is that the basic block that defines it is self-reachable (i.e. in a loop). Otherwise, we can simply insert phi nodes to thread the desired SSA value through to the exit block (we don't need to do anything in the adjoint, since the reversal of the CFG ensures dominance). Removing stacks allows for both more efficient code generation and enables higher order auto-diff (since we use control flow in Zygote, but can't handle differentiating code that contains stacks). The headline example is something like the following: ``` function foo(b, x) if b sin(x) else cos(x) end end ``` Then looking at `@code_typed derivative(x->foo(true, x), 1.0)`, we get: Before: ``` CodeInfo( 1 ── %1 = $(Expr(:foreigncall, :(:jl_alloc_array_1d), Array{Int8,1}, svec(Any, Int64), :(:ccall), 2, Array{Int8,1}, 0, 0))::Array{Int8,1} │ %2 = $(Expr(:foreigncall, :(:jl_alloc_array_1d), Array{Any,1}, svec(Any, Int64), :(:ccall), 2, Array{Any,1}, 0, 0))::Array{Any,1} │ %3 = $(Expr(:foreigncall, :(:jl_alloc_array_1d), Array{Any,1}, svec(Any, Int64), :(:ccall), 2, Array{Any,1}, 0, 0))::Array{Any,1} │ %4 = Base.sin::typeof(sin) │ invoke %4(_3::Float64)::Float64 │ %6 = %new(##334#335{Float64}, x)::##334#335{Float64} │ %7 = %new(##758#back#336{##334#335{Float64}}, %6)::##758#back#336{##334#335{Float64}} [snip] 23 ─ %52 = invoke %47(1::Int8)::Tuple{Nothing,Nothing,Any} │ %53 = Base.getfield(%52, 3, true)::Any └─── goto #24 24 ─ return %53 ) => Any ``` After: ``` CodeInfo( 1 ─ %1 = Base.sin::typeof(sin) │ invoke %1(_3::Float64)::Float64 │ %3 = Core.Intrinsics.not_int(true)::Bool └── goto #3 if not %3 2 ─ invoke Zygote.notnothing(nothing::Nothing)::Union{} └── $(Expr(:unreachable))::Union{} 3 ┄ %7 = invoke Zygote.cos(_3::Float64)::Float64 │ %8 = Base.mul_float(1.0, %7)::Float64 └── goto #4 4 ─ goto #5 5 ─ goto #6 6 ─ goto #7 7 ─ return %8 ) => Float64 ``` Which is essentially perfect (there's a bit of junk left over, but LLVM can take care of that. The only thing that doesn't get removed is the useless invocation of `sin`, but that's a separate and known issue).
- Loading branch information