Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes/class label shape #744

Merged
merged 11 commits into from
Apr 29, 2021
25 changes: 21 additions & 4 deletions hub/api/dataset.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import collections.abc as abc
import json
import sys
from typing import Iterable
import traceback
from collections import defaultdict
import numpy as np
Expand Down Expand Up @@ -43,14 +44,15 @@
_get_compressor,
_get_dynamic_tensor_dtype,
_store_helper,
check_class_label,
same_schema,
)

import hub.schema.serialize
import hub.schema.deserialize
from hub.schema.features import flatten
from hub.schema import ClassLabel
from hub import auto

from hub.store.dynamic_tensor import DynamicTensor
from hub.store.store import get_fs_and_path, get_storage_map
from hub.exceptions import (
Expand Down Expand Up @@ -595,9 +597,6 @@ def __setitem__(self, slice_, value):
if "r" in self._mode:
raise ReadModeException("__setitem__")
self._auto_checkout()
assign_value = get_value(value)
# handling strings and bytes
assign_value = str_to_int(assign_value, self.tokenizer)

if not isinstance(slice_, abc.Iterable) or isinstance(slice_, str):
slice_ = [slice_]
Expand All @@ -609,6 +608,24 @@ def __setitem__(self, slice_, value):
elif subpath not in self.keys:
raise KeyError(f"Key {subpath} not found in the dataset")

assign_value = get_value(value)
schema_dict = self.schema
if subpath[1:] in schema_dict.dict_.keys():
schema_key = schema_dict.dict_.get(subpath[1:], None)
else:
for schema_key in subpath[1:].split("/"):
schema_dict = schema_dict.dict_.get(schema_key, None)
if not isinstance(schema_dict, SchemaDict):
schema_key = schema_dict
if isinstance(schema_key, ClassLabel):
assign_value = check_class_label(assign_value, schema_key)
if isinstance(schema_key, (Text, bytes)) or (
isinstance(assign_value, Iterable)
and any(isinstance(val, str) for val in assign_value)
):
# handling strings and bytes
assign_value = str_to_int(assign_value, self.tokenizer)

if not slice_list:
self._tensors[subpath][:] = assign_value
else:
Expand Down
29 changes: 28 additions & 1 deletion hub/api/dataset_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,24 @@
"""

import os
import time
from typing import Union, Iterable
from hub.store.store import get_fs_and_path
import numpy as np
import sys
from hub.exceptions import ModuleNotInstalledException, DirectoryNotEmptyException
from hub.exceptions import (
ModuleNotInstalledException,
DirectoryNotEmptyException,
ClassLabelValueError,
)
import hashlib
import time
import numcodecs
import numcodecs.lz4
import numcodecs.zstd
from hub.schema.features import Primitive, SchemaDict
from hub.numcodecs import PngCodec
from hub.schema import ClassLabel


def slice_split(slice_):
Expand Down Expand Up @@ -266,3 +273,23 @@ def _get_compressor(compressor: str):
raise ValueError(
f"Wrong compressor: {compressor}, only LZ4, PNG and ZSTD are supported"
)


def check_class_label(value: Union[np.ndarray, list], label: ClassLabel):
"""Check if value can be assigned to predefined ClassLabel"""
if not isinstance(value, Iterable) or isinstance(value, str):
assign_class_labels = [value]
else:
assign_class_labels = value
for i, assign_class_label in enumerate(assign_class_labels):
if isinstance(assign_class_label, str):
try:
assign_class_labels[i] = label.str2int(assign_class_label)
except KeyError:
raise ClassLabelValueError(label.names, assign_class_label)

if min(assign_class_labels) < 0 or max(assign_class_labels) > label.num_classes - 1:
raise ClassLabelValueError(range(label.num_classes - 1), assign_class_label)
if len(assign_class_labels) == 1:
return assign_class_labels[0]
return assign_class_labels
28 changes: 21 additions & 7 deletions hub/api/datasetview.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@
This Source Code Form is subject to the terms of the Mozilla Public License, v. 2.0.
If a copy of the MPL was not distributed with this file, You can obtain one at https://mozilla.org/MPL/2.0/.
"""

from hub.utils import _tuple_product
from typing import Iterable
from hub.api.tensorview import TensorView
import collections.abc as abc
from hub.api.dataset_utils import (
Expand All @@ -13,10 +12,11 @@
slice_split,
str_to_int,
_store_helper,
check_class_label,
)
from hub.exceptions import NoneValueException
from hub.api.objectview import ObjectView
from hub.schema import Sequence
from hub.schema import Sequence, ClassLabel, Text, SchemaDict
import numpy as np


Expand Down Expand Up @@ -134,17 +134,31 @@ def __setitem__(self, slice_, value):
>>> ds_view["image", 3, 0:1920, 0:1080, 0:3] = np.zeros((1920, 1080, 3), "uint8") # sets the 8th image
"""
self.dataset._auto_checkout()
assign_value = get_value(value)
assign_value = str_to_int(
assign_value, self.dataset.tokenizer
) # handling strings and bytes

if not isinstance(slice_, abc.Iterable) or isinstance(slice_, str):
slice_ = [slice_]
slice_ = list(slice_)
subpath, slice_list = slice_split(slice_)
slice_list = [0] + slice_list if isinstance(self.indexes, int) else slice_list

assign_value = get_value(value)
schema_dict = self.dataset.schema
if subpath[1:] in schema_dict.dict_.keys():
schema_key = schema_dict.dict_.get(subpath[1:], None)
else:
for schema_key in subpath[1:].split("/"):
schema_dict = schema_dict.dict_.get(schema_key, None)
if not isinstance(schema_dict, SchemaDict):
schema_key = schema_dict
if isinstance(schema_key, ClassLabel):
assign_value = check_class_label(assign_value, schema_key)
if isinstance(schema_key, (Text, bytes)) or (
isinstance(assign_value, Iterable)
and any(isinstance(val, str) for val in assign_value)
):
# handling strings and bytes
assign_value = str_to_int(assign_value, self.dataset.tokenizer)

if not subpath:
raise ValueError("Can't assign to dataset sliced without key")
elif subpath not in self.keys:
Expand Down
27 changes: 22 additions & 5 deletions hub/api/tensorview.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@
This Source Code Form is subject to the terms of the Mozilla Public License, v. 2.0.
If a copy of the MPL was not distributed with this file, You can obtain one at https://mozilla.org/MPL/2.0/.
"""

import numpy as np
from typing import Iterable
import hub
import collections.abc as abc
from hub.api.dataset_utils import get_value, slice_split, str_to_int
from hub.api.dataset_utils import get_value, slice_split, str_to_int, check_class_label
from hub.exceptions import NoneValueException
from hub.schema import ClassLabel, Text, SchemaDict
import hub.api.objectview as objv


Expand Down Expand Up @@ -196,9 +197,6 @@ def __setitem__(self, slice_, value):
>>> images_tensorview[7, 0:1920, 0:1080, 0:3] = np.zeros((1920, 1080, 3), "uint8") # sets 7th image
"""
self.dataset._auto_checkout()
assign_value = get_value(value)
# handling strings and bytes
assign_value = str_to_int(assign_value, self.dataset.tokenizer)

if not isinstance(slice_, abc.Iterable) or isinstance(slice_, str):
slice_ = [slice_]
Expand All @@ -207,6 +205,25 @@ def __setitem__(self, slice_, value):
subpath, slice_list = slice_split(slice_)
if subpath:
raise ValueError("Can't setitem of TensorView with subpath")

assign_value = get_value(value)
schema_dict = self.dataset.schema
if subpath[1:] in schema_dict.dict_.keys():
schema_key = schema_dict.dict_.get(subpath[1:], None)
else:
for schema_key in subpath[1:].split("/"):
schema_dict = schema_dict.dict_.get(schema_key, None)
if not isinstance(schema_dict, SchemaDict):
schema_key = schema_dict
if isinstance(schema_key, ClassLabel):
assign_value = check_class_label(assign_value, schema_key)
if isinstance(schema_key, (Text, bytes)) or (
isinstance(assign_value, Iterable)
and any(isinstance(val, str) for val in assign_value)
):
# handling strings and bytes
assign_value = str_to_int(assign_value, self.dataset.tokenizer)

new_nums = self.nums.copy()
new_offsets = self.offsets.copy()
if isinstance(self.indexes, list):
Expand Down
98 changes: 65 additions & 33 deletions hub/api/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,21 @@
This Source Code Form is subject to the terms of the Mozilla Public License, v. 2.0.
If a copy of the MPL was not distributed with this file, You can obtain one at https://mozilla.org/MPL/2.0/.
"""

import os
import pickle
import shutil

import cloudpickle
import hub.api.dataset as dataset
from hub.cli.auth import login_fn
from hub.exceptions import DirectoryNotEmptyException, ClassLabelValueError
import numpy as np
import pytest
from hub import load, transform
from hub.api.dataset_utils import slice_extract_info, slice_split
from hub.api.dataset_utils import slice_extract_info, slice_split, check_class_label
from hub.cli.auth import login_fn
from hub.exceptions import DirectoryNotEmptyException, SchemaMismatchException
from hub.schema import BBox, ClassLabel, Image, SchemaDict, Sequence, Tensor, Text
from hub.schema.class_label import ClassLabel
from hub.utils import (
azure_creds_exist,
gcp_creds_exist,
Expand Down Expand Up @@ -1215,6 +1215,38 @@ def test_check_label_name():
assert ds[1:3].compute().tolist() == [{"label": 2}, {"label": 0}]


def test_class_label_value():
ds = Dataset(
"./data/tests/test_check_label",
mode="w",
shape=(5,),
schema={
"label": ClassLabel(names=["name1", "name2", "name3"]),
"label/b": ClassLabel(num_classes=5),
},
)
ds["label", 0:7] = 2
ds["label", 0:2] = np.array([0, 1])
ds["label", 0:3] = ["name1", "name2", "name3"]
ds[0:3]["label"] = [0, "name2", 2]
try:
ds["label/b", 0] = 6
except Exception as ex:
assert isinstance(ex, ClassLabelValueError)
try:
ds[0:4]["label/b"] = np.array([0, 1, 2, 3, 7])
except Exception as ex:
assert isinstance(ex, ClassLabelValueError)
try:
ds["label", 4] = "name4"
except Exception as ex:
assert isinstance(ex, ClassLabelValueError)
try:
ds[0]["label/b"] = ["name"]
except Exception as ex:
assert isinstance(ex, ValueError)


@pytest.mark.skipif(not minio_creds_exist(), reason="requires minio credentials")
def test_minio_endpoint():
token = {
Expand Down Expand Up @@ -1261,33 +1293,33 @@ def my_filter(sample):


if __name__ == "__main__":
test_dataset_dynamic_shaped_slicing()
test_dataset_assign_value()
test_dataset_setting_shape()
test_datasetview_repr()
test_datasetview_get_dictionary()
test_tensorview_slicing()
test_datasetview_slicing()
test_dataset()
test_dataset_batch_write_2()
test_append_dataset()
test_append_resize()
test_dataset_2()
test_text_dataset()
test_text_dataset_tokenizer()
test_dataset_compute()
test_dataset_view_compute()
test_dataset_lazy()
test_dataset_view_lazy()
test_dataset_hub()
test_meta_information()
test_dataset_filter_2()
test_dataset_filter_3()
test_pickleability()
test_dataset_append_and_read()
test_tensorview_iter()
test_dataset_filter_4()
test_datasetview_2()
test_dataset_3()
test_dataset_utils()
test_check_label_name()
# test_dataset_assign_value()
# test_dataset_setting_shape()
# test_datasetview_repr()
# test_datasetview_get_dictionary()
# test_tensorview_slicing()
# test_datasetview_slicing()
# test_dataset()
# test_dataset_batch_write_2()
# test_append_dataset()
# test_dataset_2()

# test_text_dataset()
# test_text_dataset_tokenizer()
# test_dataset_compute()
# test_dataset_view_compute()
# test_dataset_lazy()
# test_dataset_view_lazy()
# test_dataset_hub()
# test_meta_information()
# test_dataset_filter_2()
# test_dataset_filter_3()
# test_pickleability()
# test_dataset_append_and_read()
# test_tensorview_iter()
# test_dataset_filter_4()
# test_datasetview_2()
# test_dataset_3()
# test_dataset_utils()
# test_check_label_name()
test_class_label_value()
6 changes: 6 additions & 0 deletions hub/exceptions.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,12 @@ def __init__(self, correct_shape, wrong_shape):
super(HubException, self).__init__(message=message)


class ClassLabelValueError(HubException):
def __init__(self, expected_classes, value):
message = f"Expected ClassLabel to be in {expected_classes}, got {value}"
super(HubException, self).__init__(message=message)


class NoneValueException(HubException):
def __init__(self, param):
message = f"Parameter '{param}' should be provided"
Expand Down