-
Notifications
You must be signed in to change notification settings - Fork 38
/
objective.jl
164 lines (133 loc) · 6.18 KB
/
objective.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
@doc raw"""
AbstractEvaluationType
An abstract type to specify the kind of evaluation a [`AbstractManifoldObjective`](@ref) supports.
"""
abstract type AbstractEvaluationType end
@doc raw"""
AbstractManifoldObjective{E<:AbstractEvaluationType}
Describe the collection of the optimization function ``f: \mathcal M → \bbR` (or even a vectorial range)
and its corresponding elements, which might for example be a gradient or (one or more) proximal maps.
All these elements should usually be implemented as functions
`(M, p) -> ...`, or `(M, X, p) -> ...` that is
* the first argument of these functions should be the manifold `M` they are defined on
* the argument `X` is present, if the computation is performed in-place of `X` (see [`InplaceEvaluation`](@ref))
* the argument `p` is the place the function (``f`` or one of its elements) is evaluated __at__.
the type `T` indicates the global [`AbstractEvaluationType`](@ref).
"""
abstract type AbstractManifoldObjective{E<:AbstractEvaluationType} end
@doc raw"""
AbstractDecoratedManifoldObjective{E<:AbstractEvaluationType,O<:AbstractManifoldObjective}
A common supertype for all decorators of [`AbstractManifoldObjective`](@ref)s to simplify dispatch.
The second parameter should refer to the undecorated objective (the most inner one).
"""
abstract type AbstractDecoratedManifoldObjective{E,O<:AbstractManifoldObjective} <:
AbstractManifoldObjective{E} end
@doc raw"""
AllocatingEvaluation <: AbstractEvaluationType
A parameter for a [`AbstractManoptProblem`](@ref) indicating that the problem uses functions that
allocate memory for their result, they work out of place.
"""
struct AllocatingEvaluation <: AbstractEvaluationType end
@doc raw"""
InplaceEvaluation <: AbstractEvaluationType
A parameter for a [`AbstractManoptProblem`](@ref) indicating that the problem uses functions that
do not allocate memory but work on their input, in place.
"""
struct InplaceEvaluation <: AbstractEvaluationType end
@doc raw"""
ReturnManifoldObjective{E,O2,O1<:AbstractManifoldObjective{E}} <:
AbstractDecoratedManifoldObjective{E,O2}
A wrapper to indicate that `get_solver_result` should return the inner objective.
The types are such that one can still dispatch on the undecorated type `O2` of the
original objective as well.
"""
struct ReturnManifoldObjective{E,O2,O1<:AbstractManifoldObjective{E}} <:
AbstractDecoratedManifoldObjective{E,O2}
objective::O1
end
function ReturnManifoldObjective(
o::O
) where {E<:AbstractEvaluationType,O<:AbstractManifoldObjective{E}}
return ReturnManifoldObjective{E,O,O}(o)
end
function ReturnManifoldObjective(
o::O1
) where {
E<:AbstractEvaluationType,
O2<:AbstractManifoldObjective,
O1<:AbstractDecoratedManifoldObjective{E,O2},
}
return ReturnManifoldObjective{E,O2,O1}(o)
end
"""
dispatch_objective_decorator(o::AbstractManoptSolverState)
Indicate internally, whether an [`AbstractManifoldObjective`](@ref) `o` to be of decorating type,
it stores (encapsulates) an object in itself, by default in the field `o.objective`.
Decorators indicate this by returning `Val{true}` for further dispatch.
The default is `Val{false}`, so by default an state is not decorated.
"""
dispatch_objective_decorator(::AbstractManifoldObjective) = Val(false)
dispatch_objective_decorator(::AbstractDecoratedManifoldObjective) = Val(true)
"""
is_object_decorator(s::AbstractManifoldObjective)
Indicate, whether [`AbstractManifoldObjective`](@ref) `s` are of decorator type.
"""
function is_objective_decorator(s::AbstractManifoldObjective)
return _extract_val(dispatch_objective_decorator(s))
end
@doc raw"""
get_objective(o::AbstractManifoldObjective, recursive=true)
return the (one step) undecorated [`AbstractManifoldObjective`](@ref) of the (possibly) decorated `o`.
As long as your decorated objective stores the objective within `o.objective` and
the [`dispatch_objective_decorator`](@ref) is set to `Val{true}`,
the internal state are extracted automatically.
By default the objective that is stored within a decorated objective is assumed to be at
`o.objective`. Overwrite `_get_objective(o, ::Val{true}, recursive) to change this behaviour for your objective `o`
for both the recursive and the direct case.
If `recursive` is set to `false`, only the most outer decorator is taken away instead of all.
"""
function get_objective(o::AbstractManifoldObjective, recursive=true)
return _get_objective(o, dispatch_objective_decorator(o), recursive)
end
_get_objective(o::AbstractManifoldObjective, ::Val{false}, rec=true) = o
function _get_objective(o::AbstractManifoldObjective, ::Val{true}, rec=true)
return rec ? get_objective(o.objective) : o.objective
end
"""
set_manopt_parameter!(amo::AbstractManifoldObjective, element::Symbol, args...)
Set a certain `args...` from the [`AbstractManifoldObjective`](@ref) `amo` to `value.
This function should dispatch on `Val(element)`.
Currently supported
* `:Cost` passes to the [`get_cost_function`](@ref)
* `:Gradient` passes to the [`get_gradient_function`](@ref)
"""
set_manopt_parameter!(amo::AbstractManifoldObjective, e::Symbol, args...)
function set_manopt_parameter!(amo::AbstractManifoldObjective, ::Val{:Cost}, args...)
set_manopt_parameter!(get_cost_function(amo), args...)
return amo
end
function set_manopt_parameter!(amo::AbstractManifoldObjective, ::Val{:Gradient}, args...)
set_manopt_parameter!(get_gradient_function(amo), args...)
return amo
end
function show(io::IO, o::AbstractManifoldObjective{E}) where {E}
return print(io, "$(nameof(typeof(o))){$E}")
end
# Default: remove decorator for show
function show(io::IO, co::AbstractDecoratedManifoldObjective)
return show(io, get_objective(co, false))
end
function show(io::IO, t::Tuple{<:AbstractManifoldObjective,P}) where {P}
s = "$(status_summary(t[1]))"
length(s) > 0 && (s = "$(s)\n\n")
return print(
io, "$(s)To access the solver result, call `get_solver_result` on this variable."
)
end
function status_summary(::AbstractManifoldObjective{E}) where {E}
return ""
end
# Default: remove decorator for status summary
function status_summary(co::AbstractDecoratedManifoldObjective)
return status_summary(get_objective(co, false))
end