diff --git a/hub/api/dataset.py b/hub/api/dataset.py index 506afe6c64..96a8e94f32 100644 --- a/hub/api/dataset.py +++ b/hub/api/dataset.py @@ -43,6 +43,7 @@ _get_compressor, _get_dynamic_tensor_dtype, _store_helper, + same_schema, ) import hub.schema.serialize @@ -66,6 +67,7 @@ VersioningNotSupportedException, WrongUsernameException, InvalidVersionInfoException, + SchemaMismatchException, ) from hub.store.metastore import MetaStorage from hub.client.hub_control import HubControlClient @@ -207,14 +209,10 @@ def __init__( if shape != (None,) and shape != self._shape: raise TypeError( - f"Shape in metafile [{self._shape}] and shape in arguments [{shape}] are !=, use mode='w' to overwrite dataset" - ) - if schema is not None and sorted(schema.dict_.keys()) != sorted( - self._schema.dict_.keys() - ): - raise TypeError( - "Schema in metafile and schema in arguments do not match, use mode='w' to overwrite dataset" + f"Shape stored previously [{self._shape}] and shape in arguments [{shape}] are !=, use mode='w' to overwrite dataset" ) + if schema is not None and not same_schema(schema, self._schema): + raise SchemaMismatchException() else: if shape[0] is None: diff --git a/hub/api/dataset_utils.py b/hub/api/dataset_utils.py index 7714c452bd..461c30d5fb 100644 --- a/hub/api/dataset_utils.py +++ b/hub/api/dataset_utils.py @@ -14,7 +14,7 @@ import numcodecs import numcodecs.lz4 import numcodecs.zstd -from hub.schema.features import Primitive +from hub.schema.features import Primitive, SchemaDict from hub.numcodecs import PngCodec @@ -34,6 +34,24 @@ def slice_split(slice_): return path, list_slice +def same_schema(schema1, schema2): + """returns True if same, else False""" + if schema1.dict_.keys() != schema2.dict_.keys(): + return False + for k, v in schema1.dict_.items(): + if isinstance(v, SchemaDict) and not same_schema(v, schema2.dict_[k]): + return False + elif ( + v.shape != schema2.dict_[k].shape + or v.max_shape != schema2.dict_[k].max_shape + or v.chunks != schema2.dict_[k].chunks + or v.dtype != schema2.dict_[k].dtype + or v.compressor != schema2.dict_[k].compressor + ): + return False + return True + + def slice_extract_info(slice_, num): """Extracts number of samples and offset from slice""" if isinstance(slice_, int): diff --git a/hub/api/tests/test_dataset.py b/hub/api/tests/test_dataset.py index 0ebb623c73..94920ba08d 100644 --- a/hub/api/tests/test_dataset.py +++ b/hub/api/tests/test_dataset.py @@ -15,7 +15,7 @@ from hub import load, transform from hub.api.dataset_utils import slice_extract_info, slice_split from hub.cli.auth import login_fn -from hub.exceptions import DirectoryNotEmptyException +from hub.exceptions import DirectoryNotEmptyException, SchemaMismatchException from hub.schema import BBox, ClassLabel, Image, SchemaDict, Sequence, Tensor, Text from hub.schema.class_label import ClassLabel from hub.utils import ( @@ -329,15 +329,67 @@ def test_dataset_wrong_append(url="./data/test/dataset", token=None): } ds = Dataset(url, token=token, shape=(10000,), mode="w", schema=my_schema) ds.close() - try: + with pytest.raises(TypeError): ds = Dataset(url, shape=100) - except Exception as ex: - assert isinstance(ex, TypeError) - try: + with pytest.raises(SchemaMismatchException): ds = Dataset(url, schema={"hello": "uint8"}) - except Exception as ex: - assert isinstance(ex, TypeError) + + +def test_dataset_change_schema(): + schema = { + "abc": "uint8", + "def": { + "ghi": Tensor((100, 100)), + "rst": Tensor((100, 100, 100)), + }, + } + ds = Dataset("./data/test_schema_change", schema=schema, shape=(100,)) + new_schema_1 = { + "abc": "uint8", + "def": { + "ghi": Tensor((200, 100)), + "rst": Tensor((100, 100, 100)), + }, + } + new_schema_2 = { + "abrs": "uint8", + "def": { + "ghi": Tensor((100, 100)), + "rst": Tensor((100, 100, 100)), + }, + } + new_schema_3 = { + "abc": "uint8", + "def": { + "ghijk": Tensor((100, 100)), + "rst": Tensor((100, 100, 100)), + }, + } + new_schema_4 = { + "abc": "uint16", + "def": { + "ghi": Tensor((100, 100)), + "rst": Tensor((100, 100, 100)), + }, + } + new_schema_5 = { + "abc": "uint8", + "def": { + "ghi": Tensor((100, 100, 3)), + "rst": Tensor((100, 100, 100)), + }, + } + with pytest.raises(SchemaMismatchException): + ds = Dataset("./data/test_schema_change", schema=new_schema_1, shape=(100,)) + with pytest.raises(SchemaMismatchException): + ds = Dataset("./data/test_schema_change", schema=new_schema_2, shape=(100,)) + with pytest.raises(SchemaMismatchException): + ds = Dataset("./data/test_schema_change", schema=new_schema_3, shape=(100,)) + with pytest.raises(SchemaMismatchException): + ds = Dataset("./data/test_schema_change", schema=new_schema_4, shape=(100,)) + with pytest.raises(SchemaMismatchException): + ds = Dataset("./data/test_schema_change", schema=new_schema_5, shape=(100,)) def test_dataset_no_shape(url="./data/test/dataset", token=None): @@ -989,7 +1041,7 @@ def abc_filter(sample): return sample["ab"].compute().startswith("abc") my_schema = {"img": Tensor((100, 100)), "ab": Text((None,), max_shape=(10,))} - ds = Dataset("./data/new_filter", shape=(10,), schema=my_schema) + ds = Dataset("./data/new_filter_2", shape=(10,), schema=my_schema) for i in range(10): ds["img", i] = i * np.ones((100, 100)) ds["ab", i] = "abc" + str(i) if i % 2 == 0 else "def" + str(i) diff --git a/hub/exceptions.py b/hub/exceptions.py index 74a69d96cc..df493e9dd3 100644 --- a/hub/exceptions.py +++ b/hub/exceptions.py @@ -260,6 +260,12 @@ def __init__(self): super(HubException, self).__init__(message=message) +class SchemaMismatchException(HubException): + def __init__(self): + message = "Schema stored previously and schema in arguments do not match, use mode='w' to overwrite dataset" + super(HubException, self).__init__(message=message) + + class DynamicTensorShapeException(HubException): def __init__(self, exc_type): if exc_type == "none":