Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AL-1579] Auto htype #1370

Merged
merged 10 commits into from
Dec 23, 2021
14 changes: 14 additions & 0 deletions hub/api/tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -811,3 +811,17 @@ def test_empty_extend(memory_ds):
ds.create_tensor("y")
ds.y.extend(np.zeros((len(ds), 3)))
assert len(ds) == 0

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add test case for when set_htype is called and self.htype is not None.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

set_htype is internal , analogous to set_dtype, that case shouldn't be hit from user space.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add test case for when set_htype is called and self.length is > 0.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

set_htype is internal , analogous to set_dtype, that case shouldn't be hit from user space.


def test_auto_htype(memory_ds):
ds = memory_ds
with ds:
ds.create_tensor("x")
ds.create_tensor("y")
ds.create_tensor("z")
ds.x.append("hello")
ds.y.append({"a": [1, 2]})
ds.z.append([1, 2, 3])
assert ds.x.htype == "text"
assert ds.y.htype == "json"
assert ds.z.htype == "generic"
5 changes: 3 additions & 2 deletions hub/core/chunk_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from hub.core.meta.encode.chunk_id import ChunkIdEncoder
from hub.core.meta.tensor_meta import TensorMeta
from hub.core.storage.lru_cache import LRUCache
from hub.util.casting import get_dtype
from hub.util.casting import get_dtype, get_htype
from hub.util.chunk_engine import (
check_samples_type,
make_sequence,
Expand Down Expand Up @@ -353,6 +353,8 @@ def _convert_to_list(self, samples):

def _sanitize_samples(self, samples):
check_samples_type(samples)
if self.tensor_meta.htype is None:
self.tensor_meta.set_htype(get_htype(samples))
if self.tensor_meta.dtype is None:
self.tensor_meta.set_dtype(get_dtype(samples))
if self._convert_to_list(samples):
Expand All @@ -369,7 +371,6 @@ def extend(self, samples):
current_chunk = self.last_chunk() or self._create_new_chunk()
updated_chunks = {current_chunk}
enc = self.chunk_id_encoder

while len(samples) > 0:
num_samples_added = current_chunk.extend_if_has_space(samples)
if num_samples_added == 0:
Expand Down
6 changes: 3 additions & 3 deletions hub/core/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ def __getitem__(
def create_tensor(
self,
name: str,
htype: str = DEFAULT_HTYPE,
htype: str = UNSPECIFIED,
dtype: Union[str, np.dtype] = UNSPECIFIED,
sample_compression: str = UNSPECIFIED,
chunk_compression: str = UNSPECIFIED,
Expand Down Expand Up @@ -289,8 +289,8 @@ def create_tensor(

# Seperate meta and info

htype_config = HTYPE_CONFIGURATIONS[htype].copy()
info_keys = htype_config.pop("_info", [])
htype_config = HTYPE_CONFIGURATIONS.get(htype, {})
info_keys = htype_config.copy().pop("_info", [])
info_kwargs = {}
meta_kwargs = {}
for k, v in kwargs.items():
Expand Down
53 changes: 38 additions & 15 deletions hub/core/meta/tensor_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
)
from hub.htype import (
HTYPE_CONFIGURATIONS,
DEFAULT_HTYPE,
)
from hub.htype import HTYPE_CONFIGURATIONS, REQUIRE_USER_SPECIFICATION, UNSPECIFIED
from hub.core.meta.meta import Meta
Expand Down Expand Up @@ -57,22 +58,13 @@ def __init__(
**kwargs: Any key that the provided `htype` has can be overridden via **kwargs. For more information, check out `hub.htype`.
"""

if htype != UNSPECIFIED:
_validate_htype_exists(htype)
_validate_htype_overwrites(htype, kwargs)
_replace_unspecified_values(htype, kwargs)
_validate_required_htype_overwrites(htype, kwargs)
_format_values(htype, kwargs)

required_meta = _required_meta_from_htype(htype)
required_meta.update(kwargs)
super().__init__()

self._required_meta_keys = tuple(required_meta.keys())
self.__dict__.update(required_meta)
if htype != UNSPECIFIED:
self.set_htype(htype, **kwargs)
else:
self._required_meta_keys = tuple()

super().__init__()
self.set_htype(DEFAULT_HTYPE, **kwargs)
self.htype = None # type: ignore

def set_dtype(self, dtype: np.dtype):
"""Should only be called once."""
Expand All @@ -88,6 +80,38 @@ def set_dtype(self, dtype: np.dtype):

self.dtype = dtype.name

def set_htype(self, htype: str, **kwargs):
"""Should only be called once."""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

More documentation please.

  • If self.htype is present, it must be None.
  • If self.length is present, it must be zero.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The docstr is similar to that of set_dtype. This is an internal method and the conditions are explicit in the checks.

ffw_tensor_meta(self)

if getattr(self, "htype", None) is not None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a test case when htype is not yet present.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This case is hit by default when create_tensor is called.

raise ValueError(
f"Tensor meta already has a htype ({self.htype}). Incoming: {htype}."
)

if getattr(self, "length", 0) > 0:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a test case when length is not yet present.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This case is hit by default when create_tensor is called.

raise ValueError("Htype was None, but length was > 0.")

if not kwargs:
kwargs = HTYPE_CONFIGURATIONS[htype]

_validate_htype_exists(htype)
_validate_htype_overwrites(htype, kwargs)
_replace_unspecified_values(htype, kwargs)
_validate_required_htype_overwrites(htype, kwargs)
_format_values(htype, kwargs)

required_meta = _required_meta_from_htype(htype)
required_meta.update(kwargs)

self._required_meta_keys = tuple(required_meta.keys())

for k in self._required_meta_keys:
if getattr(self, k, None):
required_meta.pop(k, None)

self.__dict__.update(required_meta)

def update_shape_interval(self, shape: Tuple[int, ...]):
ffw_tensor_meta(self)

Expand Down Expand Up @@ -219,7 +243,6 @@ def _replace_unspecified_values(htype: str, htype_overwrite: dict):

def _validate_required_htype_overwrites(htype: str, htype_overwrite: dict):
"""Raises errors if `htype_overwrite` has invalid values."""

sample_compression = htype_overwrite["sample_compression"]
sample_compression = COMPRESSION_ALIASES.get(sample_compression, sample_compression)
if sample_compression not in hub.compressions:
Expand Down
15 changes: 15 additions & 0 deletions hub/util/casting.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,21 @@ def get_dtype(val: Union[np.ndarray, Sequence, Sample]) -> np.dtype:
raise TypeError(f"Cannot infer numpy dtype for {val}")


def get_htype(val: Union[np.ndarray, Sequence, Sample]) -> str:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

samples instead of val is easier to parse for me.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Method is analogous to get_dtype.

if isinstance(val, np.ndarray):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add docstring.

return "generic"
types = set((map(type, val)))
if dict in types:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't all types have to be dict? Or, what's the case when not all are dicts (please add test case)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Test added.

return "json"
if types == set((str,)):
return "text"
if np.object in [ # type: ignore
farizrahman4u marked this conversation as resolved.
Show resolved Hide resolved
np.array(x).dtype if not isinstance(x, np.ndarray) else x.dtype for x in val
]:
return "json"
return "generic"


def intelligent_cast(
sample: Any, dtype: Union[np.dtype, str], htype: str
) -> np.ndarray:
Expand Down