Skip to content

Commit

Permalink
Merge pull request #2323 from activeloopai/fr_pop_merge
Browse files Browse the repository at this point in the history
[AL-2249] Pop + Merge
  • Loading branch information
FayazRahman committed May 2, 2023
2 parents efee34a + b7fd64c commit 9177d72
Show file tree
Hide file tree
Showing 5 changed files with 44 additions and 9 deletions.
3 changes: 2 additions & 1 deletion deeplake/core/chunk_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2112,6 +2112,7 @@ def pop(
self,
global_sample_index: Optional[int] = None,
link_callback: Optional[Callable] = None,
sample_id: Optional[int] = None,
):
if global_sample_index is None:
if self.is_sequence:
Expand All @@ -2134,7 +2135,7 @@ def pop(
if link_callback:
link_callback(global_sample_index)

self.commit_diff.pop(global_sample_index)
self.commit_diff.pop(global_sample_index, sample_id)
if self.is_sequence:
# pop in reverse order else indices get shifted
for idx in reversed(range(*self.sequence_encoder[global_sample_index])):
Expand Down
10 changes: 7 additions & 3 deletions deeplake/core/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1047,11 +1047,15 @@ def _update_links(
@invalid_view_op
def pop(self, index: Optional[int] = None):
"""Removes an element at the given index."""

sample_id_tensor = self._sample_id_tensor
if index is None:
index = self.num_samples - 1
sample_id = int(sample_id_tensor[index].numpy()) if sample_id_tensor else None
self.chunk_engine.pop(
index, link_callback=self._pop_links if self.meta.links else None
index,
link_callback=self._pop_links if self.meta.links else None,
sample_id=sample_id,
)

self.invalidate_libdeeplake_dataset()

def _pop_links(self, global_sample_index: int):
Expand Down
17 changes: 14 additions & 3 deletions deeplake/core/version_control/commit_diff.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ def __init__(self, first_index=0, created=False) -> None:
self.data_added: List[int] = [first_index, first_index]
self.data_updated: Set[int] = set()
self.data_deleted: Set[int] = set()
self.data_deleted_ids: Set[int] = set()
self.info_updated = False
self.cleared = False

Expand All @@ -30,6 +31,7 @@ def tobytes(self) -> bytes:
7. The next byte is a boolean value indicating whether the tensor was cleared in the commit or not.
8. The next 8 bytes are the number of elements in the data_deleted set, let's call this n.
9. The next 8 * n bytes are the elements of the data_deleted set.
9. The next 8 * n bytes are the elements of the data_deleted_ids set.
"""
return b"".join(
[
Expand All @@ -43,6 +45,7 @@ def tobytes(self) -> bytes:
self.cleared.to_bytes(1, "big"),
len(self.data_deleted).to_bytes(8, "big"),
*(idx.to_bytes(8, "big") for idx in self.data_deleted),
*(idx.to_bytes(8, "big") for idx in self.data_deleted_ids),
]
)

Expand All @@ -67,15 +70,21 @@ def frombuffer(cls, data: bytes) -> "CommitDiff":
commit_diff.cleared = bool(int.from_bytes(data[pos : pos + 1], "big"))
commit_diff.is_dirty = False
pos += 1
commit_diff.data_deleted = set()
commit_diff.data_deleted_ids = set()
if len(data) > pos:
num_deletes = int.from_bytes(data[pos : pos + 8], "big")
pos += 8
commit_diff.data_deleted = {
int.from_bytes(data[pos + i * 8 : pos + i * 8 + 8], "big")
for i in range(num_deletes)
}
else:
commit_diff.data_deleted = set()
pos += num_deletes * 8
if len(data) > pos:
commit_diff.data_deleted_ids = {
int.from_bytes(data[pos + i * 8 : pos + i * 8 + 8], "big")
for i in range(num_deletes)
}
return commit_diff

@property
Expand Down Expand Up @@ -119,7 +128,7 @@ def transform_data(self) -> None:
self.data_transformed = True
self.is_dirty = True

def pop(self, index) -> None:
def pop(self, index, id) -> None:
index = self.translate_index(index)
if index not in range(*self.data_added):
self.data_deleted.add(index)
Expand All @@ -132,6 +141,8 @@ def pop(self, index) -> None:
self.data_updated = {
idx - 1 if idx > index else idx for idx in self.data_updated
}
if id is not None:
self.data_deleted_ids.add(id)
self.is_dirty = True

def translate_index(self, index):
Expand Down
16 changes: 16 additions & 0 deletions deeplake/core/version_control/test_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -623,3 +623,19 @@ def test_merge_with_padding(memory_ds):
ds.checkout("branch1")
ds.merge("branch2")
assert len(ds.x) == 304


def test_merge_with_pop(memory_ds):
with memory_ds as ds:
ds.create_tensor("x")
ds.x.extend([1, 2, 3, 4, 5])
cid = ds.commit()
ds.checkout("branch1", create=True)
ds.pop(2)
ds.checkout(cid)
ds.checkout("branch2", create=True)
ds.pop(3)
ds.x.append(6)
ds.checkout("branch1")
ds.merge("branch2")
np.testing.assert_array_equal(ds.x.numpy().flatten(), [1, 2, 4, 5, 6])
7 changes: 5 additions & 2 deletions deeplake/util/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,7 @@ def merge_common_tensors(
)
for tensor_name in tensor_names
}

all_new_idxs = set()
for new_idxs, _, _ in idxs.values():
all_new_idxs.update(new_idxs)
Expand Down Expand Up @@ -437,6 +438,9 @@ def find_new_updated_and_conflict_indexes(
target_id_tensor = target_dataset[id_tensor_name]
original_id_tensor = dataset[id_tensor_name]

commit_diff = dataset[tensor_name].chunk_engine.commit_diff
deleted_samples = commit_diff.data_deleted_ids if commit_diff else set()

original_node = nodes["original"]
target_node = nodes["target"]
lca_node = nodes["lca"]
Expand All @@ -456,11 +460,10 @@ def find_new_updated_and_conflict_indexes(
target_id_to_index_map = {id: idx for idx, id in enumerate(target_ids)}

new_elements_ids = set(target_ids) - set(original_ids)

new_elements_ids = new_elements_ids - deleted_samples
new_indexes = get_indexes_from_ids(
new_elements_ids, target_id_changes_commit_map, target_id_to_index_map
)

conflict_indexes: List[Tuple[int, int]] = []
updated_indexes: List[Tuple[int, int]] = []
updated_indexes, conflict_indexes = find_updated_and_conflicts(
Expand Down

0 comments on commit 9177d72

Please sign in to comment.