/
differentiation.jl
135 lines (96 loc) · 4.5 KB
/
differentiation.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
"""
AbstractDiffBackend
An abstract type for diff backends. See [`FiniteDifferencesBackend`](@ref) for
an example.
"""
abstract type AbstractDiffBackend end
struct NoneDiffBackend <: AbstractDiffBackend end
"""
_derivative(f, t[, backend::AbstractDiffBackend])
Compute the derivative of a callable `f` at time `t` computed using the given `backend`,
an object of type [`Manifolds.AbstractDiffBackend`](@ref). If the backend is not explicitly
specified, it is obtained using the function [`default_differential_backend`](@ref).
This function calculates plain Euclidean derivatives, for Riemannian differentiation see
for example [`differential`](@ref Manifolds.differential(::AbstractManifold, ::Any, ::Real, ::AbstractRiemannianDiffBackend)).
!!! note
Not specifying the backend explicitly will usually result in a type instability
and decreased performance.
"""
function _derivative end
_derivative(f, t) = _derivative(f, t, default_differential_backend())
function _derivative!(f, X, t, backend::AbstractDiffBackend=default_differential_backend())
return copyto!(X, _derivative(f, t, backend))
end
"""
_gradient(f, p[, backend::AbstractDiffBackend])
Compute the gradient of a callable `f` at point `p` computed using the given `backend`,
an object of type [`AbstractDiffBackend`](@ref). If the backend is not explicitly
specified, it is obtained using the function [`default_differential_backend`](@ref).
This function calculates plain Euclidean gradients, for Riemannian gradient calculation see
for example [`gradient`](@ref Manifolds.gradient(::AbstractManifold, ::Any, ::Any, ::AbstractRiemannianDiffBackend)).
!!! note
Not specifying the backend explicitly will usually result in a type instability
and decreased performance.
"""
function _gradient end
_gradient(f, p) = _gradient(f, p, default_differential_backend())
function _gradient!(f, X, p, backend::AbstractDiffBackend=default_differential_backend())
return copyto!(X, _gradient(f, p, backend))
end
"""
_hessian(f, p[, backend::AbstractDiffBackend])
Compute the Hessian of a callable `f` at point `p` computed using the given `backend`,
an object of type [`AbstractDiffBackend`](@ref). If the backend is not explicitly
specified, it is obtained using the function [`default_differential_backend`](@ref).
This function calculates plain Euclidean Hessian.
!!! note
Not specifying the backend explicitly will usually result in a type instability
and decreased performance.
"""
function _hessian end
_hessian(f, p) = _hessian(f, p, default_differential_backend())
"""
_jacobian(f, p[, backend::AbstractDiffBackend])
Compute the Jacobian of a callable `f` at point `p` computed using the given `backend`,
an object of type [`AbstractDiffBackend`](@ref). If the backend is not explicitly
specified, it is obtained using the function [`default_differential_backend`](@ref).
This function calculates plain Euclidean Jacobians, for Riemannian Jacobian calculation see
for example [`gradient`](@ref Manifolds.gradient(::AbstractManifold, ::Any, ::Any, ::AbstractRiemannianDiffBackend)).
!!! note
Not specifying the backend explicitly will usually result in a type instability
and decreased performance.
"""
function _jacobian end
_jacobian(f, p) = _jacobian(f, p, default_differential_backend())
function _jacobian!(f, X, p, backend::AbstractDiffBackend=default_differential_backend())
return copyto!(X, _jacobian(f, p, backend))
end
"""
CurrentDiffBackend(backend::AbstractDiffBackend)
A mutable struct for storing the current differentiation backend in a global
constant [`_current_default_differential_backend`](@ref).
# See also
[`AbstractDiffBackend`](@ref), [`default_differential_backend`](@ref), [`set_default_differential_backend!`](@ref)
"""
mutable struct CurrentDiffBackend
backend::AbstractDiffBackend
end
"""
_current_default_differential_backend
The instance of [`Manifolds.CurrentDiffBackend`](@ref) that stores the globally default
differentiation backend.
"""
const _current_default_differential_backend = CurrentDiffBackend(NoneDiffBackend())
"""
default_differential_backend() -> AbstractDiffBackend
Get the default differentiation backend.
"""
default_differential_backend() = _current_default_differential_backend.backend
"""
set_default_differential_backend!(backend::AbstractDiffBackend)
Set current backend for differentiation to `backend`.
"""
function set_default_differential_backend!(backend::AbstractDiffBackend)
_current_default_differential_backend.backend = backend
return backend
end