diff --git a/hub/api/compute_list.py b/hub/api/compute_list.py new file mode 100644 index 0000000000..6b1f3a7d80 --- /dev/null +++ b/hub/api/compute_list.py @@ -0,0 +1,21 @@ +from hub.api.dataset import Dataset, DatasetView, TensorView +import numpy as np + + +# a list of Datasets or DatasetViews or Tensorviews that supports compute operation +class ComputeList: + # Doesn't support further get item operations currently + def __init__(self, ls): + self.ls = ls + + def compute(self): + results = [ + item.compute() + if isinstance(item, (Dataset, DatasetView, TensorView)) + else item + for item in self.ls + ] + return np.concatenate(results) + + def numpy(self): + return self.compute() diff --git a/hub/api/dataset.py b/hub/api/dataset.py index 543ab538a1..7e9daa6ed7 100644 --- a/hub/api/dataset.py +++ b/hub/api/dataset.py @@ -771,10 +771,12 @@ def numpy(self, label_name=False): If the TensorView object is of the ClassLabel type, setting this to True would retrieve the label names instead of the label encoded integers, otherwise this parameter is ignored. """ - return [ - create_numpy_dict(self, i, label_name=label_name) - for i in range(self._shape[0]) - ] + return np.array( + [ + create_numpy_dict(self, i, label_name=label_name) + for i in range(self._shape[0]) + ] + ) def compute(self, label_name=False): """Gets the values from different tensorview objects in the dataset schema diff --git a/hub/api/datasetview.py b/hub/api/datasetview.py index 41277a3f42..17418d5043 100644 --- a/hub/api/datasetview.py +++ b/hub/api/datasetview.py @@ -16,6 +16,7 @@ from hub.exceptions import NoneValueException from hub.api.objectview import ObjectView from hub.schema import Sequence +import numpy as np class DatasetView: @@ -306,10 +307,12 @@ def numpy(self, label_name=False): if isinstance(self.indexes, int): return create_numpy_dict(self.dataset, self.indexes, label_name=label_name) else: - return [ - create_numpy_dict(self.dataset, index, label_name=label_name) - for index in self.indexes - ] + return np.array( + [ + create_numpy_dict(self.dataset, index, label_name=label_name) + for index in self.indexes + ] + ) def disable_lazy(self): self.lazy = False diff --git a/hub/api/sharded_datasetview.py b/hub/api/sharded_datasetview.py index 4887176224..189ee90346 100644 --- a/hub/api/sharded_datasetview.py +++ b/hub/api/sharded_datasetview.py @@ -5,9 +5,8 @@ """ from collections.abc import Iterable - -from hub.api.datasetview import DatasetView -from hub.exceptions import AdvancedSlicingNotSupported +from hub.api.dataset_utils import slice_split +from hub.api.compute_list import ComputeList class ShardedDatasetView: @@ -55,33 +54,79 @@ def identify_shard(self, index) -> tuple: shard_id += 1 return 0, 0 - def slicing(self, slice_): + def slicing(self, slice_list): """ Identifies the dataset shard that should be used - Notes: - Features of advanced slicing are missing as one would expect from a DatasetView - E.g. cross sharded dataset access is missing """ + shard_id, offset = self.identify_shard(slice_list[0]) + slice_list[0] = slice_list[0] - offset + return slice_list, shard_id + + def __getitem__(self, slice_): if not isinstance(slice_, Iterable) or isinstance(slice_, str): slice_ = [slice_] - slice_ = list(slice_) - if not isinstance(slice_[0], int): - # TODO add advanced slicing options - raise AdvancedSlicingNotSupported() - - shard_id, offset = self.identify_shard(slice_[0]) - slice_[0] = slice_[0] - offset - - return slice_, shard_id - - def __getitem__(self, slice_) -> DatasetView: - slice_, shard_id = self.slicing(slice_) - return self.datasets[shard_id][slice_] + subpath, slice_list = slice_split(slice_) + slice_list = slice_list or [slice(0, self.num_samples)] + if isinstance(slice_list[0], int): + # if integer it fetches the data from the corresponding dataset + slice_list, shard_id = self.slicing(slice_list) + slice_ = slice_list + [subpath] if subpath else slice_list + return self.datasets[shard_id][slice_] + else: + # if slice it finds all the corresponding datasets included in the slice and generates tensorviews or datasetviews (depending on slice) + # these views are stored in a ComputeList, calling compute on which will fetch data from all corresponding datasets and return a single result + results = [] + cur_index = slice_list[0].start or 0 + cur_index = cur_index + self.num_samples if cur_index < 0 else cur_index + cur_index = max(cur_index, 0) + stop_index = slice_list[0].stop or self.num_samples + stop_index = min(stop_index, self.num_samples) + while cur_index < stop_index: + shard_id, offset = self.identify_shard(cur_index) + end_index = min(offset + len(self.datasets[shard_id]), stop_index) + cur_slice_list = [ + slice(cur_index - offset, end_index - offset) + ] + slice_list[1:] + current_slice = ( + cur_slice_list + [subpath] if subpath else cur_slice_list + ) + results.append(self.datasets[shard_id][current_slice]) + cur_index = end_index + return ComputeList(results) def __setitem__(self, slice_, value) -> None: - slice_, shard_id = self.slicing(slice_) - self.datasets[shard_id][slice_] = value + if not isinstance(slice_, Iterable) or isinstance(slice_, str): + slice_ = [slice_] + slice_ = list(slice_) + subpath, slice_list = slice_split(slice_) + slice_list = slice_list or [slice(0, self.num_samples)] + if isinstance(slice_list[0], int): + # if integer it assigns the data to the corresponding dataset + slice_list, shard_id = self.slicing(slice_list) + slice_ = slice_list + [subpath] if subpath else slice_list + self.datasets[shard_id][slice_] = value + else: + # if slice it finds all the corresponding datasets and assigns slices of the value one by one + cur_index = slice_list[0].start or 0 + cur_index = cur_index + self.num_samples if cur_index < 0 else cur_index + cur_index = max(cur_index, 0) + start_index = cur_index + stop_index = slice_list[0].stop or self.num_samples + stop_index = min(stop_index, self.num_samples) + while cur_index < stop_index: + shard_id, offset = self.identify_shard(cur_index) + end_index = min(offset + len(self.datasets[shard_id]), stop_index) + cur_slice_list = [ + slice(cur_index - offset, end_index - offset) + ] + slice_list[1:] + current_slice = ( + cur_slice_list + [subpath] if subpath else cur_slice_list + ) + self.datasets[shard_id][current_slice] = value[ + cur_index - start_index : end_index - start_index + ] + cur_index = end_index def __iter__(self): """ Returns Iterable over samples """ diff --git a/hub/api/tests/test_dataset.py b/hub/api/tests/test_dataset.py index af3ae2c0fb..bc8309771b 100644 --- a/hub/api/tests/test_dataset.py +++ b/hub/api/tests/test_dataset.py @@ -1032,14 +1032,14 @@ def test_check_label_name(): ds["label", 0] = 1 ds["label", 1] = 2 ds["label", 2] = 0 - assert ds.compute(label_name=True) == [ + assert ds.compute(label_name=True).tolist() == [ {"label": "green"}, {"label": "blue"}, {"label": "red"}, {"label": "red"}, {"label": "red"}, ] - assert ds.compute() == [ + assert ds.compute().tolist() == [ {"label": 1}, {"label": 2}, {"label": 0}, @@ -1048,8 +1048,11 @@ def test_check_label_name(): ] assert ds[1].compute(label_name=True) == {"label": "blue"} assert ds[1].compute() == {"label": 2} - assert ds[1:3].compute(label_name=True) == [{"label": "blue"}, {"label": "red"}] - assert ds[1:3].compute() == [{"label": 2}, {"label": 0}] + assert ds[1:3].compute(label_name=True).tolist() == [ + {"label": "blue"}, + {"label": "red"}, + ] + assert ds[1:3].compute().tolist() == [{"label": 2}, {"label": 0}] @pytest.mark.skipif(not minio_creds_exist(), reason="requires minio credentials") diff --git a/hub/api/tests/test_sharded_dataset.py b/hub/api/tests/test_sharded_dataset.py index bb718e7c7b..36ab22c2f0 100644 --- a/hub/api/tests/test_sharded_dataset.py +++ b/hub/api/tests/test_sharded_dataset.py @@ -5,7 +5,6 @@ """ from hub.schema.features import SchemaDict -from hub.exceptions import AdvancedSlicingNotSupported from hub.api.sharded_datasetview import ShardedDatasetView from hub import Dataset import pytest @@ -27,8 +26,6 @@ def test_sharded_dataset(): assert ds.shape == (40,) assert type(ds.schema) == SchemaDict assert ds.__repr__() == "ShardedDatasetView(shape=(40,))" - with pytest.raises(AdvancedSlicingNotSupported): - ds[5:8] ds[4, "first"] = 3 for _ in ds: pass @@ -62,5 +59,61 @@ def test_sharded_dataset_with_views(): assert sharded_ds[i, "second"].compute() == 2 * (i - 5) + 1 +def test_sharded_dataset_advanced_slice(): + schema = {"first": "float", "second": "float"} + ds = Dataset("./data/test_sharded_ds", shape=(10,), schema=schema, mode="w") + for i in range(10): + ds[i, "first"] = i + ds[i, "second"] = 2 * i + 1 + + dsv = ds[3:5] + dsv2 = ds[1] + dsv3 = ds[8:] + datasets = [dsv, ds, dsv2, dsv3] + sharded_ds = ShardedDatasetView(datasets) + assert sharded_ds["first", :].compute().tolist() == [ + 3, + 4, + 0, + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 1, + 8, + 9, + ] + assert sharded_ds["first"].compute().tolist() == [ + 3, + 4, + 0, + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 1, + 8, + 9, + ] + assert sharded_ds["first", -4:].compute().tolist() == [9, 1, 8, 9] + assert sharded_ds[1:3].compute()[0] == {"first": 4.0, "second": 9.0} + assert sharded_ds[1:3].compute()[1] == {"first": 0.0, "second": 1.0} + sharded_ds["first", 1:5] = [10, 11, 12, 13] + assert sharded_ds["first", 1:5].compute().tolist() == [10, 11, 12, 13] + sharded_ds["first", 12] = 50 + assert sharded_ds["first", 12].compute() == 50 + + if __name__ == "__main__": - test_sharded_dataset() + # test_sharded_dataset() + test_sharded_dataset_advanced_slice() diff --git a/hub/exceptions.py b/hub/exceptions.py index d93750bc2c..4f823f7f54 100644 --- a/hub/exceptions.py +++ b/hub/exceptions.py @@ -261,12 +261,6 @@ def __init__(self): super(HubException, self).__init__(message=message) -class AdvancedSlicingNotSupported(HubException): - def __init__(self): - message = "Advanced slicing is not supported, only support index" - super(HubException, self).__init__(message=message) - - class NotZarrFolderException(Exception): pass diff --git a/hub/tests/test_exceptions.py b/hub/tests/test_exceptions.py index 2f2e67ca69..2681cb8d38 100644 --- a/hub/tests/test_exceptions.py +++ b/hub/tests/test_exceptions.py @@ -5,7 +5,6 @@ """ from hub.exceptions import ( - AdvancedSlicingNotSupported, DaskModuleNotInstalledException, HubException, AuthenticationException, @@ -70,10 +69,9 @@ def test_exceptions(): NotHubDatasetToAppendException() DynamicTensorNotFoundException() NotIterable() - AdvancedSlicingNotSupported() DaskModuleNotInstalledException() - DynamicTensorShapeException("none") DynamicTensorShapeException("length") DynamicTensorShapeException("not_equal") DynamicTensorShapeException("another_cause") + NotFound()