-
Notifications
You must be signed in to change notification settings - Fork 40
/
stochastic_gradient_descent.jl
79 lines (74 loc) · 2.8 KB
/
stochastic_gradient_descent.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
@doc raw"""
stochastic_gradient_descent(M, ∇F, x)
perform a stochastic gradient descent
# Input
* `M` a manifold ``\mathcal M``
* `∇F` – a gradient function, that either returns a vector of the subgradients
or is a vector of gradients
* `x` – an initial value $x ∈ \mathcal M$
# Optional
* `cost` – (`missing`) you can provide a cost function for example to track the function value
* `evaluation_order` – (`:Random`) – whether
to use a randomly permuted sequence (`:FixedRandom`), a per
cycle permuted sequence (`:Linear`) or the default `:Random` one.
* `stopping_criterion` ([`StopAfterIteration`](@ref)`(1000)`)– a [`StoppingCriterion`](@ref)
* `stepsize` ([`ConstantStepsize`](@ref)`(1.0)`) a [`Stepsize`](@ref)
* `order_type` (`:RandomOder`) a type of ordering of gradient evaluations.
values are `:RandomOrder`, a `:FixedPermutation`, `:LinearOrder`
* `order` - (`[1:n]`) the initial permutation, where `n` is the number of gradients in `∇F`.
* `retraction_method` – (`ExponentialRetraction()`) a `retraction(M,x,ξ)` to use.
# Output
* `x_opt` – the resulting (approximately critical) point of gradientDescent
OR
* `options` - the options returned by the solver (see `return_options`)
"""
function stochastic_gradient_descent(
M::Manifold,
∇F::Union{Function,AbstractVector{<:Function}},
x0;
cost::Union{Function,Missing}=Missing(),
direction::DirectionUpdateRule=StochasticGradient(),
stoping_criterion::StoppingCriterion=StopAfterIteration(10000),
stepsize::Stepsize=ConstantStepsize(1.0),
order_type::Symbol=:Random,
order=collect(1:(∇F isa Function ? length(∇F(x0)) : length(∇F))),
retraction_method::AbstractRetractionMethod=ExponentialRetraction(),
vector_transport_method::AbstractVectorTransportMethod=ParallelTransport(),
return_options=false,
kwargs...,
)
p = StochasticGradientProblem(M, ∇F; cost=cost)
o = StochasticGradientDescentOptions(
x0;
stoping_criterion=stoping_criterion,
stepsize=stepsize,
order_type=order_type,
order=order,
direction=direction,
retraction_method=retraction_method,
)
o = decorate_options(o; kwargs...)
resultO = solve(p, o)
if return_options
return resultO
else
return get_solver_result(resultO)
end
end
function initialize_solver!(
::StochasticGradientProblem, o::StochasticGradientDescentOptions
)
o.k = 1
(o.order_type == :FixedRandom) && (shuffle!(o.order))
return o
end
function step_solver!(
p::StochasticGradientProblem, o::StochasticGradientDescentOptions, iter
)
s, η = o.direction(p, o, iter)
retract!(p.M, o.x, o.x, -s * η)
# move forward in cycle
o.k = ((o.k) % length(o.order)) + 1
return o
end
get_solver_result(o::StochasticGradientDescentOptions) = o.x