diff --git a/src/differentials.jl b/src/differentials.jl index c725f0360d..08bd80de3d 100644 --- a/src/differentials.jl +++ b/src/differentials.jl @@ -1,10 +1,8 @@ struct Differential <: Function x::Expression - order::Int end -Differential(x) = Differential(x,1) -Base.show(io::IO, D::Differential) = print(io,"($(D.x),$(D.order))") +Base.show(io::IO, D::Differential) = print(io, "(D'~", D.x, ")") Base.convert(::Type{Expr}, D::Differential) = D (D::Differential)(x::Operation) = Operation(D, Expression[x]) @@ -13,8 +11,8 @@ function (D::Differential)(x::Variable) has_dependent(x, D.x) || return Constant(0) return Operation(D, Expression[x]) end -(::Differential)(::Constant) = Constant(0) -Base.:(==)(D1::Differential, D2::Differential) = D1.order == D2.order && D1.x == D2.x +(::Differential)(::Any) = Constant(0) +Base.:(==)(D1::Differential, D2::Differential) = D1.x == D2.x function expand_derivatives(O::Operation) @. O.args = expand_derivatives(O.args) @@ -56,6 +54,7 @@ function count_order(x) n, x.args[1] end +_repeat_apply(f, n) = n == 1 ? f : f ∘ _repeat_apply(f, n-1) function _differential_macro(x) ex = Expr(:block) lhss = Symbol[] @@ -66,7 +65,7 @@ function _differential_macro(x) rhs = di.args[3] order, lhs = count_order(lhs) push!(lhss, lhs) - expr = :($lhs = Differential($rhs, $order)) + expr = :($lhs = $_repeat_apply(Differential($rhs), $order)) push!(ex.args, expr) end push!(ex.args, Expr(:tuple, lhss...)) diff --git a/src/systems/diffeqs/diffeqsystem.jl b/src/systems/diffeqs/diffeqsystem.jl index a3becf8d02..ad1e35393c 100644 --- a/src/systems/diffeqs/diffeqsystem.jl +++ b/src/systems/diffeqs/diffeqsystem.jl @@ -6,17 +6,28 @@ using Base: RefValue isintermediate(eq::Equation) = !(isa(eq.lhs, Operation) && isa(eq.lhs.op, Differential)) -struct DiffEq # D(x) = t - D::Differential # D - var::Variable # x - rhs::Expression # t +function flatten_differential(O::Operation) + @assert is_derivative(O) "invalid differential: $O" + is_derivative(O.args[1]) || return (O.args[1], O.op.x, 1) + (x, t, order) = flatten_differential(O.args[1]) + t == O.op.x || throw(ArgumentError("non-matching differentials on lhs: $t, $(O.op.x)")) + return (x, t, order + 1) +end + + +struct DiffEq # dⁿx/dtⁿ = rhs + x::Expression + t::Variable + n::Int + rhs::Expression end function Base.convert(::Type{DiffEq}, eq::Equation) isintermediate(eq) && throw(ArgumentError("intermediate equation received")) - return DiffEq(eq.lhs.op, eq.lhs.args[1], eq.rhs) + (x, t, n) = flatten_differential(eq.lhs) + return DiffEq(x, t, n, eq.rhs) end -Base.:(==)(a::DiffEq, b::DiffEq) = (a.D, a.var, a.rhs) == (b.D, b.var, b.rhs) -get_args(eq::DiffEq) = Expression[eq.var, eq.rhs] +Base.:(==)(a::DiffEq, b::DiffEq) = (a.x, a.t, a.n, a.rhs) == (b.x, b.t, b.n, b.rhs) +get_args(eq::DiffEq) = Expression[eq.x, eq.t, eq.rhs] struct DiffEqSystem <: AbstractSystem eqs::Vector{DiffEq} diff --git a/src/systems/diffeqs/first_order_transform.jl b/src/systems/diffeqs/first_order_transform.jl index b2ed598195..706de49f9a 100644 --- a/src/systems/diffeqs/first_order_transform.jl +++ b/src/systems/diffeqs/first_order_transform.jl @@ -1,33 +1,27 @@ -function lower_varname(D::Differential, x; lower=false) - order = lower ? D.order-1 : D.order - return lower_varname(x, D.x, order) -end -function lower_varname(var::Variable, idv, order::Int) - sym = var.name - name = order == 0 ? sym : Symbol(sym, :_, string(idv.name)^order) +function lower_varname(var::Variable, idv, order) + order == 0 && return var + name = Symbol(var.name, :_, string(idv.name)^order) return Variable(name, var.known, var.dependents) end function ode_order_lowering(sys::DiffEqSystem) eqs_lowered = ode_order_lowering(sys.eqs, sys.iv) - DiffEqSystem(eqs_lowered, sys.iv) + DiffEqSystem(eqs_lowered, sys.iv, sys.dvs, sys.ps) end function ode_order_lowering(eqs, iv) - D = Differential(iv, 1) var_order = Dict{Variable,Int}() vars = Variable[] new_eqs = similar(eqs, DiffEq) for (i, eq) ∈ enumerate(eqs) - var, maxorder = eq.var, eq.D.order - maxorder == 1 && continue # fast pass + var, maxorder = eq.x, eq.n if maxorder > get(var_order, var, 0) var_order[var] = maxorder var ∈ vars || push!(vars, var) end - var′ = lower_varname(eq.D, eq.var, lower = true) + var′ = lower_varname(eq.x, eq.t, eq.n - 1) rhs′ = rename(eq.rhs) - new_eqs[i] = DiffEq(D, var′, rhs′) + new_eqs[i] = DiffEq(var′, iv, 1, rhs′) end for var ∈ vars @@ -35,7 +29,7 @@ function ode_order_lowering(eqs, iv) for o in (order-1):-1:1 lvar = lower_varname(var, iv, o-1) rhs = lower_varname(var, iv, o) - eq = DiffEq(D, lvar, rhs) + eq = DiffEq(lvar, iv, 1, rhs) push!(new_eqs, eq) end end @@ -45,7 +39,10 @@ end function rename(O::Expression) isa(O, Operation) || return O - isa(O.op, Differential) && return lower_varname(O.op, O.args[1]) + if is_derivative(O) + (x, t, order) = flatten_differential(O) + return lower_varname(x, t, order) + end return Operation(O.op, rename.(O.args)) end diff --git a/src/utils.jl b/src/utils.jl index d9f2935c5d..f5e83aeae3 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -60,6 +60,9 @@ is_constant(::Any) = false is_operation(::Operation) = true is_operation(::Any) = false +is_derivative(O::Operation) = isa(O.op, Differential) +is_derivative(::Any) = false + has_dependent(t::Variable) = Base.Fix2(has_dependent, t) has_dependent(x::Variable, t::Variable) = t ∈ x.dependents || any(has_dependent(t), x.dependents) diff --git a/src/variables.jl b/src/variables.jl index 473a35409b..32f4a9ff79 100644 --- a/src/variables.jl +++ b/src/variables.jl @@ -38,12 +38,7 @@ function Base.convert(::Type{Expr}, x::Variable) end Base.convert(::Type{Expr}, c::Constant) = c.value -function Base.show(io::IO, x::Variable) - subtype = x.known ? :Parameter : :Unknown - print(io, subtype, '(', repr(x.name)) - isempty(x.dependents) || print(io, ", ", x.dependents) - print(io, ')') -end +Base.show(io::IO, x::Variable) = print(io, x.name) # Build variables more easily function _parse_vars(macroname, fun, x) diff --git a/test/derivatives.jl b/test/derivatives.jl index b9b346311d..768fd80b14 100644 --- a/test/derivatives.jl +++ b/test/derivatives.jl @@ -15,6 +15,13 @@ dsin = D(sin(t)) dcsch = D(csch(t)) @test expand_derivatives(dcsch) == simplify_constants(coth(t) * csch(t) * -1) +@test expand_derivatives(D(-7)) == 0 +@test expand_derivatives(D(sin(2t))) == simplify_constants(cos(2t) * 2) +@test expand_derivatives(D2(sin(t))) == simplify_constants(-sin(t)) +@test expand_derivatives(D2(sin(2t))) == simplify_constants(sin(2t) * -4) +@test expand_derivatives(D2(t)) == 0 +@test expand_derivatives(D2(5)) == 0 + # Chain rule dsinsin = D(sin(sin(t))) @test expand_derivatives(dsinsin) == cos(sin(t))*cos(t) diff --git a/test/variable_parsing.jl b/test/variable_parsing.jl index 95080830a2..96ba18aecd 100644 --- a/test/variable_parsing.jl +++ b/test/variable_parsing.jl @@ -27,7 +27,7 @@ s1 = Parameter(:s) @test convert(Expr, s) == :s @test convert(Expr, cos(t + sin(s))) == :(cos(t + sin(s))) -@Deriv D''~t -D1 = Differential(t, 2) +@Deriv D'~t +D1 = Differential(t) @test D1 == D @test convert(Expr, D) == D