Skip to content

Commit

Permalink
Merge pull request #283 from SciML/variablerate_oop
Browse files Browse the repository at this point in the history
Out of place variable rate jump extensions
  • Loading branch information
ChrisRackauckas committed Jan 6, 2023
2 parents da0418e + 9c70cda commit 1586fce
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 18 deletions.
7 changes: 7 additions & 0 deletions src/extended_jump_array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -189,3 +189,10 @@ unpack(x::ExtendedJumpArray, ::Val{:jump_u}) = x.jump_u
end
unpack_args(i, args::Tuple{Any}) = (unpack(args[1], i),)
unpack_args(::Any, args::Tuple{}) = ()

Base.:*(x::ExtendedJumpArray, y::Number) = ExtendedJumpArray(y .* x.u, y .* x.jump_u)
Base.:*(y::Number, x::ExtendedJumpArray) = ExtendedJumpArray(y .* x.u, y .* x.jump_u)
Base.:/(x::ExtendedJumpArray, y::Number) = ExtendedJumpArray(x.u ./ y, x.jump_u ./ y)
function Base.:+(x::ExtendedJumpArray, y::ExtendedJumpArray)
ExtendedJumpArray(x.u .+ y.u, x.jump_u .+ y.jump_u)
end
89 changes: 72 additions & 17 deletions src/problem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -252,22 +252,47 @@ end
function extend_problem(prob::DiffEqBase.AbstractODEProblem, jumps; rng = DEFAULT_RNG)
_f = SciMLBase.unwrapped_f(prob.f)

jump_f = let _f = _f
function jump_f(du::ExtendedJumpArray, u::ExtendedJumpArray, p, t)
_f(du.u, u.u, p, t)
update_jumps!(du, u, p, t, length(u.u), jumps...)
if isinplace(prob)
jump_f = let _f = _f
function (du::ExtendedJumpArray, u::ExtendedJumpArray, p, t)
_f(du.u, u.u, p, t)
update_jumps!(du, u, p, t, length(u.u), jumps...)
end
end
else
jump_f = let _f = _f
function (u::ExtendedJumpArray, p, t)
du = ExtendedJumpArray(_f(u.u, p, t), u.jump_u)
update_jumps!(du, u, p, t, length(u.u), jumps...)
return du
end
end
end

ttype = eltype(prob.tspan)
u0 = ExtendedJumpArray(prob.u0,
[-randexp(rng, ttype) for i in 1:length(jumps)])
remake(prob, f = ODEFunction{true}(jump_f), u0 = u0)
remake(prob, f = ODEFunction{isinplace(prob)}(jump_f), u0 = u0)
end

function extend_problem(prob::DiffEqBase.AbstractSDEProblem, jumps; rng = DEFAULT_RNG)
function jump_f(du, u, p, t)
prob.f(du.u, u.u, p, t)
update_jumps!(du, u, p, t, length(u.u), jumps...)
_f = SciMLBase.unwrapped_f(prob.f)

if isinplace(prob)
jump_f = let _f = _f
function (du::ExtendedJumpArray, u::ExtendedJumpArray, p, t)
_f(du.u, u.u, p, t)
update_jumps!(du, u, p, t, length(u.u), jumps...)
end
end
else
jump_f = let _f = _f
function (u::ExtendedJumpArray, p, t)
du = ExtendedJumpArray(_f(u.u, p, t), u.jump_u)
update_jumps!(du, u, p, t, length(u.u), jumps...)
return du
end
end
end

