Skip to content

Commit

Permalink
WIP: Don't use stacks for simple control flow
Browse files Browse the repository at this point in the history
Right now Zygote inserts stacks whenever it needs to use an ssa value
not defined in the first basic block. This is of course unnecessary.
The condition for needing stacks is that the basic block that defines
it is self-reachable (i.e. in a loop). Otherwise, we can simply insert
phi nodes to thread the desired SSA value through to the exit block
(we don't need to do anything in the adjoint, since the reversal of
the CFG ensures dominance). Removing stacks allows for both more
efficient code generation and enables higher order auto-diff (since
we use control flow in Zygote, but can't handle differentiating code
that contains stacks). The headline example is something like the following:

```
function foo(b, x)
    if b
        sin(x)
    else
        cos(x)
    end
end
```

Then looking at `@code_typed derivative(x->foo(true, x), 1.0)`, we get:

Before:
```
CodeInfo(
1 ── %1  = $(Expr(:foreigncall, :(:jl_alloc_array_1d), Array{Int8,1}, svec(Any, Int64), :(:ccall), 2, Array{Int8,1}, 0, 0))::Array{Int8,1}
│    %2  = $(Expr(:foreigncall, :(:jl_alloc_array_1d), Array{Any,1}, svec(Any, Int64), :(:ccall), 2, Array{Any,1}, 0, 0))::Array{Any,1}
│    %3  = $(Expr(:foreigncall, :(:jl_alloc_array_1d), Array{Any,1}, svec(Any, Int64), :(:ccall), 2, Array{Any,1}, 0, 0))::Array{Any,1}
│    %4  = Base.sin::typeof(sin)
│          invoke %4(_3::Float64)::Float64
│    %6  = %new(##334#335{Float64}, x)::##334#335{Float64}
│    %7  = %new(##758#back#336{##334#335{Float64}}, %6)::##758#back#336{##334#335{Float64}}
[snip]
23 ─ %52 = invoke %47(1::Int8)::Tuple{Nothing,Nothing,Any}
│    %53 = Base.getfield(%52, 3, true)::Any
└───       goto #24
24 ─       return %53
) => Any
```

After:
```
CodeInfo(
1 ─ %1 = Base.sin::typeof(sin)
│        invoke %1(_3::Float64)::Float64
│   %3 = Core.Intrinsics.not_int(true)::Bool
└──      goto #3 if not %3
2 ─      invoke Zygote.notnothing(nothing::Nothing)::Union{}
└──      $(Expr(:unreachable))::Union{}
3 ┄ %7 = invoke Zygote.cos(_3::Float64)::Float64
│   %8 = Base.mul_float(1.0, %7)::Float64
└──      goto #4
4 ─      goto #5
5 ─      goto #6
6 ─      goto #7
7 ─      return %8
) => Float64
```

Which is essentially perfect (there's a bit of junk left over, but LLVM
can take care of that. The only thing that doesn't get removed is the
useless invocation of `sin`, but that's a separate and known issue).
  • Loading branch information
Keno committed Mar 1, 2019
1 parent 9dc74d1 commit aad4a75
Showing 1 changed file with 84 additions and 10 deletions.
94 changes: 84 additions & 10 deletions src/compiler/emit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,18 +46,87 @@ function stacklines(adj::Adjoint)
return recs
end

function do_dfs(cfg, bb)
# TODO: There are algorithms to this with much better
# asymptotics, but this'll do for now.
visited = Set{Int}(bb)
worklist = Int[bb]
while !isempty(worklist)
block = pop!(worklist)
for succ in cfg.blocks[block].succs
succ == bb && return nothing
(succ in visited) && continue
push!(visited, succ)
push!(worklist, succ)
end
end
return collect(visited)
end

function find_dominating_for_bb(domtree, bb, alpha_ssa, def_bb, phi_nodes)
# Ascend idoms, until we find another phi block or the def block
while bb != 0
if bb == def_bb
return alpha_ssa
elseif haskey(phi_nodes, bb)
return phi_nodes[bb][2]
end
bb = domtree.idoms[bb]
end
return nothing
end

function insert_phi_nest!(ir, domtree, T, def_bb, exit_bb, alpha_ssa, phi_blocks)
phi_nodes = Dict(bb => (pn = PhiNode(); ssa = insert_node!(ir, first(ir.cfg.blocks[bb].stmts), Union{T, Nothing}, pn); (pn, ssa)) for bb in phi_blocks)
# TODO: This could be more efficient by joint ascension of the domtree
for bb in phi_blocks
bb_phi, _ = phi_nodes[bb]
for pred in ir.cfg.blocks[bb].preds
dom = find_dominating_for_bb(domtree, pred, alpha_ssa, def_bb, phi_nodes)
if dom !== nothing
push!(bb_phi.edges, pred)
push!(bb_phi.values, dom)
else
push!(bb_phi.edges, pred)
push!(bb_phi.values, nothing)
end
end
end
exit_dom = find_dominating_for_bb(domtree, exit_bb, alpha_ssa, def_bb, phi_nodes)
end

function forward_stacks!(adj, F)
stks, recs = [], []
for fb in adj.perm, α in alphauses(adj.back, invperm(adj.perm)[fb])
if fb == 1
pushfirst!(recs, α)
else
fwd_cfg = adj.forw.cfg
domtree = construct_domtree(fwd_cfg)
exit_bb = length(fwd_cfg.blocks)
for fb in adj.perm
# TODO: do_dfs does double duty here, computing self reachability and giving
# us the set of live in nodes. There are better algorithms for the former
# and the latter shouldn't be necessary.
live_in = do_dfs(fwd_cfg, fb)
in_loop = live_in === nothing
if !in_loop
if dominates(domtree, fb, exit_bb)
phi_blocks = Int[]
else
# Liveness is trivial here, so we could specialize idf
# on that fact, but good enough for now.
phi_blocks = Core.Compiler.idf(fwd_cfg, Core.Compiler.BlockLiveness([fb], live_in), domtree)
end
end
for α in alphauses(adj.back, invperm(adj.perm)[fb])
T = exprtype(adj.forw, α)
stk = insert_node!(adj.forw, 1, xstack(T)...)
pushfirst!(recs, stk)
insert_blockend!(adj.forw, blockidx(adj.forw, α.id), Any, xcall(Zygote, :_push!, stk, α))
if !in_loop
α′ = insert_phi_nest!(adj.forw, domtree, T, fb, exit_bb, SSAValue.id), phi_blocks)
pushfirst!(recs, α′)
else
stk = insert_node!(adj.forw, 1, xstack(T)...)
pushfirst!(recs, stk)
insert_blockend!(adj.forw, blockidx(adj.forw, α.id), Any, xcall(Zygote, :_push!, stk, α))
end
pushfirst!(stks, (invperm(adj.perm)[fb], alpha(α), in_loop))
end
pushfirst!(stks, (invperm(adj.perm)[fb], alpha(α)))
end
args = [Argument(i) for i = 3:length(adj.forw.argtypes)]
T = Tuple{concrete.(exprtype.((adj.forw,), recs))...}
Expand All @@ -79,17 +148,22 @@ function forward_stacks!(adj, F)
return forw, stks
end

# If we had the type, we could make this a PiNode
notnothing(x::Nothing) = error()
notnothing(x) = x

function reverse_stacks!(adj, stks)
ir = adj.back
t = insert_node!(ir, 1, Any, xcall(Base, :getfield, Argument(1), QuoteNode(:t)))
for b = 1:length(ir.cfg.blocks)
repl = Dict()
for (i, (b′, α)) in enumerate(stks)
for (i, (b′, α, use_stack)) in enumerate(stks)
b == b′ || continue
loc, attach_after = afterphi(ir, range(ir.cfg.blocks[b])[1])
loc = max(2, loc)
if adj.perm[b′] == 1
if !use_stack
val = insert_node!(ir, loc, Any, xcall(:getindex, t, i), attach_after)
val = insert_node!(ir, loc, Any, xcall(Zygote, :notnothing, val), attach_after)
else
stk = insert_node!(ir, 1, Any, xcall(:getindex, t, i))
stk = insert_node!(ir, 1, Any, xcall(Zygote, :Stack, stk))
Expand Down

0 comments on commit aad4a75

Please sign in to comment.