Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Schema check #788

Merged
merged 4 commits into from
Apr 22, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
12 changes: 5 additions & 7 deletions hub/api/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
_get_compressor,
_get_dynamic_tensor_dtype,
_store_helper,
same_schema,
)

import hub.schema.serialize
Expand All @@ -66,6 +67,7 @@
VersioningNotSupportedException,
WrongUsernameException,
InvalidVersionInfoException,
SchemaMismatchException,
)
from hub.store.metastore import MetaStorage
from hub.client.hub_control import HubControlClient
Expand Down Expand Up @@ -207,14 +209,10 @@ def __init__(

if shape != (None,) and shape != self._shape:
raise TypeError(
f"Shape in metafile [{self._shape}] and shape in arguments [{shape}] are !=, use mode='w' to overwrite dataset"
)
if schema is not None and sorted(schema.dict_.keys()) != sorted(
self._schema.dict_.keys()
):
raise TypeError(
"Schema in metafile and schema in arguments do not match, use mode='w' to overwrite dataset"
f"Shape stored previously [{self._shape}] and shape in arguments [{shape}] are !=, use mode='w' to overwrite dataset"
)
if schema is not None and not same_schema(schema, self._schema):
raise SchemaMismatchException()

else:
if shape[0] is None:
Expand Down
20 changes: 19 additions & 1 deletion hub/api/dataset_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import numcodecs
import numcodecs.lz4
import numcodecs.zstd
from hub.schema.features import Primitive
from hub.schema.features import Primitive, SchemaDict
from hub.numcodecs import PngCodec


Expand All @@ -34,6 +34,24 @@ def slice_split(slice_):
return path, list_slice


def same_schema(schema1, schema2):
"""returns True if same, else False"""
if schema1.dict_.keys() != schema2.dict_.keys():
return False
for k, v in schema1.dict_.items():
if isinstance(v, SchemaDict) and not same_schema(v, schema2.dict_[k]):
return False
elif (
v.shape != schema2.dict_[k].shape
or v.max_shape != schema2.dict_[k].max_shape
or v.chunks != schema2.dict_[k].chunks
or v.dtype != schema2.dict_[k].dtype
or v.compressor != schema2.dict_[k].compressor
):
return False
return True


def slice_extract_info(slice_, num):
"""Extracts number of samples and offset from slice"""
if isinstance(slice_, int):
Expand Down
68 changes: 60 additions & 8 deletions hub/api/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from hub import load, transform
from hub.api.dataset_utils import slice_extract_info, slice_split
from hub.cli.auth import login_fn
from hub.exceptions import DirectoryNotEmptyException
from hub.exceptions import DirectoryNotEmptyException, SchemaMismatchException
from hub.schema import BBox, ClassLabel, Image, SchemaDict, Sequence, Tensor, Text
from hub.schema.class_label import ClassLabel
from hub.utils import (
Expand Down Expand Up @@ -329,15 +329,67 @@ def test_dataset_wrong_append(url="./data/test/dataset", token=None):
}
ds = Dataset(url, token=token, shape=(10000,), mode="w", schema=my_schema)
ds.close()
try:
with pytest.raises(TypeError):
ds = Dataset(url, shape=100)
except Exception as ex:
assert isinstance(ex, TypeError)

try:
with pytest.raises(SchemaMismatchException):
ds = Dataset(url, schema={"hello": "uint8"})
except Exception as ex:
assert isinstance(ex, TypeError)


def test_dataset_change_schema():
schema = {
"abc": "uint8",
"def": {
"ghi": Tensor((100, 100)),
"rst": Tensor((100, 100, 100)),
},
}
ds = Dataset("./data/test_schema_change", schema=schema, shape=(100,))
new_schema_1 = {
"abc": "uint8",
"def": {
"ghi": Tensor((200, 100)),
"rst": Tensor((100, 100, 100)),
},
}
new_schema_2 = {
"abrs": "uint8",
"def": {
"ghi": Tensor((100, 100)),
"rst": Tensor((100, 100, 100)),
},
}
new_schema_3 = {
"abc": "uint8",
"def": {
"ghijk": Tensor((100, 100)),
"rst": Tensor((100, 100, 100)),
},
}
new_schema_4 = {
"abc": "uint16",
"def": {
"ghi": Tensor((100, 100)),
"rst": Tensor((100, 100, 100)),
},
}
new_schema_5 = {
"abc": "uint8",
"def": {
"ghi": Tensor((100, 100, 3)),
"rst": Tensor((100, 100, 100)),
},
}
with pytest.raises(SchemaMismatchException):
ds = Dataset("./data/test_schema_change", schema=new_schema_1, shape=(100,))
with pytest.raises(SchemaMismatchException):
ds = Dataset("./data/test_schema_change", schema=new_schema_2, shape=(100,))
with pytest.raises(SchemaMismatchException):
ds = Dataset("./data/test_schema_change", schema=new_schema_3, shape=(100,))
with pytest.raises(SchemaMismatchException):
ds = Dataset("./data/test_schema_change", schema=new_schema_4, shape=(100,))
with pytest.raises(SchemaMismatchException):
ds = Dataset("./data/test_schema_change", schema=new_schema_5, shape=(100,))


def test_dataset_no_shape(url="./data/test/dataset", token=None):
Expand Down Expand Up @@ -989,7 +1041,7 @@ def abc_filter(sample):
return sample["ab"].compute().startswith("abc")

my_schema = {"img": Tensor((100, 100)), "ab": Text((None,), max_shape=(10,))}
ds = Dataset("./data/new_filter", shape=(10,), schema=my_schema)
ds = Dataset("./data/new_filter_2", shape=(10,), schema=my_schema)
for i in range(10):
ds["img", i] = i * np.ones((100, 100))
ds["ab", i] = "abc" + str(i) if i % 2 == 0 else "def" + str(i)
Expand Down
6 changes: 6 additions & 0 deletions hub/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,12 @@ def __init__(self):
super(HubException, self).__init__(message=message)


class SchemaMismatchException(HubException):
def __init__(self):
message = "Schema stored previously and schema in arguments do not match, use mode='w' to overwrite dataset"
super(HubException, self).__init__(message=message)


class DynamicTensorShapeException(HubException):
def __init__(self, exc_type):
if exc_type == "none":
Expand Down