Skip to content

Commit

Permalink
Relational: return BooleanAtoms instead of bools.
Browse files Browse the repository at this point in the history
Fixed tests in several locations to check for the new results.  Corrected
several areas where Python bools were expected; these areas should now handle
Python bools and SymPy BooleanAtoms equally.
  • Loading branch information
randyheydon committed Jan 27, 2014
1 parent 8614622 commit f5e0a15
Show file tree
Hide file tree
Showing 6 changed files with 59 additions and 55 deletions.
18 changes: 9 additions & 9 deletions sympy/core/relational.py
Expand Up @@ -181,11 +181,11 @@ def __new__(cls, lhs, rhs=0, **assumptions):
rhs = _sympify(rhs)
# If expressions have the same structure, they must be equal.
if lhs == rhs:
return True
return S.true
# If one side is real and the other complex, they must be unequal.
elif (lhs.is_real != rhs.is_real and
None not in (lhs.is_real, rhs.is_real)):
return False
return S.false
# Otherwise, see if the difference can be evaluated.
r = cls._eval_sides(lhs, rhs)
if r is not None:
Expand All @@ -195,7 +195,7 @@ def __new__(cls, lhs, rhs=0, **assumptions):

@classmethod
def _eval_relation(cls, lhs, rhs):
return lhs == rhs
return _sympify(lhs == rhs)

Eq = Equality

Expand Down Expand Up @@ -231,12 +231,12 @@ def __new__(cls, lhs, rhs, **assumptions):
rhs = _sympify(rhs)
is_equal = Equality(lhs, rhs)
if is_equal == True or is_equal == False:
return not is_equal
return ~is_equal
return Relational.__new__(cls, lhs, rhs, **assumptions)

@classmethod
def _eval_relation(cls, lhs, rhs):
return lhs != rhs
return _sympify(lhs != rhs)

Ne = Unequality

Expand Down Expand Up @@ -546,7 +546,7 @@ class GreaterThan(_Greater):

@classmethod
def _eval_relation(cls, lhs, rhs):
return lhs >= rhs
return _sympify(lhs >= rhs)

Ge = GreaterThan

Expand All @@ -559,7 +559,7 @@ class LessThan(_Less):

@classmethod
def _eval_relation(cls, lhs, rhs):
return lhs <= rhs
return _sympify(lhs <= rhs)

Le = LessThan

Expand All @@ -572,7 +572,7 @@ class StrictGreaterThan(_Greater):

@classmethod
def _eval_relation(cls, lhs, rhs):
return lhs > rhs
return _sympify(lhs > rhs)

Gt = StrictGreaterThan

Expand All @@ -585,7 +585,7 @@ class StrictLessThan(_Less):

@classmethod
def _eval_relation(cls, lhs, rhs):
return lhs < rhs
return _sympify(lhs < rhs)

Lt = StrictLessThan

Expand Down
4 changes: 2 additions & 2 deletions sympy/core/tests/test_assumptions.py
Expand Up @@ -694,8 +694,8 @@ def test_special_assumptions():
assert (z2*z).is_zero is True

e = -3 - sqrt(5) + (-sqrt(10)/2 - sqrt(2)/2)**2
assert (e < 0) is False
assert (e > 0) is False
assert (e < 0) is S.false
assert (e > 0) is S.false
assert (e == 0) is False # it's not a literal 0
assert e.equals(0) is True

Expand Down
64 changes: 34 additions & 30 deletions sympy/core/tests/test_relational.py
@@ -1,5 +1,6 @@
from sympy.utilities.pytest import XFAIL, raises
from sympy import Symbol, symbols, oo, I, pi, Float, And, Or, Not, Implies, Xor
from sympy import (S, Symbol, symbols, oo, I, pi, Float, And, Or, Not, Implies,
Xor)
from sympy.core.relational import ( Relational, Equality, Unequality,
GreaterThan, LessThan, StrictGreaterThan, StrictLessThan, Rel, Eq, Lt, Le,
Gt, Ge, Ne )
Expand Down Expand Up @@ -49,8 +50,8 @@ def test_rel_subs():
assert e.rhs == y

e = Eq(x, 0)
assert e.subs(x, 0) is True
assert e.subs(x, 1) is False
assert e.subs(x, 0) is S.true
assert e.subs(x, 1) is S.false


def test_wrappers():
Expand Down Expand Up @@ -83,6 +84,9 @@ def test_Eq():


def test_rel_Infinity():
# NOTE: All of these are actually handled by sympy.core.Number, and do
# not create Relational objects. Therefore, they still return True and
# False instead of S.true and S.false.
assert (oo > oo) is False
assert (oo > -oo) is True
assert (oo > 1) is True
Expand Down Expand Up @@ -110,29 +114,29 @@ def test_rel_Infinity():


