Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Made Dataset and LRUCache objects pickleable #1049

Merged
merged 16 commits into from
Jul 20, 2021
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
34 changes: 29 additions & 5 deletions hub/api/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from hub.util.cache_chain import generate_chain
from hub.util.exceptions import (
InvalidKeyTypeError,
MemoryDatasetCanNotBePickledError,
PathNotEmptyException,
TensorAlreadyExistsError,
TensorDoesNotExistError,
Expand Down Expand Up @@ -86,19 +87,19 @@ def __init__(
self.storage = generate_chain(
base_storage, memory_cache_size_bytes, local_cache_size_bytes, path
)
self.storage.autoflush = True
self.index = index or Index()

self.tensors: Dict[str, Tensor] = {}

self._token = token
self.public = public
self._init_helper()

def _init_helper(self):
AbhinavTuli marked this conversation as resolved.
Show resolved Hide resolved
self.storage.autoflush = True
if self.path.startswith("hub://"):
split_path = self.path.split("/")
self.org_id, self.ds_name = split_path[2], split_path[3]
self.client = HubBackendClient(token=token)
self.client = HubBackendClient(token=self._token)

self.public = public
self._load_meta()

hub_reporter.feature_report(
Expand All @@ -118,6 +119,29 @@ def __len__(self):
tensor_lengths = [len(tensor[self.index]) for tensor in self.tensors.values()]
return min(tensor_lengths, default=0)

def __getstate__(self):
"""Returns a dict that can be pickled and used to restore this dataset.

PS: Pickling a dataset does not copy the dataset, it only saves attributes that can be used to restore the dataset.
If you pickle a local dataset and try to access it on a machine that does not have the data present, the dataset will not work.
"""
if self.path.startswith("mem://"):
raise MemoryDatasetCanNotBePickledError
return {
"path": self.path,
"_read_only": self.read_only,
"index": self.index,
"public": self.public,
"storage": self.storage,
"_token": self.token,
}

def __setstate__(self, state):
"""Restores dataset from a pickled state."""
self.__dict__.update(state)
self.tensors = {}
self._init_helper()

def __getitem__(
self,
item: Union[
Expand Down
43 changes: 43 additions & 0 deletions hub/api/tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,14 @@
from hub.core.tests.common import parametrize_all_dataset_storages
from hub.tests.common import assert_array_lists_equal
from hub.util.exceptions import (
MemoryDatasetCanNotBePickledError,
TensorDtypeMismatchError,
TensorInvalidSampleShapeError,
)
from hub.client.client import HubBackendClient
from hub.client.utils import has_hub_testing_creds
from click.testing import CliRunner
import pickle


# need this for 32-bit and 64-bit systems to have correct tests
Expand Down Expand Up @@ -502,3 +504,44 @@ def test_empty_dataset():
ds.create_tensor("z")
ds = Dataset("test")
assert list(ds.tensors) == ["x", "y", "z"]


@parametrize_all_dataset_storages
def test_dataset_pickling(ds):
if ds.path.startswith("mem://"):
with pytest.raises(MemoryDatasetCanNotBePickledError):
pickle.dumps(ds)
return

with ds:
ds.create_tensor("image", htype="image", sample_compression="jpeg")
ds.create_tensor("label")
for i in range(10):
ds.image.append(
i * np.ones(((i + 1) * 20, (i + 1) * 20, 3), dtype=np.uint8)
)

for i in range(5):
ds.label.append(i)

pickled_ds = pickle.dumps(ds)
unpickled_ds = pickle.loads(pickled_ds)
assert len(unpickled_ds.image) == len(ds.image)
assert len(unpickled_ds.label) == len(ds.label)
assert unpickled_ds.tensors.keys() == ds.tensors.keys()
assert unpickled_ds.index.values[0].value == ds.index.values[0].value
assert unpickled_ds.meta.version == ds.meta.version

for i in range(10):
np.testing.assert_array_equal(
ds.image[i].numpy(),
(i * np.ones(((i + 1) * 20, (i + 1) * 20, 3), dtype=np.uint8)),
)
np.testing.assert_array_equal(
ds.image[i].numpy(), unpickled_ds.image[i].numpy()
)
for i in range(5):
np.testing.assert_array_equal(ds.label[i].numpy(), i)
np.testing.assert_array_equal(
ds.label[i].numpy(), unpickled_ds.label[i].numpy()
AbhinavTuli marked this conversation as resolved.
Show resolved Hide resolved
)
28 changes: 28 additions & 0 deletions hub/core/storage/lru_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,3 +270,31 @@ def _list_keys(self):
for key in self.cache_storage:
all_keys.add(key)
return list(all_keys)

def __getstate__(self):
"""Returns the state of the cache, for pickling"""

# flushes the cache before pickling
if not hasattr(self, "read_only") or not self.read_only:
self.flush()
AbhinavTuli marked this conversation as resolved.
Show resolved Hide resolved
return {
"next_storage": self.next_storage,
"cache_storage": self.cache_storage,
"cache_size": self.cache_size,
}

def __setstate__(self, state):
"""Recreates a cache with the same configuration as the state.

PS: While restoring the cache, we reset its contents.
In case the cache storage was local/s3 and is still accessible when unpickled (if same machine/s3 creds present respectively), the earlier cache contents are no longer accessible.
AbhinavTuli marked this conversation as resolved.
Show resolved Hide resolved

TODO: We might want to change this behaviour in the future by having a separate file that keeps a track of the lru order for restoring the cache.
This would also allow the cache to persist across different different Dataset objects pointing to the same dataset.
"""
self.next_storage = state["next_storage"]
self.cache_storage = state["cache_storage"]
self.cache_size = state["cache_size"]
self.lru_sizes = OrderedDict()
self.dirty_keys = set()
self.cache_used = 0
AbhinavTuli marked this conversation as resolved.
Show resolved Hide resolved
7 changes: 7 additions & 0 deletions hub/core/storage/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,3 +96,10 @@ def clear(self):
"""Clears the provider."""
self.check_readonly()
self.dict = {}

def __getstate__(self):
"""Does NOT save the in memory data in state."""
return self.root

def __setstate__(self, state):
self.__init__(root=state)
20 changes: 11 additions & 9 deletions hub/core/storage/tests/test_storage_provider.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from hub.core.storage.memory import MemoryProvider
import pytest

from click.testing import CliRunner
Expand Down Expand Up @@ -126,12 +127,13 @@ def test_cache(storage):

@parametrize_all_storages
def test_pickling(storage):
with CliRunner().isolated_filesystem():
FILE_1 = f"{KEY}_1"
storage[FILE_1] = b"hello world"
assert storage[FILE_1] == b"hello world"
pickle_file = open("storage_pickle", "wb")
pickle.dump(storage, pickle_file)
pickle_file = open("storage_pickle", "rb")
unpickled_storage = pickle.load(pickle_file)
assert unpickled_storage[FILE_1] == b"hello world"
if isinstance(storage, MemoryProvider):
# skip pickling test for memory provider as the actual data isn't pickled for it
return
AbhinavTuli marked this conversation as resolved.
Show resolved Hide resolved

FILE_1 = f"{KEY}_1"
storage[FILE_1] = b"hello world"
assert storage[FILE_1] == b"hello world"
pickled_storage = pickle.dumps(storage)
unpickled_storage = pickle.loads(pickled_storage)
assert unpickled_storage[FILE_1] == b"hello world"
7 changes: 7 additions & 0 deletions hub/util/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,3 +431,10 @@ def __init__(self):
super().__init__(
f"Python Shared memory with multiprocessing doesn't work properly on Windows."
)


class MemoryDatasetCanNotBePickledError(Exception):
def __init__(self):
super().__init__(
"Dataset having MemoryProvider as underlying storage should not be pickled as data won't be saved."
)