Skip to content

Commit

Permalink
Merge pull request #2340 from activeloopai/fy_tfm_append_empty
Browse files Browse the repository at this point in the history
[AL-2278] append_empty option in transforms
  • Loading branch information
FayazRahman committed May 11, 2023
2 parents 54a1b8c + c3cf1cb commit 266e1e8
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 10 deletions.
26 changes: 26 additions & 0 deletions deeplake/core/transform/test_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from deeplake.util.check_installation import ray_installed
from deeplake.util.exceptions import (
AllSamplesSkippedError,
EmptyTensorError,
InvalidOutputDatasetError,
TransformError,
)
Expand Down Expand Up @@ -1549,3 +1550,28 @@ def upload(stuff, ds):
np.testing.assert_array_equal(ds2.xyz.numpy(), ds.xyz.numpy())

ds2.delete()


def test_ds_append_empty(local_ds):
@deeplake.compute
def upload(stuff, ds):
ds.append(stuff, append_empty=True)

with local_ds as ds:
ds.create_tensor("images", htype="image", sample_compression="png")
ds.create_tensor("label1", htype="class_label")
ds.create_tensor("label2", htype="class_label")

samples = [
{"images": np.random.randint(0, 255, (10, 10, 3), dtype=np.uint8), "label1": 1}
for _ in range(20)
]

upload().eval(samples, ds, num_workers=TRANSFORM_TEST_NUM_WORKERS)

with pytest.raises(EmptyTensorError):
ds.label2.numpy()

ds.label2.append(1)

np.testing.assert_array_equal(ds.label2[:20].numpy(), np.array([]).reshape((20, 0)))
20 changes: 15 additions & 5 deletions deeplake/core/transform/transform_dataset.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from deeplake.util.exceptions import SampleAppendError, TensorDoesNotExistError
from deeplake.core.transform.transform_tensor import TransformTensor
from deeplake.core.linked_tiled_sample import LinkedTiledSample
from deeplake.util.exceptions import SampleAppendError
from deeplake.core.partial_sample import PartialSample
from deeplake.core.linked_sample import LinkedSample
from deeplake.core.sample import Sample
Expand Down Expand Up @@ -59,12 +59,22 @@ def __iter__(self):
for i in range(len(self)):
yield self[i]

def append(self, sample):
def append(self, sample, skip_ok=False, append_empty=False):
if skip_ok:
raise ValueError(
"`skip_ok` is not supported for `ds.append` in transforms. Use `skip_ok` parameter of the `eval` method instead."
)

if len(set(map(len, (self[k] for k in sample)))) != 1:
raise ValueError("All tensors are expected to have the same length.")
raise ValueError(
"All tensors are expected to have the same length before `ds.append`."
)

for k, v in sample.items():
self[k].append(v)
for k in self.tensors:
if k in sample:
self[k].append(sample[k])
elif append_empty:
self[k].append(None)

def item_added(self, item):
if isinstance(item, Sample):
Expand Down
5 changes: 2 additions & 3 deletions deeplake/core/transform/transform_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,10 +112,9 @@ def _verify_item(self, item):

def append(self, item):
"""Adds an item to the tensor."""
if self.is_group:
raise TensorDoesNotExistError(self.name)
try:
if self.is_group:
raise TensorDoesNotExistError(self.name)

# optimization applicable only if extending
self.non_numpy_only()

Expand Down
8 changes: 6 additions & 2 deletions deeplake/util/class_label.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import List

from deeplake.util.hash import hash_str_to_int32
from deeplake.util.exceptions import EmptyTensorError
from deeplake.client.log import logger
import numpy as np
import deeplake
Expand Down Expand Up @@ -84,8 +85,11 @@ def class_label_sync(
label_tensor: str,
hash_idx_map,
):
hashes = hash_tensor_sample.numpy().tolist()
idxs = convert_hash_to_idx(hashes, hash_idx_map)
try:
hashes = hash_tensor_sample.numpy().tolist()
idxs = convert_hash_to_idx(hashes, hash_idx_map)
except EmptyTensorError:
idxs = None
samples_out[label_tensor].append(idxs)

for tensor, temp_tensor in label_temp_tensors.items():
Expand Down

0 comments on commit 266e1e8

Please sign in to comment.