Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Made Dataset and LRUCache objects pickleable #1049

Merged
merged 16 commits into from
Jul 20, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 47 additions & 21 deletions hub/api/dataset.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from hub.core.storage.provider import StorageProvider
from hub.core.tensor import create_tensor
from typing import Callable, Dict, Optional, Union, Tuple, List, Sequence
from typing import Any, Callable, Dict, Optional, Union, Tuple, List, Sequence
from hub.constants import DEFAULT_HTYPE, UNSPECIFIED
import numpy as np

Expand All @@ -17,6 +17,7 @@
from hub.util.exceptions import (
CouldNotCreateNewDatasetException,
InvalidKeyTypeError,
MemoryDatasetCanNotBePickledError,
PathNotEmptyException,
ReadOnlyModeError,
TensorAlreadyExistsError,
Expand Down Expand Up @@ -75,11 +76,7 @@ def __init__(
creds = {}
base_storage = get_storage_provider(path, storage, read_only, creds, token)

# done instead of directly assigning read_only as backend might return read_only permissions
if hasattr(base_storage, "read_only") and base_storage.read_only:
self._read_only = True
else:
self._read_only = False
self._read_only = base_storage.read_only
AbhinavTuli marked this conversation as resolved.
Show resolved Hide resolved

# uniquely identifies dataset
self.path = path or get_path_from_storage(base_storage)
Expand All @@ -88,25 +85,11 @@ def __init__(
self.storage = generate_chain(
base_storage, memory_cache_size_bytes, local_cache_size_bytes, path
)
self.storage.autoflush = True
self.index = index or Index()

self.tensors: Dict[str, Tensor] = {}

self._token = token

if self.path.startswith("hub://"):
split_path = self.path.split("/")
self.org_id, self.ds_name = split_path[2], split_path[3]
self.client = HubBackendClient(token=token)

self.public = public
self._load_meta()
self.index.validate(self.num_samples)

hub_reporter.feature_report(
feature_name="Dataset", parameters={"Path": str(self.path)}
)
self._set_derived_attributes()

def __enter__(self):
self.storage.autoflush = False
Expand All @@ -128,6 +111,34 @@ def __len__(self):
tensor_lengths = [len(tensor[self.index]) for tensor in self.tensors.values()]
return min(tensor_lengths, default=0)

def __getstate__(self) -> Dict[str, Any]:
"""Returns a dict that can be pickled and used to restore this dataset.

Note:
Pickling a dataset does not copy the dataset, it only saves attributes that can be used to restore the dataset.
If you pickle a local dataset and try to access it on a machine that does not have the data present, the dataset will not work.
"""
if self.path.startswith("mem://"):
raise MemoryDatasetCanNotBePickledError
return {
"path": self.path,
"_read_only": self.read_only,
"index": self.index,
"public": self.public,
"storage": self.storage,
"_token": self.token,
}

def __setstate__(self, state: Dict[str, Any]):
"""Restores dataset from a pickled state.

Args:
state (dict): The pickled state used to restore the dataset.
"""
self.__dict__.update(state)
self.tensors = {}
self._set_derived_attributes()

def __getitem__(
self,
item: Union[
Expand Down Expand Up @@ -307,6 +318,21 @@ def _get_total_meta(self):
for tensor_key, tensor_value in self.tensors.items()
}

def _set_derived_attributes(self):
"""Sets derived attributes during init and unpickling."""
self.storage.autoflush = True
if self.path.startswith("hub://"):
split_path = self.path.split("/")
self.org_id, self.ds_name = split_path[2], split_path[3]
self.client = HubBackendClient(token=self._token)

self._load_meta()
self.index.validate(self.num_samples)

hub_reporter.feature_report(
feature_name="Dataset", parameters={"Path": str(self.path)}
)

def tensorflow(self):
"""Converts the dataset into a tensorflow compatible format.

Expand Down
5 changes: 2 additions & 3 deletions hub/api/tests/test_api.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import numpy as np
import pytest
import uuid
import hub
import os
from hub.api.dataset import Dataset
from hub.tests.common import assert_array_lists_equal
from hub.util.exceptions import (
Expand All @@ -11,6 +9,7 @@
UnsupportedCompressionError,
)
from click.testing import CliRunner
from hub.util.exceptions import TensorDtypeMismatchError, TensorInvalidSampleShapeError
from hub.tests.dataset_fixtures import (
enabled_datasets,
enabled_persistent_dataset_generators,
Expand Down Expand Up @@ -205,7 +204,7 @@ def test_empty_samples(ds: Dataset):
def test_scalar_samples(ds: Dataset):
tensor = ds.create_tensor("scalars")

assert tensor.meta.dtype == None
assert tensor.meta.dtype is None

# first sample sets dtype
tensor.append(5)
Expand Down
46 changes: 46 additions & 0 deletions hub/api/tests/test_pickle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import numpy as np
import pytest
from hub.util.exceptions import MemoryDatasetCanNotBePickledError
import pickle
from hub.tests.dataset_fixtures import enabled_datasets


@enabled_datasets
def test_dataset(ds):
if ds.path.startswith("mem://"):
with pytest.raises(MemoryDatasetCanNotBePickledError):
pickle.dumps(ds)
return

with ds:
ds.create_tensor("image", htype="image", sample_compression="jpeg")
ds.create_tensor("label")
for i in range(10):
ds.image.append(
i * np.ones(((i + 1) * 20, (i + 1) * 20, 3), dtype=np.uint8)
)

for i in range(5):
ds.label.append(i)

pickled_ds = pickle.dumps(ds)
unpickled_ds = pickle.loads(pickled_ds)
assert len(unpickled_ds.image) == len(ds.image)
assert len(unpickled_ds.label) == len(ds.label)
assert unpickled_ds.tensors.keys() == ds.tensors.keys()
assert unpickled_ds.index.values[0].value == ds.index.values[0].value
assert unpickled_ds.meta.version == ds.meta.version

for i in range(10):
np.testing.assert_array_equal(
ds.image[i].numpy(),
(i * np.ones(((i + 1) * 20, (i + 1) * 20, 3), dtype=np.uint8)),
)
np.testing.assert_array_equal(
ds.image[i].numpy(), unpickled_ds.image[i].numpy()
)
for i in range(5):
np.testing.assert_array_equal(ds.label[i].numpy(), i)
np.testing.assert_array_equal(
ds.label[i].numpy(), unpickled_ds.label[i].numpy()
)
44 changes: 40 additions & 4 deletions hub/core/storage/lru_cache.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from collections import OrderedDict
from hub.core.storage.cachable import Cachable
from typing import Callable, Set, Union
from typing import Any, Dict, Set, Union

from hub.core.storage.provider import StorageProvider

Expand Down Expand Up @@ -161,11 +161,10 @@ def __delitem__(self, path: str):
self.maybe_flush()

def clear_cache(self):
"""Flushes the content of the cache and and then deletes contents of all the layers of it.
"""Flushes the content of the cache if not in read mode and and then deletes contents of all the layers of it.
This doesn't delete data from the actual storage.
"""
self.check_readonly()
self.flush()
self._flush_if_not_read_only()
self.cache_used = 0
self.lru_sizes.clear()
self.dirty_keys.clear()
Expand Down Expand Up @@ -270,3 +269,40 @@ def _list_keys(self):
for key in self.cache_storage:
all_keys.add(key)
return list(all_keys)

def _flush_if_not_read_only(self):
"""Flushes the cache if not in read-only mode."""
if not self.read_only:
self.flush()

def __getstate__(self) -> Dict[str, Any]:
"""Returns the state of the cache, for pickling"""

# flushes the cache before pickling
self._flush_if_not_read_only()

return {
"next_storage": self.next_storage,
"cache_storage": self.cache_storage,
"cache_size": self.cache_size,
}

def __setstate__(self, state: Dict[str, Any]):
"""Recreates a cache with the same configuration as the state.

Args:
state (dict): The state to be used to recreate the cache.

Note:
While restoring the cache, we reset its contents.
In case the cache storage was local/s3 and is still accessible when unpickled (if same machine/s3 creds present respectively), the earlier cache contents are no longer accessible.
"""

# TODO: We might want to change this behaviour in the future by having a separate file that keeps a track of the lru order for restoring the cache.
# This would also allow the cache to persist across different different Dataset objects pointing to the same dataset.
self.next_storage = state["next_storage"]
self.cache_storage = state["cache_storage"]
self.cache_size = state["cache_size"]
self.lru_sizes = OrderedDict()
self.dirty_keys = set()
self.cache_used = 0
12 changes: 10 additions & 2 deletions hub/core/storage/memory.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from typing import Any, Dict
from hub.core.storage.provider import StorageProvider


class MemoryProvider(StorageProvider):
"""Provider class for using the memory."""

def __init__(self, root=""):
self.dict = {}
def __init__(self, root: str = ""):
self.dict: Dict[str, Any] = {}
self.root = root

def __getitem__(
Expand Down Expand Up @@ -96,3 +97,10 @@ def clear(self):
"""Clears the provider."""
self.check_readonly()
self.dict = {}

def __getstate__(self) -> str:
"""Does NOT save the in memory data in state."""
return self.root

def __setstate__(self, state: str):
self.__init__(root=state) # type: ignore
5 changes: 3 additions & 2 deletions hub/core/storage/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

class StorageProvider(ABC, MutableMapping):
autoflush = False
read_only = False

"""An abstract base class for implementing a storage provider.

Expand Down Expand Up @@ -137,7 +138,7 @@ def disable_readonly(self):

def check_readonly(self):
"""Raises an exception if the provider is in read-only mode."""
if hasattr(self, "read_only") and self.read_only:
if self.read_only:
raise ReadOnlyModeError()

def flush(self):
Expand All @@ -150,7 +151,7 @@ def maybe_flush(self):
"""Flush cache if autoflush has been enabled.
Called at the end of methods which write data, to ensure consistency as a default.
"""
if hasattr(self, "autoflush") and self.autoflush:
if self.autoflush:
self.flush()

@abstractmethod
Expand Down
6 changes: 2 additions & 4 deletions hub/core/storage/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,10 +252,8 @@ def _check_update_creds(self):
client = HubBackendClient(self.token)
org_id, ds_name = self.tag.split("/")

if hasattr(self, "read_only") and self.read_only:
mode = "r"
else:
mode = "a"
mode = "r" if self.read_only else "a"

url, creds, mode, expiration = client.get_dataset_credentials(
org_id, ds_name, mode
)
Expand Down
21 changes: 8 additions & 13 deletions hub/core/storage/tests/test_storage_provider.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
from hub.tests.storage_fixtures import enabled_storages
from hub.tests.storage_fixtures import enabled_storages, enabled_persistent_storages
from hub.tests.cache_fixtures import enabled_cache_chains
import pytest

from click.testing import CliRunner
from hub.constants import MB
import pickle

Expand Down Expand Up @@ -126,14 +124,11 @@ def test_cache(cache_chain):
check_cache(cache_chain)


@enabled_storages
@enabled_persistent_storages
def test_pickling(storage):
with CliRunner().isolated_filesystem():
FILE_1 = f"{KEY}_1"
storage[FILE_1] = b"hello world"
assert storage[FILE_1] == b"hello world"
pickle_file = open("storage_pickle", "wb")
pickle.dump(storage, pickle_file)
pickle_file = open("storage_pickle", "rb")
unpickled_storage = pickle.load(pickle_file)
assert unpickled_storage[FILE_1] == b"hello world"
FILE_1 = f"{KEY}_1"
storage[FILE_1] = b"hello world"
assert storage[FILE_1] == b"hello world"
pickled_storage = pickle.dumps(storage)
unpickled_storage = pickle.loads(pickled_storage)
assert unpickled_storage[FILE_1] == b"hello world"
6 changes: 6 additions & 0 deletions hub/tests/storage_fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,12 @@
indirect=True,
)

enabled_persistent_storages = pytest.mark.parametrize(
"storage",
["local_storage", "s3_storage", "hub_cloud_storage"],
indirect=True,
)


@pytest.fixture
def memory_storage(memory_path):
Expand Down
7 changes: 7 additions & 0 deletions hub/util/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,3 +453,10 @@ def __init__(self):
super().__init__(
f"Python Shared memory with multiprocessing doesn't work properly on Windows."
)


class MemoryDatasetCanNotBePickledError(Exception):
def __init__(self):
super().__init__(
"Dataset having MemoryProvider as underlying storage should not be pickled as data won't be saved."
)