Skip to content

Commit

Permalink
use istracked instead of pure type checking to decide whether to upda…
Browse files Browse the repository at this point in the history
…te partials cache for pow/div broadcast methods
  • Loading branch information
jrevels committed Dec 12, 2016
1 parent c8c47d6 commit db3d105
Showing 1 changed file with 8 additions and 12 deletions.
20 changes: 8 additions & 12 deletions src/derivatives/elementwise.jl
Original file line number Diff line number Diff line change
Expand Up @@ -433,9 +433,7 @@ denom_partials(n, d) = broadcast(denom_partials_kernel, n, d)
denom_partials!(out::Ref, n, d) = (out[] = denom_partials_kernel(n, d); nothing)
denom_partials!(out::AbstractArray, n, d) = (broadcast!(denom_partials_kernel, out, n, d); nothing)

rdiv_cache(x::TrackedType, y::TrackedType) = (numer_partials(value(y)), denom_partials(value(x), value(y)))
rdiv_cache(x::TrackedType, y) = (numer_partials(value(y)), nothing)
rdiv_cache(x, y::TrackedType) = (nothing, denom_partials(value(x), value(y)))
rdiv_cache(x, y) = (numer_partials(value(y)), denom_partials(value(x), value(y)))

function broadcast_rdiv{D}(x, y, ::Type{D})
tp = tape(x, y)
Expand All @@ -452,8 +450,8 @@ end
pull_value!(a)
pull_value!(b)
broadcast!(/, value(output), a_value, b_value)
!(isa(n_partials, Void)) && numer_partials!(n_partials, b_value)
!(isa(d_partials, Void)) && denom_partials!(d_partials, a_value, b_value)
istracked(a) && numer_partials!(n_partials, b_value)
istracked(b) && denom_partials!(d_partials, a_value, b_value)
return nothing
end

Expand Down Expand Up @@ -485,8 +483,8 @@ end
pull_value!(a)
pull_value!(b)
broadcast!(\, value(output), a_value, b_value)
!(isa(n_partials, Void)) && numer_partials!(n_partials, a_value)
!(isa(d_partials, Void)) && denom_partials!(d_partials, b_value, a_value)
istracked(b) && numer_partials!(n_partials, a_value)
istracked(a) && denom_partials!(d_partials, b_value, a_value)
return nothing
end

Expand Down Expand Up @@ -515,9 +513,7 @@ exp_partials(b, e) = broadcast(exp_partials_kernel, b, e)
exp_partials!(out::Ref, b, e) = (out[] = exp_partials_kernel(b, e); nothing)
exp_partials!(out::AbstractArray, b, e) = (broadcast!(exp_partials_kernel, out, b, e); nothing)

pow_cache(x::TrackedType, y::TrackedType) = (base_partials(value(x), value(y)), exp_partials(value(x), value(y)))
pow_cache(x::TrackedType, y) = (base_partials(value(x), value(y)), nothing)
pow_cache(x, y::TrackedType) = (nothing, exp_partials(value(x), value(y)))
pow_cache(x, y) = (base_partials(value(x), value(y)), exp_partials(value(x), value(y)))

function broadcast_pow{D}(x, y, ::Type{D})
tp = tape(x, y)
Expand All @@ -534,8 +530,8 @@ end
pull_value!(a)
pull_value!(b)
broadcast!(^, value(output), a_value, b_value)
!(isa(bs_partials, Void)) && base_partials!(bs_partials, a_value, b_value)
!(isa(ex_partials, Void)) && exp_partials!(ex_partials, a_value, b_value)
istracked(a) && base_partials!(bs_partials, a_value, b_value)
istracked(b) && exp_partials!(ex_partials, a_value, b_value)
return nothing
end

Expand Down

0 comments on commit db3d105

Please sign in to comment.