Skip to content

Commit

Permalink
Merge a716ae4 into 59f06bc
Browse files Browse the repository at this point in the history
  • Loading branch information
oxinabox committed Apr 17, 2020
2 parents 59f06bc + a716ae4 commit 725b43b
Show file tree
Hide file tree
Showing 11 changed files with 7 additions and 213 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ChainRules"
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
version = "0.4.1"
version = "0.5.0"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
3 changes: 0 additions & 3 deletions src/ChainRules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,8 @@ if VERSION < v"1.3.0-DEV.142"
import LinearAlgebra: dot
end

include("helper_functions.jl")

include("rulesets/Base/base.jl")
include("rulesets/Base/array.jl")
include("rulesets/Base/broadcast.jl")
include("rulesets/Base/mapreduce.jl")

include("rulesets/Statistics/statistics.jl")
Expand Down
24 changes: 0 additions & 24 deletions src/helper_functions.jl

This file was deleted.

28 changes: 0 additions & 28 deletions src/rulesets/Base/broadcast.jl

This file was deleted.

64 changes: 6 additions & 58 deletions src/rulesets/Base/mapreduce.jl
Original file line number Diff line number Diff line change
@@ -1,51 +1,3 @@
#####
##### `map`
#####

function rrule(::typeof(map), f, xs...)
y = map(f, xs...)
function map_pullback(ȳ)
ntuple(length(xs)+2) do full_i
full_i == 1 && return NO_FIELDS
full_i == 2 && return DoesNotExist()
i = full_i-2
@thunk map(ȳ, xs...) do ȳi, xis...
_, pullback = _checked_rrule(f, xis...)
∂xis = pullback(ȳi)
extern(∂xis[i+1]) #+1 to skp ∂self
end
end
end
return y, map_pullback
end

#####
##### `mapreduce`, `mapfoldl`, `mapfoldr`
#####

for mf in (:mapreduce, :mapfoldl, :mapfoldr)
sig = :(rrule(::typeof($mf), f, op, x::AbstractArray{<:Real}))
call = :($mf(f, op, x))
if mf === :mapreduce
insert!(sig.args, 2, Expr(:parameters, Expr(:kw, :dims, :(:))))
insert!(call.args, 2, Expr(:parameters, Expr(:kw, :dims, :dims)))
end
pullback_name = Symbol(mf, :_pullback)
body = quote
y = $call
function $pullback_name(ȳ)
∂x = @thunk broadcast(x, ȳ) do xi, ȳi
_, pullback_f = _checked_rrule(f, xi)
_, ∂xi = pullback_f(ȳi)
extern(∂xi)
end
(NO_FIELDS, DoesNotExist(), DoesNotExist(), ∂x)
end
return y, $pullback_name
end
eval(Expr(:function, sig, body))
end

#####
##### `sum`
#####
Expand All @@ -54,18 +6,14 @@ function frule((_, ẋ), ::typeof(sum), x)
return sum(x), sum(ẋ)
end

function rrule(::typeof(sum), f, x::AbstractArray{<:Real}; dims=:)
y, mr_pullback = rrule(mapreduce, f, Base.add_sum, x; dims=dims)
function sum_pullback(ȳ)
return NO_FIELDS, DoesNotExist(), last(mr_pullback(ȳ))
end
return y, sum_pullback
end

function rrule(::typeof(sum), x::AbstractArray{<:Real}; dims=:)
y, inner_pullback = rrule(sum, identity, x; dims=dims)
y = sum(sum, x; dims=dims)
function sum_pullback(ȳ)
return NO_FIELDS, last(inner_pullback(ȳ))
# broadcasting the two works out the size no-matter `dims`
= @thunk broadcast(x, ȳ) do xi, ȳi
ȳi
end
return (NO_FIELDS, x̄)
end
return y, sum_pullback
end
Expand Down
13 changes: 0 additions & 13 deletions src/rulesets/Statistics/statistics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,3 @@ function rrule(::typeof(mean), x::AbstractArray{<:Real}; dims=:)
end
return y_sum / n, mean_pullback
end

function rrule(::typeof(mean), f, x::AbstractArray{<:Real})
y_sum, sum_pullback = rrule(sum, f, x)
n = _denom(x, :)
function mean_pullback(ȳ)
∂x = Thunk() do
_, _, ∂sum_x = sum_pullback(ȳ)
extern(∂sum_x) / n
end
return (NO_FIELDS, DoesNotExist(), ∂x)
end
return y_sum / n, mean_pullback
end
13 changes: 0 additions & 13 deletions test/helper_functions.jl

This file was deleted.

27 changes: 0 additions & 27 deletions test/rulesets/Base/broadcast.jl

This file was deleted.

39 changes: 0 additions & 39 deletions test/rulesets/Base/mapreduce.jl
Original file line number Diff line number Diff line change
@@ -1,37 +1,4 @@
@testset "Maps and Reductions" begin
@testset "map" begin
rng = MersenneTwister(42)
n = 10
x = randn(rng, n)
vx = randn(rng, n)
= randn(rng, n)
rrule_test(map, ȳ, (sin, nothing), (x, vx))
rrule_test(map, ȳ, (+, nothing), (x, vx), (randn(rng, n), randn(rng, n)))
end
@testset "mapreduce" begin
rng = MersenneTwister(6)
n = 10
x = randn(rng, n)
vx = randn(rng, n)
= randn(rng)
rrule_test(mapreduce, ȳ, (sin, nothing), (+, nothing), (x, vx))

# With keyword arguments (not yet supported in rrule_test)
X = randn(rng, n, n)
y, pullback = rrule(mapreduce, abs2, +, X; dims=2)
= randn(rng, size(y))
(_, _, _, x̄_ad) = pullback(ȳ)
x̄_fd = only(j′vp(central_fdm(5, 1), x->mapreduce(abs2, +, x; dims=2), ȳ, X))
@test x̄_ad x̄_fd atol=1e-9 rtol=1e-9
end
@testset "$f" for f in (mapfoldl, mapfoldr)
rng = MersenneTwister(10)
n = 7
x = randn(rng, n)
vx = randn(rng, n)
= randn(rng)
rrule_test(f, ȳ, (cos, nothing), (+, nothing), (x, vx))
end
@testset "sum" begin
@testset "Vector" begin
rng, M = MersenneTwister(123456), 3
Expand All @@ -48,12 +15,6 @@
frule_test(sum, (randn(rng, M, N, P), randn(rng, M, N, P)))
rrule_test(sum, randn(rng), (randn(rng, M, N, P), randn(rng, M, N, P)))
end
@testset "function argument" begin
rng = MersenneTwister(1)
n = 8
rrule_test(sum, randn(rng), (cos, nothing), (randn(rng, n), randn(rng, n)))
rrule_test(sum, randn(rng), (abs2, nothing), (randn(rng, n), randn(rng, n)))
end
@testset "keyword arguments" begin
rng = MersenneTwister(33)
n = 4
Expand Down
4 changes: 0 additions & 4 deletions test/rulesets/Statistics/statistics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,6 @@
rrule_test(mean, randn(rng), (randn(rng, n), randn(rng, n)))
end

@testset "with function arg" begin
rrule_test(mean, randn(rng), (abs2, nothing), (randn(rng, n), randn(rng, n)))
end

@testset "with dims kwargs" begin
X = randn(rng, n, n+1)
y, mean_pullback = rrule(mean, X; dims=1)
Expand Down
3 changes: 0 additions & 3 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,11 @@ Random.seed!(1) # Set seed that all testsets should reset to.

println("Testing ChainRules.jl")
@testset "ChainRules" begin
include("helper_functions.jl")
@testset "rulesets" begin

@testset "Base" begin
include(joinpath("rulesets", "Base", "base.jl"))
include(joinpath("rulesets", "Base", "array.jl"))
include(joinpath("rulesets", "Base", "mapreduce.jl"))
include(joinpath("rulesets", "Base", "broadcast.jl"))
end

print(" ")
Expand Down

0 comments on commit 725b43b

Please sign in to comment.