Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions sdks/python/apache_beam/io/gcp/datastore/v1/fake_datastore.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,13 +90,17 @@ def create_response(entities, end_cursor, finish):
return resp


def create_entities(count):
def create_entities(count, id_or_name=False):
"""Creates a list of entities with random keys."""
entities = []

for _ in range(count):
entity_result = query_pb2.EntityResult()
entity_result.entity.key.path.add().name = str(uuid.uuid4())
if id_or_name:
entity_result.entity.key.path.add().id = (
uuid.uuid4().int & ((1 << 63) - 1))
else:
entity_result.entity.key.path.add().name = str(uuid.uuid4())
entities.append(entity_result)

return entities
14 changes: 3 additions & 11 deletions sdks/python/apache_beam/io/gcp/datastore/v1/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,28 +80,20 @@ def compare_path(p1, p2):
3. If no `id` is defined for both paths, then their `names` are compared.
"""

result = str_compare(p1.kind, p2.kind)
result = cmp(p1.kind, p2.kind)
if result != 0:
return result

if p1.HasField('id'):
if not p2.HasField('id'):
return -1

return p1.id - p2.id
return cmp(p1.id, p2.id)

if p2.HasField('id'):
return 1

return str_compare(p1.name, p2.name)


def str_compare(s1, s2):
if s1 == s2:
return 0
elif s1 < s2:
return -1
return 1
return cmp(p1.name, p2.name)


def get_datastore(project):
Expand Down
37 changes: 21 additions & 16 deletions sdks/python/apache_beam/io/gcp/datastore/v1/query_splitter_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,27 +157,32 @@ def check_get_splits(self, query, num_splits, num_entities, batch_size):
batch_size: the number of entities returned by fake datastore in one req.
"""

entities = fake_datastore.create_entities(num_entities)
mock_datastore = MagicMock()
# Assign a fake run_query method as a side_effect to the mock.
mock_datastore.run_query.side_effect = \
fake_datastore.create_run_query(entities, batch_size)
# Test for both random long ids and string ids.
id_or_name = [True, False]

split_queries = query_splitter.get_splits(mock_datastore, query, num_splits)
for id_type in id_or_name:
entities = fake_datastore.create_entities(num_entities, id_type)
mock_datastore = MagicMock()
# Assign a fake run_query method as a side_effect to the mock.
mock_datastore.run_query.side_effect = \
fake_datastore.create_run_query(entities, batch_size)

# if request num_splits is greater than num_entities, the best it can
# do is one entity per split.
expected_num_splits = min(num_splits, num_entities + 1)
self.assertEqual(len(split_queries), expected_num_splits)
split_queries = query_splitter.get_splits(
mock_datastore, query, num_splits)

expected_requests = QuerySplitterTest.create_scatter_requests(
query, num_splits, batch_size, num_entities)
# if request num_splits is greater than num_entities, the best it can
# do is one entity per split.
expected_num_splits = min(num_splits, num_entities + 1)
self.assertEqual(len(split_queries), expected_num_splits)

expected_calls = []
for req in expected_requests:
expected_calls.append(call(req))
expected_requests = QuerySplitterTest.create_scatter_requests(
query, num_splits, batch_size, num_entities)

self.assertEqual(expected_calls, mock_datastore.run_query.call_args_list)
expected_calls = []
for req in expected_requests:
expected_calls.append(call(req))

self.assertEqual(expected_calls, mock_datastore.run_query.call_args_list)

@staticmethod
def create_scatter_requests(query, num_splits, batch_size, num_entities):
Expand Down