diff --git a/labelbox/pagination.py b/labelbox/pagination.py index 8d46901e5..9f58155ec 100644 --- a/labelbox/pagination.py +++ b/labelbox/pagination.py @@ -79,6 +79,32 @@ def __next__(self): self._data_ind += 1 return rval + def get_one(self): + """Iterates over self and returns first value + This method is idempotent + """ + for value in self: + return value + + def get_many(self, n: int): + """Iterates over self and returns first n results + This method is idempotent + + Args: + n (int): Number of elements to retrieve + """ + results = [] + i = 0 + + for value in self: + if i >= n: + break + + results.append(value) + i += 1 + + return results + class _Pagination(ABC): diff --git a/tests/integration/test_dataset.py b/tests/integration/test_dataset.py index 89a89b78c..e9094ae98 100644 --- a/tests/integration/test_dataset.py +++ b/tests/integration/test_dataset.py @@ -21,6 +21,14 @@ def test_dataset(client, rand_gen): assert len(after) == len(before) + 1 assert dataset in after + # confirm get_one returns first dataset + get_one_dataset = client.get_datasets().get_one() + assert get_one_dataset.uid == after[0].uid + + # confirm get_many(1) returns first dataset + get_many_datasets = client.get_datasets().get_many(1) + assert get_many_datasets[0].uid == after[0].uid + dataset = client.get_dataset(dataset.uid) assert dataset.name == name