/
common.jl
62 lines (51 loc) · 1.67 KB
/
common.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
function solve{uType,tType,isinplace,AlgType<:ODEJLAlgorithm,F}(prob::AbstractODEProblem{uType,tType,isinplace,F},
alg::AlgType,timeseries=[],ts=[],ks=[];dense=true,save_timeseries=true,
saveat=tType[],timeseries_errors=true,reltol = 1e-5, abstol = 1e-8,
dtmin = abs(prob.tspan[2]-prob.tspan[1])/1e-9,
dtmax = abs(prob.tspan[2]-prob.tspan[1])/2.5,
dt = 0.,norm = Base.vecnorm,
kwargs...)
tspan = prob.tspan
if tspan[end]-tspan[1]<tType(0)
error("final time must be greater than starting time. Aborting.")
end
u0 = prob.u0
Ts = sort(unique([tspan[1];saveat;tspan[2]]))
if save_timeseries
points = :all
else
points = :specified
end
sizeu = size(prob.u0)
if isinplace
f = (t,u) -> (du = zeros(u); prob.f(t,u,du); vec(du))
elseif uType <: AbstractArray
f = (t,u) -> vec(prob.f(t,reshape(u,sizeu)))
else
f = prob.f
end
if uType <: AbstractArray
u0 = vec(prob.u0)
else
u0 = prob.u0
end
ts,timeseries_tmp = AlgType(f,u0,Ts;
norm = norm,
abstol=abstol,
reltol=reltol,
maxstep=dtmax,
minstep=dtmin,
initstep=dt,
points=points)
# Reshape the result if needed
if uType <: AbstractArray
timeseries = Vector{uType}(0)
for i=1:length(timeseries_tmp)
push!(timeseries,reshape(timeseries_tmp[i],sizeu))
end
else
timeseries = timeseries_tmp
end
build_solution(prob,alg,ts,timeseries,
timeseries_errors = timeseries_errors)
end