diff --git a/src/systems/clock_inference.jl b/src/systems/clock_inference.jl index ee9e92ad03..76c91f61c6 100644 --- a/src/systems/clock_inference.jl +++ b/src/systems/clock_inference.jl @@ -285,6 +285,7 @@ function is_time_domain_conversion(v) o isa Operator || return false itd = input_timedomain(o) allequal(itd) || return true + isempty(itd) && return true otd = output_timedomain(o) itd[1] == otd || return true return false diff --git a/src/systems/systemstructure.jl b/src/systems/systemstructure.jl index b5066a3e98..c867306c34 100644 --- a/src/systems/systemstructure.jl +++ b/src/systems/systemstructure.jl @@ -456,7 +456,8 @@ function TearingState(sys; quick_cancel = false, check = true, sort_eqs = true) if !symbolic_contains(v, dvs) isvalid = iscall(v) && - (operation(v) isa Shift || is_transparent_operator(operation(v))) + (operation(v) isa Shift || isempty(arguments(v)) || + is_transparent_operator(operation(v))) v′ = v while !isvalid && iscall(v′) && operation(v′) isa Union{Differential, Shift} v′ = arguments(v′)[1] diff --git a/src/utils.jl b/src/utils.jl index 0da7e4860b..355c4858d2 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -633,15 +633,21 @@ can be checked using `check_scope_depth`. This function should return `nothing`. """ function collect_vars!(unknowns, parameters, expr, iv; depth = 0, op = Symbolics.Operator) + expr = unwrap(expr) if issym(expr) return collect_var!(unknowns, parameters, expr, iv; depth) end - for var in vars(expr; op) - while iscall(var) && operation(var) isa op - validate_operator(operation(var), arguments(var), iv; context = expr) - var = arguments(var)[1] + varsbuf = OrderedSet() + vars!(varsbuf, expr; op) + for var in varsbuf + if iscall(var) && operation(var) isa op + args = arguments(var) + validate_operator(operation(var), args, iv; context = expr) + isempty(args) && continue + push!(varsbuf, args[1]) + else + collect_var!(unknowns, parameters, var, iv; depth) end - collect_var!(unknowns, parameters, var, iv; depth) end return nothing end @@ -1184,4 +1190,4 @@ function wrap_with_D(n, D, repeats) else wrap_with_D(D(n), D, repeats - 1) end -end \ No newline at end of file +end diff --git a/test/clock.jl b/test/clock.jl index bb16884fe7..6eaed52b73 100644 --- a/test/clock.jl +++ b/test/clock.jl @@ -1,6 +1,7 @@ using ModelingToolkit, Test, Setfield, OrdinaryDiffEq, DiffEqCallbacks using ModelingToolkit: ContinuousClock using ModelingToolkit: t_nounits as t, D_nounits as D +using Symbolics, SymbolicUtils function infer_clocks(sys) ts = TearingState(sys) @@ -146,6 +147,29 @@ eqs = [yd ~ Sample(dt)(y) @test varmap[z] == clk end +struct ZeroArgOp <: Symbolics.Operator end +(o::ZeroArgOp)() = Symbolics.Term{Bool}(o, Any[]) +SymbolicUtils.promote_symtype(::ZeroArgOp, T) = Union{Bool, T} +SymbolicUtils.isbinop(::ZeroArgOp) = false +Base.nameof(::ZeroArgOp) = :ZeroArgOp +ModelingToolkit.input_timedomain(::ZeroArgOp, _ = nothing) = () +ModelingToolkit.output_timedomain(::ZeroArgOp, _ = nothing) = Clock(0.1) +ModelingToolkit.validate_operator(::ZeroArgOp, args, iv; context = nothing) = nothing +SciMLBase.is_discrete_time_domain(::ZeroArgOp) = true + +@testset "Zero-argument clock operators" begin + @variables x(t) y(t) + clk = Clock(0.1) + eqs = [D(x) ~ x + y ~ ZeroArgOp()()] + @named sys = System(eqs, t) + @test issetequal(unknowns(sys), [x, y]) + ts = TearingState(sys) + @test issetequal(ts.fullvars, [D(x), x, y, ZeroArgOp()()]) + ci, clkmap = infer_clocks(sys) + @test clkmap[ZeroArgOp()()] == clk +end + @test_skip begin Tf = 1.0 prob = ODEProblem(