Skip to content

Commit

Permalink
Added sample_indices api to virtual tensors. (#2654)
Browse files Browse the repository at this point in the history
* Added sample_indices() to virtual tensor and dataset.

* Added test.

* Fixed test for old libdeeplake. Need to revisit after libdeeplake release.

* Added more tests.
  • Loading branch information
khustup committed Oct 13, 2023
1 parent 46a2c0e commit dadbc8c
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 0 deletions.
9 changes: 9 additions & 0 deletions deeplake/core/dataset/deeplake_query_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,15 @@ def no_view_dataset(self):
def index(self):
return self._index

@property
def sample_indices(self):
for t in self.tensors.values():
try:
return t.indra_tensor.indexes
except RuntimeError:
pass
return range(self.num_samples)

def _tensors(
self, include_hidden: bool = True, include_disabled=True
) -> Dict[str, Tensor]:
Expand Down
7 changes: 7 additions & 0 deletions deeplake/core/dataset/deeplake_query_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,13 @@ def min_shape(self):
def chunk_engine(self):
raise NotImplementedError("Virtual tensor does not have chunk engine.")

@property
def sample_indices(self):
try:
return self.indra_tensor.indexes
except RuntimeError:
return range(self.num_samples)

@property
def shape(self):
if (
Expand Down
16 changes: 16 additions & 0 deletions deeplake/core/tests/test_deeplake_indra_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,3 +423,19 @@ def test_virtual_tensors(local_auth_ds_generator):
assert deeplake_indra_ds.score[100 - i].numpy() == [
np.sqrt(2.0 / (i + 1) / (i + 1))
]

assert deeplake_indra_ds.sample_indices == slice(0, 100, 1)
deeplake_indra_ds = deeplake_ds.query(
"SELECT *, l2_norm(embeddings - ARRAY[0, 0, 0]) as score order by l2_norm(embeddings - ARRAY[0, 0, 0]) asc"
)
assert list(deeplake_indra_ds.sample_indices) == list(reversed(range(100)))
assert list(deeplake_indra_ds.embeddings.sample_indices) == list(
reversed(range(100))
)
assert deeplake_indra_ds.score.sample_indices == slice(0, 100, 1)

deeplake_indra_ds = deeplake_ds.query(
"SELECT l2_norm(embeddings - ARRAY[0, 0, 0]) as score order by l2_norm(embeddings - ARRAY[0, 0, 0]) asc"
)
assert deeplake_indra_ds.sample_indices == slice(0, 100, 1)
assert deeplake_indra_ds.score.sample_indices == slice(0, 100, 1)

0 comments on commit dadbc8c

Please sign in to comment.