diff --git a/sdks/python/apache_beam/io/gcp/datastore/v1/fake_datastore.py b/sdks/python/apache_beam/io/gcp/datastore/v1/fake_datastore.py index 23325793c15c..aa3780558d7f 100644 --- a/sdks/python/apache_beam/io/gcp/datastore/v1/fake_datastore.py +++ b/sdks/python/apache_beam/io/gcp/datastore/v1/fake_datastore.py @@ -90,13 +90,17 @@ def create_response(entities, end_cursor, finish): return resp -def create_entities(count): +def create_entities(count, id_or_name=False): """Creates a list of entities with random keys.""" entities = [] for _ in range(count): entity_result = query_pb2.EntityResult() - entity_result.entity.key.path.add().name = str(uuid.uuid4()) + if id_or_name: + entity_result.entity.key.path.add().id = ( + uuid.uuid4().int & ((1 << 63) - 1)) + else: + entity_result.entity.key.path.add().name = str(uuid.uuid4()) entities.append(entity_result) return entities diff --git a/sdks/python/apache_beam/io/gcp/datastore/v1/helper.py b/sdks/python/apache_beam/io/gcp/datastore/v1/helper.py index 9e2c0531e540..f977536f32aa 100644 --- a/sdks/python/apache_beam/io/gcp/datastore/v1/helper.py +++ b/sdks/python/apache_beam/io/gcp/datastore/v1/helper.py @@ -80,7 +80,7 @@ def compare_path(p1, p2): 3. If no `id` is defined for both paths, then their `names` are compared. """ - result = str_compare(p1.kind, p2.kind) + result = cmp(p1.kind, p2.kind) if result != 0: return result @@ -88,20 +88,12 @@ def compare_path(p1, p2): if not p2.HasField('id'): return -1 - return p1.id - p2.id + return cmp(p1.id, p2.id) if p2.HasField('id'): return 1 - return str_compare(p1.name, p2.name) - - -def str_compare(s1, s2): - if s1 == s2: - return 0 - elif s1 < s2: - return -1 - return 1 + return cmp(p1.name, p2.name) def get_datastore(project): diff --git a/sdks/python/apache_beam/io/gcp/datastore/v1/query_splitter_test.py b/sdks/python/apache_beam/io/gcp/datastore/v1/query_splitter_test.py index b7b054f382bd..52f25facd058 100644 --- a/sdks/python/apache_beam/io/gcp/datastore/v1/query_splitter_test.py +++ b/sdks/python/apache_beam/io/gcp/datastore/v1/query_splitter_test.py @@ -157,27 +157,32 @@ def check_get_splits(self, query, num_splits, num_entities, batch_size): batch_size: the number of entities returned by fake datastore in one req. """ - entities = fake_datastore.create_entities(num_entities) - mock_datastore = MagicMock() - # Assign a fake run_query method as a side_effect to the mock. - mock_datastore.run_query.side_effect = \ - fake_datastore.create_run_query(entities, batch_size) + # Test for both random long ids and string ids. + id_or_name = [True, False] - split_queries = query_splitter.get_splits(mock_datastore, query, num_splits) + for id_type in id_or_name: + entities = fake_datastore.create_entities(num_entities, id_type) + mock_datastore = MagicMock() + # Assign a fake run_query method as a side_effect to the mock. + mock_datastore.run_query.side_effect = \ + fake_datastore.create_run_query(entities, batch_size) - # if request num_splits is greater than num_entities, the best it can - # do is one entity per split. - expected_num_splits = min(num_splits, num_entities + 1) - self.assertEqual(len(split_queries), expected_num_splits) + split_queries = query_splitter.get_splits( + mock_datastore, query, num_splits) - expected_requests = QuerySplitterTest.create_scatter_requests( - query, num_splits, batch_size, num_entities) + # if request num_splits is greater than num_entities, the best it can + # do is one entity per split. + expected_num_splits = min(num_splits, num_entities + 1) + self.assertEqual(len(split_queries), expected_num_splits) - expected_calls = [] - for req in expected_requests: - expected_calls.append(call(req)) + expected_requests = QuerySplitterTest.create_scatter_requests( + query, num_splits, batch_size, num_entities) - self.assertEqual(expected_calls, mock_datastore.run_query.call_args_list) + expected_calls = [] + for req in expected_requests: + expected_calls.append(call(req)) + + self.assertEqual(expected_calls, mock_datastore.run_query.call_args_list) @staticmethod def create_scatter_requests(query, num_splits, batch_size, num_entities):