Skip to content

Commit

Permalink
Merge pull request #79 from JuliaDiffEq/hg/fix/constant
Browse files Browse the repository at this point in the history
Refactor Constant representation
  • Loading branch information
ChrisRackauckas committed Jan 11, 2019
2 parents bc7e230 + 9bff84f commit 0097578
Show file tree
Hide file tree
Showing 12 changed files with 87 additions and 120 deletions.
10 changes: 5 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -147,16 +147,12 @@ In this section we define the core pieces of the IR and what they mean.

### Variables

The simplest piece of the IR is the `Variable`. The `Variable` is the
The most fundamental part of the IR is the `Variable`. The `Variable` is the
context-aware single variable of the IR. Its fields are described as follows:

- `name`: the name of the `Variable`. Note that this is not necessarily
the same as the name of the Julia variable. But this symbol itself is considered
the core identifier of the `Variable` in the sense of equality.
- `value`: the value of the `Variable`. The meaning of the value can be
interpreted differently for different systems, but in most cases it's tied to
whatever value information would be required for the system to be well-defined
such as the initial condition of a differential equation.
- `value_type`: the type that the values have to be. It's disconnected
from the `value` because in many cases the `value` may not be able to be
specified in advance even when we may already know the type. This can be used
Expand All @@ -179,6 +175,10 @@ context-aware single variable of the IR. Its fields are described as follows:
- `context`: this is an open field for DSLs to carry along more context
in the variables, but is not used in the systems themselves.

### Constants

`Constant` is a simple wrapper type to store numerical Julia constants.

### Operations

Operations are the basic composition of variables and puts together the pieces
Expand Down
4 changes: 2 additions & 2 deletions src/ModelingToolkit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ include("domains.jl")
include("variables.jl")

Base.promote_rule(::Type{T},::Type{T2}) where {T<:Number,T2<:Expression} = Expression
Base.one(::Type{T}) where T<:Expression = Constant(1)
Base.zero(::Type{T}) where T<:Expression = Constant(0)
Base.zero(::Type{<:Expression}) = Constant(0)
Base.one(::Type{<:Expression}) = Constant(1)
Base.convert(::Type{Variable},x::Int64) = Constant(x)

function caclulate_jacobian end
Expand Down
4 changes: 2 additions & 2 deletions src/differentials.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ function (D::Differential)(x::Variable)
end
Base.:(==)(D1::Differential, D2::Differential) = D1.order == D2.order && D1.x == D2.x