def test_bool():
assert Eq(0, 0) is True
assert Eq(1, 0) is False
assert Ne(0, 0) is False
assert Ne(1, 0) is True
assert Lt(0, 1) is True
assert Lt(1, 0) is False
assert Le(0, 1) is True
assert Le(1, 0) is False
assert Le(0, 0) is True
assert Gt(1, 0) is True
assert Gt(0, 1) is False
assert Ge(1, 0) is True
assert Ge(0, 1) is False
assert Ge(1, 1) is True
assert Eq(I, 2) is False
assert Ne(I, 2) is True
assert Gt(I, 2) not in [True, False]
assert Ge(I, 2) not in [True, False]
assert Lt(I, 2) not in [True, False]
assert Le(I, 2) not in [True, False]
assert Eq(0, 0) is S.true
assert Eq(1, 0) is S.false
assert Ne(0, 0) is S.false
assert Ne(1, 0) is S.true
assert Lt(0, 1) is S.true
assert Lt(1, 0) is S.false
assert Le(0, 1) is S.true
assert Le(1, 0) is S.false
assert Le(0, 0) is S.true
assert Gt(1, 0) is S.true
assert Gt(0, 1) is S.false
assert Ge(1, 0) is S.true
assert Ge(0, 1) is S.false
assert Ge(1, 1) is S.true
assert Eq(I, 2) is S.false
assert Ne(I, 2) is S.true
assert Gt(I, 2) not in [S.true, S.false]
assert Ge(I, 2) not in [S.true, S.false]
assert Lt(I, 2) not in [S.true, S.false]
assert Le(I, 2) not in [S.true, S.false]
a = Float('.000000000000000000001', '')
b = Float('.0000000000000000000001', '')
assert Eq(pi + a, pi + b) is False
assert Eq(pi + a, pi + b) is S.false


def test_rich_cmp():
Expand All @@ -149,14 +153,14 @@ def test_doit():
np = Symbol('np', nonpositive=True)
nn = Symbol('nn', nonnegative=True)

assert Gt(p, 0).doit() is True
assert Gt(p, 0).doit() is S.true
assert Gt(p, 1).doit() == Gt(p, 1)
assert Ge(p, 0).doit() is True
assert Le(p, 0).doit() is False
assert Lt(n, 0).doit() is True
assert Le(np, 0).doit() is True
assert Ge(p, 0).doit() is S.true
assert Le(p, 0).doit() is S.false
assert Lt(n, 0).doit() is S.true
assert Le(np, 0).doit() is S.true
assert Gt(nn, 0).doit() == Gt(nn, 0)
assert Lt(nn, 0).doit() is False
assert Lt(nn, 0).doit() is S.false

assert Eq(x, 0).doit() == Eq(x, 0)

Expand Down
2 changes: 1 addition & 1 deletion sympy/simplify/tests/test_simplify.py
Expand Up @@ -1702,7 +1702,7 @@ def test_issue_2998():

def test_signsimp():
e = x*(-x + 1) + x*(x - 1)
assert signsimp(Eq(e, 0)) is True
assert signsimp(Eq(e, 0)) is S.true


def test_besselsimp():
Expand Down
12 changes: 6 additions & 6 deletions sympy/solvers/solvers.py
Expand Up @@ -683,7 +683,7 @@ def _sympified_list(w):
f[i] = fi.lhs - fi.rhs
elif isinstance(fi, Poly):
f[i] = fi.as_expr()
elif isinstance(fi, bool) or fi.is_Relational:
elif isinstance(fi, (bool, C.BooleanAtom)) or fi.is_Relational:
return reduce_inequalities(f, assume=flags.get('assume'),
symbols=symbols)

Expand Down Expand Up @@ -1162,22 +1162,22 @@ def _solve(f, *symbols, **flags):
for candidate in candidates:
if candidate in result:
continue
cond = cond is True or cond.subs(symbol, candidate)
if cond is not False:
cond = (cond == True) or cond.subs(symbol, candidate)
if cond != False:
# Only include solutions that do not match the condition
# of any previous pieces.
matches_other_piece = False
for other_n, (other_expr, other_cond) in enumerate(f.args):
if other_n == n:
break
if other_cond is False:
if other_cond == False:
continue
if other_cond.subs(symbol, candidate) is True:
if other_cond.subs(symbol, candidate) == True:
matches_other_piece = True
break
if not matches_other_piece:
result.add(Piecewise(
(candidate, cond is True or cond.doit()),
(candidate, cond == True or cond.doit()),
(S.NaN, True)
))
check = False
Expand Down
14 changes: 7 additions & 7 deletions sympy/stats/rv.py
Expand Up @@ -832,10 +832,10 @@ def return_generator():

if condition: # Check that these values satisfy the condition
gd = given_fn(*args)
if not isinstance(gd, bool):
if gd != True and gd != False:
raise ValueError(
"Conditions must not contain free symbols")
if gd is False: # If the values don't satisfy then try again
if not gd: # If the values don't satisfy then try again
continue

yield fn(*args)
Expand All @@ -861,9 +861,9 @@ def sample_iter_subs(expr, condition=None, numsamples=S.Infinity, **kwargs):

if condition is not None: # Check that these values satisfy the condition
gd = condition.xreplace(d)
if not isinstance(gd, bool):
if gd != True and gd != False:
raise ValueError("Conditions must not contain free symbols")
if gd is False: # If the values don't satisfy then try again
if not gd: # If the values don't satisfy then try again
continue

yield expr.xreplace(d)
Expand All @@ -889,10 +889,10 @@ def sampling_P(condition, given_condition=None, numsamples=1,
numsamples=numsamples, **kwargs)

for x in samples:
if not isinstance(x, bool):
if x != True and x != False:
raise ValueError("Conditions must not contain free symbols")

if x is True:
if x:
count_true += 1
else:
count_false += 1
Expand Down Expand Up @@ -1035,5 +1035,5 @@ def _value_check(condition, message):
Raises ValueError with message if condition is not True
"""
if condition is not True:
if condition != True:
raise ValueError(message)

0 comments on commit f5e0a15

Please sign in to comment.