Skip to content

Commit

Permalink
review-related minor modifications
Browse files Browse the repository at this point in the history
  • Loading branch information
smichr committed Apr 27, 2014
1 parent c6d3070 commit 79e5d90
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 86 deletions.
65 changes: 26 additions & 39 deletions sympy/solvers/ode.py
Expand Up @@ -597,12 +597,11 @@ def _helper_simplify(eq, hint, match, simplify=True, **kwargs):
# attempt to solve for func, and apply any other hint specific
# simplifications
sols = solvefunc(eq, func, order, match)
free = eq.free_symbols
cons = lambda s: s.free_symbols.difference(free)
if isinstance(sols, C.Expr):
constants = sols.free_symbols.difference(eq.free_symbols)
return odesimp(sols, func, order, constants, hint)
return [odesimp(sol, func, order,
sol.free_symbols.difference(eq.free_symbols), hint)
for sol in sols]
return odesimp(sols, func, order, cons(sols), hint)
return [odesimp(s, func, order, cons(s), hint) for s in sols]
else:
# We still want to integrate (you can disable it separately with the hint)
match['simplify'] = False # Some hints can take advantage of this option
Expand Down Expand Up @@ -1321,7 +1320,7 @@ def odesimp(eq, func, order, constants, hint):
# contract over-expanded exponentials -- this is
# a work-around for collect's bad behavior.
w1, w2 = Wild('w1',exclude=[x]), Wild('w2')
sol = sol.replace( exp(w1*x)**w2, exp(w1*w2*x) )
sol = sol.replace(exp(w1*x)**w2, exp(w1*w2*x))
eq[0] = Eq(f(x), sol)

else:
Expand Down Expand Up @@ -1737,21 +1736,21 @@ def _get_constant_subexpressions(expr, Cs):
Cs = set(Cs)
Ces = []
def _recursive_walk(expr):
expr_syms = expr.atoms(Symbol)
expr_syms = expr.free_symbols
if len(expr_syms) > 0 and expr_syms.issubset(Cs):
Ces.append(expr)
else:
if expr.func == exp:
expr = expr.expand(mul=True)
if expr.func in (Add, Mul):
d = sift(expr.args, lambda i : i.free_symbols.issubset(Cs))
if True in d and len(d[True])>1:
if len(d[True]) > 1:
x = expr.func(*d[True])
if 0 < len(x.atoms(Symbol)):
if not x.is_number:
Ces.append(x)
elif expr.func in (C.Integral, ):
elif isinstance(expr, C.Integral):
if expr.free_symbols.issubset(Cs) and \
all(map( lambda x : len(x) == 3, expr.limits)):
all(map(lambda x: len(x) == 3, expr.limits)):
Ces.append(expr)
for i in expr.args:
_recursive_walk(i)
Expand All @@ -1761,12 +1760,12 @@ def _recursive_walk(expr):

def __remove_linear_redundancies(expr, Cs):
cnts = dict([(i, expr.count(i)) for i in Cs])
Cs = [ i for i in Cs if cnts[i] > 0 ]
Cs = [i for i in Cs if cnts[i] > 0]

def _linear(expr):
if expr.func is Add:
xs = [ i for i in Cs if expr.count(i)==cnts[i] \
and 0 == expr.diff(i, 2) ]
xs = [i for i in Cs if expr.count(i)==cnts[i] \
and 0 == expr.diff(i, 2)]
d = {}
for x in xs:
y = expr.diff(x)
Expand All @@ -1787,18 +1786,14 @@ def _recursive_walk(expr):
return expr

