Skip to content

Commit

Permalink
Merge branch 'master' into gd/adtypes
Browse files Browse the repository at this point in the history
  • Loading branch information
gdalle committed May 17, 2024
2 parents 535f889 + a129272 commit df51d34
Show file tree
Hide file tree
Showing 22 changed files with 118 additions and 117 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ SpecialFunctions = "2"
StaticArrays = "1.1"
SymbolicIndexingInterface = "0.3.14"
SymbolicLimits = "0.2.0"
SymbolicUtils = "1.4"
SymbolicUtils = "1.7"
julia = "1.10"

[extras]
Expand Down
4 changes: 2 additions & 2 deletions docs/src/getting_started.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ using Symbolics
```

After defining variables as symbolic, symbolic expressions, which we call a
`istree` object, can be generated by utilizing Julia expressions. For example:
`iscall` object, can be generated by utilizing Julia expressions. For example:

```@example symbolic_basics
z = x^2 + y
Expand All @@ -35,7 +35,7 @@ A = [x^2 + y 0 2x
y^2 + x 0 0]
```

Note that by default, `@variables` returns `Sym` or `istree` objects wrapped in
Note that by default, `@variables` returns `Sym` or `iscall` objects wrapped in
`Num` to make them behave like subtypes of `Real`. Any operation on
these `Num` objects will return a new `Num` object, wrapping the result of
computing symbolically on the underlying values.
Expand Down
8 changes: 4 additions & 4 deletions docs/src/manual/variables.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
Symbolics IR mirrors the Julia AST but allows for easy mathematical
manipulation by itself following mathematical semantics. The base of the IR is
the `Sym` type, which defines a symbolic variable. Registered (mathematical)
functions on `Sym`s (or `istree` objects) return an expression that `istree`.
functions on `Sym`s (or `iscall` objects) return an expression that `iscall`.
For example, `op1 = x+y` is one symbolic object and `op2 = 2z` is another, and
so `op1*op2` is another tree object. Then, at the top, an `Equation`, normally
written as `op1 ~ op2`, defines the symbolic equality between two operations.
Expand All @@ -13,10 +13,10 @@ written as `op1 ~ op2`, defines the symbolic equality between two operations.
`Sym`, `Term`, and `FnType` are from
[SymbolicUtils.jl](https://symbolicutils.juliasymbolics.org/api/). Note that in
Symbolics, we always use `Sym{Real}`, `Term{Real}`, and
`FnType{Tuple{Any}, Real}`. To get the arguments of an `istree` object, use
`FnType{Tuple{Any}, Real}`. To get the arguments of an `iscall` object, use
`arguments(t::Term)`, and to get the operation, use `operation(t::Term)`.
However, note that one should never dispatch on `Term` or test `isa Term`.
Instead, one needs to use `SymbolicUtils.istree` to check if `arguments` and
Instead, one needs to use `SymbolicUtils.iscall` to check if `arguments` and
`operation` is defined.

```@docs
Expand Down Expand Up @@ -80,7 +80,7 @@ Control flow can be expressed in Symbolics.jl in the following ways:
## Inspection Functions

```@docs
SymbolicUtils.istree
SymbolicUtils.iscall
SymbolicUtils.operation
SymbolicUtils.arguments
```
4 changes: 2 additions & 2 deletions ext/SymbolicsSymPyExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,13 @@ end

using Symbolics: value
using Symbolics.SymbolicUtils
using SymbolicUtils: istree, operation, arguments, symtype,
using SymbolicUtils: iscall, operation, arguments, symtype,
FnType, Symbolic

function Symbolics.symbolics_to_sympy(expr)
expr = value(expr)
expr isa Symbolic || return expr
if istree(expr)
if iscall(expr)
sop = symbolics_to_sympy(operation(expr))
sargs = map(symbolics_to_sympy, arguments(expr))
if sop === (^) && length(sargs) == 2 && sargs[2] isa Number
Expand Down
2 changes: 1 addition & 1 deletion src/Symbolics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ using PrecompileTools

