Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SymbolicUtils integration #326

Merged
merged 8 commits into from Apr 29, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 2 additions & 0 deletions Project.toml
Expand Up @@ -18,6 +18,7 @@ SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b"
TreeViews = "a2a6695c-b41b-5b7d-aed9-dbfdeacea5d7"
UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"
Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"
Expand All @@ -34,6 +35,7 @@ NaNMath = "0.3"
SafeTestsets = "0.0.1"
SpecialFunctions = "0.7, 0.8, 0.9, 0.10"
StaticArrays = "0.10, 0.11, 0.12"
SymbolicUtils = "0.1.1"
TreeViews = "0.3"
UnPack = "0.1"
Unitful = "1.1"
Expand Down
2 changes: 1 addition & 1 deletion src/differentials.jl
Expand Up @@ -52,7 +52,7 @@ function expand_derivatives(O::Operation)
end |> simplify_constants
end

return O
return simplify_constants(O)
end
expand_derivatives(x) = x

Expand Down
125 changes: 40 additions & 85 deletions src/simplify.jl
@@ -1,97 +1,52 @@
function simplify_constants(O::Operation, shorten_tree)
while true
O′ = _simplify_constants(O, shorten_tree)
if is_operation(O′)
O′ = Operation(O′.op, simplify_constants.(O′.args, shorten_tree))
end
isequal(O, O′) && return O
O = O′
import SymbolicUtils
import SymbolicUtils: FnType

# ModelingToolkit -> SymbolicUtils
SymbolicUtils.istree(x::Operation) = true
function SymbolicUtils.operation(x::Operation)
if x.op isa Variable
T = FnType{NTuple{length(x.args), Any}, vartype(x.op)}
SymbolicUtils.Variable{T}(x.op.name)
else
x.op
end
end
simplify_constants(x, shorten_tree) = x

"""
simplify_constants(x::Operation)
# This is required to infer the right type for
# Operation(Variable{Parameter{Number}}(:foo), [])
# While keeping the metadata that the variable is a parameter.
SymbolicUtils.promote_symtype(f::SymbolicUtils.Sym{FnType{X,Parameter{Y}}},
xs...) where {X, Y} = Y

Simplifies the constants within an expression, for example removing equations
multiplied by a zero and summing constant values.
"""
simplify_constants(x) = simplify_constants(x, true)
SymbolicUtils.arguments(x::Operation) = x.args

Base.isone(x::Operation) = x.op == one || x.op == Constant && isone(x.args)
const AC_OPERATORS = (*, +)
# SymbolicUtils wants raw numbers
SymbolicUtils.to_symbolic(x::Constant) = x.value
SymbolicUtils.to_symbolic(x::Variable{T}) where {T} = SymbolicUtils.Sym{T}(x.name)

function _simplify_constants(O::Operation, shorten_tree)
# Tree shrinking
if shorten_tree && O.op ∈ AC_OPERATORS
# Flatten tree
idxs = findall(x -> is_operation(x) && x.op === O.op, O.args)
if !isempty(idxs)
keep_idxs = eachindex(O.args) .∉ (idxs,)
args = Vector{Expression}[O.args[i].args for i in idxs]
push!(args, O.args[keep_idxs])
return Operation(O.op, vcat(args...))
end
# Optional types of vars
# Once converted to SymbolicUtils Variable, a Parameter needs to hide its metadata
_vartype(x::Variable{<:Parameter{T}}) where {T} = T
_vartype(x::Variable{T}) where {T} = T
SymbolicUtils.symtype(x::Variable) = _vartype(x) # needed for a()
SymbolicUtils.symtype(x::SymbolicUtils.Sym{<:Parameter{T}}) where {T} = T

