diff --git a/hub/api/dataset.py b/hub/api/dataset.py old mode 100644 new mode 100755 index 2086d0328a..e3b32b4d57 --- a/hub/api/dataset.py +++ b/hub/api/dataset.py @@ -10,6 +10,7 @@ import collections.abc as abc import json import sys +from typing import Iterable import traceback from collections import defaultdict import numpy as np @@ -43,14 +44,15 @@ _get_compressor, _get_dynamic_tensor_dtype, _store_helper, + check_class_label, same_schema, ) import hub.schema.serialize import hub.schema.deserialize from hub.schema.features import flatten -from hub.schema import ClassLabel from hub import auto + from hub.store.dynamic_tensor import DynamicTensor from hub.store.store import get_fs_and_path, get_storage_map from hub.exceptions import ( @@ -595,9 +597,6 @@ def __setitem__(self, slice_, value): if "r" in self._mode: raise ReadModeException("__setitem__") self._auto_checkout() - assign_value = get_value(value) - # handling strings and bytes - assign_value = str_to_int(assign_value, self.tokenizer) if not isinstance(slice_, abc.Iterable) or isinstance(slice_, str): slice_ = [slice_] @@ -609,6 +608,24 @@ def __setitem__(self, slice_, value): elif subpath not in self.keys: raise KeyError(f"Key {subpath} not found in the dataset") + assign_value = get_value(value) + schema_dict = self.schema + if subpath[1:] in schema_dict.dict_.keys(): + schema_key = schema_dict.dict_.get(subpath[1:], None) + else: + for schema_key in subpath[1:].split("/"): + schema_dict = schema_dict.dict_.get(schema_key, None) + if not isinstance(schema_dict, SchemaDict): + schema_key = schema_dict + if isinstance(schema_key, ClassLabel): + assign_value = check_class_label(assign_value, schema_key) + if isinstance(schema_key, (Text, bytes)) or ( + isinstance(assign_value, Iterable) + and any(isinstance(val, str) for val in assign_value) + ): + # handling strings and bytes + assign_value = str_to_int(assign_value, self.tokenizer) + if not slice_list: self._tensors[subpath][:] = assign_value else: diff --git a/hub/api/dataset_utils.py b/hub/api/dataset_utils.py index 461c30d5fb..c2c013810d 100644 --- a/hub/api/dataset_utils.py +++ b/hub/api/dataset_utils.py @@ -5,10 +5,16 @@ """ import os +import time +from typing import Union, Iterable from hub.store.store import get_fs_and_path import numpy as np import sys -from hub.exceptions import ModuleNotInstalledException, DirectoryNotEmptyException +from hub.exceptions import ( + ModuleNotInstalledException, + DirectoryNotEmptyException, + ClassLabelValueError, +) import hashlib import time import numcodecs @@ -16,6 +22,7 @@ import numcodecs.zstd from hub.schema.features import Primitive, SchemaDict from hub.numcodecs import PngCodec +from hub.schema import ClassLabel def slice_split(slice_): @@ -266,3 +273,23 @@ def _get_compressor(compressor: str): raise ValueError( f"Wrong compressor: {compressor}, only LZ4, PNG and ZSTD are supported" ) + + +def check_class_label(value: Union[np.ndarray, list], label: ClassLabel): + """Check if value can be assigned to predefined ClassLabel""" + if not isinstance(value, Iterable) or isinstance(value, str): + assign_class_labels = [value] + else: + assign_class_labels = value + for i, assign_class_label in enumerate(assign_class_labels): + if isinstance(assign_class_label, str): + try: + assign_class_labels[i] = label.str2int(assign_class_label) + except KeyError: + raise ClassLabelValueError(label.names, assign_class_label) + + if min(assign_class_labels) < 0 or max(assign_class_labels) > label.num_classes - 1: + raise ClassLabelValueError(range(label.num_classes - 1), assign_class_label) + if len(assign_class_labels) == 1: + return assign_class_labels[0] + return assign_class_labels diff --git a/hub/api/datasetview.py b/hub/api/datasetview.py old mode 100644 new mode 100755 index eeb71c529b..804dc6b66e --- a/hub/api/datasetview.py +++ b/hub/api/datasetview.py @@ -3,8 +3,7 @@ This Source Code Form is subject to the terms of the Mozilla Public License, v. 2.0. If a copy of the MPL was not distributed with this file, You can obtain one at https://mozilla.org/MPL/2.0/. """ - -from hub.utils import _tuple_product +from typing import Iterable from hub.api.tensorview import TensorView import collections.abc as abc from hub.api.dataset_utils import ( @@ -13,10 +12,11 @@ slice_split, str_to_int, _store_helper, + check_class_label, ) from hub.exceptions import NoneValueException from hub.api.objectview import ObjectView -from hub.schema import Sequence +from hub.schema import Sequence, ClassLabel, Text, SchemaDict import numpy as np @@ -134,10 +134,6 @@ def __setitem__(self, slice_, value): >>> ds_view["image", 3, 0:1920, 0:1080, 0:3] = np.zeros((1920, 1080, 3), "uint8") # sets the 8th image """ self.dataset._auto_checkout() - assign_value = get_value(value) - assign_value = str_to_int( - assign_value, self.dataset.tokenizer - ) # handling strings and bytes if not isinstance(slice_, abc.Iterable) or isinstance(slice_, str): slice_ = [slice_] @@ -145,6 +141,24 @@ def __setitem__(self, slice_, value): subpath, slice_list = slice_split(slice_) slice_list = [0] + slice_list if isinstance(self.indexes, int) else slice_list + assign_value = get_value(value) + schema_dict = self.dataset.schema + if subpath[1:] in schema_dict.dict_.keys(): + schema_key = schema_dict.dict_.get(subpath[1:], None) + else: + for schema_key in subpath[1:].split("/"): + schema_dict = schema_dict.dict_.get(schema_key, None) + if not isinstance(schema_dict, SchemaDict): + schema_key = schema_dict + if isinstance(schema_key, ClassLabel): + assign_value = check_class_label(assign_value, schema_key) + if isinstance(schema_key, (Text, bytes)) or ( + isinstance(assign_value, Iterable) + and any(isinstance(val, str) for val in assign_value) + ): + # handling strings and bytes + assign_value = str_to_int(assign_value, self.dataset.tokenizer) + if not subpath: raise ValueError("Can't assign to dataset sliced without key") elif subpath not in self.keys: diff --git a/hub/api/tensorview.py b/hub/api/tensorview.py old mode 100644 new mode 100755 index 7594ed10ca..42b6582cd8 --- a/hub/api/tensorview.py +++ b/hub/api/tensorview.py @@ -3,12 +3,13 @@ This Source Code Form is subject to the terms of the Mozilla Public License, v. 2.0. If a copy of the MPL was not distributed with this file, You can obtain one at https://mozilla.org/MPL/2.0/. """ - import numpy as np +from typing import Iterable import hub import collections.abc as abc -from hub.api.dataset_utils import get_value, slice_split, str_to_int +from hub.api.dataset_utils import get_value, slice_split, str_to_int, check_class_label from hub.exceptions import NoneValueException +from hub.schema import ClassLabel, Text, SchemaDict import hub.api.objectview as objv @@ -196,9 +197,6 @@ def __setitem__(self, slice_, value): >>> images_tensorview[7, 0:1920, 0:1080, 0:3] = np.zeros((1920, 1080, 3), "uint8") # sets 7th image """ self.dataset._auto_checkout() - assign_value = get_value(value) - # handling strings and bytes - assign_value = str_to_int(assign_value, self.dataset.tokenizer) if not isinstance(slice_, abc.Iterable) or isinstance(slice_, str): slice_ = [slice_] @@ -207,6 +205,25 @@ def __setitem__(self, slice_, value): subpath, slice_list = slice_split(slice_) if subpath: raise ValueError("Can't setitem of TensorView with subpath") + + assign_value = get_value(value) + schema_dict = self.dataset.schema + if subpath[1:] in schema_dict.dict_.keys(): + schema_key = schema_dict.dict_.get(subpath[1:], None) + else: + for schema_key in subpath[1:].split("/"): + schema_dict = schema_dict.dict_.get(schema_key, None) + if not isinstance(schema_dict, SchemaDict): + schema_key = schema_dict + if isinstance(schema_key, ClassLabel): + assign_value = check_class_label(assign_value, schema_key) + if isinstance(schema_key, (Text, bytes)) or ( + isinstance(assign_value, Iterable) + and any(isinstance(val, str) for val in assign_value) + ): + # handling strings and bytes + assign_value = str_to_int(assign_value, self.dataset.tokenizer) + new_nums = self.nums.copy() new_offsets = self.offsets.copy() if isinstance(self.indexes, list): diff --git a/hub/api/tests/test_dataset.py b/hub/api/tests/test_dataset.py index 94920ba08d..461b1da261 100644 --- a/hub/api/tests/test_dataset.py +++ b/hub/api/tests/test_dataset.py @@ -3,21 +3,21 @@ This Source Code Form is subject to the terms of the Mozilla Public License, v. 2.0. If a copy of the MPL was not distributed with this file, You can obtain one at https://mozilla.org/MPL/2.0/. """ - import os import pickle import shutil import cloudpickle import hub.api.dataset as dataset +from hub.cli.auth import login_fn +from hub.exceptions import DirectoryNotEmptyException, ClassLabelValueError import numpy as np import pytest from hub import load, transform -from hub.api.dataset_utils import slice_extract_info, slice_split +from hub.api.dataset_utils import slice_extract_info, slice_split, check_class_label from hub.cli.auth import login_fn from hub.exceptions import DirectoryNotEmptyException, SchemaMismatchException from hub.schema import BBox, ClassLabel, Image, SchemaDict, Sequence, Tensor, Text -from hub.schema.class_label import ClassLabel from hub.utils import ( azure_creds_exist, gcp_creds_exist, @@ -1215,6 +1215,38 @@ def test_check_label_name(): assert ds[1:3].compute().tolist() == [{"label": 2}, {"label": 0}] +def test_class_label_value(): + ds = Dataset( + "./data/tests/test_check_label", + mode="w", + shape=(5,), + schema={ + "label": ClassLabel(names=["name1", "name2", "name3"]), + "label/b": ClassLabel(num_classes=5), + }, + ) + ds["label", 0:7] = 2 + ds["label", 0:2] = np.array([0, 1]) + ds["label", 0:3] = ["name1", "name2", "name3"] + ds[0:3]["label"] = [0, "name2", 2] + try: + ds["label/b", 0] = 6 + except Exception as ex: + assert isinstance(ex, ClassLabelValueError) + try: + ds[0:4]["label/b"] = np.array([0, 1, 2, 3, 7]) + except Exception as ex: + assert isinstance(ex, ClassLabelValueError) + try: + ds["label", 4] = "name4" + except Exception as ex: + assert isinstance(ex, ClassLabelValueError) + try: + ds[0]["label/b"] = ["name"] + except Exception as ex: + assert isinstance(ex, ValueError) + + @pytest.mark.skipif(not minio_creds_exist(), reason="requires minio credentials") def test_minio_endpoint(): token = { @@ -1261,33 +1293,33 @@ def my_filter(sample): if __name__ == "__main__": - test_dataset_dynamic_shaped_slicing() - test_dataset_assign_value() - test_dataset_setting_shape() - test_datasetview_repr() - test_datasetview_get_dictionary() - test_tensorview_slicing() - test_datasetview_slicing() - test_dataset() - test_dataset_batch_write_2() - test_append_dataset() - test_append_resize() - test_dataset_2() - test_text_dataset() - test_text_dataset_tokenizer() - test_dataset_compute() - test_dataset_view_compute() - test_dataset_lazy() - test_dataset_view_lazy() - test_dataset_hub() - test_meta_information() - test_dataset_filter_2() - test_dataset_filter_3() - test_pickleability() - test_dataset_append_and_read() - test_tensorview_iter() - test_dataset_filter_4() - test_datasetview_2() - test_dataset_3() - test_dataset_utils() - test_check_label_name() + # test_dataset_assign_value() + # test_dataset_setting_shape() + # test_datasetview_repr() + # test_datasetview_get_dictionary() + # test_tensorview_slicing() + # test_datasetview_slicing() + # test_dataset() + # test_dataset_batch_write_2() + # test_append_dataset() + # test_dataset_2() + + # test_text_dataset() + # test_text_dataset_tokenizer() + # test_dataset_compute() + # test_dataset_view_compute() + # test_dataset_lazy() + # test_dataset_view_lazy() + # test_dataset_hub() + # test_meta_information() + # test_dataset_filter_2() + # test_dataset_filter_3() + # test_pickleability() + # test_dataset_append_and_read() + # test_tensorview_iter() + # test_dataset_filter_4() + # test_datasetview_2() + # test_dataset_3() + # test_dataset_utils() + # test_check_label_name() + test_class_label_value() diff --git a/hub/exceptions.py b/hub/exceptions.py old mode 100644 new mode 100755 index df493e9dd3..df9c854199 --- a/hub/exceptions.py +++ b/hub/exceptions.py @@ -188,6 +188,12 @@ def __init__(self, correct_shape, wrong_shape): super(HubException, self).__init__(message=message) +class ClassLabelValueError(HubException): + def __init__(self, expected_classes, value): + message = f"Expected ClassLabel to be in {expected_classes}, got {value}" + super(HubException, self).__init__(message=message) + + class NoneValueException(HubException): def __init__(self, param): message = f"Parameter '{param}' should be provided"