Skip to content

Commit

Permalink
introduce cache array to make discrete-time state update atomic
Browse files Browse the repository at this point in the history
closes #2701
  • Loading branch information
baggepinnen committed May 27, 2024
1 parent e1befe0 commit 61b8ea9
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 3 deletions.
10 changes: 7 additions & 3 deletions src/systems/clock_inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,8 @@ function generate_discrete_affect(
end)
end

cache = copy(p[$disc_range]) # Cache needed for atomic state update

# @show disc_to_cont_idxs
# @show cont_to_disc_idxs
# @show disc_range
Expand Down Expand Up @@ -368,14 +370,16 @@ function generate_discrete_affect(
if use_index_cache
quote
if !$empty_disc
disc(disc_unknowns, integrator.u, p..., t)
for (val, i) in zip(disc_unknowns, $disc_range)
# NOTE: the first and third arguments to `disc` MAY NOT be aliased
disc(cache, integrator.u, p..., t) # Cache needed for atomic state update
for (val, i) in zip(cache, $disc_range)
$(_set_parameter_unchecked!)(p, val, i; update_dependent = false)
end
end
end
else
:($empty_disc || disc(disc_unknowns, disc_unknowns, p, t))
:($empty_disc || disc(cache, disc_unknowns, p, t)) # Cache needed for atomic state update
copyto!(disc_unknowns, cache)
end
)
# @show "after state update", p
Expand Down
29 changes: 29 additions & 0 deletions test/clock.jl
Original file line number Diff line number Diff line change
Expand Up @@ -528,3 +528,32 @@ prob = ODEProblem(sys, [], (0.0, 10.0), [x(k - 1) => 2.0])
int = init(prob, Tsit5(); kwargshandle = KeywordArgSilent)
@test int.ps[x] == 3.0
@test int.ps[x(k - 1)] == 2.0

## Atomic state upadte tested using simple delay https://github.com/SciML/ModelingToolkit.jl/issues/2701

using ModelingToolkit: t_nounits as t, D_nounits as D
##
k = ShiftIndex(Clock(t, 1))
@mtkmodel DelayModel begin
@variables begin
input(t) = 0
delay(t) = 0
x(t) = 0
end
@structural_parameters begin
d
end
@equations begin
input ~ (t >= 2)
delay(k) ~ input(k - d)
D(x) ~ (-x + Hold(delay)) / 1e-3
end
end

for d in 0:3
@mtkbuild m = DelayModel(; d)
prob = ODEProblem(
m, [m.delay(k - 3) => 0, m.delay(k - 2) => 0, m.delay(k - 1) => 0], (0.0, 10.0))
sol = solve(prob, Tsit5(), kwargshandle = KeywordArgSilent, dtmax = 0.5)
@test reduce(vcat, sol((0:1:10) .+ 0.1)[:])[zeros(2 + d); ones(10 - 1 - d)] atol=1e-3
end

0 comments on commit 61b8ea9

Please sign in to comment.