Skip to content

Commit

Permalink
implemented iip_config
Browse files Browse the repository at this point in the history
  • Loading branch information
Brad Carman committed Oct 18, 2022
1 parent 9544860 commit 7d0c5c2
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 4 deletions.
12 changes: 8 additions & 4 deletions src/systems/diffeqs/abstractodesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -487,8 +487,10 @@ function ODEFunctionExpr{iip}(sys::AbstractODESystem, dvs = states(sys),
sparse = false, simplify = false,
steady_state = false,
sparsity = false,
iip_config = (true, true),
kwargs...) where {iip}
f_oop, f_iip = generate_function(sys, dvs, ps; expression = Val{true}, kwargs...)
f_oop, f_iip = generate_function(sys, dvs, ps; expression = Val{true}, iip_config,
kwargs...)

dict = Dict()

Expand All @@ -498,7 +500,8 @@ function ODEFunctionExpr{iip}(sys::AbstractODESystem, dvs = states(sys),
if tgrad
tgrad_oop, tgrad_iip = generate_tgrad(sys, dvs, ps;
simplify = simplify,
expression = Val{true}, kwargs...)
expression = Val{true},
iip_config, kwargs...)
_tgrad = :($tgradsym = $ODEFunctionClosure($tgrad_oop, $tgrad_iip))
else
_tgrad = :($tgradsym = nothing)
Expand All @@ -508,7 +511,8 @@ function ODEFunctionExpr{iip}(sys::AbstractODESystem, dvs = states(sys),
if jac
jac_oop, jac_iip = generate_jacobian(sys, dvs, ps;
sparse = sparse, simplify = simplify,
expression = Val{true}, kwargs...)
expression = Val{true},
iip_config, kwargs...)
_jac = :($jacsym = $ODEFunctionClosure($jac_oop, $jac_iip))
else
_jac = :($jacsym = nothing)
Expand Down Expand Up @@ -537,7 +541,7 @@ function ODEFunctionExpr{iip}(sys::AbstractODESystem, dvs = states(sys),
jac_prototype = $jp_expr,
syms = $(Symbol.(states(sys))),
indepsym = $(QuoteNode(Symbol(get_iv(sys)))),
paramsyms = $((Symbol.(parameters(sys))),
paramsyms = $(Symbol.(parameters(sys))),
sparsity = $(jacobian_sparsity(sys)))
end
!linenumbers ? striplines(ex) : ex
Expand Down
9 changes: 9 additions & 0 deletions test/odesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,15 @@ for f in [
@test J == f.jac(u, p, t)
end

#check iip_config
f = eval(ODEFunctionExpr(de, [x, y, z], [σ, ρ, β], iip_config = (false, true)))
du = zeros(3)
u = collect(1:3)
p = collect(4:6)
f.f(du, u, p, 0.1)
@test du == [4, 0, -16]
@test_throws ArgumentError f.f(u, p, 0.1)

eqs = [D(x) ~ σ * (y - x),
D(y) ~ x *- z) - y * t,
D(z) ~ x * y - β * z]
Expand Down

0 comments on commit 7d0c5c2

Please sign in to comment.