-
-
Notifications
You must be signed in to change notification settings - Fork 26
/
utils.jl
84 lines (72 loc) · 4 KB
/
utils.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
"""
fsal_typeof(integrator::ODEIntegrator)
Return type of FSAL of `integrator`.
"""
function fsal_typeof(integrator::ODEIntegrator{<:OrdinaryDiffEq.OrdinaryDiffEqAlgorithm,
uType,tType,tTypeNoUnits,tdirType,ksEltype,
SolType,F,ProgressType,CacheType,O,
FSALType}) where {uType,tType,tTypeNoUnits,
tdirType,ksEltype,SolType,
F,ProgressType,CacheType,O,
FSALType}
return FSALType
end
"""
build_linked_cache(cache, alg, u, uprev, uprev2, f, t, dt)
Create cache for algorithm `alg` from existing cache `cache` with updated `u`, `uprev`,
`uprev2`, `f`, `t`, and `dt`.
"""
@generated function build_linked_cache(cache, alg, u, uprev, uprev2, f, t, dt)
assignments = [assign_expr(Val{name}(), fieldtype(cache, name), cache)
for name in fieldnames(cache) if name ∉ [:u, :uprev, :uprev2, :t, :dt]]
:($(assignments...); $(DiffEqBase.parameterless_type(cache))($(fieldnames(cache)...)))
end
"""
assign_expr(::Val{name}, ::Type{T}, ::Type{cache})
Create expression that extracts field `name` of type `T` from cache of type `cache`
to variable `name`.
Hereby u, uprev, uprev2, and function f are updated, if required.
"""
assign_expr(::Val{name}, ::Type, ::Type) where {name} =
:($name = getfield(cache, $(Meta.quot(name))))
# update matrix exponential
assign_expr(::Val{:expA}, ::Type, ::Type) =
:(A = f.f1; expA = expm(A*dt))
assign_expr(::Val{:phi1}, ::Type, ::Type{<:OrdinaryDiffEq.NorsettEulerCache}) =
:(phi1 = ((expA-I)/A))
# update derivative wrappers
assign_expr(::Val{name}, ::Type{<:OrdinaryDiffEq.TimeDerivativeWrapper}, ::Type) where name =
:($name = OrdinaryDiffEq.TimeDerivativeWrapper(f, u))
assign_expr(::Val{name}, ::Type{<:OrdinaryDiffEq.UDerivativeWrapper}, ::Type) where name =
:($name = OrdinaryDiffEq.UDerivativeWrapper(f, t))
assign_expr(::Val{name}, ::Type{<:OrdinaryDiffEq.TimeGradientWrapper}, ::Type) where name =
:($name = OrdinaryDiffEq.TimeGradientWrapper(
f,uprev,
getfield(cache, $(Meta.quot(name))).fx1))
assign_expr(::Val{name}, ::Type{<:OrdinaryDiffEq.UJacobianWrapper}, ::Type) where name =
:($name = OrdinaryDiffEq.UJacobianWrapper(
f,t,
uprev,
getfield(cache, $(Meta.quot(name))).fx1))
# create new config of Jacobian
assign_expr(::Val{name}, ::Type{ForwardDiff.JacobianConfig{T,V,N,D}},
::Type) where {name,T,V,N,D} =
:($name = ForwardDiff.JacobianConfig(uf, du1, uprev,
ForwardDiff.Chunk{$N}()))
# update implicit RHS
assign_expr(::Val{name}, ::Type{<:OrdinaryDiffEq.ImplicitRHS}, ::Type) where name =
:($name = OrdinaryDiffEq.ImplicitRHS(f, cache.tmp, t, t, t, cache.dual_cache))
assign_expr(::Val{name}, ::Type{<:OrdinaryDiffEq.ImplicitRHS_Scalar}, ::Type) where name =
:($name = OrdinaryDiffEq.ImplicitRHS_Scalar(f, zero(u), t, t, t))
assign_expr(::Val{name}, ::Type{<:OrdinaryDiffEq.RHS_IIF}, ::Type) where name =
:($name = OrdinaryDiffEq.RHS_IIF(f, cache.tmp, t, t, cache.tmp, cache.dual_cache))
assign_expr(::Val{name}, ::Type{<:OrdinaryDiffEq.RHS_IIF_Scalar}, ::Type) where name =
:($name = OrdinaryDiffEq.RHS_IIF_Scalar(f, zero(u), t, t,
getfield(cache, $(Meta.quot(name))).a))
# create new NLsolve differentiable function
assign_expr(::Val{name}, ::Type{<:NLsolve.DifferentiableMultivariateFunction},
::Type{<:OrdinaryDiffEq.OrdinaryDiffEqMutableCache}) where name =
:($name = alg.nlsolve(Val{:init},rhs,u))
assign_expr(::Val{name}, ::Type{<:NLsolve.DifferentiableMultivariateFunction},
::Type{<:OrdinaryDiffEq.OrdinaryDiffEqConstantCache}) where name =
:($name = alg.nlsolve(Val{:init},rhs,uhold))