Skip to content

Commit

Permalink
Merge pull request #2490 from dhermes/fix-1467
Browse files Browse the repository at this point in the history
Making max_results part of the base Iterator class.
  • Loading branch information
dhermes committed Oct 4, 2016
2 parents 1828fb2 + 78b09c5 commit 8c50e41
Show file tree
Hide file tree
Showing 5 changed files with 157 additions and 50 deletions.
72 changes: 62 additions & 10 deletions core/google/cloud/iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,40 +45,84 @@ def get_items_from_response(self, response):
"""


import six


class Iterator(object):
"""A generic class for iterating through Cloud JSON APIs list responses.
:type client: :class:`google.cloud.client.Client`
:param client: The client, which owns a connection to make requests.
:type path: string
:type path: str
:param path: The path to query for the list of items.
:type page_token: str
:param page_token: (Optional) A token identifying a page in a result set.
:type max_results: int
:param max_results: (Optional) The maximum number of results to fetch.
:type extra_params: dict or None
:param extra_params: Extra query string parameters for the API call.
"""

PAGE_TOKEN = 'pageToken'
RESERVED_PARAMS = frozenset([PAGE_TOKEN])
MAX_RESULTS = 'maxResults'
RESERVED_PARAMS = frozenset([PAGE_TOKEN, MAX_RESULTS])

def __init__(self, client, path, extra_params=None):
def __init__(self, client, path, page_token=None,
max_results=None, extra_params=None):
self.client = client
self.path = path
self.page_number = 0
self.next_page_token = None
self.next_page_token = page_token
self.max_results = max_results
self.num_results = 0
self.extra_params = extra_params or {}
reserved_in_use = self.RESERVED_PARAMS.intersection(
self.extra_params)
if reserved_in_use:
raise ValueError(('Using a reserved parameter',
reserved_in_use))
self._curr_items = iter(())

def __iter__(self):
"""Iterate through the list of items."""
while self.has_next_page():
"""The :class:`Iterator` is an iterator."""
return self

def _update_items(self):
"""Replace the current items iterator.
Intended to be used when the current items iterator is exhausted.
After replacing the iterator, consumes the first value to make sure
it is valid.
:rtype: object
:returns: The first item in the next iterator.
:raises: :class:`~exceptions.StopIteration` if there is no next page.
"""
if self.has_next_page():
response = self.get_next_page_response()
for item in self.get_items_from_response(response):
yield item
items = self.get_items_from_response(response)
self._curr_items = iter(items)
return six.next(self._curr_items)
else:
raise StopIteration

def next(self):
"""Get the next value in the iterator."""
try:
item = six.next(self._curr_items)
except StopIteration:
item = self._update_items()

self.num_results += 1
return item

# Alias needed for Python 2/3 support.
__next__ = next

def has_next_page(self):
"""Determines whether or not this iterator has more pages.
Expand All @@ -89,6 +133,10 @@ def has_next_page(self):
if self.page_number == 0:
return True

if self.max_results is not None:
if self.num_results >= self.max_results:
return False

return self.next_page_token is not None

def get_query_params(self):
Expand All @@ -97,8 +145,11 @@ def get_query_params(self):
:rtype: dict
:returns: A dictionary of query parameters.
"""
result = ({self.PAGE_TOKEN: self.next_page_token}
if self.next_page_token else {})
result = {}
if self.next_page_token is not None:
result[self.PAGE_TOKEN] = self.next_page_token
if self.max_results is not None:
result[self.MAX_RESULTS] = self.max_results - self.num_results
result.update(self.extra_params)
return result

Expand All @@ -123,6 +174,7 @@ def reset(self):
"""Resets the iterator to the beginning."""
self.page_number = 0
self.next_page_token = None
self.num_results = 0

def get_items_from_response(self, response):
"""Factory method called while iterating. This should be overridden.
Expand Down
69 changes: 57 additions & 12 deletions core/unit_tests/test_iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,21 +34,49 @@ def test_ctor(self):
self.assertEqual(iterator.page_number, 0)
self.assertIsNone(iterator.next_page_token)

