diff --git a/sdks/python/apache_beam/io/gcp/datastore/v1new/query_splitter.py b/sdks/python/apache_beam/io/gcp/datastore/v1new/query_splitter.py index 4f8be839df67..f1420fdcf653 100644 --- a/sdks/python/apache_beam/io/gcp/datastore/v1new/query_splitter.py +++ b/sdks/python/apache_beam/io/gcp/datastore/v1new/query_splitter.py @@ -27,6 +27,7 @@ from builtins import round from apache_beam.io.gcp.datastore.v1new import types +from apache_beam.options.value_provider import ValueProvider __all__ = ['QuerySplitterError', 'SplitNotPossibleError', 'get_splits'] @@ -104,7 +105,11 @@ def validate_split(query): raise SplitNotPossibleError('Query cannot have a limit set.') for filter in query.filters: - if filter[1] in ['<', '<=', '>', '>=']: + if isinstance(filter[1], ValueProvider): + filter_operator = filter[1].get() + else: + filter_operator = filter[1] + if filter_operator in ['<', '<=', '>', '>=']: raise SplitNotPossibleError('Query cannot have any inequality filters.') diff --git a/sdks/python/apache_beam/io/gcp/datastore/v1new/types.py b/sdks/python/apache_beam/io/gcp/datastore/v1new/types.py index c80fe0486b27..b89425c69f0a 100644 --- a/sdks/python/apache_beam/io/gcp/datastore/v1new/types.py +++ b/sdks/python/apache_beam/io/gcp/datastore/v1new/types.py @@ -29,6 +29,8 @@ from google.cloud.datastore import key from google.cloud.datastore import query +from apache_beam.options.value_provider import ValueProvider + __all__ = ['Query', 'Key', 'Entity'] @@ -44,8 +46,11 @@ def __init__(self, kind=None, project=None, namespace=None, ancestor=None, ancestor: (:class:`~apache_beam.io.gcp.datastore.v1new.types.Key`) (Optional) key of the ancestor to which this query's results are restricted. - filters: (sequence of tuple[str, str, str]) Property filters applied by - this query. The sequence is ``(property_name, operator, value)``. + filters: (sequence of tuple[str, str, str], + sequence of + tuple[ValueProvider(str), ValueProvider(str), ValueProvider(str)]) + Property filters applied by this query. + The sequence is ``(property_name, operator, value)``. projection: (sequence of string) fields returned as part of query results. order: (sequence of string) field names used to order query results. Prepend ``-`` to a field name to sort it in descending order. @@ -75,12 +80,39 @@ def _to_client_query(self, client): ancestor_client_key = None if self.ancestor is not None: ancestor_client_key = self.ancestor.to_client_key() + + self.filters = self._set_runtime_filters() + return query.Query( client, kind=self.kind, project=self.project, namespace=self.namespace, ancestor=ancestor_client_key, filters=self.filters, projection=self.projection, order=self.order, distinct_on=self.distinct_on) + def _set_runtime_filters(self): + """ + Extracts values from ValueProviders in `self.filters` if available + :param filters: sequence of tuple[str, str, str] or + sequence of tuple[ValueProvider, ValueProvider, ValueProvider] + :return: tuple[str, str, str] + """ + runtime_filters = [] + if not all(len(filter_tuple) == 3 for filter_tuple in self.filters): + raise TypeError('%s: filters must be a sequence of tuple with length=3' + ' got %r instead' + % (self.__class__.__name__, self.filters)) + + for filter_type, filter_operator, filter_value in self.filters: + if isinstance(filter_type, ValueProvider): + filter_type = filter_type.get() + if isinstance(filter_operator, ValueProvider): + filter_operator = filter_operator.get() + if isinstance(filter_value, ValueProvider): + filter_value = filter_value.get() + runtime_filters.append((filter_type, filter_operator, filter_value)) + + return runtime_filters or () + def clone(self): return copy.copy(self) diff --git a/sdks/python/apache_beam/io/gcp/datastore/v1new/types_test.py b/sdks/python/apache_beam/io/gcp/datastore/v1new/types_test.py index 7ba82c546b1c..17cb0cc95d1f 100644 --- a/sdks/python/apache_beam/io/gcp/datastore/v1new/types_test.py +++ b/sdks/python/apache_beam/io/gcp/datastore/v1new/types_test.py @@ -30,6 +30,7 @@ from apache_beam.io.gcp.datastore.v1new.types import Entity from apache_beam.io.gcp.datastore.v1new.types import Key from apache_beam.io.gcp.datastore.v1new.types import Query + from apache_beam.options.value_provider import StaticValueProvider # TODO(BEAM-4543): Remove TypeError once googledatastore dependency is removed. except (ImportError, TypeError): client = None @@ -134,6 +135,31 @@ def testQuery(self): logging.info('query: %s', q) # Test __repr__() + def testValueProviderFilters(self): + self.vp_filters = [ + [( + StaticValueProvider(str, 'property_name'), + StaticValueProvider(str, '='), + StaticValueProvider(str, 'value'))], + [( + StaticValueProvider(str, 'property_name'), + StaticValueProvider(str, '='), + StaticValueProvider(str, 'value')), + ('property_name', '=', 'value')], + ] + self.expected_filters = [[('property_name', '=', 'value')], + [('property_name', '=', 'value'), + ('property_name', '=', 'value')], + ] + + for vp_filter, exp_filter in zip(self.vp_filters, self.expected_filters): + q = Query(kind='kind', project=self._PROJECT, namespace=self._NAMESPACE, + filters=vp_filter) + cq = q._to_client_query(self._test_client) + self.assertEqual(exp_filter, cq.filters) + + logging.info('query: %s', q) # Test __repr__() + def testQueryEmptyNamespace(self): # Test that we can pass a namespace of None. self._test_client.namespace = None