Skip to content

Commit

Permalink
Merge pull request #460 from activeloopai/feature/filtering
Browse files Browse the repository at this point in the history
Dataset filtering support
  • Loading branch information
AbhinavTuli committed Jan 19, 2021
2 parents 4f537c5 + 9bd63d3 commit 66985dd
Show file tree
Hide file tree
Showing 15 changed files with 1,022 additions and 812 deletions.
131 changes: 67 additions & 64 deletions hub/api/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,12 @@
from hub.log import logger
import hub.store.pickle_s3_storage

from hub.api.datasetview import DatasetView, ObjectView, TensorView

from hub.api.datasetview import DatasetView
from hub.api.objectview import ObjectView
from hub.api.tensorview import TensorView
from hub.api.dataset_utils import (
create_numpy_dict,
get_value,
slice_extract_info,
slice_split,
str_to_int,
)
Expand All @@ -41,8 +41,10 @@
from hub.store.store import get_fs_and_path, get_storage_map
from hub.exceptions import (
HubDatasetNotFoundException,
LargeShapeFilteringException,
NotHubDatasetToOverwriteException,
NotHubDatasetToAppendException,
OutOfBoundsError,
ShapeArgumentNotFoundException,
SchemaArgumentNotFoundException,
ModuleNotInstalledException,
Expand All @@ -55,7 +57,7 @@
from hub.schema import Audio, BBox, ClassLabel, Image, Sequence, Text, Video
from hub.numcodecs import PngCodec

from hub.utils import norm_cache, norm_shape
from hub.utils import norm_cache, norm_shape, _tuple_product
from hub import defaults


Expand Down Expand Up @@ -195,6 +197,8 @@ def __init__(
logger.error("Deleting the dataset " + traceback.format_exc() + str(e))
raise

self.indexes = list(range(self._shape[0]))

if needcreate and (
self._path.startswith("s3://snark-hub-dev/")
or self._path.startswith("s3://snark-hub/")
Expand Down Expand Up @@ -386,12 +390,10 @@ def __getitem__(self, slice_):
raise ValueError(
"Can't slice a dataset with multiple slices without key"
)
num, ofs = slice_extract_info(slice_list[0], self._shape[0])
indexes = self.indexes[slice_list[0]]
return DatasetView(
dataset=self,
num_samples=num,
offset=ofs,
squeeze_dim=isinstance(slice_list[0], int),
indexes=indexes,
lazy=self.lazy,
)
elif not slice_list:
Expand All @@ -402,45 +404,35 @@ def __getitem__(self, slice_):
slice_=slice(0, self._shape[0]),
lazy=self.lazy,
)
if self.lazy:
return tensorview
else:
return tensorview.compute()
return tensorview if self.lazy else tensorview.compute()
for key in self.keys:
if subpath.startswith(key):
objectview = ObjectView(
dataset=self, subpath=subpath, lazy=self.lazy
dataset=self,
subpath=subpath,
lazy=self.lazy,
slice_=[slice(0, self._shape[0])],
)
if self.lazy:
return objectview
else:
return objectview.compute()
return objectview if self.lazy else objectview.compute()
return self._get_dictionary(subpath)
else:
num, ofs = slice_extract_info(slice_list[0], self.shape[0])
schema_obj = self.schema.dict_[subpath.split("/")[1]]
if subpath in self.keys and (
not isinstance(schema_obj, Sequence) or len(slice_list) <= 1
):
tensorview = TensorView(
dataset=self, subpath=subpath, slice_=slice_list, lazy=self.lazy
)
if self.lazy:
return tensorview
else:
return tensorview.compute()
return tensorview if self.lazy else tensorview.compute()
for key in self.keys:
if subpath.startswith(key):
objectview = ObjectView(
dataset=self,
subpath=subpath,
slice_list=slice_list,
slice_=slice_list,
lazy=self.lazy,
)
if self.lazy:
return objectview
else:
return objectview.compute()
return objectview if self.lazy else objectview.compute()
if len(slice_list) > 1:
raise ValueError("You can't slice a dictionary of Tensors")
return self._get_dictionary(subpath, slice_list[0])
Expand All @@ -463,26 +455,43 @@ def __setitem__(self, slice_, value):
subpath, slice_list = slice_split(slice_)

if not subpath:
raise ValueError("Can't assign to dataset sliced without key")
elif not slice_list:
if subpath in self.keys:
self._tensors[subpath][:] = assign_value # Add path check
else:
ObjectView(dataset=self, subpath=subpath)[:] = assign_value
raise ValueError("Can't assign to dataset sliced without subpath")
elif subpath not in self.keys:
raise KeyError(f"Key {subpath} not found in the dataset")

if not slice_list:
self._tensors[subpath][:] = assign_value
else:
if subpath in self.keys:
self._tensors[subpath][slice_list] = assign_value
else:
ObjectView(dataset=self, subpath=subpath, slice_list=slice_list)[
:
] = assign_value
self._tensors[subpath][slice_list] = assign_value

def filter(self, dic):
"""| Applies a filter to get a new datasetview that matches the dictionary provided
Parameters
----------
dic: dictionary
A dictionary of key value pairs, used to filter the dataset. For nested schemas use flattened dictionary representation
i.e instead of {"abc": {"xyz" : 5}} use {"abc/xyz" : 5}
"""
indexes = self.indexes
for k, v in dic.items():
k = k if k.startswith("/") else "/" + k
if k not in self.keys:
raise KeyError(f"Key {k} not found in the dataset")
tsv = self[k]
max_shape = tsv.dtype.max_shape
prod = _tuple_product(max_shape)
if prod > 100:
raise LargeShapeFilteringException(k)
indexes = [index for index in indexes if tsv[index].compute() == v]
return DatasetView(dataset=self, lazy=self.lazy, indexes=indexes)

def resize_shape(self, size: int) -> None:
""" Resize the shape of the dataset by resizing each tensor first dimension """
if size == self._shape[0]:
return

self._shape = (int(size),)
self.indexes = list(range(self.shape[0]))
self.meta = self._store_meta()
for t in self._tensors.values():
t.resize_shape(int(size))
Expand Down Expand Up @@ -518,8 +527,7 @@ def to_pytorch(
transform=None,
inplace=True,
output_type=dict,
offset=None,
num_samples=None,
indexes=None,
):
"""| Converts the dataset into a pytorch compatible format.
Expand All @@ -542,18 +550,15 @@ def to_pytorch(
raise ModuleNotInstalledException("torch")

global torch
indexes = indexes or self.indexes

if "r" not in self.mode:
self.flush() # FIXME Without this some tests in test_converters.py fails, not clear why
return TorchDataset(
self,
transform,
inplace=inplace,
output_type=output_type,
offset=offset,
num_samples=num_samples,
self, transform, inplace=inplace, output_type=output_type, indexes=indexes
)

def to_tensorflow(self, offset=None, num_samples=None):
def to_tensorflow(self, indexes=None):
"""| Converts the dataset into a tensorflow compatible format
Parameters
Expand All @@ -570,11 +575,11 @@ def to_tensorflow(self, offset=None, num_samples=None):

global tf

offset = 0 if offset is None else offset
num_samples = self._shape[0] if num_samples is None else num_samples
indexes = indexes or self.indexes
indexes = [indexes] if isinstance(indexes, int) else indexes

def tf_gen():
for index in range(offset, offset + num_samples):
for index in indexes:
d = {}
for key in self.keys:
split_key = key.split("/")
Expand Down Expand Up @@ -1144,22 +1149,15 @@ def my_transform(sample):

class TorchDataset:
def __init__(
self,
ds,
transform=None,
inplace=True,
output_type=dict,
num_samples=None,
offset=None,
self, ds, transform=None, inplace=True, output_type=dict, indexes=None
):
self._ds = None
self._url = ds.url
self._token = ds.token
self._transform = transform
self.inplace = inplace
self.output_type = output_type
self.num_samples = num_samples
self.offset = offset
self.indexes = indexes
self._inited = False

def _do_transform(self, data):
Expand All @@ -1182,7 +1180,7 @@ def _init_ds(self):

def __len__(self):
self._init_ds()
return self.num_samples if self.num_samples is not None else self._ds.shape[0]
return len(self.indexes) if isinstance(self.indexes, list) else 1

def _get_active_item(self, key, index):
active_range = self._active_chunks_range.get(key)
Expand All @@ -1198,8 +1196,13 @@ def _get_active_item(self, key, index):
]
return self._active_chunks[key][index % samples_per_chunk]

def __getitem__(self, index):
index = index + self.offset if self.offset is not None else index
def __getitem__(self, ind):
if isinstance(self.indexes, int):
if ind != 0:
raise OutOfBoundsError(f"Got index {ind} for dataset of length 1")
index = self.indexes
else:
index = self.indexes[ind]
self._init_ds()
d = {}
for key in self._ds._tensors.keys():
Expand Down

0 comments on commit 66985dd

Please sign in to comment.