Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,23 +1,27 @@
name = "OrdinaryDiffEqOperatorSplitting"
uuid = "760fc936-9fa0-4281-9c1a-468957eedc89"
authors = ["Dennis Ogiermann <termi-official@users.noreply.github.com> and contributors"]
version = "0.2.2"
authors = ["Dennis Ogiermann <termi-official@users.noreply.github.com> and contributors"]

[deps]
CommonSolve = "38540f10-b2f7-11e9-35d8-d573e4eb0ff2"
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
OrdinaryDiffEqCore = "bbf590c4-e513-4bbe-9b18-05decba2e5d8"
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"
UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"
Unrolled = "9602ed7d-8fef-5bc8-8597-8f21381861e8"

[compat]
CommonSolve = "0.2.4"
DataStructures = "0.18.22, 0.19"
DiffEqBase = "6.165.1"
ModelingToolkit = "10"
OrdinaryDiffEqCore = "1.19.0"
OrdinaryDiffEqTsit5 = "1.1.0"
RecursiveArrayTools = "3.39.0"
SafeTestsets = "0.1.0"
SciMLBase = "2.77.0"
TimerOutputs = "0.5.28"
Expand Down
9 changes: 6 additions & 3 deletions src/OrdinaryDiffEqOperatorSplitting.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,20 @@ timeit_debug_enabled() = false
import Unrolled: @unroll

import SciMLBase, DiffEqBase, DataStructures
import SciMLBase: ReturnCode
import SciMLBase: DEIntegrator, NullParameters, isadaptive

import RecursiveArrayTools

import OrdinaryDiffEqCore

import UnPack: @unpack
import DiffEqBase: init, TimeChoiceIterator

abstract type AbstractOperatorSplitFunction <: DiffEqBase.AbstractODEFunction{true} end
abstract type AbstractOperatorSplitFunction <: SciMLBase.AbstractODEFunction{true} end
abstract type AbstractOperatorSplittingAlgorithm end
abstract type AbstractOperatorSplittingCache end

@inline DiffEqBase.isadaptive(::AbstractOperatorSplittingAlgorithm) = false
@inline SciMLBase.isadaptive(::AbstractOperatorSplittingAlgorithm) = false

