-
-
Notifications
You must be signed in to change notification settings - Fork 29
/
lowerlevel_solve.jl
84 lines (67 loc) · 2.69 KB
/
lowerlevel_solve.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
"""
Lower level API for `EnsembleArrayAlgorithm`. Avoids conversion of solution to CPU arrays.
```julia
vectorized_map_solve(probs, alg,
ensemblealg::Union{EnsembleArrayAlgorithm}, I,
adaptive)
```
## Arguments
- `probs`: the GPU-setup problems generated by the ensemble.
- `alg`: the kernel-based differential equation solver. Most of the solvers from OrdinaryDiffEq.jl
are supported.
- `ensemblealg`: The `EnsembleGPUArray()` algorithm.
- `I`: The iterator argument. Can be set to for e.g. 1:10_000 to simulate 10,000 trajectories.
- `adaptive`: The Boolean argument for time-stepping. Use `true` to enable adaptive time-stepping.
## Keyword Arguments
Only a subset of the common solver arguments are supported.
"""
function vectorized_map_solve end
function vectorized_map_solve(probs, alg,
ensemblealg::Union{EnsembleArrayAlgorithm}, I,
adaptive;
kwargs...)
# @assert all(Base.Fix2((prob1, prob2) -> isequal(prob1.tspan, prob2.tspan),probs[1]),probs)
# u0 = reduce(hcat, Array(probs[i].u0) for i in 1:length(I))
# p = reduce(hcat,
# probs[i].p isa SciMLBase.NullParameters ? probs[i].p : Array(probs[i].p)
# for i in 1:length(I))
u0 = hcat([Array(probs[i].u0) for i in 1:length(I)]...)
p = hcat([Array(probs[i].p) for i in 1:length(I)]...)
prob = probs[1]
sol = vectorized_map_solve_up(prob, alg, ensemblealg, I, u0, p;
adaptive = adaptive, kwargs...)
end
function vectorized_map_solve_up(prob, alg, ensemblealg, I, u0, p; kwargs...)
if ensemblealg isa EnsembleGPUArray
backend = ensemblealg.backend
u0 = adapt(backend, u0)
p = adapt(backend, p)
end
len = length(prob.u0)
if SciMLBase.has_jac(prob.f)
if ensemblealg isa EnsembleGPUArray
backend = ensemblealg.backend
jac_prototype = allocate(backend, Float32, (len, len, length(I)))
fill!(jac_prototype, 0.0)
else
jac_prototype = zeros(Float32, len, len, length(I))
end
if prob.f.colorvec !== nothing
colorvec = repeat(prob.f.colorvec, length(I))
else
colorvec = repeat(1:length(prob.u0), length(I))
end
else
jac_prototype = nothing
colorvec = nothing
end
_callback = generate_callback(prob, length(I), ensemblealg; kwargs...)
prob = generate_problem(prob, u0, p, jac_prototype, colorvec)
if hasproperty(alg, :linsolve)
_alg = remake(alg, linsolve = LinSolveGPUSplitFactorize(len, -1))
else
_alg = alg
end
sol = solve(prob, _alg; kwargs..., callback = _callback, merge_callbacks = false,
internalnorm = diffeqgpunorm)
end