Skip to content

Commit

Permalink
write => for OptimiserChain
Browse files Browse the repository at this point in the history
  • Loading branch information
mcabbott committed Apr 6, 2023
1 parent 14949f1 commit 0258de8
Showing 1 changed file with 18 additions and 5 deletions.
23 changes: 18 additions & 5 deletions src/rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -621,23 +621,26 @@ function apply!(o::ClipNorm, state, x, dx)
end

"""
OptimiserChain(opts...)
OptimiserChain(o1, o2, o34...)
o1 => o2 => o3
Compose a sequence of optimisers so that each `opt` in `opts`
Compose a sequence of optimisers so that each `opt` in `(o1, o2, o34...)`
updates the gradient, in the order specified.
May be entered using `Pair` syntax with several `AbstractRule`s.
With an empty sequence, `OptimiserChain()` is the identity,
so `update!` will subtract the full gradient from the parameters.
This is equivalent to `Descent(1)`.
# Example
```jldoctest
julia> o = OptimiserChain(ClipGrad(1.0), Descent(0.1));
julia> o = ClipGrad(1.0) => Descent(0.1)
OptimiserChain(ClipGrad{Float64}(1.0), Descent{Float64}(0.1))
julia> m = (zeros(3),);
julia> s = Optimisers.setup(o, m)
(Leaf(OptimiserChain(ClipGrad{Float64}(1.0), Descent{Float64}(0.1)), (nothing, nothing)),)
(Leaf(ClipGrad{Float64}(1.0) => Descent{Float64}(0.1), (nothing, nothing)),)
julia> Optimisers.update(s, m, ([0.3, 1, 7],))[2] # clips before discounting
([-0.03, -0.1, -0.1],)
Expand All @@ -648,6 +651,9 @@ struct OptimiserChain{O<:Tuple} <: AbstractRule
end
OptimiserChain(opts...) = OptimiserChain(opts)

Base.Pair(a::AbstractRule, b::AbstractRule) = OptimiserChain(a, b)
Base.Pair(a::AbstractRule, bc::OptimiserChain) = OptimiserChain(a, bc.opts...)

@functor OptimiserChain

init(o::OptimiserChain, x::AbstractArray) = map(opt -> init(opt, x), o.opts)
Expand All @@ -659,7 +665,14 @@ function apply!(o::OptimiserChain, states, x, dx, dxs...)
end
end

function Base.show(io::IO, c::OptimiserChain)
function Base.show(io::IO, c::OptimiserChain) # compact show
if length(c.opts) > 1
join(io, c.opts, " => ")
else
show(io, MIME"text/plain"(), c)
end
end
function Base.show(io::IO, ::MIME"text/plain", c::OptimiserChain)
print(io, "OptimiserChain(")
join(io, c.opts, ", ")
print(io, ")")
Expand Down

0 comments on commit 0258de8

Please sign in to comment.