# Collapse constants
idxs = findall(is_constant, O.args)
if length(idxs) > 1
other_idxs = eachindex(O.args) .∉ (idxs,)
new_const = Constant(mapreduce(get, O.op, O.args[idxs]))
args = push!(O.args[other_idxs], new_const)
# returning Any causes SymbolicUtils to infer the type using `promote_symtype`
# But we are OK with Number here for now I guess
SymbolicUtils.symtype(x::Expression) = Number

length(args) == 1 && return first(args)
return Operation(O.op, args)
end
end

if O.op === (*)
# If any variable is `Constant(0)`, zero the whole thing
any(iszero, O.args) && return Constant(0)

# If any variable is `Constant(1)`, remove that `Constant(1)` unless
# they are all `Constant(1)`, in which case simplify to a single variable
if any(isone, O.args)
args = filter(!isone, O.args)

isempty(args) && return Constant(1)
length(args) == 1 && return first(args)
return Operation(O.op, args)
end

return O
end

if O.op === (^) && length(O.args) == 2 && iszero(O.args[2])
return Constant(1)
end

if O.op === (^) && length(O.args) == 2 && isone(O.args[2])
return O.args[1]
end

if O.op === (+) && any(iszero, O.args)
# If there are Constant(0)s in a big `+` expression, get rid of them
args = filter(!iszero, O.args)

isempty(args) && return Constant(0)
length(args) == 1 && return first(args)
return Operation(O.op, args)
end
# SymbolicUtils -> ModelingToolkit

if (O.op === (-) || O.op === (+) || O.op === (*)) && all(is_constant, O.args) && !isempty(O.args)
v = O.args[1].value
for i in 2:length(O.args)
v = O.op(v, O.args[i].value)
end
return Constant(v)
end

(O.op, length(O.args)) === (identity, 1) && return O.args[1]

(O.op, length(O.args)) === (-, 1) && return Operation(*, Expression[-1, O.args[1]])
function simplify_constants(expr)
SymbolicUtils.simplify(expr) |> to_mtk
end

return O
to_mtk(x) = x
to_mtk(x::Number) = Constant(x)
to_mtk(v::SymbolicUtils.Sym{T}) where {T} = Variable{T}(nameof(v))
to_mtk(v::SymbolicUtils.Sym{FnType{X,Y}}) where {X,Y} = Variable{Y}(nameof(v))
function to_mtk(expr::SymbolicUtils.Term)
Operation(to_mtk(SymbolicUtils.operation(expr)),
map(to_mtk, SymbolicUtils.arguments(expr)))
end
_simplify_constants(x, shorten_tree) = x
_simplify_constants(x) = _simplify_constants(x, true)
2 changes: 2 additions & 0 deletions src/systems/diffeqs/abstractodesystem.jl
Expand Up @@ -114,6 +114,8 @@ function calculate_massmatrix(sys::AbstractODESystem, simplify=true)
end
end
M = simplify ? simplify_constants.(M) : M
# M should only contain concrete numbers
M = map(x->x isa Constant ? x.value : x, M)
M == I ? I : M
end

Expand Down
30 changes: 16 additions & 14 deletions test/derivatives.jl
Expand Up @@ -6,6 +6,8 @@ using Test
@variables x y z
@derivatives D'~t D2''~t Dx'~x

test_equal(a, b) = @test isequal(simplify_constants(a), simplify_constants(b))

@test @macroexpand(@derivatives D'~t D2''~t) == @macroexpand(@derivatives (D'~t), (D2''~t))

@test isequal(expand_derivatives(D(t)), 1)
Expand All @@ -15,12 +17,12 @@ dsin = D(sin(t))
@test isequal(expand_derivatives(dsin), cos(t))

dcsch = D(csch(t))
@test isequal(expand_derivatives(dcsch), simplify_constants(coth(t) * csch(t) * -1))
@test isequal(expand_derivatives(dcsch), simplify_constants(-coth(t) * csch(t)))