if expr.func is Equality:
lhs = expr.lhs
rhs = expr.rhs
lhs = _recursive_walk(lhs)
rhs = _recursive_walk(rhs)
lhs, rhs = [_recursive_walk(i) for i in expr.args]
f = lambda i: isinstance(i, C.Number) or i in Cs
g = lambda i: isinstance(i, C.Number) or i in Cs
if lhs.func is Symbol and lhs in Cs:
(rhs, lhs) = (lhs, rhs)
rhs, lhs = lhs, rhs
if lhs.func in (Add, Symbol) and rhs.func in (Add, Symbol):
dlhs = sift([lhs] if isinstance(lhs, C.AtomicExpr) else lhs.args, f)
drhs = sift([rhs] if isinstance(rhs, C.AtomicExpr) else rhs.args, f)
for i in [ True, False ]:
for i in [True, False]:
for hs in [dlhs, drhs]:
if i not in hs:
hs[i] = [0]
Expand Down Expand Up @@ -1894,9 +1889,9 @@ def constantsimp(expr, constants):
constant_subexprs = _get_constant_subexpressions(expr, Cs)
for xe in constant_subexprs:
xes = list(xe.free_symbols)
if len(xes) == 0:
if not xes:
continue
if all([ expr.count(c) == xe.count(c) for c in xes ]):
if all([expr.count(c) == xe.count(c) for c in xes]):
xes.sort(key=str)
expr = expr.subs(xe, xes[0])

Expand All @@ -1916,7 +1911,7 @@ def constantsimp(expr, constants):
pass
expr = __remove_linear_redundancies(expr, Cs)

def conditional_term_factoring(expr):
def _conditional_term_factoring(expr):
new_expr = terms_gcd(expr, clear=False, deep=True, expand=False)

# we do not want to factor exponentials, so handle this separately
Expand All @@ -1934,15 +1929,7 @@ def conditional_term_factoring(expr):
break
return new_expr

# XXX terms_gcd sometimes turns an equality into a difference.
# So we explicitly handle equalities seperately. Fix terms_gcd.
if isinstance(expr, Eq):
new_lhs = conditional_term_factoring(expr.lhs)
new_rhs = conditional_term_factoring(expr.rhs)
new_expr = Eq(new_lhs, new_rhs)
else:
new_expr = conditional_term_factoring(expr)
expr = new_expr
expr = _conditional_term_factoring(expr)

# call recursively if more simplification is possible
if orig_expr != expr:
Expand Down Expand Up @@ -2000,7 +1987,7 @@ def constant_renumber(expr, symbolname, startnumber, endnumber):
)
global newstartnumber
newstartnumber = 1
constants_found = [None]*(endnumber+2)
constants_found = [None]*(endnumber + 2)
constantsymbols = [Symbol(
symbolname + "%d" % t) for t in range(startnumber,
endnumber + 1)]
Expand Down Expand Up @@ -2041,12 +2028,12 @@ def _constant_renumber(expr):
*[_constant_renumber(x) for x in expr.args])
else:
sortedargs = list(expr.args)
sortedargs.sort( key=sort_key )
sortedargs.sort(key=sort_key)
return expr.func(*[_constant_renumber(x) for x in sortedargs])
expr = _constant_renumber(expr)
# Renumbering happens here
newconsts = symbols('C1:%d'%newstartnumber)
expr = expr.subs( zip(constants_found[1:], newconsts), simultaneous=True)
newconsts = symbols('C1:%d' % newstartnumber)
expr = expr.subs(zip(constants_found[1:], newconsts), simultaneous=True)
return expr


Expand Down Expand Up @@ -3061,7 +3048,7 @@ def _frobenius(n, m, p0, q0, p, q, x0, x, c, check=None):
dict_ = Poly(list(ordered(tseries.args))[: -1], x).as_dict()
# Fill in with zeros, if coefficients are zero.
for i in range(n + 1):
if (i, ) not in dict_:
if (i,) not in dict_:
dict_[(i,)] = S(0)
serlist.append(dict_)

