Skip to content

Commit

Permalink
Merge pull request #1288 from FayazRahman/fy_clear_tensor
Browse files Browse the repository at this point in the history
tensor.clear() to delete all samples from tensor
  • Loading branch information
FayazRahman authored Mar 28, 2022
2 parents 8c5a710 + 32c4ec2 commit ff63116
Show file tree
Hide file tree
Showing 16 changed files with 384 additions and 32 deletions.
28 changes: 28 additions & 0 deletions hub/api/tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from hub.tests.common import assert_array_lists_equal
from hub.tests.storage_fixtures import enabled_remote_storages
from hub.tests.dataset_fixtures import enabled_persistent_dataset_generators
from hub.core.storage import GCSProvider
from hub.util.exceptions import (
RenameError,
Expand Down Expand Up @@ -1008,6 +1009,33 @@ def test_tobytes(memory_ds, compressed_image_paths, audio_paths):
assert ds.audio[i].tobytes() == audio_bytes


@enabled_persistent_dataset_generators
def test_tensor_clear(ds_generator):
ds = ds_generator()
a = ds.create_tensor("a")
a.extend([1, 2, 3, 4])
a.clear()
assert len(ds) == 0
assert len(a) == 0

image = ds.create_tensor("image", htype="image", sample_compression="png")
image.extend(np.ones((4, 224, 224, 3), dtype="uint8"))
image.extend(np.ones((4, 224, 224, 3), dtype="uint8"))
image.clear()
assert len(ds) == 0
assert len(image) == 0
assert image.htype == "image"
assert image.meta.sample_compression == "png"
image.extend(np.ones((4, 224, 224, 3), dtype="uint8"))
a.append([1, 2, 3])

ds = ds_generator()
assert len(ds) == 1
assert len(image) == 4
assert image.htype == "image"
assert image.meta.sample_compression == "png"


def test_no_view(memory_ds):
memory_ds.create_tensor("a")
memory_ds.a.extend([0, 1, 2, 3])
Expand Down
15 changes: 15 additions & 0 deletions hub/api/tests/test_chunk_sizes.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,11 @@ def _extend_tensors(images, labels):
labels.extend(np.ones(100, dtype=np.uint32))


def _clear_tensors(images, labels):
images.clear()
labels.clear()


def test_append(memory_ds):
ds = memory_ds
images, labels = _create_tensors(ds)
Expand Down Expand Up @@ -99,3 +104,13 @@ def test_extend_and_append(memory_ds):
_assert_num_chunks(images, 20)

assert len(ds) == 400


def test_clear(memory_ds):
ds = memory_ds
images, labels = _create_tensors(ds)

_clear_tensors(images, labels)

_assert_num_chunks(labels, 0)
_assert_num_chunks(images, 0)
42 changes: 41 additions & 1 deletion hub/core/chunk_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
get_tensor_commit_chunk_set_key,
get_tensor_meta_key,
get_tensor_tile_encoder_key,
get_tensor_info_key,
)
from hub.util.exceptions import (
CorruptedMetaError,
Expand Down Expand Up @@ -656,10 +657,49 @@ def _create_new_chunk(self, register=True) -> BaseChunk:
chunk._update_tensor_meta_length = register
if self.active_appended_chunk is not None:
self.write_chunk_to_storage(self.active_appended_chunk)

self.active_appended_chunk = chunk
return chunk

def clear(self):
"""Clears all samples and cachables."""
self.cache.check_readonly()

commit_id = self.commit_id

chunk_folder_path = get_chunk_key(self.key, "", commit_id)
self.cache.clear(prefix=chunk_folder_path)

enc_key = get_chunk_id_encoder_key(self.key, commit_id)
self._chunk_id_encoder = None
try:
del self.meta_cache[enc_key]
except KeyError:
pass

info_key = get_tensor_info_key(self.key, commit_id)
try:
self._info = None
del self.cache[info_key]
except KeyError:
pass

self.commit_diff.clear_data()

tile_encoder_key = get_tensor_tile_encoder_key(self.key, commit_id)
try:
self._tile_encoder = None
del self.cache[tile_encoder_key]
except KeyError:
pass

self.tensor_meta.length = 0
self.tensor_meta.min_shape = []
self.tensor_meta.max_shape = []
self.tensor_meta.is_dirty = True

self.cache.maybe_flush()
self.meta_cache.maybe_flush()

def _replace_tiled_sample(self, global_sample_index: int, sample):
new_chunks, tiles = self._samples_to_chunks(
[sample], start_chunk=None, register=False
Expand Down
12 changes: 8 additions & 4 deletions hub/core/storage/gcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,10 +271,11 @@ def _all_keys(self):
self._blob_objects = self.client_bucket.list_blobs(prefix=self.path)
return {posixpath.relpath(obj.name, self.path) for obj in self._blob_objects}

def clear(self):
"""Remove all keys below root - empties out mapping"""
def clear(self, prefix=""):
"""Remove all keys with given prefix below root - empties out mapping"""
self.check_readonly()
blob_objects = self.client_bucket.list_blobs(prefix=self.path)
path = posixpath.join(self.path, prefix) if prefix else self.path
blob_objects = self.client_bucket.list_blobs(prefix=path)
for blob in blob_objects:
try:
blob.delete()
Expand Down Expand Up @@ -336,7 +337,10 @@ def __delitem__(self, key):
"""Remove key"""
self.check_readonly()
blob = self.client_bucket.blob(self._get_path_from_key(key))
blob.delete()
try:
blob.delete()
except self.missing_exceptions:
raise KeyError(key)

def __contains__(self, key):
"""Does key exist in mapping?"""
Expand Down
10 changes: 7 additions & 3 deletions hub/core/storage/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,11 +184,15 @@ def _check_is_file(self, path: str):
raise DirectoryAtPathException
return full_path

def clear(self):
"""Deletes ALL data on the local machine (under self.root). Exercise caution!"""
def clear(self, prefix=""):
"""Deletes ALL data with keys having given prefix on the local machine (under self.root). Exercise caution!"""
self.check_readonly()
self.files = set()
full_path = os.path.expanduser(self.root)
if prefix and self.files:
self.files = set(file for file in self.files if not file.startswith(prefix))
full_path = os.path.join(full_path, prefix)
else:
self.files = set()
if os.path.exists(full_path):
shutil.rmtree(full_path)

Expand Down
26 changes: 19 additions & 7 deletions hub/core/storage/lru_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,18 +246,30 @@ def clear_cache_without_flush(self):
if self.next_storage is not None and hasattr(self.next_storage, "clear_cache"):
self.next_storage.clear_cache()

def clear(self):
def clear(self, prefix=""):
"""Deletes ALL the data from all the layers of the cache and the actual storage.
This is an IRREVERSIBLE operation. Data once deleted can not be recovered.
"""
self.check_readonly()
self.cache_used = 0
self.lru_sizes.clear()
self.dirty_keys.clear()
self.cache_storage.clear()
self.hub_objects.clear()
if prefix:
rm = [path for path in self.hub_objects if path.startswith(prefix)]
for path in rm:
self.remove_hub_object(path)

rm = [path for path in self.lru_sizes if path.startswith(prefix)]
for path in rm:
size = self.lru_sizes.pop(path)
self.cache_used -= size
self.dirty_keys.discard(path)
else:
self.cache_used = 0
self.lru_sizes.clear()
self.dirty_keys.clear()
self.hub_objects.clear()

self.cache_storage.clear(prefix=prefix)
if self.next_storage is not None:
self.next_storage.clear()
self.next_storage.clear(prefix=prefix)

def __len__(self):
"""Returns the number of files present in the cache and the underlying storage.
Expand Down
7 changes: 5 additions & 2 deletions hub/core/storage/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,10 +101,13 @@ def _all_keys(self):
"""
return set(self.dict.keys())

def clear(self):
def clear(self, prefix=""):
"""Clears the provider."""
self.check_readonly()
self.dict = {}
if prefix:
self.dict = {k: v for k, v in self.dict.items() if not k.startswith(prefix)}
else:
self.dict = {}

def __getstate__(self) -> str:
"""Does NOT save the in memory data in state."""
Expand Down
2 changes: 1 addition & 1 deletion hub/core/storage/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def maybe_flush(self):
self.flush()

@abstractmethod
def clear(self):
def clear(self, prefix=""):
"""Delete the contents of the provider."""

def delete_multiple(self, paths: Sequence[str]):
Expand Down
7 changes: 4 additions & 3 deletions hub/core/storage/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,14 +320,15 @@ def __iter__(self):
self._check_update_creds()
yield from self._all_keys()

def clear(self):
"""Deletes ALL data on the s3 bucket (under self.root). Exercise caution!"""
def clear(self, prefix=""):
"""Deletes ALL data with keys having given prefix on the s3 bucket (under self.root). Exercise caution!"""
self.check_readonly()
self._check_update_creds()
path = posixpath.join(self.path, prefix) if prefix else self.path
if self.resource is not None:
try:
bucket = self.resource.Bucket(self.bucket)
bucket.objects.filter(Prefix=self.path).delete()
bucket.objects.filter(Prefix=path).delete()
except Exception as err:
reload = self.need_to_reload_creds(err)
manager = S3ReloadCredentialsManager if reload else S3ResetClientManager
Expand Down
13 changes: 13 additions & 0 deletions hub/core/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
get_tensor_tile_encoder_key,
tensor_exists,
get_tensor_info_key,
get_sample_id_tensor_key,
)
from hub.util.keys import get_tensor_meta_key, tensor_exists, get_tensor_info_key
from hub.util.modified import get_modified_indexes
Expand Down Expand Up @@ -311,6 +312,18 @@ def append(
"""
self.extend([sample])

def clear(self):
"""Deletes all samples from the tensor"""
self.chunk_engine.clear()
sample_id_key = get_sample_id_tensor_key(self.key)
try:
sample_id_tensor = Tensor(sample_id_key, self.dataset)
sample_id_tensor.chunk_engine.clear()
self.meta.links.clear()
self.meta.is_dirty = True
except TensorDoesNotExistError:
pass

def modified_samples(
self, target_id: Optional[str] = None, return_indexes: Optional[bool] = False
):
Expand Down
6 changes: 6 additions & 0 deletions hub/core/transform/test_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,11 +352,15 @@ def test_add_to_non_empty_dataset(local_ds, scheduler, do_commit):
},
}
if do_commit:
change["image"]["cleared"] = False
change["label"]["cleared"] = False
change["image"]["created"] = False
change["label"]["created"] = False
change["image"]["data_added"] = [10, 610]
change["label"]["data_added"] = [10, 610]
else:
change["image"]["cleared"] = False
change["label"]["cleared"] = False
change["image"]["created"] = True
change["label"]["created"] = True
change["image"]["data_added"] = [0, 610]
Expand Down Expand Up @@ -617,13 +621,15 @@ def test_inplace_transform(local_ds_generator):
change = {
"img": {
"created": False,
"cleared": False,
"data_added": [0, 20],
"data_updated": set(),
"data_transformed_in_place": True,
"info_updated": False,
},
"label": {
"created": False,
"cleared": False,
"data_added": [0, 20],
"data_updated": set(),
"data_transformed_in_place": True,
Expand Down
13 changes: 13 additions & 0 deletions hub/core/version_control/commit_diff.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ def __init__(self, first_index=0, created=False) -> None:
self.data_added: List[int] = [first_index, first_index]
self.data_updated: Set[int] = set()
self.info_updated = False
self.cleared = False

# this is stored for in place transforms in which we no longer need to considered older diffs about added/updated data
self.data_transformed = False
Expand All @@ -25,6 +26,7 @@ def tobytes(self) -> bytes:
4. The next 8 + 8 bytes are the two elements of the data_added list.
5. The next 8 bytes are the number of elements in the data_updated set, let's call this m.
6. The next 8 * m bytes are the elements of the data_updated set.
7. The last byte is a boolean value indicating whether the tensor was cleared in the commit or not.
"""
return b"".join(
[
Expand All @@ -35,6 +37,7 @@ def tobytes(self) -> bytes:
self.data_added[1].to_bytes(8, "big"),
len(self.data_updated).to_bytes(8, "big"),
*(idx.to_bytes(8, "big") for idx in self.data_updated),
self.cleared.to_bytes(1, "big"),
]
)

Expand All @@ -55,6 +58,8 @@ def frombuffer(cls, data: bytes) -> "CommitDiff":
int.from_bytes(data[27 + i * 8 : 35 + i * 8], "big")
for i in range(num_updates)
}
pos = 35 + (num_updates - 1) * 8
commit_diff.cleared = bool(int.from_bytes(data[pos : pos + 1], "big"))
commit_diff.is_dirty = False
return commit_diff

Expand Down Expand Up @@ -84,6 +89,14 @@ def update_data(self, global_index: int) -> None:
self.data_updated.add(global_index)
self.is_dirty = True

def clear_data(self):
"""Clears data"""
self.data_added = [0, 0]
self.data_updated = set()
self.info_updated = False
self.cleared = True
self.is_dirty = True

def transform_data(self) -> None:
"""Stores information that the data has been transformed using an inplace transform."""
self.data_transformed = True
Expand Down
20 changes: 20 additions & 0 deletions hub/core/version_control/test_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,3 +250,23 @@ def test_conflicts(local_ds):

ds.merge("other", conflict_resolution="theirs")
np.testing.assert_array_equal(ds.image[4].numpy(), 25 * np.ones((200, 200, 3)))


def test_clear_merge(local_ds):
with local_ds as ds:
ds.create_tensor("abc")
ds.abc.append([1, 2, 3])
a = ds.commit()

ds.checkout("alt", create=True)
ds.abc.append([2, 3, 4])
b = ds.commit()
ds.abc.clear()
c = ds.commit()

ds.checkout("main")
ds.abc.append([5, 6, 3])
d = ds.commit()
ds.merge("alt")

np.testing.assert_array_equal(ds.abc.numpy(), np.array([]))
Loading

0 comments on commit ff63116

Please sign in to comment.