Skip to content

Commit

Permalink
Merge pull request #1179 from JuliaSymbolics/b/sort-arguments-for-pri…
Browse files Browse the repository at this point in the history
…nting-&-code-generation

Refactor: Migrate from deprecated `unsorted_arguments` to `arguments`
  • Loading branch information
ChrisRackauckas committed Jun 25, 2024
2 parents 2da32c9 + f8ddd52 commit 82b3d32
Show file tree
Hide file tree
Showing 7 changed files with 45 additions and 45 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ StaticArrays = "1.1"
SymPy = "2"
SymbolicIndexingInterface = "0.3.14"
SymbolicLimits = "0.2.0"
SymbolicUtils = "2.0.2"
SymbolicUtils = "2.1"
TermInterface = "0.4"
julia = "1.10"

Expand Down
2 changes: 1 addition & 1 deletion src/build_function.jl
Original file line number Diff line number Diff line change
Expand Up @@ -596,7 +596,7 @@ function numbered_expr(O::Symbolic,varnumbercache,args...;varordering = args[1],
Expr(:ref, toexpr(args[1], states), toexpr.(args[2:end] .+ offset, (states,))...)
else
Expr(:call, Symbol(operation(O)), (numbered_expr(x,varnumbercache,args...;offset=offset,lhsname=lhsname,
rhsnames=rhsnames,varordering=varordering) for x in arguments(O))...)
rhsnames=rhsnames,varordering=varordering) for x in sorted_arguments(O))...)
end
elseif issym(O)
tosymbol(O, escape=false)
Expand Down
10 changes: 5 additions & 5 deletions src/latexify_recipes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ function latexify_derivatives(ex)
integrand
)
elseif x.args[1] === :_textbf
ls = latexify(latexify_derivatives(arguments(x)[1])).s
ls = latexify(latexify_derivatives(sorted_arguments(x)[1])).s
return "\\textbf{" * strip(ls, '\$') * "}"
else
return x
Expand Down Expand Up @@ -134,7 +134,7 @@ function _toexpr(O)

# We need to iterate over each term in m, ignoring the numeric coefficient.
# This iteration needs to be stable, so we can't iterate over m.dict.
for term in Iterators.drop(arguments(m), isone(m.coeff) ? 0 : 1)
for term in Iterators.drop(sorted_arguments(m), isone(m.coeff) ? 0 : 1)
if !ispow(term)
push!(numer, _toexpr(term))
continue
Expand Down Expand Up @@ -182,7 +182,7 @@ function _toexpr(O)
!iscall(O) && return O

op = operation(O)
args = arguments(O)
args = sorted_arguments(O)

if (op===(*)) && (args[1] === -1)
arg_mul = Expr(:call, :(*), _toexpr(args[2:end])...)
Expand Down Expand Up @@ -233,8 +233,8 @@ _toexpr(eqs::AbstractArray) = map(eq->_toexpr(eq), eqs)
_toexpr(x::Num) = _toexpr(value(x))

function getindex_to_symbol(t)
@assert iscall(t) && operation(t) === getindex && symtype(arguments(t)[1]) <: AbstractArray
args = arguments(t)
@assert iscall(t) && operation(t) === getindex && symtype(sorted_arguments(t)[1]) <: AbstractArray
args = sorted_arguments(t)
idxs = args[2:end]
try
sub = join(map(map_subscripts, idxs), "ˏ")
Expand Down
14 changes: 7 additions & 7 deletions src/semipoly.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ function Base.:+(a::SemiMonomial, b::SemiMonomial)
end
function Base.:+(m::SemiMonomial, t)
if iscall(t) && operation(t) == (+)
return Term(+, [unsorted_arguments(t); m])
return Term(+, [arguments(t); m])
end
Term(+, [m, t])
end
Expand All @@ -42,7 +42,7 @@ function Base.:*(m::SemiMonomial, t::Symbolic)
args = collect(all_terms(t))
return Term(+, (m,) .* args)
elseif op == (*)
return Term(*, [unsorted_arguments(t); m])
return Term(*, [arguments(t); m])
end
end
Term(*, [t, m])
Expand Down Expand Up @@ -151,7 +151,7 @@ function mark_and_exponentiate(expr, vars)
@rule (~a::isop(+))^(~b::isreal) => expand(Pow((~a), real(~b)))
@rule *(~~xs::(xs -> all(issemimonomial, xs))) => *(~~xs...)
@rule *(~~xs::(xs -> any(isop(+), xs))) => expand(Term(*, ~~xs))
@rule (~a::isop(+)) / (~b::issemimonomial) => +(map(x->x/~b, unsorted_arguments(~a))...)
@rule (~a::isop(+)) / (~b::issemimonomial) => +(map(x->x/~b, arguments(~a))...)
@rule (~a::issemimonomial) / (~b::issemimonomial) => (~a) / (~b)]
expr′ = Postwalk(RestartedChain(rules), maketerm = simpleterm)(expr′)
end
Expand All @@ -178,7 +178,7 @@ function has_vars(expr, vars)::Bool
if expr in vars
return true
elseif iscall(expr)
for arg in unsorted_arguments(expr)
for arg in arguments(expr)
if has_vars(arg, vars)
return true
end
Expand All @@ -199,7 +199,7 @@ function mark_vars(expr, vars)
@assert length(args) == 2
return Term{symtype(expr)}(op, map(mark_vars(vars), args))
end
args = unsorted_arguments(expr)
args = arguments(expr)
if op === (+) || op === (*)
return Term{symtype(expr)}(op, map(mark_vars(vars), args))
elseif length(args) == 1
Expand Down Expand Up @@ -375,7 +375,7 @@ function semiquadratic_form(exprs, vars)
push!(V2, v)
else
@assert isop(k, *)
a, b = unsorted_arguments(k)
a, b = arguments(k)
p, q = extrema((idxmap[a], idxmap[b]))
j = div(q*(q-1), 2) + p
push!(J2, j)
Expand Down Expand Up @@ -403,7 +403,7 @@ end

