Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes/class label shape #744

Merged
merged 11 commits into from
Apr 29, 2021
6 changes: 5 additions & 1 deletion hub/api/dataset.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,14 @@
_get_compressor,
_get_dynamic_tensor_dtype,
_store_helper,
check_class_label,
)

import hub.schema.serialize
import hub.schema.deserialize
from hub.schema.features import flatten
from hub.schema import ClassLabel
from hub import auto

from hub.store.dynamic_tensor import DynamicTensor
from hub.store.store import get_fs_and_path, get_storage_map
from hub.exceptions import (
Expand Down Expand Up @@ -611,6 +612,9 @@ def __setitem__(self, slice_, value):
elif subpath not in self.keys:
raise KeyError(f"Key {subpath} not found in the dataset")

subpath_type = self.schema.dict_.get(subpath.replace("/", ""), None)
if isinstance(subpath_type, ClassLabel):
check_class_label(value, subpath_type)
if not slice_list:
self._tensors[subpath][:] = assign_value
else:
Expand Down
31 changes: 30 additions & 1 deletion hub/api/dataset_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,24 @@
"""

import os
import time
from typing import Union, Iterable
from hub.store.store import get_fs_and_path
import numpy as np
import sys
from hub.exceptions import ModuleNotInstalledException, DirectoryNotEmptyException
from hub.exceptions import (
ModuleNotInstalledException,
DirectoryNotEmptyException,
ClassLabelValueError,
)
import hashlib
import time
import numcodecs
import numcodecs.lz4
import numcodecs.zstd
from hub.schema.features import Primitive
from hub.numcodecs import PngCodec
from hub.schema import ClassLabel


def slice_split(slice_):
Expand Down Expand Up @@ -248,3 +255,25 @@ def _get_compressor(compressor: str):
raise ValueError(
f"Wrong compressor: {compressor}, only LZ4, PNG and ZSTD are supported"
)


def check_class_label(value: Union[np.ndarray, list], subpath_type=None):
"""Check if value can be assigned to predefined ClassLabel"""
if not isinstance(value, Iterable) or isinstance(value, str):
assign_class_labels = [value]
else:
assign_class_labels = value
for assign_class_label in assign_class_labels:
if str(assign_class_label).isdigit():
assign_class_label = int(assign_class_label)
if (
isinstance(assign_class_label, str)
and assign_class_label not in subpath_type.names
):
raise ClassLabelValueError(subpath_type.names, assign_class_label)
elif isinstance(assign_class_label, int) and (
assign_class_label >= subpath_type.num_classes or assign_class_label < 0
):
raise ClassLabelValueError(
range(subpath_type.num_classes - 1), assign_class_label
)
7 changes: 6 additions & 1 deletion hub/api/datasetview.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,11 @@
slice_split,
str_to_int,
_store_helper,
check_class_label,
)
from hub.exceptions import NoneValueException
from hub.api.objectview import ObjectView
from hub.schema import Sequence
from hub.schema import Sequence, ClassLabel
import numpy as np


Expand Down Expand Up @@ -145,6 +146,10 @@ def __setitem__(self, slice_, value):
subpath, slice_list = slice_split(slice_)
slice_list = [0] + slice_list if isinstance(self.indexes, int) else slice_list

subpath_type = self.dataset.schema.dict_.get(subpath.replace("/", ""), None)
if isinstance(subpath_type, ClassLabel):
check_class_label(value, subpath_type)

if not subpath:
raise ValueError("Can't assign to dataset sliced without key")
elif subpath not in self.keys:
Expand Down
9 changes: 7 additions & 2 deletions hub/api/tensorview.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@
This Source Code Form is subject to the terms of the Mozilla Public License, v. 2.0.
If a copy of the MPL was not distributed with this file, You can obtain one at https://mozilla.org/MPL/2.0/.
"""

import numpy as np
import hub
import collections.abc as abc
from hub.api.dataset_utils import get_value, slice_split, str_to_int
from hub.api.dataset_utils import get_value, slice_split, str_to_int, check_class_label
from hub.exceptions import NoneValueException
from hub.schema import ClassLabel
import hub.api.objectview as objv


Expand Down Expand Up @@ -207,6 +207,11 @@ def __setitem__(self, slice_, value):
subpath, slice_list = slice_split(slice_)
if subpath:
raise ValueError("Can't setitem of TensorView with subpath")

subpath_type = self.dataset.schema.dict_.get(subpath.replace("/", ""), None)
if isinstance(subpath_type, ClassLabel):
check_class_label(value, subpath_type)

new_nums = self.nums.copy()
new_offsets = self.offsets.copy()
if isinstance(self.indexes, list):
Expand Down
29 changes: 25 additions & 4 deletions hub/api/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,19 @@
This Source Code Form is subject to the terms of the Mozilla Public License, v. 2.0.
If a copy of the MPL was not distributed with this file, You can obtain one at https://mozilla.org/MPL/2.0/.
"""

import os
import pickle
import shutil

import cloudpickle
import hub.api.dataset as dataset
import pickle
from hub.cli.auth import login_fn
from hub.exceptions import DirectoryNotEmptyException, ClassLabelValueError
import numpy as np
import pytest
from hub import load, transform
from hub.api.dataset_utils import slice_extract_info, slice_split
from hub.api.dataset_utils import slice_extract_info, slice_split, check_class_label
from hub.cli.auth import login_fn
from hub.exceptions import DirectoryNotEmptyException
from hub.schema import BBox, ClassLabel, Image, SchemaDict, Sequence, Tensor, Text
Expand Down Expand Up @@ -1163,6 +1165,25 @@ def test_check_label_name():
assert ds[1:3].compute().tolist() == [{"label": 2}, {"label": 0}]


def test_class_label_value():
ds = Dataset(
"./data/tests/test_check_label",
mode="a",
shape=(5,),
schema={"label": ClassLabel(names=["name1", "name2", "name3"])},
)
ds["label", 0:7] = 2
ds["label", 0:2] = np.array([0, 1])
try:
ds["label", 0] = 4
except Exception as ex:
assert isinstance(ex, ClassLabelValueError)
try:
ds[0:4]["label"] = np.array([0, 1, 2, 3])
except Exception as ex:
assert isinstance(ex, ClassLabelValueError)


@pytest.mark.skipif(not minio_creds_exist(), reason="requires minio credentials")
def test_minio_endpoint():
token = {
Expand Down Expand Up @@ -1209,7 +1230,6 @@ def my_filter(sample):


if __name__ == "__main__":
test_dataset_dynamic_shaped_slicing()
test_dataset_assign_value()
test_dataset_setting_shape()
test_datasetview_repr()
Expand All @@ -1219,8 +1239,8 @@ def my_filter(sample):
test_dataset()
test_dataset_batch_write_2()
test_append_dataset()
test_append_resize()
test_dataset_2()

test_text_dataset()
test_text_dataset_tokenizer()
test_dataset_compute()
Expand All @@ -1239,3 +1259,4 @@ def my_filter(sample):
test_dataset_3()
test_dataset_utils()
test_check_label_name()
test_class_label_value()
6 changes: 6 additions & 0 deletions hub/exceptions.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,12 @@ def __init__(self, correct_shape, wrong_shape):
super(HubException, self).__init__(message=message)


class ClassLabelValueError(HubException):
def __init__(self, expected_classes, value):
message = f"Expected ClassLabel to be in {expected_classes}, got {value}"
super(HubException, self).__init__(message=message)


class NoneValueException(HubException):
def __init__(self, param):
message = f"Parameter '{param}' should be provided"
Expand Down