Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

allow non-int sample ids in DatasetDefault #259

Merged
merged 7 commits into from Jan 24, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
35 changes: 22 additions & 13 deletions fuse/data/datasets/dataset_default.py
Expand Up @@ -38,17 +38,21 @@
class DatasetDefault(DatasetBase):
def __init__(
self,
sample_ids: Union[int, Sequence[Hashable]],
sample_ids: Union[int, Sequence[Hashable], None],
static_pipeline: Optional[PipelineDefault] = None,
dynamic_pipeline: Optional[PipelineDefault] = None,
cacher: Optional[SamplesCacher] = None,
allow_uncached_sample_morphing: bool = False,
):
"""
:param sample_ids: list of sample_ids included in dataset.
Optionally, you can provide an integer that describes only the size of the dataset. This is useful in massive datasets
(for example 100M samples). In such case, multiple functionalities will not be supported, mainly -
cacher, allow_uncached_sample_morphing and get_all_sample_ids
:param sample_ids: list of sample_ids included in dataset. Or:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

- An integer that describes only the size of the dataset. This is useful in massive datasets
(for example 100M samples). In such case, multiple functionalities will not be supported, mainly -
cacher, allow_uncached_sample_morphing and get_all_sample_ids
- None. In this case, the dataset will not deal with sample ids. it is the user's respobsibility to handle
iterations w.r.t the length of the dataset, as well as the index passed to __getitem__
this is useful for massive datasets, but when the sample ids are not expected to be running integets from 0 to a given length.

:param static_pipeline: static_pipeline, the output of this pipeline will be automatically cached.
:param dynamic_pipeline: dynamic_pipeline. applied sequentially after the static_pipeline, but not automatically cached.
changing it will NOT trigger recaching of the static_pipeline part.
Expand All @@ -67,9 +71,11 @@ def __init__(
)
if cacher is not None:
raise Exception("providing a cacher is not allowed when providing sample_ids=an integer value")
self._explicit_sample_ids_mode = False
self._sample_ids_mode = "running_int"
elif sample_ids is None:
self._sample_ids_mode = "external"
else:
self._explicit_sample_ids_mode = True
self._sample_ids_mode = "explicit"

# self._orig_sample_ids = sample_ids
self._allow_uncached_sample_morphing = allow_uncached_sample_morphing
Expand Down Expand Up @@ -155,8 +161,8 @@ def get_all_sample_ids(self):
if not self._created:
raise Exception("you must first call create()")

if not self._explicit_sample_ids_mode:
raise Exception("get_all_sample_ids is not supported when constructed with an integer for sample_ids")
if self._sample_ids_mode != "explicit":
raise Exception("get_all_sample_ids is not supported when constructed with non explicit sample_ids")

return copy.deepcopy(self._final_sample_ids)

Expand Down Expand Up @@ -185,10 +191,11 @@ def getitem(
raise Exception("you must first call create()")

# get sample id
if not self._explicit_sample_ids_mode:
if self._sample_ids_mode != "explicit":
sample_id = item
if sample_id >= self._final_sample_ids:
raise IndexError
if self._sample_ids_mode == "running_int": # allow using non int sample_ids
if sample_id >= self._final_sample_ids:
raise IndexError

elif not isinstance(item, (int, np.integer)):
sample_id = item
Expand Down Expand Up @@ -280,8 +287,10 @@ def __len__(self):
if not self._created:
raise Exception("you must first call create()")

if not self._explicit_sample_ids_mode:
if self._sample_ids_mode == "running_int":
return self._final_sample_ids
elif self._sample_ids_mode == "external":
raise Exception("__len__ is not defined where explicit sample_ids or an interer len are not provided.")

return len(self._final_sample_ids)

Expand Down