diff --git a/deeplake/api/tests/test_api.py b/deeplake/api/tests/test_api.py index c2a99218c9..f6169c9fe3 100644 --- a/deeplake/api/tests/test_api.py +++ b/deeplake/api/tests/test_api.py @@ -1211,6 +1211,7 @@ def test_htypes_list(): "point_cloud", "polygon", "segment_mask", + "tag", "text", "video", ] diff --git a/deeplake/api/tests/test_tag.py b/deeplake/api/tests/test_tag.py new file mode 100644 index 0000000000..fd380ebc6f --- /dev/null +++ b/deeplake/api/tests/test_tag.py @@ -0,0 +1,15 @@ +import numpy as np +import deeplake +import pytest + + +@pytest.mark.parametrize("chunk_compression", [None, "lz4"]) +def test_tag(memory_ds, chunk_compression): + with memory_ds as ds: + ds.create_tensor("abc", htype="tag", chunk_compression=chunk_compression) + ds.abc.append("a") + ds.abc.append(["a", "b"]) + ds.abc.extend(["a", ["b", "c"]]) + + np.testing.assert_array_equal(ds.abc.shapes(), np.array([[1], [2], [1], [2]])) + assert ds.abc.data()["value"] == [["a"], ["a", "b"], ["a"], ["b", "c"]] diff --git a/deeplake/core/chunk/base_chunk.py b/deeplake/core/chunk/base_chunk.py index 049c8ed382..a81063fa45 100644 --- a/deeplake/core/chunk/base_chunk.py +++ b/deeplake/core/chunk/base_chunk.py @@ -91,7 +91,7 @@ def __init__( else (len(tensor_meta.max_shape) if tensor_meta.max_shape else None) ) self.is_text_like = ( - self.htype in {"json", "list", "text"} or self.tensor_meta.is_link + self.htype in {"json", "list", "text", "tag"} or self.tensor_meta.is_link ) self.compression = compression @@ -328,7 +328,7 @@ def serialize_sample( incoming_sample = incoming_sample.path if incoming_sample is None: htype = "text" if self.tensor_meta.is_link else self.htype - empty_mapping = {"text": "", "list": [], "json": {}} + empty_mapping = {"text": "", "list": [], "json": {}, "tag": []} incoming_sample = empty_mapping[htype] if isinstance(incoming_sample, Sample): diff --git a/deeplake/core/chunk_engine.py b/deeplake/core/chunk_engine.py index 312d50e898..1b6d71ed39 100644 --- a/deeplake/core/chunk_engine.py +++ b/deeplake/core/chunk_engine.py @@ -261,7 +261,7 @@ def is_data_cachable(self): tensor_meta = self.tensor_meta return ( self.chunk_class == UncompressedChunk - and tensor_meta.htype not in ["text", "json", "list", "polygon"] + and tensor_meta.htype not in ["text", "json", "list", "polygon", "tag"] and tensor_meta.max_shape and (tensor_meta.max_shape == tensor_meta.min_shape) and (np.prod(tensor_meta.max_shape) < 20) @@ -746,6 +746,10 @@ def _sanitize_samples( p if isinstance(p, Polygons) else Polygons(p, dtype=tensor_meta.dtype) for p in samples ] + elif tensor_meta.htype == "tag": + samples = [ + sample if isinstance(sample, list) else [sample] for sample in samples + ] return samples, verified_samples def _convert_class_labels(self, samples): @@ -3055,7 +3059,7 @@ def get_empty_sample(self): raise ValueError("This tensor has no samples, cannot get empty sample.") htype = self.tensor_meta.htype dtype = self.tensor_meta.dtype - if htype in ("text", "json", "list"): + if htype in ("text", "json", "list", "tag"): return get_empty_text_like_sample(htype) ndim = len(self.tensor_meta.max_shape) if self.is_sequence: @@ -3066,7 +3070,7 @@ def get_empty_sample(self): @property def is_text_like(self): return ( - self.tensor_meta.htype in {"text", "json", "list"} + self.tensor_meta.htype in {"text", "json", "list", "tag"} or self.tensor_meta.is_link ) diff --git a/deeplake/core/dataset/dataset.py b/deeplake/core/dataset/dataset.py index 18ef706267..277f0557d0 100644 --- a/deeplake/core/dataset/dataset.py +++ b/deeplake/core/dataset/dataset.py @@ -858,7 +858,7 @@ def _create_tensor( "nifti", ): self._create_sample_info_tensor(name) - if create_shape_tensor and htype not in ("text", "json"): + if create_shape_tensor and htype not in ("text", "json", "tag"): self._create_sample_shape_tensor(name, htype=htype) if create_id_tensor: self._create_sample_id_tensor(name) diff --git a/deeplake/core/meta/tensor_meta.py b/deeplake/core/meta/tensor_meta.py index 529ff40df5..8d90d39208 100644 --- a/deeplake/core/meta/tensor_meta.py +++ b/deeplake/core/meta/tensor_meta.py @@ -361,7 +361,7 @@ def _validate_required_htype_overwrites(htype: str, htype_overwrite: dict): ) if htype_overwrite["dtype"] is not None: - if htype in ("json", "list"): + if htype in ("json", "list", "tag"): validate_json_schema(htype_overwrite["dtype"]) else: _raise_if_condition( @@ -385,7 +385,7 @@ def _format_values(htype: str, htype_overwrite: dict): dtype = htype_overwrite["dtype"] if dtype is not None: - if htype in ("json", "list", "intrinsics"): + if htype in ("json", "list", "tag", "intrinsics"): if getattr(dtype, "__module__", None) == "typing": htype_overwrite["dtype"] = str(dtype) else: diff --git a/deeplake/core/sample.py b/deeplake/core/sample.py index 3ec1defc42..58b495b294 100644 --- a/deeplake/core/sample.py +++ b/deeplake/core/sample.py @@ -147,7 +147,7 @@ def buffer(self): @property def is_text_like(self): - return self.htype in {"text", "list", "json"} + return self.htype in {"text", "list", "json", "tag"} @property def dtype(self): diff --git a/deeplake/core/serialize.py b/deeplake/core/serialize.py index 42513dadd1..770924afbb 100644 --- a/deeplake/core/serialize.py +++ b/deeplake/core/serialize.py @@ -452,7 +452,7 @@ def check_sample_shape(shape, num_dims): def text_to_bytes(sample, dtype, htype): - if htype in ("json", "list"): + if htype in ("json", "list", "tag"): if isinstance(sample, np.ndarray): if htype == "list": sample = list(sample) if sample.dtype == object else sample.tolist() @@ -463,7 +463,7 @@ def text_to_bytes(sample, dtype, htype): sample = list(sample) validate_json_object(sample, dtype) byts = json.dumps(sample, cls=HubJsonEncoder).encode() - shape = (len(sample),) if htype == "list" else (1,) + shape = (len(sample),) if htype in ("list", "tag") else (1,) else: # htype == "text": if isinstance(sample, np.ndarray) and sample.size == 1: sample = str(sample.reshape(())) @@ -480,7 +480,7 @@ def bytes_to_text(buffer, htype): arr = np.empty(1, dtype=object) arr[0] = json.loads(bytes.decode(buffer), cls=HubJsonDecoder) return arr - elif htype == "list": + elif htype in ("list", "tag"): lst = json.loads(bytes.decode(buffer), cls=HubJsonDecoder) arr = np.empty(len(lst), dtype=object) arr[:] = lst diff --git a/deeplake/core/tensor.py b/deeplake/core/tensor.py index be03f19a4e..3c36a387b3 100644 --- a/deeplake/core/tensor.py +++ b/deeplake/core/tensor.py @@ -570,7 +570,7 @@ def ndim(self) -> int: @property def dtype(self) -> Optional[np.dtype]: """Dtype of the tensor.""" - if self.base_htype in ("json", "list"): + if self.base_htype in ("json", "list", "tag"): return np.dtype(str) if self.meta.dtype: return np.dtype(self.meta.typestr or self.meta.dtype) @@ -977,7 +977,7 @@ def data(self, aslist: bool = False, fetch_chunks: bool = False) -> Any: return {"value": self.text(fetch_chunks=fetch_chunks)} if htype == "json": return {"value": self.dict(fetch_chunks=fetch_chunks)} - if htype == "list": + if htype in ("list", "tag"): return {"value": self.list(fetch_chunks=fetch_chunks)} if self.htype == "video": data = {} @@ -1441,9 +1441,9 @@ def dict(self, fetch_chunks: bool = False): return self._extract_value("json", fetch_chunks=fetch_chunks) def list(self, fetch_chunks: bool = False): - """Return list data. Only applicable for tensors with 'list' base htype.""" - if self.base_htype != "list": - raise Exception("Only supported for list tensors.") + """Return list data. Only applicable for tensors with 'list' or 'tag' base htype.""" + if self.base_htype not in ("list", "tag"): + raise Exception("Only supported for list and tag tensors.") if self.ndim == 1: return list(self.numpy(fetch_chunks=fetch_chunks)) diff --git a/deeplake/enterprise/test_pytorch.py b/deeplake/enterprise/test_pytorch.py index c2b7253dfa..b45529cde5 100644 --- a/deeplake/enterprise/test_pytorch.py +++ b/deeplake/enterprise/test_pytorch.py @@ -402,6 +402,35 @@ def test_string_tensors(local_auth_ds): np.testing.assert_array_equal(batch["strings"], f"string{idx}") +@pytest.mark.slow +@requires_torch +@requires_libdeeplake +@pytest.mark.flaky +def test_tag_tensors(local_auth_ds): + with local_auth_ds as ds: + ds.create_tensor("tags", htype="tag") + ds.tags.extend( + [ + f"tag{idx}" if idx % 2 == 0 else [f"tag{idx}", f"tag{idx}"] + for idx in range(5) + ] + ) + + ptds = ds.pytorch(batch_size=1) + for idx, batch in enumerate(ptds): + if idx % 2 == 0: + np.testing.assert_array_equal(batch["tags"], [[f"tag{idx}"]]) + else: + np.testing.assert_array_equal(batch["tags"], [[f"tag{idx}", f"tag{idx}"]]) + + ptds2 = ds.pytorch(batch_size=None) + for idx, batch in enumerate(ptds2): + if idx % 2 == 0: + np.testing.assert_array_equal(batch["tags"], [f"tag{idx}"]) + else: + np.testing.assert_array_equal(batch["tags"], [f"tag{idx}", f"tag{idx}"]) + + @pytest.mark.xfail(raises=NotImplementedError, strict=True) def test_pytorch_large(): raise NotImplementedError diff --git a/deeplake/htype.py b/deeplake/htype.py index 56e770d724..2933f20932 100644 --- a/deeplake/htype.py +++ b/deeplake/htype.py @@ -19,6 +19,7 @@ class htype: IMAGE_RGB = "image.rgb" IMAGE_GRAY = "image.gray" CLASS_LABEL = "class_label" + TAG = "tag" BBOX = "bbox" BBOX_3D = "bbox.3d" VIDEO = "video" @@ -95,6 +96,7 @@ class htype: }, htype.LIST: {"dtype": "List"}, htype.TEXT: {"dtype": "str"}, + htype.TAG: {"dtype": "List"}, htype.DICOM: {"sample_compression": "dcm"}, htype.NIFTI: {}, htype.POINT_CLOUD: {"dtype": "float32"}, @@ -135,6 +137,11 @@ def CLASS_LABEL(shape, dtype): constraints.ndim_error("class_label", len(shape)) ) + @staticmethod + def TAG(shape, dtype): + if dtype.name != "str": + raise IncompatibleHtypeError(constraints.dtype_error("tag", dtype)) + @staticmethod def BBOX(shape, dtype): if len(shape) not in (2, 3): @@ -185,6 +192,7 @@ def POINT(shape, dtype): HTYPE_CONSTRAINTS: Dict[str, Callable] = { htype.IMAGE: constraints.IMAGE, htype.CLASS_LABEL: constraints.CLASS_LABEL, + htype.TAG: constraints.TAG, htype.BBOX: constraints.BBOX, htype.BBOX_3D: constraints.BBOX_3D, htype.EMBEDDING: constraints.EMBEDDING, @@ -211,6 +219,7 @@ def POINT(shape, dtype): htype.IMAGE_GRAY: _image_compressions, htype.VIDEO: VIDEO_COMPRESSIONS[:], htype.AUDIO: AUDIO_COMPRESSIONS[:], + htype.TAG: BYTE_COMPRESSIONS[:], htype.TEXT: BYTE_COMPRESSIONS[:], htype.LIST: BYTE_COMPRESSIONS[:], htype.JSON: BYTE_COMPRESSIONS[:], diff --git a/deeplake/integrations/pytorch/common.py b/deeplake/integrations/pytorch/common.py index 0e8e3cc0bd..2261765053 100644 --- a/deeplake/integrations/pytorch/common.py +++ b/deeplake/integrations/pytorch/common.py @@ -22,7 +22,10 @@ def collate_fn(batch): ) if isinstance(elem, np.ndarray) and elem.size > 0 and isinstance(elem[0], str): - batch = [it[0] for it in batch] + if elem.dtype == object: + return [it.tolist() for it in batch] + else: + batch = [it[0] for it in batch] elif isinstance(elem, (tuple, list)) and len(elem) > 0 and isinstance(elem[0], str): batch = [it[0] for it in batch] elif isinstance(elem, Polygons): @@ -52,7 +55,10 @@ def convert_fn(data): if isinstance(data, IterableOrderedDict): return IterableOrderedDict((k, convert_fn(v)) for k, v in data.items()) if isinstance(data, np.ndarray) and data.size > 0 and isinstance(data[0], str): - data = data[0] + if data.dtype == object: + return data.tolist() + else: + data = data[0] elif isinstance(data, Polygons): data = data.numpy() @@ -85,6 +91,7 @@ def check_tensors(dataset, tensors, verbose=True): jpeg_png_compressed_tensors = [] json_tensors = [] list_tensors = [] + tag_tensors = [] supported_image_compressions = {"png", "jpeg"} for tensor_name in tensors: tensor = dataset._get_tensor_from_root(tensor_name) @@ -101,6 +108,8 @@ def check_tensors(dataset, tensors, verbose=True): json_tensors.append(tensor_name) elif meta.htype == "list": list_tensors.append(tensor_name) + elif meta.htype == "tag": + tag_tensors.append(tensor_name) if verbose and (json_tensors or list_tensors): json_list_tensors = set(json_tensors + list_tensors) @@ -108,6 +117,8 @@ def check_tensors(dataset, tensors, verbose=True): f"The following tensors have json or list htype: {json_list_tensors}. Collation of these tensors will fail by default. Ensure that these tensors are either transformed by specifying a transform or a custom collate_fn is specified to handle them." ) + list_tensors += tag_tensors + return jpeg_png_compressed_tensors, json_tensors, list_tensors diff --git a/deeplake/integrations/tests/test_pytorch.py b/deeplake/integrations/tests/test_pytorch.py index bc74930f52..b680eb08d3 100644 --- a/deeplake/integrations/tests/test_pytorch.py +++ b/deeplake/integrations/tests/test_pytorch.py @@ -495,6 +495,33 @@ def test_string_tensors(local_ds): np.testing.assert_array_equal(batch["strings"], f"string{idx}") +@requires_torch +@pytest.mark.flaky +def test_tag_tensors(local_ds): + with local_ds: + local_ds.create_tensor("tags", htype="tag") + local_ds.tags.extend( + [ + f"tag{idx}" if idx % 2 == 0 else [f"tag{idx}", f"tag{idx}"] + for idx in range(5) + ] + ) + + ptds = local_ds.pytorch(batch_size=1) + for idx, batch in enumerate(ptds): + if idx % 2 == 0: + np.testing.assert_array_equal(batch["tags"], [[f"tag{idx}"]]) + else: + np.testing.assert_array_equal(batch["tags"], [[f"tag{idx}", f"tag{idx}"]]) + + ptds2 = local_ds.pytorch(batch_size=None) + for idx, batch in enumerate(ptds2): + if idx % 2 == 0: + np.testing.assert_array_equal(batch["tags"], [f"tag{idx}"]) + else: + np.testing.assert_array_equal(batch["tags"], [f"tag{idx}", f"tag{idx}"]) + + @pytest.mark.slow @requires_torch @pytest.mark.flaky diff --git a/deeplake/util/casting.py b/deeplake/util/casting.py index 895bd6c790..52ef2217ef 100644 --- a/deeplake/util/casting.py +++ b/deeplake/util/casting.py @@ -79,7 +79,7 @@ def get_empty_text_like_sample(htype: str): return "" elif htype == "json": return {} - elif htype == "list": + elif htype == "list" or htype == "tag": return [] else: raise ValueError( diff --git a/deeplake/util/htype.py b/deeplake/util/htype.py index 3aba8f713b..d87180d67f 100644 --- a/deeplake/util/htype.py +++ b/deeplake/util/htype.py @@ -27,6 +27,8 @@ def parse_complex_htype(htype: Optional[str]) -> Tuple[bool, bool, Optional[str] raise ValueError( "Can't create a linked tensor with a generic htype, you need to specify htype, for example link[image]" ) + if is_sequence and htype == HTYPE.TAG: + raise ValueError("Htype sequence[tag] is not supported") return is_sequence, is_link, htype diff --git a/docs/source/Htypes.rst b/docs/source/Htypes.rst index d25ba5911d..760f7d06ec 100644 --- a/docs/source/Htypes.rst +++ b/docs/source/Htypes.rst @@ -221,6 +221,44 @@ Appending text labels >>> ds.labels.append(["cars", "airplanes"]) +.. _tag-htype: + +Tag Htype +~~~~~~~~~ + +- :bluebold:`Sample dimensions:` ``(# tags,)`` + +This htype can be used to tag samples with one or more string values. + +:blue:`Creating a tag tensor` +----------------------------- + +A tag tensor can be created using + +>>> ds.create_tensor("tags", htype="tag", chunk_compression="lz4") + +- Optional args: + - :ref:`chunk_compression `. + +- Supported compressions: + +>>> ["lz4"] + +:blue:`Appending tag samples` +----------------------------- + +- Tag samples can be appended as ``str`` or ``list`` of ``str``. + +:bluebold:`Examples` + +Appending a tag + +>>> ds.tags.append("verified") + +Extending with list of tags + +>>> ds.tags.extend(["verified", "unverified"]) + .. _bbox-htype: Bounding Box Htype diff --git a/docs/source/_static/csv/htypes.csv b/docs/source/_static/csv/htypes.csv index f220c012f8..c4f5131f8c 100644 --- a/docs/source/_static/csv/htypes.csv +++ b/docs/source/_static/csv/htypes.csv @@ -6,6 +6,7 @@ generic, None, None :ref:`video `, uint8, Required arg :ref:`audio `, float64, Required arg :ref:`class_label `, uint32, None +:ref:`tag `, str, None :ref:`bbox `, float32, None :ref:`bbox.3d `, float32, None :ref:`intrinsics `,float32,None