diff --git a/src/TermInterface.jl b/src/TermInterface.jl index f22ca5c..3d94ec9 100644 --- a/src/TermInterface.jl +++ b/src/TermInterface.jl @@ -3,7 +3,7 @@ module TermInterface """ istree(x) -Returns `true` if `x` is a term. If true, `operation`, `arguments` +Returns `true` if `x` is a term. If true, `head` and `children` must also be defined for `x` appropriately. """ istree(x) = false @@ -12,14 +12,12 @@ export istree """ symtype(x) -Returns the symbolic type of `x`. By default this is just `typeof(x)`. +Returns the symbolic type of `x`. By default this is just `Any`. Define this for your symbolic types if you want `SymbolicUtils.simplify` to apply rules specific to numbers (such as commutativity of multiplication). Or such rules that may be implemented in the future. """ -function symtype(x) - typeof(x) -end +symtype(x) = Any export symtype """ @@ -31,20 +29,6 @@ on `x` and must return a Symbol. issym(x) = false export issym -""" - exprhead(x) - -If `x` is a term as defined by `istree(x)`, `exprhead(x)` must return a symbol, -corresponding to the head of the `Expr` most similar to the term `x`. -If `x` represents a function call, for example, the `exprhead` is `:call`. -If `x` represents an indexing operation, such as `arr[i]`, then `exprhead` is `:ref`. -Note that `exprhead` is different from `operation` and both functions should -be defined correctly in order to let other packages provide code generation -and pattern matching features. -""" -function exprhead end -export exprhead - """ head(x) @@ -54,33 +38,28 @@ term if `x`. The `head` type has to be provided by the package. function head end export head -""" - head_symbol(x::HeadType) - -If `x` is a head object, `head_symbol(T, x)` returns a `Symbol` object that -corresponds to `y.head` if `y` was the representation of the corresponding term -as a Julia Expression. This is useful to define interoperability between -symbolic term types defined in different packages and should be used when -calling `maketerm`. -""" -function head_symbol end -export head_symbol - """ children(x) -Get the arguments of `x`, must be defined if `istree(x)` is `true`. +Get the children of `x`, must be defined if `istree(x)` is `true`. """ function children end export children +""" + is_function_call(x) + +Return true if `x` is a term as defined by `istree(x)` and corresponds to a +function call. If true, `operation` and `arguments` must be defined +appropriately. +""" +is_function_call(x) = false """ operation(x) -If `x` is a term as defined by `istree(x)`, `operation(x)` returns the -operation of the term if `x` represents a function call, for example, the head -is the function being called. +If `x` is a function call as defined by `is_function_call(x)`, `operation(x)` +returns the function being called. """ function operation end export operation @@ -88,7 +67,8 @@ export operation """ arguments(x) -Get the arguments of `x`, must be defined if `istree(x)` is `true`. +If `x` is a function call as defined by `is_function_call(x)`, `arguments(x)` +returns the arguments on which the function is called. """ function arguments end export arguments @@ -137,74 +117,17 @@ end """ - maketerm(head::H, children; type=Any, metadata=nothing) + maketerm(head::H, children; symtype=Any, metadata=nothing) -Has to be implemented by the provider of H. -Returns a term that is in the same closure of types as `typeof(x)`, -with `head` as the head and `children` as the arguments, `type` as the symtype -and `metadata` as the metadata. +Has to be implemented by the provider of H. Returns a term that is in the same +closure of types as `H`, with `head` as the head and `children` as the children, +`symtype` as the symtype and `metadata` as the metadata. """ function maketerm end export maketerm -""" - is_operation(f) - -Returns a single argument anonymous function predicate, that returns `true` if and only if -the argument to the predicate satisfies `istree` and `operation(x) == f` -""" -is_operation(f) = @nospecialize(x) -> istree(x) && (operation(x) == f) -export is_operation - - -""" - node_count(t) -Count the nodes in a symbolic expression tree satisfying `istree` and `arguments`. -""" -node_count(t) = istree(t) ? reduce(+, node_count(x) for x in arguments(t), init = 0) + 1 : 1 -export node_count - +include("utils.jl") include("expr.jl") -""" - @matchable struct Foo fields... end [HeadType] - -Take a struct definition and automatically define `TermInterface` methods. This -will automatically define a head type. If `HeadType` is given then it will be -used as `head(::Foo)`. If it is omitted, and the struct is called `Foo`, then -the head type will be called `FooHead`. The `head_symbol` of such head types -will default to `:call`. -""" -macro matchable(expr, head_name=nothing) - @assert expr.head == :struct - name = expr.args[2] - if name isa Expr - name.head === :(<:) && (name = name.args[1]) - name isa Expr && name.head === :curly && (name = name.args[1]) - end - fields = filter(x -> x isa Symbol || (x isa Expr && x.head == :(::)), expr.args[3].args) - get_name(s::Symbol) = s - get_name(e::Expr) = (@assert(e.head == :(::)); e.args[1]) - fields = map(get_name, fields) - head_name = isnothing(head_name) ? Symbol(name, :Head) : head_name - - quote - $expr - struct $head_name - head - end - TermInterface.head_symbol(x::$head_name) = x.head - # TODO default to call? - TermInterface.head(::$name) = $head_name(:call) - TermInterface.istree(::$name) = true - TermInterface.operation(::$name) = $name - TermInterface.arguments(x::$name) = getfield.((x,), ($(QuoteNode.(fields)...),)) - TermInterface.children(x::$name) = [operation(x); arguments(x)...] - TermInterface.arity(x::$name) = $(length(fields)) - Base.length(x::$name) = $(length(fields) + 1) - end |> esc -end -export @matchable - end # module diff --git a/src/expr.jl b/src/expr.jl index 9a0fdd2..102b2ae 100644 --- a/src/expr.jl +++ b/src/expr.jl @@ -6,12 +6,12 @@ struct ExprHead end export ExprHead -head_symbol(eh::ExprHead) = eh.head - istree(x::Expr) = true head(e::Expr) = ExprHead(e.head) children(e::Expr) = e.args +is_function_call(e::Expr) = head(e).head in (:call, :macrocall) + # See https://docs.julialang.org/en/v1/devdocs/ast/ function operation(e::Expr) h = head(e) @@ -19,7 +19,7 @@ function operation(e::Expr) if hh in (:call, :macrocall) e.args[1] else - hh + throw(ArgumentError("Not a function call")) end end @@ -29,7 +29,7 @@ function arguments(e::Expr) if hh in (:call, :macrocall) e.args[2:end] else - e.args + throw(ArgumentError("Not a function call")) end end diff --git a/src/utils.jl b/src/utils.jl new file mode 100644 index 0000000..ef15427 --- /dev/null +++ b/src/utils.jl @@ -0,0 +1,16 @@ +""" + is_head(f) + +Returns a single argument anonymous function predicate, that returns `true` if and only if +the argument to the predicate satisfies `istree` and `head(x) == f` +""" +is_head(f) = @nospecialize(x) -> istree(x) && (head(x) == f) +export is_head + + +""" + node_count(t) +Count the nodes in a symbolic expression tree satisfying `istree` and `arguments`. +""" +node_count(t) = istree(t) ? reduce(+, node_count(x) for x in children(t), init = 0) + 1 : 1 +export node_count diff --git a/test/runtests.jl b/test/runtests.jl index 17c791c..a2b5906 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -3,6 +3,7 @@ using TermInterface, Test @testset "Expr" begin ex = :(f(a, b)) @test head(ex) == ExprHead(:call) + @test is_function_call(ex) == true @test children(ex) == [:f, :a, :b] @test operation(ex) == :f @test arguments(ex) == [:a, :b] @@ -10,23 +11,21 @@ using TermInterface, Test ex = :(arr[i, j]) @test head(ex) == ExprHead(:ref) - @test operation(ex) == :ref - @test arguments(ex) == [:arr, :i, :j] + @test children(ex) == [:arr, :i, :j] + @test is_function_call(ex) == false @test ex == maketerm(ExprHead(:ref), [:arr, :i, :j]) ex = :(i, j) @test head(ex) == ExprHead(:tuple) - @test operation(ex) == :tuple - @test arguments(ex) == [:i, :j] @test children(ex) == [:i, :j] + @test is_function_call(ex) == false @test ex == maketerm(ExprHead(:tuple), [:i, :j]) - ex = Expr(:block, :a, :b, :c) @test head(ex) == ExprHead(:block) - @test operation(ex) == :block - @test children(ex) == arguments(ex) == [:a, :b, :c] + @test children(ex) == [:a, :b, :c] + @test is_function_call(ex) == false @test ex == maketerm(ExprHead(:block), [:a, :b, :c]) end @@ -40,31 +39,12 @@ end end TermInterface.head(::Foo) = FooHead(:call) TermInterface.head_symbol(q::FooHead) = q.head - TermInterface.operation(::Foo) = Foo TermInterface.istree(::Foo) = true - TermInterface.arguments(x::Foo) = [x.args...] - TermInterface.children(x::Foo) = [operation(x); x.args...] + TermInterface.children(x::Foo) = x.args t = Foo(1, 2) @test head(t) == FooHead(:call) @test head_symbol(head(t)) == :call - @test operation(t) == Foo - @test istree(t) == true - @test arguments(t) == [1, 2] - @test children(t) == [Foo, 1, 2] -end - -@testset "Automatically Generated Methods" begin - @matchable struct Bar - a - b::Int - end - - t = Bar(1, 2) - @test head(t) == BarHead(:call) - @test head_symbol(head(t)) == :call - @test operation(t) == Bar @test istree(t) == true - @test arguments(t) == (1, 2) - @test children(t) == [Bar, 1, 2] + @test children(t) == [1, 2] end \ No newline at end of file