Skip to content

Commit

Permalink
more generate_function controls
Browse files Browse the repository at this point in the history
  • Loading branch information
ChrisRackauckas committed Dec 11, 2019
1 parent 55b2bbf commit 1bebe9c
Show file tree
Hide file tree
Showing 5 changed files with 19 additions and 14 deletions.
2 changes: 1 addition & 1 deletion Project.toml
@@ -1,7 +1,7 @@
name = "ModelingToolkit"
uuid = "961ee093-0014-501f-94e3-6117800e7a78"
authors = ["Chris Rackauckas <accounts@chrisrackauckas.com>"]
version = "1.0.1"
version = "1.0.2"

[deps]
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
Expand Down
2 changes: 1 addition & 1 deletion src/ModelingToolkit.jl
Expand Up @@ -13,7 +13,7 @@ using StaticArrays, LinearAlgebra
using Latexify

using MacroTools
import MacroTools: splitdef, combinedef, postwalk
import MacroTools: splitdef, combinedef, postwalk, striplines
import GeneralizedGenerated
using DocStringExtensions

Expand Down
14 changes: 7 additions & 7 deletions src/systems/diffeqs/diffeqsystem.jl
Expand Up @@ -160,16 +160,16 @@ function (f::ODEToExpr)(O::Operation)
end
(f::ODEToExpr)(x) = convert(Expr, x)

function generate_jacobian(sys::ODESystem, dvs = sys.dvs, ps = sys.ps, expression = Val{true})
function generate_jacobian(sys::ODESystem, dvs = sys.dvs, ps = sys.ps, expression = Val{true}; kwargs...)
jac = calculate_jacobian(sys)
return build_function(jac, dvs, ps, (sys.iv.name,), ODEToExpr(sys), expression)
return build_function(jac, dvs, ps, (sys.iv.name,), ODEToExpr(sys), expression; kwargs...)
end

function generate_function(sys::ODESystem, dvs = sys.dvs, ps = sys.ps, expression = Val{true})
function generate_function(sys::ODESystem, dvs = sys.dvs, ps = sys.ps, expression = Val{true}; kwargs...)
rhss = [deq.rhs for deq sys.eqs]
dvs′ = [clean(dv) for dv dvs]
ps′ = [clean(p) for p ps]
return build_function(rhss, dvs′, ps′, (sys.iv.name,), ODEToExpr(sys), expression)
return build_function(rhss, dvs′, ps′, (sys.iv.name,), ODEToExpr(sys), expression; kwargs...)
end

function calculate_factorized_W(sys::ODESystem, simplify=true)
Expand All @@ -196,16 +196,16 @@ function calculate_factorized_W(sys::ODESystem, simplify=true)
(Wfact,Wfact_t)
end

function generate_factorized_W(sys::ODESystem, vs = sys.dvs, ps = sys.ps, simplify=true, expression = Val{true})
function generate_factorized_W(sys::ODESystem, vs = sys.dvs, ps = sys.ps, simplify=true, expression = Val{true}; kwargs...)
(Wfact,Wfact_t) = calculate_factorized_W(sys,simplify)
siz = size(Wfact)
constructor = :(x -> begin
A = SMatrix{$siz...}(x)
StaticArrays.LU(LowerTriangular( SMatrix{$siz...}(UnitLowerTriangular(A)) ), UpperTriangular(A), SVector(ntuple(n->n, max($siz...))))
end)

Wfact_func = build_function(Wfact , vs, ps, (:gam,:t), ODEToExpr(sys), expression;constructor=constructor)
Wfact_t_func = build_function(Wfact_t, vs, ps, (:gam,:t), ODEToExpr(sys), expression;constructor=constructor)
Wfact_func = build_function(Wfact , vs, ps, (:gam,:t), ODEToExpr(sys), expression;constructor=constructor,kwargs...)
Wfact_t_func = build_function(Wfact_t, vs, ps, (:gam,:t), ODEToExpr(sys), expression;constructor=constructor,kwargs...)

return (Wfact_func, Wfact_t_func)
end
Expand Down
4 changes: 2 additions & 2 deletions src/systems/nonlinear/nonlinear_system.jl
Expand Up @@ -84,9 +84,9 @@ end
(f::NLSysToExpr)(x) = convert(Expr, x)


function generate_function(sys::NonlinearSystem, vs, ps, expression = Val{true}; version = nothing)
function generate_function(sys::NonlinearSystem, vs, ps, expression = Val{true}; kwargs...)
rhss = [eq.rhs for eq sys.eqs]
vs′ = [clean(v) for v vs]
ps′ = [clean(p) for p ps]
return build_function(rhss, vs′, ps′, (), NLSysToExpr(sys))
return build_function(rhss, vs′, ps′, (), NLSysToExpr(sys), expression; kwargs...)
end
11 changes: 8 additions & 3 deletions src/utils.jl
Expand Up @@ -31,8 +31,8 @@ function flatten_expr!(x)
x
end

function build_function(rhss, vs, ps = (), args = (), conv = simplified_expr, expression = Val{true};
checkbounds = false, constructor=nothing)
function build_function(rhss, vs, ps = (), args = (), conv = simplified_expr, expression = Val{true};
checkbounds = false, constructor=nothing, linenumbers = true)
_vs = map(x-> x isa Operation ? x.op : x, vs)
_ps = map(x-> x isa Operation ? x.op : x, ps)
var_pairs = [(u.name, :(u[$i])) for (i, u) enumerate(_vs)]
Expand All @@ -51,7 +51,7 @@ function build_function(rhss, vs, ps = (), args = (), conv = simplified_expr, ex
let_expr = Expr(:let, var_eqs, sys_expr)
bounds_block = checkbounds ? let_expr : :(@inbounds begin $let_expr end)
ip_bounds_block = checkbounds ? ip_let_expr : :(@inbounds begin $ip_let_expr end)

fargs = ps == () ? :(u,$(args...)) : :(u,p,$(args...))

oop_ex = :(
Expand All @@ -75,6 +75,11 @@ function build_function(rhss, vs, ps = (), args = (), conv = simplified_expr, ex
end
)

if !linenumbers
oop_ex = striplines(oop_ex)
iip_ex = striplines(iip_ex)
end

if expression == Val{true}
return oop_ex, iip_ex
else
Expand Down

0 comments on commit 1bebe9c

Please sign in to comment.