## Utilities

all_terms(x) = iscall(x) && operation(x) == (+) ? collect(Iterators.flatten(map(all_terms, unsorted_arguments(x)))) : (x,)
all_terms(x) = iscall(x) && operation(x) == (+) ? collect(Iterators.flatten(map(all_terms, arguments(x)))) : (x,)

function unwrap_sp(m::SemiMonomial)
degree_dict = pdegrees(m.p)
Expand Down
56 changes: 28 additions & 28 deletions src/solver.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,8 @@ function get_parts_list(a, b, a_list = Vector{Any}(), b_list = Vector{Any}())
push!(a_list, a)
push!(b_list, b)
elseif iscall(a) && iscall(b) && isequal(operation(a), operation(b))
a_args = arguments(a)
b_args = arguments(b)
a_args = sorted_arguments(a)
b_args = sorted_arguments(b)

length(a_args) != length(b_args) && return Nothing

Expand Down Expand Up @@ -163,7 +163,7 @@ function replace_term(expr, dic::Dict)
elseif iscall(expr)
args = Any[]

for arg in arguments(expr)
for arg in sorted_arguments(expr)
push!(args, replace_term(arg, dic))
end

Expand Down Expand Up @@ -205,11 +205,11 @@ function expr_similar(ref_expr, expr, check_matches = true)
SymbolicUtils.issym(expr) && iscall(ref_expr) && return false

if iscall(ref_expr)
ref_args = arguments(ref_expr)
ref_args = sorted_arguments(ref_expr)
ref_len = length(ref_args)
ref_op = operation(ref_expr)

args = arguments(expr)
args = sorted_arguments(expr)
len = length(args)
op = operation(expr)

Expand Down Expand Up @@ -250,12 +250,12 @@ end

function get_base(expr)
(!iscall(expr) || operation(expr) != (^)) && throw("not a power(^) -> $expr")
return arguments(expr)[1]
return sorted_arguments(expr)[1]
end

function get_exp(expr)
(!iscall(expr) || operation(expr) != (^)) && throw("not a power(^) -> $expr")
return arguments(expr)[2]
return sorted_arguments(expr)[2]
end

