From d3b9ebaf7a39b01984db84d0231a3d73e65c898d Mon Sep 17 00:00:00 2001 From: Dmitriy Apollonin Date: Fri, 2 Dec 2022 14:58:53 -0700 Subject: [PATCH 1/4] add get_one and get_many methods to paginated collection --- labelbox/pagination.py | 38 ++++++++++++++++++++++++++----- tests/integration/test_dataset.py | 8 +++++++ 2 files changed, 40 insertions(+), 6 deletions(-) diff --git a/labelbox/pagination.py b/labelbox/pagination.py index 8d46901e5..8034eb2da 100644 --- a/labelbox/pagination.py +++ b/labelbox/pagination.py @@ -45,7 +45,7 @@ def __init__(self, """ self._fetched_all = False self._data: List[Dict[str, Any]] = [] - self._data_ind = 0 + self._data_index = 0 pagination_kwargs = { 'client': client, @@ -62,11 +62,11 @@ def __init__(self, **pagination_kwargs) def __iter__(self): - self._data_ind = 0 + self._data_index = 0 return self def __next__(self): - if len(self._data) <= self._data_ind: + if len(self._data) <= self._data_index: if self._fetched_all: raise StopIteration() @@ -75,9 +75,35 @@ def __next__(self): if len(page_data) == 0: raise StopIteration() - rval = self._data[self._data_ind] - self._data_ind += 1 - return rval + next_value = self._data[self._data_index] + self._data_index += 1 + return next_value + + def get_one(self): + """Iterates over self and returns first value + This method is idempotent + """ + for value in self: + return value + + def get_many(self, n: int): + """Iterates over self and returns first n results + This method is idempotent + + Args: + n (int): Number of elements to retrieve + """ + results = [] + i = 0 + + for value in self: + if i >= n: + break + + results.append(value) + i += 1 + + return results class _Pagination(ABC): diff --git a/tests/integration/test_dataset.py b/tests/integration/test_dataset.py index 89a89b78c..e9094ae98 100644 --- a/tests/integration/test_dataset.py +++ b/tests/integration/test_dataset.py @@ -21,6 +21,14 @@ def test_dataset(client, rand_gen): assert len(after) == len(before) + 1 assert dataset in after + # confirm get_one returns first dataset + get_one_dataset = client.get_datasets().get_one() + assert get_one_dataset.uid == after[0].uid + + # confirm get_many(1) returns first dataset + get_many_datasets = client.get_datasets().get_many(1) + assert get_many_datasets[0].uid == after[0].uid + dataset = client.get_dataset(dataset.uid) assert dataset.name == name From d4c52b65198d451770ba3fda8efcc9e2d865bfac Mon Sep 17 00:00:00 2001 From: Dmitriy Apollonin Date: Fri, 2 Dec 2022 15:07:10 -0700 Subject: [PATCH 2/4] fix lint --- labelbox/pagination.py | 36 ++++++++++++++++++------------------ 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/labelbox/pagination.py b/labelbox/pagination.py index 8034eb2da..f4abb3720 100644 --- a/labelbox/pagination.py +++ b/labelbox/pagination.py @@ -80,30 +80,30 @@ def __next__(self): return next_value def get_one(self): - """Iterates over self and returns first value - This method is idempotent - """ - for value in self: - return value + """Iterates over self and returns first value + This method is idempotent + """ + for value in self: + return value def get_many(self, n: int): - """Iterates over self and returns first n results - This method is idempotent + """Iterates over self and returns first n results + This method is idempotent - Args: - n (int): Number of elements to retrieve - """ - results = [] - i = 0 + Args: + n (int): Number of elements to retrieve + """ + results = [] + i = 0 - for value in self: - if i >= n: - break + for value in self: + if i >= n: + break - results.append(value) - i += 1 + results.append(value) + i += 1 - return results + return results class _Pagination(ABC): From 2439e41fca351d8dbbdc183f0e02fd0e28a121a0 Mon Sep 17 00:00:00 2001 From: Dmitriy Apollonin Date: Fri, 2 Dec 2022 15:07:37 -0700 Subject: [PATCH 3/4] fix lint --- labelbox/pagination.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/labelbox/pagination.py b/labelbox/pagination.py index f4abb3720..3646b7a89 100644 --- a/labelbox/pagination.py +++ b/labelbox/pagination.py @@ -98,7 +98,7 @@ def get_many(self, n: int): for value in self: if i >= n: - break + break results.append(value) i += 1 From 34c864a448fda3b2f91c4f0b1fa571440de8e16d Mon Sep 17 00:00:00 2001 From: Dmitriy Apollonin Date: Fri, 2 Dec 2022 15:13:38 -0700 Subject: [PATCH 4/4] rename variables back --- labelbox/pagination.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/labelbox/pagination.py b/labelbox/pagination.py index 3646b7a89..9f58155ec 100644 --- a/labelbox/pagination.py +++ b/labelbox/pagination.py @@ -45,7 +45,7 @@ def __init__(self, """ self._fetched_all = False self._data: List[Dict[str, Any]] = [] - self._data_index = 0 + self._data_ind = 0 pagination_kwargs = { 'client': client, @@ -62,11 +62,11 @@ def __init__(self, **pagination_kwargs) def __iter__(self): - self._data_index = 0 + self._data_ind = 0 return self def __next__(self): - if len(self._data) <= self._data_index: + if len(self._data) <= self._data_ind: if self._fetched_all: raise StopIteration() @@ -75,9 +75,9 @@ def __next__(self): if len(page_data) == 0: raise StopIteration() - next_value = self._data[self._data_index] - self._data_index += 1 - return next_value + rval = self._data[self._data_ind] + self._data_ind += 1 + return rval def get_one(self): """Iterates over self and returns first value