Skip to content

Commit

Permalink
Merge pull request sympy#1438 from asmeurer/nsolve-fix
Browse files Browse the repository at this point in the history
nsolve uses free_symbols instead of atoms(Symbol) (issue 3309)
  • Loading branch information
smichr committed Jul 25, 2012
2 parents c84e5df + 2ca0aba commit 82b045f
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 5 deletions.
6 changes: 3 additions & 3 deletions sympy/solvers/solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2034,10 +2034,10 @@ def nsolve(*args, **kwargs):
if isinstance(f, Equality):
f = f.lhs - f.rhs
f = f.evalf()
atoms = f.atoms(Symbol)
syms = f.free_symbols
if fargs is None:
fargs = atoms.copy().pop()
if not (len(atoms) == 1 and (fargs in atoms or fargs[0] in atoms)):
fargs = syms.copy().pop()
if not (len(syms) == 1 and (fargs in syms or fargs[0] in syms)):
raise ValueError(filldedent('''
expected a one-dimensional and numerical function'''))

Expand Down
13 changes: 11 additions & 2 deletions sympy/solvers/tests/test_numeric.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from sympy import Eq, Matrix, pi, sin, sqrt, Symbol
from sympy import Eq, Matrix, pi, sin, sqrt, Symbol, Integral, Piecewise, symbols
from sympy.mpmath import mnorm, mpf
from sympy.solvers import nsolve
from sympy.utilities.lambdify import lambdify
from sympy.utilities.pytest import raises
from sympy.utilities.pytest import raises, XFAIL

def test_nsolve():
# onedimensional
Expand Down Expand Up @@ -43,3 +43,12 @@ def getroot(x0):
a = Symbol('a')
assert nsolve(1/(0.001 + a)**3 - 6/(0.9 - a)**3, a, 0.3).ae(
mpf('0.31883011387318591'))

def test_issue_3309():
x = Symbol('x')
assert nsolve(Piecewise((x,x<1),(x**2,True)),x,2) == 0.0

@XFAIL
def test_issue_3309_fail():
x, y = symbols('x y')
assert nsolve(Integral(x*y,(x,0,5)),y,2) == 0.0

0 comments on commit 82b045f

Please sign in to comment.