Variable(x::Variable, D::Differential) = Variable(x.name,x.value,x.value_type,
Variable(x::Variable, D::Differential) = Variable(x.name,x.value_type,
x.subtype,D,x.dependents,x.description,x.flow,x.domain,
x.size,x.context)

Expand All @@ -31,7 +31,7 @@ function expand_derivatives(O::Operation)

return O
end
expand_derivatives(x::Variable) = x
expand_derivatives(x) = x

# Don't specialize on the function here
function Derivative(O::Operation,idx)
Expand Down
21 changes: 10 additions & 11 deletions src/operations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@ end
function Base.:(==)(x::Operation,y::Operation)
x.op == y.op && length(x.args) == length(y.args) && all(isequal.(x.args,y.args))
end
Base.:(==)(x::Operation,y::Number) = false
Base.:(==)(x::Number,y::Operation) = false
Base.:(==)(x::Operation,y::Nothing) = false
Base.:(==)(x::Nothing,y::Operation) = false
Base.:(==)(x::Variable,y::Operation) = false
Base.:(==)(x::Operation,y::Variable) = false
Base.:(==)(::Operation, ::Number ) = false
Base.:(==)(::Number , ::Operation) = false
Base.:(==)(::Operation, ::Variable ) = false
Base.:(==)(::Variable , ::Operation) = false
Base.:(==)(::Operation, ::Constant ) = false
Base.:(==)(::Constant , ::Operation) = false

Base.convert(::Type{Expr}, O::Operation) =
build_expr(:call, Any[Symbol(O.op); convert.(Expr, O.args)])
Expand All @@ -36,8 +36,7 @@ function find_replace!(O::Operation,x::Variable,y::Expression)
end

# For inv
Base.convert(::Type{Operation},x::Int) = Operation(identity,Expression[Constant(x)])
Base.convert(::Type{Operation},x::Bool) = Operation(identity,Expression[Constant(x)])
Base.convert(::Type{Operation},x::Variable) = Operation(identity,Expression[x])
Operation(x) = convert(Operation,x)
Operation(x::Operation) = x
Base.convert(::Type{Operation}, x::Number) = Operation(identity, Expression[Constant(x)])
Base.convert(::Type{Operation}, x::Operation) = x
Base.convert(::Type{Operation}, x::Expression) = Operation(identity, Expression[x])
Operation(x) = convert(Operation, x)
11 changes: 7 additions & 4 deletions src/simplify.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
function simplify_constants(O::Operation, shorten_tree = true)
function simplify_constants(O::Operation, shorten_tree)
while true
O′ = _simplify_constants(O, shorten_tree)
if is_operation(O′)
Expand All @@ -8,10 +8,13 @@ function simplify_constants(O::Operation, shorten_tree = true)
O = O′
end
end
simplify_constants(x, shorten_tree) = x
simplify_constants(x) = simplify_constants(x, true)


const AC_OPERATORS = (*, +)

function _simplify_constants(O, shorten_tree = true)
function _simplify_constants(O::Operation, shorten_tree)
# Tree shrinking
if shorten_tree && O.op AC_OPERATORS
# Flatten tree
Expand Down Expand Up @@ -67,7 +70,7 @@ function _simplify_constants(O, shorten_tree = true)

return O
end
simplify_constants(x::Variable, y=false) = x
_simplify_constants(x::Variable, y=false) = x
_simplify_constants(x, shorten_tree) = x
_simplify_constants(x) = _simplify_constants(x, true)

export simplify_constants
7 changes: 3 additions & 4 deletions src/systems/diffeqs/first_order_transform.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
function lower_varname(var::Variable, naming_scheme; lower=false)
D = var.diff
D == nothing && return var
D === nothing && return var
order = lower ? D.order-1 : D.order
lower_varname(var.name, D.x, order, var.subtype, naming_scheme)
end
function lower_varname(sym::Symbol, idv, order::Int, subtype::Symbol, naming_scheme)
order == 0 && return Variable(sym, subtype)
name = Symbol(String(sym)*naming_scheme*String(idv.name)^order)
Variable(name, subtype=subtype)
name = order == 0 ? sym : Symbol(sym, naming_scheme, string(idv.name)^order)
return Variable(name, subtype=subtype)
end

function ode_order_lowering(sys::DiffEqSystem; kwargs...)
Expand Down
2 changes: 1 addition & 1 deletion src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ end

toexpr(ex) = MacroTools.postwalk(x -> isa(x, Expression) ? convert(Expr, x) : x, ex)

is_constant(x::Variable) = x.subtype === :Constant
is_constant(::Constant) = true
is_constant(::Any) = false

is_operation(::Operation) = true
Expand Down
115 changes: 40 additions & 75 deletions src/variables.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
mutable struct Variable <: Expression
name::Symbol
value
value_type::DataType
subtype::Symbol
diff::Union{Function,Nothing} # FIXME
Expand All @@ -13,26 +12,23 @@ mutable struct Variable <: Expression
end

Variable(name,
value = nothing,
value_type = typeof(value);
value_type = Any;
subtype::Symbol=:Variable,
dependents::Vector{Variable} = Variable[],
flow::Bool = false,
description::String = "",
domain = Reals(),
size = nothing,
context = nothing) =
Variable(name,value,value_type,subtype,nothing,
Variable(name,value_type,subtype,nothing,
dependents,description,flow,domain,size,context)
Variable(name,args...;kwargs...) = Variable(name,args...;subtype=:Variable,kwargs...)

Variable(name,x::Variable) = Variable(name,x.value,x.value_type,
Variable(name,x::Variable) = Variable(name,x.value_type,
x.subtype,D,x.dependents,x.description,x.flow,x.domain,
x.size,x.context)

Parameter(name,args...;kwargs...) = Variable(name,args...;subtype=:Parameter,kwargs...)
Constant(value::Number) = Variable(Symbol(value),value,typeof(value);subtype=:Constant)
Constant(name,args...;kwargs...) = Variable(name,args...;subtype=:Constant,kwargs...)
IndependentVariable(name,args...;kwargs...) = Variable(name,args...;subtype=:IndependentVariable,kwargs...)

function DependentVariable(name,args...;dependents = [],kwargs...)
Expand Down Expand Up @@ -64,53 +60,38 @@ export Variable,Parameter,Constant,DependentVariable,IndependentVariable,JumpVar
@Var, @DVar, @IVar, @Param, @Const


Base.get(x::Variable) = x.value
struct Constant <: Expression
value::Number
end
Base.get(c::Constant) = c.value


Base.iszero(::Expression) = false
Base.iszero(c::Variable) = get(c) isa Number && iszero(get(c))
Base.isone(::Expression) = false
Base.isone(c::Variable) = get(c) isa Number && isone(get(c))
Base.iszero(ex::Expression) = isa(ex, Constant) && iszero(ex.value)
Base.isone(ex::Expression) = isa(ex, Constant) && isone(ex.value)


# Variables use isequal for equality since == is an Operation
function Base.:(==)(x::Variable,y::Variable)
x.name == y.name && x.subtype == y.subtype && x.value == y.value &&
function Base.:(==)(x::Variable, y::Variable)
x.name == y.name && x.subtype == y.subtype &&
x.value_type == y.value_type && x.diff == y.diff
end

function Base.:(==)(x::Variable,y::Number)
x == Constant(y)
end

function Base.:(==)(x::Number,y::Variable)
Constant(x) == y
end
Base.:(==)(::Variable, ::Number) = false
Base.:(==)(::Number, ::Variable) = false
Base.:(==)(::Variable, ::Constant) = false
Base.:(==)(::Constant, ::Variable) = false
Base.:(==)(c::Constant, n::Number) = c.value == n
Base.:(==)(n::Number, c::Constant) = c.value == n
Base.:(==)(a::Constant, b::Constant) = a.value == b.value

function Base.convert(::Type{Expr}, x::Variable)
if x.subtype == :Constant
return x.value
elseif x.diff == nothing
return :($(x.name))
else
return :($(Symbol("$(x.name)_$(x.diff.x.name)")))
end
x.diff === nothing && return x.name
return Symbol("$(x.name)_$(x.diff.x.name)")
end
Base.convert(::Type{Expr}, c::Constant) = c.value

function Base.show(io::IO, A::Variable)
if A.subtype == :Constant
print(io,"Constant($(A.value))")
else
str = "$(A.subtype)($(A.name))"
if A.value != nothing
str *= ", value = " * string(A.value)
end

if A.diff != nothing
str *= ", diff = " * string(A.diff)
end

print(io,str)
end
function Base.show(io::IO, x::Variable)
print(io, x.subtype, '(', x.name, ')')
x.diff === nothing || print(io, ", diff = ", x.diff)
end

extract_idv(eq) = eq.args[1].diff.x
Expand Down Expand Up @@ -159,45 +140,29 @@ function _parse_vars(macroname, fun, x)
# begin
# x
# y
# z = exp(2)
# z
# end
x = flatten_expr!(x)
for _var in x
iscall = typeof(_var) <: Expr && _var.head == :call
issym = _var isa Symbol
isassign = issym ? false : _var.head == :(=)
@assert iscall || issym || isassign "@$macroname expects a tuple of expressions!\nE.g. `@$macroname x y z=1`"
if iscall || issym
if iscall
dependents = :([$(_var.args[2:end]...)])
var = _var.args[1]
else
dependents = Variable[]
var = _var
end
lhs = var
push!(lhss, lhs)
expr = :( $lhs = $fun( Symbol($(String(lhs))) ,
dependents = $dependents))
end
if isassign
iscall = typeof(_var.args[1]) <: Expr && _var.args[1].head == :call
if iscall
dependents = :([$(_var.args[1].args[2:end]...)])
lhs = _var.args[1].args[1]
else
dependents = Variable[]
lhs = _var.args[1]
end
rhs = _var.args[2]
push!(lhss, lhs)
expr = :( $lhs = $fun( Symbol($(String(lhs))) , $rhs,
dependents = $dependents))
@assert iscall || issym "@$macroname expects a tuple of expressions!\nE.g. `@$macroname x y z`"

if iscall
dependents = :([$(_var.args[2:end]...)])
lhs = _var.args[1]
else
dependents = Variable[]
lhs = _var
end

push!(lhss, lhs)
expr = :( $lhs = $fun( Symbol($(String(lhs))) ,
dependents = $dependents))
push!(ex.args, expr)
end
push!(ex.args, Expr(:tuple, lhss...))
ex
push!(ex.args, build_expr(:tuple, lhss))
return ex
end

for funs in ((:DVar, :DependentVariable), (:IVar, :IndependentVariable),
Expand Down
8 changes: 4 additions & 4 deletions test/basic_variables_and_operations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@ using Test
@Const c=0

# Default values
p = Parameter(:p, 1)
u = DependentVariable(:u, [1], dependents = [t])
p = Parameter(:p)
u = DependentVariable(:u, dependents = [t])

s = JumpVariable(:s,3,dependents=[t])
n = NoiseVariable(:n,dependents=[t])
s = JumpVariable(:s, dependents=[t])
n = NoiseVariable(:n, dependents=[t])

σ*(y-x)
D(x)
Expand Down
2 changes: 1 addition & 1 deletion test/derivatives.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ expand_derivatives(dsin)

@test expand_derivatives(dsin) == cos(t)
dcsch = D(csch(t))
@test expand_derivatives(dcsch) == simplify_constants(Operation(coth(t)*csch(t)*-1))
@test expand_derivatives(dcsch) == simplify_constants(coth(t) * csch(t) * -1)

# Chain rule
dsinsin = D(sin(sin(t)))
Expand Down
5 changes: 3 additions & 2 deletions test/system_construction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,14 +50,15 @@ function test_eqs(eqs1, eqs2)
eq = true
for i in eachindex(eqs1)
lhs1, lhs2 = eqs1[i].args[1], eqs2[i].args[1]
typeof(lhs1) === typeof(lhs2) || return false
for f in fieldnames(typeof(lhs1))
eq = eq && isequal(getfield(lhs1, f), getfield(lhs2, f))
end
eq = eq && isequal(eqs1[i].args[2], eqs2[i].args[2])
end
@test_broken eq
eq
end
test_eqs(de1.eqs, lowered_eqs)
@test_broken test_eqs(de1.eqs, lowered_eqs)

# Internal calculations
eqs = [a ~ y-x,
Expand Down

0 comments on commit 0097578

Please sign in to comment.