Skip to content

Commit

Permalink
Merge pull request #514 from JuliaDiffEq/xg/save
Browse files Browse the repository at this point in the history
`save_on` option in `solve`
  • Loading branch information
ChrisRackauckas committed Oct 5, 2018
2 parents 2ddf833 + c570cd3 commit 269794e
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 39 deletions.
78 changes: 40 additions & 38 deletions src/integrators/integrator_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,30 +55,54 @@ function modify_dt_for_tstops!(integrator)
end

function savevalues!(integrator::ODEIntegrator,force_save=false,reduce_size=true)
while !isempty(integrator.opts.saveat) && integrator.tdir*top(integrator.opts.saveat) <= integrator.tdir*integrator.t # Perform saveat
integrator.saveiter += 1
curt = pop!(integrator.opts.saveat)
if curt!=integrator.t # If <t, interpolate
ode_addsteps!(integrator)
Θ = (curt - integrator.tprev)/integrator.dt
val = ode_interpolant(Θ,integrator,integrator.opts.save_idxs,Val{0}) # out of place, but no force copy later
copyat_or_push!(integrator.sol.t,integrator.saveiter,curt)
save_val = val
copyat_or_push!(integrator.sol.u,integrator.saveiter,save_val,Val{false})
if typeof(integrator.alg) <: OrdinaryDiffEqCompositeAlgorithm
copyat_or_push!(integrator.sol.alg_choice,integrator.saveiter,integrator.cache.current)
if integrator.opts.save_on
while !isempty(integrator.opts.saveat) && integrator.tdir*top(integrator.opts.saveat) <= integrator.tdir*integrator.t # Perform saveat
integrator.saveiter += 1
curt = pop!(integrator.opts.saveat)
if curt!=integrator.t # If <t, interpolate
ode_addsteps!(integrator)
Θ = (curt - integrator.tprev)/integrator.dt
val = ode_interpolant(Θ,integrator,integrator.opts.save_idxs,Val{0}) # out of place, but no force copy later
copyat_or_push!(integrator.sol.t,integrator.saveiter,curt)
save_val = val
copyat_or_push!(integrator.sol.u,integrator.saveiter,save_val,Val{false})
if typeof(integrator.alg) <: OrdinaryDiffEqCompositeAlgorithm
copyat_or_push!(integrator.sol.alg_choice,integrator.saveiter,integrator.cache.current)
end
else # ==t, just save
copyat_or_push!(integrator.sol.t,integrator.saveiter,integrator.t)
if integrator.opts.save_idxs == nothing
copyat_or_push!(integrator.sol.u,integrator.saveiter,integrator.u)
else
copyat_or_push!(integrator.sol.u,integrator.saveiter,integrator.u[integrator.opts.save_idxs],Val{false})
end
if typeof(integrator.alg) <: FunctionMap || integrator.opts.dense
integrator.saveiter_dense +=1
if integrator.opts.dense
if integrator.opts.save_idxs ==nothing
copyat_or_push!(integrator.sol.k,integrator.saveiter_dense,integrator.k)
else
copyat_or_push!(integrator.sol.k,integrator.saveiter_dense,[k[integrator.opts.save_idxs] for k in integrator.k],Val{false})
end
end
end
if typeof(integrator.alg) <: OrdinaryDiffEqCompositeAlgorithm
copyat_or_push!(integrator.sol.alg_choice,integrator.saveiter,integrator.cache.current)
end
end
else # ==t, just save
copyat_or_push!(integrator.sol.t,integrator.saveiter,integrator.t)
end
if force_save || (integrator.opts.save_everystep && integrator.iter%integrator.opts.timeseries_steps==0)
integrator.saveiter += 1
if integrator.opts.save_idxs == nothing
copyat_or_push!(integrator.sol.u,integrator.saveiter,integrator.u)
else
copyat_or_push!(integrator.sol.u,integrator.saveiter,integrator.u[integrator.opts.save_idxs],Val{false})
end
copyat_or_push!(integrator.sol.t,integrator.saveiter,integrator.t)
if typeof(integrator.alg) <: FunctionMap || integrator.opts.dense
integrator.saveiter_dense +=1
if integrator.opts.dense
if integrator.opts.save_idxs ==nothing
if integrator.opts.save_idxs == nothing
copyat_or_push!(integrator.sol.k,integrator.saveiter_dense,integrator.k)
else
copyat_or_push!(integrator.sol.k,integrator.saveiter_dense,[k[integrator.opts.save_idxs] for k in integrator.k],Val{false})
Expand All @@ -89,30 +113,8 @@ function savevalues!(integrator::ODEIntegrator,force_save=false,reduce_size=true
copyat_or_push!(integrator.sol.alg_choice,integrator.saveiter,integrator.cache.current)
end
end
reduce_size && resize!(integrator.k,integrator.kshortsize)
end
if force_save || (integrator.opts.save_everystep && integrator.iter%integrator.opts.timeseries_steps==0)
integrator.saveiter += 1
if integrator.opts.save_idxs == nothing
copyat_or_push!(integrator.sol.u,integrator.saveiter,integrator.u)
else
copyat_or_push!(integrator.sol.u,integrator.saveiter,integrator.u[integrator.opts.save_idxs],Val{false})
end
copyat_or_push!(integrator.sol.t,integrator.saveiter,integrator.t)
if typeof(integrator.alg) <: FunctionMap || integrator.opts.dense
integrator.saveiter_dense +=1
if integrator.opts.dense
if integrator.opts.save_idxs == nothing
copyat_or_push!(integrator.sol.k,integrator.saveiter_dense,integrator.k)
else
copyat_or_push!(integrator.sol.k,integrator.saveiter_dense,[k[integrator.opts.save_idxs] for k in integrator.k],Val{false})
end
end
end
if typeof(integrator.alg) <: OrdinaryDiffEqCompositeAlgorithm
copyat_or_push!(integrator.sol.alg_choice,integrator.saveiter,integrator.cache.current)
end
end
reduce_size && resize!(integrator.k,integrator.kshortsize)
end

function postamble!(integrator::ODEIntegrator)
Expand Down
1 change: 1 addition & 0 deletions src/integrators/type.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ mutable struct DEOptions{absType,relType,QT,tType,F1,F2,F3,F4,F5,F6,tstopsType,d
beta2::QT
qoldinit::QT
dense::Bool
save_on::Bool
save_start::Bool
save_end::Bool
callback::F3
Expand Down
3 changes: 2 additions & 1 deletion src/solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ function DiffEqBase.__init(
save_idxs = nothing,
save_everystep = isempty(saveat),
save_timeseries = nothing,
save_on = true,
save_start = save_everystep || isempty(saveat) || typeof(saveat) <: Number ? true : prob.tspan[1] in saveat,
save_end = save_everystep || isempty(saveat) || typeof(saveat) <: Number ? true : prob.tspan[2] in saveat,
callback=nothing,
Expand Down Expand Up @@ -252,7 +253,7 @@ function DiffEqBase.__init(
userdata,progress,progress_steps,
progress_name,progress_message,timeseries_errors,dense_errors,
QT(beta1),QT(beta2),QT(qoldinit),dense,
save_start,save_end,callbacks_internal,isoutofdomain,
save_on,save_start,save_end,callbacks_internal,isoutofdomain,
unstable_check,verbose,
calck,force_dtmin,advance_to_tstop,stop_at_next_tstop)

Expand Down
6 changes: 6 additions & 0 deletions test/ode/ode_saveat_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -94,3 +94,9 @@ sol2=solve(prob2,DP5(),dt=1//2^(2),saveat=.1,save_idxs=1:2:5,save_everystep=true
sol=solve(prob2,DP5(),dt=1//2^(2),save_start=false)

@test sol.t[1] == 1//2^(2)

# Test save_on switch
sol = solve(prob, DP5(), save_on=false, save_start=false, save_end=false)
@test isempty(sol.t) && isempty(sol.u)
sol = solve(prob, DP5(), saveat=0.2, save_on=false, save_start=false, save_end=false)
@test isempty(sol.t) && isempty(sol.u)

0 comments on commit 269794e

Please sign in to comment.