Skip to content

Commit

Permalink
Removes rules depending on casted (#120)
Browse files Browse the repository at this point in the history
* Remove used of Casted

* bump version

* Update Project.toml
  • Loading branch information
oxinabox authored and willtebbutt committed Oct 17, 2019
1 parent 296723c commit 9e4cb76
Show file tree
Hide file tree
Showing 7 changed files with 19 additions and 29 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ChainRules"
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
version = "0.2.1"
version = "0.2.2-DEV"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand All @@ -10,7 +10,7 @@ Requires = "ae029012-a4dd-5104-9daa-d747884805df"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[compat]
ChainRulesCore = "^0.3"
ChainRulesCore = "0.3, 0.4"
FiniteDifferences = "^0.7"
julia = "^1.0"

Expand Down
4 changes: 2 additions & 2 deletions src/rulesets/Base/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,15 @@ end
function frule(::typeof(broadcast), f, x)
Ω, ∂x = _cast_diff(f, x)
function broadcast_pushforward(_, Δf, Δx)
return Δx * cast(∂x)
return Δx .* ∂x
end
return Ω, broadcast_pushforward
end

function rrule(::typeof(broadcast), f, x)
values, derivs = _cast_diff(f, x)
function broadcast_pullback(ΔΩ)
return (NO_FIELDS, DNE(), @thunk(ΔΩ * cast(derivs)))
return (NO_FIELDS, DNE(), @thunk(ΔΩ .* derivs))
end
return values, broadcast_pullback
end
2 changes: 1 addition & 1 deletion src/rulesets/Base/mapreduce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ end

function rrule(::typeof(sum), x)
function sum_pullback(ȳ)
return (NO_FIELDS, cast(ȳ))
return (NO_FIELDS, @thunk(fill(ȳ, size(x))))
end
return sum(x), sum_pullback
end
Expand Down
4 changes: 2 additions & 2 deletions src/rulesets/LinearAlgebra/dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,14 @@ const SquareMatrix{T} = Union{Diagonal{T},AbstractTriangular{T}}

function frule(::typeof(dot), x, y)
function dot_pushforward(Δself, Δx, Δy)
return sum(Δx * cast(y)) + sum(cast(x) * Δy)
return sum(Δx .* y) + sum(x .* Δy)
end
return dot(x, y), dot_pushforward
end

function rrule(::typeof(dot), x, y)
function dot_pullback(ΔΩ)
return (NO_FIELDS, ΔΩ * cast(y), cast(x) * ΔΩ,)
return (NO_FIELDS, @thunk(ΔΩ .* y), @thunk(x .* ΔΩ))
end
return dot(x, y), dot_pullback
end
Expand Down
25 changes: 8 additions & 17 deletions test/rulesets/Base/base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@

= rand(3, 5)
(ds, dx, dy) = pullback(z̄)

@test ds === NO_FIELDS

@test extern(dx) == extern(accumulate(zeros(3, 2), dx))
Expand All @@ -147,22 +147,13 @@
end

@testset "hypot(x, y)" begin
x, y = rand(2)
h, pushforward = frule(hypot, x, y)
dxy(x, y) = pushforward(NamedTuple(), x, y)

@test extern(dxy(One(), Zero())) === x / h
@test extern(dxy(Zero(), One())) === y / h

cx, cy = cast((One(), Zero())), cast((Zero(), One()))
dx, dy = extern(dxy(cx, cy))
@test dx === x / h
@test dy === y / h

cx, cy = cast((rand(), Zero())), cast((Zero(), rand()))
dx, dy = extern(dxy(cx, cy))
@test dx === x / h * cx.value[1]
@test dy === y / h * cy.value[2]
rng = MersenneTwister(123456)
x, Δx, x̄ = randn(rng, 3)
y, Δy, ȳ = randn(rng, 3)
Δz = randn(rng)

frule_test(hypot, (x, Δx), (y, Δy))
rrule_test(hypot, Δz, (x, x̄), (y, ȳ))
end

@testset "identity" begin
Expand Down
4 changes: 2 additions & 2 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@ using Statistics
using Test

# For testing purposes we use a lot of
using ChainRulesCore: cast, extern, accumulate, accumulate!, store!, @scalar_rule,
using ChainRulesCore: extern, accumulate, accumulate!, store!, @scalar_rule,
Wirtinger, wirtinger_primal, wirtinger_conjugate,
Zero, One, Casted, DNE, Thunk, AbstractDifferential
Zero, One, DNE, Thunk, AbstractDifferential

include("test_util.jl")

Expand Down
5 changes: 2 additions & 3 deletions test/test_util.jl
Original file line number Diff line number Diff line change
Expand Up @@ -193,12 +193,11 @@ end
function Base.isapprox(ad::Wirtinger, fd; kwargs...)
error("Finite differencing with Wirtinger rules not implemented")
end
function Base.isapprox(d_ad::Casted, d_fd; kwargs...)
return all(isapprox.(extern(d_ad), d_fd; kwargs...))
end

function Base.isapprox(d_ad::DNE, d_fd; kwargs...)
error("Tried to differentiate w.r.t. a DNE")
end

function Base.isapprox(d_ad::AbstractDifferential, d_fd; kwargs...)
return isapprox(extern(d_ad), d_fd; kwargs...)
end
Expand Down

0 comments on commit 9e4cb76

Please sign in to comment.