diff --git a/hub/api/dataset.py b/hub/api/dataset.py index 5ff6d38eba..c606170fd3 100644 --- a/hub/api/dataset.py +++ b/hub/api/dataset.py @@ -1,6 +1,6 @@ from hub.core.storage.provider import StorageProvider from hub.core.tensor import create_tensor -from typing import Callable, Dict, Optional, Union, Tuple, List, Sequence +from typing import Any, Callable, Dict, Optional, Union, Tuple, List, Sequence from hub.constants import DEFAULT_HTYPE, UNSPECIFIED import numpy as np @@ -17,6 +17,7 @@ from hub.util.exceptions import ( CouldNotCreateNewDatasetException, InvalidKeyTypeError, + MemoryDatasetCanNotBePickledError, PathNotEmptyException, ReadOnlyModeError, TensorAlreadyExistsError, @@ -75,11 +76,7 @@ def __init__( creds = {} base_storage = get_storage_provider(path, storage, read_only, creds, token) - # done instead of directly assigning read_only as backend might return read_only permissions - if hasattr(base_storage, "read_only") and base_storage.read_only: - self._read_only = True - else: - self._read_only = False + self._read_only = base_storage.read_only # uniquely identifies dataset self.path = path or get_path_from_storage(base_storage) @@ -88,25 +85,11 @@ def __init__( self.storage = generate_chain( base_storage, memory_cache_size_bytes, local_cache_size_bytes, path ) - self.storage.autoflush = True self.index = index or Index() - self.tensors: Dict[str, Tensor] = {} - self._token = token - - if self.path.startswith("hub://"): - split_path = self.path.split("/") - self.org_id, self.ds_name = split_path[2], split_path[3] - self.client = HubBackendClient(token=token) - self.public = public - self._load_meta() - self.index.validate(self.num_samples) - - hub_reporter.feature_report( - feature_name="Dataset", parameters={"Path": str(self.path)} - ) + self._set_derived_attributes() def __enter__(self): self.storage.autoflush = False @@ -128,6 +111,34 @@ def __len__(self): tensor_lengths = [len(tensor[self.index]) for tensor in self.tensors.values()] return min(tensor_lengths, default=0) + def __getstate__(self) -> Dict[str, Any]: + """Returns a dict that can be pickled and used to restore this dataset. + + Note: + Pickling a dataset does not copy the dataset, it only saves attributes that can be used to restore the dataset. + If you pickle a local dataset and try to access it on a machine that does not have the data present, the dataset will not work. + """ + if self.path.startswith("mem://"): + raise MemoryDatasetCanNotBePickledError + return { + "path": self.path, + "_read_only": self.read_only, + "index": self.index, + "public": self.public, + "storage": self.storage, + "_token": self.token, + } + + def __setstate__(self, state: Dict[str, Any]): + """Restores dataset from a pickled state. + + Args: + state (dict): The pickled state used to restore the dataset. + """ + self.__dict__.update(state) + self.tensors = {} + self._set_derived_attributes() + def __getitem__( self, item: Union[ @@ -307,6 +318,21 @@ def _get_total_meta(self): for tensor_key, tensor_value in self.tensors.items() } + def _set_derived_attributes(self): + """Sets derived attributes during init and unpickling.""" + self.storage.autoflush = True + if self.path.startswith("hub://"): + split_path = self.path.split("/") + self.org_id, self.ds_name = split_path[2], split_path[3] + self.client = HubBackendClient(token=self._token) + + self._load_meta() + self.index.validate(self.num_samples) + + hub_reporter.feature_report( + feature_name="Dataset", parameters={"Path": str(self.path)} + ) + def tensorflow(self): """Converts the dataset into a tensorflow compatible format. diff --git a/hub/api/tests/test_api.py b/hub/api/tests/test_api.py index 74b32c69e6..44729b2ff4 100644 --- a/hub/api/tests/test_api.py +++ b/hub/api/tests/test_api.py @@ -1,8 +1,6 @@ import numpy as np import pytest -import uuid import hub -import os from hub.api.dataset import Dataset from hub.tests.common import assert_array_lists_equal from hub.util.exceptions import ( @@ -11,6 +9,7 @@ UnsupportedCompressionError, ) from click.testing import CliRunner +from hub.util.exceptions import TensorDtypeMismatchError, TensorInvalidSampleShapeError from hub.tests.dataset_fixtures import ( enabled_datasets, enabled_persistent_dataset_generators, @@ -205,7 +204,7 @@ def test_empty_samples(ds: Dataset): def test_scalar_samples(ds: Dataset): tensor = ds.create_tensor("scalars") - assert tensor.meta.dtype == None + assert tensor.meta.dtype is None # first sample sets dtype tensor.append(5) diff --git a/hub/api/tests/test_pickle.py b/hub/api/tests/test_pickle.py new file mode 100644 index 0000000000..bda1fbc13b --- /dev/null +++ b/hub/api/tests/test_pickle.py @@ -0,0 +1,46 @@ +import numpy as np +import pytest +from hub.util.exceptions import MemoryDatasetCanNotBePickledError +import pickle +from hub.tests.dataset_fixtures import enabled_datasets + + +@enabled_datasets +def test_dataset(ds): + if ds.path.startswith("mem://"): + with pytest.raises(MemoryDatasetCanNotBePickledError): + pickle.dumps(ds) + return + + with ds: + ds.create_tensor("image", htype="image", sample_compression="jpeg") + ds.create_tensor("label") + for i in range(10): + ds.image.append( + i * np.ones(((i + 1) * 20, (i + 1) * 20, 3), dtype=np.uint8) + ) + + for i in range(5): + ds.label.append(i) + + pickled_ds = pickle.dumps(ds) + unpickled_ds = pickle.loads(pickled_ds) + assert len(unpickled_ds.image) == len(ds.image) + assert len(unpickled_ds.label) == len(ds.label) + assert unpickled_ds.tensors.keys() == ds.tensors.keys() + assert unpickled_ds.index.values[0].value == ds.index.values[0].value + assert unpickled_ds.meta.version == ds.meta.version + + for i in range(10): + np.testing.assert_array_equal( + ds.image[i].numpy(), + (i * np.ones(((i + 1) * 20, (i + 1) * 20, 3), dtype=np.uint8)), + ) + np.testing.assert_array_equal( + ds.image[i].numpy(), unpickled_ds.image[i].numpy() + ) + for i in range(5): + np.testing.assert_array_equal(ds.label[i].numpy(), i) + np.testing.assert_array_equal( + ds.label[i].numpy(), unpickled_ds.label[i].numpy() + ) diff --git a/hub/core/storage/lru_cache.py b/hub/core/storage/lru_cache.py index 36c9a9ba63..b8abd205ba 100644 --- a/hub/core/storage/lru_cache.py +++ b/hub/core/storage/lru_cache.py @@ -1,6 +1,6 @@ from collections import OrderedDict from hub.core.storage.cachable import Cachable -from typing import Callable, Set, Union +from typing import Any, Dict, Set, Union from hub.core.storage.provider import StorageProvider @@ -161,11 +161,10 @@ def __delitem__(self, path: str): self.maybe_flush() def clear_cache(self): - """Flushes the content of the cache and and then deletes contents of all the layers of it. + """Flushes the content of the cache if not in read mode and and then deletes contents of all the layers of it. This doesn't delete data from the actual storage. """ - self.check_readonly() - self.flush() + self._flush_if_not_read_only() self.cache_used = 0 self.lru_sizes.clear() self.dirty_keys.clear() @@ -270,3 +269,40 @@ def _list_keys(self): for key in self.cache_storage: all_keys.add(key) return list(all_keys) + + def _flush_if_not_read_only(self): + """Flushes the cache if not in read-only mode.""" + if not self.read_only: + self.flush() + + def __getstate__(self) -> Dict[str, Any]: + """Returns the state of the cache, for pickling""" + + # flushes the cache before pickling + self._flush_if_not_read_only() + + return { + "next_storage": self.next_storage, + "cache_storage": self.cache_storage, + "cache_size": self.cache_size, + } + + def __setstate__(self, state: Dict[str, Any]): + """Recreates a cache with the same configuration as the state. + + Args: + state (dict): The state to be used to recreate the cache. + + Note: + While restoring the cache, we reset its contents. + In case the cache storage was local/s3 and is still accessible when unpickled (if same machine/s3 creds present respectively), the earlier cache contents are no longer accessible. + """ + + # TODO: We might want to change this behaviour in the future by having a separate file that keeps a track of the lru order for restoring the cache. + # This would also allow the cache to persist across different different Dataset objects pointing to the same dataset. + self.next_storage = state["next_storage"] + self.cache_storage = state["cache_storage"] + self.cache_size = state["cache_size"] + self.lru_sizes = OrderedDict() + self.dirty_keys = set() + self.cache_used = 0 diff --git a/hub/core/storage/memory.py b/hub/core/storage/memory.py index 4f53d5e10e..8ae91fd840 100644 --- a/hub/core/storage/memory.py +++ b/hub/core/storage/memory.py @@ -1,11 +1,12 @@ +from typing import Any, Dict from hub.core.storage.provider import StorageProvider class MemoryProvider(StorageProvider): """Provider class for using the memory.""" - def __init__(self, root=""): - self.dict = {} + def __init__(self, root: str = ""): + self.dict: Dict[str, Any] = {} self.root = root def __getitem__( @@ -96,3 +97,10 @@ def clear(self): """Clears the provider.""" self.check_readonly() self.dict = {} + + def __getstate__(self) -> str: + """Does NOT save the in memory data in state.""" + return self.root + + def __setstate__(self, state: str): + self.__init__(root=state) # type: ignore diff --git a/hub/core/storage/provider.py b/hub/core/storage/provider.py index 9576be33ad..125985d771 100644 --- a/hub/core/storage/provider.py +++ b/hub/core/storage/provider.py @@ -10,6 +10,7 @@ class StorageProvider(ABC, MutableMapping): autoflush = False + read_only = False """An abstract base class for implementing a storage provider. @@ -137,7 +138,7 @@ def disable_readonly(self): def check_readonly(self): """Raises an exception if the provider is in read-only mode.""" - if hasattr(self, "read_only") and self.read_only: + if self.read_only: raise ReadOnlyModeError() def flush(self): @@ -150,7 +151,7 @@ def maybe_flush(self): """Flush cache if autoflush has been enabled. Called at the end of methods which write data, to ensure consistency as a default. """ - if hasattr(self, "autoflush") and self.autoflush: + if self.autoflush: self.flush() @abstractmethod diff --git a/hub/core/storage/s3.py b/hub/core/storage/s3.py index fbeae6c6f4..41a6d7f9cf 100644 --- a/hub/core/storage/s3.py +++ b/hub/core/storage/s3.py @@ -252,10 +252,8 @@ def _check_update_creds(self): client = HubBackendClient(self.token) org_id, ds_name = self.tag.split("/") - if hasattr(self, "read_only") and self.read_only: - mode = "r" - else: - mode = "a" + mode = "r" if self.read_only else "a" + url, creds, mode, expiration = client.get_dataset_credentials( org_id, ds_name, mode ) diff --git a/hub/core/storage/tests/test_storage_provider.py b/hub/core/storage/tests/test_storage_provider.py index 93fc8dd25f..23fe2e943d 100644 --- a/hub/core/storage/tests/test_storage_provider.py +++ b/hub/core/storage/tests/test_storage_provider.py @@ -1,8 +1,6 @@ -from hub.tests.storage_fixtures import enabled_storages +from hub.tests.storage_fixtures import enabled_storages, enabled_persistent_storages from hub.tests.cache_fixtures import enabled_cache_chains import pytest - -from click.testing import CliRunner from hub.constants import MB import pickle @@ -126,14 +124,11 @@ def test_cache(cache_chain): check_cache(cache_chain) -@enabled_storages +@enabled_persistent_storages def test_pickling(storage): - with CliRunner().isolated_filesystem(): - FILE_1 = f"{KEY}_1" - storage[FILE_1] = b"hello world" - assert storage[FILE_1] == b"hello world" - pickle_file = open("storage_pickle", "wb") - pickle.dump(storage, pickle_file) - pickle_file = open("storage_pickle", "rb") - unpickled_storage = pickle.load(pickle_file) - assert unpickled_storage[FILE_1] == b"hello world" + FILE_1 = f"{KEY}_1" + storage[FILE_1] = b"hello world" + assert storage[FILE_1] == b"hello world" + pickled_storage = pickle.dumps(storage) + unpickled_storage = pickle.loads(pickled_storage) + assert unpickled_storage[FILE_1] == b"hello world" diff --git a/hub/tests/storage_fixtures.py b/hub/tests/storage_fixtures.py index d685b492b3..6e2c864d2d 100644 --- a/hub/tests/storage_fixtures.py +++ b/hub/tests/storage_fixtures.py @@ -11,6 +11,12 @@ indirect=True, ) +enabled_persistent_storages = pytest.mark.parametrize( + "storage", + ["local_storage", "s3_storage", "hub_cloud_storage"], + indirect=True, +) + @pytest.fixture def memory_storage(memory_path): diff --git a/hub/util/exceptions.py b/hub/util/exceptions.py index d226f46733..98aa50af02 100644 --- a/hub/util/exceptions.py +++ b/hub/util/exceptions.py @@ -453,3 +453,10 @@ def __init__(self): super().__init__( f"Python Shared memory with multiprocessing doesn't work properly on Windows." ) + + +class MemoryDatasetCanNotBePickledError(Exception): + def __init__(self): + super().__init__( + "Dataset having MemoryProvider as underlying storage should not be pickled as data won't be saved." + )