Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Default to TensorKind.SPARSE_COO for sparse tensors #181

Merged
merged 2 commits into from
Aug 24, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 2 additions & 2 deletions tests/readers/test_tensor_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def parametrize_fields(*fields):
def test_max_partition_weight_dense(dense_uri, fields, key_dim_index, memory_budget):
config = {"py.max_incomplete_retries": 0, "sm.memory_budget": memory_budget}
with tiledb.open(dense_uri, config=config) as a:
schema = ArrayParams(a, key_dim_index, fields).to_tensor_schema()
schema = ArrayParams(a, key_dim_index, fields).tensor_schema
max_weight = schema.max_partition_weight
for key_range in schema.key_range.partition_by_weight(max_weight):
# query succeeds without incomplete retries
Expand All @@ -118,7 +118,7 @@ def test_max_partition_weight_sparse(sparse_uri, fields, key_dim_index, memory_b
}
with tiledb.open(sparse_uri, config=config) as a:
key_dim = a.dim(key_dim_index)
schema = ArrayParams(a, key_dim_index, fields).to_tensor_schema()
schema = ArrayParams(a, key_dim_index, fields).tensor_schema
max_weight = schema.max_partition_weight
for key_range in schema.key_range.partition_by_weight(max_weight):
# query succeeds without incomplete retries
Expand Down
22 changes: 9 additions & 13 deletions tiledb/ml/readers/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from functools import partial
from operator import methodcaller
from typing import Any, Callable, Iterator, Mapping, Sequence, Tuple, Union
from typing import Any, Callable, Iterator, Sequence, Tuple, Union

import numpy as np
import scipy.sparse
Expand All @@ -12,7 +12,7 @@
from torchdata.datapipes.iter import IterableWrapper

from ._ranges import InclusiveRange
from ._tensor_schema import TensorKind, TensorSchema
from ._tensor_schema import MappedTensorSchema, TensorKind, TensorSchema
from .types import ArrayParams

