Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 21 additions & 98 deletions src/TermInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ module TermInterface
"""
istree(x)

Returns `true` if `x` is a term. If true, `operation`, `arguments`
Returns `true` if `x` is a term. If true, `head` and `children`
must also be defined for `x` appropriately.
"""
istree(x) = false
Expand All @@ -12,14 +12,12 @@ export istree
"""
symtype(x)

Returns the symbolic type of `x`. By default this is just `typeof(x)`.
Returns the symbolic type of `x`. By default this is just `Any`.
Define this for your symbolic types if you want `SymbolicUtils.simplify` to apply rules
specific to numbers (such as commutativity of multiplication). Or such
rules that may be implemented in the future.
"""
function symtype(x)
typeof(x)
end
symtype(x) = Any
export symtype

"""
Expand All @@ -31,20 +29,6 @@ on `x` and must return a Symbol.
issym(x) = false
export issym

"""
exprhead(x)

If `x` is a term as defined by `istree(x)`, `exprhead(x)` must return a symbol,
corresponding to the head of the `Expr` most similar to the term `x`.
If `x` represents a function call, for example, the `exprhead` is `:call`.
If `x` represents an indexing operation, such as `arr[i]`, then `exprhead` is `:ref`.
Note that `exprhead` is different from `operation` and both functions should
be defined correctly in order to let other packages provide code generation
and pattern matching features.
"""
function exprhead end
export exprhead

"""
head(x)

Expand All @@ -54,41 +38,37 @@ term if `x`. The `head` type has to be provided by the package.
function head end
export head

"""
head_symbol(x::HeadType)

If `x` is a head object, `head_symbol(T, x)` returns a `Symbol` object that
corresponds to `y.head` if `y` was the representation of the corresponding term
as a Julia Expression. This is useful to define interoperability between
symbolic term types defined in different packages and should be used when
calling `maketerm`.
"""
function head_symbol end
export head_symbol

"""
children(x)

Get the arguments of `x`, must be defined if `istree(x)` is `true`.
Get the children of `x`, must be defined if `istree(x)` is `true`.
"""
function children end
export children

"""
is_function_call(x)

Return true if `x` is a term as defined by `istree(x)` and corresponds to a
function call. If true, `operation` and `arguments` must be defined
appropriately.
"""
is_function_call(x) = false

"""
operation(x)

If `x` is a term as defined by `istree(x)`, `operation(x)` returns the
operation of the term if `x` represents a function call, for example, the head
is the function being called.
If `x` is a function call as defined by `is_function_call(x)`, `operation(x)`
returns the function being called.
"""
function operation end
export operation

"""
arguments(x)

Get the arguments of `x`, must be defined if `istree(x)` is `true`.
If `x` is a function call as defined by `is_function_call(x)`, `arguments(x)`
returns the arguments on which the function is called.
"""
function arguments end
export arguments
Expand Down Expand Up @@ -137,74 +117,17 @@ end


"""
maketerm(head::H, children; type=Any, metadata=nothing)
maketerm(head::H, children; symtype=Any, metadata=nothing)

Has to be implemented by the provider of H.
Returns a term that is in the same closure of types as `typeof(x)`,
with `head` as the head and `children` as the arguments, `type` as the symtype
and `metadata` as the metadata.
Has to be implemented by the provider of H. Returns a term that is in the same
closure of types as `H`, with `head` as the head and `children` as the children,
`symtype` as the symtype and `metadata` as the metadata.
"""
function maketerm end
export maketerm

"""
is_operation(f)

Returns a single argument anonymous function predicate, that returns `true` if and only if
the argument to the predicate satisfies `istree` and `operation(x) == f`
"""
is_operation(f) = @nospecialize(x) -> istree(x) && (operation(x) == f)
export is_operation


"""
node_count(t)
Count the nodes in a symbolic expression tree satisfying `istree` and `arguments`.
"""
node_count(t) = istree(t) ? reduce(+, node_count(x) for x in arguments(t), init = 0) + 1 : 1
export node_count

include("utils.jl")
include("expr.jl")

"""
@matchable struct Foo fields... end [HeadType]

Take a struct definition and automatically define `TermInterface` methods. This
will automatically define a head type. If `HeadType` is given then it will be
used as `head(::Foo)`. If it is omitted, and the struct is called `Foo`, then
the head type will be called `FooHead`. The `head_symbol` of such head types
will default to `:call`.
"""
macro matchable(expr, head_name=nothing)
@assert expr.head == :struct
name = expr.args[2]
if name isa Expr
name.head === :(<:) && (name = name.args[1])
name isa Expr && name.head === :curly && (name = name.args[1])
end
fields = filter(x -> x isa Symbol || (x isa Expr && x.head == :(::)), expr.args[3].args)
get_name(s::Symbol) = s
get_name(e::Expr) = (@assert(e.head == :(::)); e.args[1])
fields = map(get_name, fields)
head_name = isnothing(head_name) ? Symbol(name, :Head) : head_name