Expand Down Expand Up @@ -3855,7 +3842,7 @@ def ode_nth_linear_constant_coeff_homogeneous(eq, func, order, match,
chareq += r[i]*symbol**i

chareq = Poly(chareq, symbol)
chareqroots = [ RootOf(chareq, k) for k in range(chareq.degree()) ]
chareqroots = [RootOf(chareq, k) for k in range(chareq.degree())]

# A generator of constants
constants = list(get_numbered_constants(eq, num=chareq.degree()*2))
Expand Down
86 changes: 43 additions & 43 deletions sympy/solvers/tests/test_constantsimp.py
Expand Up @@ -26,54 +26,54 @@ def test_constant_mul():
assert constant_renumber(constantsimp(C1*x, [C1]), 'C', 1, 1) == x*C1
assert constant_renumber(constantsimp(2*C1, [C1]), 'C', 1, 1) == C1
assert constant_renumber(constantsimp(C1*2, [C1]), 'C', 1, 1) == C1
assert constant_renumber(constantsimp(y*C1*x, [C1,y]), 'C', 1, 1) == C1*x
assert constant_renumber(constantsimp(x*y*C1, [C1,y]), 'C', 1, 1) == x*C1
assert constant_renumber(constantsimp(y*x*C1, [C1,y]), 'C', 1, 1) == x*C1
assert constant_renumber(constantsimp(C1*x*y, [C1,y]), 'C', 1, 1) == C1*x
assert constant_renumber(constantsimp(x*C1*y, [C1,y]), 'C', 1, 1) == x*C1
assert constant_renumber(constantsimp(y*C1*x, [C1, y]), 'C', 1, 1) == C1*x
assert constant_renumber(constantsimp(x*y*C1, [C1, y]), 'C', 1, 1) == x*C1
assert constant_renumber(constantsimp(y*x*C1, [C1, y]), 'C', 1, 1) == x*C1
assert constant_renumber(constantsimp(C1*x*y, [C1, y]), 'C', 1, 1) == C1*x
assert constant_renumber(constantsimp(x*C1*y, [C1, y]), 'C', 1, 1) == x*C1
assert constant_renumber(constantsimp(C1*y*(y + 1), [C1]), 'C', 1, 1) == C1*y*(y+1)
assert constant_renumber(constantsimp(y*C1*(y + 1), [C1]), 'C', 1, 1) == C1*y*(y+1)
assert constant_renumber(constantsimp(x*(y*C1), [C1]), 'C', 1, 1) == x*y*C1
assert constant_renumber(constantsimp(x*(C1*y), [C1]), 'C', 1, 1) == x*y*C1
assert constant_renumber(constantsimp(C1*(x*y), [C1,y]), 'C', 1, 1) == C1*x
assert constant_renumber(constantsimp((x*y)*C1, [C1,y]), 'C', 1, 1) == x*C1
assert constant_renumber(constantsimp((y*x)*C1, [C1,y]), 'C', 1, 1) == x*C1
assert constant_renumber(constantsimp(y*(y + 1)*C1, [C1,y]), 'C', 1, 1) == C1
assert constant_renumber(constantsimp((C1*x)*y, [C1,y]), 'C', 1, 1) == C1*x
assert constant_renumber(constantsimp(y*(x*C1), [C1,y]), 'C', 1, 1) == x*C1
assert constant_renumber(constantsimp((x*C1)*y, [C1,y]), 'C', 1, 1) == x*C1
assert constant_renumber(constantsimp(C1*(x*y), [C1, y]), 'C', 1, 1) == C1*x
assert constant_renumber(constantsimp((x*y)*C1, [C1, y]), 'C', 1, 1) == x*C1
assert constant_renumber(constantsimp((y*x)*C1, [C1, y]), 'C', 1, 1) == x*C1
assert constant_renumber(constantsimp(y*(y + 1)*C1, [C1, y]), 'C', 1, 1) == C1
assert constant_renumber(constantsimp((C1*x)*y, [C1, y]), 'C', 1, 1) == C1*x
assert constant_renumber(constantsimp(y*(x*C1), [C1, y]), 'C', 1, 1) == x*C1
assert constant_renumber(constantsimp((x*C1)*y, [C1, y]), 'C', 1, 1) == x*C1
assert constant_renumber(
constantsimp(C1*x*y*x*y*2, [C1,y]), 'C', 1, 1) == C1*x**2
assert constant_renumber(constantsimp(C1*x*y*z, [C1,y,z]), 'C', 1, 1) == C1*x
constantsimp(C1*x*y*x*y*2, [C1, y]), 'C', 1, 1) == C1*x**2
assert constant_renumber(constantsimp(C1*x*y*z, [C1, y, z]), 'C', 1, 1) == C1*x
assert constant_renumber(
constantsimp(C1*x*y**2*sin(z), [C1,y,z]), 'C', 1, 1) == C1*x
constantsimp(C1*x*y**2*sin(z), [C1, y, z]), 'C', 1, 1) == C1*x
assert constant_renumber(constantsimp(C1*C1, [C1]), 'C', 1, 1) == C1
assert constant_renumber(constantsimp(C1*C2, [C1,C2]), 'C', 1, 2) == C1
assert constant_renumber(constantsimp(C2*C2, [C1,C2]), 'C', 1, 2) == C1
assert constant_renumber(constantsimp(C1*C1*C2, [C1,C2]), 'C', 1, 2) == C1
assert constant_renumber(constantsimp(C1*C2, [C1, C2]), 'C', 1, 2) == C1
assert constant_renumber(constantsimp(C2*C2, [C1, C2]), 'C', 1, 2) == C1
assert constant_renumber(constantsimp(C1*C1*C2, [C1, C2]), 'C', 1, 2) == C1
assert constant_renumber(
constantsimp(C1*x*2**x, [C1]), 'C', 1, 1) == C1*x*2**x