import DomainSets: Domain

import SymbolicUtils: similarterm, istree, operation, arguments, symtype, metadata
import SymbolicUtils: similarterm, iscall, operation, arguments, symtype, metadata

import SymbolicUtils: Term, Add, Mul, Pow, Sym, Div, BasicSymbolic,
FnType, @rule, Rewriters, substitute,
Expand Down
8 changes: 4 additions & 4 deletions src/array-lib.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ end
function Base.getindex(x::SymArray, idx...)
idx = unwrap.(idx)
meta = metadata(unwrap(x))
if istree(x) && (op = operation(x)) isa Operator
if iscall(x) && (op = operation(x)) isa Operator
args = arguments(x)
return op(only(args)[idx...])
elseif shape(x) !== Unknown() && all(i -> i isa Integer, idx)
Expand Down Expand Up @@ -111,7 +111,7 @@ end

import Base: +, -, *
tup(c::CartesianIndex) = Tuple(c)
tup(c::Symbolic{CartesianIndex}) = istree(c) ? arguments(c) : error("Cartesian index not found")
tup(c::Symbolic{CartesianIndex}) = iscall(c) ? arguments(c) : error("Cartesian index not found")

@wrapped function -(x::CartesianIndex, y::CartesianIndex)
CartesianIndex((tup(x) .- tup(y))...)
Expand Down Expand Up @@ -251,7 +251,7 @@ isadjointvec(A::Adjoint) = ndims(parent(A)) == 1
isadjointvec(A::Transpose) = ndims(parent(A)) == 1

function isadjointvec(A)
if istree(A)
if iscall(A)
(operation(A) === (adjoint) ||
operation(A) == (transpose)) && ndims(arguments(A)[1]) == 1
else
Expand Down Expand Up @@ -305,7 +305,7 @@ function _matvec(A, b)
end
@wrapped (*)(A::AbstractMatrix, b::AbstractVector) = _matvec(A, b)

# specialize `dot` to dispatch on `Symbolic{<:Number}` to eventually work for
# specialize `dot` to dispatch on `Symbolic{<:Number}` to eventually work for
# arrays of (possibly unwrapped) Symbolic types, see issue #831
@wrapped LinearAlgebra.dot(x::Number, y::Number) = conj(x) * y

Expand Down
40 changes: 20 additions & 20 deletions src/arrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ shape(aop::ArrayOp) = aop.shape

const show_arrayop = Ref{Bool}(false)
function Base.show(io::IO, aop::ArrayOp)
if istree(aop.term) && !show_arrayop[]
if iscall(aop.term) && !show_arrayop[]
show(io, aop.term)
else
print(io, "@arrayop")
Expand All @@ -117,7 +117,7 @@ function Base.showarg(io::IO, aop::ArrayOp, toplevel)
end

symtype(a::ArrayOp{T}) where {T} = T
istree(a::ArrayOp) = true
iscall(a::ArrayOp) = true
function operation(a::ArrayOp)
isnothing(a.term) ? typeof(a) : operation(a.term)
end
Expand Down Expand Up @@ -332,7 +332,7 @@ function get_extents(xs)
if all(iszerowrap, boundaries)
get(first(xs))
else
ii = findfirst(x->issym(x) || istree(x), boundaries)
ii = findfirst(x->issym(x) || iscall(x), boundaries)
if !isnothing(ii)
error("Could not find the boundary from symbolic index $(xs[ii]). Please manually specify the range of indices.")
end
Expand All @@ -355,11 +355,11 @@ get_extents(x::AbstractRange) = x
# boundary: how much padding is this indexing requiring, for example
# boundary is 2 for x[i + 2], and boundary = -2 for x[i - 2]
function idx_to_axes(expr, dict=Dict{Any, Vector}(), ranges=Dict())
if istree(expr)
if iscall(expr)
if operation(expr) === (getindex)
args = arguments(expr)
for (axis, idx_expr) in enumerate(@views args[2:end])
if issym(idx_expr) || istree(idx_expr)
if issym(idx_expr) || iscall(idx_expr)
vs = get_variables(idx_expr)
isempty(vs) && continue
sym = only(get_variables(idx_expr))
Expand Down Expand Up @@ -529,9 +529,9 @@ wrapper_type(::Type{<:AbstractVector{T}}) where {T} = Arr{maybewrap(T), 1}

