Skip to content

Commit cb25390

Browse files
fix: bug fixes
1 parent 1211cb8 commit cb25390

File tree

3 files changed

+30
-10
lines changed

3 files changed

+30
-10
lines changed

ext/MTKHomotopyContinuationExt.jl

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ using HomotopyContinuation
88
using ModelingToolkit: iscomplete, parameters, has_index_cache, get_index_cache, get_u0,
99
get_u0_p, check_eqs_u0, CommonSolve
1010

11+
const MTK = ModelingToolkit
12+
1113
function contains_variable(x, wrt)
1214
any(isequal(x), wrt) && return true
1315
istree(x) || return false
@@ -19,6 +21,7 @@ function is_polynomial(x, wrt)
1921
symbolic_type(x) == NotSymbolic() && return true
2022
istree(x) || return true
2123
contains_variable(x, wrt) || return true
24+
any(isequal(x), wrt) && return true
2225

2326
if operation(x) in (*, +, -)
2427
return all(y -> is_polynomial(y, wrt), arguments(x))
@@ -69,8 +72,10 @@ function ModelKit.evaluate_and_jacobian!(u, U, sys::MTKHomotopySystem, x, p = no
6972
sys.jac(U, x, sys.p)
7073
end
7174

72-
function ModelingToolkit.HomotopyContinuationProblem(
73-
sys::NonlinearSystem, u0map, parammap; compile = :all, kwargs...)
75+
SymbolicIndexingInterface.parameter_values(s::MTKHomotopySystem) = s.p
76+
77+
function MTK.HomotopyContinuationProblem(
78+
sys::NonlinearSystem, u0map, parammap; compile = :all, eval_expression = false, eval_module = ModelingToolkit, kwargs...)
7479
if !iscomplete(sys)
7580
error("A completed `NonlinearSystem` is required. Call `complete` or `structural_simplify` on the system before creating a `HomotopyContinuationProblem`")
7681
end
@@ -85,24 +90,29 @@ function ModelingToolkit.HomotopyContinuationProblem(
8590
end
8691
end
8792

88-
nlfn = NonlinearFunction(sys; jac = true)
93+
nlfn = NonlinearFunction(sys; jac = true, eval_expression, eval_module)
8994
hvars = symbolics_to_hc.(dvs)
9095

96+
u0map = MTK.todict(u0map)
97+
parammap = MTK.todict(parammap)
98+
9199
if has_index_cache(sys) && get_index_cache(sys) !== nothing
92100
u0, defs = get_u0(sys, u0map, parammap)
93101
check_eqs_u0(eqs, dvs, u0; kwargs...)
94102
p = MTKParameters(sys, parammap, u0map)
95103
else
96-
u0, p, defs = get_u0_p(sys, u0map, parammap; tofloat, use_union)
104+
u0, p, defs = get_u0_p(sys, u0map, parammap)
97105
check_eqs_u0(eqs, dvs, u0; kwargs...)
98106
end
99107

100108
mtkhsys = MTKHomotopySystem(nlfn.f, p, nlfn.jac, hvars, length(eqs))
101109

102-
return ModelingToolkit.HomotopyContinuationProblem(u0, mtkhsys, sys)
110+
obsfn = MTK.ObservedFunctionCache(sys; eval_expression, eval_module)
111+
112+
return MTK.HomotopyContinuationProblem(u0, mtkhsys, sys, obsfn)
103113
end
104114

105-
function CommonSolve.solve(prob::ModelingToolkit.HomotopyContinuationProblem; kwargs...)
115+
function CommonSolve.solve(prob::MTK.HomotopyContinuationProblem; kwargs...)
106116
sol = HomotopyContinuation.solve(prob.homotopy_continuation_system; kwargs...)
107117
realsols = HomotopyContinuation.results(sol; only_real = true)
108118
if isempty(realsols)
@@ -118,7 +128,7 @@ function CommonSolve.solve(prob::ModelingToolkit.HomotopyContinuationProblem; kw
118128
retcode = SciMLBase.ReturnCode.Success
119129
end
120130

121-
return SciMLBase.build_solution(prob, :HomotopyContinuation, u, resid; retcode)
131+
return SciMLBase.build_solution(prob, :HomotopyContinuation, u, resid; retcode, original = sol)
122132
end
123133

124134
end

src/ModelingToolkit.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -281,4 +281,6 @@ export Clock, SolverStepClock, TimeDomain
281281

282282
export MTKParameters, reorder_dimension_by_tunables!, reorder_dimension_by_tunables
283283

284+
export HomotopyContinuationProblem
285+
284286
end # module

src/systems/nonlinear/nonlinearsystem.jl

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -599,10 +599,11 @@ function Base.:(==)(sys1::NonlinearSystem, sys2::NonlinearSystem)
599599
all(s1 == s2 for (s1, s2) in zip(get_systems(sys1), get_systems(sys2)))
600600
end
601601

602-
struct HomotopyContinuationProblem{uType, H} <: SciMLBase.AbstractNonlinearProblem{uType, true}
602+
struct HomotopyContinuationProblem{uType, H, O} <: SciMLBase.AbstractNonlinearProblem{uType, true}
603603
u0::uType
604604
homotopy_continuation_system::H
605605
sys::NonlinearSystem
606+
obsfn::O
606607
end
607608

608609
function HomotopyContinuationProblem(args...; kwargs...)
@@ -615,8 +616,15 @@ function SymbolicIndexingInterface.set_state!(p::HomotopyContinuationProblem, ar
615616
set_state!(p.u0, args...)
616617
end
617618
function SymbolicIndexingInterface.parameter_values(p::HomotopyContinuationProblem)
618-
parameter_values(p.hcsys)
619+
parameter_values(p.homotopy_continuation_system)
619620
end
620621
function SymbolicIndexingInterface.set_parameter!(p::HomotopyContinuationProblem, args...)
621-
set_parameter!(p.hcsys, args...)
622+
set_parameter!(parameter_values(p), args...)
623+
end
624+
function SymbolicIndexingInterface.observed(p::HomotopyContinuationProblem, sym)
625+
if p.obsfn !== nothing
626+
return p.obsfn(sym)
627+
else
628+
return SymbolicIndexingInterface.observed(p.sys, sym)
629+
end
622630
end

0 commit comments

Comments
 (0)