diff --git a/tests/readers/test_pytorch.py b/tests/readers/test_pytorch.py index cd70f13e..d244b5ec 100644 --- a/tests/readers/test_pytorch.py +++ b/tests/readers/test_pytorch.py @@ -33,9 +33,7 @@ def test_dataloader( assert num_workers and (x_spec.sparse or y_spec.sparse) else: assert isinstance(dataloader, torch.utils.data.DataLoader) - validate_tensor_generator( - dataloader, x_spec, y_spec, batch_size, supports_csr=True - ) + validate_tensor_generator(dataloader, x_spec, y_spec, batch_size) # ensure the dataloader can be iterated again n1 = sum(1 for _ in dataloader) assert n1 != 0 diff --git a/tests/readers/test_tensorflow.py b/tests/readers/test_tensorflow.py index 38393751..6d8c6a23 100644 --- a/tests/readers/test_tensorflow.py +++ b/tests/readers/test_tensorflow.py @@ -9,14 +9,16 @@ from .utils import ingest_in_tiledb, parametrize_for_dataset, validate_tensor_generator -def dataset_batching_shuffling(dataset: tf.data.Dataset, batch_size: int, shuffle_buffer_size: int) -> tf.data.Dataset: +def dataset_batching_shuffling( + dataset: tf.data.Dataset, batch_size: int, shuffle_buffer_size: int +) -> tf.data.Dataset: if shuffle_buffer_size > 0: dataset = dataset.shuffle(shuffle_buffer_size) return dataset.batch(batch_size) class TestTensorflowTileDBDataset: - @parametrize_for_dataset() + @parametrize_for_dataset(x_prefer_csr=[False], y_prefer_csr=[False]) def test_dataset( self, tmpdir, x_spec, y_spec, batch_size, shuffle_buffer_size, num_workers ): @@ -33,14 +35,14 @@ def test_dataset( shuffle_buffer_size=shuffle_buffer_size, ) assert isinstance(dataset, tf.data.Dataset) - validate_tensor_generator( - dataset, x_spec, y_spec, batch_size, supports_csr=False - ) + validate_tensor_generator(dataset, x_spec, y_spec, batch_size) @parametrize_for_dataset( # Add one extra key on X x_shape=((108, 10), (108, 10, 3)), y_shape=((107, 5), (107, 5, 2)), + x_prefer_csr=[False], + y_prefer_csr=[False], ) def test_unequal_num_keys( self, tmpdir, x_spec, y_spec, batch_size, shuffle_buffer_size, num_workers @@ -55,7 +57,13 @@ def test_unequal_num_keys( ) assert "All arrays must have the same key range" in str(ex.value) - @parametrize_for_dataset(num_fields=[0], shuffle_buffer_size=[0], num_workers=[0]) + @parametrize_for_dataset( + num_fields=[0], + shuffle_buffer_size=[0], + num_workers=[0], + x_prefer_csr=[False], + y_prefer_csr=[False], + ) def test_dataset_order( self, tmpdir, x_spec, y_spec, batch_size, shuffle_buffer_size, num_workers ): diff --git a/tests/readers/utils.py b/tests/readers/utils.py index 24494e5c..33e7f687 100644 --- a/tests/readers/utils.py +++ b/tests/readers/utils.py @@ -19,18 +19,20 @@ @dataclass(frozen=True) class ArraySpec: sparse: bool + prefer_csr: bool shape: Sequence[int] key_dim: int key_dim_dtype: np.dtype non_key_dim_dtype: np.dtype num_fields: int - def tensor_kind(self, supports_csr: bool) -> TensorKind: + @property + def tensor_kind(self) -> TensorKind: if not self.sparse: return TensorKind.DENSE elif not np.issubdtype(self.non_key_dim_dtype, np.integer): return TensorKind.RAGGED - elif len(self.shape) == 2 and supports_csr: + elif len(self.shape) == 2 and self.prefer_csr: return TensorKind.SPARSE_CSR else: return TensorKind.SPARSE_COO @@ -40,6 +42,8 @@ def parametrize_for_dataset( *, x_sparse=(True, False), y_sparse=(True, False), + x_prefer_csr=(True, False), + y_prefer_csr=(True, False), x_shape=((107, 10), (107, 10, 3)), y_shape=((107, 5), (107, 5, 2)), x_key_dim=(0, 1), @@ -56,6 +60,8 @@ def parametrize_for_dataset( for ( x_sparse_, y_sparse_, + x_prefer_csr_, + y_prefer_csr_, x_shape_, y_shape_, x_key_dim_, @@ -69,6 +75,8 @@ def parametrize_for_dataset( ) in it.product( x_sparse, y_sparse, + x_prefer_csr, + y_prefer_csr, x_shape, y_shape, x_key_dim, @@ -88,8 +96,15 @@ def parametrize_for_dataset( continue common_args = (key_dim_dtype_, non_key_dim_dtype_, num_fields_) - x_spec = ArraySpec(x_sparse_, x_shape_, x_key_dim_, *common_args) - y_spec = ArraySpec(y_sparse_, y_shape_, y_key_dim_, *common_args) + x_spec = ArraySpec(x_sparse_, x_prefer_csr_, x_shape_, x_key_dim_, *common_args) + y_spec = ArraySpec(y_sparse_, y_prefer_csr_, y_shape_, y_key_dim_, *common_args) + + # no need to parametrize for prefer_csr=True if the spec doesn't have TensorKind.SPARSE_CSR + if x_prefer_csr_ and x_spec.tensor_kind is not TensorKind.SPARSE_CSR: + continue + if y_prefer_csr_ and y_spec.tensor_kind is not TensorKind.SPARSE_CSR: + continue + argvalues.append( (x_spec, y_spec, batch_size_, shuffle_buffer_size_, num_workers_) ) @@ -149,7 +164,7 @@ def ingest_in_tiledb(tmpdir, spec: ArraySpec): fields = np.random.choice(all_fields, size=spec.num_fields, replace=False).tolist() with tiledb.open(uri) as array: - yield ArrayParams(array, spec.key_dim, fields), original_data + yield ArrayParams(array, spec.key_dim, fields, spec.tensor_kind), original_data def _rand_array(shape: Sequence[int], sparse: bool = False) -> np.ndarray: @@ -214,17 +229,17 @@ def _int_to_bytes(n: int) -> bytes: return bytes(s) -def validate_tensor_generator(generator, x_spec, y_spec, batch_size, supports_csr): +def validate_tensor_generator(generator, x_spec, y_spec, batch_size): for x_tensors, y_tensors in generator: for x_tensor in x_tensors if isinstance(x_tensors, Sequence) else [x_tensors]: - _validate_tensor(x_tensor, x_spec, batch_size, supports_csr) + _validate_tensor(x_tensor, x_spec, batch_size) for y_tensor in y_tensors if isinstance(y_tensors, Sequence) else [y_tensors]: - _validate_tensor(y_tensor, y_spec, batch_size, supports_csr) + _validate_tensor(y_tensor, y_spec, batch_size) -def _validate_tensor(tensor, spec, batch_size, supports_csr): +def _validate_tensor(tensor, spec, batch_size): tensor_kind = _get_tensor_kind(tensor) - assert tensor_kind is spec.tensor_kind(supports_csr) + assert tensor_kind is spec.tensor_kind spec_row_shape = spec.shape[1:] if tensor_kind is not TensorKind.RAGGED: