Skip to content

Commit

Permalink
Merge pull request #146 from JuliaDiff/myb/fuse_frule
Browse files Browse the repository at this point in the history
Update to ChainRulesCore 0.5.1
  • Loading branch information
YingboMa committed Jan 12, 2020
2 parents bc600ce + c1a4451 commit 3d2618e
Show file tree
Hide file tree
Showing 11 changed files with 71 additions and 142 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ChainRules"
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
version = "0.2.5"
version = "0.3"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand All @@ -10,7 +10,7 @@ Requires = "ae029012-a4dd-5104-9daa-d747884805df"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[compat]
ChainRulesCore = "0.4"
ChainRulesCore = "0.5.1"
FiniteDifferences = "^0.7"
Reexport = "0.2"
Requires = "0.5.2, 1"
Expand Down
1 change: 0 additions & 1 deletion docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,6 @@ The most important `AbstractDifferential`s when getting started are the ones abo
- `One`, `Zero`: There are special representations of `1` and `0`. They do great things around avoiding expanding `Thunks` in multiplication and (for `Zero`) addition.

#### Other `AbstractDifferential`s: don't worry about them right now
- `Wirtinger`: it is complex. The docs need to be better. [Read the links in this issue](https://github.com/JuliaDiff/ChainRulesCore.jl/issues/40).
- `Casted`: it implements broadcasting mechanics. See [#10](https://github.com/JuliaDiff/ChainRulesCore.jl/issues/10)
- `InplaceableThunk`: it is like a Thunk but it can do `store!` and `accumulate!` in-place.

Expand Down
25 changes: 7 additions & 18 deletions src/rulesets/Base/base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
@scalar_rule(zero(x), Zero())
@scalar_rule(sign(x), Zero())

@scalar_rule(abs2(x), Wirtinger(x', x))
@scalar_rule(abs2(x), 2x)
@scalar_rule(log(x), inv(x))
@scalar_rule(log10(x), inv(x) / log(oftype(x, 10)))
@scalar_rule(log2(x), inv(x) / log(oftype(x, 2)))
Expand Down Expand Up @@ -50,14 +50,12 @@
@scalar_rule(deg2rad(x), π / oftype(x, 180))
@scalar_rule(rad2deg(x), oftype(x, 180) / π)

@scalar_rule(conj(x), Wirtinger(Zero(), One()))
@scalar_rule(adjoint(x), Wirtinger(Zero(), One()))
@scalar_rule(conj(x::Real), One())
@scalar_rule(adjoint(x::Real), One())
@scalar_rule(transpose(x), One())

@scalar_rule(abs(x::Real), sign(x))
@scalar_rule(abs(x::Complex), Wirtinger(x' / 2Ω, x / 2Ω))
@scalar_rule(hypot(x::Real), sign(x))
@scalar_rule(hypot(x::Complex), Wirtinger(x' / 2Ω, x / 2Ω))
@scalar_rule(rem2pi(x, r::RoundingMode), (One(), DoesNotExist()))

@scalar_rule(+(x), One())
Expand Down Expand Up @@ -98,20 +96,14 @@
(ifelse(isint, nan, one(u)), ifelse(isint, nan, -trunc(u))))
@scalar_rule(fma(x, y, z), (y, x, One()))
@scalar_rule(muladd(x, y, z), (y, x, One()))
@scalar_rule(angle(x::Complex), @setup(u = abs2(x)), Wirtinger(-im//2 * x' / u, im//2 * x / u))
@scalar_rule(angle(x::Real), Zero())
@scalar_rule(real(x::Complex), Wirtinger(1//2, 1//2))
@scalar_rule(real(x::Real), One())
@scalar_rule(imag(x::Complex), Wirtinger(-im//2, im//2))
@scalar_rule(imag(x::Real), Zero())

# product rule requires special care for arguments where `mul` is non-commutative

function frule(::typeof(*), x::Number, y::Number)
function times_pushforward(_, Δx, Δy)
return Δx * y + x * Δy
end
return x * y, times_pushforward
function frule(::typeof(*), x::Number, y::Number, _, Δx, Δy)
return x * y, Δx * y + x * Δy
end

function rrule(::typeof(*), x::Number, y::Number)
Expand All @@ -121,11 +113,8 @@ function rrule(::typeof(*), x::Number, y::Number)
return x * y, times_pullback
end

function frule(::typeof(identity), x)
function identity_pushforward(_, ẏ)
return
end
return x, identity_pushforward
function frule(::typeof(identity), x, _, ẏ)
return x, ẏ
end

function rrule(::typeof(identity), x)
Expand Down
12 changes: 5 additions & 7 deletions src/rulesets/Base/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,17 @@ https://github.com/JuliaLang/julia/issues/22129.
=#
function _cast_diff(f, x)
function element_rule(u)
fu, du = frule(f, u)
fu, extern(du(NamedTuple(), One()))
dself = Zero()
fu, du = frule(f, u, dself, One())
fu, extern(du)
end
results = broadcast(element_rule, x)
return first.(results), last.(results)
end

function frule(::typeof(broadcast), f, x)
function frule(::typeof(broadcast), f, x, _, Δf, Δx)
Ω, ∂x = _cast_diff(f, x)
function broadcast_pushforward(_, Δf, Δx)
return Δx .* ∂x
end
return Ω, broadcast_pushforward
return Ω, Δx .* ∂x
end

function rrule(::typeof(broadcast), f, x)
Expand Down
7 changes: 2 additions & 5 deletions src/rulesets/Base/mapreduce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,8 @@ end
##### `sum`
#####

function frule(::typeof(sum), x)
function sum_pushforward(_, ẋ)
return sum(ẋ)
end
return sum(x), sum_pushforward
function frule(::typeof(sum), x, _, ẋ)
return sum(x), sum(ẋ)
end

function rrule(::typeof(sum), f, x::AbstractArray{<:Real}; dims=:)
Expand Down
18 changes: 6 additions & 12 deletions src/rulesets/LinearAlgebra/blas.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ _zeros(x) = fill!(similar(x), zero(eltype(x)))
##### `BLAS.dot`
#####

frule(::typeof(BLAS.dot), x, y) = frule(dot, x, y)
frule(::typeof(BLAS.dot), x, y, Δself, Δx, Δy) = frule(dot, x, y, Δself, Δx, Δy)

rrule(::typeof(BLAS.dot), x, y) = rrule(dot, x, y)

Expand All @@ -35,12 +35,9 @@ end
##### `BLAS.nrm2`
#####

function frule(::typeof(BLAS.nrm2), x)
function frule(::typeof(BLAS.nrm2), x, _, Δx)
Ω = BLAS.nrm2(x)
function nrm2_pushforward(_, Δx)
return sum(Δx * cast(@thunk(x * inv(Ω))))
end
return Ω, nrm2_pushforward
return Ω, sum(Δx .* @thunk(x * inv(Ω)))
end

function rrule(::typeof(BLAS.nrm2), x)
Expand Down Expand Up @@ -70,16 +67,13 @@ end
##### `BLAS.asum`
#####

function frule(::typeof(BLAS.asum), x)
function asum_pushforward(_, Δx)
return sum(cast(sign, x) * Δx)
end
return BLAS.asum(x), asum_pushforward
function frule(::typeof(BLAS.asum), x, _, Δx)
return BLAS.asum(x), sum(sign.(x) .* Δx)
end

function rrule(::typeof(BLAS.asum), x)
function asum_pullback(ΔΩ)
return (NO_FIELDS, @thunk(ΔΩ * cast(sign, x)))
return (NO_FIELDS, @thunk(ΔΩ * sign.(x)))
end
return BLAS.asum(x), asum_pullback
end
Expand Down
43 changes: 13 additions & 30 deletions src/rulesets/LinearAlgebra/dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,8 @@ const SquareMatrix{T} = Union{Diagonal{T},AbstractTriangular{T}}
##### `dot`
#####

function frule(::typeof(dot), x, y)
function dot_pushforward(Δself, Δx, Δy)
return sum(Δx .* y) + sum(x .* Δy)
end
return dot(x, y), dot_pushforward
function frule(::typeof(dot), x, y, _, Δx, Δy)
return dot(x, y), sum(Δx .* y) + sum(x .* Δy)
end

function rrule(::typeof(dot), x, y)
Expand All @@ -26,20 +23,15 @@ end
##### `inv`
#####

function frule(::typeof(inv), x::AbstractArray)
function frule(::typeof(inv), x::AbstractArray, _, Δx)
Ω = inv(x)
m = @thunk(-Ω)
function inv_pushforward(_, Δx)
return m * Δx * Ω
end
return Ω, inv_pushforward
return Ω, -Ω * Δx * Ω
end

function rrule(::typeof(inv), x::AbstractArray)
Ω = inv(x)
m = @thunk(-Ω')
function inv_pullback(ΔΩ)
return NO_FIELDS, m * ΔΩ * Ω'
return NO_FIELDS, -Ω' * ΔΩ * Ω'
end
return Ω, inv_pullback
end
Expand All @@ -48,14 +40,11 @@ end
##### `det`
#####

function frule(::typeof(det), x)
function frule(::typeof(det), x, _, ẋ)
Ω = det(x)
function det_pushforward(_, ẋ)
# TODO Performance optimization: probably there is an efficent
# way to compute this trace without during the full compution within
return Ω * tr(inv(x) * ẋ)
end
return Ω, det_pushforward
# TODO Performance optimization: probably there is an efficent
# way to compute this trace without during the full compution within
return Ω, Ω * tr(inv(x) * ẋ)
end

function rrule(::typeof(det), x)
Expand All @@ -70,12 +59,9 @@ end
##### `logdet`
#####

function frule(::typeof(logdet), x)
function frule(::typeof(logdet), x, _, Δx)
Ω = logdet(x)
function logdet_pushforward(_, Δx)
return tr(inv(x) * Δx)
end
return Ω, logdet_pushforward
return Ω, tr(inv(x) * Δx)
end

function rrule(::typeof(logdet), x)
Expand All @@ -90,11 +76,8 @@ end
##### `trace`
#####

function frule(::typeof(tr), x)
function tr_pushforward(_, Δx)
return tr(Δx)
end
return tr(x), tr_pushforward
function frule(::typeof(tr), x, _, Δx)
return tr(x), tr(Δx)
end

function rrule(::typeof(tr), x)
Expand Down
39 changes: 19 additions & 20 deletions test/rulesets/Base/base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,18 +51,19 @@
test_scalar(acscd, 1/x)
test_scalar(acotd, 1/x)
end

@testset "sincos" begin
x, Δx, x̄ = randn(3)
Δz = (randn(), randn())
@testset "Multivariate" begin
@testset "sincos" begin
x, Δx, x̄ = randn(3)
Δz = (randn(), randn())

frule_test(sincos, (x, Δx))
rrule_test(sincos, Δz, (x, x̄))
frule_test(sincos, (x, Δx))
rrule_test(sincos, Δz, (x, x̄))
end
end
end # Trig

@testset "math" begin
for x in (-0.1, 6.4, 1.0+0.5im, -10.0+0im)
for x in (-0.1, 6.4)
test_scalar(deg2rad, x)
test_scalar(rad2deg, x)

Expand All @@ -72,22 +73,20 @@
test_scalar(exp2, x)
test_scalar(exp10, x)

x isa Real && test_scalar(cbrt, x)
if (x isa Real && x >= 0) || x isa Complex
# this check is needed because these have discontinuities between
# `-10 + im*eps()` and `-10 - im*eps()`
should_test_wirtinger = imag(x) != 0 && real(x) < 0
test_scalar(sqrt, x; test_wirtinger=should_test_wirtinger)
test_scalar(log, x; test_wirtinger=should_test_wirtinger)
test_scalar(log2, x; test_wirtinger=should_test_wirtinger)
test_scalar(log10, x; test_wirtinger=should_test_wirtinger)
test_scalar(log1p, x; test_wirtinger=should_test_wirtinger)
test_scalar(cbrt, x)

if x >= 0
test_scalar(sqrt, x)
test_scalar(log, x)
test_scalar(log2, x)
test_scalar(log10, x)
test_scalar(log1p, x)
end
end
end

@testset "Unary complex functions" begin
for x in (-4.1, 6.4, 1.0+0.5im, -10.0+1.5im)
for x in (-4.1, 6.4)
test_scalar(real, x)
test_scalar(imag, x)

Expand All @@ -97,6 +96,7 @@
test_scalar(angle, x)
test_scalar(abs2, x)
test_scalar(conj, x)
test_scalar(adjoint, x)
end
end

Expand Down Expand Up @@ -152,8 +152,7 @@
_, x̄ = pb(10.5)
@test extern(x̄) == 0

_, pf = frule(sign, 0.0)
= pf(NamedTuple(), 10.5)
_, ẏ = frule(sign, 0.0, Zero(), 10.5)
@test extern(ẏ) == 0
end
end
Expand Down
4 changes: 1 addition & 3 deletions test/rulesets/Base/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,8 @@
end
@testset "frule" begin
x = rand(3, 3)
y, pushforward = frule(broadcast, sin, x)
y, = frule(broadcast, sin, x, Zero(), Zero(), One())
@test y == sin.(x)

= pushforward(NamedTuple(), NamedTuple(), One())
@test extern(ẏ) == cos.(x)
end
end
Expand Down
1 change: 0 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ using Test

# For testing purposes we use a lot of
using ChainRulesCore: extern, accumulate, accumulate!, store!, @scalar_rule,
Wirtinger, wirtinger_primal, wirtinger_conjugate,
Zero, One, DoesNotExist, Thunk, AbstractDifferential

Random.seed!(1) # Set seed that all testsets should reset to.
Expand Down
Loading

2 comments on commit 3d2618e

@YingboMa
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register()

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/7827

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if Julia TagBot is installed, or can be done manually through the github interface, or via:

git tag -a v0.3.0 -m "<description of version>" 3d2618e58dc55cf688e8570f97648ed9b4cb60f3
git push origin v0.3.0

Please sign in to comment.