Skip to content

Commit

Permalink
update common interface for v0.15.0
Browse files Browse the repository at this point in the history
  • Loading branch information
ChrisRackauckas committed Apr 12, 2017
1 parent ddc898e commit b598ad6
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 18 deletions.
53 changes: 41 additions & 12 deletions src/common.jl
Expand Up @@ -9,26 +9,50 @@ abstract DASKRDAEAlgorithm{LinearSolver} <: AbstractDAEAlgorithm

# DAE Algorithms
immutable daskr{LinearSolver} <: DASKRDAEAlgorithm{LinearSolver} end
daskr(;linear_solver=:Dense) = daskr{linear_solver}()
Base.@pure daskr(;linear_solver=:Dense) = daskr{linear_solver}()

export daskr

## Solve for DAEs uses raw_solver

function solve{uType,duType,tType,isinplace,F,LinearSolver}(
prob::AbstractDAEProblem{uType,duType,tType,isinplace,F},
function solve{uType,duType,tType,isinplace,LinearSolver}(
prob::AbstractDAEProblem{uType,duType,tType,isinplace},
alg::DASKRDAEAlgorithm{LinearSolver},
timeseries = [], ts = [], ks = [];
callback = () -> nothing, abstol = 1/10^6, reltol = 1/10^3,
saveat = Float64[], adaptive = true, maxiter = Int(1e5),
timeseries_errors = true, save_timeseries = true,
timeseries_errors = true, save_everystep = isempty(saveat),
save_start = true, save_timeseries = nothing,
userdata = nothing, isdiff = fill(true, length(prob.u0)), kwargs...)


if save_timeseries != nothing
warn("save_timeseries is deprecated. Use save_everystep instead")
save_everystep = save_timeseries
end

tspan = prob.tspan
t0 = tspan[1]
T = tspan[end]

save_ts = sort(unique([t0;saveat;T]))


if typeof(saveat) <: Number
saveat_vec = convert(Vector{tType},saveat:saveat:(tspan[end]-saveat))
# Exclude the endpoint because of floating point issues
else
saveat_vec = convert(Vector{tType},collect(saveat))
end

if !isempty(saveat_vec) && saveat_vec[end] == tspan[2]
pop!(saveat_vec)
end

if !isempty(saveat_vec) && saveat_vec[1] == tspan[1]
save_ts = sort(unique([saveat_vec[2:end];tspan[2]]))
else
save_ts = sort(unique([saveat_vec;tspan[2]]))
end

if T < save_ts[end]
error("Final saving timepoint is past the solving timespan")
Expand Down Expand Up @@ -74,7 +98,7 @@ function solve{uType,duType,tType,isinplace,F,LinearSolver}(
idid = Int32[0]
info = zeros(Int32, 20)

info[3] = save_timeseries
info[3] = save_everystep
info[11] = 0
info[16] = 0 # == 1 to ignore algebraic variables in the error calculation
info[17] = 0
Expand All @@ -98,12 +122,14 @@ function solve{uType,duType,tType,isinplace,F,LinearSolver}(
psol = Int32[0]

ures = Vector{Vector{Float64}}()
ts = [t0]

save_start ? ts = [t0] : ts = Float64[]
save_start ? start_idx = 1 : start_idx = 2
save_start && push!(ures, copy(u0))

u = copy(u0)
du = copy(du0)
# The Inner Loops : Style depends on save_timeseries
for k in 1:length(save_ts)
for k in start_idx:length(save_ts)
tout = [save_ts[k]]
while t[1] < save_ts[k]
DASKR.unsafe_solve(res, N, t, u, du, tout, info, rtol, atol, idid, rwork, lrw, iwork, liw, rpar, ipar, jac, psol, rt, nrt, jroot)
Expand All @@ -113,17 +139,20 @@ function solve{uType,duType,tType,isinplace,F,LinearSolver}(
end
### Finishing Routine



timeseries = Vector{uType}(0)
if typeof(prob.u0)<:Number
for i=1:length(ures)
for i=start_idx:length(ures)
push!(timeseries,ures[i][1])
end
else
for i=1:length(ures)
for i=start_idx:length(ures)
push!(timeseries,reshape(ures[i],sizeu))
end
end

build_solution(prob,alg,ts,timeseries,
timeseries_errors = timeseries_errors)
timeseries_errors = timeseries_errors,
retcode = :Success)
end
15 changes: 9 additions & 6 deletions test/runtests.jl
Expand Up @@ -67,14 +67,17 @@ let
dt = 1000
saveat = float(collect(0:dt:100000))
sol = solve(prob, daskr())
sol = solve(prob, daskr(),save_timeseries=false)
@test length(sol.t) > 2
sol = solve(prob, daskr(),save_everystep=false)
@test length(sol.u) == length(sol.t) == 2
sol = solve(prob, daskr(), saveat = saveat, isdiff = [true, true, false])
@test sol.t == saveat
sol = solve(prob, daskr(), saveat = dt, isdiff = [true, true, false])
@test sol.t == saveat
sol = solve(prob, daskr(), saveat = saveat,
save_timeseries = false,
save_everystep = true,
isdiff = [true, true, false])
sol = solve(prob, daskr(), saveat = saveat, save_timeseries = false)

@test minimum([t sol.t for t in saveat])
sol = solve(prob, daskr(), saveat = saveat, save_everystep = true)
@test intersect(sol.t, saveat) == saveat

@test sol.t == saveat
end

0 comments on commit b598ad6

Please sign in to comment.