@test isequal(expand_derivatives(D(-7)), 0)
@test isequal(expand_derivatives(D(sin(2t))), simplify_constants(cos(2t) * 2))
@test isequal(expand_derivatives(D2(sin(t))), simplify_constants(-sin(t)))
@test isequal(expand_derivatives(D2(sin(2t))), simplify_constants(sin(2t) * -4))
@test isequal(expand_derivatives(D2(sin(2t))), simplify_constants(-sin(2t) * 4))
@test isequal(expand_derivatives(D2(t)), 0)
@test isequal(expand_derivatives(D2(5)), 0)

Expand All @@ -30,23 +32,23 @@ dsinsin = D(sin(sin(t)))

d1 = D(sin(t)*t)
d2 = D(sin(t)*cos(t))
@test isequal(expand_derivatives(d1), t*cos(t)+sin(t))
@test isequal(expand_derivatives(d2), simplify_constants(cos(t)*cos(t)+(sin(t)*-1)*sin(t)))
@test isequal(expand_derivatives(d1), simplify_constants(t*cos(t)+sin(t)))
@test isequal(expand_derivatives(d2), simplify_constants(cos(t)*cos(t)+(-sin(t))*sin(t)))

eqs = [0 ~ σ*(y-x),
0 ~ x*(ρ-z)-y,
0 ~ x*y - β*z]
sys = NonlinearSystem(eqs, [x,y,z], [σ,ρ,β])
jac = calculate_jacobian(sys)
@test isequal(jac[1,1], σ*-1)
@test isequal(jac[1,2], σ)
@test isequal(jac[1,3], 0)
@test isequal(jac[2,1], ρ-z)
@test isequal(jac[2,2], -1)
@test isequal(jac[2,3], x*-1)
@test isequal(jac[3,1], y)
@test isequal(jac[3,2], x)
@test isequal(jac[3,3], -1*β)
test_equal(jac[1,1], -1σ)
test_equal(jac[1,2], σ)
test_equal(jac[1,3], 0)
test_equal(jac[2,1], ρ - z)
test_equal(jac[2,2], -1)
test_equal(jac[2,3], -1x)
test_equal(jac[3,1], y)
test_equal(jac[3,2], x)
test_equal(jac[3,3], -)

# Variable dependence checking in differentiation
@variables a(t) b(a)
Expand All @@ -57,7 +59,7 @@ jac = calculate_jacobian(sys)
@variables x(t) y(t) z(t)

@test isequal(expand_derivatives(D(x * y)), simplify_constants(y*D(x) + x*D(y)))
@test_broken isequal(expand_derivatives(D(x * y)), simplify_constants(D(x)*y + x*D(y)))
@test isequal(expand_derivatives(D(x * y)), simplify_constants(D(x)*y + x*D(y)))

@test isequal(expand_derivatives(D(2t)), 2)
@test isequal(expand_derivatives(D(2x)), 2D(x))
Expand Down
8 changes: 5 additions & 3 deletions test/direct.jl
Expand Up @@ -2,6 +2,8 @@ using ModelingToolkit, StaticArrays, LinearAlgebra, SparseArrays
using DiffEqBase
using Test

canonequal(a, b) = isequal(simplify_constants(a), simplify_constants(b))

# Calculus
@parameters t σ ρ β
@variables x y z
Expand All @@ -24,11 +26,11 @@ end
∂ = ModelingToolkit.jacobian(eqs,[x,y,z])
for i in 1:3
∇ = ModelingToolkit.gradient(eqs[i],[x,y,z])
@test isequal(∂[i,:],∇)
@test canonequal(∂[i,:],∇)
end

@test all(isequal.(ModelingToolkit.gradient(eqs[1],[x,y,z]),[σ * -1,σ,0]))
@test all(isequal.(ModelingToolkit.hessian(eqs[1],[x,y,z]),0))
@test all(canonequal.(ModelingToolkit.gradient(eqs[1],[x,y,z]),[σ * -1,σ,0]))
@test all(canonequal.(ModelingToolkit.hessian(eqs[1],[x,y,z]),0))