function solve_single_eq_unchecked(
Expand Down Expand Up @@ -326,10 +326,10 @@ end
function left_prod_right_zero(eq::Equation, var, single_solution)
if SymbolicUtils.ismul(eq.lhs) && isequal(0, eq.rhs)
if (single_solution)
eq = arguments(eq.lhs)[1] ~ 0
eq = sorted_arguments(eq.lhs)[1] ~ 0
else
solutions = Equation[]
for arg in arguments(eq.lhs)
for arg in sorted_arguments(eq.lhs)
temp = solve_single_eq(arg ~ 0, var)
temp = temp isa Equation ? [temp] : temp
push!(solutions, temp...)
Expand Down Expand Up @@ -371,7 +371,7 @@ function move_to_other_side(eq::Equation, var)
op = operation(eq.lhs)

if op in (+, *)
elements = arguments(eq.lhs)
elements = sorted_arguments(eq.lhs)

stays = []#has variable
move = []#does not have variable
Expand Down Expand Up @@ -424,7 +424,7 @@ function special_strategy(eq::Equation, var)
!iscall(eq.lhs) && return eq#make sure left side is tree form

op = operation(eq.lhs)
elements = arguments(eq.lhs)
elements = sorted_arguments(eq.lhs)

if (op == +) &&
length(elements) == 2 &&
Expand All @@ -446,13 +446,13 @@ function special_strategy(eq::Equation, var)
isequal(eq.rhs, 0) &&
length(elements) == 2 &&
sum(iscall.(elements)) == length(elements) &&
length(arguments(elements[1])) == 2 &&
isequal(arguments(elements[1])[1], -1) &&
iscall(arguments(elements[1])[2]) &&
operation(elements[2]) == operation(arguments(elements[1])[2])#-f(y)+f(x)=0 -> x-y=0
length(sorted_arguments(elements[1])) == 2 &&
isequal(sorted_arguments(elements[1])[1], -1) &&
iscall(sorted_arguments(elements[1])[2]) &&
operation(elements[2]) == operation(sorted_arguments(elements[1])[2])#-f(y)+f(x)=0 -> x-y=0

x = arguments(elements[2])[1]
y = arguments(arguments(elements[1])[2])[1]
x = sorted_arguments(elements[2])[1]
y = sorted_arguments(sorted_arguments(elements[1])[2])[1]

eq = x - y ~ 0
end
Expand All @@ -474,20 +474,20 @@ function reduce_root(a)
end

if iscall(a) && (operation(a) == sqrt)
a = SymbolicUtils.Pow(arguments(a)[1], 1 // 2)
a = SymbolicUtils.Pow(sorted_arguments(a)[1], 1 // 2)
elseif iscall(a) &&
(operation(a) == ^) &&
isequal(arguments(a)[2], 1 // 2) &&
!(arguments(a)[1] isa Number)
a = term(sqrt, arguments(a)[1])
isequal(sorted_arguments(a)[2], 1 // 2) &&
!(sorted_arguments(a)[1] isa Number)
a = term(sqrt, sorted_arguments(a)[1])
end

if iscall(a) &&
(operation(a) == ^) &&
arguments(a)[2] isa Rational &&
isequal((arguments(a)[2]).num, 1)
value = demote_rational(arguments(a)[1])
root = (arguments(a)[2]).den
sorted_arguments(a)[2] isa Rational &&
isequal((sorted_arguments(a)[2]).num, 1)
value = demote_rational(sorted_arguments(a)[1])
root = (sorted_arguments(a)[2]).den

if value isa Integer && value > 0
if isinteger(value^(1.0 / root))
Expand Down Expand Up @@ -596,13 +596,13 @@ function inverse_funcs(eq::Equation, var)

if haskey(inverseOps, op)
inverseOp = inverseOps[op]
inner = arguments(eq.lhs)[1]
inner = sorted_arguments(eq.lhs)[1]
eq = inner ~ term(inverseOp, eq.rhs)
elseif (op == sqrt)
inner = arguments(eq.lhs)[1]
inner = sorted_arguments(eq.lhs)[1]
eq = inner ~ (eq.rhs)^2
elseif (op == lambertw)
inner = arguments(eq.lhs)[1]
inner = sorted_arguments(eq.lhs)[1]
eq = inner ~ eq.rhs * term(exp, eq.rhs)
end

Expand Down
2 changes: 1 addition & 1 deletion src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,7 @@ function coeff(p, sym=nothing)
sum(coeff(k, sym) * v for (k, v) in p.dict)
end
elseif ismul(p)
args = unsorted_arguments(p)
args = arguments(p)
coeffs = map(a->coeff(a, sym), args)
if all(iszero, coeffs)
return 0
Expand Down
4 changes: 2 additions & 2 deletions src/variable.jl
Original file line number Diff line number Diff line change
Expand Up @@ -476,7 +476,7 @@ function fast_substitute(expr, subs; operator = Nothing)
end
iscall(expr) || return expr
op = fast_substitute(operation(expr), subs; operator)
args = SymbolicUtils.unsorted_arguments(expr)
args = SymbolicUtils.arguments(expr)
if !(op isa operator)
canfold = Ref(!(op isa Symbolic))
args = let canfold = canfold
Expand Down Expand Up @@ -504,7 +504,7 @@ function fast_substitute(expr, pair::Pair; operator = Nothing)
end
iscall(expr) || return expr
op = fast_substitute(operation(expr), pair; operator)
args = SymbolicUtils.unsorted_arguments(expr)
args = SymbolicUtils.arguments(expr)
if !(op isa operator)
canfold = Ref(!(op isa Symbolic))
args = let canfold = canfold
Expand Down

0 comments on commit 82b3d32

Please sign in to comment.