def test_constant_add():
assert constant_renumber(constantsimp(C1 + C1, [C1]), 'C', 1, 1) == C1
assert constant_renumber(constantsimp(C1 + 2, [C1]), 'C', 1, 1) == C1
assert constant_renumber(constantsimp(2 + C1, [C1]), 'C', 1, 1) == C1
assert constant_renumber(constantsimp(C1 + y, [C1,y]), 'C', 1, 1) == C1
assert constant_renumber(constantsimp(C1 + y, [C1, y]), 'C', 1, 1) == C1
assert constant_renumber(constantsimp(C1 + x, [C1]), 'C', 1, 1) == C1 + x
assert constant_renumber(constantsimp(C1 + C1, [C1]), 'C', 1, 1) == C1
assert constant_renumber(constantsimp(C1 + C2, [C1,C2]), 'C', 1, 2) == C1
assert constant_renumber(constantsimp(C2 + C1, [C1,C2]), 'C', 1, 2) == C1
assert constant_renumber(constantsimp(C1 + C2 + C1, [C1,C2]), 'C', 1, 2) == C1
assert constant_renumber(constantsimp(C1 + C2, [C1, C2]), 'C', 1, 2) == C1
assert constant_renumber(constantsimp(C2 + C1, [C1, C2]), 'C', 1, 2) == C1
assert constant_renumber(constantsimp(C1 + C2 + C1, [C1, C2]), 'C', 1, 2) == C1


