diff --git a/hub/api/dataset.py b/hub/api/dataset.py index c606170fd3..97e6c9dcfe 100644 --- a/hub/api/dataset.py +++ b/hub/api/dataset.py @@ -1,3 +1,4 @@ +from hub.api.info import load_info from hub.core.storage.provider import StorageProvider from hub.core.tensor import create_tensor from typing import Any, Callable, Dict, Optional, Union, Tuple, List, Sequence @@ -11,7 +12,12 @@ from hub.core.index import Index from hub.integrations import dataset_to_tensorflow -from hub.util.keys import dataset_exists, get_dataset_meta_key, tensor_exists +from hub.util.keys import ( + dataset_exists, + get_dataset_info_key, + get_dataset_meta_key, + tensor_exists, +) from hub.util.bugout_reporter import hub_reporter from hub.util.cache_chain import generate_chain from hub.util.exceptions import ( @@ -89,6 +95,7 @@ def __init__( self.tensors: Dict[str, Tensor] = {} self._token = token self.public = public + self._set_derived_attributes() def __enter__(self): @@ -239,18 +246,19 @@ def _load_meta(self): raise PathNotEmptyException else: - self.meta = DatasetMeta() - - try: - self.storage[meta_key] = self.meta - except ReadOnlyModeError: - # if this is thrown, that means the dataset doesn't exist and the user has no write access. + if self.read_only: + # cannot create a new dataset when in read_only mode. raise CouldNotCreateNewDatasetException(self.path) + self.meta = DatasetMeta() + self.storage[meta_key] = self.meta self.flush() if self.path.startswith("hub://"): self.client.create_dataset_entry( - self.org_id, self.ds_name, self.meta.as_dict(), public=self.public + self.org_id, + self.ds_name, + self.meta.__getstate__(), + public=self.public, ) @property @@ -320,13 +328,15 @@ def _get_total_meta(self): def _set_derived_attributes(self): """Sets derived attributes during init and unpickling.""" + self.storage.autoflush = True if self.path.startswith("hub://"): split_path = self.path.split("/") self.org_id, self.ds_name = split_path[2], split_path[3] self.client = HubBackendClient(token=self._token) - self._load_meta() + self._load_meta() # TODO: use the same scheme as `load_info` + self.info = load_info(get_dataset_info_key(), self.storage) # type: ignore self.index.validate(self.num_samples) hub_reporter.feature_report( diff --git a/hub/api/info.py b/hub/api/info.py new file mode 100644 index 0000000000..383edb7716 --- /dev/null +++ b/hub/api/info.py @@ -0,0 +1,101 @@ +from hub.core.storage.lru_cache import LRUCache +from typing import Any, Dict +from hub.core.storage.cachable import CachableCallback, use_callback + + +class Info(CachableCallback): + def __init__(self): + """Contains **optional** key/values that datasets/tensors use for human-readability. + See the `Meta` class for required key/values for datasets/tensors. + + Note: + Since `Info` is rarely written to and mostly by the user, every modifier will call `cache[key] = self`. + Must call `initialize_callback_location` before using any methods. + """ + + self._info = {} + super().__init__() + + @property + def nbytes(self): + # TODO: optimize this + return len(self.tobytes()) + + @use_callback(check_only=True) + def __len__(self): + return len(self._info) + + @use_callback(check_only=True) + def __getstate__(self) -> Dict[str, Any]: + return self._info + + def __setstate__(self, state: Dict[str, Any]): + self._info = state + + @use_callback() + def update(self, *args, **kwargs): + """Store optional dataset/tensor information. Will be accessible after loading your data from a new script! + Inputs must be supported by JSON. + + + Note: + This method has the same functionality as `dict().update(...)` Reference: https://www.geeksforgeeks.org/python-dictionary-update-method/. + A full list of supported value types can be found here: https://docs.python.org/3/library/json.html#json.JSONEncoder. + + Examples: + Normal update usage: + >>> ds.info + {} + >>> ds.info.update(key=0) + >>> ds.info + {"key": 0} + >>> ds.info.update({"key1": 5, "key2": [1, 2, "test"]}) + >>> ds.info + {"key": 0, "key1": 5, "key2": [1, 2, "test"]} + + Alternate update usage: + >>> ds.info + {} + >>> ds.info.update(list=[1, 2, "apple"]) + >>> ds.info + {"list": [1, 2, "apple"]} + >>> l = ds.info.list + >>> l + [1, 2, "apple"] + >>> l.append(5) + >>> l + [1, 2, "apple", 5] + >>> ds.info.update() # required to be persistent! + + """ + + self._cache.check_readonly() + self._info.update(*args, **kwargs) + + def __getattribute__(self, name: str) -> Any: + """Allows access to info values using the `.` syntax. Example: `info.description`.""" + + if name == "_info": + return super().__getattribute__(name) + if name in self._info: + return self.__getitem__(name) + return super().__getattribute__(name) + + def __getitem__(self, key: str): + return self._info[key] + + def __str__(self): + return self._info.__str__() + + def __repr__(self): + return self._info.__repr__() + + +def load_info(info_key: str, cache: LRUCache): + if info_key in cache: + info = cache.get_cachable(info_key, Info) + else: + info = Info() + info.initialize_callback_location(info_key, cache) + + return info diff --git a/hub/api/tensor.py b/hub/api/tensor.py index 284c298be5..c45b69d650 100644 --- a/hub/api/tensor.py +++ b/hub/api/tensor.py @@ -1,4 +1,8 @@ -from hub.util.keys import get_chunk_id_encoder_key, get_tensor_meta_key, tensor_exists +from hub.api.info import load_info +from hub.util.keys import ( + get_tensor_info_key, + tensor_exists, +) from hub.core.sample import Sample # type: ignore from typing import List, Sequence, Union, Optional, Tuple, Dict from hub.util.shape import ShapeInterval @@ -47,6 +51,8 @@ def __init__( self.chunk_engine = ChunkEngine(self.key, self.storage) self.index.validate(self.num_samples) + self.info = load_info(get_tensor_info_key(self.key), self.storage) + def extend(self, samples: Union[np.ndarray, Sequence[SampleValue]]): """Extends the end of the tensor by appending multiple elements from a sequence. Accepts a sequence, a single batched numpy array, or a sequence of `hub.read` outputs, which can be used to load files. See examples down below. diff --git a/hub/api/tests/test_info.py b/hub/api/tests/test_info.py new file mode 100644 index 0000000000..c2cd25be32 --- /dev/null +++ b/hub/api/tests/test_info.py @@ -0,0 +1,110 @@ +def test_dataset(local_ds_generator): + ds = local_ds_generator() + + assert len(ds.info) == 0 + + ds.info.update(my_key=0) + ds.info.update(my_key=1) + + ds.info.update(another_key="hi") + ds.info.update({"another_key": "hello"}) + + ds.info.update({"something": "aaaaa"}, something="bbbb") + + ds.info.update(test=[1, 2, "5"]) + + test_list = ds.info.test + with ds: + ds.info.update({"test2": (1, 5, (1, "2"), [5, 6, (7, 8)])}) + ds.info.update(xyz="abc") + test_list.extend(["user made change without `update`"]) + + ds.info.update({"1_-+": 5}) + + ds = local_ds_generator() + + assert len(ds.info) == 7 + + assert ds.info.another_key == "hello" + assert ds.info.something == "bbbb" + + assert ds.info.test == [1, 2, "5", "user made change without `update`"] + assert ds.info.test2 == [1, 5, [1, "2"], [5, 6, [7, 8]]] + + assert ds.info.xyz == "abc" + assert ds.info["1_-+"] == 5 # key can't be accessed with `.` syntax + + ds.info.update(test=[99]) + + ds = local_ds_generator() + + assert len(ds.info) == 7 + assert ds.info.test == [99] + + +def test_tensor(local_ds_generator): + ds = local_ds_generator() + + t1 = ds.create_tensor("tensor1") + t2 = ds.create_tensor("tensor2") + + assert len(t1.info) == 0 + assert len(t2.info) == 0 + + t1.info.update(key=0) + t2.info.update(key=1, key1=0) + + ds = local_ds_generator() + + t1 = ds.tensor1 + t2 = ds.tensor2 + + assert len(t1.info) == 1 + assert len(t2.info) == 2 + + assert t1.info.key == 0 + assert t2.info.key == 1 + assert t2.info.key1 == 0 + + with ds: + t1.info.update(key=99) + + ds = local_ds_generator() + + t1 = ds.tensor1 + t2 = ds.tensor2 + + assert len(t1.info) == 1 + assert len(t2.info) == 2 + + assert t1.info.key == 99 + + +def test_update_reference_manually(local_ds_generator): + """Right now synchronization can only happen when you call `info.update`.""" + + ds = local_ds_generator() + + ds.info.update(key=[1, 2, 3]) + + ds = local_ds_generator() + + l = ds.info.key + assert l == [1, 2, 3] + + # un-registered update + l.append(5) + assert ds.info.key == [1, 2, 3, 5] + + ds = local_ds_generator() + + l = ds.info.key + assert l == [1, 2, 3] + + # registered update + l.append(99) + ds.info.update() + + ds = local_ds_generator() + + assert l == [1, 2, 3, 99] diff --git a/hub/api/tests/test_readonly.py b/hub/api/tests/test_readonly.py index 6bd0c361ff..901423b5f2 100644 --- a/hub/api/tests/test_readonly.py +++ b/hub/api/tests/test_readonly.py @@ -29,6 +29,12 @@ def test_readonly(local_ds_generator): ds.read_only = True _assert_readonly_ops(ds, 1, (100, 100)) + with pytest.raises(ReadOnlyModeError): + ds.info.update(key=0) + + with pytest.raises(ReadOnlyModeError): + ds.tensor.info.update(key=0) + @pytest.mark.xfail(raises=CouldNotCreateNewDatasetException, strict=True) def test_readonly_doesnt_exist(local_path): diff --git a/hub/constants.py b/hub/constants.py index 60f76469cb..52787be452 100644 --- a/hub/constants.py +++ b/hub/constants.py @@ -24,6 +24,7 @@ SUPPORTED_MODES = ["r", "a"] +# min chunk size is always half of `DEFAULT_MAX_CHUNK_SIZE` DEFAULT_MAX_CHUNK_SIZE = 32 * MB MIN_FIRST_CACHE_SIZE = 32 * MB @@ -34,8 +35,14 @@ DEFAULT_LOCAL_CACHE_SIZE = 0 +# meta is hub-defined information, necessary for hub Datasets/Tensors to function DATASET_META_FILENAME = "dataset_meta.json" TENSOR_META_FILENAME = "tensor_meta.json" + +# info is user-defined information, entirely optional. may be used by the visualizer +DATASET_INFO_FILENAME = "dataset_info.json" +TENSOR_INFO_FILENAME = "tensor_info.json" + META_ENCODING = "utf8" CHUNKS_FOLDER = "chunks" diff --git a/hub/core/chunk.py b/hub/core/chunk.py index 7aa81db641..3c5238a361 100644 --- a/hub/core/chunk.py +++ b/hub/core/chunk.py @@ -108,8 +108,10 @@ def update_headers(self, incoming_num_bytes: int, sample_shape: Tuple[int]): self.shapes_encoder.add_shape(sample_shape, 1) self.byte_positions_encoder.add_byte_position(num_bytes_per_sample, 1) - def __len__(self): + @property + def nbytes(self): """Calculates the number of bytes `tobytes` will be without having to call `tobytes`. Used by `LRUCache` to determine if this chunk can be cached.""" + return infer_chunk_num_bytes( hub.__version__, self.shapes_encoder.array, diff --git a/hub/core/chunk_engine.py b/hub/core/chunk_engine.py index 10cacfefe9..4ef6fad5fd 100644 --- a/hub/core/chunk_engine.py +++ b/hub/core/chunk_engine.py @@ -213,7 +213,7 @@ def _synchronize_cache(self): # synchronize last chunk last_chunk_key = self.last_chunk_key last_chunk = self.last_chunk - self.cache.update_used_cache_for_path(last_chunk_key, len(last_chunk)) # type: ignore + self.cache.update_used_cache_for_path(last_chunk_key, last_chunk.nbytes) # type: ignore # synchronize tensor meta tensor_meta_key = get_tensor_meta_key(self.key) diff --git a/hub/core/meta/dataset_meta.py b/hub/core/meta/dataset_meta.py index 7c7ebc7172..5ca988fbf3 100644 --- a/hub/core/meta/dataset_meta.py +++ b/hub/core/meta/dataset_meta.py @@ -1,7 +1,5 @@ -from typing import Dict, List -from hub.core.storage.provider import StorageProvider +from typing import Any, Dict from hub.core.meta.meta import Meta -from hub.util.keys import get_dataset_meta_key class DatasetMeta(Meta): @@ -10,7 +8,12 @@ def __init__(self): super().__init__() - def as_dict(self) -> dict: - d = super().as_dict() + @property + def nbytes(self): + # TODO: can optimize this + return len(self.tobytes()) + + def __getstate__(self) -> Dict[str, Any]: + d = super().__getstate__() d["tensors"] = self.tensors return d diff --git a/hub/core/meta/encode/chunk_id.py b/hub/core/meta/encode/chunk_id.py index fd6b0ecc41..7b3518979e 100644 --- a/hub/core/meta/encode/chunk_id.py +++ b/hub/core/meta/encode/chunk_id.py @@ -72,6 +72,11 @@ def __init__(self): self._encoded_ids = None + @property + def nbytes(self): + # TODO: optimize this + return len(self.tobytes()) + def tobytes(self) -> memoryview: if self._encoded_ids is None: return serialize_chunkids( diff --git a/hub/core/meta/meta.py b/hub/core/meta/meta.py index 7aa553b955..64e4838690 100644 --- a/hub/core/meta/meta.py +++ b/hub/core/meta/meta.py @@ -1,10 +1,15 @@ +from typing import Any, Dict import hub from hub.core.storage.cachable import Cachable class Meta(Cachable): def __init__(self): + """Contains **required** key/values that datasets/tensors use to function. + See the `Info` class for optional key/values for datasets/tensors. + """ + self.version = hub.__version__ - def as_dict(self) -> dict: + def __getstate__(self) -> Dict[str, Any]: return {"version": self.version} diff --git a/hub/core/meta/tensor_meta.py b/hub/core/meta/tensor_meta.py index 80563bf5fb..9bb1bb59b4 100644 --- a/hub/core/meta/tensor_meta.py +++ b/hub/core/meta/tensor_meta.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, List, Tuple +from typing import Any, Callable, Dict, List, Tuple import numpy as np from hub.util.exceptions import ( TensorInvalidSampleShapeError, @@ -51,7 +51,11 @@ def __init__( required_meta = _required_meta_from_htype(htype) required_meta.update(kwargs) + + self._required_meta_keys = tuple(required_meta.keys()) self.__dict__.update(required_meta) + else: + self._required_meta_keys = tuple() super().__init__() @@ -167,9 +171,22 @@ def _update_shape_interval(self, shape: Tuple[int, ...]): self.min_shape[i] = min(dim, self.min_shape[i]) self.max_shape[i] = max(dim, self.max_shape[i]) - def as_dict(self): - # TODO: tensor meta as_dict - raise NotImplementedError + def __getstate__(self) -> Dict[str, Any]: + d = super().__getstate__() + + for key in self._required_meta_keys: + d[key] = getattr(self, key) + + return d + + def __setstate__(self, state: Dict[str, Any]): + super().__setstate__(state) + self._required_meta_keys = tuple(state.keys()) + + @property + def nbytes(self): + # TODO: optimize this + return len(self.tobytes()) def _required_meta_from_htype(htype: str) -> dict: diff --git a/hub/core/storage/cachable.py b/hub/core/storage/cachable.py index 7a3cb120c3..4a40f0c6c2 100644 --- a/hub/core/storage/cachable.py +++ b/hub/core/storage/cachable.py @@ -1,5 +1,7 @@ from abc import ABC import json +from typing import Any, Dict +from hub.util.exceptions import CallbackInitializationError class Cachable(ABC): @@ -13,14 +15,103 @@ def __init__(self, buffer: bytes = None): if buffer: self.frombuffer(buffer) - def __len__(self): - return len(self.tobytes()) + @property + def nbytes(self): + # do not implement, each class should do this because it could be very slow if `tobytes` is called + raise NotImplementedError + + def __getstate__(self) -> Dict[str, Any]: + raise NotImplementedError + + def __setstate__(self, state: Dict[str, Any]): + self.__dict__.update(state) def tobytes(self) -> bytes: - return bytes(json.dumps(self.__dict__), "utf-8") + return bytes(json.dumps(self.__getstate__()), "utf-8") @classmethod def frombuffer(cls, buffer: bytes): instance = cls() - instance.__dict__.update(json.loads(buffer)) + instance.__setstate__(json.loads(buffer)) return instance + + +def use_callback(check_only: bool = False): + """Decorator function for `CachableCallback` and it's subclasses. + + Note: + Must call `@use_callback()` not `@use_callback`. + Use this decorator on a field method that should use the `CachableCallback.callback` method. + All methods that are decorated will require that `CachableCallback.initialize_callback_location` + is called first. Also, after the function executes, `CachableCallback.callback` is called. + + Args: + check_only (bool): If True, the callback is not actually called. Only the requirement check is executed. Defaults to False. + + Returns: + Decorator function. + """ + + def outer(func): + def inner(obj: "CachableCallback", *args, **kwargs): + if not obj._is_callback_initialized(): + raise CallbackInitializationError( + "Must first call `initialize_callback_location` before any other methods may be called." + ) + + y = func(obj, *args, **kwargs) + + if not check_only: + obj.callback() + + return y + + return inner + + return outer + + +class CachableCallback(Cachable): + def __init__(self): + """CachableCallback objects can be stored in memory cache and when modifier methods are called, this class is synchronized + with the cache. This means the user doesn't have to do `ds.cache[cache_key] = ds.info`. + + Note: + This class should be used as infrequently as possible, as it may lead to slowdowns. + When extending this class, methods that should have a callback called should be decorated with + `@use_callback()`. + """ + + self._key = None + self._cache = None + + def _is_callback_initialized(self) -> bool: + key_exists = self._key is not None + cache_exists = self._cache is not None + return key_exists and cache_exists + + def initialize_callback_location(self, key, cache): + """Must be called once before any other method calls. + + Args: + key: The key for where in `cache` bytes are serialized with each callback call. + cache: The cache for where bytes are serialized with each callback call. + + Raises: + CallbackInitializationError: Cannot re-initialize. + """ + + if self._is_callback_initialized(): + raise CallbackInitializationError( + f"`initialize_callback_location` was already called. key={self._key}" + ) + + self._key = key + self._cache = cache + + def callback(self): + self._cache[self._key] = self + + @use_callback(check_only=True) + def flush(self): + self._cache.flush() diff --git a/hub/core/storage/lru_cache.py b/hub/core/storage/lru_cache.py index b8abd205ba..9de0fe3d86 100644 --- a/hub/core/storage/lru_cache.py +++ b/hub/core/storage/lru_cache.py @@ -1,10 +1,16 @@ from collections import OrderedDict -from hub.core.storage.cachable import Cachable +from hub.core.storage.cachable import Cachable, CachableCallback from typing import Any, Dict, Set, Union from hub.core.storage.provider import StorageProvider +def _get_nbytes(obj: Union[bytes, memoryview, Cachable]): + if isinstance(obj, Cachable): + return obj.nbytes + return len(obj) + + # TODO use lock for multiprocessing class LRUCache(StorageProvider): """LRU Cache that uses StorageProvider for caching""" @@ -82,8 +88,13 @@ def get_cachable(self, path: str, expected_class): if isinstance(item, (bytes, memoryview)): obj = expected_class.frombuffer(item) - if len(obj) <= self.cache_size: + + if isinstance(obj, CachableCallback): + obj.initialize_callback_location(path, self) + + if obj.nbytes <= self.cache_size: self._insert_in_cache(path, obj) + return obj raise ValueError(f"Item at '{path}' got an invalid type: '{type(item)}'.") @@ -106,7 +117,8 @@ def __getitem__(self, path: str): return self.cache_storage[path] else: result = self.next_storage[path] # fetch from storage, may throw KeyError - if len(result) <= self.cache_size: # insert in cache if it fits + + if _get_nbytes(result) <= self.cache_size: # insert in cache if it fits self._insert_in_cache(path, result) return result @@ -125,7 +137,7 @@ def __setitem__(self, path: str, value: Union[bytes, Cachable]): size = self.lru_sizes.pop(path) self.cache_used -= size - if len(value) <= self.cache_size: + if _get_nbytes(value) <= self.cache_size: self._insert_in_cache(path, value) self.dirty_keys.add(path) else: # larger than cache, directly send to next layer @@ -254,10 +266,10 @@ def _insert_in_cache(self, path: str, value: Union[bytes, Cachable]): ReadOnlyError: If the provider is in read-only mode. """ - self._free_up_space(len(value)) + self._free_up_space(_get_nbytes(value)) self.cache_storage[path] = value # type: ignore - self.update_used_cache_for_path(path, len(value)) + self.update_used_cache_for_path(path, _get_nbytes(value)) def _list_keys(self): """Helper function that lists all the objects present in the cache and the underlying storage. diff --git a/hub/core/storage/tests/test_readonly.py b/hub/core/storage/tests/test_readonly.py index c5005f4da8..d61d704f52 100644 --- a/hub/core/storage/tests/test_readonly.py +++ b/hub/core/storage/tests/test_readonly.py @@ -1,6 +1,5 @@ import pytest -from hub.util.exceptions import CouldNotCreateNewDatasetException, ReadOnlyModeError -from hub import Dataset +from hub.util.exceptions import ReadOnlyModeError from hub.tests.storage_fixtures import enabled_storages diff --git a/hub/tests/path_fixtures.py b/hub/tests/path_fixtures.py index 08a429be22..cc7a656e0e 100644 --- a/hub/tests/path_fixtures.py +++ b/hub/tests/path_fixtures.py @@ -100,6 +100,7 @@ def local_path(request): return path = _get_storage_path(request, LOCAL) + LocalProvider(path).clear() yield path @@ -115,6 +116,7 @@ def s3_path(request): return path = _get_storage_path(request, S3) + S3Provider(path).clear() yield path @@ -130,6 +132,7 @@ def hub_cloud_path(request, hub_cloud_dev_token): return path = _get_storage_path(request, HUB_CLOUD) + storage_provider_from_hub_path(path, token=hub_cloud_dev_token).clear() yield path diff --git a/hub/util/exceptions.py b/hub/util/exceptions.py index 98aa50af02..243bbcd551 100644 --- a/hub/util/exceptions.py +++ b/hub/util/exceptions.py @@ -455,6 +455,10 @@ def __init__(self): ) +class CallbackInitializationError(Exception): + pass + + class MemoryDatasetCanNotBePickledError(Exception): def __init__(self): super().__init__( diff --git a/hub/util/keys.py b/hub/util/keys.py index f10fb6649e..d120777ee7 100644 --- a/hub/util/keys.py +++ b/hub/util/keys.py @@ -1,3 +1,4 @@ +from hub.util.exceptions import CorruptedMetaError from hub.core.storage.provider import StorageProvider import posixpath @@ -13,10 +14,19 @@ def get_dataset_meta_key() -> str: return constants.DATASET_META_FILENAME +def get_dataset_info_key() -> str: + # dataset info is always relative to the `StorageProvider`'s root + return constants.DATASET_INFO_FILENAME + + def get_tensor_meta_key(key: str) -> str: return posixpath.join(key, constants.TENSOR_META_FILENAME) +def get_tensor_info_key(key: str) -> str: + return posixpath.join(key, constants.TENSOR_INFO_FILENAME) + + def get_chunk_id_encoder_key(key: str) -> str: return posixpath.join( key, @@ -26,7 +36,10 @@ def get_chunk_id_encoder_key(key: str) -> str: def dataset_exists(storage: StorageProvider) -> bool: - return get_dataset_meta_key() in storage + """A dataset exists if the provided `storage` contains a `dataset_meta.json`.""" + + meta_exists = get_dataset_meta_key() in storage + return meta_exists def tensor_exists(key: str, storage: StorageProvider) -> bool: