Skip to content

Commit

Permalink
Merge pull request #539 from activeloopai/feature/filtering_improvements
Browse files Browse the repository at this point in the history
Filtering improvements
  • Loading branch information
AbhinavTuli committed Feb 15, 2021
2 parents 8faed02 + 00c79a2 commit cc3abc2
Show file tree
Hide file tree
Showing 6 changed files with 114 additions and 68 deletions.
39 changes: 39 additions & 0 deletions docs/source/concepts/filtering.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Dataset Filtering

Using Hub you can filter your dataset to get a DatasetView that only has the items that you're interested in.
Filtering can be applied both to a Dataset or to a DatasetView (obtained by slicing or filtering a Dataset)

## Filtering using a function
Using filter, you can pass in a function that is applied element by element to the dataset. Only those elements for which the function returns True stay in the newly created DatasetView.

Example:-

```python
my_schema = {
"img": Tensor((100, 100)),
"name": Text((None,), max_shape=(10,))
}
ds = hub.Dataset("./data/filtering_example", shape=(20,), schema=my_schema)
for i in range(10): # assigning some values to the dataset
ds["img", i] = np.ones((100, 100))
ds["name", i] = "abc" + str(i) if i % 2 == 0 else "def" + str(i)

def my_filter(sample):
return sample["name"].compute().startswith("abc") and (sample["img"].compute() == np.ones((100, 100))).all()
ds2 = ds.filter(my_filter)

# alternatively, we can also use a lambda function to achieve the same results
ds3 = ds.filter(
lambda x: x["name"].compute().startswith("abc")
and (x["img"].compute() == np.ones((100, 100))).all()
)
```

## API
```eval_rst
.. autofunction:: hub.api.dataset.Dataset.filter
.. autofunction:: hub.api.datasetview.DatasetView.filter
```



1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ Wouldn’t it be more convenient to have large datasets stored & version-control
concepts/features.md
concepts/dataset.md
concepts/transform.md
concepts/filtering.md

