Skip to content
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ MacroTools = "0.5"
NaNMath = "0.3"
RecursiveArrayTools = "2.3"
Requires = "1.0"
RuntimeGeneratedFunctions = "0.4"
RuntimeGeneratedFunctions = "0.4.3"
SafeTestsets = "0.0.1"
SpecialFunctions = "0.7, 0.8, 0.9, 0.10, 1.0"
StaticArrays = "0.10, 0.11, 0.12, 1.0"
Expand Down
8 changes: 5 additions & 3 deletions src/build_function.jl
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ end
# Scalar output
function _build_function(target::JuliaTarget, op, args...;
conv = toexpr, expression = Val{true},
expression_module = @__MODULE__,
checkbounds = false,
linenumbers = true, headerfun=addheader)

Expand Down Expand Up @@ -127,12 +128,12 @@ function _build_function(target::JuliaTarget, op, args...;
if expression == Val{true}
return ModelingToolkit.inject_registered_module_functions(oop_ex)
else
_build_and_inject_function(@__MODULE__, oop_ex)
_build_and_inject_function(expression_module, oop_ex)
end
end

function _build_and_inject_function(mod::Module, ex)
@RuntimeGeneratedFunction(ModelingToolkit.inject_registered_module_functions(ex))
@RuntimeGeneratedFunction(mod, ModelingToolkit.inject_registered_module_functions(ex))
end

# Detect heterogeneous element types of "arrays of matrices/sparce matrices"
Expand Down Expand Up @@ -218,6 +219,7 @@ Special Keyword Argumnets:
"""
function _build_function(target::JuliaTarget, rhss::AbstractArray, args...;
conv = toexpr, expression = Val{true},
expression_module = @__MODULE__,
checkbounds = false,
linenumbers = false, multithread=nothing,
headerfun = addheader, outputidxs=nothing,
Expand Down Expand Up @@ -457,7 +459,7 @@ function _build_function(target::JuliaTarget, rhss::AbstractArray, args...;
if expression == Val{true}
return ModelingToolkit.inject_registered_module_functions(oop_ex), ModelingToolkit.inject_registered_module_functions(iip_ex)
else
return _build_and_inject_function(@__MODULE__, oop_ex), _build_and_inject_function(@__MODULE__, iip_ex)
return _build_and_inject_function(expression_module, oop_ex), _build_and_inject_function(expression_module, iip_ex)
end
end

Expand Down
13 changes: 7 additions & 6 deletions src/systems/diffeqs/abstractodesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -124,19 +124,20 @@ function DiffEqBase.ODEFunction{iip}(sys::AbstractODESystem, dvs = states(sys),
version = nothing, tgrad=false,
jac = false,
eval_expression = true,
eval_module = @__MODULE__,
sparse = false, simplify = true,
kwargs...) where {iip}

f_gen = generate_function(sys, dvs, ps; expression=Val{eval_expression}, kwargs...)
f_oop,f_iip = eval_expression ? (@RuntimeGeneratedFunction(ex) for ex in f_gen) : f_gen
f_gen = generate_function(sys, dvs, ps; expression=Val{eval_expression}, expression_module=eval_module, kwargs...)
f_oop,f_iip = eval_expression ? (@RuntimeGeneratedFunction(eval_module, ex) for ex in f_gen) : f_gen
f(u,p,t) = f_oop(u,p,t)
f(du,u,p,t) = f_iip(du,u,p,t)

if tgrad
tgrad_gen = generate_tgrad(sys, dvs, ps;
simplify=simplify,
expression=Val{eval_expression}, kwargs...)
tgrad_oop,tgrad_iip = eval_expression ? (@RuntimeGeneratedFunction(ex) for ex in tgrad_gen) : tgrad_gen
expression=Val{eval_expression}, expression_module=eval_module, kwargs...)
tgrad_oop,tgrad_iip = eval_expression ? (@RuntimeGeneratedFunction(eval_module, ex) for ex in tgrad_gen) : tgrad_gen
_tgrad(u,p,t) = tgrad_oop(u,p,t)
_tgrad(J,u,p,t) = tgrad_iip(J,u,p,t)
else
Expand All @@ -146,8 +147,8 @@ function DiffEqBase.ODEFunction{iip}(sys::AbstractODESystem, dvs = states(sys),
if jac
jac_gen = generate_jacobian(sys, dvs, ps;
simplify=simplify, sparse = sparse,
expression=Val{eval_expression}, kwargs...)
jac_oop,jac_iip = eval_expression ? (@RuntimeGeneratedFunction(ex) for ex in jac_gen) : jac_gen
expression=Val{eval_expression}, expression_module=eval_module, kwargs...)
jac_oop,jac_iip = eval_expression ? (@RuntimeGeneratedFunction(eval_module, ex) for ex in jac_gen) : jac_gen
_jac(u,p,t) = jac_oop(u,p,t)
_jac(J,u,p,t) = jac_iip(J,u,p,t)
else
Expand Down
26 changes: 26 additions & 0 deletions test/precompile_test.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
using Test
using ModelingToolkit

# Test that the precompiled ODE system works
push!(LOAD_PATH, joinpath(@__DIR__, "precompile_test"))
using ODEPrecompileTest

u = collect(1:3)
p = collect(4:6)

# This case does not work, because "f_bad" gets defined in ModelingToolkit
# instead of in the compiled module!
@test parentmodule(typeof(ODEPrecompileTest.f_bad.f.f_iip).parameters[2]) == ModelingToolkit
@test parentmodule(typeof(ODEPrecompileTest.f_bad.f.f_oop).parameters[2]) == ModelingToolkit
@test parentmodule(typeof(ODEPrecompileTest.f_noeval_bad.f.f_iip).parameters[2]) == ModelingToolkit
@test parentmodule(typeof(ODEPrecompileTest.f_noeval_bad.f.f_oop).parameters[2]) == ModelingToolkit
@test_throws KeyError ODEPrecompileTest.f_bad(u, p, 0.1)
@test_throws KeyError ODEPrecompileTest.f_noeval_bad(u, p, 0.1)

# This case works, because "f_good" gets defined in the precompiled module.
@test parentmodule(typeof(ODEPrecompileTest.f_good.f.f_iip).parameters[2]) == ODEPrecompileTest
@test parentmodule(typeof(ODEPrecompileTest.f_good.f.f_oop).parameters[2]) == ODEPrecompileTest
@test parentmodule(typeof(ODEPrecompileTest.f_noeval_good.f.f_iip).parameters[2]) == ODEPrecompileTest
@test parentmodule(typeof(ODEPrecompileTest.f_noeval_good.f.f_oop).parameters[2]) == ODEPrecompileTest
@test ODEPrecompileTest.f_good(u, p, 0.1) == [4, 0, -16]
@test ODEPrecompileTest.f_noeval_good(u, p, 0.1) == [4, 0, -16]
32 changes: 32 additions & 0 deletions test/precompile_test/ODEPrecompileTest.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
module ODEPrecompileTest
using ModelingToolkit

function system(; kwargs...)
# Define some variables
@parameters t σ ρ β
@variables x(t) y(t) z(t)
@derivatives D'~t

# Define a differential equation
eqs = [D(x) ~ σ*(y-x),
D(y) ~ x*(ρ-z)-y,
D(z) ~ x*y - β*z]

de = ODESystem(eqs)
return ODEFunction(de, [x,y,z], [σ,ρ,β]; kwargs...)
end

# Build an ODEFunction as part of the module's precompilation. This case
# will not work, because the generated RGFs will be put into
# ModelingToolkit's RGF cache.
const f_bad = system()

# This case will work, because it will be put into our own module's cache.
using RuntimeGeneratedFunctions
RuntimeGeneratedFunctions.init(@__MODULE__)
const f_good = system(; eval_module=@__MODULE__)

# Also test that eval_expression=false works
const f_noeval_bad = system(; eval_expression=false)
const f_noeval_good = system(; eval_expression=false, eval_module=@__MODULE__)
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,4 @@ using SafeTestsets, Test
println("Last test requires gcc available in the path!")
@safetestset "C Compilation Test" begin include("ccompile.jl") end
@safetestset "Latexify recipes Test" begin include("latexify.jl") end
@safetestset "Precompiled Modules Test" begin include("precompile_test.jl") end