Skip to content

Commit

Permalink
Reduce code duplication in backprop_params for compiled gf
Browse files Browse the repository at this point in the history
  • Loading branch information
marcoct committed Aug 30, 2018
1 parent 13d87f4 commit 79744ee
Showing 1 changed file with 111 additions and 171 deletions.
282 changes: 111 additions & 171 deletions src/generators/basic/backprop.jl
Expand Up @@ -96,16 +96,86 @@ function process!(state::Union{BBBackpropParamsState,BBBackpropTraceState}, node
# skip
end

function initialize_backprop!(ir::BasicBlockIR, stmts::Vector{Expr})

# create a gradient variable for each value node, initialize them to zero.
value_refs = Dict{ValueNode,Expr}()
grad_vars = Dict{ValueNode,Symbol}()
for (name, node) in ir.value_nodes
value_refs[node] = value_trace_ref(:trace, node)
grad_var = gensym("grad_$name")
if node !== ir.output_node || !ir.output_ad
if is_differentiable(get_type(node))
grad_vars[node] = grad_var
push!(stmts, quote
$grad_var = zero($(value_refs[node]))
end)
end
end
end

# initialize the gradient variable for the output node
if ir.output_node !== nothing && ir.output_ad
grad_var = grad_vars[something(ir.output_node)]
push!(stmts, quote
$grad_var = retval_grad
end)
end

(value_refs, grad_vars)
end

function input_gradients(ir::BasicBlockIR, grad_vars)
input_grads_var = gensym("input_grads")
input_grads = []
for (node, has_grad) in zip(ir.arg_nodes, ir.args_ad)
if has_grad
push!(input_grads, grad_vars[node])
else
push!(input_grads, QuoteNode(nothing))
end
end
Expr(:tuple, input_grads...)
end

function increment_input_gradients!(stmts, node, dist_or_gen, grad_vars, increments)
inputs_do_ad = has_argument_grads(dist_or_gen)
for (in_node, do_ad, incr) in zip(node.input_nodes, inputs_do_ad, increments)
if do_ad
if !haskey(grad_vars, in_node)
error("$(dist_or_gen) has AD for an input that is not floating point, node: $node")
end
grad_var = grad_vars[in_node]
push!(stmts, quote
$grad_var += $incr
end)
end
end
end


function increment_output_gradient!(stmts, node::AddrDistNode, grad_vars, increment)
# NOTE: if the output has a gradient, then it must be a float...
# but it may be a float and the may not hae a gradient, currently this is
# silent; could warn?
if has_output_grad(node.dist)
if !haskey(grad_vars, node.output)
error("Distribution $(node.dist) has AD but the return value is not floating point, node: $node")
end
output_grad_var = grad_vars[node.output]
push!(stmts, quote
$output_grad_var += $increment
end)
end
end


###################
# backprop_params #
###################

function process!(state::BBBackpropParamsState, node::AddrDistNode)

# NOTE: if the output does AD, then it must be a float...
# but it may be a float and the grad may not do AD.

# get gradient of log density with respect to output and inputs
input_value_refs = [state.value_refs[in_node] for in_node in node.input_nodes]
output_value_ref = state.value_refs[node.output]
Expand All @@ -116,30 +186,8 @@ function process!(state::BBBackpropParamsState, node::AddrDistNode)
$(QuoteNode(node.dist)), $output_value_ref, $(input_value_refs...))
end)

# increment output gradient
if has_output_grad(node.dist)
if !haskey(state.grad_vars, node.output)
error("Distribution $(node.dist) has AD but the return value is not floating point, node: $node")
end
output_grad_var = state.grad_vars[node.output]
push!(state.stmts, quote
$output_grad_var += $output_grad_incr
end)
end

# increment input gradients
inputs_do_ad = has_argument_grads(node.dist)
for (in_node, do_ad, incr) in zip(node.input_nodes, inputs_do_ad, input_grad_incrs)
if do_ad
if !haskey(state.grad_vars, in_node)
error("Distribution $(node.dist) has AD for an input that is not floating point, node: $node")
end
grad_var = state.grad_vars[in_node]
push!(state.stmts, quote
$grad_var += $incr
end)
end
end
increment_output_gradient!(state.stmts, node, state.grad_vars, output_grad_incr)
increment_input_gradients!(state.stmts, node, node.dist, state.grad_vars, input_grad_incrs)
end

function process!(state::BBBackpropParamsState, node::AddrGeneratorNode)
Expand All @@ -154,50 +202,18 @@ function process!(state::BBBackpropParamsState, node::AddrGeneratorNode)
($(input_grad_incrs...),) = backprop_params($(QuoteNode(node.gen)), $subtrace, $output_grad)
end)

