diff --git a/Changelog.rst b/Changelog.rst index 546aa97ecf..282d84d459 100644 --- a/Changelog.rst +++ b/Changelog.rst @@ -24,6 +24,8 @@ version 3.16.2 * Fix bug in `cf.read` when reading UM files that caused LBPROC value 131072 (Mean over an ensemble of parallel runs) to be ignored (https://github.com/NCAS-CMS/cf-python/issues/737) +* New keyword parameters to `cf.wi`: ``open_lower`` and ``open_upper`` + (https://github.com/NCAS-CMS/cf-python/issues/740) * Fix bug in `cf.aggregate` that sometimes put a null transpose operation into the Dask graph when one was not needed (https://github.com/NCAS-CMS/cf-python/issues/754) diff --git a/cf/query.py b/cf/query.py index a9548b3106..b4567750c9 100644 --- a/cf/query.py +++ b/cf/query.py @@ -207,6 +207,8 @@ def __init__( exact=True, rtol=None, atol=None, + open_lower=False, + open_upper=False, ): """**Initialisation** @@ -249,6 +251,24 @@ def __init__( .. versionadded:: 3.15.2 + open_lower: `bool`, optional + Only applicable to the ``'wi'`` operator. + If True, open the interval at the lower + bound so that value0 is excluded from the + range. By default the interval is closed + so that value0 is included. + + .. versionadded:: NEXTVERSION + + open_upper: `bool`, optional + Only applicable to the ``'wi'`` operator. + If True, open the interval at the upper + bound so that value1 is excluded from the + range. By default the interval is closed + so that value1 is included. + + .. versionadded:: NEXTVERSION + exact: deprecated at version 3.0.0. Use `re.compile` objects in *value* instead. @@ -289,6 +309,16 @@ def __init__( self._rtol = rtol self._atol = atol + if open_lower or open_upper: + if operator != "wi": + raise ValueError( + "Can only set the 'open_lower' and 'open_upper' " + "parameters for the 'wi' operator" + ) + + self._open_lower = open_lower + self._open_upper = open_upper + def __dask_tokenize__(self): """Return a hashable value fully representative of the object. @@ -316,6 +346,9 @@ def __dask_tokenize__(self): if operator == "isclose": value += (self.rtol, self.atol) + if operator == "wi": + value += (self.open_lower, self.open_upper) + return (self.__class__, operator, self._attr) + value def __deepcopy__(self, memo): @@ -452,8 +485,22 @@ def __str__(self): attr = ".".join(self._attr) operator = self._operator compound = self._compound + + # For "wi" queries only, open intervals are supported. For "wi" _value + # is a list of two values, with representation from string list form + # of '[a, b]' which corresponds to the standard mathematical notation + # for a closed interval, the default. But an open endpoint is indicated + # by a parenthesis, so adjust repr. to convert square bracket(s). + repr_value = str(self._value) + if self.open_lower: + repr_value = "(" + repr_value[1:] + + + if self.open_upper: + repr_value = repr_value[:-1] + ")" + if not compound: - out = f"{attr}({operator} {self._value!s}" + out = f"{attr}({operator} {repr_value}" rtol = self.rtol if rtol is not None: out += f" rtol={rtol}" @@ -596,6 +643,28 @@ def Units(self): raise AttributeError(f"{self!r} has indeterminate units") + @property + def open_lower(self): + """True if the interval is open at the (excludes the) lower bound. + + .. versionadded:: NEXTVERSION + + .. seealso:: `open_upper` + + """ + return getattr(self, "_open_lower", False) + + @property + def open_upper(self): + """True if the interval is open at the (excludes the) upper bound. + + .. versionadded:: NEXTVERSION + + .. seealso:: `open_lower` + + """ + return getattr(self, "_open_upper", False) + @property def rtol(self): """The tolerance on relative numerical differences. @@ -644,8 +713,7 @@ def value(self): return value def addattr(self, attr): - """Return a `Query` object with a new left hand side operand - attribute to be used during evaluation. TODO. + """Redefine the query to be on an object's attribute. If another attribute has previously been specified, then the new attribute is considered to be an attribute of the existing @@ -803,6 +871,8 @@ def equals(self, other, verbose=None, traceback=False): "_operator", "_rtol", "_atol", + "_open_lower", + "_open_upper", ): x = getattr(self, attr, None) y = getattr(other, attr, None) @@ -905,7 +975,17 @@ def _evaluate(self, x, parent_attr): if _wi is not None: return _wi(value) - return (x >= value[0]) & (x <= value[1]) + if self.open_lower: + lower_bound = x > value[0] + else: + lower_bound = x >= value[0] + + if self.open_upper: + upper_bound = x < value[1] + else: + upper_bound = x <= value[1] + + return lower_bound & upper_bound if operator == "eq": try: @@ -1629,9 +1709,21 @@ def isclose(value, units=None, attr=None, rtol=None, atol=None): ) -def wi(value0, value1, units=None, attr=None): +def wi( + value0, + value1, + units=None, + attr=None, + open_lower=False, + open_upper=False, +): """A `Query` object for a "within a range" condition. + The condition is a closed interval by default, inclusive of + both the endpoints, but can be made open or half-open to exclude + the endpoints on either end with use of the `open_lower` and + `open_upper` parameters. + .. seealso:: `cf.contains`, `cf.eq`, `cf.ge`, `cf.gt`, `cf.ne`, `cf.le`, `cf.lt`, `cf.set`, `cf.wo`, `cf.isclose` @@ -1643,6 +1735,22 @@ def wi(value0, value1, units=None, attr=None): value1: The upper bound of the range. + open_lower: `bool`, optional + If True, open the interval at the lower + bound so that value0 is excluded from the + range. By default the interval is closed + so that value0 is included. + + .. versionadded:: NEXTVERSION + + open_upper: `bool`, optional + If True, open the interval at the upper + bound so that value1 is excluded from the + range. By default the interval is closed + so that value1 is included. + + .. versionadded:: NEXTVERSION + units: `str` or `Units`, optional The units of *value*. By default, the same units as the operand being tested are assumed, if applicable. If @@ -1671,9 +1779,42 @@ def wi(value0, value1, units=None, attr=None): True >>> q.evaluate(4) False + >>> q.evaluate(5) + True + >>> q.evaluate(7) + True + + The interval can be made open on either side or both. Note that, + as per mathematical interval notation, square brackets indicate + closed endpoints and parentheses open endpoints in the representation: + + >>> q = cf.wi(5, 7, open_upper=True) + >>> q + + >>> q.evaluate(7) + False + >>> q = cf.wi(5, 7, open_lower=True) + >>> q + + >>> q.evaluate(5) + False + >>> q = cf.wi(5, 7, open_lower=True, open_upper=True) + >>> q + + >>> q.evaluate(5) + False + >>> q.evaluate(7) + False """ - return Query("wi", [value0, value1], units=units, attr=attr) + return Query( + "wi", + [value0, value1], + units=units, + attr=attr, + open_lower=open_lower, + open_upper=open_upper, + ) def wo(value0, value1, units=None, attr=None): @@ -2466,10 +2607,6 @@ def seasons(n=4, start=12): .. seealso:: `cf.year`, `cf.month`, `cf.day`, `cf.hour`, `cf.minute`, `cf.second`, `cf.djf`, `cf.mam`, `cf.jja`, `cf.son` - TODO - - .. seealso:: `cf.mam`, `cf.jja`, `cf.son`, `cf.djf` - :Parameters: n: `int`, optional diff --git a/cf/test/test_Query.py b/cf/test/test_Query.py index 1d9f584e75..dc398ae204 100644 --- a/cf/test/test_Query.py +++ b/cf/test/test_Query.py @@ -36,6 +36,9 @@ def test_Query(self): s = q | r t = cf.Query("gt", 12, attr="bounds") u = s & t + v = cf.wi(2, 5, open_lower=True) + w = cf.wi(2, 5, open_upper=True) + x = cf.wi(2, 5, open_lower=True, open_upper=True) repr(q) repr(s) @@ -47,6 +50,13 @@ def test_Query(self): str(u) u.dump(display=False) + # For "wi", check repr. provides correct notation for open/closed-ness + # of the interval captured. + self.assertEqual(repr(q), "") + self.assertEqual(repr(v), "") + self.assertEqual(repr(w), "") + self.assertEqual(repr(x), "") + u.attr u.operator q.attr @@ -58,6 +68,16 @@ def test_Query(self): self.assertTrue(u.equals(u.copy(), verbose=2)) self.assertFalse(u.equals(t, verbose=0)) + self.assertTrue(q.equals(q.copy())) + self.assertTrue( + q.equals(cf.wi(2, 5, open_lower=False, open_upper=False)) + ) + self.assertFalse(q.equals(v)) + self.assertFalse(q.equals(w)) + self.assertFalse(q.equals(x)) + self.assertFalse(v.equals(w)) + self.assertFalse(v.equals(x)) + copy.deepcopy(u) c = self.f.dimension_coordinate("X") @@ -494,16 +514,65 @@ def test_Query_evaluate(self): self.assertNotEqual(cf.set([3, 8, 11]), x) c = cf.wi(2, 4) - d = cf.wi(6, 8) - - e = d | c + c0 = cf.wi(2, 4, open_lower=False) # equivalent to c, to check default + c1 = cf.wi(2, 4, open_lower=True) + c2 = cf.wi(2, 4, open_upper=True) + c3 = cf.wi(2, 4, open_lower=True, open_upper=True) + all_c = [c, c0, c1, c2, c3] - self.assertTrue(c.evaluate(3)) - self.assertFalse(c.evaluate(5)) - - self.assertTrue(e.evaluate(3)) - self.assertTrue(e.evaluate(7)) - self.assertFalse(e.evaluate(5)) + d = cf.wi(6, 8) + d0 = cf.wi(6, 8, open_lower=False) # equivalent to d, to check default + d1 = cf.wi(6, 8, open_lower=True) + d2 = cf.wi(6, 8, open_upper=True) + d3 = cf.wi(6, 8, open_lower=True, open_upper=True) + + e = d | c # interval: [2, 4] | [6, 8] + e1 = c0 | d1 # interval: [2, 4] | (6, 8] + e2 = c1 | d2 # interval: (2, 4] | [6, 8) + e3 = d3 | c3 # interval: (6, 8) | (2, 4) + all_e = [e, e1, e2, e3] + + for cx in all_c: + self.assertTrue(cx.evaluate(3)) + self.assertFalse(cx.evaluate(5)) + + # Test the two open_* keywords for direct (non-compound) queries + self.assertEqual(c.evaluate(2), c0.evaluate(2)) + self.assertTrue(c0.evaluate(2)) + self.assertFalse(c1.evaluate(2)) + self.assertTrue(c2.evaluate(2)) + self.assertFalse(c3.evaluate(2)) + self.assertEqual(c.evaluate(4), c0.evaluate(4)) + self.assertTrue(c0.evaluate(4)) + self.assertTrue(c1.evaluate(4)) + self.assertFalse(c2.evaluate(4)) + self.assertFalse(c3.evaluate(4)) + + for ex in all_e: + self.assertTrue(e.evaluate(3)) + self.assertTrue(e.evaluate(7)) + self.assertFalse(e.evaluate(5)) + + # Test the two open_* keywords for compound queries. + # Must be careful to capture correct openness/closure of any inner + # bounds introduced through compound queries, e.g. for 'e' there + # are internal endpoints at 4 and 6 to behave like in 'c' and 'd'. + self.assertTrue(e.evaluate(2)) + self.assertTrue(e1.evaluate(2)) + self.assertFalse(e2.evaluate(2)) + self.assertFalse(e3.evaluate(2)) + self.assertTrue(e.evaluate(4)) + self.assertTrue(e1.evaluate(4)) + self.assertTrue(e2.evaluate(4)) + self.assertFalse(e3.evaluate(4)) + self.assertTrue(e.evaluate(6)) + self.assertFalse(e1.evaluate(6)) + self.assertTrue(e2.evaluate(6)) + self.assertFalse(e3.evaluate(6)) + self.assertTrue(e.evaluate(8)) + self.assertTrue(e1.evaluate(8)) + self.assertFalse(e2.evaluate(8)) + self.assertFalse(e3.evaluate(8)) self.assertEqual(3, c) self.assertNotEqual(5, c) @@ -646,6 +715,8 @@ def test_Query__dask_tokenize__(self): cf.wo(2, 5, attr="day") | cf.set(cf.Data([1, 2], "km")), cf.eq(8) | cf.lt(9) & cf.ge(10), cf.isclose(1, "days", rtol=10, atol=99), + cf.wi(-5, 5, open_lower=True), + cf.wi(-5, 5, open_lower=True, open_upper=True), ): self.assertEqual(tokenize(q), tokenize(q.copy())) @@ -657,6 +728,34 @@ def test_Query__dask_tokenize__(self): tokenize(cf.isclose(9)), tokenize(cf.isclose(9, rtol=10)) ) + self.assertNotEqual( + tokenize(cf.wi(-5, 5, open_lower=True)), tokenize(cf.wi(-5, 5)) + ) + self.assertNotEqual( + tokenize(cf.wi(-5, 5, open_upper=True)), tokenize(cf.wi(-5, 5)) + ) + self.assertNotEqual( + tokenize(cf.wi(-5, 5, open_upper=True)), + tokenize(cf.wi(-5, 5, open_lower=True)), + ) + self.assertNotEqual( + tokenize(cf.wi(-5, 5, open_lower=True, open_upper=True)), + tokenize(cf.wi(-5, 5)), + ) + self.assertNotEqual( + tokenize(cf.wi(-5, 5, open_lower=True, open_upper=True)), + tokenize(cf.wi(-5, 5, open_lower=True)), + ) + self.assertNotEqual( + tokenize(cf.wi(-5, 5, open_lower=True, open_upper=True)), + tokenize(cf.wi(-5, 5, open_upper=True)), + ) + # Check defaults + self.assertEqual( + tokenize(cf.wi(-5, 5, open_lower=False, open_upper=False)), + tokenize(cf.wi(-5, 5)), + ) + def test_Query_Units(self): self.assertEqual(cf.eq(9).Units, cf.Units()) self.assertEqual(cf.eq(9, "m s-1").Units, cf.Units("m s-1"))