Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from builtins import round

from apache_beam.io.gcp.datastore.v1new import types
from apache_beam.options.value_provider import ValueProvider

__all__ = ['QuerySplitterError', 'SplitNotPossibleError', 'get_splits']

Expand Down Expand Up @@ -104,7 +105,11 @@ def validate_split(query):
raise SplitNotPossibleError('Query cannot have a limit set.')

for filter in query.filters:
if filter[1] in ['<', '<=', '>', '>=']:
if isinstance(filter[1], ValueProvider):
filter_operator = filter[1].get()
else:
filter_operator = filter[1]
if filter_operator in ['<', '<=', '>', '>=']:
raise SplitNotPossibleError('Query cannot have any inequality filters.')


Expand Down
36 changes: 34 additions & 2 deletions sdks/python/apache_beam/io/gcp/datastore/v1new/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
from google.cloud.datastore import key
from google.cloud.datastore import query

from apache_beam.options.value_provider import ValueProvider

__all__ = ['Query', 'Key', 'Entity']


Expand All @@ -44,8 +46,11 @@ def __init__(self, kind=None, project=None, namespace=None, ancestor=None,
ancestor: (:class:`~apache_beam.io.gcp.datastore.v1new.types.Key`)
(Optional) key of the ancestor to which this query's results are
restricted.
filters: (sequence of tuple[str, str, str]) Property filters applied by
this query. The sequence is ``(property_name, operator, value)``.
filters: (sequence of tuple[str, str, str],
sequence of
tuple[ValueProvider(str), ValueProvider(str), ValueProvider(str)])
Property filters applied by this query.
The sequence is ``(property_name, operator, value)``.
projection: (sequence of string) fields returned as part of query results.
order: (sequence of string) field names used to order query results.
Prepend ``-`` to a field name to sort it in descending order.
Expand Down Expand Up @@ -75,12 +80,39 @@ def _to_client_query(self, client):
ancestor_client_key = None
if self.ancestor is not None:
ancestor_client_key = self.ancestor.to_client_key()

self.filters = self._set_runtime_filters()

return query.Query(
client, kind=self.kind, project=self.project, namespace=self.namespace,
ancestor=ancestor_client_key, filters=self.filters,
projection=self.projection, order=self.order,
distinct_on=self.distinct_on)

def _set_runtime_filters(self):
"""
Extracts values from ValueProviders in `self.filters` if available
:param filters: sequence of tuple[str, str, str] or
sequence of tuple[ValueProvider, ValueProvider, ValueProvider]
:return: tuple[str, str, str]
"""
runtime_filters = []
if not all(len(filter_tuple) == 3 for filter_tuple in self.filters):
raise TypeError('%s: filters must be a sequence of tuple with length=3'
' got %r instead'
% (self.__class__.__name__, self.filters))

for filter_type, filter_operator, filter_value in self.filters:
if isinstance(filter_type, ValueProvider):
filter_type = filter_type.get()
if isinstance(filter_operator, ValueProvider):
filter_operator = filter_operator.get()
if isinstance(filter_value, ValueProvider):
filter_value = filter_value.get()
runtime_filters.append((filter_type, filter_operator, filter_value))

return runtime_filters or ()

def clone(self):
return copy.copy(self)

Expand Down
26 changes: 26 additions & 0 deletions sdks/python/apache_beam/io/gcp/datastore/v1new/types_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from apache_beam.io.gcp.datastore.v1new.types import Entity
from apache_beam.io.gcp.datastore.v1new.types import Key
from apache_beam.io.gcp.datastore.v1new.types import Query
from apache_beam.options.value_provider import StaticValueProvider
# TODO(BEAM-4543): Remove TypeError once googledatastore dependency is removed.
except (ImportError, TypeError):
client = None
Expand Down Expand Up @@ -134,6 +135,31 @@ def testQuery(self):

logging.info('query: %s', q) # Test __repr__()

def testValueProviderFilters(self):
self.vp_filters = [
[(
StaticValueProvider(str, 'property_name'),
StaticValueProvider(str, '='),
StaticValueProvider(str, 'value'))],
[(
StaticValueProvider(str, 'property_name'),
StaticValueProvider(str, '='),
StaticValueProvider(str, 'value')),
('property_name', '=', 'value')],
]
self.expected_filters = [[('property_name', '=', 'value')],
[('property_name', '=', 'value'),
('property_name', '=', 'value')],
]

for vp_filter, exp_filter in zip(self.vp_filters, self.expected_filters):
q = Query(kind='kind', project=self._PROJECT, namespace=self._NAMESPACE,
filters=vp_filter)
cq = q._to_client_query(self._test_client)
self.assertEqual(exp_filter, cq.filters)

logging.info('query: %s', q) # Test __repr__()

def testQueryEmptyNamespace(self):
# Test that we can pass a namespace of None.
self._test_client.namespace = None
Expand Down