if prob.noise_rate_prototype === nothing
Expand All @@ -283,30 +308,60 @@ function extend_problem(prob::DiffEqBase.AbstractSDEProblem, jumps; rng = DEFAUL
ttype = eltype(prob.tspan)
u0 = ExtendedJumpArray(prob.u0,
[-randexp(rng, ttype) for i in 1:length(jumps)])
remake(prob, f = SDEFunction{true}(jump_f, jump_g), g = jump_g, u0 = u0)
remake(prob, f = SDEFunction{isinplace(prob)}(jump_f, jump_g), g = jump_g, u0 = u0)
end

function extend_problem(prob::DiffEqBase.AbstractDDEProblem, jumps; rng = DEFAULT_RNG)
jump_f = function (du, u, h, p, t)
prob.f(du.u, u.u, h, p, t)
update_jumps!(du, u, p, t, length(u.u), jumps...)
_f = SciMLBase.unwrapped_f(prob.f)

if isinplace(prob)
jump_f = let _f = _f
function (du::ExtendedJumpArray, u::ExtendedJumpArray, h, p, t)
_f(du.u, u.u, h, p, t)
update_jumps!(du, u, p, t, length(u.u), jumps...)
end
end
else
jump_f = let _f = _f
function (u::ExtendedJumpArray, h, p, t)
du = ExtendedJumpArray(_f(u.u, h, p, t), u.jump_u)
update_jumps!(du, u, p, t, length(u.u), jumps...)
return du
end
end
end

ttype = eltype(prob.tspan)
u0 = ExtendedJumpArray(prob.u0,
[-randexp(rng, ttype) for i in 1:length(jumps)])
remake(prob, f = DDEFunction{true}(jump_f), u0 = u0)
remake(prob, f = DDEFunction{isinplace(prob)}(jump_f), u0 = u0)
end

# Not sure if the DAE one is correct: Should be a residual of sorts
function extend_problem(prob::DiffEqBase.AbstractDAEProblem, jumps; rng = DEFAULT_RNG)
jump_f = function (out, du, u, p, t)
prob.f(out.u, du.u, u.u, t)
update_jumps!(du, u, t, length(u.u), jumps...)
_f = SciMLBase.unwrapped_f(prob.f)

if isinplace(prob)
jump_f = let _f = _f
function (out, du::ExtendedJumpArray, u::ExtendedJumpArray, h, p, t)
_f(out, du.u, u.u, h, p, t)
update_jumps!(out, u, p, t, length(u.u), jumps...)
end
end
else
jump_f = let _f = _f
function (du, u::ExtendedJumpArray, h, p, t)
out = ExtendedJumpArray(_f(du.u, u.u, h, p, t), u.jump_u)
update_jumps!(du, u, p, t, length(u.u), jumps...)
return du
end
end
end

ttype = eltype(prob.tspan)
u0 = ExtendedJumpArray(prob.u0,
[-randexp(rng, ttype) for i in 1:length(jumps)])
remake(prob, f = DAEFunction{true}(jump_f), u0 = u0)
remake(prob, f = DAEFunction{isinplace(prob)}(jump_f), u0 = u0)
end

function build_variable_callback(cb, idx, jump, jumps...; rng = DEFAULT_RNG)
Expand Down
20 changes: 19 additions & 1 deletion test/variable_rate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,24 @@ prob = ODEProblem(f4, [x₀], Δt)
jumpProblem = JumpProblem(prob, Direct(), jump)
sol = solve(jumpProblem, Tsit5())

# Out of place test

function drift(x, p, t)
return p * x
end

function rate2(x, p, t)
return 3 * max(0.0, x[1])
end

function affect!2(integrator)
integrator.u ./= 2
end
x0 = rand(2)
prob = ODEProblem(drift, x0, (0.0, 10.0), 2.0)
jump = VariableRateJump(rate2, affect!2)
jump_prob = JumpProblem(prob, Direct(), jump)

# test to check lack of dependency graphs is caught in Coevolve for systems with non-maj
# jumps
let
Expand Down Expand Up @@ -172,4 +190,4 @@ let
rateinterval = ((u, p, t) -> 1.0))
@test_throws ErrorException JumpProblem(dprob_, alg, mass_action_jump_, vrj;
save_positions = (false, false))
end
end

0 comments on commit 1586fce

Please sign in to comment.