Skip to content

Commit

Permalink
Support Booleans for Query Conditions
Browse files Browse the repository at this point in the history
  • Loading branch information
nguyenv committed Dec 1, 2022
1 parent d04c802 commit 3d699e7
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 3 deletions.
2 changes: 1 addition & 1 deletion tiledb/libtiledb.pyx
Expand Up @@ -680,7 +680,7 @@ def _tiledb_cast_tile_extent(tile_extent, dtype):
cdef int _numpy_typeid(tiledb_datatype_t tiledb_dtype):
"""Return a numpy type num (int) given a tiledb_datatype_t enum value."""
np_id_type = _tiledb_dtype_to_numpy_typeid_convert.get(tiledb_dtype, None)
if np_id_type:
if np_id_type is not None:
return np_id_type
return np.NPY_DATETIME if _tiledb_type_is_datetime(tiledb_dtype) else np.NPY_NOTYPE
Expand Down
12 changes: 10 additions & 2 deletions tiledb/query_condition.py
Expand Up @@ -365,19 +365,27 @@ def cast_value_to_dtype(
# casted to numeric types
if isinstance(value, str):
raise TileDBError(f"Cannot cast `{value}` to {dtype}.")

if np.issubdtype(dtype, np.datetime64):
cast = getattr(np, "int64")
elif np.issubdtype(dtype, bool):
cast = getattr(np, "uint8")
else:
cast = getattr(np, dtype)

value = cast(value)

except ValueError:
raise TileDBError(f"Cannot cast `{value}` to {dtype}.")

return value

def init_pyqc(self, pyqc: PyQueryCondition, dtype: str) -> Callable:
if dtype != "string" and np.issubdtype(dtype, np.datetime64):
dtype = "int64"
if dtype != "string":
if np.issubdtype(dtype, np.datetime64):
dtype = "int64"
elif np.issubdtype(dtype, bool):
dtype = "uint8"

init_fn_name = f"init_{dtype}"

Expand Down
34 changes: 34 additions & 0 deletions tiledb/tests/test_query_condition.py
Expand Up @@ -723,6 +723,40 @@ def test_do_not_return_attrs(self):
assert "D" in A.query(cond=cond, attrs=None).multi_index[:]
assert "D" not in A.query(cond=cond, attrs=[]).multi_index[:]

def test_boolean(self):
path = self.path("test_boolean")

dom = tiledb.Domain(tiledb.Dim(domain=(1, 10), tile=1, dtype=np.uint32))
attrs = [
tiledb.Attr(name="a", dtype=np.bool_),
tiledb.Attr(name="b", dtype=np.bool_),
tiledb.Attr(name="c", dtype=np.bool_),
]
schema = tiledb.ArraySchema(domain=dom, attrs=attrs, sparse=True)
tiledb.Array.create(path, schema)

with tiledb.open(path, "w") as arr:
arr[np.arange(1, 11)] = {
"a": np.random.randint(0, high=2, size=10),
"b": np.random.randint(0, high=2, size=10),
"c": np.random.randint(0, high=2, size=10),
}

with tiledb.open(path) as A:
result = A.query(cond="a == True")[:]
assert all(result["a"])

result = A.query(cond="a == False")[:]
assert all(~result["a"])

result = A.query(cond="a == True and b == True")[:]
assert all(result["a"])
assert all(result["b"])

result = A.query(cond="a == False and c == True")[:]
assert all(~result["a"])
assert all(result["c"])


class QueryDeleteTest(DiskTestCase):
def test_basic_sparse(self):
Expand Down

0 comments on commit 3d699e7

Please sign in to comment.