diff --git a/Project.toml b/Project.toml index ff8d8ce886..9dd107c933 100644 --- a/Project.toml +++ b/Project.toml @@ -46,7 +46,7 @@ MacroTools = "0.5" NaNMath = "0.3" RecursiveArrayTools = "2.3" Requires = "1.0" -RuntimeGeneratedFunctions = "0.4" +RuntimeGeneratedFunctions = "0.4.3" SafeTestsets = "0.0.1" SpecialFunctions = "0.7, 0.8, 0.9, 0.10, 1.0" StaticArrays = "0.10, 0.11, 0.12, 1.0" diff --git a/src/build_function.jl b/src/build_function.jl index 752f7e69ac..dc6ff3642b 100644 --- a/src/build_function.jl +++ b/src/build_function.jl @@ -100,6 +100,7 @@ end # Scalar output function _build_function(target::JuliaTarget, op, args...; conv = toexpr, expression = Val{true}, + expression_module = @__MODULE__, checkbounds = false, linenumbers = true, headerfun=addheader) @@ -127,12 +128,12 @@ function _build_function(target::JuliaTarget, op, args...; if expression == Val{true} return ModelingToolkit.inject_registered_module_functions(oop_ex) else - _build_and_inject_function(@__MODULE__, oop_ex) + _build_and_inject_function(expression_module, oop_ex) end end function _build_and_inject_function(mod::Module, ex) - @RuntimeGeneratedFunction(ModelingToolkit.inject_registered_module_functions(ex)) + @RuntimeGeneratedFunction(mod, ModelingToolkit.inject_registered_module_functions(ex)) end # Detect heterogeneous element types of "arrays of matrices/sparce matrices" @@ -218,6 +219,7 @@ Special Keyword Argumnets: """ function _build_function(target::JuliaTarget, rhss::AbstractArray, args...; conv = toexpr, expression = Val{true}, + expression_module = @__MODULE__, checkbounds = false, linenumbers = false, multithread=nothing, headerfun = addheader, outputidxs=nothing, @@ -457,7 +459,7 @@ function _build_function(target::JuliaTarget, rhss::AbstractArray, args...; if expression == Val{true} return ModelingToolkit.inject_registered_module_functions(oop_ex), ModelingToolkit.inject_registered_module_functions(iip_ex) else - return _build_and_inject_function(@__MODULE__, oop_ex), _build_and_inject_function(@__MODULE__, iip_ex) + return _build_and_inject_function(expression_module, oop_ex), _build_and_inject_function(expression_module, iip_ex) end end diff --git a/src/systems/diffeqs/abstractodesystem.jl b/src/systems/diffeqs/abstractodesystem.jl index 388ad9c3de..a38b2a343d 100644 --- a/src/systems/diffeqs/abstractodesystem.jl +++ b/src/systems/diffeqs/abstractodesystem.jl @@ -124,19 +124,20 @@ function DiffEqBase.ODEFunction{iip}(sys::AbstractODESystem, dvs = states(sys), version = nothing, tgrad=false, jac = false, eval_expression = true, + eval_module = @__MODULE__, sparse = false, simplify = true, kwargs...) where {iip} - f_gen = generate_function(sys, dvs, ps; expression=Val{eval_expression}, kwargs...) - f_oop,f_iip = eval_expression ? (@RuntimeGeneratedFunction(ex) for ex in f_gen) : f_gen + f_gen = generate_function(sys, dvs, ps; expression=Val{eval_expression}, expression_module=eval_module, kwargs...) + f_oop,f_iip = eval_expression ? (@RuntimeGeneratedFunction(eval_module, ex) for ex in f_gen) : f_gen f(u,p,t) = f_oop(u,p,t) f(du,u,p,t) = f_iip(du,u,p,t) if tgrad tgrad_gen = generate_tgrad(sys, dvs, ps; simplify=simplify, - expression=Val{eval_expression}, kwargs...) - tgrad_oop,tgrad_iip = eval_expression ? (@RuntimeGeneratedFunction(ex) for ex in tgrad_gen) : tgrad_gen + expression=Val{eval_expression}, expression_module=eval_module, kwargs...) + tgrad_oop,tgrad_iip = eval_expression ? (@RuntimeGeneratedFunction(eval_module, ex) for ex in tgrad_gen) : tgrad_gen _tgrad(u,p,t) = tgrad_oop(u,p,t) _tgrad(J,u,p,t) = tgrad_iip(J,u,p,t) else @@ -146,8 +147,8 @@ function DiffEqBase.ODEFunction{iip}(sys::AbstractODESystem, dvs = states(sys), if jac jac_gen = generate_jacobian(sys, dvs, ps; simplify=simplify, sparse = sparse, - expression=Val{eval_expression}, kwargs...) - jac_oop,jac_iip = eval_expression ? (@RuntimeGeneratedFunction(ex) for ex in jac_gen) : jac_gen + expression=Val{eval_expression}, expression_module=eval_module, kwargs...) + jac_oop,jac_iip = eval_expression ? (@RuntimeGeneratedFunction(eval_module, ex) for ex in jac_gen) : jac_gen _jac(u,p,t) = jac_oop(u,p,t) _jac(J,u,p,t) = jac_iip(J,u,p,t) else diff --git a/test/precompile_test.jl b/test/precompile_test.jl new file mode 100644 index 0000000000..61abebe55e --- /dev/null +++ b/test/precompile_test.jl @@ -0,0 +1,26 @@ +using Test +using ModelingToolkit + +# Test that the precompiled ODE system works +push!(LOAD_PATH, joinpath(@__DIR__, "precompile_test")) +using ODEPrecompileTest + +u = collect(1:3) +p = collect(4:6) + +# This case does not work, because "f_bad" gets defined in ModelingToolkit +# instead of in the compiled module! +@test parentmodule(typeof(ODEPrecompileTest.f_bad.f.f_iip).parameters[2]) == ModelingToolkit +@test parentmodule(typeof(ODEPrecompileTest.f_bad.f.f_oop).parameters[2]) == ModelingToolkit +@test parentmodule(typeof(ODEPrecompileTest.f_noeval_bad.f.f_iip).parameters[2]) == ModelingToolkit +@test parentmodule(typeof(ODEPrecompileTest.f_noeval_bad.f.f_oop).parameters[2]) == ModelingToolkit +@test_throws KeyError ODEPrecompileTest.f_bad(u, p, 0.1) +@test_throws KeyError ODEPrecompileTest.f_noeval_bad(u, p, 0.1) + +# This case works, because "f_good" gets defined in the precompiled module. +@test parentmodule(typeof(ODEPrecompileTest.f_good.f.f_iip).parameters[2]) == ODEPrecompileTest +@test parentmodule(typeof(ODEPrecompileTest.f_good.f.f_oop).parameters[2]) == ODEPrecompileTest +@test parentmodule(typeof(ODEPrecompileTest.f_noeval_good.f.f_iip).parameters[2]) == ODEPrecompileTest +@test parentmodule(typeof(ODEPrecompileTest.f_noeval_good.f.f_oop).parameters[2]) == ODEPrecompileTest +@test ODEPrecompileTest.f_good(u, p, 0.1) == [4, 0, -16] +@test ODEPrecompileTest.f_noeval_good(u, p, 0.1) == [4, 0, -16] \ No newline at end of file diff --git a/test/precompile_test/ODEPrecompileTest.jl b/test/precompile_test/ODEPrecompileTest.jl new file mode 100644 index 0000000000..453cb0d774 --- /dev/null +++ b/test/precompile_test/ODEPrecompileTest.jl @@ -0,0 +1,32 @@ +module ODEPrecompileTest + using ModelingToolkit + + function system(; kwargs...) + # Define some variables + @parameters t σ ρ β + @variables x(t) y(t) z(t) + @derivatives D'~t + + # Define a differential equation + eqs = [D(x) ~ σ*(y-x), + D(y) ~ x*(ρ-z)-y, + D(z) ~ x*y - β*z] + + de = ODESystem(eqs) + return ODEFunction(de, [x,y,z], [σ,ρ,β]; kwargs...) + end + + # Build an ODEFunction as part of the module's precompilation. This case + # will not work, because the generated RGFs will be put into + # ModelingToolkit's RGF cache. + const f_bad = system() + + # This case will work, because it will be put into our own module's cache. + using RuntimeGeneratedFunctions + RuntimeGeneratedFunctions.init(@__MODULE__) + const f_good = system(; eval_module=@__MODULE__) + + # Also test that eval_expression=false works + const f_noeval_bad = system(; eval_expression=false) + const f_noeval_good = system(; eval_expression=false, eval_module=@__MODULE__) +end \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index 5ad02512c9..0049c05ec1 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -33,3 +33,4 @@ using SafeTestsets, Test println("Last test requires gcc available in the path!") @safetestset "C Compilation Test" begin include("ccompile.jl") end @safetestset "Latexify recipes Test" begin include("latexify.jl") end +@safetestset "Precompiled Modules Test" begin include("precompile_test.jl") end