diff --git a/tiledb/libtiledb.pyx b/tiledb/libtiledb.pyx index 2d33581b8f..9037447431 100644 --- a/tiledb/libtiledb.pyx +++ b/tiledb/libtiledb.pyx @@ -680,7 +680,7 @@ def _tiledb_cast_tile_extent(tile_extent, dtype): cdef int _numpy_typeid(tiledb_datatype_t tiledb_dtype): """Return a numpy type num (int) given a tiledb_datatype_t enum value.""" np_id_type = _tiledb_dtype_to_numpy_typeid_convert.get(tiledb_dtype, None) - if np_id_type: + if np_id_type is not None: return np_id_type return np.NPY_DATETIME if _tiledb_type_is_datetime(tiledb_dtype) else np.NPY_NOTYPE diff --git a/tiledb/query_condition.py b/tiledb/query_condition.py index e01bc71723..8163830f93 100644 --- a/tiledb/query_condition.py +++ b/tiledb/query_condition.py @@ -365,19 +365,27 @@ def cast_value_to_dtype( # casted to numeric types if isinstance(value, str): raise TileDBError(f"Cannot cast `{value}` to {dtype}.") + if np.issubdtype(dtype, np.datetime64): cast = getattr(np, "int64") + elif np.issubdtype(dtype, bool): + cast = getattr(np, "uint8") else: cast = getattr(np, dtype) + value = cast(value) + except ValueError: raise TileDBError(f"Cannot cast `{value}` to {dtype}.") return value def init_pyqc(self, pyqc: PyQueryCondition, dtype: str) -> Callable: - if dtype != "string" and np.issubdtype(dtype, np.datetime64): - dtype = "int64" + if dtype != "string": + if np.issubdtype(dtype, np.datetime64): + dtype = "int64" + elif np.issubdtype(dtype, bool): + dtype = "uint8" init_fn_name = f"init_{dtype}" diff --git a/tiledb/tests/test_query_condition.py b/tiledb/tests/test_query_condition.py index 550da8691f..e94d6fba66 100644 --- a/tiledb/tests/test_query_condition.py +++ b/tiledb/tests/test_query_condition.py @@ -723,6 +723,40 @@ def test_do_not_return_attrs(self): assert "D" in A.query(cond=cond, attrs=None).multi_index[:] assert "D" not in A.query(cond=cond, attrs=[]).multi_index[:] + def test_boolean(self): + path = self.path("test_boolean") + + dom = tiledb.Domain(tiledb.Dim(domain=(1, 10), tile=1, dtype=np.uint32)) + attrs = [ + tiledb.Attr(name="a", dtype=np.bool_), + tiledb.Attr(name="b", dtype=np.bool_), + tiledb.Attr(name="c", dtype=np.bool_), + ] + schema = tiledb.ArraySchema(domain=dom, attrs=attrs, sparse=True) + tiledb.Array.create(path, schema) + + with tiledb.open(path, "w") as arr: + arr[np.arange(1, 11)] = { + "a": np.random.randint(0, high=2, size=10), + "b": np.random.randint(0, high=2, size=10), + "c": np.random.randint(0, high=2, size=10), + } + + with tiledb.open(path) as A: + result = A.query(cond="a == True")[:] + assert all(result["a"]) + + result = A.query(cond="a == False")[:] + assert all(~result["a"]) + + result = A.query(cond="a == True and b == True")[:] + assert all(result["a"]) + assert all(result["b"]) + + result = A.query(cond="a == False and c == True")[:] + assert all(~result["a"]) + assert all(result["c"]) + class QueryDeleteTest(DiskTestCase): def test_basic_sparse(self):