diff --git a/src/systems/diffeqs/diffeqsystem.jl b/src/systems/diffeqs/diffeqsystem.jl index 24204c4832..37db56e41b 100644 --- a/src/systems/diffeqs/diffeqsystem.jl +++ b/src/systems/diffeqs/diffeqsystem.jl @@ -76,31 +76,29 @@ function generate_function(sys::DiffEqSystem; version::FunctionVersion = ArrayFu end -function generate_ode_iW(sys::DiffEqSystem, simplify=true; version::FunctionVersion = ArrayFunction) +function generate_factorized_W(sys::DiffEqSystem, simplify=true; version::FunctionVersion = ArrayFunction) jac = calculate_jacobian(sys) gam = Variable(:gam; known = true) W = LinearAlgebra.I - gam*jac - W = SMatrix{size(W,1),size(W,2)}(W) - iW = inv(W) + Wfact = lu(W, Val(false), check=false).factors if simplify - iW = simplify_constants.(iW) + Wfact = simplify_constants.(Wfact) end - W = inv(LinearAlgebra.I/gam - jac) - W = SMatrix{size(W,1),size(W,2)}(W) - iW_t = inv(W) + W_t = LinearAlgebra.I/gam - jac + Wfact_t = lu(W_t, Val(false), check=false).factors if simplify - iW_t = simplify_constants.(iW_t) + Wfact_t = simplify_constants.(Wfact_t) end vs, ps = sys.dvs, sys.ps - iW_func = build_function(iW , vs, ps, (:gam,:t); version = version) - iW_t_func = build_function(iW_t, vs, ps, (:gam,:t); version = version) + Wfact_func = build_function(Wfact , vs, ps, (:gam,:t); version = version) + Wfact_t_func = build_function(Wfact_t, vs, ps, (:gam,:t); version = version) - return (iW_func, iW_t_func) + return (Wfact_func, Wfact_t_func) end function DiffEqBase.ODEFunction(sys::DiffEqSystem; version::FunctionVersion = ArrayFunction) diff --git a/test/system_construction.jl b/test/system_construction.jl index 6f7b106504..dc103a9e17 100644 --- a/test/system_construction.jl +++ b/test/system_construction.jl @@ -16,7 +16,15 @@ generate_function(de;version=ModelingToolkit.SArrayFunction) jac_expr = generate_jacobian(de) jac = calculate_jacobian(de) f = ODEFunction(de) -ModelingToolkit.generate_ode_iW(de) +fw, fwt = map(eval, ModelingToolkit.generate_factorized_W(de)) +du = zeros(3) +u = collect(1:3) +p = collect(4:6) +f(du, u, p, 0.1) +@test du == [4, 0, -16] +FW = zeros(3, 3) +fw(FW, u, p, 0.2, 0.1) +fwt(FW, u, p, 0.2, 0.1) # Differential equation with automatic extraction of variables de2 = DiffEqSystem(eqs, t)