Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor Constant representation #79

Merged
merged 5 commits into from
Jan 11, 2019
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/ModelingToolkit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ include("domains.jl")
include("variables.jl")

Base.promote_rule(::Type{T},::Type{T2}) where {T<:Number,T2<:Expression} = Expression
Base.one(::Type{T}) where T<:Expression = Constant(1)
Base.zero(::Type{T}) where T<:Expression = Constant(0)
Base.zero(::Type{<:Expression}) = Constant(0)
Base.one(::Type{<:Expression}) = Constant(1)
Base.convert(::Type{Variable},x::Int64) = Constant(x)

function caclulate_jacobian end
Expand Down
4 changes: 2 additions & 2 deletions src/differentials.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ function (D::Differential)(x::Variable)
end
Base.:(==)(D1::Differential, D2::Differential) = D1.order == D2.order && D1.x == D2.x

Variable(x::Variable, D::Differential) = Variable(x.name,x.value,x.value_type,
Variable(x::Variable, D::Differential) = Variable(x.name,x.value_type,
x.subtype,D,x.dependents,x.description,x.flow,x.domain,
x.size,x.context)

Expand All @@ -31,7 +31,7 @@ function expand_derivatives(O::Operation)

return O
end
expand_derivatives(x::Variable) = x
expand_derivatives(x) = x

# Don't specialize on the function here
function Derivative(O::Operation,idx)
Expand Down
22 changes: 11 additions & 11 deletions src/operations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@ end
function Base.:(==)(x::Operation,y::Operation)
x.op == y.op && length(x.args) == length(y.args) && all(isequal.(x.args,y.args))
end
Base.:(==)(x::Operation,y::Number) = false
Base.:(==)(x::Number,y::Operation) = false
Base.:(==)(x::Operation,y::Nothing) = false
Base.:(==)(x::Nothing,y::Operation) = false
Base.:(==)(x::Variable,y::Operation) = false
Base.:(==)(x::Operation,y::Variable) = false
Base.:(==)(::Operation, ::Number ) = false
Base.:(==)(::Number , ::Operation) = false
Base.:(==)(::Operation, ::Variable ) = false
Base.:(==)(::Variable , ::Operation) = false
Base.:(==)(::Operation, ::Constant ) = false
Base.:(==)(::Constant , ::Operation) = false

Base.convert(::Type{Expr}, O::Operation) =
build_expr(:call, Any[Symbol(O.op); convert.(Expr, O.args)])
Expand All @@ -36,8 +36,8 @@ function find_replace!(O::Operation,x::Variable,y::Expression)
end

# For inv
Base.convert(::Type{Operation},x::Int) = Operation(identity,Expression[Constant(x)])
Base.convert(::Type{Operation},x::Bool) = Operation(identity,Expression[Constant(x)])
Base.convert(::Type{Operation},x::Variable) = Operation(identity,Expression[x])
Operation(x) = convert(Operation,x)
Operation(x::Operation) = x
Base.convert(::Type{Operation}, x::Int) = Operation(identity, Expression[Constant(x)])
Base.convert(::Type{Operation}, x::Bool) = Operation(identity, Expression[Constant(x)])
Base.convert(::Type{Operation}, x::Operation) = x
Base.convert(::Type{Operation}, x::Expression) = Operation(identity, Expression[x])
Operation(x) = convert(Operation, x)
11 changes: 7 additions & 4 deletions src/simplify.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
function simplify_constants(O::Operation, shorten_tree = true)
function simplify_constants(O::Operation, shorten_tree)
while true
O′ = _simplify_constants(O, shorten_tree)
if is_operation(O′)
Expand All @@ -8,10 +8,13 @@ function simplify_constants(O::Operation, shorten_tree = true)
O = O′
end
end
simplify_constants(x, shorten_tree) = x
simplify_constants(x) = simplify_constants(x, true)


const AC_OPERATORS = (*, +)

function _simplify_constants(O, shorten_tree = true)
function _simplify_constants(O::Operation, shorten_tree)
# Tree shrinking
if shorten_tree && O.op ∈ AC_OPERATORS
# Flatten tree
Expand Down Expand Up @@ -67,7 +70,7 @@ function _simplify_constants(O, shorten_tree = true)

return O
end
simplify_constants(x::Variable, y=false) = x
_simplify_constants(x::Variable, y=false) = x
_simplify_constants(x, shorten_tree) = x
_simplify_constants(x) = _simplify_constants(x, true)

export simplify_constants
2 changes: 1 addition & 1 deletion src/systems/diffeqs/first_order_transform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ function lower_varname(var::Variable, naming_scheme; lower=false)
lower_varname(var.name, D.x, order, var.subtype, naming_scheme)
end
function lower_varname(sym::Symbol, idv, order::Int, subtype::Symbol, naming_scheme)
order == 0 && return Variable(sym, subtype)
order == 0 && return Variable(sym, subtype=subtype)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What a subtle bug!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah...

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does that fix it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep - I also changed this function afterwards, to make the goal more explicit.

name = Symbol(String(sym)*naming_scheme*String(idv.name)^order)
Variable(name, subtype=subtype)
end
Expand Down
2 changes: 1 addition & 1 deletion src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ end

toexpr(ex) = MacroTools.postwalk(x -> isa(x, Expression) ? convert(Expr, x) : x, ex)

is_constant(x::Variable) = x.subtype === :Constant
is_constant(::Constant) = true
is_constant(::Any) = false

is_operation(::Operation) = true
Expand Down
115 changes: 40 additions & 75 deletions src/variables.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
mutable struct Variable <: Expression
name::Symbol
value
value_type::DataType
subtype::Symbol
diff::Union{Function,Nothing} # FIXME
Expand All @@ -13,26 +12,23 @@ mutable struct Variable <: Expression
end

Variable(name,
value = nothing,
value_type = typeof(value);
value_type = Any;
subtype::Symbol=:Variable,
dependents::Vector{Variable} = Variable[],
flow::Bool = false,
description::String = "",
domain = Reals(),
size = nothing,
context = nothing) =
Variable(name,value,value_type,subtype,nothing,
Variable(name,value_type,subtype,nothing,
dependents,description,flow,domain,size,context)
Variable(name,args...;kwargs...) = Variable(name,args...;subtype=:Variable,kwargs...)

Variable(name,x::Variable) = Variable(name,x.value,x.value_type,
Variable(name,x::Variable) = Variable(name,x.value_type,
x.subtype,D,x.dependents,x.description,x.flow,x.domain,
x.size,x.context)

Parameter(name,args...;kwargs...) = Variable(name,args...;subtype=:Parameter,kwargs...)
Constant(value::Number) = Variable(Symbol(value),value,typeof(value);subtype=:Constant)
Constant(name,args...;kwargs...) = Variable(name,args...;subtype=:Constant,kwargs...)
IndependentVariable(name,args...;kwargs...) = Variable(name,args...;subtype=:IndependentVariable,kwargs...)

function DependentVariable(name,args...;dependents = [],kwargs...)
Expand Down Expand Up @@ -64,53 +60,38 @@ export Variable,Parameter,Constant,DependentVariable,IndependentVariable,JumpVar
@Var, @DVar, @IVar, @Param, @Const


Base.get(x::Variable) = x.value
struct Constant <: Expression
value::Number
end
Base.get(c::Constant) = c.value


Base.iszero(::Expression) = false
Base.iszero(c::Variable) = get(c) isa Number && iszero(get(c))
Base.isone(::Expression) = false
Base.isone(c::Variable) = get(c) isa Number && isone(get(c))
Base.iszero(ex::Expression) = isa(ex, Constant) && iszero(ex.value)
Base.isone(ex::Expression) = isa(ex, Constant) && isone(ex.value)


# Variables use isequal for equality since == is an Operation
function Base.:(==)(x::Variable,y::Variable)
x.name == y.name && x.subtype == y.subtype && x.value == y.value &&
function Base.:(==)(x::Variable, y::Variable)
x.name == y.name && x.subtype == y.subtype &&
x.value_type == y.value_type && x.diff == y.diff
end

function Base.:(==)(x::Variable,y::Number)
x == Constant(y)
end

function Base.:(==)(x::Number,y::Variable)
Constant(x) == y
end
Base.:(==)(::Variable, ::Number) = false
Base.:(==)(::Number, ::Variable) = false
Base.:(==)(::Variable, ::Constant) = false
Base.:(==)(::Constant, ::Variable) = false
Base.:(==)(c::Constant, n::Number) = c.value == n
Base.:(==)(n::Number, c::Constant) = c.value == n
Base.:(==)(a::Constant, b::Constant) = a.value == b.value

function Base.convert(::Type{Expr}, x::Variable)
if x.subtype == :Constant
return x.value
elseif x.diff == nothing
return :($(x.name))
else
return :($(Symbol("$(x.name)_$(x.diff.x.name)")))
end
x.diff === nothing && return x.name
return Symbol("$(x.name)_$(x.diff.x.name)")
end
Base.convert(::Type{Expr}, c::Constant) = c.value

function Base.show(io::IO, A::Variable)
if A.subtype == :Constant
print(io,"Constant($(A.value))")
else
str = "$(A.subtype)($(A.name))"
if A.value != nothing
str *= ", value = " * string(A.value)
end

if A.diff != nothing
str *= ", diff = " * string(A.diff)
end

print(io,str)
end
function Base.show(io::IO, x::Variable)
print(io, x.subtype, '(', x.name, ')')
x.diff === nothing || print(io, ", diff = ", x.diff)
end

extract_idv(eq) = eq.args[1].diff.x
Expand Down Expand Up @@ -159,45 +140,29 @@ function _parse_vars(macroname, fun, x)
# begin
# x
# y
# z = exp(2)
# z
# end
x = flatten_expr!(x)
for _var in x
iscall = typeof(_var) <: Expr && _var.head == :call
issym = _var isa Symbol
isassign = issym ? false : _var.head == :(=)
@assert iscall || issym || isassign "@$macroname expects a tuple of expressions!\nE.g. `@$macroname x y z=1`"
if iscall || issym
if iscall
dependents = :([$(_var.args[2:end]...)])
var = _var.args[1]
else
dependents = Variable[]
var = _var
end
lhs = var
push!(lhss, lhs)
expr = :( $lhs = $fun( Symbol($(String(lhs))) ,
dependents = $dependents))
end
if isassign
iscall = typeof(_var.args[1]) <: Expr && _var.args[1].head == :call
if iscall
dependents = :([$(_var.args[1].args[2:end]...)])
lhs = _var.args[1].args[1]
else
dependents = Variable[]
lhs = _var.args[1]
end
rhs = _var.args[2]
push!(lhss, lhs)
expr = :( $lhs = $fun( Symbol($(String(lhs))) , $rhs,
dependents = $dependents))
@assert iscall || issym "@$macroname expects a tuple of expressions!\nE.g. `@$macroname x y z`"

if iscall
dependents = :([$(_var.args[2:end]...)])
lhs = _var.args[1]
else
dependents = Variable[]
lhs = _var
end

push!(lhss, lhs)
expr = :( $lhs = $fun( Symbol($(String(lhs))) ,
dependents = $dependents))
push!(ex.args, expr)
end
push!(ex.args, Expr(:tuple, lhss...))
ex
push!(ex.args, build_expr(:tuple, lhss))
return ex
end

for funs in ((:DVar, :DependentVariable), (:IVar, :IndependentVariable),
Expand Down
8 changes: 4 additions & 4 deletions test/basic_variables_and_operations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@ using Test
@Const c=0

# Default values
p = Parameter(:p, 1)
u = DependentVariable(:u, [1], dependents = [t])
p = Parameter(:p)
u = DependentVariable(:u, dependents = [t])

s = JumpVariable(:s,3,dependents=[t])
n = NoiseVariable(:n,dependents=[t])
s = JumpVariable(:s, dependents=[t])
n = NoiseVariable(:n, dependents=[t])

σ*(y-x)
D(x)
Expand Down
2 changes: 1 addition & 1 deletion test/derivatives.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ expand_derivatives(dsin)

@test expand_derivatives(dsin) == cos(t)
dcsch = D(csch(t))
@test expand_derivatives(dcsch) == simplify_constants(Operation(coth(t)*csch(t)*-1))
@test expand_derivatives(dcsch) == simplify_constants(coth(t) * csch(t) * -1)

# Chain rule
dsinsin = D(sin(sin(t)))
Expand Down
5 changes: 3 additions & 2 deletions test/system_construction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,14 +50,15 @@ function test_eqs(eqs1, eqs2)
eq = true
for i in eachindex(eqs1)
lhs1, lhs2 = eqs1[i].args[1], eqs2[i].args[1]
typeof(lhs1) === typeof(lhs2) || return false
for f in fieldnames(typeof(lhs1))
eq = eq && isequal(getfield(lhs1, f), getfield(lhs2, f))
end
eq = eq && isequal(eqs1[i].args[2], eqs2[i].args[2])
end
@test_broken eq
eq
end
test_eqs(de1.eqs, lowered_eqs)
@test_broken test_eqs(de1.eqs, lowered_eqs)

# Internal calculations
eqs = [a ~ y-x,
Expand Down
18 changes: 9 additions & 9 deletions test/variable_parsing.jl
Original file line number Diff line number Diff line change
@@ -1,23 +1,23 @@
using ModelingToolkit
using Test

@Var a=1.0 b
a1 = Variable(:a,1.0)
@Var a b
a1 = Variable(:a)
@test a1 == a
@test convert(Expr, a) == :a

@Var begin
a = 1.0
a
b
end

@IVar t
@DVar x(t)
@DVar y(t)=sin(1)+exp(1)
@DVar y(t)
@DVar z(t)
x1 = DependentVariable(:x,dependents = [t])
y1 = DependentVariable(:y, sin(1) + exp(1),dependents = [t])
z1 = DependentVariable(:z,dependents = [t])
x1 = DependentVariable(:x ,dependents = [t])
y1 = DependentVariable(:y ,dependents = [t])
z1 = DependentVariable(:z ,dependents = [t])
@test x1 == x
@test y1 == y
@test z1 == z
Expand All @@ -27,10 +27,10 @@ z1 = DependentVariable(:z,dependents = [t])

@IVar begin
t
s = cos(2.5)
s
end
t1 = IndependentVariable(:t)
s1 = IndependentVariable(:s, cos(2.5))
s1 = IndependentVariable(:s)
@test t1 == t
@test s1 == s
@test convert(Expr, t) == :t
Expand Down