include("function.jl")
include("problem.jl")
Expand Down
86 changes: 43 additions & 43 deletions src/integrator.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ mutable struct OperatorSplittingIntegrator{
syncTreeType,
controllerType,
optionsType
} <: DiffEqBase.AbstractODEIntegrator{algType, true, uType, tType}
} <: SciMLBase.AbstractODEIntegrator{algType, true, uType, tType}
const f::fType
const alg::algType
u::uType # Master Solution
Expand Down Expand Up @@ -76,7 +76,7 @@ mutable struct OperatorSplittingIntegrator{
end

# called by DiffEqBase.init and DiffEqBase.solve
function DiffEqBase.__init(
function SciMLBase.__init(
prob::OperatorSplittingProblem,
alg::AbstractOperatorSplittingAlgorithm,
args...;
Expand All @@ -87,7 +87,7 @@ function DiffEqBase.__init(
save_everystep = false,
callback = nothing,
advance_to_tstop = false,
adaptive = DiffEqBase.isadaptive(alg),
adaptive = isadaptive(alg),
controller = nothing,
alias_u0 = false,
verbose = true,
Expand All @@ -102,9 +102,9 @@ function DiffEqBase.__init(
tType = typeof(dt)

# Warn if the algorithm is non-adaptive but the user tries to make it adaptive.
(!DiffEqBase.isadaptive(alg) && adaptive && verbose) && warn("The algorithm $alg is not adaptive.")
(!isadaptive(alg) && adaptive && verbose) && warn("The algorithm $alg is not adaptive.")

dtchangeable = true # DiffEqBase.isadaptive(alg)
dtchangeable = true # isadaptive(alg)

if tstops isa AbstractArray || tstops isa Tuple || tstops isa Number
_tstops = nothing
Expand All @@ -125,7 +125,7 @@ function DiffEqBase.__init(
tmp = setup_u(prob, alg, false)
uType = typeof(u)

sol = DiffEqBase.build_solution(prob, alg, tType[], uType[])
sol = SciMLBase.build_solution(prob, alg, tType[], uType[])

callback = DiffEqBase.CallbackSet(callback)

Expand Down Expand Up @@ -176,7 +176,7 @@ function DiffEqBase.__init(
return integrator
end

DiffEqBase.has_reinit(integrator::OperatorSplittingIntegrator) = true
SciMLBase.has_reinit(integrator::OperatorSplittingIntegrator) = true
function DiffEqBase.reinit!(
integrator::OperatorSplittingIntegrator,
u0 = integrator.sol.prob.u0;
Expand Down Expand Up @@ -209,8 +209,8 @@ function DiffEqBase.reinit!(
DiffEqBase.initialize!(saving_callback, u0, t0, integrator)
end
if reinit_retcode
integrator.sol = DiffEqBase.solution_new_retcode(
integrator.sol, SciMLBase.ReturnCode.Default)
integrator.sol = SciMLBase.solution_new_retcode(
integrator.sol, ReturnCode.Default)
end

subreinit!(
Expand All @@ -231,7 +231,7 @@ function subreinit!(
f,
u0,
solution_indices,
subintegrator::DiffEqBase.DEIntegrator;
subintegrator::DEIntegrator;
dt,
kwargs...
)
Expand Down Expand Up @@ -325,7 +325,7 @@ function store_previous_info!(integrator::OperatorSplittingIntegrator)
end

function update_uprev!(integrator::OperatorSplittingIntegrator)
SciMLBase.recursivecopy!(integrator.uprev, integrator.u)
RecursiveArrayTools.recursivecopy!(integrator.uprev, integrator.u)
nothing
end

Expand Down Expand Up @@ -376,7 +376,7 @@ function step_footer!(integrator::OperatorSplittingIntegrator)
# OrdinaryDiffEqCore.handle_callbacks!(integrator)
step_accept_controller!(integrator) # Noop for non-adaptive algorithms
elseif integrator.force_stepfail
if SciMLBase.isadaptive(integrator)
if isadaptive(integrator)
step_reject_controller!(integrator)
OrdinaryDiffEqCore.post_newton_controller!(integrator, integrator.alg)
elseif integrator.dtchangeable # Non-adaptive but can change dt
Expand All @@ -393,9 +393,9 @@ function step_footer!(integrator::OperatorSplittingIntegrator)
end

# called by DiffEqBase.solve
function DiffEqBase.__solve(prob::OperatorSplittingProblem,
function SciMLBase.__solve(prob::OperatorSplittingProblem,
alg::AbstractOperatorSplittingAlgorithm, args...; kwargs...)
integrator = DiffEqBase.__init(prob, alg, args...; kwargs...)
integrator = SciMLBase.__init(prob, alg, args...; kwargs...)
DiffEqBase.solve!(integrator)
end

Expand All @@ -404,8 +404,8 @@ function DiffEqBase.solve!(integrator::OperatorSplittingIntegrator)
while !isempty(integrator.tstops)
while tdir(integrator) * integrator.t < SciMLBase.first_tstop(integrator)
step_header!(integrator)
@timeit_debug "check_error" DiffEqBase.check_error!(integrator) ∉ (
SciMLBase.ReturnCode.Success, SciMLBase.ReturnCode.Default)&&return
@timeit_debug "check_error" SciMLBase.check_error!(integrator) ∉ (
ReturnCode.Success, ReturnCode.Default)&&return
__step!(integrator)
step_footer!(integrator)
if !SciMLBase.has_tstop(integrator)
Expand All @@ -414,21 +414,21 @@ function DiffEqBase.solve!(integrator::OperatorSplittingIntegrator)
end
OrdinaryDiffEqCore.handle_tstop!(integrator)
end
OrdinaryDiffEqCore.postamble!(integrator)
if integrator.sol.retcode != SciMLBase.ReturnCode.Default
SciMLBase.postamble!(integrator)
if integrator.sol.retcode != ReturnCode.Default
return integrator.sol
end
return integrator.sol = SciMLBase.solution_new_retcode(
integrator.sol, SciMLBase.ReturnCode.Success)
integrator.sol, ReturnCode.Success)
end

function DiffEqBase.step!(integrator::OperatorSplittingIntegrator)
@timeit_debug "step!" if integrator.advance_to_tstop
tstop = first_tstop(integrator)
while !reached_tstop(integrator, tstop)
step_header!(integrator)
@timeit_debug "check_error" DiffEqBase.check_error!(integrator) ∉ (
SciMLBase.ReturnCode.Success, SciMLBase.ReturnCode.Default)&&return
@timeit_debug "check_error" SciMLBase.check_error!(integrator) ∉ (
ReturnCode.Success, ReturnCode.Default)&&return
__step!(integrator)
step_footer!(integrator)
if !SciMLBase.has_tstop(integrator)
Expand All @@ -437,14 +437,14 @@ function DiffEqBase.step!(integrator::OperatorSplittingIntegrator)
end
else
step_header!(integrator)
@timeit_debug "check_error" DiffEqBase.check_error!(integrator) ∉ (
SciMLBase.ReturnCode.Success, SciMLBase.ReturnCode.Default)&&return
@timeit_debug "check_error" SciMLBase.check_error!(integrator) ∉ (
ReturnCode.Success, ReturnCode.Default)&&return
__step!(integrator)
step_footer!(integrator)
while !should_accept_step(integrator)
step_header!(integrator)
@timeit_debug "check_error" DiffEqBase.check_error!(integrator) ∉ (
SciMLBase.ReturnCode.Success, SciMLBase.ReturnCode.Default)&&return
@timeit_debug "check_error" SciMLBase.check_error!(integrator) ∉ (
ReturnCode.Success, ReturnCode.Default)&&return
__step!(integrator)
step_footer!(integrator)
end
Expand All @@ -454,7 +454,7 @@ end

function SciMLBase.check_error(integrator::OperatorSplittingIntegrator)
if !SciMLBase.successful_retcode(integrator.sol) &&
integrator.sol.retcode != SciMLBase.ReturnCode.Default
integrator.sol.retcode != ReturnCode.Default
return integrator.sol.retcode
end

Expand All @@ -464,7 +464,7 @@ function SciMLBase.check_error(integrator::OperatorSplittingIntegrator)
if verbose
@warn("NaN dt detected. Likely a NaN value in the state, parameters, or derivative value caused this outcome.")
end
return SciMLBase.ReturnCode.DtNaN
return ReturnCode.DtNaN
end

return check_error_subintegrators(integrator, integrator.subintegrator_tree)
Expand All @@ -473,14 +473,14 @@ end
function check_error_subintegrators(integrator, subintegrator_tree::Tuple)
for subintegrator in subintegrator_tree
retcode = check_error_subintegrators(integrator, subintegrator)
if !SciMLBase.successful_retcode(retcode) && retcode != SciMLBase.ReturnCode.Default
if !SciMLBase.successful_retcode(retcode) && retcode != ReturnCode.Default
return retcode
end
end
return integrator.sol.retcode
end

function check_error_subintegrators(integrator, subintegrator::SciMLBase.DEIntegrator)
function check_error_subintegrators(integrator, subintegrator::DEIntegrator)
return SciMLBase.check_error(subintegrator)
end

Expand All @@ -494,8 +494,8 @@ function DiffEqBase.step!(integrator::OperatorSplittingIntegrator, dt, stop_at_t
stop_at_tdt && DiffEqBase.add_tstop!(integrator, tnext)
while !reached_tstop(integrator, tnext, stop_at_tdt)
step_header!(integrator)
@timeit_debug "check_error" DiffEqBase.check_error!(integrator) ∉ (
SciMLBase.ReturnCode.Success, SciMLBase.ReturnCode.Default)&&return
@timeit_debug "check_error" SciMLBase.check_error!(integrator) ∉ (
ReturnCode.Success, ReturnCode.Default)&&return
__step!(integrator)
step_footer!(integrator)
end
Expand All @@ -506,7 +506,7 @@ function setup_u(prob::OperatorSplittingProblem, solver, alias_u0)
if alias_u0
return prob.u0
else
return OrdinaryDiffEqCore.recursivecopy(prob.u0)
return RecursiveArrayTools.recursivecopy(prob.u0)
end
end

Expand Down Expand Up @@ -535,7 +535,7 @@ Updates the controller using the current state of the integrator if the operator
"""
@inline function stepsize_controller!(integrator::OperatorSplittingIntegrator)
algorithm = integrator.alg
DiffEqBase.isadaptive(algorithm) || return nothing
isadaptive(algorithm) || return nothing
stepsize_controller!(integrator, algorithm)
end

Expand All @@ -546,7 +546,7 @@ Updates `dtcache` of the integrator if the step is accepted and the operator spl
"""
@inline function step_accept_controller!(integrator::OperatorSplittingIntegrator)
algorithm = integrator.alg
DiffEqBase.isadaptive(algorithm) || return nothing
isadaptive(algorithm) || return nothing
step_accept_controller!(integrator, algorithm, nothing)
end

Expand All @@ -557,7 +557,7 @@ Updates `dtcache` of the integrator if the step is rejected and the the operator
"""
@inline function step_reject_controller!(integrator::OperatorSplittingIntegrator)
algorithm = integrator.alg
DiffEqBase.isadaptive(algorithm) || return nothing
isadaptive(algorithm) || return nothing
step_reject_controller!(integrator, algorithm, nothing)
end

Expand All @@ -578,16 +578,16 @@ end
# Dunno stuff
function SciMLBase.done(integrator::OperatorSplittingIntegrator)
if !(integrator.sol.retcode in (
SciMLBase.ReturnCode.Default, SciMLBase.ReturnCode.Success))
ReturnCode.Default, ReturnCode.Success))
return true
elseif isempty(integrator.tstops)
DiffEqBase.postamble!(integrator)
SciMLBase.postamble!(integrator)
return true
end
false
end

function DiffEqBase.postamble!(integrator::OperatorSplittingIntegrator)
function SciMLBase.postamble!(integrator::OperatorSplittingIntegrator)
DiffEqBase.finalize!(integrator.callback, integrator.u, integrator.t, integrator)
end

Expand All @@ -604,7 +604,7 @@ function advance_solution_to!(integrator::OperatorSplittingIntegrator, tnext)
end

function advance_solution_to!(outer_integrator::OperatorSplittingIntegrator,
integrator::DiffEqBase.DEIntegrator, solution_indices, sync, cache, tend)
integrator::DEIntegrator, solution_indices, sync, cache, tend)
dt = tend - integrator.t
SciMLBase.step!(integrator, dt, true)
end
Expand Down Expand Up @@ -653,10 +653,10 @@ end
end

function synchronize_subintegrator!(
subintegrator::SciMLBase.DEIntegrator, integrator::OperatorSplittingIntegrator)
subintegrator::DEIntegrator, integrator::OperatorSplittingIntegrator)
@unpack t, dt = integrator
@assert subintegrator.t == t
if !DiffEqBase.isadaptive(subintegrator)
if !isadaptive(subintegrator)
SciMLBase.set_proposed_dt!(subintegrator, dt)
end
end
Expand Down Expand Up @@ -763,13 +763,13 @@ function build_subintegrator_tree_with_cache(
# In that case ODEProblem constructs the correct parameter struct.
# If the system does not have parameters in first place, then
# The NullParameters object will be constructed automatically.
prob2 = if p isa DiffEqBase.NullParameters
prob2 = if p isa NullParameters
SciMLBase.ODEProblem(f, u, (t0, min(t0 + dt, tf)))
else
SciMLBase.ODEProblem(f, u, (t0, min(t0 + dt, tf)), p)
end

integrator = DiffEqBase.__init(
integrator = SciMLBase.__init(
prob2,
alg;
dt,
Expand Down
4 changes: 2 additions & 2 deletions src/problem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"""
mutable struct OperatorSplittingProblem{
fType <: AbstractOperatorSplitFunction, uType, tType, pType <: Tuple, K} <:
DiffEqBase.AbstractODEProblem{uType, tType, true}
SciMLBase.AbstractODEProblem{uType, tType, true}
f::fType
u0::uType
tspan::tType
Expand All @@ -29,5 +29,5 @@ function recursive_null_parameters(f::GenericSplitFunction)
ntuple(i->recursive_null_parameters(get_operator(f, i)), length(f.functions))
end
function recursive_null_parameters(f) # Wildcard for leafs
DiffEqBase.NullParameters()
NullParameters()
end
6 changes: 3 additions & 3 deletions src/solver.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ function init_cache(f::GenericSplitFunction, alg::LieTrotterGodunov;
alias_uprev = true,
alias_u = false
)
_uprev = alias_uprev ? uprev : SciMLBase.recursivecopy(uprev)
_u = alias_u ? u : SciMLBase.recursivecopy(u)
_uprev = alias_uprev ? uprev : RecursiveArrayTools.recursivecopy(uprev)
_u = alias_u ? u : RecursiveArrayTools.recursivecopy(u)
LieTrotterGodunovCache(_u, _uprev, inner_caches)
end

Expand All @@ -45,7 +45,7 @@ end
outer_integrator, subinteg, idxs, synchronizer, cache, tnext)
if !(subinteg isa Tuple) &&
subinteg.sol.retcode ∉
(SciMLBase.ReturnCode.Default, SciMLBase.ReturnCode.Success)
(ReturnCode.Default, ReturnCode.Success)
return
end
backward_sync_subintegrator!(outer_integrator, subinteg, idxs, synchronizer)
Expand Down
Loading
Loading