def test_constructor_w_extra_param_collision(self):
connection = _Connection()
client = _Client(connection)
PATH = '/foo'
extra_params = {'pageToken': 'val'}
self.assertRaises(ValueError, self._makeOne, client, PATH,
extra_params=extra_params)

def test___iter__(self):
iterator = self._makeOne(None, None)
self.assertIs(iter(iterator), iterator)

def test_iterate(self):
import six

PATH = '/foo'
KEY1 = 'key1'
KEY2 = 'key2'
ITEM1, ITEM2 = object(), object()
ITEMS = {KEY1: ITEM1, KEY2: ITEM2}

def _get_items(response):
for item in response.get('items', []):
yield ITEMS[item['name']]
connection = _Connection({'items': [{'name': KEY1}, {'name': KEY2}]})
return [ITEMS[item['name']]
for item in response.get('items', [])]

connection = _Connection(
{'items': [{'name': KEY1}, {'name': KEY2}]})
client = _Client(connection)
iterator = self._makeOne(client, PATH)
iterator.get_items_from_response = _get_items
self.assertEqual(list(iterator), [ITEM1, ITEM2])
self.assertEqual(iterator.num_results, 0)

val1 = six.next(iterator)
self.assertEqual(val1, ITEM1)
self.assertEqual(iterator.num_results, 1)

val2 = six.next(iterator)
self.assertEqual(val2, ITEM2)
self.assertEqual(iterator.num_results, 2)

with self.assertRaises(StopIteration):
six.next(iterator)

kw, = connection._requested
self.assertEqual(kw['method'], 'GET')
self.assertEqual(kw['path'], PATH)
Expand Down Expand Up @@ -79,6 +107,19 @@ def test_has_next_page_w_number_w_token(self):
iterator.next_page_token = TOKEN
self.assertTrue(iterator.has_next_page())

def test_has_next_page_w_max_results_not_done(self):
iterator = self._makeOne(None, None, max_results=3,
page_token='definitely-not-none')
iterator.page_number = 1
self.assertLess(iterator.num_results, iterator.max_results)
self.assertTrue(iterator.has_next_page())

def test_has_next_page_w_max_results_done(self):
iterator = self._makeOne(None, None, max_results=3)
iterator.page_number = 1
iterator.num_results = iterator.max_results
self.assertFalse(iterator.has_next_page())

def test_get_query_params_no_token(self):
connection = _Connection()
client = _Client(connection)
Expand All @@ -96,6 +137,18 @@ def test_get_query_params_w_token(self):
self.assertEqual(iterator.get_query_params(),
{'pageToken': TOKEN})

def test_get_query_params_w_max_results(self):
connection = _Connection()
client = _Client(connection)
path = '/foo'
max_results = 3
iterator = self._makeOne(client, path,
max_results=max_results)
iterator.num_results = 1
local_max = max_results - iterator.num_results
self.assertEqual(iterator.get_query_params(),
{'maxResults': local_max})

def test_get_query_params_extra_params(self):
connection = _Connection()
client = _Client(connection)
Expand All @@ -117,14 +170,6 @@ def test_get_query_params_w_token_and_extra_params(self):
expected_query.update({'pageToken': TOKEN})
self.assertEqual(iterator.get_query_params(), expected_query)

def test_get_query_params_w_token_collision(self):
connection = _Connection()
client = _Client(connection)
PATH = '/foo'
extra_params = {'pageToken': 'val'}
self.assertRaises(ValueError, self._makeOne, client, PATH,
extra_params=extra_params)