.. toctree::
:maxdepth: 3
Expand Down
23 changes: 6 additions & 17 deletions hub/api/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@
from hub.store.store import get_fs_and_path, get_storage_map
from hub.exceptions import (
HubDatasetNotFoundException,
LargeShapeFilteringException,
NotHubDatasetToOverwriteException,
NotHubDatasetToAppendException,
OutOfBoundsError,
Expand Down Expand Up @@ -470,26 +469,16 @@ def __setitem__(self, slice_, value):
else:
self._tensors[subpath][slice_list] = assign_value

def filter(self, dic):
"""| Applies a filter to get a new datasetview that matches the dictionary provided
def filter(self, fn):
"""| Applies a function on each element one by one as a filter to get a new DatasetView
Parameters
----------
dic: dictionary
A dictionary of key value pairs, used to filter the dataset. For nested schemas use flattened dictionary representation
i.e instead of {"abc": {"xyz" : 5}} use {"abc/xyz" : 5}
fn: function
Should take in a single sample of the dataset and return True or False
This function is applied to all the items of the datasetview and retains those items that return True
"""
indexes = self.indexes
for k, v in dic.items():
k = k if k.startswith("/") else "/" + k
if k not in self.keys:
raise KeyError(f"Key {k} not found in the dataset")
tsv = self[k]
max_shape = tsv.dtype.max_shape
prod = _tuple_product(max_shape)
if prod > 100:
raise LargeShapeFilteringException(k)
indexes = [index for index in indexes if tsv[index].compute() == v]
indexes = [index for index in self.indexes if fn(self[index])]
return DatasetView(dataset=self, lazy=self.lazy, indexes=indexes)

def resize_shape(self, size: int) -> None:
Expand Down
35 changes: 15 additions & 20 deletions hub/api/datasetview.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
slice_split,
str_to_int,
)
from hub.exceptions import LargeShapeFilteringException, NoneValueException
from hub.exceptions import NoneValueException
from hub.api.objectview import ObjectView
from hub.schema import Sequence

Expand Down Expand Up @@ -175,29 +175,24 @@ def __setitem__(self, slice_, value):
current_slice = [index] + slice_list[1:]
self.dataset._tensors[subpath][current_slice] = assign_value[i]

def filter(self, dic):
"""| Applies a filter to get a new datasetview that matches the dictionary provided
def filter(self, fn):
"""| Applies a function on each element one by one as a filter to get a new DatasetView
Parameters
----------
dic: dictionary
A dictionary of key value pairs, used to filter the dataset. For nested schemas use flattened dictionary representation
i.e instead of {"abc": {"xyz" : 5}} use {"abc/xyz" : 5}
fn: function
Should take in a single sample of the dataset and return True or False
This function is applied to all the items of the datasetview and retains those items that return True
"""
indexes = self.indexes
for k, v in dic.items():
k = k if k.startswith("/") else "/" + k
if k not in self.keys:
raise KeyError(f"Key {k} not found in the dataset")
tsv = self.dataset[k]
max_shape = tsv.dtype.max_shape
prod = _tuple_product(max_shape)
if prod > 100:
raise LargeShapeFilteringException(k)
if isinstance(indexes, list):
indexes = [index for index in indexes if tsv[index].compute() == v]
else:
indexes = indexes if tsv[indexes].compute() == v else []
indexes = []
if isinstance(self.indexes, int):
dsv = self.dataset[self.indexes]
if fn(dsv):
return DatasetView(
dataset=self.dataset, lazy=self.lazy, indexes=self.indexes
)
else:
indexes = [index for index in self.indexes if fn(self.dataset[index])]
return DatasetView(dataset=self.dataset, lazy=self.lazy, indexes=indexes)

@property
Expand Down
78 changes: 53 additions & 25 deletions hub/api/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import cloudpickle
import pickle
from hub.cli.auth import login_fn
from hub.exceptions import HubException, LargeShapeFilteringException
import numpy as np
import pytest
from hub import transform
Expand Down Expand Up @@ -748,7 +747,38 @@ def test_dataset_assign_value():
assert ds["text", 6].compute() == "YGFJN75NF"


def test_dataset_filtering():
def test_dataset_filter():
def abc_filter(sample):
return sample["ab"].compute().startswith("abc")

my_schema = {"img": Tensor((100, 100)), "ab": Text((None,), max_shape=(10,))}
ds = Dataset("./data/new_filter", shape=(10,), schema=my_schema)
for i in range(10):
ds["img", i] = i * np.ones((100, 100))
ds["ab", i] = "abc" + str(i) if i % 2 == 0 else "def" + str(i)

ds2 = ds.filter(abc_filter)
assert ds2.indexes == [0, 2, 4, 6, 8]


def test_datasetview_filter():
def abc_filter(sample):
return sample["ab"].compute().startswith("abc")

my_schema = {"img": Tensor((100, 100)), "ab": Text((None,), max_shape=(10,))}
ds = Dataset("./data/new_filter", shape=(10,), schema=my_schema)
for i in range(10):
ds["img", i] = i * np.ones((100, 100))
ds["ab", i] = "abc" + str(i) if i % 2 == 0 else "def" + str(i)
dsv = ds[2:7]
ds2 = dsv.filter(abc_filter)
assert ds2.indexes == [2, 4, 6]
dsv2 = ds[2]
ds3 = dsv2.filter(abc_filter)
assert ds3.indexes == 2


def test_dataset_filter_2():
my_schema = {
"fname": Text((None,), max_shape=(10,)),
"lname": Text((None,), max_shape=(10,)),
Expand All @@ -764,7 +794,9 @@ def test_dataset_filtering():
for i in [15, 31, 25, 75, 3, 6]:
ds["lname", i] = "loop"

dsv_combined = ds.filter({"fname": "Active", "lname": "loop"})
dsv_combined = ds.filter(
lambda x: x["fname"].compute() == "Active" and x["lname"].compute() == "loop"
)
tsv_combined_fname = dsv_combined["fname"]
tsv_combined_lname = dsv_combined["lname"]
for item in dsv_combined:
Expand All @@ -773,8 +805,8 @@ def test_dataset_filtering():
assert item.compute() == "Active"
for item in tsv_combined_lname:
assert item.compute() == "loop"
dsv_1 = ds.filter({"fname": "Active"})
dsv_2 = dsv_1.filter({"lname": "loop"})
dsv_1 = ds.filter(lambda x: x["fname"].compute() == "Active")
dsv_2 = dsv_1.filter(lambda x: x["lname"].compute() == "loop")
for item in dsv_1:
assert item.compute()["fname"] == "Active"
tsv_1 = dsv_1["fname"]
Expand All @@ -789,8 +821,8 @@ def test_dataset_filtering():
assert dsv_1.indexes == [1, 3, 6, 15, 63, 75, 96]
assert dsv_2.indexes == [3, 6, 15, 75]

dsv_3 = ds.filter({"lname": "loop"})
dsv_4 = dsv_3.filter({"fname": "Active"})
dsv_3 = ds.filter(lambda x: x["lname"].compute() == "loop")
dsv_4 = dsv_3.filter(lambda x: x["fname"].compute() == "Active")
for item in dsv_3:
assert item.compute()["lname"] == "loop"
for item in dsv_4:
Expand All @@ -803,52 +835,48 @@ def test_dataset_filtering():
"lname": Text((None,), max_shape=(10,)),
"image": Image((1920, 1080, 3)),
}
ds = Dataset("./data/tests/filtering", shape=(100,), schema=my_schema2, mode="w")
with pytest.raises(LargeShapeFilteringException):
ds.filter({"image": np.ones((1920, 1080, 3))})
ds = Dataset("./data/tests/filtering2", shape=(100,), schema=my_schema2, mode="w")
with pytest.raises(KeyError):
ds.filter({"random": np.ones((1920, 1080, 3))})
ds.filter(lambda x: (x["random"].compute() == np.ones((1920, 1080, 3))).all())

for i in [1, 3, 6, 15, 63, 96, 75]:
ds["fname", i] = "Active"
dsv = ds.filter({"fname": "Active"})
with pytest.raises(LargeShapeFilteringException):
dsv.filter({"image": np.ones((1920, 1080, 3))})
dsv = ds.filter(lambda x: x["fname"].compute() == "Active")
with pytest.raises(KeyError):
dsv.filter({"random": np.ones((1920, 1080, 3))})
dsv.filter(lambda x: (x["random"].compute() == np.ones((1920, 1080, 3))).all())


def test_dataset_filtering_2():
def test_dataset_filter_3():
schema = {
"img": Image((None, None, 3), max_shape=(100, 100, 3)),
"cl": ClassLabel(names=["cat", "dog", "horse"]),
}
ds = Dataset("./data/tests/filtering_2", shape=(100,), schema=schema, mode="w")
ds = Dataset("./data/tests/filtering_3", shape=(100,), schema=schema, mode="w")
for i in range(100):
ds["cl", i] = 0 if i % 5 == 0 else 1
ds["img", i] = i * np.ones((5, 6, 3))
ds["cl", 4] = 2
ds_filtered = ds.filter({"cl": 0})
ds_filtered = ds.filter(lambda x: x["cl"].compute() == 0)
assert ds_filtered.indexes == [5 * i for i in range(20)]
with pytest.raises(ValueError):
ds_filtered["img"].compute()
ds_filtered_2 = ds.filter({"cl": 2})
ds_filtered_2 = ds.filter(lambda x: x["cl"].compute() == 2)
assert (ds_filtered_2["img"].compute() == 4 * np.ones((1, 5, 6, 3))).all()
for item in ds_filtered_2:
assert (item["img"].compute() == 4 * np.ones((5, 6, 3))).all()
assert item["cl"].compute() == 2


def test_dataset_filtering_3():
def test_dataset_filter_4():
schema = {
"img": Image((None, None, 3), max_shape=(100, 100, 3)),
"cl": ClassLabel(names=["cat", "dog", "horse"]),
}
ds = Dataset("./data/tests/filtering_3", shape=(100,), schema=schema, mode="w")
ds = Dataset("./data/tests/filtering_4", shape=(100,), schema=schema, mode="w")
for i in range(100):
ds["cl", i] = 0 if i < 10 else 1
ds["img", i] = i * np.ones((5, 6, 3))
ds_filtered = ds.filter({"cl": 0})
ds_filtered = ds.filter(lambda x: x["cl"].compute() == 0)
assert (ds_filtered[3:8, "cl"].compute() == np.zeros((5,))).all()


Expand Down Expand Up @@ -951,12 +979,12 @@ def test_minio_endpoint():
test_dataset_view_lazy()
test_dataset_hub()
test_meta_information()
test_dataset_filtering()
test_dataset_filtering_2()
test_dataset_filter_2()
test_dataset_filter_3()
test_pickleability()
test_dataset_append_and_read()
test_tensorview_iter()
test_dataset_filtering_3()
test_dataset_filter_4()
test_datasetview_2()
test_dataset_3()
test_dataset_utils()
Expand Down
6 changes: 0 additions & 6 deletions hub/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,12 +170,6 @@ def __init__(self):
super(HubException, self).__init__(message=message)


class LargeShapeFilteringException(HubException):
def __init__(self, key):
message = f"The shape of {key} is large (product > 100), use smaller keys for filtering"
super(HubException, self).__init__(message=message)


class ValueShapeError(HubException):
def __init__(self, correct_shape, wrong_shape):
message = f"parameter 'value': expected array with shape {correct_shape}, got {wrong_shape}"
Expand Down

0 comments on commit cc3abc2

Please sign in to comment.