Joop,Jiip = eval.(ModelingToolkit.build_function(∂,[x,y,z],[σ,ρ,β],t))
J = Joop([1.0,2.0,3.0],[1.0,2.0,3.0],1.0)
Expand Down
6 changes: 3 additions & 3 deletions test/mass_matrix.jl
Expand Up @@ -10,9 +10,9 @@ eqs = [D(y[1]) ~ -k[1]*y[1] + k[3]*y[2]*y[3],

sys = ODESystem(eqs,t,y,k)
M = calculate_massmatrix(sys)
M == [1 0 0
0 1 0
0 0 0]
@test M == [1 0 0
0 1 0
0 0 0]

f = ODEFunction(sys)
prob_mm = ODEProblem(f,[1.0,0.0,0.0],(0.0,1e5),(0.04,3e7,1e4))
Expand Down
20 changes: 11 additions & 9 deletions test/nonlinearsystem.jl
Expand Up @@ -2,6 +2,8 @@ using ModelingToolkit, StaticArrays, LinearAlgebra
using DiffEqBase
using Test

canonequal(a, b) = isequal(simplify_constants(a), simplify_constants(b))

# Define some variables
@parameters t σ ρ β
@variables x y z
Expand Down Expand Up @@ -37,15 +39,15 @@ eqs = [0 ~ σ*(y-x),
ns = NonlinearSystem(eqs, [x,y,z], [σ,ρ,β])
jac = calculate_jacobian(ns)
@testset "nlsys jacobian" begin
@test isequal(jac[1,1], σ * -1)
@test isequal(jac[1,2], σ)
@test isequal(jac[1,3], 0)
@test isequal(jac[2,1], ρ - z)
@test isequal(jac[2,2], -1)
@test isequal(jac[2,3], x * -1)
@test isequal(jac[3,1], y)
@test isequal(jac[3,2], x)
@test isequal(jac[3,3], -1 * β)
@test canonequal(jac[1,1], σ * -1)
@test canonequal(jac[1,2], σ)
@test canonequal(jac[1,3], 0)
@test canonequal(jac[2,1], ρ - z)
@test canonequal(jac[2,2], -1)
@test canonequal(jac[2,3], x * -1)
@test canonequal(jac[3,1], y)
@test canonequal(jac[3,2], x)
@test canonequal(jac[3,3], -1 * β)
end
nlsys_func = generate_function(ns, [x,y,z], [σ,ρ,β])
jac_func = generate_jacobian(ns)
Expand Down
11 changes: 6 additions & 5 deletions test/simplify.jl
Expand Up @@ -14,17 +14,18 @@ identity_op = Operation(identity,[x])
@test isequal(simplify_constants(identity_op), x)

minus_op = -x
@test isequal(simplify_constants(minus_op), -1*x)
@test isequal(simplify_constants(minus_op), -x)
simplify_constants(minus_op)

@variables x

@test simplified_expr(expand_derivatives(Differential(x)((x-2)^2))) == :((x-2) * 2)
@test simplified_expr(expand_derivatives(Differential(x)((x-2)^3))) == :((x-2)^2 * 3)
@test simplified_expr(simplify_constants(x+2+3)) == :(x + 5)
@test simplified_expr(expand_derivatives(Differential(x)((x-2)^2))) == :(2 * (-2 + x))
@test simplified_expr(expand_derivatives(Differential(x)((x-2)^3))) == :(3 * (-2 + x)^2)
@test simplified_expr(simplify_constants(x+2+3)) == :(5 + x)

d1 = Differential(x)((x-2)^2)
d1 = Differential(x)((-2 + x)^2)
d2 = Differential(x)(d1)
d3 = Differential(x)(d2)

@test simplified_expr(expand_derivatives(d3)) == :(0)
@test simplified_expr(simplify_constants(x^0)) == :(1)