Skip to content

Commit

Permalink
change from => to >>
Browse files Browse the repository at this point in the history
  • Loading branch information
mcabbott committed Sep 12, 2023
1 parent 99ae33d commit 4a55975
Showing 1 changed file with 9 additions and 7 deletions.
16 changes: 9 additions & 7 deletions src/rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -637,11 +637,11 @@ end

"""
OptimiserChain(o1, o2, o34...)
o1 => o2 => o3
o1 >> o2 >> o3
Compose a sequence of optimisers so that each `opt` in `(o1, o2, o34...)`
updates the gradient, in the order specified.
May be entered using `Pair` syntax with several `AbstractRule`s.
May be entered using the `>>` operator with several `AbstractRule`s.
With an empty sequence, `OptimiserChain()` is the identity,
so `update!` will subtract the full gradient from the parameters.
Expand All @@ -650,8 +650,8 @@ This is equivalent to `Descent(1)`.
# Example
```jldoctest
julia> o = ClipGrad(1.0) => Descent(0.1)
OptimiserChain(ClipGrad{Float64}(1.0), Descent{Float64}(0.1))
julia> o = ClipGrad(1.0) >> Descent(0.1)
OptimiserChain(ClipGrad(1.0), Descent(0.1))
julia> m = (zeros(3),);
Expand All @@ -667,8 +667,10 @@ struct OptimiserChain{O<:Tuple} <: AbstractRule
end
OptimiserChain(opts...) = OptimiserChain(opts)

Base.Pair(a::AbstractRule, b::AbstractRule) = OptimiserChain(a, b)
Base.Pair(a::AbstractRule, bc::OptimiserChain) = OptimiserChain(a, bc.opts...)
@doc @doc(OptimiserChain)
Base.:(>>)(a::AbstractRule, b::AbstractRule) = OptimiserChain(a, b)
Base.:(>>)(a::AbstractRule, bc::OptimiserChain) = OptimiserChain(a, bc.opts...)
Base.:(>>)(ab::OptimiserChain, c::AbstractRule) = OptimiserChain(ab.opts..., c)

@functor OptimiserChain

Expand All @@ -687,7 +689,7 @@ end

function Base.show(io::IO, c::OptimiserChain) # compact show
if length(c.opts) > 1
join(io, c.opts, " => ")
join(io, c.opts, " >> ")
else
show(io, MIME"text/plain"(), c)
end
Expand Down

0 comments on commit 4a55975

Please sign in to comment.