From c90cf90ef726cf10b3763f55015d0588e27ba2d2 Mon Sep 17 00:00:00 2001 From: kristinagrig06 Date: Thu, 25 Mar 2021 14:16:53 +0400 Subject: [PATCH 1/7] Add ClassLabel value checks --- hub/api/dataset.py | 23 +++++++++++++++++++++++ hub/api/datasetview.py | 27 +++++++++++++++++++++++++-- hub/api/tensorview.py | 27 ++++++++++++++++++++++++++- hub/exceptions.py | 6 ++++++ 4 files changed, 80 insertions(+), 3 deletions(-) mode change 100644 => 100755 hub/api/dataset.py mode change 100644 => 100755 hub/api/datasetview.py mode change 100644 => 100755 hub/api/tensorview.py mode change 100644 => 100755 hub/exceptions.py diff --git a/hub/api/dataset.py b/hub/api/dataset.py old mode 100644 new mode 100755 index 2d881af3e8..bce08f5b5d --- a/hub/api/dataset.py +++ b/hub/api/dataset.py @@ -14,6 +14,7 @@ from collections import defaultdict import numpy as np from PIL import Image as im, ImageChops +from typing import Iterable import fsspec from fsspec.spec import AbstractFileSystem @@ -63,6 +64,7 @@ VersioningNotSupportedException, WrongUsernameException, InvalidVersionInfoException, + ClassLabelValueError, ) from hub.store.metastore import MetaStorage from hub.client.hub_control import HubControlClient @@ -629,6 +631,27 @@ def __setitem__(self, slice_, value): elif subpath not in self.keys: raise KeyError(f"Key {subpath} not found in the dataset") + subpath_type = self.schema.dict_[subpath.replace("/", "")] + if isinstance(subpath_type, ClassLabel): + if not isinstance(value, Iterable) or isinstance(value, str): + assign_class_labels = [value] + else: + assign_class_labels = value + for assign_class_label in assign_class_labels: + if assign_class_label.isdigit(): + assign_class_label = int(assign_class_label) + if ( + isinstance(assign_class_label, str) + and assign_class_label not in subpath_type.names + ): + raise ClassLabelValueError(subpath_type.names, assign_class_label) + elif ( + isinstance(assign_class_label, int) + and assign_class_label >= subpath_type.num_classes + ): + raise ClassLabelValueError( + range(subpath_type.num_classes - 1), assign_class_label + ) if not slice_list: self._tensors[subpath][:] = assign_value else: diff --git a/hub/api/datasetview.py b/hub/api/datasetview.py old mode 100644 new mode 100755 index c04303c438..2dea108a68 --- a/hub/api/datasetview.py +++ b/hub/api/datasetview.py @@ -4,6 +4,7 @@ If a copy of the MPL was not distributed with this file, You can obtain one at https://mozilla.org/MPL/2.0/. """ +from typing import Iterable from hub.utils import _tuple_product from hub.api.tensorview import TensorView import collections.abc as abc @@ -13,9 +14,9 @@ slice_split, str_to_int, ) -from hub.exceptions import NoneValueException +from hub.exceptions import NoneValueException, ClassLabelValueError from hub.api.objectview import ObjectView -from hub.schema import Sequence +from hub.schema import Sequence, ClassLabel import numpy as np @@ -144,6 +145,28 @@ def __setitem__(self, slice_, value): subpath, slice_list = slice_split(slice_) slice_list = [0] + slice_list if isinstance(self.indexes, int) else slice_list + subpath_type = self.dataset.schema.dict_[subpath.replace("/", "")] + if isinstance(subpath_type, ClassLabel): + if not isinstance(value, Iterable) or isinstance(value, str): + assign_class_labels = [value] + else: + assign_class_labels = value + for assign_class_label in assign_class_labels: + if assign_class_label.isdigit(): + assign_class_label = int(assign_class_label) + if ( + isinstance(assign_class_label, str) + and assign_class_label not in subpath_type.names + ): + raise ClassLabelValueError(subpath_type.names, assign_class_label) + elif ( + isinstance(assign_class_label, int) + and assign_class_label >= subpath_type.num_classes + ): + raise ClassLabelValueError( + range(subpath_type.num_classes - 1), assign_class_label + ) + if not subpath: raise ValueError("Can't assign to dataset sliced without key") elif subpath not in self.keys: diff --git a/hub/api/tensorview.py b/hub/api/tensorview.py old mode 100644 new mode 100755 index 1f49839995..4b48334877 --- a/hub/api/tensorview.py +++ b/hub/api/tensorview.py @@ -4,12 +4,14 @@ If a copy of the MPL was not distributed with this file, You can obtain one at https://mozilla.org/MPL/2.0/. """ +from typing import Iterable import numpy as np import hub import collections.abc as abc from hub.api.dataset_utils import get_value, slice_split, str_to_int -from hub.exceptions import NoneValueException +from hub.exceptions import NoneValueException, ClassLabelValueError import hub.api.objectview as objv +from hub.schema import ClassLabel class TensorView: @@ -214,6 +216,29 @@ def __setitem__(self, slice_, value): subpath, slice_list = slice_split(slice_) if subpath: raise ValueError("Can't setitem of TensorView with subpath") + + subpath_type = self.dataset.schema.dict_[self.subpath.replace("/", "")] + if isinstance(subpath_type, ClassLabel): + if not isinstance(value, Iterable) or isinstance(value, str): + assign_class_labels = [value] + else: + assign_class_labels = value + for assign_class_label in assign_class_labels: + if assign_class_label.isdigit(): + assign_class_label = int(assign_class_label) + if ( + isinstance(assign_class_label, str) + and assign_class_label not in subpath_type.names + ): + raise ClassLabelValueError(subpath_type.names, assign_class_label) + elif ( + isinstance(assign_class_label, int) + and assign_class_label >= subpath_type.num_classes + ): + raise ClassLabelValueError( + range(subpath_type.num_classes - 1), assign_class_label + ) + new_nums = self.nums.copy() new_offsets = self.offsets.copy() if isinstance(self.indexes, list): diff --git a/hub/exceptions.py b/hub/exceptions.py old mode 100644 new mode 100755 index 006e29d209..85395dfd7a --- a/hub/exceptions.py +++ b/hub/exceptions.py @@ -210,6 +210,12 @@ def __init__(self, correct_shape, wrong_shape): hub_reporter.error_report(self, tags=hub_tags) +class ClassLabelValueError(HubException): + def __init__(self, expected_classes, value): + message = f"Expected ClassLabel to be in {expected_classes}, got {value}" + super(HubException, self).__init__(message=message) + + class NoneValueException(HubException): def __init__(self, param): message = f"Parameter '{param}' should be provided" From a2c7d0a8b667df17cc9c5d78706ac73082b5e1ad Mon Sep 17 00:00:00 2001 From: kristinagrig06 Date: Mon, 5 Apr 2021 11:12:23 +0400 Subject: [PATCH 2/7] Move check to function, add test --- hub/api/dataset.py | 23 +--------- hub/api/dataset_utils.py | 32 +++++++++++++- hub/api/datasetview.py | 23 ++-------- hub/api/tensorview.py | 26 ++--------- hub/api/tests/test_dataset.py | 82 ++++++++++++++++++++++------------- 5 files changed, 90 insertions(+), 96 deletions(-) diff --git a/hub/api/dataset.py b/hub/api/dataset.py index bce08f5b5d..a61e29b425 100755 --- a/hub/api/dataset.py +++ b/hub/api/dataset.py @@ -41,12 +41,12 @@ slice_split, str_to_int, _copy_helper, + check_class_label, ) import hub.schema.serialize import hub.schema.deserialize from hub.schema.features import flatten -from hub.schema import ClassLabel from hub.store.dynamic_tensor import DynamicTensor from hub.store.store import get_fs_and_path, get_storage_map @@ -64,7 +64,6 @@ VersioningNotSupportedException, WrongUsernameException, InvalidVersionInfoException, - ClassLabelValueError, ) from hub.store.metastore import MetaStorage from hub.client.hub_control import HubControlClient @@ -633,25 +632,7 @@ def __setitem__(self, slice_, value): subpath_type = self.schema.dict_[subpath.replace("/", "")] if isinstance(subpath_type, ClassLabel): - if not isinstance(value, Iterable) or isinstance(value, str): - assign_class_labels = [value] - else: - assign_class_labels = value - for assign_class_label in assign_class_labels: - if assign_class_label.isdigit(): - assign_class_label = int(assign_class_label) - if ( - isinstance(assign_class_label, str) - and assign_class_label not in subpath_type.names - ): - raise ClassLabelValueError(subpath_type.names, assign_class_label) - elif ( - isinstance(assign_class_label, int) - and assign_class_label >= subpath_type.num_classes - ): - raise ClassLabelValueError( - range(subpath_type.num_classes - 1), assign_class_label - ) + check_class_label(value, subpath_type) if not slice_list: self._tensors[subpath][:] = assign_value else: diff --git a/hub/api/dataset_utils.py b/hub/api/dataset_utils.py index f5f8b2145d..f808427c84 100644 --- a/hub/api/dataset_utils.py +++ b/hub/api/dataset_utils.py @@ -5,12 +5,18 @@ """ import os +import time +from typing import Union, Iterable from hub.store.store import get_fs_and_path import numpy as np import sys -from hub.exceptions import ModuleNotInstalledException, DirectoryNotEmptyException +from hub.exceptions import ( + ModuleNotInstalledException, + DirectoryNotEmptyException, + ClassLabelValueError, +) import hashlib -import time +from hub.schema import ClassLabel def slice_split(slice_): @@ -198,3 +204,25 @@ def _copy_helper( src_fs=src_fs, ) return dst_url + + +def check_class_label(value: Union[np.ndarray, list], subpath_type=None): + if not isinstance(value, Iterable) or isinstance(value, str): + assign_class_labels = [value] + else: + assign_class_labels = value + for assign_class_label in assign_class_labels: + if str(assign_class_label).isdigit(): + assign_class_label = int(assign_class_label) + if ( + isinstance(assign_class_label, str) + and assign_class_label not in subpath_type.names + ): + raise ClassLabelValueError(subpath_type.names, assign_class_label) + elif ( + isinstance(assign_class_label, int) + and assign_class_label >= subpath_type.num_classes + ): + raise ClassLabelValueError( + range(subpath_type.num_classes - 1), assign_class_label + ) diff --git a/hub/api/datasetview.py b/hub/api/datasetview.py index 2dea108a68..9feb3365f0 100755 --- a/hub/api/datasetview.py +++ b/hub/api/datasetview.py @@ -13,8 +13,9 @@ get_value, slice_split, str_to_int, + check_class_label, ) -from hub.exceptions import NoneValueException, ClassLabelValueError +from hub.exceptions import NoneValueException from hub.api.objectview import ObjectView from hub.schema import Sequence, ClassLabel import numpy as np @@ -147,25 +148,7 @@ def __setitem__(self, slice_, value): subpath_type = self.dataset.schema.dict_[subpath.replace("/", "")] if isinstance(subpath_type, ClassLabel): - if not isinstance(value, Iterable) or isinstance(value, str): - assign_class_labels = [value] - else: - assign_class_labels = value - for assign_class_label in assign_class_labels: - if assign_class_label.isdigit(): - assign_class_label = int(assign_class_label) - if ( - isinstance(assign_class_label, str) - and assign_class_label not in subpath_type.names - ): - raise ClassLabelValueError(subpath_type.names, assign_class_label) - elif ( - isinstance(assign_class_label, int) - and assign_class_label >= subpath_type.num_classes - ): - raise ClassLabelValueError( - range(subpath_type.num_classes - 1), assign_class_label - ) + check_class_label(value, subpath_type) if not subpath: raise ValueError("Can't assign to dataset sliced without key") diff --git a/hub/api/tensorview.py b/hub/api/tensorview.py index 4b48334877..39ce87c1d9 100755 --- a/hub/api/tensorview.py +++ b/hub/api/tensorview.py @@ -8,10 +8,10 @@ import numpy as np import hub import collections.abc as abc -from hub.api.dataset_utils import get_value, slice_split, str_to_int -from hub.exceptions import NoneValueException, ClassLabelValueError -import hub.api.objectview as objv +from hub.api.dataset_utils import get_value, slice_split, str_to_int, check_class_label +from hub.exceptions import NoneValueException from hub.schema import ClassLabel +import hub.api.objectview as objv class TensorView: @@ -219,25 +219,7 @@ def __setitem__(self, slice_, value): subpath_type = self.dataset.schema.dict_[self.subpath.replace("/", "")] if isinstance(subpath_type, ClassLabel): - if not isinstance(value, Iterable) or isinstance(value, str): - assign_class_labels = [value] - else: - assign_class_labels = value - for assign_class_label in assign_class_labels: - if assign_class_label.isdigit(): - assign_class_label = int(assign_class_label) - if ( - isinstance(assign_class_label, str) - and assign_class_label not in subpath_type.names - ): - raise ClassLabelValueError(subpath_type.names, assign_class_label) - elif ( - isinstance(assign_class_label, int) - and assign_class_label >= subpath_type.num_classes - ): - raise ClassLabelValueError( - range(subpath_type.num_classes - 1), assign_class_label - ) + check_class_label(value, subpath_type) new_nums = self.nums.copy() new_offsets = self.offsets.copy() diff --git a/hub/api/tests/test_dataset.py b/hub/api/tests/test_dataset.py index d07dac4944..3e6617950b 100644 --- a/hub/api/tests/test_dataset.py +++ b/hub/api/tests/test_dataset.py @@ -4,14 +4,14 @@ If a copy of the MPL was not distributed with this file, You can obtain one at https://mozilla.org/MPL/2.0/. """ -from hub.api.dataset_utils import slice_extract_info, slice_split +from hub.api.dataset_utils import slice_extract_info, slice_split, check_class_label from hub.schema.class_label import ClassLabel import os import shutil import cloudpickle import pickle from hub.cli.auth import login_fn -from hub.exceptions import DirectoryNotEmptyException +from hub.exceptions import DirectoryNotEmptyException, ClassLabelValueError import numpy as np import pytest from hub import transform @@ -1107,6 +1107,25 @@ def test_check_label_name(): assert ds[1:3].compute().tolist() == [{"label": 2}, {"label": 0}] +def test_class_label_value(): + ds = Dataset( + "./data/tests/test_check_label", + mode="a", + shape=(5,), + schema={"label": ClassLabel(names=["name1", "name2", "name3"])}, + ) + ds["label", 0:7] = 2 + ds["label", 0:2] = np.array([0, 1]) + try: + ds["label", 0] = 4 + except Exception as ex: + assert isinstance(ex, ClassLabelValueError) + try: + ds[0:4]["label"] = np.array([0, 1, 2, 3]) + except Exception as ex: + assert isinstance(ex, ClassLabelValueError) + + @pytest.mark.skipif(not minio_creds_exist(), reason="requires minio credentials") def test_minio_endpoint(): token = { @@ -1129,32 +1148,33 @@ def test_minio_endpoint(): if __name__ == "__main__": - test_dataset_assign_value() - test_dataset_setting_shape() - test_datasetview_repr() - test_datasetview_get_dictionary() - test_tensorview_slicing() - test_datasetview_slicing() - test_dataset() - test_dataset_batch_write_2() - test_append_dataset() - test_dataset_2() - - test_text_dataset() - test_text_dataset_tokenizer() - test_dataset_compute() - test_dataset_view_compute() - test_dataset_lazy() - test_dataset_view_lazy() - test_dataset_hub() - test_meta_information() - test_dataset_filter_2() - test_dataset_filter_3() - test_pickleability() - test_dataset_append_and_read() - test_tensorview_iter() - test_dataset_filter_4() - test_datasetview_2() - test_dataset_3() - test_dataset_utils() - test_check_label_name() + # test_dataset_assign_value() + # test_dataset_setting_shape() + # test_datasetview_repr() + # test_datasetview_get_dictionary() + # test_tensorview_slicing() + # test_datasetview_slicing() + # test_dataset() + # test_dataset_batch_write_2() + # test_append_dataset() + # test_dataset_2() + + # test_text_dataset() + # test_text_dataset_tokenizer() + # test_dataset_compute() + # test_dataset_view_compute() + # test_dataset_lazy() + # test_dataset_view_lazy() + # test_dataset_hub() + # test_meta_information() + # test_dataset_filter_2() + # test_dataset_filter_3() + # test_pickleability() + # test_dataset_append_and_read() + # test_tensorview_iter() + # test_dataset_filter_4() + # test_datasetview_2() + # test_dataset_3() + # test_dataset_utils() + # test_check_label_name() + test_class_label_value() From e8d16b81f73a87ba0b861c01bf2eb4b42f4839a1 Mon Sep 17 00:00:00 2001 From: kristinagrig06 Date: Mon, 5 Apr 2021 11:37:48 +0400 Subject: [PATCH 3/7] Remove Iterable from imports --- hub/api/dataset.py | 1 - hub/api/dataset_utils.py | 1 + hub/api/datasetview.py | 1 - hub/api/tensorview.py | 2 -- 4 files changed, 1 insertion(+), 4 deletions(-) diff --git a/hub/api/dataset.py b/hub/api/dataset.py index 1eb03d5fe3..0b5c48479c 100755 --- a/hub/api/dataset.py +++ b/hub/api/dataset.py @@ -14,7 +14,6 @@ from collections import defaultdict import numpy as np from PIL import Image as im, ImageChops -from typing import Iterable import fsspec from fsspec.spec import AbstractFileSystem diff --git a/hub/api/dataset_utils.py b/hub/api/dataset_utils.py index 1ae415b251..2b44b1f80e 100644 --- a/hub/api/dataset_utils.py +++ b/hub/api/dataset_utils.py @@ -258,6 +258,7 @@ def _get_compressor(compressor: str): def check_class_label(value: Union[np.ndarray, list], subpath_type=None): + """Check if value can be assigned to predefined ClassLabel""" if not isinstance(value, Iterable) or isinstance(value, str): assign_class_labels = [value] else: diff --git a/hub/api/datasetview.py b/hub/api/datasetview.py index 070edda5a9..fb3e617060 100755 --- a/hub/api/datasetview.py +++ b/hub/api/datasetview.py @@ -4,7 +4,6 @@ If a copy of the MPL was not distributed with this file, You can obtain one at https://mozilla.org/MPL/2.0/. """ -from typing import Iterable from hub.utils import _tuple_product from hub.api.tensorview import TensorView import collections.abc as abc diff --git a/hub/api/tensorview.py b/hub/api/tensorview.py index 6cbf40c386..3a6f204aac 100755 --- a/hub/api/tensorview.py +++ b/hub/api/tensorview.py @@ -3,8 +3,6 @@ This Source Code Form is subject to the terms of the Mozilla Public License, v. 2.0. If a copy of the MPL was not distributed with this file, You can obtain one at https://mozilla.org/MPL/2.0/. """ - -from typing import Iterable import numpy as np import hub import collections.abc as abc From ce7b10d6122b7a86de43888039366a470cc83cf8 Mon Sep 17 00:00:00 2001 From: kristinagrig06 Date: Mon, 5 Apr 2021 12:30:33 +0400 Subject: [PATCH 4/7] Fix tests --- hub/api/dataset.py | 2 +- hub/api/dataset_utils.py | 5 ++--- hub/api/datasetview.py | 2 +- hub/api/tensorview.py | 2 +- 4 files changed, 5 insertions(+), 6 deletions(-) diff --git a/hub/api/dataset.py b/hub/api/dataset.py index 0b5c48479c..2682bcf093 100755 --- a/hub/api/dataset.py +++ b/hub/api/dataset.py @@ -612,7 +612,7 @@ def __setitem__(self, slice_, value): elif subpath not in self.keys: raise KeyError(f"Key {subpath} not found in the dataset") - subpath_type = self.schema.dict_[subpath.replace("/", "")] + subpath_type = self.schema.dict_.get(subpath.replace("/", ""), None) if isinstance(subpath_type, ClassLabel): check_class_label(value, subpath_type) if not slice_list: diff --git a/hub/api/dataset_utils.py b/hub/api/dataset_utils.py index 2b44b1f80e..d0aaedf429 100644 --- a/hub/api/dataset_utils.py +++ b/hub/api/dataset_utils.py @@ -271,9 +271,8 @@ def check_class_label(value: Union[np.ndarray, list], subpath_type=None): and assign_class_label not in subpath_type.names ): raise ClassLabelValueError(subpath_type.names, assign_class_label) - elif ( - isinstance(assign_class_label, int) - and assign_class_label >= subpath_type.num_classes + elif isinstance(assign_class_label, int) and ( + assign_class_label >= subpath_type.num_classes or assign_class_label < 0 ): raise ClassLabelValueError( range(subpath_type.num_classes - 1), assign_class_label diff --git a/hub/api/datasetview.py b/hub/api/datasetview.py index fb3e617060..3e09c6f96c 100755 --- a/hub/api/datasetview.py +++ b/hub/api/datasetview.py @@ -146,7 +146,7 @@ def __setitem__(self, slice_, value): subpath, slice_list = slice_split(slice_) slice_list = [0] + slice_list if isinstance(self.indexes, int) else slice_list - subpath_type = self.dataset.schema.dict_[subpath.replace("/", "")] + subpath_type = self.dataset.schema.dict_.get(subpath.replace("/", ""), None) if isinstance(subpath_type, ClassLabel): check_class_label(value, subpath_type) diff --git a/hub/api/tensorview.py b/hub/api/tensorview.py index 3a6f204aac..f48b24892e 100755 --- a/hub/api/tensorview.py +++ b/hub/api/tensorview.py @@ -208,7 +208,7 @@ def __setitem__(self, slice_, value): if subpath: raise ValueError("Can't setitem of TensorView with subpath") - subpath_type = self.dataset.schema.dict_[self.subpath.replace("/", "")] + subpath_type = self.dataset.schema.dict_.get(subpath.replace("/", ""), None) if isinstance(subpath_type, ClassLabel): check_class_label(value, subpath_type) From 81ed5542f303ae1d2a49c17b66f660c121ffa3c7 Mon Sep 17 00:00:00 2001 From: kristinagrig06 Date: Wed, 7 Apr 2021 12:49:42 +0400 Subject: [PATCH 5/7] Separate Text conversion --- hub/api/dataset.py | 24 ++++++++++++++++++------ hub/api/dataset_utils.py | 28 +++++++++++++--------------- hub/api/datasetview.py | 26 ++++++++++++++++++-------- hub/api/tensorview.py | 26 ++++++++++++++++++-------- hub/api/tests/test_dataset.py | 18 +++++++++++------- 5 files changed, 78 insertions(+), 44 deletions(-) diff --git a/hub/api/dataset.py b/hub/api/dataset.py index 2682bcf093..a907ab3388 100755 --- a/hub/api/dataset.py +++ b/hub/api/dataset.py @@ -598,9 +598,6 @@ def __setitem__(self, slice_, value): if "r" in self._mode: raise ReadModeException("__setitem__") self._auto_checkout() - assign_value = get_value(value) - # handling strings and bytes - assign_value = str_to_int(assign_value, self.tokenizer) if not isinstance(slice_, abc.Iterable) or isinstance(slice_, str): slice_ = [slice_] @@ -612,9 +609,24 @@ def __setitem__(self, slice_, value): elif subpath not in self.keys: raise KeyError(f"Key {subpath} not found in the dataset") - subpath_type = self.schema.dict_.get(subpath.replace("/", ""), None) - if isinstance(subpath_type, ClassLabel): - check_class_label(value, subpath_type) + assign_value = get_value(value) + schema_dict = self.schema + if subpath[1:] in schema_dict.dict_.keys(): + schema_key = schema_dict.dict_.get(subpath[1:], None) + else: + for schema_key in subpath[1:].split("/"): + schema_dict = schema_dict.dict_.get(schema_key, None) + if not isinstance(schema_dict, SchemaDict): + schema_key = schema_dict + if isinstance(schema_key, ClassLabel): + assign_value = check_class_label(assign_value, schema_key) + if ( + isinstance(schema_key, (Text, bytes)) + or np.array(assign_value).dtype.type is np.str_ + ): + # handling strings and bytes + assign_value = str_to_int(assign_value, self.tokenizer) + if not slice_list: self._tensors[subpath][:] = assign_value else: diff --git a/hub/api/dataset_utils.py b/hub/api/dataset_utils.py index d0aaedf429..5af47d888f 100644 --- a/hub/api/dataset_utils.py +++ b/hub/api/dataset_utils.py @@ -257,23 +257,21 @@ def _get_compressor(compressor: str): ) -def check_class_label(value: Union[np.ndarray, list], subpath_type=None): +def check_class_label(value: Union[np.ndarray, list], label: ClassLabel): """Check if value can be assigned to predefined ClassLabel""" if not isinstance(value, Iterable) or isinstance(value, str): assign_class_labels = [value] else: assign_class_labels = value - for assign_class_label in assign_class_labels: - if str(assign_class_label).isdigit(): - assign_class_label = int(assign_class_label) - if ( - isinstance(assign_class_label, str) - and assign_class_label not in subpath_type.names - ): - raise ClassLabelValueError(subpath_type.names, assign_class_label) - elif isinstance(assign_class_label, int) and ( - assign_class_label >= subpath_type.num_classes or assign_class_label < 0 - ): - raise ClassLabelValueError( - range(subpath_type.num_classes - 1), assign_class_label - ) + for i, assign_class_label in enumerate(assign_class_labels): + if isinstance(assign_class_label, str): + try: + assign_class_labels[i] = label.str2int(assign_class_label) + except KeyError: + raise ClassLabelValueError(label.names, assign_class_label) + + if min(assign_class_labels) < 0 or max(assign_class_labels) > label.num_classes - 1: + raise ClassLabelValueError(range(label.num_classes - 1), assign_class_label) + if len(assign_class_labels) == 1: + return assign_class_labels[0] + return assign_class_labels diff --git a/hub/api/datasetview.py b/hub/api/datasetview.py index 3e09c6f96c..98870bd9d7 100755 --- a/hub/api/datasetview.py +++ b/hub/api/datasetview.py @@ -17,7 +17,7 @@ ) from hub.exceptions import NoneValueException from hub.api.objectview import ObjectView -from hub.schema import Sequence, ClassLabel +from hub.schema import Sequence, ClassLabel, Text, SchemaDict import numpy as np @@ -135,10 +135,6 @@ def __setitem__(self, slice_, value): >>> ds_view["image", 3, 0:1920, 0:1080, 0:3] = np.zeros((1920, 1080, 3), "uint8") # sets the 8th image """ self.dataset._auto_checkout() - assign_value = get_value(value) - assign_value = str_to_int( - assign_value, self.dataset.tokenizer - ) # handling strings and bytes if not isinstance(slice_, abc.Iterable) or isinstance(slice_, str): slice_ = [slice_] @@ -146,9 +142,23 @@ def __setitem__(self, slice_, value): subpath, slice_list = slice_split(slice_) slice_list = [0] + slice_list if isinstance(self.indexes, int) else slice_list - subpath_type = self.dataset.schema.dict_.get(subpath.replace("/", ""), None) - if isinstance(subpath_type, ClassLabel): - check_class_label(value, subpath_type) + assign_value = get_value(value) + schema_dict = self.dataset.schema + if subpath[1:] in schema_dict.dict_.keys(): + schema_key = schema_dict.dict_.get(subpath[1:], None) + else: + for schema_key in subpath[1:].split("/"): + schema_dict = schema_dict.dict_.get(schema_key, None) + if not isinstance(schema_dict, SchemaDict): + schema_key = schema_dict + if isinstance(schema_key, ClassLabel): + assign_value = check_class_label(assign_value, schema_key) + if ( + isinstance(schema_key, (Text, bytes)) + or np.array(assign_value).dtype.type is np.str_ + ): + # handling strings and bytes + assign_value = str_to_int(assign_value, self.dataset.tokenizer) if not subpath: raise ValueError("Can't assign to dataset sliced without key") diff --git a/hub/api/tensorview.py b/hub/api/tensorview.py index f48b24892e..1c27c0afeb 100755 --- a/hub/api/tensorview.py +++ b/hub/api/tensorview.py @@ -8,7 +8,7 @@ import collections.abc as abc from hub.api.dataset_utils import get_value, slice_split, str_to_int, check_class_label from hub.exceptions import NoneValueException -from hub.schema import ClassLabel +from hub.schema import ClassLabel, Text, SchemaDict import hub.api.objectview as objv @@ -196,9 +196,6 @@ def __setitem__(self, slice_, value): >>> images_tensorview[7, 0:1920, 0:1080, 0:3] = np.zeros((1920, 1080, 3), "uint8") # sets 7th image """ self.dataset._auto_checkout() - assign_value = get_value(value) - # handling strings and bytes - assign_value = str_to_int(assign_value, self.dataset.tokenizer) if not isinstance(slice_, abc.Iterable) or isinstance(slice_, str): slice_ = [slice_] @@ -208,10 +205,23 @@ def __setitem__(self, slice_, value): if subpath: raise ValueError("Can't setitem of TensorView with subpath") - subpath_type = self.dataset.schema.dict_.get(subpath.replace("/", ""), None) - if isinstance(subpath_type, ClassLabel): - check_class_label(value, subpath_type) - + assign_value = get_value(value) + schema_dict = self.dataset.schema + if subpath[1:] in schema_dict.dict_.keys(): + schema_key = schema_dict.dict_.get(subpath[1:], None) + else: + for schema_key in subpath[1:].split("/"): + schema_dict = schema_dict.dict_.get(schema_key, None) + if not isinstance(schema_dict, SchemaDict): + schema_key = schema_dict + if isinstance(schema_key, ClassLabel): + assign_value = check_class_label(assign_value, schema_key) + if ( + isinstance(schema_key, (Text, bytes)) + or np.array(assign_value).dtype.type is np.str_ + ): + # handling strings and bytes + assign_value = str_to_int(assign_value, self.dataset.tokenizer) new_nums = self.nums.copy() new_offsets = self.offsets.copy() if isinstance(self.indexes, list): diff --git a/hub/api/tests/test_dataset.py b/hub/api/tests/test_dataset.py index b1b7442a4c..02e390e783 100644 --- a/hub/api/tests/test_dataset.py +++ b/hub/api/tests/test_dataset.py @@ -9,17 +9,14 @@ import cloudpickle import hub.api.dataset as dataset -import pickle from hub.cli.auth import login_fn from hub.exceptions import DirectoryNotEmptyException, ClassLabelValueError import numpy as np import pytest from hub import load, transform from hub.api.dataset_utils import slice_extract_info, slice_split, check_class_label -from hub.cli.auth import login_fn from hub.exceptions import DirectoryNotEmptyException from hub.schema import BBox, ClassLabel, Image, SchemaDict, Sequence, Tensor, Text -from hub.schema.class_label import ClassLabel from hub.utils import ( azure_creds_exist, gcp_creds_exist, @@ -1168,18 +1165,25 @@ def test_check_label_name(): def test_class_label_value(): ds = Dataset( "./data/tests/test_check_label", - mode="a", + mode="w", shape=(5,), - schema={"label": ClassLabel(names=["name1", "name2", "name3"])}, + schema={ + "label": ClassLabel(names=["name1", "name2", "name3"]), + "label/b": ClassLabel(num_classes=5), + }, ) ds["label", 0:7] = 2 ds["label", 0:2] = np.array([0, 1]) try: - ds["label", 0] = 4 + ds["label/b", 0] = 6 + except Exception as ex: + assert isinstance(ex, ClassLabelValueError) + try: + ds[0:4]["label/b"] = np.array([0, 1, 2, 3, 7]) except Exception as ex: assert isinstance(ex, ClassLabelValueError) try: - ds[0:4]["label"] = np.array([0, 1, 2, 3]) + ds["label", 4] = "name4" except Exception as ex: assert isinstance(ex, ClassLabelValueError) From 63aadd19cc24830ae98bf45dc08c1c4f16e0d21c Mon Sep 17 00:00:00 2001 From: kristinagrig06 Date: Wed, 7 Apr 2021 14:51:39 +0400 Subject: [PATCH 6/7] Check str in value --- hub/api/dataset.py | 7 ++++--- hub/api/datasetview.py | 9 ++++----- hub/api/tensorview.py | 8 +++++--- 3 files changed, 13 insertions(+), 11 deletions(-) diff --git a/hub/api/dataset.py b/hub/api/dataset.py index a907ab3388..5925d524a8 100755 --- a/hub/api/dataset.py +++ b/hub/api/dataset.py @@ -10,6 +10,7 @@ import collections.abc as abc import json import sys +from typing import Iterable import traceback from collections import defaultdict import numpy as np @@ -620,9 +621,9 @@ def __setitem__(self, slice_, value): schema_key = schema_dict if isinstance(schema_key, ClassLabel): assign_value = check_class_label(assign_value, schema_key) - if ( - isinstance(schema_key, (Text, bytes)) - or np.array(assign_value).dtype.type is np.str_ + if isinstance(schema_key, (Text, bytes)) or ( + isinstance(assign_value, Iterable) + and any(isinstance(val, str) for val in assign_value) ): # handling strings and bytes assign_value = str_to_int(assign_value, self.tokenizer) diff --git a/hub/api/datasetview.py b/hub/api/datasetview.py index 98870bd9d7..e021cf5c68 100755 --- a/hub/api/datasetview.py +++ b/hub/api/datasetview.py @@ -3,8 +3,7 @@ This Source Code Form is subject to the terms of the Mozilla Public License, v. 2.0. If a copy of the MPL was not distributed with this file, You can obtain one at https://mozilla.org/MPL/2.0/. """ - -from hub.utils import _tuple_product +from typing import Iterable from hub.api.tensorview import TensorView import collections.abc as abc from hub.api.dataset_utils import ( @@ -153,9 +152,9 @@ def __setitem__(self, slice_, value): schema_key = schema_dict if isinstance(schema_key, ClassLabel): assign_value = check_class_label(assign_value, schema_key) - if ( - isinstance(schema_key, (Text, bytes)) - or np.array(assign_value).dtype.type is np.str_ + if isinstance(schema_key, (Text, bytes)) or ( + isinstance(assign_value, Iterable) + and any(isinstance(val, str) for val in assign_value) ): # handling strings and bytes assign_value = str_to_int(assign_value, self.dataset.tokenizer) diff --git a/hub/api/tensorview.py b/hub/api/tensorview.py index 1c27c0afeb..42b6582cd8 100755 --- a/hub/api/tensorview.py +++ b/hub/api/tensorview.py @@ -4,6 +4,7 @@ If a copy of the MPL was not distributed with this file, You can obtain one at https://mozilla.org/MPL/2.0/. """ import numpy as np +from typing import Iterable import hub import collections.abc as abc from hub.api.dataset_utils import get_value, slice_split, str_to_int, check_class_label @@ -216,12 +217,13 @@ def __setitem__(self, slice_, value): schema_key = schema_dict if isinstance(schema_key, ClassLabel): assign_value = check_class_label(assign_value, schema_key) - if ( - isinstance(schema_key, (Text, bytes)) - or np.array(assign_value).dtype.type is np.str_ + if isinstance(schema_key, (Text, bytes)) or ( + isinstance(assign_value, Iterable) + and any(isinstance(val, str) for val in assign_value) ): # handling strings and bytes assign_value = str_to_int(assign_value, self.dataset.tokenizer) + new_nums = self.nums.copy() new_offsets = self.offsets.copy() if isinstance(self.indexes, list): From 7b0cb518401ed4de98cf285dd42ac58c9c778743 Mon Sep 17 00:00:00 2001 From: kristinagrig06 Date: Thu, 22 Apr 2021 11:54:06 +0400 Subject: [PATCH 7/7] Add new tests --- hub/api/tests/test_dataset.py | 64 +++++++++++++++++++---------------- 1 file changed, 35 insertions(+), 29 deletions(-) diff --git a/hub/api/tests/test_dataset.py b/hub/api/tests/test_dataset.py index 02e390e783..2483eebef2 100644 --- a/hub/api/tests/test_dataset.py +++ b/hub/api/tests/test_dataset.py @@ -1174,6 +1174,8 @@ def test_class_label_value(): ) ds["label", 0:7] = 2 ds["label", 0:2] = np.array([0, 1]) + ds["label", 0:3] = ["name1", "name2", "name3"] + ds[0:3]["label"] = [0, "name2", 2] try: ds["label/b", 0] = 6 except Exception as ex: @@ -1186,6 +1188,10 @@ def test_class_label_value(): ds["label", 4] = "name4" except Exception as ex: assert isinstance(ex, ClassLabelValueError) + try: + ds[0]["label/b"] = ["name"] + except Exception as ex: + assert isinstance(ex, ValueError) @pytest.mark.skipif(not minio_creds_exist(), reason="requires minio credentials") @@ -1234,33 +1240,33 @@ def my_filter(sample): if __name__ == "__main__": - test_dataset_assign_value() - test_dataset_setting_shape() - test_datasetview_repr() - test_datasetview_get_dictionary() - test_tensorview_slicing() - test_datasetview_slicing() - test_dataset() - test_dataset_batch_write_2() - test_append_dataset() - test_dataset_2() - - test_text_dataset() - test_text_dataset_tokenizer() - test_dataset_compute() - test_dataset_view_compute() - test_dataset_lazy() - test_dataset_view_lazy() - test_dataset_hub() - test_meta_information() - test_dataset_filter_2() - test_dataset_filter_3() - test_pickleability() - test_dataset_append_and_read() - test_tensorview_iter() - test_dataset_filter_4() - test_datasetview_2() - test_dataset_3() - test_dataset_utils() - test_check_label_name() + # test_dataset_assign_value() + # test_dataset_setting_shape() + # test_datasetview_repr() + # test_datasetview_get_dictionary() + # test_tensorview_slicing() + # test_datasetview_slicing() + # test_dataset() + # test_dataset_batch_write_2() + # test_append_dataset() + # test_dataset_2() + + # test_text_dataset() + # test_text_dataset_tokenizer() + # test_dataset_compute() + # test_dataset_view_compute() + # test_dataset_lazy() + # test_dataset_view_lazy() + # test_dataset_hub() + # test_meta_information() + # test_dataset_filter_2() + # test_dataset_filter_3() + # test_pickleability() + # test_dataset_append_and_read() + # test_tensorview_iter() + # test_dataset_filter_4() + # test_datasetview_2() + # test_dataset_3() + # test_dataset_utils() + # test_check_label_name() test_class_label_value()