def test_constant_power_as_base():
assert constant_renumber(constantsimp(C1**C1, [C1]), 'C', 1, 1) == C1
assert constant_renumber(constantsimp(Pow(C1, C1), [C1]), 'C', 1, 1) == C1
assert constant_renumber(constantsimp(C1**C1, [C1]), 'C', 1, 1) == C1
assert constant_renumber(constantsimp(C1**C2, [C1,C2]), 'C', 1, 2) == C1
assert constant_renumber(constantsimp(C2**C1, [C1,C2]), 'C', 1, 2) == C1
assert constant_renumber(constantsimp(C2**C2, [C1,C2]), 'C', 1, 2) == C1
assert constant_renumber(constantsimp(C1**y, [C1,y]), 'C', 1, 1) == C1
assert constant_renumber(constantsimp(C1**C2, [C1, C2]), 'C', 1, 2) == C1
assert constant_renumber(constantsimp(C2**C1, [C1, C2]), 'C', 1, 2) == C1
assert constant_renumber(constantsimp(C2**C2, [C1, C2]), 'C', 1, 2) == C1
assert constant_renumber(constantsimp(C1**y, [C1, y]), 'C', 1, 1) == C1
assert constant_renumber(constantsimp(C1**x, [C1]), 'C', 1, 1) == C1**x
assert constant_renumber(constantsimp(C1**2, [C1]), 'C', 1, 1) == C1
assert constant_renumber(
Expand All @@ -82,15 +82,15 @@ def test_constant_power_as_base():

def test_constant_power_as_exp():
assert constant_renumber(constantsimp(x**C1, [C1]), 'C', 1, 1) == x**C1
assert constant_renumber(constantsimp(y**C1, [C1,y]), 'C', 1, 1) == C1
assert constant_renumber(constantsimp(x**y**C1, [C1,y]), 'C', 1, 1) == x**C1
assert constant_renumber(constantsimp(y**C1, [C1, y]), 'C', 1, 1) == C1
assert constant_renumber(constantsimp(x**y**C1, [C1, y]), 'C', 1, 1) == x**C1
assert constant_renumber(
constantsimp((x**y)**C1, [C1]), 'C', 1, 1) == (x**y)**C1
assert constant_renumber(
constantsimp(x**(y**C1), [C1,y]), 'C', 1, 1) == x**C1
assert constant_renumber(constantsimp(x**C1**y, [C1,y]), 'C', 1, 1) == x**C1
constantsimp(x**(y**C1), [C1, y]), 'C', 1, 1) == x**C1
assert constant_renumber(constantsimp(x**C1**y, [C1, y]), 'C', 1, 1) == x**C1
assert constant_renumber(
constantsimp(x**(C1**y), [C1,y]), 'C', 1, 1) == x**C1
constantsimp(x**(C1**y), [C1, y]), 'C', 1, 1) == x**C1
assert constant_renumber(
constantsimp((x**C1)**y, [C1]), 'C', 1, 1) == (x**C1)**y
assert constant_renumber(constantsimp(2**C1, [C1]), 'C', 1, 1) == C1
Expand All @@ -105,14 +105,14 @@ def test_constant_function():
assert constant_renumber(constantsimp(sin(C1), [C1]), 'C', 1, 1) == C1
assert constant_renumber(constantsimp(f(C1), [C1]), 'C', 1, 1) == C1
assert constant_renumber(constantsimp(f(C1, C1), [C1]), 'C', 1, 1) == C1
assert constant_renumber(constantsimp(f(C1, C2), [C1,C2]), 'C', 1, 2) == C1
assert constant_renumber(constantsimp(f(C2, C1), [C1,C2]), 'C', 1, 2) == C1
assert constant_renumber(constantsimp(f(C2, C2), [C1,C2]), 'C', 1, 2) == C1
assert constant_renumber(constantsimp(f(C1, C2), [C1, C2]), 'C', 1, 2) == C1
assert constant_renumber(constantsimp(f(C2, C1), [C1, C2]), 'C', 1, 2) == C1
assert constant_renumber(constantsimp(f(C2, C2), [C1, C2]), 'C', 1, 2) == C1
assert constant_renumber(
constantsimp(f(C1, x), [C1]), 'C', 1, 2) == f(C1, x)
assert constant_renumber(constantsimp(f(C1, y), [C1,y]), 'C', 1, 2) == C1
assert constant_renumber(constantsimp(f(y, C1), [C1,y]), 'C', 1, 2) == C1
assert constant_renumber(constantsimp(f(C1, y, C2), [C1,C2,y]), 'C', 1, 2) == C1
assert constant_renumber(constantsimp(f(C1, y), [C1, y]), 'C', 1, 2) == C1
assert constant_renumber(constantsimp(f(y, C1), [C1, y]), 'C', 1, 2) == C1
assert constant_renumber(constantsimp(f(C1, y, C2), [C1, C2, y]), 'C', 1, 2) == C1


def test_constant_function_multiple():
Expand All @@ -128,17 +128,17 @@ def test_constant_multiple():
assert constant_renumber(constantsimp(C1**2*2 + 2, [C1]), 'C', 1, 1) == C1
assert constant_renumber(
constantsimp(sin(2*C1) + x + sqrt(2), [C1]), 'C', 1, 1) == C1 + x
assert constant_renumber(constantsimp(2*C1 + C2, [C1,C2]), 'C', 1, 2) == C1
assert constant_renumber(constantsimp(2*C1 + C2, [C1, C2]), 'C', 1, 2) == C1

def test_constant_repeated():
assert C1 + C1*x == constant_renumber( C1 + C1*x, 'C', 1, 3)

def test_ode_solutions():
# only a few examples here, the rest will be tested in the actual dsolve tests
assert constant_renumber(constantsimp(C1*exp(2*x) + exp(x)*(C2 + C3), [C1,C2,C3]), 'C', 1, 3) == \
assert constant_renumber(constantsimp(C1*exp(2*x) + exp(x)*(C2 + C3), [C1, C2, C3]), 'C', 1, 3) == \
constant_renumber((C1*exp(x) + C2*exp(2*x)), 'C', 1, 2)
assert constant_renumber(
constantsimp(Eq(f(x), I*C1*sinh(x/3) + C2*cosh(x/3)), [C1,C2]),
constantsimp(Eq(f(x), I*C1*sinh(x/3) + C2*cosh(x/3)), [C1, C2]),
'C', 1, 2) == constant_renumber(Eq(f(x), C1*sinh(x/3) + C2*cosh(x/3)), 'C', 1, 2)
assert constant_renumber(constantsimp(Eq(f(x), acos((-C1)/cos(x))), [C1]), 'C', 1, 1) == \
Eq(f(x), acos(C1/cos(x)))
Expand All @@ -164,10 +164,10 @@ def test_ode_solutions():

@XFAIL
def test_nonlocal_simplification():
assert constantsimp(C1 + C2+x*C2, [C1,C2]) == C1 + C2*x
assert constantsimp(C1 + C2+x*C2, [C1, C2]) == C1 + C2*x


def test_constant_Eq():
# C1 on the rhs is well-tested, but the lhs is only tested here
assert constantsimp(Eq(C1, 3 + f(x)*x), [C1]) == Eq(x*f(x),C1)
assert constantsimp(Eq(C1, 3 * f(x)*x), [C1]) == Eq(f(x)*x,C1)
assert constantsimp(Eq(C1, 3 + f(x)*x), [C1]) == Eq(x*f(x), C1)
assert constantsimp(Eq(C1, 3 * f(x)*x), [C1]) == Eq(f(x)*x, C1)
4 changes: 2 additions & 2 deletions sympy/solvers/tests/test_ode.py
Expand Up @@ -255,7 +255,7 @@ def test_old_ode_tests():
sol4 = Eq(f(x), C1*sin(x/3) + C2*cos(x/3))
sol5 = Eq(f(x), C1*exp(-x/3) + C2*exp(x/3))
sol6 = Eq(f(x), (C1 - cos(x))/x**3)
sol7 = Eq(f(x), C1*exp(x) + C2*exp(2*x))
sol7 = Eq(f(x), (C1 + C2*exp(x))*exp(x))
sol8 = Eq(f(x), (C1 + C2*x)*exp(2*x))
sol9 = Eq(f(x), (C1*sin(x*sqrt(2)) + C2*cos(x*sqrt(2)))*exp(-x))
sol10 = Eq(f(x), C1 + x/3)
Expand Down Expand Up @@ -709,7 +709,7 @@ def test_nth_linear_constant_coeff_homogeneous():
eq29 = f(x).diff(x, 4) + 4*f(x).diff(x, 2)
eq30 = f(x).diff(x, 5) + 2*f(x).diff(x, 3) + f(x).diff(x)
sol1 = Eq(f(x), C1 + C2*exp(-2*x))
sol2 = Eq(f(x), (C1*exp(x) + C2*exp(2*x)))
sol2 = Eq(f(x), (C1 + C2*exp(x))*exp(x))
sol3 = Eq(f(x), C1*exp(x) + C2*exp(-x))
sol4 = Eq(f(x), C1 + C2*exp(-3*x) + C3*exp(2*x))
sol5 = Eq(f(x), C1*exp(x/2) + C2*exp(4*x/3))
Expand Down
4 changes: 2 additions & 2 deletions sympy/utilities/tests/test_iterables.py
Expand Up @@ -168,8 +168,8 @@ def test_cartes():
def test_numbered_symbols():
s = numbered_symbols(cls=Dummy)
assert isinstance(next(s), Dummy)
assert next(numbered_symbols('C', start=1, exclude=[Symbol('C1')])) == \
Symbol('C2')
assert next(numbered_symbols('C', start=1, exclude=[symbols('C1')])) == \
symbols('C2')


def test_sift():
Expand Down

0 comments on commit 79e5d90

Please sign in to comment.