quote
$expr
struct $head_name
head
end
TermInterface.head_symbol(x::$head_name) = x.head
# TODO default to call?
TermInterface.head(::$name) = $head_name(:call)
TermInterface.istree(::$name) = true
TermInterface.operation(::$name) = $name
TermInterface.arguments(x::$name) = getfield.((x,), ($(QuoteNode.(fields)...),))
TermInterface.children(x::$name) = [operation(x); arguments(x)...]
TermInterface.arity(x::$name) = $(length(fields))
Base.length(x::$name) = $(length(fields) + 1)
end |> esc
end
export @matchable

end # module

8 changes: 4 additions & 4 deletions src/expr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,20 @@ struct ExprHead
end
export ExprHead

head_symbol(eh::ExprHead) = eh.head

istree(x::Expr) = true
head(e::Expr) = ExprHead(e.head)
children(e::Expr) = e.args

is_function_call(e::Expr) = head(e).head in (:call, :macrocall)

# See https://docs.julialang.org/en/v1/devdocs/ast/
function operation(e::Expr)
h = head(e)
hh = h.head
if hh in (:call, :macrocall)
e.args[1]
else
hh
throw(ArgumentError("Not a function call"))
end
end

Expand All @@ -29,7 +29,7 @@ function arguments(e::Expr)
if hh in (:call, :macrocall)
e.args[2:end]
else
e.args
throw(ArgumentError("Not a function call"))
end
end

Expand Down
16 changes: 16 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
"""
is_head(f)

Returns a single argument anonymous function predicate, that returns `true` if and only if
the argument to the predicate satisfies `istree` and `head(x) == f`
"""
is_head(f) = @nospecialize(x) -> istree(x) && (head(x) == f)
export is_head


"""
node_count(t)
Count the nodes in a symbolic expression tree satisfying `istree` and `arguments`.
"""
node_count(t) = istree(t) ? reduce(+, node_count(x) for x in children(t), init = 0) + 1 : 1
export node_count
36 changes: 8 additions & 28 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,30 +3,29 @@ using TermInterface, Test
@testset "Expr" begin
ex = :(f(a, b))
@test head(ex) == ExprHead(:call)
@test is_function_call(ex) == true
@test children(ex) == [:f, :a, :b]
@test operation(ex) == :f
@test arguments(ex) == [:a, :b]
@test ex == maketerm(ExprHead(:call), [:f, :a, :b])

ex = :(arr[i, j])
@test head(ex) == ExprHead(:ref)
@test operation(ex) == :ref
@test arguments(ex) == [:arr, :i, :j]
@test children(ex) == [:arr, :i, :j]
@test is_function_call(ex) == false
@test ex == maketerm(ExprHead(:ref), [:arr, :i, :j])


ex = :(i, j)
@test head(ex) == ExprHead(:tuple)
@test operation(ex) == :tuple
@test arguments(ex) == [:i, :j]
@test children(ex) == [:i, :j]
@test is_function_call(ex) == false
@test ex == maketerm(ExprHead(:tuple), [:i, :j])


ex = Expr(:block, :a, :b, :c)
@test head(ex) == ExprHead(:block)
@test operation(ex) == :block
@test children(ex) == arguments(ex) == [:a, :b, :c]
@test children(ex) == [:a, :b, :c]
@test is_function_call(ex) == false
@test ex == maketerm(ExprHead(:block), [:a, :b, :c])
end

Expand All @@ -40,31 +39,12 @@ end
end
TermInterface.head(::Foo) = FooHead(:call)
TermInterface.head_symbol(q::FooHead) = q.head
TermInterface.operation(::Foo) = Foo
TermInterface.istree(::Foo) = true
TermInterface.arguments(x::Foo) = [x.args...]
TermInterface.children(x::Foo) = [operation(x); x.args...]
TermInterface.children(x::Foo) = x.args

t = Foo(1, 2)
@test head(t) == FooHead(:call)
@test head_symbol(head(t)) == :call
@test operation(t) == Foo
@test istree(t) == true
@test arguments(t) == [1, 2]
@test children(t) == [Foo, 1, 2]
end

@testset "Automatically Generated Methods" begin
@matchable struct Bar
a
b::Int
end

t = Bar(1, 2)
@test head(t) == BarHead(:call)
@test head_symbol(head(t)) == :call
@test operation(t) == Bar
@test istree(t) == true
@test arguments(t) == (1, 2)
@test children(t) == [Bar, 1, 2]
@test children(t) == [1, 2]
end