diff --git a/src/ChainRules.jl b/src/ChainRules.jl index c0ae991ed..7fd6aca8e 100644 --- a/src/ChainRules.jl +++ b/src/ChainRules.jl @@ -6,6 +6,8 @@ using Base.Broadcast: materialize, materialize!, broadcasted, Broadcasted, broad import NaNMath, SpecialFunctions, LinearAlgebra, LinearAlgebra.BLAS +export AbstractRule, Rule, frule, rrule + include("differentials.jl") include("rules.jl") include("rules/base.jl") diff --git a/src/rules.jl b/src/rules.jl index 24492bf66..ff2a09433 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -244,7 +244,7 @@ unary input, unary output scalar function: ```julia-repl julia> x = rand(); -julia> sinx, dsin = ChainRules.frule(sin, x); +julia> sinx, dsin = frule(sin, x); julia> sinx == sin(x) true @@ -258,7 +258,7 @@ unary input, binary output scalar function: ```julia-repl julia> x = rand(); -julia> sincosx, (dsin, dcos) = ChainRules.frule(sincos, x); +julia> sincosx, (dsin, dcos) = frule(sincos, x); julia> sincosx == sincos(x) true @@ -300,7 +300,7 @@ unary input, unary output scalar function: ```julia-repl julia> x = rand(); -julia> sinx, dx = ChainRules.rrule(sin, x); +julia> sinx, dx = rrule(sin, x); julia> sinx == sin(x) true @@ -314,7 +314,7 @@ binary input, unary output scalar function: ```julia-repl julia> x, y = rand(2); -julia> hypotxy, (dx, dy) = ChainRules.rrule(hypot, x, y); +julia> hypotxy, (dx, dy) = rrule(hypot, x, y); julia> hypotxy == hypot(x, y) true diff --git a/test/runtests.jl b/test/runtests.jl index 0f6324d4c..17a5ad2f5 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,10 +1,10 @@ # TODO: more tests! using ChainRules, Test, FDM, LinearAlgebra, Random -using ChainRules: rrule, frule, extern, accumulate, accumulate!, store!, @scalar_rule, +using ChainRules: extern, accumulate, accumulate!, store!, @scalar_rule, Wirtinger, wirtinger_primal, wirtinger_conjugate, add_wirtinger, mul_wirtinger, Zero, add_zero, mul_zero, One, add_one, mul_one, Casted, cast, add_casted, mul_casted, - DNE, Thunk, Casted, Wirtinger + DNE, Thunk, Casted using Base.Broadcast: broadcastable include("test_util.jl")