Skip to content

Commit

Permalink
Merge #1362
Browse files Browse the repository at this point in the history
1362: Implement AdaBelief r=DhairyaLGandhi a=willtebbutt

Implements the fancy new [AdaBelief](https://arxiv.org/abs/2010.07468) algorithm. I guess time will tell whether or not it's helpful, but it doesn't do any harm including it, and it's a pretty trivial change from ADAM.

### PR Checklist

- [x] Tests are added
- [x] Entry in NEWS.md
- [x] Documentation, if applicable
- [ ] Final review from `@dhairyagandhi96` (for API changes).

edit: I've tweaked the docstring from ADAM, which I think probably counts as documentation as I'm assuming it's going to get autodoc-ed into the optimiser documentation.

Co-authored-by: Will Tebbutt <will.tebbutt@invenialabs.co.uk>
Co-authored-by: wt <wt0881@my.bristol.ac.uk>
  • Loading branch information
3 people committed Oct 19, 2020
2 parents 98e7222 + d4a8252 commit c5c35cc
Show file tree
Hide file tree
Showing 6 changed files with 48 additions and 6 deletions.
4 changes: 4 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
# v0.11.2

* Adds the [AdaBelief](https://arxiv.org/abs/2010.07468) optimiser.

# v0.11

* Moved CUDA compatibility to use [CUDA.jl instead of CuArrays.jl](https://github.com/FluxML/Flux.jl/pull/1204)
Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "Flux"
uuid = "587475ba-b771-5e3f-ad9e-33799f191a9c"
version = "0.11.1"
version = "0.11.2"

[deps]
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
Expand Down
4 changes: 2 additions & 2 deletions src/Flux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ using .Optimise: @epochs
using .Optimise: skip
export Descent, ADAM, Momentum, Nesterov, RMSProp,
ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM, OADAM,
ADAMW, RADAM, InvDecay, ExpDecay, WeightDecay,
ClipValue, ClipNorm
ADAMW, RADAM, AdaBelief, InvDecay, ExpDecay,
WeightDecay, ClipValue, ClipNorm


using CUDA
Expand Down
2 changes: 1 addition & 1 deletion src/optimise/Optimise.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ using LinearAlgebra

export train!, update!,
Descent, ADAM, Momentum, Nesterov, RMSProp,
ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM, ADAMW,RADAM, OADAM,
ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM, ADAMW,RADAM, OADAM, AdaBelief,
InvDecay, ExpDecay, WeightDecay, stop, skip, Optimiser,
ClipValue, ClipNorm

Expand Down
38 changes: 38 additions & 0 deletions src/optimise/optimisers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -460,6 +460,44 @@ opt = ADAMW(0.001, (0.89, 0.995), 0.1)
ADAMW= 0.001, β = (0.9, 0.999), decay = 0) =
Optimiser(ADAM(η, β), WeightDecay(decay))

"""
AdaBelief(η = 0.001, β::Tuple = (0.9, 0.999))
The [AdaBelief](https://arxiv.org/abs/2010.07468) optimiser is a variant of the well-known
ADAM optimiser.
# Parameters
- Learning rate (`η`): Amount by which gradients are discounted before updating
the weights.
- Decay of momentums (`β::Tuple`): Exponential decay for the first (β1) and the
second (β2) momentum estimate.
# Examples
```julia
opt = AdaBelief()
opt = AdaBelief(0.001, (0.9, 0.8))
```
"""
mutable struct AdaBelief
eta::Float64
beta::Tuple{Float64,Float64}
state::IdDict
end

AdaBelief= 0.001, β = (0.9, 0.999)) = AdaBelief(η, β, IdDict())

function apply!(o::AdaBelief, x, Δ)
η, β = o.eta, o.beta
mt, st = get!(o.state, x, (zero(x), zero(x)))
@. mt = β[1] * mt + (1 - β[1]) * Δ
@. st = β[2] * st + (1 - β[2]) *- mt)^2
@. Δ = η * mt / ((st) + ϵ)
o.state[x] = (mt, st)
return Δ
end


# Compose optimizers

"""
Expand Down
4 changes: 2 additions & 2 deletions test/optimise.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ using Random
Random.seed!(84)
w = randn(10, 10)
@testset for opt in [ADAMW(), ADAGrad(0.1), AdaMax(), ADADelta(0.9), AMSGrad(),
NADAM(), RADAM(), Descent(0.1), ADAM(), OADAM(), Nesterov(), RMSProp(),
Momentum()]
NADAM(), RADAM(), Descent(0.1), ADAM(), OADAM(), AdaBelief(),
Nesterov(), RMSProp(), Momentum()]
Random.seed!(42)
w′ = randn(10, 10)
loss(x) = Flux.Losses.mse(w*x, w′*x)
Expand Down

0 comments on commit c5c35cc

Please sign in to comment.