-
Notifications
You must be signed in to change notification settings - Fork 602
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[AL-1579] Auto htype #1370
[AL-1579] Auto htype #1370
Changes from 5 commits
83ebe38
d579fa9
625c9a7
1ae75af
e0e038d
ebdc30c
cf44374
ced86d1
b3a59a2
e36f021
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -811,3 +811,17 @@ def test_empty_extend(memory_ds): | |
ds.create_tensor("y") | ||
ds.y.extend(np.zeros((len(ds), 3))) | ||
assert len(ds) == 0 | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add test case for when There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. set_htype is internal , analogous to set_dtype, that case shouldn't be hit from user space. |
||
|
||
def test_auto_htype(memory_ds): | ||
ds = memory_ds | ||
with ds: | ||
ds.create_tensor("x") | ||
ds.create_tensor("y") | ||
ds.create_tensor("z") | ||
ds.x.append("hello") | ||
ds.y.append({"a": [1, 2]}) | ||
ds.z.append([1, 2, 3]) | ||
assert ds.x.htype == "text" | ||
assert ds.y.htype == "json" | ||
assert ds.z.htype == "generic" |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -26,6 +26,7 @@ | |
) | ||
from hub.htype import ( | ||
HTYPE_CONFIGURATIONS, | ||
DEFAULT_HTYPE, | ||
) | ||
from hub.htype import HTYPE_CONFIGURATIONS, REQUIRE_USER_SPECIFICATION, UNSPECIFIED | ||
from hub.core.meta.meta import Meta | ||
|
@@ -57,22 +58,13 @@ def __init__( | |
**kwargs: Any key that the provided `htype` has can be overridden via **kwargs. For more information, check out `hub.htype`. | ||
""" | ||
|
||
if htype != UNSPECIFIED: | ||
_validate_htype_exists(htype) | ||
_validate_htype_overwrites(htype, kwargs) | ||
_replace_unspecified_values(htype, kwargs) | ||
_validate_required_htype_overwrites(htype, kwargs) | ||
_format_values(htype, kwargs) | ||
|
||
required_meta = _required_meta_from_htype(htype) | ||
required_meta.update(kwargs) | ||
super().__init__() | ||
|
||
self._required_meta_keys = tuple(required_meta.keys()) | ||
self.__dict__.update(required_meta) | ||
if htype != UNSPECIFIED: | ||
self.set_htype(htype, **kwargs) | ||
else: | ||
self._required_meta_keys = tuple() | ||
|
||
super().__init__() | ||
self.set_htype(DEFAULT_HTYPE, **kwargs) | ||
self.htype = None # type: ignore | ||
|
||
def set_dtype(self, dtype: np.dtype): | ||
"""Should only be called once.""" | ||
|
@@ -88,6 +80,38 @@ def set_dtype(self, dtype: np.dtype): | |
|
||
self.dtype = dtype.name | ||
|
||
def set_htype(self, htype: str, **kwargs): | ||
"""Should only be called once.""" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. More documentation please.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The docstr is similar to that of set_dtype. This is an internal method and the conditions are explicit in the checks. |
||
ffw_tensor_meta(self) | ||
|
||
if getattr(self, "htype", None) is not None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add a test case when htype is not yet present. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This case is hit by default when create_tensor is called. |
||
raise ValueError( | ||
f"Tensor meta already has a htype ({self.htype}). Incoming: {htype}." | ||
) | ||
|
||
if getattr(self, "length", 0) > 0: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add a test case when length is not yet present. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This case is hit by default when create_tensor is called. |
||
raise ValueError("Htype was None, but length was > 0.") | ||
|
||
if not kwargs: | ||
kwargs = HTYPE_CONFIGURATIONS[htype] | ||
|
||
_validate_htype_exists(htype) | ||
_validate_htype_overwrites(htype, kwargs) | ||
_replace_unspecified_values(htype, kwargs) | ||
_validate_required_htype_overwrites(htype, kwargs) | ||
_format_values(htype, kwargs) | ||
|
||
required_meta = _required_meta_from_htype(htype) | ||
required_meta.update(kwargs) | ||
|
||
self._required_meta_keys = tuple(required_meta.keys()) | ||
|
||
for k in self._required_meta_keys: | ||
if getattr(self, k, None): | ||
required_meta.pop(k, None) | ||
|
||
self.__dict__.update(required_meta) | ||
|
||
def update_shape_interval(self, shape: Tuple[int, ...]): | ||
ffw_tensor_meta(self) | ||
|
||
|
@@ -219,7 +243,6 @@ def _replace_unspecified_values(htype: str, htype_overwrite: dict): | |
|
||
def _validate_required_htype_overwrites(htype: str, htype_overwrite: dict): | ||
"""Raises errors if `htype_overwrite` has invalid values.""" | ||
|
||
sample_compression = htype_overwrite["sample_compression"] | ||
sample_compression = COMPRESSION_ALIASES.get(sample_compression, sample_compression) | ||
if sample_compression not in hub.compressions: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -37,6 +37,21 @@ def get_dtype(val: Union[np.ndarray, Sequence, Sample]) -> np.dtype: | |
raise TypeError(f"Cannot infer numpy dtype for {val}") | ||
|
||
|
||
def get_htype(val: Union[np.ndarray, Sequence, Sample]) -> str: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Method is analogous to get_dtype. |
||
if isinstance(val, np.ndarray): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add docstring. |
||
return "generic" | ||
types = set((map(type, val))) | ||
if dict in types: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Don't all types have to be dict? Or, what's the case when not all are dicts (please add test case)? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Test added. |
||
return "json" | ||
if types == set((str,)): | ||
return "text" | ||
if np.object in [ # type: ignore | ||
farizrahman4u marked this conversation as resolved.
Show resolved
Hide resolved
|
||
np.array(x).dtype if not isinstance(x, np.ndarray) else x.dtype for x in val | ||
]: | ||
return "json" | ||
return "generic" | ||
|
||
|
||
def intelligent_cast( | ||
sample: Any, dtype: Union[np.dtype, str], htype: str | ||
) -> np.ndarray: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add test case for when
set_htype
is called andself.htype
is not None.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
set_htype is internal , analogous to set_dtype, that case shouldn't be hit from user space.