Skip to content

Commit

Permalink
make frule/rrule definitions more rigorous
Browse files Browse the repository at this point in the history
  • Loading branch information
jrevels committed Apr 4, 2019
1 parent 2a7b666 commit 35f717b
Showing 1 changed file with 92 additions and 12 deletions.
104 changes: 92 additions & 12 deletions src/rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -186,26 +186,106 @@ end
=#

"""
frule(f, xs...)
frule(f, x...)
Apply the forward-mode differentiation rule to `f` with the given arguments `xs`,
returning a tuple of `f(xs...)` and an [`AbstractRule`](@ref) object which can be
called to evaluate the rule. If no forward-mode rule has been defined for `f`,
`nothing` is returned.
Expressing `x` as the tuple `(x₁, x₂, ...)` and the output tuple of `f(x...)`
as `Ω`, return the tuple:
See also: [`rrule`](@ref)
(Ω, (rule_for_ΔΩ₁::AbstractRule, rule_for_ΔΩ₂::AbstractRule, ...))
where each returned propagation rule `rule_for_ΔΩᵢ` can be invoked as
rule_for_ΔΩᵢ(previous_ΔΩᵢ, Δx₁, Δx₂, ...)
to yield `Ωᵢ`'s corresponding differential `ΔΩᵢ`. To illustrate, if all involved
values are real-valued scalars, this differential can be written as:
previous_ΔΩᵢ + ∂Ωᵢ_∂x₁ * Δx₁ + ∂Ωᵢ_∂x₁ * Δx₂ + ...
If no method matching `frule(f, xs...)` has been defined, then return `nothing`.
Examples:
unary input, unary output scalar function:
julia> x = rand();
julia> sinx, dsin = ChainRules.frule(sin, x);
julia> sinx == sin(x)
true
julia> dsin(0, 1) == cos(x)
true
unary input, binary output scalar function:
julia> x = rand();
julia> sincosx, (dsin, dcos) = ChainRules.frule(sincos, x);
julia> sincosx == sincos(x)
true
julia> dsin(0, 1) == cos(x)
true
julia> dcos(0, 1) == -sin(x)
true
See also: [`rrule`](@ref), [`AbstractRule`](@ref)
"""
frule(::Any, ::Vararg{Any}) = nothing

"""
rrule(f, xs...)
rrule(f, x...)
Expressing `x` as the tuple `(x₁, x₂, ...)` and the output tuple of `f(x...)`
as `Ω`, return the tuple:
(Ω, (rule_for_Δx₁::AbstractRule, rule_for_Δx₂::AbstractRule, ...))
where each returned propagation rule `rule_for_Δxᵢ` can be invoked as
rule_for_Δxᵢ(previous_Δxᵢ, ΔΩ₁, ΔΩ₂, ...)
to yield `xᵢ`'s corresponding differential `Δxᵢ`. To illustrate, if all involved
values are real-valued scalars, this differential can be written as:
previous_Δxᵢ + ∂Ω₁_∂xᵢ * ΔΩ₁ + ∂Ω₂_∂xᵢ * ΔΩ₂ + ...
If no method matching `rrule(f, xs...)` has been defined, then return `nothing`.
Examples:
unary input, unary output scalar function:
julia> x = rand();
julia> sinx, dx = ChainRules.rrule(sin, x);
julia> sinx == sin(x)
true
julia> dx(0, 1) == cos(x)
true
binary input, unary output scalar function:
julia> x, y = rand(2);
julia> hypotxy, (dx, dy) = ChainRules.rrule(hypot, x, y);
julia> hypotxy == hypot(x, y)
true
julia> dx(0, 1) == (y / hypot(x, y))
true
Apply the reverse-mode differentiation rule to `f` with the given arguments `xs`,
returning a tuple of `f(xs...)` and an [`AbstractRule`](@ref) object which can be
called to evaluate the rule. If no reverse-mode rule has been defined for `f`,
`nothing` is returned.
julia> dy(0, 1) == (x / hypot(x, y))
true
See also: [`frule`](@ref)
See also: [`frule`](@ref), [`AbstractRule`](@ref)
"""
rrule(::Any, ::Vararg{Any}) = nothing

Expand Down

0 comments on commit 35f717b

Please sign in to comment.