def test_get_next_page_response_new_no_token_in_response(self):
PATH = '/foo'
TOKEN = 'token'
Expand Down
14 changes: 11 additions & 3 deletions resource_manager/google/cloud/resource_manager/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,14 +168,22 @@ class _ProjectIterator(Iterator):
:type client: :class:`~google.cloud.resource_manager.client.Client`
:param client: The client to use for making connections.
:type page_token: str
:param page_token: (Optional) A token identifying a page in a result set.
:type max_results: int
:param max_results: (Optional) The maximum number of results to fetch.
:type extra_params: dict
:param extra_params: (Optional) Extra query string parameters for
the API call.
"""

def __init__(self, client, extra_params=None):
super(_ProjectIterator, self).__init__(client=client, path='/projects',
extra_params=extra_params)
def __init__(self, client, page_token=None,
max_results=None, extra_params=None):
super(_ProjectIterator, self).__init__(
client=client, path='/projects', page_token=page_token,
max_results=max_results, extra_params=extra_params)

def get_items_from_response(self, response):
"""Yield projects from response.
Expand Down
22 changes: 11 additions & 11 deletions storage/google/cloud/storage/bucket.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,21 +37,29 @@ class _BlobIterator(Iterator):
:type bucket: :class:`google.cloud.storage.bucket.Bucket`
:param bucket: The bucket from which to list blobs.
:type page_token: str
:param page_token: (Optional) A token identifying a page in a result set.
:type max_results: int
:param max_results: (Optional) The maximum number of results to fetch.
:type extra_params: dict or None
:param extra_params: Extra query string parameters for the API call.
:type client: :class:`google.cloud.storage.client.Client`
:param client: Optional. The client to use for making connections.
Defaults to the bucket's client.
"""
def __init__(self, bucket, extra_params=None, client=None):
def __init__(self, bucket, page_token=None, max_results=None,
extra_params=None, client=None):
if client is None:
client = bucket.client
self.bucket = bucket
self.prefixes = set()
self._current_prefixes = None
super(_BlobIterator, self).__init__(
client=client, path=bucket.path + '/o',
page_token=page_token, max_results=max_results,
extra_params=extra_params)

def get_items_from_response(self, response):
Expand Down Expand Up @@ -285,9 +293,6 @@ def list_blobs(self, max_results=None, page_token=None, prefix=None,
"""
extra_params = {}

if max_results is not None:
extra_params['maxResults'] = max_results

if prefix is not None:
extra_params['prefix'] = prefix

Expand All @@ -303,13 +308,8 @@ def list_blobs(self, max_results=None, page_token=None, prefix=None,
extra_params['fields'] = fields

result = self._iterator_class(
self, extra_params=extra_params, client=client)
# Page token must be handled specially since the base `Iterator`
# class has it as a reserved property.
if page_token is not None:
# pylint: disable=attribute-defined-outside-init
result.next_page_token = page_token
# pylint: enable=attribute-defined-outside-init
self, page_token=page_token, max_results=max_results,
extra_params=extra_params, client=client)
return result

def delete(self, force=False, client=None):
Expand Down
30 changes: 16 additions & 14 deletions storage/google/cloud/storage/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,9 +256,6 @@ def list_buckets(self, max_results=None, page_token=None, prefix=None,
"""
extra_params = {'project': self.project}

if max_results is not None:
extra_params['maxResults'] = max_results

if prefix is not None:
extra_params['prefix'] = prefix

Expand All @@ -267,14 +264,10 @@ def list_buckets(self, max_results=None, page_token=None, prefix=None,
if fields is not None:
extra_params['fields'] = fields

result = _BucketIterator(client=self,
extra_params=extra_params)
# Page token must be handled specially since the base `Iterator`
# class has it as a reserved property.
if page_token is not None:
# pylint: disable=attribute-defined-outside-init
result.next_page_token = page_token
# pylint: enable=attribute-defined-outside-init
result = _BucketIterator(
client=self, page_token=page_token,
max_results=max_results, extra_params=extra_params)

return result


Expand All @@ -288,13 +281,22 @@ class _BucketIterator(Iterator):
:type client: :class:`google.cloud.storage.client.Client`
:param client: The client to use for making connections.
:type page_token: str
:param page_token: (Optional) A token identifying a page in a result set.
:type max_results: int
:param max_results: (Optional) The maximum number of results to fetch.
:type extra_params: dict or ``NoneType``
:param extra_params: Extra query string parameters for the API call.
"""

def __init__(self, client, extra_params=None):
super(_BucketIterator, self).__init__(client=client, path='/b',
extra_params=extra_params)
def __init__(self, client, page_token=None,
max_results=None, extra_params=None):
super(_BucketIterator, self).__init__(
client=client, path='/b',
page_token=page_token, max_results=max_results,
extra_params=extra_params)

def get_items_from_response(self, response):
"""Factory method which yields :class:`.Bucket` items from a response.
Expand Down

0 comments on commit 8c50e41

Please sign in to comment.