Skip to content

Commit

Permalink
Merge c1a4451 into bc600ce
Browse files Browse the repository at this point in the history
  • Loading branch information
YingboMa committed Jan 12, 2020
2 parents bc600ce + c1a4451 commit 00b588c
Show file tree
Hide file tree
Showing 11 changed files with 71 additions and 142 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ChainRules"
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
version = "0.2.5"
version = "0.3"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand All @@ -10,7 +10,7 @@ Requires = "ae029012-a4dd-5104-9daa-d747884805df"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[compat]
ChainRulesCore = "0.4"
ChainRulesCore = "0.5.1"
FiniteDifferences = "^0.7"
Reexport = "0.2"
Requires = "0.5.2, 1"
Expand Down
1 change: 0 additions & 1 deletion docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,6 @@ The most important `AbstractDifferential`s when getting started are the ones abo
- `One`, `Zero`: There are special representations of `1` and `0`. They do great things around avoiding expanding `Thunks` in multiplication and (for `Zero`) addition.

#### Other `AbstractDifferential`s: don't worry about them right now
- `Wirtinger`: it is complex. The docs need to be better. [Read the links in this issue](https://github.com/JuliaDiff/ChainRulesCore.jl/issues/40).
- `Casted`: it implements broadcasting mechanics. See [#10](https://github.com/JuliaDiff/ChainRulesCore.jl/issues/10)
- `InplaceableThunk`: it is like a Thunk but it can do `store!` and `accumulate!` in-place.

Expand Down
25 changes: 7 additions & 18 deletions src/rulesets/Base/base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
@scalar_rule(zero(x), Zero())
@scalar_rule(sign(x), Zero())

@scalar_rule(abs2(x), Wirtinger(x', x))
@scalar_rule(abs2(x), 2x)
@scalar_rule(log(x), inv(x))
@scalar_rule(log10(x), inv(x) / log(oftype(x, 10)))
@scalar_rule(log2(x), inv(x) / log(oftype(x, 2)))
Expand Down Expand Up @@ -50,14 +50,12 @@
@scalar_rule(deg2rad(x), π / oftype(x, 180))
@scalar_rule(rad2deg(x), oftype(x, 180) / π)

@scalar_rule(conj(x), Wirtinger(Zero(), One()))
@scalar_rule(adjoint(x), Wirtinger(Zero(), One()))
@scalar_rule(conj(x::Real), One())
@scalar_rule(adjoint(x::Real), One())
@scalar_rule(transpose(x), One())

@scalar_rule(abs(x::Real), sign(x))
@scalar_rule(abs(x::Complex), Wirtinger(x' / 2Ω, x / 2Ω))
@scalar_rule(hypot(x::Real), sign(x))
@scalar_rule(hypot(x::Complex), Wirtinger(x' / 2Ω, x / 2Ω))
@scalar_rule(rem2pi(x, r::RoundingMode), (One(), DoesNotExist()))

@scalar_rule(+(x), One())
Expand Down Expand Up @@ -98,20 +96,14 @@
(ifelse(isint, nan, one(u)), ifelse(isint, nan, -trunc(u))))
@scalar_rule(fma(x, y, z), (y, x, One()))
@scalar_rule(muladd(x, y, z), (y, x, One()))
@scalar_rule(angle(x::Complex), @setup(u = abs2(x)), Wirtinger(-im//2 * x' / u, im//2 * x / u))
@scalar_rule(angle(x::Real), Zero())
@scalar_rule(real(x::Complex), Wirtinger(1//2, 1//2))
@scalar_rule(real(x::Real), One())
@scalar_rule(imag(x::Complex), Wirtinger(-im//2, im//2))
@scalar_rule(imag(x::Real), Zero())

# product rule requires special care for arguments where `mul` is non-commutative

function frule(::typeof(*), x::Number, y::Number)
function times_pushforward(_, Δx, Δy)
return Δx * y + x * Δy
end
return x * y, times_pushforward
function frule(::typeof(*), x::Number, y::Number, _, Δx, Δy)
return x * y, Δx * y + x * Δy
end

function rrule(::typeof(*), x::Number, y::Number)
Expand All @@ -121,11 +113,8 @@ function rrule(::typeof(*), x::Number, y::Number)
return x * y, times_pullback
end

function frule(::typeof(identity), x)
function identity_pushforward(_, ẏ)
return
end
return x, identity_pushforward
function frule(::typeof(identity), x, _, ẏ)
return x, ẏ
end

function rrule(::typeof(identity), x)
Expand Down
12 changes: 5 additions & 7 deletions src/rulesets/Base/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,17 @@ https://github.com/JuliaLang/julia/issues/22129.
=#
function _cast_diff(f, x)
function element_rule(u)
fu, du = frule(f, u)
fu, extern(du(NamedTuple(), One()))
dself = Zero()
fu, du = frule(f, u, dself, One())
fu, extern(du)
end
results = broadcast(element_rule, x)
return first.(results), last.(results)
end

function frule(::typeof(broadcast), f, x)
function frule(::typeof(broadcast), f, x, _, Δf, Δx)
Ω, ∂x = _cast_diff(f, x)
function broadcast_pushforward(_, Δf, Δx)
return Δx .* ∂x
end
return Ω, broadcast_pushforward
return Ω, Δx .* ∂x
end

function rrule(::typeof(broadcast), f, x)
Expand Down
7 changes: 2 additions & 5 deletions src/rulesets/Base/mapreduce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,8 @@ end
##### `sum`
#####

function frule(::typeof(sum), x)
function sum_pushforward(_, ẋ)
return sum(ẋ)
end
return sum(x), sum_pushforward
function frule(::typeof(sum), x, _, ẋ)
return sum(x), sum(ẋ)
end

function rrule(::typeof(sum), f, x::AbstractArray{<:Real}; dims=:)
Expand Down
18 changes: 6 additions & 12 deletions src/rulesets/LinearAlgebra/blas.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ _zeros(x) = fill!(similar(x), zero(eltype(x)))
##### `BLAS.dot`
#####

frule(::typeof(BLAS.dot), x, y) = frule(dot, x, y)
frule(::typeof(BLAS.dot), x, y, Δself, Δx, Δy) = frule(dot, x, y, Δself, Δx, Δy)

rrule(::typeof(BLAS.dot), x, y) = rrule(dot, x, y)

Expand All @@ -35,12 +35,9 @@ end
##### `BLAS.nrm2`
#####

function frule(::typeof(BLAS.nrm2), x)
function frule(::typeof(BLAS.nrm2), x, _, Δx)
Ω = BLAS.nrm2(x)
function nrm2_pushforward(_, Δx)
return sum(Δx * cast(@thunk(x * inv(Ω))))
end
return Ω, nrm2_pushforward
return Ω, sum(Δx .* @thunk(x * inv(Ω)))
end

function rrule(::typeof(BLAS.nrm2), x)
Expand Down Expand Up @@ -70,16 +67,13 @@ end
##### `BLAS.asum`
#####

function frule(::typeof(BLAS.asum), x)
function asum_pushforward(_, Δx)
return sum(cast(sign, x) * Δx)
end
return BLAS.asum(x), asum_pushforward
function frule(::typeof(BLAS.asum), x, _, Δx)
return BLAS.asum(x), sum(sign.(x) .* Δx)
end

function rrule(::typeof(BLAS.asum), x)
function asum_pullback(ΔΩ)
return (NO_FIELDS, @thunk(ΔΩ * cast(sign, x)))
return (NO_FIELDS, @thunk(ΔΩ * sign.(x)))
end
return BLAS.asum(x), asum_pullback
end
Expand Down
43 changes: 13 additions & 30 deletions src/rulesets/LinearAlgebra/dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,8 @@ const SquareMatrix{T} = Union{Diagonal{T},AbstractTriangular{T}}
##### `dot`
#####

function frule(::typeof(dot), x, y)
function dot_pushforward(Δself, Δx, Δy)
return sum(Δx .* y) + sum(x .* Δy)
end
return dot(x, y), dot_pushforward
function frule(::typeof(dot), x, y, _, Δx, Δy)
return dot(x, y), sum(Δx .* y) + sum(x .* Δy)
end

function rrule(::typeof(dot), x, y)
Expand All @@ -26,20 +23,15 @@ end
##### `inv`
#####

function frule(::typeof(inv), x::AbstractArray)
function frule(::typeof(inv), x::AbstractArray, _, Δx)
Ω = inv(x)
m = @thunk(-Ω)
function inv_pushforward(_, Δx)
return m * Δx * Ω
end
return Ω, inv_pushforward
return Ω, -Ω * Δx * Ω
end

function rrule(::typeof(inv), x::AbstractArray)
Ω = inv(x)
m = @thunk(-Ω')
function inv_pullback(ΔΩ)
return NO_FIELDS, m * ΔΩ * Ω'
return NO_FIELDS, -Ω' * ΔΩ * Ω'
end
return Ω, inv_pullback
end
Expand All @@ -48,14 +40,11 @@ end
##### `det`
#####

function frule(::typeof(det), x)
function frule(::typeof(det), x, _, ẋ)
Ω = det(x)
function det_pushforward(_, ẋ)
# TODO Performance optimization: probably there is an efficent
# way to compute this trace without during the full compution within
return Ω * tr(inv(x) * ẋ)
end
return Ω, det_pushforward
# TODO Performance optimization: probably there is an efficent
# way to compute this trace without during the full compution within
return Ω, Ω * tr(inv(x) * ẋ)
end

function rrule(::typeof(det), x)
Expand All @@ -70,12 +59,9 @@ end
##### `logdet`
#####

function frule(::typeof(logdet), x)
function frule(::typeof(logdet), x, _, Δx)
Ω = logdet(x)
function logdet_pushforward(_, Δx)
return tr(inv(x) * Δx)
end
return Ω, logdet_pushforward
return Ω, tr(inv(x) * Δx)
end

function rrule(::typeof(logdet), x)
Expand All @@ -90,11 +76,8 @@ end
##### `trace`
#####

function frule(::typeof(tr), x)
function tr_pushforward(_, Δx)
return tr(Δx)
end
return tr(x), tr_pushforward
function frule(::typeof(tr), x, _, Δx)
return tr(x), tr(Δx)
end

function rrule(::typeof(tr), x)
Expand Down
39 changes: 19 additions & 20 deletions test/rulesets/Base/base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,18 +51,19 @@
test_scalar(acscd, 1/x)
test_scalar(acotd, 1/x)
end

@testset "sincos" begin
x, Δx, x̄ = randn(3)
Δz = (randn(), randn())
@testset "Multivariate" begin
@testset "sincos" begin
x, Δx, x̄ = randn(3)
Δz = (randn(), randn())

frule_test(sincos, (x, Δx))
rrule_test(sincos, Δz, (x, x̄))
frule_test(sincos, (x, Δx))
rrule_test(sincos, Δz, (x, x̄))
end
end
end # Trig

@testset "math" begin
for x in (-0.1, 6.4, 1.0+0.5im, -10.0+0im)
for x in (-0.1, 6.4)
test_scalar(deg2rad, x)
test_scalar(rad2deg, x)

Expand All @@ -72,22 +73,20 @@
test_scalar(exp2, x)
test_scalar(exp10, x)

x isa Real && test_scalar(cbrt, x)
if (x isa Real && x >= 0) || x isa Complex
# this check is needed because these have discontinuities between
# `-10 + im*eps()` and `-10 - im*eps()`
should_test_wirtinger = imag(x) != 0 && real(x) < 0
test_scalar(sqrt, x; test_wirtinger=should_test_wirtinger)
test_scalar(log, x; test_wirtinger=should_test_wirtinger)
test_scalar(log2, x; test_wirtinger=should_test_wirtinger)
test_scalar(log10, x; test_wirtinger=should_test_wirtinger)
test_scalar(log1p, x; test_wirtinger=should_test_wirtinger)
test_scalar(cbrt, x)

if x >= 0
test_scalar(sqrt, x)
test_scalar(log, x)
test_scalar(log2, x)
test_scalar(log10, x)
test_scalar(log1p, x)
end
end
end

@testset "Unary complex functions" begin
for x in (-4.1, 6.4, 1.0+0.5im, -10.0+1.5im)
for x in (-4.1, 6.4)
test_scalar(real, x)
test_scalar(imag, x)

Expand All @@ -97,6 +96,7 @@
test_scalar(angle, x)
test_scalar(abs2, x)
test_scalar(conj, x)
test_scalar(adjoint, x)
end
end

Expand Down Expand Up @@ -152,8 +152,7 @@
_, x̄ = pb(10.5)
@test extern(x̄) == 0

_, pf = frule(sign, 0.0)
= pf(NamedTuple(), 10.5)
_, ẏ = frule(sign, 0.0, Zero(), 10.5)
@test extern(ẏ) == 0
end
end
Expand Down
4 changes: 1 addition & 3 deletions test/rulesets/Base/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,8 @@
end
@testset "frule" begin
x = rand(3, 3)
y, pushforward = frule(broadcast, sin, x)
y, = frule(broadcast, sin, x, Zero(), Zero(), One())
@test y == sin.(x)

= pushforward(NamedTuple(), NamedTuple(), One())
@test extern(ẏ) == cos.(x)
end
end
Expand Down
1 change: 0 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ using Test

# For testing purposes we use a lot of
using ChainRulesCore: extern, accumulate, accumulate!, store!, @scalar_rule,
Wirtinger, wirtinger_primal, wirtinger_conjugate,
Zero, One, DoesNotExist, Thunk, AbstractDifferential

Random.seed!(1) # Set seed that all testsets should reset to.
Expand Down
Loading

0 comments on commit 00b588c

Please sign in to comment.