Skip to content

Commit

Permalink
advanced slicing added
Browse files Browse the repository at this point in the history
  • Loading branch information
AbhinavTuli committed Feb 15, 2021
1 parent f5a3329 commit 1a8bd16
Show file tree
Hide file tree
Showing 5 changed files with 110 additions and 24 deletions.
21 changes: 21 additions & 0 deletions hub/api/compute_list.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from hub.api.dataset import Dataset, DatasetView, TensorView
import numpy as np


class ComputeList:
# Doesn't support further get item operations currently
def __init__(self, ls):
self.ls = ls

def compute(self):
results = [
item.compute()
if isinstance(item, (Dataset, DatasetView, TensorView))
else item
for item in self.ls
]
return np.concatenate(results)

def numpy(self):
return self.compute()

4 changes: 2 additions & 2 deletions hub/api/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -771,10 +771,10 @@ def numpy(self, label_name=False):
If the TensorView object is of the ClassLabel type, setting this to True would retrieve the label names
instead of the label encoded integers, otherwise this parameter is ignored.
"""
return [
return np.array([
create_numpy_dict(self, i, label_name=label_name)
for i in range(self._shape[0])
]
])

def compute(self, label_name=False):
"""Gets the values from different tensorview objects in the dataset schema
Expand Down
5 changes: 3 additions & 2 deletions hub/api/datasetview.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from hub.exceptions import NoneValueException
from hub.api.objectview import ObjectView
from hub.schema import Sequence
import numpy as np


class DatasetView:
Expand Down Expand Up @@ -306,10 +307,10 @@ def numpy(self, label_name=False):
if isinstance(self.indexes, int):
return create_numpy_dict(self.dataset, self.indexes, label_name=label_name)
else:
return [
return np.array([
create_numpy_dict(self.dataset, index, label_name=label_name)
for index in self.indexes
]
])

def disable_lazy(self):
self.lazy = False
Expand Down
50 changes: 31 additions & 19 deletions hub/api/sharded_datasetview.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
"""

from collections.abc import Iterable

from hub.api.datasetview import DatasetView
import numpy as np
from hub.exceptions import AdvancedSlicingNotSupported

from hub.api.dataset_utils import slice_split
from hub.api.compute_list import ComputeList

class ShardedDatasetView:
def __init__(self, datasets: list) -> None:
Expand Down Expand Up @@ -55,29 +55,41 @@ def identify_shard(self, index) -> tuple:
shard_id += 1
return 0, 0

def slicing(self, slice_):
def slicing(self, slice_list):
"""
Identifies the dataset shard that should be used
Notes:
Features of advanced slicing are missing as one would expect from a DatasetView
E.g. cross sharded dataset access is missing
"""
shard_id, offset = self.identify_shard(slice_list[0])
slice_list[0] = slice_list[0] - offset
return slice_list, shard_id

def __getitem__(self, slice_):
if not isinstance(slice_, Iterable) or isinstance(slice_, str):
slice_ = [slice_]

slice_ = list(slice_)
if not isinstance(slice_[0], int):
# TODO add advanced slicing options
raise AdvancedSlicingNotSupported()
subpath, slice_list = slice_split(slice_)
slice_list = slice_list or [slice(0, self.num_samples)]
if isinstance(slice_list[0], int):
slice_list, shard_id = self.slicing(slice_list)
slice_ = slice_list + [subpath] if subpath else slice_list
return self.datasets[shard_id][slice_]
else:
results = []
cur_index = slice_list[0].start or 0
cur_index = cur_index + self.num_samples if cur_index < 0 else cur_index
cur_index = max(cur_index, 0)
stop_index = slice_list[0].stop or self.num_samples
stop_index = min(stop_index, self.num_samples)
while cur_index < stop_index:
shard_id, offset = self.identify_shard(cur_index)
end_index = min(offset + len(self.datasets[shard_id]), stop_index)
cur_slice_list = [slice(cur_index - offset, end_index - offset)] + slice_list[1:]
current_slice = cur_slice_list + [subpath] if subpath else cur_slice_list
results.append(self.datasets[shard_id][current_slice])
cur_index = end_index
return ComputeList(results)


shard_id, offset = self.identify_shard(slice_[0])
slice_[0] = slice_[0] - offset

return slice_, shard_id

def __getitem__(self, slice_) -> DatasetView:
slice_, shard_id = self.slicing(slice_)
return self.datasets[shard_id][slice_]

def __setitem__(self, slice_, value) -> None:
slice_, shard_id = self.slicing(slice_)
Expand Down
54 changes: 53 additions & 1 deletion hub/api/tests/test_sharded_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,5 +62,57 @@ def test_sharded_dataset_with_views():
assert sharded_ds[i, "second"].compute() == 2 * (i - 5) + 1


def test_sharded_dataset_advanced_slice():
schema = {"first": "float", "second": "float"}
ds = Dataset("./data/test_sharded_ds", shape=(10,), schema=schema, mode="w")
for i in range(10):
ds[i, "first"] = i
ds[i, "second"] = 2 * i + 1

dsv = ds[3:5]
dsv2 = ds[1]
dsv3 = ds[8:]
datasets = [dsv, ds, dsv2, dsv3]
sharded_ds = ShardedDatasetView(datasets)
assert sharded_ds["first", :].compute().tolist() == [
3,
4,
0,
1,
2,
3,
4,
5,
6,
7,
8,
9,
1,
8,
9,
]
assert sharded_ds["first"].compute().tolist() == [
3,
4,
0,
1,
2,
3,
4,
5,
6,
7,
8,
9,
1,
8,
9,
]
assert sharded_ds["first", -4:].compute().tolist() == [9, 1, 8, 9]
assert sharded_ds[1:3].compute()[0] == {"first": 4.0, "second": 9.0}
assert sharded_ds[1:3].compute()[1] == {"first": 0.0, "second": 1.0}


if __name__ == "__main__":
test_sharded_dataset()
# test_sharded_dataset()
test_sharded_dataset_advanced_slice()

0 comments on commit 1a8bd16

Please sign in to comment.