# increment input gradients
inputs_do_ad = has_argument_grads(node.gen)
for (in_node, do_ad, incr) in zip(node.input_nodes, inputs_do_ad, input_grad_incrs)
if do_ad
if !haskey(state.grad_vars, in_node)
error("Generator $(node.gen) has AD for an input that is not floating point, node: $node")
end
grad_var = state.grad_vars[in_node]
push!(state.stmts, quote
$grad_var += $incr
end)
end
end
increment_input_gradients!(state.stmts, node, node.gen, state.grad_vars, input_grad_incrs)
end

function codegen_backprop_params(gen::Type{T}, trace, retval_grad) where {T <: BasicGenFunction}
ir = get_ir(gen)
stmts = Expr[]

# create a gradient variable for each value node, initialize them to zero.
grad_vars = Dict{ValueNode,Symbol}()
value_refs = Dict{ValueNode,Expr}()
for (name, node) in ir.value_nodes
value_refs[node] = value_trace_ref(:trace, node)
grad_var = gensym("grad_$name")
if node !== ir.output_node || !ir.output_ad
if is_differentiable(get_type(node))
grad_vars[node] = grad_var
push!(stmts, quote
$grad_var = zero($(value_refs[node]))
end)
end
end
end
# also get trace references for each value node
(value_refs, grad_vars) = initialize_backprop!(ir, stmts)

# initialize the gradient variable for the output node
if ir.output_node !== nothing && ir.output_ad
grad_var = grad_vars[something(ir.output_node)]
push!(stmts, quote
$grad_var = retval_grad
end)
end

# visit statements in reverse topological order, generating code
# visit statements in reverse topological order
state = BBBackpropParamsState(gen, :trace, stmts, value_refs, grad_vars)
for node in reverse(ir.expr_nodes_sorted)
process!(state, node)
Expand All @@ -212,17 +228,9 @@ function codegen_backprop_params(gen::Type{T}, trace, retval_grad) where {T <: B
end)
end

# gradients with respect to inputs
input_grads = []
for (node, has_grad) in zip(ir.arg_nodes, ir.args_ad)
if has_grad
push!(input_grads, grad_vars[node])
else
push!(input_grads, QuoteNode(nothing))
end
end
# return statement
push!(stmts, quote
return ($(input_grads...),)
return $(input_gradients(ir, grad_vars))
end)
Expr(:block, stmts...)
end
Expand All @@ -240,9 +248,6 @@ end)

function process!(state::BBBackpropTraceState, node::AddrDistNode)

# NOTE: if the output does AD, then it must be a float...
# but it may be a float and the grad may not do AD.

# get gradient of log density with respect to output and inputs
input_value_refs = [state.value_refs[in_node] for in_node in node.input_nodes]
output_value_ref = state.value_refs[node.output]
Expand All @@ -253,30 +258,8 @@ function process!(state::BBBackpropTraceState, node::AddrDistNode)
$(QuoteNode(node.dist)), $output_value_ref, $(input_value_refs...))
end)

# increment output gradient
if has_output_grad(node.dist)
if !haskey(state.grad_vars, node.output)
error("Distribution $(node.dist) has AD but the return value is not floating point, node: $node")
end
output_grad_var = state.grad_vars[node.output]
push!(state.stmts, quote
$output_grad_var += $output_grad_incr
end)
end

# increment input gradients
inputs_do_ad = has_argument_grads(node.dist)
for (in_node, do_ad, incr) in zip(node.input_nodes, inputs_do_ad, input_grad_incrs)
if do_ad
if !haskey(state.grad_vars, in_node)
error("Distribution $(node.dist) has AD for an input that is not floating point, node: $node")
end
grad_var = state.grad_vars[in_node]
push!(state.stmts, quote
$grad_var += $incr
end)
end
end
increment_output_gradient!(state.stmts, node, state.grad_vars, output_grad_incr)
increment_input_gradients!(state.stmts, node, node.dist, state.grad_vars, input_grad_incrs)

# handle selected address
addr = node.address
Expand Down Expand Up @@ -313,95 +296,52 @@ function process!(state::BBBackpropTraceState, node::AddrGeneratorNode)
$(QuoteNode(node.gen)), $subtrace, $selection, $output_grad)
end)

# increment input gradients
inputs_do_ad = has_argument_grads(node.gen)
for (in_node, do_ad, incr) in zip(node.input_nodes, inputs_do_ad, input_grad_incrs)
if do_ad
if !haskey(state.grad_vars, in_node)
error("Generator $(node.gen) has AD for an input that is not floating point, node: $node")
end
grad_var = state.grad_vars[in_node]
push!(state.stmts, quote
$grad_var += $incr
end)
end
increment_input_gradients!(state.stmts, node, node.gen, state.grad_vars, input_grad_incrs)
end

const backprop_values_trie = gensym("values")
const backprop_gradients_trie = gensym("gradients")

function choice_trie_construction(leaf_nodes_set, internal_nodes_set)
leaf_nodes = collect(leaf_nodes_set)
quoted_leaf_keys = map((node) -> QuoteNode(node.addr), leaf_nodes)
leaf_values = map((node) -> node.value_ref, leaf_nodes)
leaf_gradients = map((node) -> node.gradient_var, leaf_nodes)
internal_nodes = collect(internal_nodes_set)
quoted_internal_keys = map((node) -> QuoteNode(node.addr), internal_nodes)
internal_values = map((node) -> node.values_var, internal_nodes)
internal_gradients = map((node) -> node.gradients_var, internal_nodes)
quote
$backprop_values_trie = StaticChoiceTrie(
NamedTuple{($(quoted_leaf_keys...),)}(($(leaf_values...),)),
NamedTuple{($(quoted_internal_keys...),)}(($(internal_values...),)))
$backprop_gradients_trie = StaticChoiceTrie(
NamedTuple{($(quoted_leaf_keys...),)}(($(leaf_gradients...),)),
NamedTuple{($(quoted_internal_keys...),)}(($(internal_gradients...),)))
end
end

function codegen_backprop_trace(gen::Type{T}, trace, selection, retval_grad) where {T <: BasicGenFunction}
#Core.println("generating backprop_trace($gen, selection: $selection...)")
schema = get_address_schema(selection)
ir = get_ir(gen)
stmts = Expr[]

# create a gradient variable for each value node, initialize them to zero.
grad_vars = Dict{ValueNode,Symbol}()
value_refs = Dict{ValueNode,Expr}()
for (name, node) in ir.value_nodes
value_refs[node] = value_trace_ref(:trace, node)
grad_var = gensym("grad_$name")
if node !== ir.output_node || !ir.output_ad
if is_differentiable(get_type(node))
grad_vars[node] = grad_var
push!(stmts, quote
$grad_var = zero($(value_refs[node]))
end)
end
end
end

# initialize the gradient variable for the output node
if ir.output_node !== nothing && ir.output_ad
grad_var = grad_vars[something(ir.output_node)]
push!(stmts, quote
$grad_var = retval_grad
end)
end
# also get trace references for each value node
(value_refs, grad_vars) = initialize_backprop!(ir, stmts)

# visit statements in reverse topological order, generating code
# visit statements in reverse topological order
state = BBBackpropTraceState(gen, :trace, stmts, schema, value_refs, grad_vars)
for node in reverse(ir.expr_nodes_sorted)
process!(state, node)
end

# construct values and gradients static choice tries
values = gensym("values")
gradients = gensym("gradients")
push!(stmts, choice_trie_construction(state.leaf_nodes, state.internal_nodes))

leaf_nodes = collect(state.leaf_nodes)
quoted_leaf_keys = map((node) -> QuoteNode(node.addr), leaf_nodes)
leaf_values = map((node) -> node.value_ref, leaf_nodes)
leaf_gradients = map((node) -> node.gradient_var, leaf_nodes)

internal_nodes = collect(state.internal_nodes)
quoted_internal_keys = map((node) -> QuoteNode(node.addr), internal_nodes)
internal_values = map((node) -> node.values_var, internal_nodes)
internal_gradients = map((node) -> node.gradients_var, internal_nodes)

push!(stmts, quote
$values = StaticChoiceTrie(
NamedTuple{($(quoted_leaf_keys...),)}(($(leaf_values...),)),
NamedTuple{($(quoted_internal_keys...),)}(($(internal_values...),)))
$gradients = StaticChoiceTrie(
NamedTuple{($(quoted_leaf_keys...),)}(($(leaf_gradients...),)),
NamedTuple{($(quoted_internal_keys...),)}(($(internal_gradients...),)))
end)

# gradients with respect to inputs
input_grads_var = gensym("input_grads")
input_grads = []
for (node, has_grad) in zip(ir.arg_nodes, ir.args_ad)
if has_grad
push!(input_grads, grad_vars[node])
else
push!(input_grads, QuoteNode(nothing))
end
end
push!(stmts, quote
$input_grads_var = ($(input_grads...),)
end)
# return statement
push!(stmts, quote
return ($input_grads_var, $values, $gradients)
return ($(input_gradients(ir, grad_vars)), $backprop_values_trie, $backprop_gradients_trie)
end)
Expr(:block, stmts...)
end
Expand Down

0 comments on commit 79744ee

Please sign in to comment.