From 1a8bd167fe0103e3c6954abe89066d295b7e1e71 Mon Sep 17 00:00:00 2001 From: AbhinavTuli Date: Mon, 15 Feb 2021 21:42:10 +0530 Subject: [PATCH 1/6] advanced slicing added --- hub/api/compute_list.py | 21 +++++++++++ hub/api/dataset.py | 4 +- hub/api/datasetview.py | 5 ++- hub/api/sharded_datasetview.py | 50 +++++++++++++++---------- hub/api/tests/test_sharded_dataset.py | 54 ++++++++++++++++++++++++++- 5 files changed, 110 insertions(+), 24 deletions(-) create mode 100644 hub/api/compute_list.py diff --git a/hub/api/compute_list.py b/hub/api/compute_list.py new file mode 100644 index 0000000000..546f952825 --- /dev/null +++ b/hub/api/compute_list.py @@ -0,0 +1,21 @@ +from hub.api.dataset import Dataset, DatasetView, TensorView +import numpy as np + + +class ComputeList: + # Doesn't support further get item operations currently + def __init__(self, ls): + self.ls = ls + + def compute(self): + results = [ + item.compute() + if isinstance(item, (Dataset, DatasetView, TensorView)) + else item + for item in self.ls + ] + return np.concatenate(results) + + def numpy(self): + return self.compute() + diff --git a/hub/api/dataset.py b/hub/api/dataset.py index 543ab538a1..20571c9e9a 100644 --- a/hub/api/dataset.py +++ b/hub/api/dataset.py @@ -771,10 +771,10 @@ def numpy(self, label_name=False): If the TensorView object is of the ClassLabel type, setting this to True would retrieve the label names instead of the label encoded integers, otherwise this parameter is ignored. """ - return [ + return np.array([ create_numpy_dict(self, i, label_name=label_name) for i in range(self._shape[0]) - ] + ]) def compute(self, label_name=False): """Gets the values from different tensorview objects in the dataset schema diff --git a/hub/api/datasetview.py b/hub/api/datasetview.py index 41277a3f42..f728723764 100644 --- a/hub/api/datasetview.py +++ b/hub/api/datasetview.py @@ -16,6 +16,7 @@ from hub.exceptions import NoneValueException from hub.api.objectview import ObjectView from hub.schema import Sequence +import numpy as np class DatasetView: @@ -306,10 +307,10 @@ def numpy(self, label_name=False): if isinstance(self.indexes, int): return create_numpy_dict(self.dataset, self.indexes, label_name=label_name) else: - return [ + return np.array([ create_numpy_dict(self.dataset, index, label_name=label_name) for index in self.indexes - ] + ]) def disable_lazy(self): self.lazy = False diff --git a/hub/api/sharded_datasetview.py b/hub/api/sharded_datasetview.py index 4887176224..8b0aeb1304 100644 --- a/hub/api/sharded_datasetview.py +++ b/hub/api/sharded_datasetview.py @@ -5,10 +5,10 @@ """ from collections.abc import Iterable - -from hub.api.datasetview import DatasetView +import numpy as np from hub.exceptions import AdvancedSlicingNotSupported - +from hub.api.dataset_utils import slice_split +from hub.api.compute_list import ComputeList class ShardedDatasetView: def __init__(self, datasets: list) -> None: @@ -55,29 +55,41 @@ def identify_shard(self, index) -> tuple: shard_id += 1 return 0, 0 - def slicing(self, slice_): + def slicing(self, slice_list): """ Identifies the dataset shard that should be used - Notes: - Features of advanced slicing are missing as one would expect from a DatasetView - E.g. cross sharded dataset access is missing """ + shard_id, offset = self.identify_shard(slice_list[0]) + slice_list[0] = slice_list[0] - offset + return slice_list, shard_id + + def __getitem__(self, slice_): if not isinstance(slice_, Iterable) or isinstance(slice_, str): slice_ = [slice_] - slice_ = list(slice_) - if not isinstance(slice_[0], int): - # TODO add advanced slicing options - raise AdvancedSlicingNotSupported() + subpath, slice_list = slice_split(slice_) + slice_list = slice_list or [slice(0, self.num_samples)] + if isinstance(slice_list[0], int): + slice_list, shard_id = self.slicing(slice_list) + slice_ = slice_list + [subpath] if subpath else slice_list + return self.datasets[shard_id][slice_] + else: + results = [] + cur_index = slice_list[0].start or 0 + cur_index = cur_index + self.num_samples if cur_index < 0 else cur_index + cur_index = max(cur_index, 0) + stop_index = slice_list[0].stop or self.num_samples + stop_index = min(stop_index, self.num_samples) + while cur_index < stop_index: + shard_id, offset = self.identify_shard(cur_index) + end_index = min(offset + len(self.datasets[shard_id]), stop_index) + cur_slice_list = [slice(cur_index - offset, end_index - offset)] + slice_list[1:] + current_slice = cur_slice_list + [subpath] if subpath else cur_slice_list + results.append(self.datasets[shard_id][current_slice]) + cur_index = end_index + return ComputeList(results) + - shard_id, offset = self.identify_shard(slice_[0]) - slice_[0] = slice_[0] - offset - - return slice_, shard_id - - def __getitem__(self, slice_) -> DatasetView: - slice_, shard_id = self.slicing(slice_) - return self.datasets[shard_id][slice_] def __setitem__(self, slice_, value) -> None: slice_, shard_id = self.slicing(slice_) diff --git a/hub/api/tests/test_sharded_dataset.py b/hub/api/tests/test_sharded_dataset.py index bb718e7c7b..f4e88528a2 100644 --- a/hub/api/tests/test_sharded_dataset.py +++ b/hub/api/tests/test_sharded_dataset.py @@ -62,5 +62,57 @@ def test_sharded_dataset_with_views(): assert sharded_ds[i, "second"].compute() == 2 * (i - 5) + 1 +def test_sharded_dataset_advanced_slice(): + schema = {"first": "float", "second": "float"} + ds = Dataset("./data/test_sharded_ds", shape=(10,), schema=schema, mode="w") + for i in range(10): + ds[i, "first"] = i + ds[i, "second"] = 2 * i + 1 + + dsv = ds[3:5] + dsv2 = ds[1] + dsv3 = ds[8:] + datasets = [dsv, ds, dsv2, dsv3] + sharded_ds = ShardedDatasetView(datasets) + assert sharded_ds["first", :].compute().tolist() == [ + 3, + 4, + 0, + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 1, + 8, + 9, + ] + assert sharded_ds["first"].compute().tolist() == [ + 3, + 4, + 0, + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 1, + 8, + 9, + ] + assert sharded_ds["first", -4:].compute().tolist() == [9, 1, 8, 9] + assert sharded_ds[1:3].compute()[0] == {"first": 4.0, "second": 9.0} + assert sharded_ds[1:3].compute()[1] == {"first": 0.0, "second": 1.0} + + if __name__ == "__main__": - test_sharded_dataset() + # test_sharded_dataset() + test_sharded_dataset_advanced_slice() From 6fed71a2592a1eb84824db7a0d643099dc2970d9 Mon Sep 17 00:00:00 2001 From: AbhinavTuli Date: Mon, 15 Feb 2021 21:55:10 +0530 Subject: [PATCH 2/6] fixing tests --- hub/api/tests/test_dataset.py | 2 +- hub/api/tests/test_sharded_dataset.py | 3 --- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/hub/api/tests/test_dataset.py b/hub/api/tests/test_dataset.py index af3ae2c0fb..3b24dc8eaa 100644 --- a/hub/api/tests/test_dataset.py +++ b/hub/api/tests/test_dataset.py @@ -1032,7 +1032,7 @@ def test_check_label_name(): ds["label", 0] = 1 ds["label", 1] = 2 ds["label", 2] = 0 - assert ds.compute(label_name=True) == [ + assert ds.compute(label_name=True).tolist() == [ {"label": "green"}, {"label": "blue"}, {"label": "red"}, diff --git a/hub/api/tests/test_sharded_dataset.py b/hub/api/tests/test_sharded_dataset.py index f4e88528a2..de300c9197 100644 --- a/hub/api/tests/test_sharded_dataset.py +++ b/hub/api/tests/test_sharded_dataset.py @@ -5,7 +5,6 @@ """ from hub.schema.features import SchemaDict -from hub.exceptions import AdvancedSlicingNotSupported from hub.api.sharded_datasetview import ShardedDatasetView from hub import Dataset import pytest @@ -27,8 +26,6 @@ def test_sharded_dataset(): assert ds.shape == (40,) assert type(ds.schema) == SchemaDict assert ds.__repr__() == "ShardedDatasetView(shape=(40,))" - with pytest.raises(AdvancedSlicingNotSupported): - ds[5:8] ds[4, "first"] = 3 for _ in ds: pass From a0c399d8f2f777986e2b2f1b6151bb98dde27eb7 Mon Sep 17 00:00:00 2001 From: AbhinavTuli Date: Mon, 15 Feb 2021 22:31:28 +0530 Subject: [PATCH 3/6] setitem added for advanced slicing --- hub/api/sharded_datasetview.py | 44 ++++++++++++++++++++++++++++------ 1 file changed, 37 insertions(+), 7 deletions(-) diff --git a/hub/api/sharded_datasetview.py b/hub/api/sharded_datasetview.py index 8b0aeb1304..ff02da1ad3 100644 --- a/hub/api/sharded_datasetview.py +++ b/hub/api/sharded_datasetview.py @@ -10,6 +10,7 @@ from hub.api.dataset_utils import slice_split from hub.api.compute_list import ComputeList + class ShardedDatasetView: def __init__(self, datasets: list) -> None: """ @@ -79,21 +80,50 @@ def __getitem__(self, slice_): cur_index = cur_index + self.num_samples if cur_index < 0 else cur_index cur_index = max(cur_index, 0) stop_index = slice_list[0].stop or self.num_samples - stop_index = min(stop_index, self.num_samples) + stop_index = min(stop_index, self.num_samples) while cur_index < stop_index: shard_id, offset = self.identify_shard(cur_index) end_index = min(offset + len(self.datasets[shard_id]), stop_index) - cur_slice_list = [slice(cur_index - offset, end_index - offset)] + slice_list[1:] - current_slice = cur_slice_list + [subpath] if subpath else cur_slice_list + cur_slice_list = [ + slice(cur_index - offset, end_index - offset) + ] + slice_list[1:] + current_slice = ( + cur_slice_list + [subpath] if subpath else cur_slice_list + ) results.append(self.datasets[shard_id][current_slice]) cur_index = end_index return ComputeList(results) - - def __setitem__(self, slice_, value) -> None: - slice_, shard_id = self.slicing(slice_) - self.datasets[shard_id][slice_] = value + if not isinstance(slice_, Iterable) or isinstance(slice_, str): + slice_ = [slice_] + slice_ = list(slice_) + subpath, slice_list = slice_split(slice_) + slice_list = slice_list or [slice(0, self.num_samples)] + if isinstance(slice_list[0], int): + slice_list, shard_id = self.slicing(slice_list) + slice_ = slice_list + [subpath] if subpath else slice_list + self.datasets[shard_id][slice_] = value + else: + cur_index = slice_list[0].start or 0 + cur_index = cur_index + self.num_samples if cur_index < 0 else cur_index + cur_index = max(cur_index, 0) + start_index = cur_index + stop_index = slice_list[0].stop or self.num_samples + stop_index = min(stop_index, self.num_samples) + while cur_index < stop_index: + shard_id, offset = self.identify_shard(cur_index) + end_index = min(offset + len(self.datasets[shard_id]), stop_index) + cur_slice_list = [ + slice(cur_index - offset, end_index - offset) + ] + slice_list[1:] + current_slice = ( + cur_slice_list + [subpath] if subpath else cur_slice_list + ) + self.datasets[shard_id][current_slice] = value[ + cur_index - start_index : end_index - start_index + ] + cur_index = end_index def __iter__(self): """ Returns Iterable over samples """ From 4ecd521fb89743e069d54ed7c7702718018f8f88 Mon Sep 17 00:00:00 2001 From: AbhinavTuli Date: Mon, 15 Feb 2021 23:04:51 +0530 Subject: [PATCH 4/6] fixes test and linting --- hub/api/compute_list.py | 1 - hub/api/dataset.py | 10 ++++++---- hub/api/datasetview.py | 10 ++++++---- hub/api/tests/test_dataset.py | 9 ++++++--- hub/api/tests/test_sharded_dataset.py | 4 ++++ 5 files changed, 22 insertions(+), 12 deletions(-) diff --git a/hub/api/compute_list.py b/hub/api/compute_list.py index 546f952825..914aff9337 100644 --- a/hub/api/compute_list.py +++ b/hub/api/compute_list.py @@ -18,4 +18,3 @@ def compute(self): def numpy(self): return self.compute() - diff --git a/hub/api/dataset.py b/hub/api/dataset.py index 20571c9e9a..7e9daa6ed7 100644 --- a/hub/api/dataset.py +++ b/hub/api/dataset.py @@ -771,10 +771,12 @@ def numpy(self, label_name=False): If the TensorView object is of the ClassLabel type, setting this to True would retrieve the label names instead of the label encoded integers, otherwise this parameter is ignored. """ - return np.array([ - create_numpy_dict(self, i, label_name=label_name) - for i in range(self._shape[0]) - ]) + return np.array( + [ + create_numpy_dict(self, i, label_name=label_name) + for i in range(self._shape[0]) + ] + ) def compute(self, label_name=False): """Gets the values from different tensorview objects in the dataset schema diff --git a/hub/api/datasetview.py b/hub/api/datasetview.py index f728723764..17418d5043 100644 --- a/hub/api/datasetview.py +++ b/hub/api/datasetview.py @@ -307,10 +307,12 @@ def numpy(self, label_name=False): if isinstance(self.indexes, int): return create_numpy_dict(self.dataset, self.indexes, label_name=label_name) else: - return np.array([ - create_numpy_dict(self.dataset, index, label_name=label_name) - for index in self.indexes - ]) + return np.array( + [ + create_numpy_dict(self.dataset, index, label_name=label_name) + for index in self.indexes + ] + ) def disable_lazy(self): self.lazy = False diff --git a/hub/api/tests/test_dataset.py b/hub/api/tests/test_dataset.py index 3b24dc8eaa..bc8309771b 100644 --- a/hub/api/tests/test_dataset.py +++ b/hub/api/tests/test_dataset.py @@ -1039,7 +1039,7 @@ def test_check_label_name(): {"label": "red"}, {"label": "red"}, ] - assert ds.compute() == [ + assert ds.compute().tolist() == [ {"label": 1}, {"label": 2}, {"label": 0}, @@ -1048,8 +1048,11 @@ def test_check_label_name(): ] assert ds[1].compute(label_name=True) == {"label": "blue"} assert ds[1].compute() == {"label": 2} - assert ds[1:3].compute(label_name=True) == [{"label": "blue"}, {"label": "red"}] - assert ds[1:3].compute() == [{"label": 2}, {"label": 0}] + assert ds[1:3].compute(label_name=True).tolist() == [ + {"label": "blue"}, + {"label": "red"}, + ] + assert ds[1:3].compute().tolist() == [{"label": 2}, {"label": 0}] @pytest.mark.skipif(not minio_creds_exist(), reason="requires minio credentials") diff --git a/hub/api/tests/test_sharded_dataset.py b/hub/api/tests/test_sharded_dataset.py index de300c9197..36ab22c2f0 100644 --- a/hub/api/tests/test_sharded_dataset.py +++ b/hub/api/tests/test_sharded_dataset.py @@ -108,6 +108,10 @@ def test_sharded_dataset_advanced_slice(): assert sharded_ds["first", -4:].compute().tolist() == [9, 1, 8, 9] assert sharded_ds[1:3].compute()[0] == {"first": 4.0, "second": 9.0} assert sharded_ds[1:3].compute()[1] == {"first": 0.0, "second": 1.0} + sharded_ds["first", 1:5] = [10, 11, 12, 13] + assert sharded_ds["first", 1:5].compute().tolist() == [10, 11, 12, 13] + sharded_ds["first", 12] = 50 + assert sharded_ds["first", 12].compute() == 50 if __name__ == "__main__": From f8253edd25357dbe211afaf28bf3719b6254b83c Mon Sep 17 00:00:00 2001 From: AbhinavTuli Date: Fri, 19 Feb 2021 19:05:16 +0530 Subject: [PATCH 5/6] comments added --- hub/api/compute_list.py | 1 + hub/api/sharded_datasetview.py | 7 +++++-- hub/exceptions.py | 6 ------ 3 files changed, 6 insertions(+), 8 deletions(-) diff --git a/hub/api/compute_list.py b/hub/api/compute_list.py index 914aff9337..6b1f3a7d80 100644 --- a/hub/api/compute_list.py +++ b/hub/api/compute_list.py @@ -2,6 +2,7 @@ import numpy as np +# a list of Datasets or DatasetViews or Tensorviews that supports compute operation class ComputeList: # Doesn't support further get item operations currently def __init__(self, ls): diff --git a/hub/api/sharded_datasetview.py b/hub/api/sharded_datasetview.py index ff02da1ad3..189ee90346 100644 --- a/hub/api/sharded_datasetview.py +++ b/hub/api/sharded_datasetview.py @@ -5,8 +5,6 @@ """ from collections.abc import Iterable -import numpy as np -from hub.exceptions import AdvancedSlicingNotSupported from hub.api.dataset_utils import slice_split from hub.api.compute_list import ComputeList @@ -71,10 +69,13 @@ def __getitem__(self, slice_): subpath, slice_list = slice_split(slice_) slice_list = slice_list or [slice(0, self.num_samples)] if isinstance(slice_list[0], int): + # if integer it fetches the data from the corresponding dataset slice_list, shard_id = self.slicing(slice_list) slice_ = slice_list + [subpath] if subpath else slice_list return self.datasets[shard_id][slice_] else: + # if slice it finds all the corresponding datasets included in the slice and generates tensorviews or datasetviews (depending on slice) + # these views are stored in a ComputeList, calling compute on which will fetch data from all corresponding datasets and return a single result results = [] cur_index = slice_list[0].start or 0 cur_index = cur_index + self.num_samples if cur_index < 0 else cur_index @@ -101,10 +102,12 @@ def __setitem__(self, slice_, value) -> None: subpath, slice_list = slice_split(slice_) slice_list = slice_list or [slice(0, self.num_samples)] if isinstance(slice_list[0], int): + # if integer it assigns the data to the corresponding dataset slice_list, shard_id = self.slicing(slice_list) slice_ = slice_list + [subpath] if subpath else slice_list self.datasets[shard_id][slice_] = value else: + # if slice it finds all the corresponding datasets and assigns slices of the value one by one cur_index = slice_list[0].start or 0 cur_index = cur_index + self.num_samples if cur_index < 0 else cur_index cur_index = max(cur_index, 0) diff --git a/hub/exceptions.py b/hub/exceptions.py index d93750bc2c..4f823f7f54 100644 --- a/hub/exceptions.py +++ b/hub/exceptions.py @@ -261,12 +261,6 @@ def __init__(self): super(HubException, self).__init__(message=message) -class AdvancedSlicingNotSupported(HubException): - def __init__(self): - message = "Advanced slicing is not supported, only support index" - super(HubException, self).__init__(message=message) - - class NotZarrFolderException(Exception): pass From 9cc562f2151c965b4cf3990de2bc70580554f412 Mon Sep 17 00:00:00 2001 From: AbhinavTuli Date: Fri, 19 Feb 2021 19:08:21 +0530 Subject: [PATCH 6/6] fixed test --- hub/tests/test_exceptions.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/hub/tests/test_exceptions.py b/hub/tests/test_exceptions.py index 2f2e67ca69..2681cb8d38 100644 --- a/hub/tests/test_exceptions.py +++ b/hub/tests/test_exceptions.py @@ -5,7 +5,6 @@ """ from hub.exceptions import ( - AdvancedSlicingNotSupported, DaskModuleNotInstalledException, HubException, AuthenticationException, @@ -70,10 +69,9 @@ def test_exceptions(): NotHubDatasetToAppendException() DynamicTensorNotFoundException() NotIterable() - AdvancedSlicingNotSupported() DaskModuleNotInstalledException() - DynamicTensorShapeException("none") DynamicTensorShapeException("length") DynamicTensorShapeException("not_equal") DynamicTensorShapeException("another_cause") + NotFound()