diff --git a/hub/api/compute_list.py b/hub/api/compute_list.py new file mode 100644 index 0000000000..546f952825 --- /dev/null +++ b/hub/api/compute_list.py @@ -0,0 +1,21 @@ +from hub.api.dataset import Dataset, DatasetView, TensorView +import numpy as np + + +class ComputeList: + # Doesn't support further get item operations currently + def __init__(self, ls): + self.ls = ls + + def compute(self): + results = [ + item.compute() + if isinstance(item, (Dataset, DatasetView, TensorView)) + else item + for item in self.ls + ] + return np.concatenate(results) + + def numpy(self): + return self.compute() + diff --git a/hub/api/dataset.py b/hub/api/dataset.py index 543ab538a1..20571c9e9a 100644 --- a/hub/api/dataset.py +++ b/hub/api/dataset.py @@ -771,10 +771,10 @@ def numpy(self, label_name=False): If the TensorView object is of the ClassLabel type, setting this to True would retrieve the label names instead of the label encoded integers, otherwise this parameter is ignored. """ - return [ + return np.array([ create_numpy_dict(self, i, label_name=label_name) for i in range(self._shape[0]) - ] + ]) def compute(self, label_name=False): """Gets the values from different tensorview objects in the dataset schema diff --git a/hub/api/datasetview.py b/hub/api/datasetview.py index 41277a3f42..f728723764 100644 --- a/hub/api/datasetview.py +++ b/hub/api/datasetview.py @@ -16,6 +16,7 @@ from hub.exceptions import NoneValueException from hub.api.objectview import ObjectView from hub.schema import Sequence +import numpy as np class DatasetView: @@ -306,10 +307,10 @@ def numpy(self, label_name=False): if isinstance(self.indexes, int): return create_numpy_dict(self.dataset, self.indexes, label_name=label_name) else: - return [ + return np.array([ create_numpy_dict(self.dataset, index, label_name=label_name) for index in self.indexes - ] + ]) def disable_lazy(self): self.lazy = False diff --git a/hub/api/sharded_datasetview.py b/hub/api/sharded_datasetview.py index 4887176224..8b0aeb1304 100644 --- a/hub/api/sharded_datasetview.py +++ b/hub/api/sharded_datasetview.py @@ -5,10 +5,10 @@ """ from collections.abc import Iterable - -from hub.api.datasetview import DatasetView +import numpy as np from hub.exceptions import AdvancedSlicingNotSupported - +from hub.api.dataset_utils import slice_split +from hub.api.compute_list import ComputeList class ShardedDatasetView: def __init__(self, datasets: list) -> None: @@ -55,29 +55,41 @@ def identify_shard(self, index) -> tuple: shard_id += 1 return 0, 0 - def slicing(self, slice_): + def slicing(self, slice_list): """ Identifies the dataset shard that should be used - Notes: - Features of advanced slicing are missing as one would expect from a DatasetView - E.g. cross sharded dataset access is missing """ + shard_id, offset = self.identify_shard(slice_list[0]) + slice_list[0] = slice_list[0] - offset + return slice_list, shard_id + + def __getitem__(self, slice_): if not isinstance(slice_, Iterable) or isinstance(slice_, str): slice_ = [slice_] - slice_ = list(slice_) - if not isinstance(slice_[0], int): - # TODO add advanced slicing options - raise AdvancedSlicingNotSupported() + subpath, slice_list = slice_split(slice_) + slice_list = slice_list or [slice(0, self.num_samples)] + if isinstance(slice_list[0], int): + slice_list, shard_id = self.slicing(slice_list) + slice_ = slice_list + [subpath] if subpath else slice_list + return self.datasets[shard_id][slice_] + else: + results = [] + cur_index = slice_list[0].start or 0 + cur_index = cur_index + self.num_samples if cur_index < 0 else cur_index + cur_index = max(cur_index, 0) + stop_index = slice_list[0].stop or self.num_samples + stop_index = min(stop_index, self.num_samples) + while cur_index < stop_index: + shard_id, offset = self.identify_shard(cur_index) + end_index = min(offset + len(self.datasets[shard_id]), stop_index) + cur_slice_list = [slice(cur_index - offset, end_index - offset)] + slice_list[1:] + current_slice = cur_slice_list + [subpath] if subpath else cur_slice_list + results.append(self.datasets[shard_id][current_slice]) + cur_index = end_index + return ComputeList(results) + - shard_id, offset = self.identify_shard(slice_[0]) - slice_[0] = slice_[0] - offset - - return slice_, shard_id - - def __getitem__(self, slice_) -> DatasetView: - slice_, shard_id = self.slicing(slice_) - return self.datasets[shard_id][slice_] def __setitem__(self, slice_, value) -> None: slice_, shard_id = self.slicing(slice_) diff --git a/hub/api/tests/test_sharded_dataset.py b/hub/api/tests/test_sharded_dataset.py index bb718e7c7b..f4e88528a2 100644 --- a/hub/api/tests/test_sharded_dataset.py +++ b/hub/api/tests/test_sharded_dataset.py @@ -62,5 +62,57 @@ def test_sharded_dataset_with_views(): assert sharded_ds[i, "second"].compute() == 2 * (i - 5) + 1 +def test_sharded_dataset_advanced_slice(): + schema = {"first": "float", "second": "float"} + ds = Dataset("./data/test_sharded_ds", shape=(10,), schema=schema, mode="w") + for i in range(10): + ds[i, "first"] = i + ds[i, "second"] = 2 * i + 1 + + dsv = ds[3:5] + dsv2 = ds[1] + dsv3 = ds[8:] + datasets = [dsv, ds, dsv2, dsv3] + sharded_ds = ShardedDatasetView(datasets) + assert sharded_ds["first", :].compute().tolist() == [ + 3, + 4, + 0, + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 1, + 8, + 9, + ] + assert sharded_ds["first"].compute().tolist() == [ + 3, + 4, + 0, + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 1, + 8, + 9, + ] + assert sharded_ds["first", -4:].compute().tolist() == [9, 1, 8, 9] + assert sharded_ds[1:3].compute()[0] == {"first": 4.0, "second": 9.0} + assert sharded_ds[1:3].compute()[1] == {"first": 0.0, "second": 1.0} + + if __name__ == "__main__": - test_sharded_dataset() + # test_sharded_dataset() + test_sharded_dataset_advanced_slice()