From 79e5d905daba334c192afbaf10adb6456ab34e89 Mon Sep 17 00:00:00 2001 From: Christopher Smith Date: Wed, 16 Apr 2014 16:34:26 -0500 Subject: [PATCH] review-related minor modifications --- sympy/solvers/ode.py | 65 +++++++----------- sympy/solvers/tests/test_constantsimp.py | 86 ++++++++++++------------ sympy/solvers/tests/test_ode.py | 4 +- sympy/utilities/tests/test_iterables.py | 4 +- 4 files changed, 73 insertions(+), 86 deletions(-) diff --git a/sympy/solvers/ode.py b/sympy/solvers/ode.py index 842f8262db7d..3ce0fcade586 100644 --- a/sympy/solvers/ode.py +++ b/sympy/solvers/ode.py @@ -597,12 +597,11 @@ def _helper_simplify(eq, hint, match, simplify=True, **kwargs): # attempt to solve for func, and apply any other hint specific # simplifications sols = solvefunc(eq, func, order, match) + free = eq.free_symbols + cons = lambda s: s.free_symbols.difference(free) if isinstance(sols, C.Expr): - constants = sols.free_symbols.difference(eq.free_symbols) - return odesimp(sols, func, order, constants, hint) - return [odesimp(sol, func, order, - sol.free_symbols.difference(eq.free_symbols), hint) - for sol in sols] + return odesimp(sols, func, order, cons(sols), hint) + return [odesimp(s, func, order, cons(s), hint) for s in sols] else: # We still want to integrate (you can disable it separately with the hint) match['simplify'] = False # Some hints can take advantage of this option @@ -1321,7 +1320,7 @@ def odesimp(eq, func, order, constants, hint): # contract over-expanded exponentials -- this is # a work-around for collect's bad behavior. w1, w2 = Wild('w1',exclude=[x]), Wild('w2') - sol = sol.replace( exp(w1*x)**w2, exp(w1*w2*x) ) + sol = sol.replace(exp(w1*x)**w2, exp(w1*w2*x)) eq[0] = Eq(f(x), sol) else: @@ -1737,7 +1736,7 @@ def _get_constant_subexpressions(expr, Cs): Cs = set(Cs) Ces = [] def _recursive_walk(expr): - expr_syms = expr.atoms(Symbol) + expr_syms = expr.free_symbols if len(expr_syms) > 0 and expr_syms.issubset(Cs): Ces.append(expr) else: @@ -1745,13 +1744,13 @@ def _recursive_walk(expr): expr = expr.expand(mul=True) if expr.func in (Add, Mul): d = sift(expr.args, lambda i : i.free_symbols.issubset(Cs)) - if True in d and len(d[True])>1: + if len(d[True]) > 1: x = expr.func(*d[True]) - if 0 < len(x.atoms(Symbol)): + if not x.is_number: Ces.append(x) - elif expr.func in (C.Integral, ): + elif isinstance(expr, C.Integral): if expr.free_symbols.issubset(Cs) and \ - all(map( lambda x : len(x) == 3, expr.limits)): + all(map(lambda x: len(x) == 3, expr.limits)): Ces.append(expr) for i in expr.args: _recursive_walk(i) @@ -1761,12 +1760,12 @@ def _recursive_walk(expr): def __remove_linear_redundancies(expr, Cs): cnts = dict([(i, expr.count(i)) for i in Cs]) - Cs = [ i for i in Cs if cnts[i] > 0 ] + Cs = [i for i in Cs if cnts[i] > 0] def _linear(expr): if expr.func is Add: - xs = [ i for i in Cs if expr.count(i)==cnts[i] \ - and 0 == expr.diff(i, 2) ] + xs = [i for i in Cs if expr.count(i)==cnts[i] \ + and 0 == expr.diff(i, 2)] d = {} for x in xs: y = expr.diff(x) @@ -1787,18 +1786,14 @@ def _recursive_walk(expr): return expr if expr.func is Equality: - lhs = expr.lhs - rhs = expr.rhs - lhs = _recursive_walk(lhs) - rhs = _recursive_walk(rhs) + lhs, rhs = [_recursive_walk(i) for i in expr.args] f = lambda i: isinstance(i, C.Number) or i in Cs - g = lambda i: isinstance(i, C.Number) or i in Cs if lhs.func is Symbol and lhs in Cs: - (rhs, lhs) = (lhs, rhs) + rhs, lhs = lhs, rhs if lhs.func in (Add, Symbol) and rhs.func in (Add, Symbol): dlhs = sift([lhs] if isinstance(lhs, C.AtomicExpr) else lhs.args, f) drhs = sift([rhs] if isinstance(rhs, C.AtomicExpr) else rhs.args, f) - for i in [ True, False ]: + for i in [True, False]: for hs in [dlhs, drhs]: if i not in hs: hs[i] = [0] @@ -1894,9 +1889,9 @@ def constantsimp(expr, constants): constant_subexprs = _get_constant_subexpressions(expr, Cs) for xe in constant_subexprs: xes = list(xe.free_symbols) - if len(xes) == 0: + if not xes: continue - if all([ expr.count(c) == xe.count(c) for c in xes ]): + if all([expr.count(c) == xe.count(c) for c in xes]): xes.sort(key=str) expr = expr.subs(xe, xes[0]) @@ -1916,7 +1911,7 @@ def constantsimp(expr, constants): pass expr = __remove_linear_redundancies(expr, Cs) - def conditional_term_factoring(expr): + def _conditional_term_factoring(expr): new_expr = terms_gcd(expr, clear=False, deep=True, expand=False) # we do not want to factor exponentials, so handle this separately @@ -1934,15 +1929,7 @@ def conditional_term_factoring(expr): break return new_expr - # XXX terms_gcd sometimes turns an equality into a difference. - # So we explicitly handle equalities seperately. Fix terms_gcd. - if isinstance(expr, Eq): - new_lhs = conditional_term_factoring(expr.lhs) - new_rhs = conditional_term_factoring(expr.rhs) - new_expr = Eq(new_lhs, new_rhs) - else: - new_expr = conditional_term_factoring(expr) - expr = new_expr + expr = _conditional_term_factoring(expr) # call recursively if more simplification is possible if orig_expr != expr: @@ -2000,7 +1987,7 @@ def constant_renumber(expr, symbolname, startnumber, endnumber): ) global newstartnumber newstartnumber = 1 - constants_found = [None]*(endnumber+2) + constants_found = [None]*(endnumber + 2) constantsymbols = [Symbol( symbolname + "%d" % t) for t in range(startnumber, endnumber + 1)] @@ -2041,12 +2028,12 @@ def _constant_renumber(expr): *[_constant_renumber(x) for x in expr.args]) else: sortedargs = list(expr.args) - sortedargs.sort( key=sort_key ) + sortedargs.sort(key=sort_key) return expr.func(*[_constant_renumber(x) for x in sortedargs]) expr = _constant_renumber(expr) # Renumbering happens here - newconsts = symbols('C1:%d'%newstartnumber) - expr = expr.subs( zip(constants_found[1:], newconsts), simultaneous=True) + newconsts = symbols('C1:%d' % newstartnumber) + expr = expr.subs(zip(constants_found[1:], newconsts), simultaneous=True) return expr @@ -3061,7 +3048,7 @@ def _frobenius(n, m, p0, q0, p, q, x0, x, c, check=None): dict_ = Poly(list(ordered(tseries.args))[: -1], x).as_dict() # Fill in with zeros, if coefficients are zero. for i in range(n + 1): - if (i, ) not in dict_: + if (i,) not in dict_: dict_[(i,)] = S(0) serlist.append(dict_) @@ -3855,7 +3842,7 @@ def ode_nth_linear_constant_coeff_homogeneous(eq, func, order, match, chareq += r[i]*symbol**i chareq = Poly(chareq, symbol) - chareqroots = [ RootOf(chareq, k) for k in range(chareq.degree()) ] + chareqroots = [RootOf(chareq, k) for k in range(chareq.degree())] # A generator of constants constants = list(get_numbered_constants(eq, num=chareq.degree()*2)) diff --git a/sympy/solvers/tests/test_constantsimp.py b/sympy/solvers/tests/test_constantsimp.py index 5f00f0e41fbc..025dba300459 100644 --- a/sympy/solvers/tests/test_constantsimp.py +++ b/sympy/solvers/tests/test_constantsimp.py @@ -26,31 +26,31 @@ def test_constant_mul(): assert constant_renumber(constantsimp(C1*x, [C1]), 'C', 1, 1) == x*C1 assert constant_renumber(constantsimp(2*C1, [C1]), 'C', 1, 1) == C1 assert constant_renumber(constantsimp(C1*2, [C1]), 'C', 1, 1) == C1 - assert constant_renumber(constantsimp(y*C1*x, [C1,y]), 'C', 1, 1) == C1*x - assert constant_renumber(constantsimp(x*y*C1, [C1,y]), 'C', 1, 1) == x*C1 - assert constant_renumber(constantsimp(y*x*C1, [C1,y]), 'C', 1, 1) == x*C1 - assert constant_renumber(constantsimp(C1*x*y, [C1,y]), 'C', 1, 1) == C1*x - assert constant_renumber(constantsimp(x*C1*y, [C1,y]), 'C', 1, 1) == x*C1 + assert constant_renumber(constantsimp(y*C1*x, [C1, y]), 'C', 1, 1) == C1*x + assert constant_renumber(constantsimp(x*y*C1, [C1, y]), 'C', 1, 1) == x*C1 + assert constant_renumber(constantsimp(y*x*C1, [C1, y]), 'C', 1, 1) == x*C1 + assert constant_renumber(constantsimp(C1*x*y, [C1, y]), 'C', 1, 1) == C1*x + assert constant_renumber(constantsimp(x*C1*y, [C1, y]), 'C', 1, 1) == x*C1 assert constant_renumber(constantsimp(C1*y*(y + 1), [C1]), 'C', 1, 1) == C1*y*(y+1) assert constant_renumber(constantsimp(y*C1*(y + 1), [C1]), 'C', 1, 1) == C1*y*(y+1) assert constant_renumber(constantsimp(x*(y*C1), [C1]), 'C', 1, 1) == x*y*C1 assert constant_renumber(constantsimp(x*(C1*y), [C1]), 'C', 1, 1) == x*y*C1 - assert constant_renumber(constantsimp(C1*(x*y), [C1,y]), 'C', 1, 1) == C1*x - assert constant_renumber(constantsimp((x*y)*C1, [C1,y]), 'C', 1, 1) == x*C1 - assert constant_renumber(constantsimp((y*x)*C1, [C1,y]), 'C', 1, 1) == x*C1 - assert constant_renumber(constantsimp(y*(y + 1)*C1, [C1,y]), 'C', 1, 1) == C1 - assert constant_renumber(constantsimp((C1*x)*y, [C1,y]), 'C', 1, 1) == C1*x - assert constant_renumber(constantsimp(y*(x*C1), [C1,y]), 'C', 1, 1) == x*C1 - assert constant_renumber(constantsimp((x*C1)*y, [C1,y]), 'C', 1, 1) == x*C1 + assert constant_renumber(constantsimp(C1*(x*y), [C1, y]), 'C', 1, 1) == C1*x + assert constant_renumber(constantsimp((x*y)*C1, [C1, y]), 'C', 1, 1) == x*C1 + assert constant_renumber(constantsimp((y*x)*C1, [C1, y]), 'C', 1, 1) == x*C1 + assert constant_renumber(constantsimp(y*(y + 1)*C1, [C1, y]), 'C', 1, 1) == C1 + assert constant_renumber(constantsimp((C1*x)*y, [C1, y]), 'C', 1, 1) == C1*x + assert constant_renumber(constantsimp(y*(x*C1), [C1, y]), 'C', 1, 1) == x*C1 + assert constant_renumber(constantsimp((x*C1)*y, [C1, y]), 'C', 1, 1) == x*C1 assert constant_renumber( - constantsimp(C1*x*y*x*y*2, [C1,y]), 'C', 1, 1) == C1*x**2 - assert constant_renumber(constantsimp(C1*x*y*z, [C1,y,z]), 'C', 1, 1) == C1*x + constantsimp(C1*x*y*x*y*2, [C1, y]), 'C', 1, 1) == C1*x**2 + assert constant_renumber(constantsimp(C1*x*y*z, [C1, y, z]), 'C', 1, 1) == C1*x assert constant_renumber( - constantsimp(C1*x*y**2*sin(z), [C1,y,z]), 'C', 1, 1) == C1*x + constantsimp(C1*x*y**2*sin(z), [C1, y, z]), 'C', 1, 1) == C1*x assert constant_renumber(constantsimp(C1*C1, [C1]), 'C', 1, 1) == C1 - assert constant_renumber(constantsimp(C1*C2, [C1,C2]), 'C', 1, 2) == C1 - assert constant_renumber(constantsimp(C2*C2, [C1,C2]), 'C', 1, 2) == C1 - assert constant_renumber(constantsimp(C1*C1*C2, [C1,C2]), 'C', 1, 2) == C1 + assert constant_renumber(constantsimp(C1*C2, [C1, C2]), 'C', 1, 2) == C1 + assert constant_renumber(constantsimp(C2*C2, [C1, C2]), 'C', 1, 2) == C1 + assert constant_renumber(constantsimp(C1*C1*C2, [C1, C2]), 'C', 1, 2) == C1 assert constant_renumber( constantsimp(C1*x*2**x, [C1]), 'C', 1, 1) == C1*x*2**x @@ -58,22 +58,22 @@ def test_constant_add(): assert constant_renumber(constantsimp(C1 + C1, [C1]), 'C', 1, 1) == C1 assert constant_renumber(constantsimp(C1 + 2, [C1]), 'C', 1, 1) == C1 assert constant_renumber(constantsimp(2 + C1, [C1]), 'C', 1, 1) == C1 - assert constant_renumber(constantsimp(C1 + y, [C1,y]), 'C', 1, 1) == C1 + assert constant_renumber(constantsimp(C1 + y, [C1, y]), 'C', 1, 1) == C1 assert constant_renumber(constantsimp(C1 + x, [C1]), 'C', 1, 1) == C1 + x assert constant_renumber(constantsimp(C1 + C1, [C1]), 'C', 1, 1) == C1 - assert constant_renumber(constantsimp(C1 + C2, [C1,C2]), 'C', 1, 2) == C1 - assert constant_renumber(constantsimp(C2 + C1, [C1,C2]), 'C', 1, 2) == C1 - assert constant_renumber(constantsimp(C1 + C2 + C1, [C1,C2]), 'C', 1, 2) == C1 + assert constant_renumber(constantsimp(C1 + C2, [C1, C2]), 'C', 1, 2) == C1 + assert constant_renumber(constantsimp(C2 + C1, [C1, C2]), 'C', 1, 2) == C1 + assert constant_renumber(constantsimp(C1 + C2 + C1, [C1, C2]), 'C', 1, 2) == C1 def test_constant_power_as_base(): assert constant_renumber(constantsimp(C1**C1, [C1]), 'C', 1, 1) == C1 assert constant_renumber(constantsimp(Pow(C1, C1), [C1]), 'C', 1, 1) == C1 assert constant_renumber(constantsimp(C1**C1, [C1]), 'C', 1, 1) == C1 - assert constant_renumber(constantsimp(C1**C2, [C1,C2]), 'C', 1, 2) == C1 - assert constant_renumber(constantsimp(C2**C1, [C1,C2]), 'C', 1, 2) == C1 - assert constant_renumber(constantsimp(C2**C2, [C1,C2]), 'C', 1, 2) == C1 - assert constant_renumber(constantsimp(C1**y, [C1,y]), 'C', 1, 1) == C1 + assert constant_renumber(constantsimp(C1**C2, [C1, C2]), 'C', 1, 2) == C1 + assert constant_renumber(constantsimp(C2**C1, [C1, C2]), 'C', 1, 2) == C1 + assert constant_renumber(constantsimp(C2**C2, [C1, C2]), 'C', 1, 2) == C1 + assert constant_renumber(constantsimp(C1**y, [C1, y]), 'C', 1, 1) == C1 assert constant_renumber(constantsimp(C1**x, [C1]), 'C', 1, 1) == C1**x assert constant_renumber(constantsimp(C1**2, [C1]), 'C', 1, 1) == C1 assert constant_renumber( @@ -82,15 +82,15 @@ def test_constant_power_as_base(): def test_constant_power_as_exp(): assert constant_renumber(constantsimp(x**C1, [C1]), 'C', 1, 1) == x**C1 - assert constant_renumber(constantsimp(y**C1, [C1,y]), 'C', 1, 1) == C1 - assert constant_renumber(constantsimp(x**y**C1, [C1,y]), 'C', 1, 1) == x**C1 + assert constant_renumber(constantsimp(y**C1, [C1, y]), 'C', 1, 1) == C1 + assert constant_renumber(constantsimp(x**y**C1, [C1, y]), 'C', 1, 1) == x**C1 assert constant_renumber( constantsimp((x**y)**C1, [C1]), 'C', 1, 1) == (x**y)**C1 assert constant_renumber( - constantsimp(x**(y**C1), [C1,y]), 'C', 1, 1) == x**C1 - assert constant_renumber(constantsimp(x**C1**y, [C1,y]), 'C', 1, 1) == x**C1 + constantsimp(x**(y**C1), [C1, y]), 'C', 1, 1) == x**C1 + assert constant_renumber(constantsimp(x**C1**y, [C1, y]), 'C', 1, 1) == x**C1 assert constant_renumber( - constantsimp(x**(C1**y), [C1,y]), 'C', 1, 1) == x**C1 + constantsimp(x**(C1**y), [C1, y]), 'C', 1, 1) == x**C1 assert constant_renumber( constantsimp((x**C1)**y, [C1]), 'C', 1, 1) == (x**C1)**y assert constant_renumber(constantsimp(2**C1, [C1]), 'C', 1, 1) == C1 @@ -105,14 +105,14 @@ def test_constant_function(): assert constant_renumber(constantsimp(sin(C1), [C1]), 'C', 1, 1) == C1 assert constant_renumber(constantsimp(f(C1), [C1]), 'C', 1, 1) == C1 assert constant_renumber(constantsimp(f(C1, C1), [C1]), 'C', 1, 1) == C1 - assert constant_renumber(constantsimp(f(C1, C2), [C1,C2]), 'C', 1, 2) == C1 - assert constant_renumber(constantsimp(f(C2, C1), [C1,C2]), 'C', 1, 2) == C1 - assert constant_renumber(constantsimp(f(C2, C2), [C1,C2]), 'C', 1, 2) == C1 + assert constant_renumber(constantsimp(f(C1, C2), [C1, C2]), 'C', 1, 2) == C1 + assert constant_renumber(constantsimp(f(C2, C1), [C1, C2]), 'C', 1, 2) == C1 + assert constant_renumber(constantsimp(f(C2, C2), [C1, C2]), 'C', 1, 2) == C1 assert constant_renumber( constantsimp(f(C1, x), [C1]), 'C', 1, 2) == f(C1, x) - assert constant_renumber(constantsimp(f(C1, y), [C1,y]), 'C', 1, 2) == C1 - assert constant_renumber(constantsimp(f(y, C1), [C1,y]), 'C', 1, 2) == C1 - assert constant_renumber(constantsimp(f(C1, y, C2), [C1,C2,y]), 'C', 1, 2) == C1 + assert constant_renumber(constantsimp(f(C1, y), [C1, y]), 'C', 1, 2) == C1 + assert constant_renumber(constantsimp(f(y, C1), [C1, y]), 'C', 1, 2) == C1 + assert constant_renumber(constantsimp(f(C1, y, C2), [C1, C2, y]), 'C', 1, 2) == C1 def test_constant_function_multiple(): @@ -128,17 +128,17 @@ def test_constant_multiple(): assert constant_renumber(constantsimp(C1**2*2 + 2, [C1]), 'C', 1, 1) == C1 assert constant_renumber( constantsimp(sin(2*C1) + x + sqrt(2), [C1]), 'C', 1, 1) == C1 + x - assert constant_renumber(constantsimp(2*C1 + C2, [C1,C2]), 'C', 1, 2) == C1 + assert constant_renumber(constantsimp(2*C1 + C2, [C1, C2]), 'C', 1, 2) == C1 def test_constant_repeated(): assert C1 + C1*x == constant_renumber( C1 + C1*x, 'C', 1, 3) def test_ode_solutions(): # only a few examples here, the rest will be tested in the actual dsolve tests - assert constant_renumber(constantsimp(C1*exp(2*x) + exp(x)*(C2 + C3), [C1,C2,C3]), 'C', 1, 3) == \ + assert constant_renumber(constantsimp(C1*exp(2*x) + exp(x)*(C2 + C3), [C1, C2, C3]), 'C', 1, 3) == \ constant_renumber((C1*exp(x) + C2*exp(2*x)), 'C', 1, 2) assert constant_renumber( - constantsimp(Eq(f(x), I*C1*sinh(x/3) + C2*cosh(x/3)), [C1,C2]), + constantsimp(Eq(f(x), I*C1*sinh(x/3) + C2*cosh(x/3)), [C1, C2]), 'C', 1, 2) == constant_renumber(Eq(f(x), C1*sinh(x/3) + C2*cosh(x/3)), 'C', 1, 2) assert constant_renumber(constantsimp(Eq(f(x), acos((-C1)/cos(x))), [C1]), 'C', 1, 1) == \ Eq(f(x), acos(C1/cos(x))) @@ -164,10 +164,10 @@ def test_ode_solutions(): @XFAIL def test_nonlocal_simplification(): - assert constantsimp(C1 + C2+x*C2, [C1,C2]) == C1 + C2*x + assert constantsimp(C1 + C2+x*C2, [C1, C2]) == C1 + C2*x def test_constant_Eq(): # C1 on the rhs is well-tested, but the lhs is only tested here - assert constantsimp(Eq(C1, 3 + f(x)*x), [C1]) == Eq(x*f(x),C1) - assert constantsimp(Eq(C1, 3 * f(x)*x), [C1]) == Eq(f(x)*x,C1) + assert constantsimp(Eq(C1, 3 + f(x)*x), [C1]) == Eq(x*f(x), C1) + assert constantsimp(Eq(C1, 3 * f(x)*x), [C1]) == Eq(f(x)*x, C1) diff --git a/sympy/solvers/tests/test_ode.py b/sympy/solvers/tests/test_ode.py index a8b90a0d867a..078a3459f0cb 100644 --- a/sympy/solvers/tests/test_ode.py +++ b/sympy/solvers/tests/test_ode.py @@ -255,7 +255,7 @@ def test_old_ode_tests(): sol4 = Eq(f(x), C1*sin(x/3) + C2*cos(x/3)) sol5 = Eq(f(x), C1*exp(-x/3) + C2*exp(x/3)) sol6 = Eq(f(x), (C1 - cos(x))/x**3) - sol7 = Eq(f(x), C1*exp(x) + C2*exp(2*x)) + sol7 = Eq(f(x), (C1 + C2*exp(x))*exp(x)) sol8 = Eq(f(x), (C1 + C2*x)*exp(2*x)) sol9 = Eq(f(x), (C1*sin(x*sqrt(2)) + C2*cos(x*sqrt(2)))*exp(-x)) sol10 = Eq(f(x), C1 + x/3) @@ -709,7 +709,7 @@ def test_nth_linear_constant_coeff_homogeneous(): eq29 = f(x).diff(x, 4) + 4*f(x).diff(x, 2) eq30 = f(x).diff(x, 5) + 2*f(x).diff(x, 3) + f(x).diff(x) sol1 = Eq(f(x), C1 + C2*exp(-2*x)) - sol2 = Eq(f(x), (C1*exp(x) + C2*exp(2*x))) + sol2 = Eq(f(x), (C1 + C2*exp(x))*exp(x)) sol3 = Eq(f(x), C1*exp(x) + C2*exp(-x)) sol4 = Eq(f(x), C1 + C2*exp(-3*x) + C3*exp(2*x)) sol5 = Eq(f(x), C1*exp(x/2) + C2*exp(4*x/3)) diff --git a/sympy/utilities/tests/test_iterables.py b/sympy/utilities/tests/test_iterables.py index 69e97b2b51e6..ed98bffe520e 100644 --- a/sympy/utilities/tests/test_iterables.py +++ b/sympy/utilities/tests/test_iterables.py @@ -168,8 +168,8 @@ def test_cartes(): def test_numbered_symbols(): s = numbered_symbols(cls=Dummy) assert isinstance(next(s), Dummy) - assert next(numbered_symbols('C', start=1, exclude=[Symbol('C1')])) == \ - Symbol('C2') + assert next(numbered_symbols('C', start=1, exclude=[symbols('C1')])) == \ + symbols('C2') def test_sift():