Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions tests/readers/test_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,7 @@ def test_dataloader(
assert num_workers and (x_spec.sparse or y_spec.sparse)
else:
assert isinstance(dataloader, torch.utils.data.DataLoader)
validate_tensor_generator(
dataloader, x_spec, y_spec, batch_size, supports_csr=True
)
validate_tensor_generator(dataloader, x_spec, y_spec, batch_size)
# ensure the dataloader can be iterated again
n1 = sum(1 for _ in dataloader)
assert n1 != 0
Expand Down
20 changes: 14 additions & 6 deletions tests/readers/test_tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,16 @@
from .utils import ingest_in_tiledb, parametrize_for_dataset, validate_tensor_generator


def dataset_batching_shuffling(dataset: tf.data.Dataset, batch_size: int, shuffle_buffer_size: int) -> tf.data.Dataset:
def dataset_batching_shuffling(
dataset: tf.data.Dataset, batch_size: int, shuffle_buffer_size: int
) -> tf.data.Dataset:
if shuffle_buffer_size > 0:
dataset = dataset.shuffle(shuffle_buffer_size)
return dataset.batch(batch_size)


class TestTensorflowTileDBDataset:
@parametrize_for_dataset()
@parametrize_for_dataset(x_prefer_csr=[False], y_prefer_csr=[False])
def test_dataset(
self, tmpdir, x_spec, y_spec, batch_size, shuffle_buffer_size, num_workers
):
Expand All @@ -33,14 +35,14 @@ def test_dataset(
shuffle_buffer_size=shuffle_buffer_size,
)
assert isinstance(dataset, tf.data.Dataset)
validate_tensor_generator(
dataset, x_spec, y_spec, batch_size, supports_csr=False
)
validate_tensor_generator(dataset, x_spec, y_spec, batch_size)

@parametrize_for_dataset(
# Add one extra key on X
x_shape=((108, 10), (108, 10, 3)),
y_shape=((107, 5), (107, 5, 2)),
x_prefer_csr=[False],
y_prefer_csr=[False],
)
def test_unequal_num_keys(
self, tmpdir, x_spec, y_spec, batch_size, shuffle_buffer_size, num_workers
Expand All @@ -55,7 +57,13 @@ def test_unequal_num_keys(
)
assert "All arrays must have the same key range" in str(ex.value)

@parametrize_for_dataset(num_fields=[0], shuffle_buffer_size=[0], num_workers=[0])
@parametrize_for_dataset(
num_fields=[0],
shuffle_buffer_size=[0],
num_workers=[0],
x_prefer_csr=[False],
y_prefer_csr=[False],
)
def test_dataset_order(
self, tmpdir, x_spec, y_spec, batch_size, shuffle_buffer_size, num_workers
):
Expand Down
35 changes: 25 additions & 10 deletions tests/readers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,20 @@
@dataclass(frozen=True)
class ArraySpec:
sparse: bool
prefer_csr: bool
shape: Sequence[int]
key_dim: int
key_dim_dtype: np.dtype
non_key_dim_dtype: np.dtype
num_fields: int

def tensor_kind(self, supports_csr: bool) -> TensorKind:
@property
def tensor_kind(self) -> TensorKind:
if not self.sparse:
return TensorKind.DENSE
elif not np.issubdtype(self.non_key_dim_dtype, np.integer):
return TensorKind.RAGGED
elif len(self.shape) == 2 and supports_csr:
elif len(self.shape) == 2 and self.prefer_csr:
return TensorKind.SPARSE_CSR
else:
return TensorKind.SPARSE_COO
Expand All @@ -40,6 +42,8 @@ def parametrize_for_dataset(
*,
x_sparse=(True, False),
y_sparse=(True, False),
x_prefer_csr=(True, False),
y_prefer_csr=(True, False),
x_shape=((107, 10), (107, 10, 3)),
y_shape=((107, 5), (107, 5, 2)),
x_key_dim=(0, 1),
Expand All @@ -56,6 +60,8 @@ def parametrize_for_dataset(
for (
x_sparse_,
y_sparse_,
x_prefer_csr_,
y_prefer_csr_,
x_shape_,
y_shape_,
x_key_dim_,
Expand All @@ -69,6 +75,8 @@ def parametrize_for_dataset(
) in it.product(
x_sparse,
y_sparse,
x_prefer_csr,
y_prefer_csr,
x_shape,
y_shape,
x_key_dim,
Expand All @@ -88,8 +96,15 @@ def parametrize_for_dataset(
continue

common_args = (key_dim_dtype_, non_key_dim_dtype_, num_fields_)
x_spec = ArraySpec(x_sparse_, x_shape_, x_key_dim_, *common_args)
y_spec = ArraySpec(y_sparse_, y_shape_, y_key_dim_, *common_args)
x_spec = ArraySpec(x_sparse_, x_prefer_csr_, x_shape_, x_key_dim_, *common_args)
y_spec = ArraySpec(y_sparse_, y_prefer_csr_, y_shape_, y_key_dim_, *common_args)

# no need to parametrize for prefer_csr=True if the spec doesn't have TensorKind.SPARSE_CSR
if x_prefer_csr_ and x_spec.tensor_kind is not TensorKind.SPARSE_CSR:
continue
if y_prefer_csr_ and y_spec.tensor_kind is not TensorKind.SPARSE_CSR:
continue

argvalues.append(
(x_spec, y_spec, batch_size_, shuffle_buffer_size_, num_workers_)
)
Expand Down Expand Up @@ -149,7 +164,7 @@ def ingest_in_tiledb(tmpdir, spec: ArraySpec):
fields = np.random.choice(all_fields, size=spec.num_fields, replace=False).tolist()

with tiledb.open(uri) as array:
yield ArrayParams(array, spec.key_dim, fields), original_data
yield ArrayParams(array, spec.key_dim, fields, spec.tensor_kind), original_data


def _rand_array(shape: Sequence[int], sparse: bool = False) -> np.ndarray:
Expand Down Expand Up @@ -214,17 +229,17 @@ def _int_to_bytes(n: int) -> bytes:
return bytes(s)


def validate_tensor_generator(generator, x_spec, y_spec, batch_size, supports_csr):
def validate_tensor_generator(generator, x_spec, y_spec, batch_size):
for x_tensors, y_tensors in generator:
for x_tensor in x_tensors if isinstance(x_tensors, Sequence) else [x_tensors]:
_validate_tensor(x_tensor, x_spec, batch_size, supports_csr)
_validate_tensor(x_tensor, x_spec, batch_size)
for y_tensor in y_tensors if isinstance(y_tensors, Sequence) else [y_tensors]:
_validate_tensor(y_tensor, y_spec, batch_size, supports_csr)
_validate_tensor(y_tensor, y_spec, batch_size)


def _validate_tensor(tensor, spec, batch_size, supports_csr):
def _validate_tensor(tensor, spec, batch_size):
tensor_kind = _get_tensor_kind(tensor)
assert tensor_kind is spec.tensor_kind(supports_csr)
assert tensor_kind is spec.tensor_kind

spec_row_shape = spec.shape[1:]
if tensor_kind is not TensorKind.RAGGED:
Expand Down