Tensor = Union[np.ndarray, sparse.COO, scipy.sparse.csr_matrix]
Expand Down Expand Up @@ -50,9 +50,13 @@ def PyTorchTileDBDataLoader(
Users should NOT pass (TileDB-ML either doesn't support or implements internally the corresponding functionality)
the following arguments: 'shuffle', 'sampler', 'batch_sampler', 'worker_init_fn' and 'collate_fn'.
"""
schemas = tuple(
array_params.to_tensor_schema(_transforms) for array_params in all_array_params
)
schemas = []
for array_params in all_array_params:
schema = array_params.tensor_schema
if schema.kind in (TensorKind.SPARSE_COO, TensorKind.SPARSE_CSR):
schema = MappedTensorSchema(schema, methodcaller("to_sparse_array"))
schemas.append(schema)

key_range = schemas[0].key_range
if not all(key_range.equal_values(schema.key_range) for schema in schemas[1:]):
raise ValueError(f"All arrays must have the same key range: {key_range}")
Expand Down Expand Up @@ -231,11 +235,3 @@ def _get_tensor_collator(
return collator
else:
return _CompositeCollator(*(collator,) * num_fields)


_transforms: Mapping[TensorKind, Union[Callable[[Any], Any], bool]] = {
TensorKind.DENSE: True,
TensorKind.SPARSE_COO: methodcaller("to_sparse_array"),
TensorKind.SPARSE_CSR: methodcaller("to_sparse_array"),
TensorKind.RAGGED: hasattr(torch, "nested_tensor"),
}
32 changes: 19 additions & 13 deletions tiledb/ml/readers/tensorflow.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
"""Functionality for loading data from TileDB arrays to the Tensorflow Data API."""

from typing import Any, Callable, Mapping, Sequence, Union
from typing import Sequence, Union

import numpy as np
import tensorflow as tf

from ._tensor_schema import RaggedArray, SparseData, TensorKind, TensorSchema
from ._tensor_schema import (
MappedTensorSchema,
RaggedArray,
SparseData,
TensorKind,
TensorSchema,
)
from .types import ArrayParams

Tensor = Union[np.ndarray, tf.SparseTensor]
Expand All @@ -21,9 +27,17 @@ def TensorflowTileDBDataset(
used to fetch inputs asynchronously and in parallel. Note: when `num_workers` > 1
yielded batches may be shuffled even if `shuffle_buffer_size` is zero.
"""
schemas = tuple(
array_params.to_tensor_schema(_transforms) for array_params in all_array_params
)
schemas = []
for array_params in all_array_params:
schema = array_params.tensor_schema
if schema.kind is TensorKind.SPARSE_CSR:
raise NotImplementedError(f"{schema.kind} tensors not supported")
elif schema.kind is TensorKind.SPARSE_COO:
schema = MappedTensorSchema(schema, _to_sparse_tensor)
elif schema.kind is TensorKind.RAGGED:
schema = MappedTensorSchema(schema, _to_ragged_tensor)
schemas.append(schema)

key_range = schemas[0].key_range
if not all(key_range.equal_values(schema.key_range) for schema in schemas[1:]):
raise ValueError(f"All arrays must have the same key range: {key_range}")
Expand Down Expand Up @@ -83,11 +97,3 @@ def _to_sparse_tensor(sd: SparseData) -> tf.SparseTensor:

def _to_ragged_tensor(ra: RaggedArray) -> tf.RaggedTensor:
return tf.ragged.constant(ra, dtype=ra[0].dtype)


_transforms: Mapping[TensorKind, Union[Callable[[Any], Any], bool]] = {
TensorKind.DENSE: True,
TensorKind.SPARSE_COO: _to_sparse_tensor,
TensorKind.SPARSE_CSR: False,
TensorKind.RAGGED: _to_ragged_tensor,
}
42 changes: 7 additions & 35 deletions tiledb/ml/readers/types.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,11 @@
from dataclasses import dataclass, field
from typing import Any, Callable, Mapping, Optional, Sequence, Union
from typing import Any, Mapping, Optional, Sequence, Union

import numpy as np

import tiledb

from ._tensor_schema import (
MappedTensorSchema,
TensorKind,
TensorSchema,
TensorSchemaFactories,
)
from ._tensor_schema import TensorKind, TensorSchema, TensorSchemaFactories


@dataclass(frozen=True)
Expand Down Expand Up @@ -60,22 +55,9 @@ def __post_init__(self) -> None:
)
object.__setattr__(self, "_tensor_schema_kwargs", tensor_schema_kwargs)

def to_tensor_schema(
self,
transforms: Mapping[TensorKind, Union[Callable[[Any], Any], bool]] = {},
) -> TensorSchema[Any]:
"""
Create a TensorSchema from an ArrayParams instance.

:param transforms: A mapping of `TensorKind`s to transformation callables.
If `array_params.tensor_kind` (or the inferred tensor_kind for `array_params`)
has a callable value in `transforms`, the returned `TensorSchema` will map
each tensor yielded by its `iter_tensors` method with this callable.

A value in transforms may also be a boolean value:
- If False, a `NotImplementedError` is raised.
- If True, no transformation will be applied (same as if the key is missing).
"""
@property
def tensor_schema(self) -> TensorSchema[Any]:
"""Create a `TensorSchema` from this `ArrayParams` instance."""
if self.tensor_kind is not None:
tensor_kind = self.tensor_kind
elif not self.array.schema.sparse:
Expand All @@ -85,18 +67,8 @@ def to_tensor_schema(
for dim in self._tensor_schema_kwargs["_all_dims"][1:]
):
tensor_kind = TensorKind.RAGGED
elif self.array.ndim != 2 or not transforms.get(TensorKind.SPARSE_CSR, True):
tensor_kind = TensorKind.SPARSE_COO
else:
tensor_kind = TensorKind.SPARSE_CSR
tensor_kind = TensorKind.SPARSE_COO

transform = transforms.get(tensor_kind, True)
if not transform:
raise NotImplementedError(
f"Mapping to {tensor_kind} tensors is not implemented"
)
factory = TensorSchemaFactories[tensor_kind]
tensor_schema = factory(kind=tensor_kind, **self._tensor_schema_kwargs)
if transform is not True:
tensor_schema = MappedTensorSchema(tensor_schema, transform)
return tensor_schema
return factory(kind=tensor_kind, **self._tensor_schema_kwargs)