Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hookup demand-driven forward mode to the Diffractor runtime #99

Merged
merged 2 commits into from
Jan 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
137 changes: 119 additions & 18 deletions src/codegen/forward_demand.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
using Core.Compiler: IRInterpretationState, construct_postdomtree, PiNode,
is_known_call, argextype, postdominates

function forward_diff!(ir::IRCode, interp, irsv::IRInterpretationState, pantelides::Vector{SSAValue}; custom_diff! = (args...)->nothing, diff_cache=Dict{SSAValue, SSAValue}())
#=
function forward_diff!(ir::IRCode, interp, irsv::IRInterpretationState, to_diff::Vector{Pair{SSAValue, Int}}; custom_diff! = (args...)->nothing, diff_cache=Dict{SSAValue, SSAValue}())
Δs = SSAValue[]
rets = findall(@nospecialize(x)->isa(x, ReturnNode) && isdefined(x, :val), ir.stmts.inst)
postdomtree = construct_postdomtree(ir.cfg.blocks)
for ssa in pantelides
Δssa = forward_diff!(ir, interp, irsv, ssa; custom_diff!, diff_cache)
for (ssa, order) in to_diff
Δssa = forward_diff!(ir, interp, irsv, ssa, order; custom_diff!, diff_cache)
Δblock = block_for_inst(ir, Δssa.id)
for idx in rets
retblock = block_for_inst(ir, idx)
Expand All @@ -18,31 +19,24 @@ function forward_diff!(ir::IRCode, interp, irsv::IRInterpretationState, pantelid
end
return (ir, Δs)
end
=#

function diff_unassigned_variable!(ir, ssa)
return insert_node!(ir, ssa, NewInstruction(
Expr(:call, GlobalRef(Intrinsics, :state_ddt), ssa), Float64), #=attach_after=#true)
end

function forward_diff!(ir::IRCode, interp, irsv::IRInterpretationState, ssa::SSAValue; custom_diff!, diff_cache)
function forward_diff!(ir::IRCode, interp, irsv::IRInterpretationState, ssa::SSAValue, order::Int; custom_diff!, diff_cache)
if haskey(diff_cache, ssa)
return diff_cache[ssa]
end
inst = ir[ssa]
stmt = inst[:inst]
if isa(stmt, SSAValue)
return forward_diff!(ir, interp, irsv, stmt; custom_diff!, diff_cache)
end
Δssa = forward_diff_uncached!(ir, interp, irsv, ssa, inst; custom_diff!, diff_cache)
Δssa = forward_diff_uncached!(ir, interp, irsv, ssa, inst, order::Int; custom_diff!, diff_cache)
@assert Δssa !== nothing
if isa(Δssa, SSAValue)
diff_cache[ssa] = Δssa
end
return Δssa
end
forward_diff!(ir::IRCode, interp, irsv::IRInterpretationState, val::Union{Integer, AbstractFloat}; custom_diff!, diff_cache) = zero(val)
forward_diff!(ir::IRCode, interp, irsv::IRInterpretationState, @nospecialize(arg); custom_diff!, diff_cache) = ChainRulesCore.NoTangent()
function forward_diff!(ir::IRCode, interp, irsv::IRInterpretationState, arg::Argument; custom_diff!, diff_cache)
forward_diff!(ir::IRCode, interp, irsv::IRInterpretationState, val::Union{Integer, AbstractFloat}, order::Int; custom_diff!, diff_cache) = zero(val)
forward_diff!(ir::IRCode, interp, irsv::IRInterpretationState, @nospecialize(arg), order::Int; custom_diff!, diff_cache) = ChainRulesCore.NoTangent()
function forward_diff!(ir::IRCode, interp, irsv::IRInterpretationState, arg::Argument, order::Int; custom_diff!, diff_cache)
recurse(x) = forward_diff!(ir, interp, irsv, x; custom_diff!, diff_cache)
val = custom_diff!(ir, SSAValue(0), arg, recurse)
if val !== nothing
Expand All @@ -51,13 +45,15 @@ function forward_diff!(ir::IRCode, interp, irsv::IRInterpretationState, arg::Arg
return ChainRulesCore.NoTangent()
end

function forward_diff_uncached!(ir::IRCode, interp, irsv::IRInterpretationState, ssa::SSAValue, inst::Core.Compiler.Instruction; custom_diff!, diff_cache)
function forward_diff_uncached!(ir::IRCode, interp, irsv::IRInterpretationState, ssa::SSAValue, inst::Core.Compiler.Instruction, order::Int; custom_diff!, diff_cache)
stmt = inst[:inst]
recurse(x) = forward_diff!(ir, interp, irsv, x; custom_diff!, diff_cache)
recurse(x) = forward_diff!(ir, interp, irsv, x, order; custom_diff!, diff_cache)
if (val = custom_diff!(ir, ssa, stmt, recurse)) !== nothing
return val
elseif isa(stmt, PiNode)
return recurse(stmt.val)
elseif isa(stmt, SSAValue)
return recurse(stmt)
elseif isa(stmt, PhiNode)
Δphi = PhiNode(copy(stmt.edges), similar(stmt.values))
T = Union{}
Expand Down Expand Up @@ -152,3 +148,108 @@ function forward_diff_uncached!(ir::IRCode, interp, irsv::IRInterpretationState,
return Δssa
end
end

function forward_visit!(ir::IRCode, ssa::SSAValue, order::Int, ssa_orders::Vector{Pair{Int, Bool}}, visit_custom!)
if ssa_orders[ssa.id][1] >= order
return
end
ssa_orders[ssa.id] = order => ssa_orders[ssa.id][2]
inst = ir[ssa]
stmt = inst[:inst]
recurse(@nospecialize(val)) = forward_visit!(ir, val, order, ssa_orders, visit_custom!)
if visit_custom!(ir, stmt, order, recurse)
ssa_orders[ssa.id] = order => true
return
elseif isa(stmt, PiNode)
return recurse(stmt.val)
elseif isa(stmt, PhiNode)
for i = 1:length(stmt.values)
isassigned(stmt.values, i) || continue
recurse(stmt.values[i])
end
return
elseif isexpr(stmt, :new) || isexpr(stmt, :invoke)
foreach(recurse, stmt.args[2:end])
elseif isexpr(stmt, :call)
foreach(recurse, stmt.args)
elseif isa(stmt, SSAValue)
recurse(stmt)
elseif !isa(stmt, Expr)
return
else
@show stmt
error()
end
end
forward_visit!(ir::IRCode, _, order::Int, ssa_orders::Vector{Pair{Int, Bool}}, visit_custom!) = nothing
function forward_visit!(ir::IRCode, a::Argument, order::Int, ssa_orders::Vector{Pair{Int, Bool}}, visit_custom!)
recurse(@nospecialize(val)) = forward_visit!(ir, val, order, ssa_orders, visit_custom!)
return visit_custom!(ir, a, order, recurse)
end


function forward_diff_no_inf!(ir::IRCode, interp, mi::MethodInstance, world, to_diff::Vector{Pair{SSAValue, Int}};
visit_custom! = (args...)->false, transform! = (args...)->error())
# Step 1: For each SSAValue in the IR, keep track of the differentiation order needed
ssa_orders = [0=>false for i = 1:length(ir.stmts)]
for (ssa, order) in to_diff
forward_visit!(ir, ssa, order, ssa_orders, visit_custom!)
end

# Step 2: Transform
function maparg(arg, ssa, order)
if isa(arg, Argument)
# TODO: Should we remember whether the callbacks wanted the arg?
return transform!(ir, arg, order)
elseif isa(arg, SSAValue)
# TODO: Bundle truncation if necessary
return arg
end
@assert !isa(arg, Expr)
return insert_node!(ir, ssa, NewInstruction(Expr(:call, ZeroBundle{order}, arg), Any))
end

for (ssa, (order, custom)) in enumerate(ssa_orders)
if order == 0
# TODO: Bundle truncation?
continue
end
if custom
transform!(ir, SSAValue(ssa), order)
else
inst = ir[SSAValue(ssa)]
stmt = inst[:inst]
if isexpr(stmt, :invoke)
inst[:inst] = Expr(:call, ∂☆{order}(), map(arg->maparg(arg, SSAValue(ssa), order), stmt.args[2:end])...)
inst[:type] = Any
elseif !isa(stmt, Expr)
inst[:inst] = maparg(stmt, ssa, order)
inst[:type] = Any
else
@show stmt
error()
end
end
end

end

function forward_diff!(ir::IRCode, interp, mi::MethodInstance, world, to_diff::Vector{Pair{SSAValue, Int}}; kwargs...)
forward_diff_no_inf!(ir, interp, mi, world, to_diff; kwargs...)

# Step 3: Re-inference
ir = compact!(ir)

extra_reprocess = CC.BitSet()
for i = 1:length(ir.stmts)
if ir[SSAValue(i)][:type] == Any
CC.push!(extra_reprocess, i)
end
end

interp′ = enable_reinference(interp)
irsv = IRInterpretationState(interp′, ir, mi, world, ir.argtypes[1:mi.def.nargs])
rt = CC._ir_abstract_constant_propagation(interp′, irsv; extra_reprocess)

return ir
end
17 changes: 17 additions & 0 deletions src/higher_fwd_rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,23 @@ for f in (sin, cos, exp)
end
end

# TODO: It's a bit embarassing that we need to write these out, but currently the
# compiler is not strong enough to automatically lift the frule. Let's hope we
# can delete these in the near future.
function (∂☆ₙ::∂☆{N})(fb::ZeroBundle{N, typeof(+)}, a::TaylorBundle{N}, b::TaylorBundle{N}) where {N}
TaylorBundle{N}(primal(a) + primal(b),
map(+, a.tangent.coeffs, b.tangent.coeffs))
end

function (∂☆ₙ::∂☆{N})(fb::ZeroBundle{N, typeof(+)}, a::TaylorBundle{N}, b::ZeroBundle{N}) where {N}
TaylorBundle{N}(primal(a) + primal(b), a.tangent.coeffs)
end

function (∂☆ₙ::∂☆{N})(fb::ZeroBundle{N, typeof(-)}, a::TaylorBundle{N}, b::TaylorBundle{N}) where {N}
TaylorBundle{N}(primal(a) - primal(b),
map(-, a.tangent.coeffs, b.tangent.coeffs))
end

function (::Diffractor.∂☆new{N})(B::ATB{N, Type{T}}, args::ATB{N}...) where {N, T<:SArray}
error("Should have intercepted the constructor")
end
8 changes: 5 additions & 3 deletions src/stage1/forward.jl
Original file line number Diff line number Diff line change
Expand Up @@ -205,13 +205,15 @@ struct FwdIterate{N, T<:AbstractTangentBundle{N}}
end
function (f::FwdIterate)(arg::ATB{N}) where {N}
r = ∂☆{N}()(f.f, arg)
primal(r) === nothing && return nothing
# `primal(r) === nothing` would work, but doesn't create `Conditional` in inference
isa(r, ATB{N, Nothing}) && return nothing
(∂☆{N}()(ZeroBundle{N}(getindex), r, ZeroBundle{N}(1)),
primal(∂☆{N}()(ZeroBundle{N}(getindex), r, ZeroBundle{N}(2))))
end
function (f::FwdIterate)(arg::ATB{N}, st) where {N}
@Base.constprop :aggressive function (f::FwdIterate)(arg::ATB{N}, st) where {N}
r = ∂☆{N}()(f.f, arg, ZeroBundle{N}(st))
primal(r) === nothing && return nothing
# `primal(r) === nothing` would work, but doesn't create `Conditional` in inference
isa(r, ATB{N, Nothing}) && return nothing
(∂☆{N}()(ZeroBundle{N}(getindex), r, ZeroBundle{N}(1)),
primal(∂☆{N}()(ZeroBundle{N}(getindex), r, ZeroBundle{N}(2))))
end
Expand Down
36 changes: 26 additions & 10 deletions src/stage2/forward.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,34 +2,50 @@ using .CC: compact!

# Engineering entry point for the 2nd-order forward AD functionality. This is
# unlikely to be the actual interface. For now, it is used for testing.
function dontuse_nth_order_forward_stage2(tt::Type)
function dontuse_nth_order_forward_stage2(tt::Type, order::Int=1)
interp = ADInterpreter(; forward=true, backward=false)
match = Base._which(tt)
frame = Core.Compiler.typeinf_frame(interp, match.method, match.spec_types, match.sparams, #=run_optimizer=#true)

ir = copy((interp.opt[0][frame.linfo].inferred).ir::IRCode)

# Find all Return Nodes
vals = SSAValue[]
vals = Pair{SSAValue, Int}[]
for i = 1:length(ir.stmts)
if isa(ir[SSAValue(i)][:inst], ReturnNode)
push!(vals, SSAValue(i))
push!(vals, SSAValue(i)=>order)
end
end

function custom_diff!(ir, ssa, stmt, recurse)
function visit_custom!(ir::IRCode, @nospecialize(stmt), order, recurse)
if isa(stmt, ReturnNode)
r = recurse(stmt.val)
ir[ssa][:inst] = ReturnNode(r)
return ssa
recurse(stmt.val)
return true
elseif isa(stmt, Argument)
return 1.0
return true
else
return false
end
return nothing
end

function transform!(ir::IRCode, ssa::SSAValue, _)
inst = ir[ssa]
stmt = inst[:inst]
if isa(stmt, ReturnNode)
nr = insert_node!(ir, ssa, NewInstruction(Expr(:call, getindex, stmt.val, TaylorTangentIndex(order)), Any))
inst[:inst] = ReturnNode(nr)
else
error()
end
end

function transform!(ir::IRCode, arg::Argument, _)
return insert_node!(ir, SSAValue(1), NewInstruction(Expr(:call, ∂xⁿ{order}(), arg), typeof(∂xⁿ{order}()(1.0))))
end


irsv = CC.IRInterpretationState(interp, ir, frame.linfo, CC.get_world_counter(interp), ir.argtypes[1:frame.linfo.def.nargs])
forward_diff!(ir, interp, irsv, vals; custom_diff!)
ir = forward_diff!(ir, interp, frame.linfo, CC.get_world_counter(interp), vals; visit_custom!, transform!)

ir = compact!(ir)
return OpaqueClosure(ir)
Expand Down
Loading