function Base.show(io::IO, arr::Arr)
x = unwrap(arr)
istree(x) && print(io, "(")
iscall(x) && print(io, "(")
print(io, unwrap(arr))
istree(x) && print(io, ")")
iscall(x) && print(io, ")")
if !(shape(x) isa Unknown)
print(io, "[", join(string.(axes(arr)), ","), "]")
end
Expand Down Expand Up @@ -618,7 +618,7 @@ function replace_by_scalarizing(ex, dict)
end

function rewrite_operation(x)
if istree(x) && istree(operation(x))
if iscall(x) && iscall(operation(x))
f = operation(x)
ff = replace_by_scalarizing(f, dict)
if metadata(x) !== nothing
Expand All @@ -638,7 +638,7 @@ end

function prewalk_if(cond, f, t, similarterm)
t′ = cond(t) ? f(t) : return t
if istree(t′)
if iscall(t′)
return similarterm(t′, operation(t′),
map(x->prewalk_if(cond, f, x, similarterm), arguments(t′)))
else
Expand All @@ -652,7 +652,7 @@ function scalarize(arr::AbstractArray, idx)
end

function scalarize(arr, idx)
if istree(arr)
if iscall(arr)
scalarize_op(operation(arr), arr, idx)
else
error("scalarize is not defined for $arr at idx=$idx")
Expand Down Expand Up @@ -761,20 +761,20 @@ eval_array_term(op) = eval_array_term(operation(op), op)

function scalarize(arr)
if arr isa Arr || arr isa Symbolic{<:AbstractArray}
if istree(arr)
if iscall(arr)
arr = eval_array_term(arr)
end
map(Iterators.product(axes(arr)...)) do i
scalarize(arr[i...]) # Use arr[i...] here to trigger any getindex hooks
end
elseif istree(arr) && operation(arr) == getindex
elseif iscall(arr) && operation(arr) == getindex
args = arguments(arr)
scalarize(args[1], (args[2:end]...,))
elseif arr isa Num
wrap(scalarize(unwrap(arr)))
elseif istree(arr) && symtype(arr) <: Number
elseif iscall(arr) && symtype(arr) <: Number
t = similarterm(arr, operation(arr), map(scalarize, arguments(arr)), symtype(arr), metadata=metadata(arr))
istree(t) ? scalarize_op(operation(t), t) : t
iscall(t) ? scalarize_op(operation(t), t) : t
else
arr
end
Expand Down Expand Up @@ -804,7 +804,7 @@ function arraymaker(T, shape, views, seq...)
ArrayMaker{T}(shape, [(views .=> seq)...], nothing)
end

istree(x::ArrayMaker) = true
iscall(x::ArrayMaker) = true
operation(x::ArrayMaker) = arraymaker
arguments(x::ArrayMaker) = [eltype(x), shape(x), map(first, x.sequence), map(last, x.sequence)...]

Expand Down Expand Up @@ -965,7 +965,7 @@ end

function scalarize(x::ArrayMaker, idx)
for (vw, arr) in reverse(x.sequence) # last one wins
if any(x->issym(x) || istree(x), idx)
if any(x->issym(x) || iscall(x), idx)
return term(getindex, x, idx...)
end
if all(in.(idx, vw))
Expand All @@ -979,7 +979,7 @@ function scalarize(x::ArrayMaker, idx)
end
end
end
if !any(x->issym(x) || istree(x), idx) && all(in.(idx, axes(x)))
if !any(x->issym(x) || iscall(x), idx) && all(in.(idx, axes(x)))
throw(UndefRefError())
end

