This repository has been archived by the owner on Mar 1, 2023. It is now read-only.
/
MultirateRungeKuttaMethod.jl
166 lines (138 loc) · 4.54 KB
/
MultirateRungeKuttaMethod.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
export MultirateRungeKutta
LSRK2N = LowStorageRungeKutta2N
"""
MultirateRungeKutta(slow_solver, fast_solver; dt, t0 = 0)
This is a time stepping object for explicitly time stepping the differential
equation given by the right-hand-side function `f` with the state `Q`, i.e.,
```math
\\dot{Q} = f_fast(Q, t) + f_slow(Q, t)
```
with the required time step size `dt` and optional initial time `t0`. This
time stepping object is intended to be passed to the `solve!` command.
The constructor builds a multirate Runge-Kutta scheme using two different RK
solvers. This is based on
Currently only the low storage RK methods can be used as slow solvers
### References
@article{SchlegelKnothArnoldWolke2012,
title={Implementation of multirate time integration methods for air
pollution modelling},
author={Schlegel, M and Knoth, O and Arnold, M and Wolke, R},
journal={Geoscientific Model Development},
volume={5},
number={6},
pages={1395--1405},
year={2012},
publisher={Copernicus GmbH}
}
"""
mutable struct MultirateRungeKutta{SS, FS, RT} <: AbstractODESolver
"slow solver"
slow_solver::SS
"fast solver"
fast_solver::FS
"time step"
dt::RT
"time"
t::RT
"elapsed time steps"
steps::Int
function MultirateRungeKutta(
slow_solver::LSRK2N,
fast_solver,
Q = nothing;
dt = getdt(slow_solver),
t0 = slow_solver.t,
) where {AT <: AbstractArray}
SS = typeof(slow_solver)
FS = typeof(fast_solver)
RT = real(eltype(slow_solver.dQ))
new{SS, FS, RT}(slow_solver, fast_solver, RT(dt), RT(t0), 0)
end
end
function MultirateRungeKutta(
solvers::Tuple,
Q = nothing;
dt = getdt(solvers[1]),
t0 = solvers[1].t,
) where {AT <: AbstractArray}
if length(solvers) < 2
error("Must specify atleast two solvers")
elseif length(solvers) == 2
fast_solver = solvers[2]
else
fast_solver = MultirateRungeKutta(solvers[2:end], Q; dt = dt, t0 = t0)
end
slow_solver = solvers[1]
MultirateRungeKutta(slow_solver, fast_solver, Q; dt = dt, t0 = t0)
end
function dostep!(
Q,
mrrk::MultirateRungeKutta{SS},
param,
time,
in_slow_δ = nothing,
in_slow_rv_dQ = nothing,
in_slow_scaling = nothing,
) where {SS <: LSRK2N}
dt = mrrk.dt
slow = mrrk.slow_solver
fast = mrrk.fast_solver
slow_rv_dQ = realview(slow.dQ)
groupsize = 256
fast_dt_in = getdt(fast)
for slow_s in 1:length(slow.RKA)
# Currnent slow state time
slow_stage_time = time + slow.RKC[slow_s] * dt
# Evaluate the slow mode
slow.rhs!(slow.dQ, Q, param, slow_stage_time, increment = true)
if in_slow_δ !== nothing
slow_scaling = nothing
if slow_s == length(slow.RKA)
slow_scaling = in_slow_scaling
end
# update solution and scale RHS
event = Event(array_device(Q))
event = update!(array_device(Q), groupsize)(
slow_rv_dQ,
in_slow_rv_dQ,
in_slow_δ,
slow_scaling;
ndrange = length(realview(Q)),
dependencies = (event,),
)
wait(array_device(Q), event)
end
# Fractional time for slow stage
if slow_s == length(slow.RKA)
γ = 1 - slow.RKC[slow_s]
else
γ = slow.RKC[slow_s + 1] - slow.RKC[slow_s]
end
# RKB for the slow with fractional time factor remove (since full
# integration of fast will result in scaling by γ)
slow_δ = slow.RKB[slow_s] / (γ)
# RKB for the slow with fractional time factor remove (since full
# integration of fast will result in scaling by γ)
nsubsteps = fast_dt_in > 0 ? ceil(Int, γ * dt / fast_dt_in) : 1
fast_dt = γ * dt / nsubsteps
updatedt!(fast, fast_dt)
for substep in 1:nsubsteps
slow_rka = nothing
if substep == nsubsteps
slow_rka = slow.RKA[slow_s % length(slow.RKA) + 1]
end
fast_time = slow_stage_time + (substep - 1) * fast_dt
dostep!(Q, fast, param, fast_time, slow_δ, slow_rv_dQ, slow_rka)
end
end
updatedt!(fast, fast_dt_in)
end
@kernel function update!(fast_dQ, slow_dQ, δ, slow_rka = nothing)
i = @index(Global, Linear)
@inbounds begin
fast_dQ[i] += δ * slow_dQ[i]
if slow_rka !== nothing
slow_dQ[i] *= slow_rka
end
end
end