Expand All @@ -992,7 +992,7 @@ end
function SymbolicUtils.Code.toexpr(x::ArrayOp, st)
haskey(st.symbolify, x) && return st.symbolify[x]

if istree(x.term)
if iscall(x.term)
toexpr(x.term, st)
else
_array_toexpr(x, st)
Expand Down Expand Up @@ -1048,7 +1048,7 @@ end

function inplace_builtin(term, outsym)
isarr(n) = x->symtype(x) <: AbstractArray{<:Any, n}
if istree(term) && operation(term) == (*) && length(arguments(term)) == 2
if iscall(term) && operation(term) == (*) && length(arguments(term)) == 2
A, B = arguments(term)
isarr(2)(A) && (isarr(1)(B) || isarr(2)(B)) && return :($mul!($outsym, $A, $B))
end
Expand All @@ -1058,7 +1058,7 @@ end
function find_inter(acc, expr)
if !issym(expr) && symtype(expr) <: AbstractArray
push!(acc, expr)
elseif istree(expr)
elseif iscall(expr)
foreach(x -> find_inter(acc, x), arguments(expr))
end
acc
Expand Down
8 changes: 4 additions & 4 deletions src/build_function.jl
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ function _build_function(target::JuliaTarget, op::Union{Arr, ArrayOp}, args...;

N = length(shape(op))
op = unwrap(op)
if op isa ArrayOp && istree(op.term)
if op isa ArrayOp && iscall(op.term)
op_body = op.term
else
op_body = :(let $outsym = zeros(Float64, map(length, ($(shape(op)...),)))
Expand Down Expand Up @@ -221,7 +221,7 @@ Build function target: `JuliaTarget`
```julia
function _build_function(target::JuliaTarget, rhss, args...;
conv = toexpr,
conv = toexpr,
expression = Val{true},
checkbounds = false,
linenumbers = false,
Expand Down Expand Up @@ -584,13 +584,13 @@ function numbered_expr(O::Symbolic,varnumbercache,args...;varordering = args[1],
states = LazyState(),
lhsname=:du,rhsnames=[Symbol("MTK$i") for i in 1:length(args)])
O = value(O)
if (issym(O) || issym(operation(O))) || (istree(O) && operation(O) == getindex)
if (issym(O) || issym(operation(O))) || (iscall(O) && operation(O) == getindex)
(j,i) = get(varnumbercache, O, (nothing, nothing))
if !isnothing(j)
return i==0 ? :($(rhsnames[j])) : :($(rhsnames[j])[$(i+offset)])
end
end
if istree(O)
if iscall(O)
if operation(O) === getindex
args = arguments(O)
Expr(:ref, toexpr(args[1], states), toexpr.(args[2:end] .+ offset, (states,))...)
Expand Down
8 changes: 4 additions & 4 deletions src/complex.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@ function wrapper_type(::Type{Complex{T}}) where T
end

symtype(a::ComplexTerm{T}) where T = Complex{T}
istree(a::ComplexTerm) = true
iscall(a::ComplexTerm) = true
operation(a::ComplexTerm{T}) where T = Complex{T}
arguments(a::ComplexTerm) = [a.re, a.im]
metadata(a::ComplexTerm) = a.re.metadata
metadata(a::ComplexTerm) = metadata(a.re)

function similarterm(t::ComplexTerm, f, args, symtype; metadata=nothing)
if f <: Complex
Expand All @@ -41,8 +41,8 @@ function Base.show(io::IO, a::Complex{Num})
rr = unwrap(real(a))
ii = unwrap(imag(a))

if istree(rr) && (operation(rr) === real) &&
istree(ii) && (operation(ii) === imag) &&
if iscall(rr) && (operation(rr) === real) &&
iscall(ii) && (operation(ii) === imag) &&
isequal(arguments(rr)[1], arguments(ii)[1])

return print(io, arguments(rr)[1])
Expand Down
Loading

0 comments on commit df51d34

Please sign in to comment.