From 45987be3e6e015ae34e4e6e7631bd72a30c52fcc Mon Sep 17 00:00:00 2001 From: Mark Nunberg Date: Tue, 20 Feb 2018 13:19:28 -0500 Subject: [PATCH 1/5] Add aggregation & query builders This still needs tests --- redisearch/aggregation.py | 297 +++++++++++++++++++++++++++++++++++++ redisearch/client.py | 6 + redisearch/querystring.py | 303 ++++++++++++++++++++++++++++++++++++++ redisearch/reducers.py | 134 +++++++++++++++++ 4 files changed, 740 insertions(+) create mode 100644 redisearch/aggregation.py create mode 100644 redisearch/querystring.py create mode 100644 redisearch/reducers.py diff --git a/redisearch/aggregation.py b/redisearch/aggregation.py new file mode 100644 index 0000000..10a8eae --- /dev/null +++ b/redisearch/aggregation.py @@ -0,0 +1,297 @@ +FIELDNAME = object() + + +class Reducer(object): + """ + Base reducer object for all reducers. + + See the `redisearch.reducers` module for the actual reducers. + """ + NAME = None + + def __init__(self, *args): + self._args = args + self._field = None + self._alias = None + pass + + def alias(self, alias): + """ + Set the alias for this reducer. + + ### Parameters + + - **alias**: The value of the alias for this reducer. If this is the + special value `aggregation.FIELDNAME` then this reducer will be + aliased using the same name as the field upon which it operates. + Note that using `FIELDNAME` is only possible on reducers which + operate on a single field value. + + This method returns the `Reducer` object making it suitable for + chaining. + """ + if alias is FIELDNAME: + if not self._field: + raise ValueError("Cannot use FIELDNAME alias with no field") + # Chop off initial '@' + alias = self._field[1:] + self._alias = alias + return self + + @property + def args(self): + return self._args + + +class SortDirection(object): + """ + This special class is used to indicate sort direction. + """ + DIRSTRING = None + + def __init__(self, field): + self.field = field + + +class Asc(SortDirection): + """ + Indicate that the given field should be sorted in ascending order + """ + DIRSTRING = 'ASC' + + +class Desc(SortDirection): + """ + Indicate that the given field should be sorted in descending order + """ + DIRSTRING = 'DESC' + + +class Group(object): + """ + This object automatically created in the `AggregateRequest.group_by()` + """ + def __init__(self, *fields): + self.fields = fields + self.reducers = [] + self.limit = [0, 0] + + def add_reducer(self, reducer): + self.reducers.append(reducer) + + def validate(self): + if not self.reducers: + raise ValueError('Need at least one reducer') + + def build_args(self): + self.validate() + if not self.fields: + raise Exception('No fields to group by') + ret = [str(len(self.fields))] + ret.extend(self.fields) + for reducer in self.reducers: + ret += ['REDUCE', reducer.NAME, str(len(reducer.args))] + ret.extend(reducer.args) + if reducer._alias: + ret += ['AS', reducer._alias] + return ret + + +class AggregateRequest(object): + """ + Aggregation request which can be passed to `Client.aggregate`. + """ + def __init__(self, query='*'): + """ + Create an aggregation request. This request may then be passed to + `client.aggregate()`. + + In order for the request to be usable, it must contain at least one + group. + + - **query** Query string for filtering records. + + All member methods (except `build_args()`) + return the object itself, making them useful for chaining. + """ + self._query = query + self._groups = [] + self._projections = [] + self._loadfields = [] + self._limit = [] + self._sortby = [] + self._max = 0 + + def load(self, *fields): + """ + Indicate the fields to be returned in the response. These fields are + returned in addition to any others implicitly specified. + + ### Parameters + + - **fields**: One or more fields in the format of `@field` + """ + self._loadfields.extend(fields) + return self + + def group_by(self, fields, *reducers): + """ + Specify by which fields to group the aggregation. + + ### Parameters + + - **fields**: Fields to group by. This can either be a single string, + or a list of strings. both cases, the field should be specified as + `@field`. + - **reducers**: One or more reducers. Reducers may be found in the + `aggregation` module. + """ + if isinstance(fields, basestring): + fields = [fields] + group = Group(*fields) + + if not reducers: + raise ValueError("Must pass at least one reducer") + + for reducer in reducers: + group.add_reducer(reducer) + + self._groups.append(group) + + return self + + def apply(self, **kwexpr): + """ + Specify one or more projection expressions to add to each result + + ### Parameters + + - **kwexpr**: One or more key-value pairs for a projection. The key is + the alias for the projection, and the value is the projection + expression itself, for example `apply(square_root="sqrt(@foo)")` + """ + for alias, expr in kwexpr.items(): + self._projections.append([alias, expr]) + + return self + + def limit(self, offset, num): + """ + Sets the limit for the most recent group or query. + + If no group has been defined yet (via `group_by()`) then this sets + the limit for the initial pool of results from the query. Otherwise, + this limits the number of items operated on from the previous group. + + Setting a limit on the initial search results may be useful when + attempting to execute an aggregation on a sample of a large data set. + + ### Parameters + + - **offset**: Result offset from which to begin paging + - **num**: Number of results to return + + + Example of sorting the initial results: + + ``` + AggregateRequest('@sale_amount:[10000, inf]')\ + .limit(0, 10)\ + .group_by('@state', r.count()) + ``` + + Will only group by the states found in the first 10 results of the + query `@sale_amount:[10000, inf]`. On the other hand, + + ``` + AggregateRequest('@sale_amount:[10000, inf]')\ + .limit(0, 1000)\ + .group_by('@state', r.count()\ + .limit(0, 10) + ``` + + Will group all the results matching the query, but only return the + first 10 groups. + + If you only wish to return a *top-N* style query, consider using + `sort_by()` instead. + + """ + if self._groups: + self._groups[-1].limit = [offset, num] + else: + self._limit = [offset, num] + return self + + def sort_by(self, fields, max=0): + """ + Indicate how the results should be sorted. This can also be used for + *top-N* style queries + + ### Parameters + + - **fields**: The fields by which to sort. This can be either a single + field or a list of fields. If you wish to specify order, you can + use the `Asc` or `Desc` wrapper classes. + - **max**: Maximum number of results to return. This can be used instead + of `LIMIT` and is also faster. + + + Example of sorting by `foo` ascending and `bar` descending: + + ``` + sort_by(Asc('@foo'), Desc('@bar')) + ``` + + Return the top 10 customers: + + ``` + AggregateRequest()\ + .group_by('@customer', r.sum('@paid').alias(FIELDNAME))\ + .sort_by(Desc('@paid'), max=10) + ``` + """ + self._max = max + if isinstance(fields, (basestring, SortDirection)): + fields = [fields] + for f in fields: + if isinstance(f, SortDirection): + self._sortby += [f.field, f.DIRSTRING] + else: + self._sortby.append(f) + return self + + def validate(self): + if not self._groups: + raise ValueError('Request requires at least one group') + + def build_args(self): + self.validate() + # @foo:bar ... + ret = [self._query] + if self._loadfields: + ret.append('LOAD') + ret.append(str(len(self._loadfields))) + ret.extend(self._loadfields) + for group in self._groups: + ret += ['GROUPBY'] + ret.extend(group.build_args()) + if group.limit: + ret += ['LIMIT'] + [str(x) for x in group.limit] + for alias, projector in self._projections: + ret += ['APPLY', projector] + if alias: + ret += ['AS', alias] + + if self._sortby: + ret += ['SORTBY', str(len(self._sortby))] + ret += self._sortby + if self._max: + ret += ['MAX', str(self._max)] + + if self._limit: + ret += ['LIMIT'] + [str(x) for x in self._limit] + + return ret + + diff --git a/redisearch/client.py b/redisearch/client.py index 170f38c..acf0f04 100644 --- a/redisearch/client.py +++ b/redisearch/client.py @@ -4,6 +4,7 @@ from .document import Document from .result import Result from .query import Query, Filter +from .aggregate_request import AggregateRequest class Field(object): @@ -102,6 +103,7 @@ class Client(object): DROP_CMD = 'FT.DROP' EXPLAIN_CMD = 'FT.EXPLAIN' DEL_CMD = 'FT.DEL' + AGGREGATE_CMD = 'FT.AGGREGATE' NOOFFSETS = 'NOOFFSETS' @@ -314,3 +316,7 @@ def search(self, query): def explain(self, query): args, query_text = self._mk_query_args(query) return self.redis.execute_command(self.EXPLAIN_CMD, *args) + + def aggregate(self, query): + cmd = [self.AGGREGATE_CMD, self.index_name] + query.build_args() + return self.redis.execute_command(*cmd) \ No newline at end of file diff --git a/redisearch/querystring.py b/redisearch/querystring.py new file mode 100644 index 0000000..86b951e --- /dev/null +++ b/redisearch/querystring.py @@ -0,0 +1,303 @@ + + +def tags(*t): + """ + Indicate that the values should be matched to a tag field + + ### Parameters + + - **t**: Tags to search for + """ + return '{' + ','.join(t) + '}' + + +def between(a, b, inclusive_min=True, inclusive_max=True): + """ + Indicate that value is a numeric range + """ + return RangeValue(a, b, + inclusive_min=inclusive_min, inclusive_max=inclusive_max) + +def equal(n): + """ + Match a numeric value + """ + return between(n, n) + + +def lt(n): + """ + Match any value less than n + """ + return between(None, n, inclusive_max=False) + + +def le(n): + """ + Match any value less or equal to n + """ + return between(None, n, inclusive_max=True) + + +def gt(n): + """ + Match any value greater than n + """ + return between(n, None, inclusive_min=False) + + +def ge(n): + """ + Match any value greater or equal to n + """ + return between(n, None, inclusive_min=True) + + +def geo(lat, lon, radius, unit='km'): + """ + Indicate that value is a geo region + """ + return GeoValue(lat, lon, radius, unit) + + +class Value(object): + @property + def combinable(self): + """ + Whether this type of value may be combined with other values for the same + field. This makes the filter potentially more efficient + """ + return False + + @staticmethod + def make_value(v): + """ + Convert an object to a value, if it is not a value already + """ + if isinstance(v, Value): + return v + return ScalarValue(v) + + +class RangeValue(Value): + combinable = False + + def __init__(self, a, b, inclusive_min=False, inclusive_max=False): + self.range = [a,b] + self.inclusive_min = inclusive_min + self.inclusive_max = inclusive_max + + def to_string(self): + a, b = self.range + a = a if a is not None else '-inf' + b = b if b is not None else 'inf' + if not self.inclusive_min: + a = '(' + a + if not self.inclusive_max: + b = '(' + b + return '[{} {}]'.format(a, b) + + +class ScalarValue(Value): + combinable = True + + def __init__(self, v): + self.v = str(v) + + def to_string(self): + return self.v + + +class TagValue(Value): + combinable = False + + def __init__(self, *tags): + self.tags = tags + + def to_string(self): + return '{' + ','.join(self.tags) + '}' + + +class GeoValue(Value): + def __init__(self, lon, lat, radius, unit='km'): + self.lon = lon + self.lat = lat + self.radius = radius + self.unit = unit + + +class Node(object): + def __init__(self, *children, **kwparams): + """ + Create a node + + ### Parameters + + - **children**: One or more sub-conditions. These can be additional + `intersect`, `disjunct`, `union`, `optional`, or any other `Node` + type. + + The semantics of multiple conditions are dependent on the type of + query. For an `intersection` node, this amounts to a logical AND, + for a `union` node, this amounts to a logical `OR`. + + - **kwparams**: key-value parameters. Each key is the name of a field, + and the value should be a field value. This can be one of the + following: + + - Simple string (for text field matches) + - value returned by one of the helper functions + - list of either a string or a value + + + ### Examples + + Field `num` should be between 1 and 10 + ``` + intersect(num=between(1, 10) + ``` + + Name can either be `bob` or `john` + + ``` + union(name=('bob', 'john')) + ``` + + Don't select countries in Israel, Japan, or US + + ``` + disjunct_union(country=('il', 'jp', 'us')) + ``` + """ + + self.params = [] + + kvparams = {} + for k, v in kwparams.items(): + curvals = kvparams.setdefault(k, []) + if isinstance(v, (basestring, int, long, float)): + curvals.append(Value.make_value(v)) + elif isinstance(v, Value): + curvals.append(v) + else: + curvals.extend(Value.make_value(subv) for subv in v) + + self.params += [Node.to_node(p) for p in children] + + for k, v in kvparams.items(): + self.params.extend(self.join_fields(k, v)) + + def join_fields(self, key, vals): + if len(vals) == 1: + return [BaseNode('@{}:{}'.format(key, vals[0].to_string()))] + if not vals[0].combinable: + return [BaseNode('@{}:{}'.format(key, v.to_string())) for v in vals] + s = BaseNode('@{}:({})'.format(key, self.JOINSTR.join(v.to_string() for v in vals))) + return [s] + + @classmethod + def to_node(cls, obj): + if isinstance(obj, Node): + return obj + return BaseNode(obj) + + @property + def JOINSTR(self): + raise NotImplemented() + + def to_string(self, with_parens=None): + with_parens = self._should_use_paren(with_parens) + pre, post = ('(', ')') if with_parens else ('', '') + return "{}{}{}".format( + pre, self.JOINSTR.join(n.to_string() for n in self.params), post) + + def _should_use_paren(self, optval): + if optval is not None: + return optval + return len(self.params) > 1 + + def __str__(self): + return self.to_string() + + +class BaseNode(Node): + def __init__(self, s): + super(BaseNode, self).__init__() + self.s = str(s) + + def to_string(self, with_parens=None): + return self.s + + +class IntersectNode(Node): + """ + Create an intersection node. All children need to be satisfied in order for + this node to evaluate as true + """ + JOINSTR = ' ' + + +class UnionNode(Node): + """ + Create a union node. Any of the children need to be satisfied in order for + this node to evaluate as true + """ + JOINSTR = '|' + + +class DisjunctNode(IntersectNode): + """ + Create a disjunct node. In order for this node to be true, all of its + children must evaluate to false + """ + def to_string(self, with_parens=None): + with_parens = self._should_use_paren(with_parens) + ret = super(DisjunctNode, self).to_string(with_parens=False) + if with_parens: + return '(-' + ret + ')' + else: + return '-' + ret + + +class DistjunctUnion(DisjunctNode): + """ + This node is true if *all* of its children are false. This is equivalent to + ``` + disjunct(union(...)) + ``` + """ + JOINSTR = '|' + + +class OptionalNode(IntersectNode): + """ + Create an optional node. If this nodes evaluates to true, then the document + will be rated higher in score/rank. + """ + def to_string(self, with_parens=None): + with_parens = self._should_use_paren(with_parens) + ret = super(OptionalNode, self).to_string(with_parens=False) + if with_parens: + return '(~' + ret + ')' + else: + return '~' + ret + + +def intersect(*args, **kwargs): + return IntersectNode(*args, **kwargs) + + +def union(*args, **kwargs): + return UnionNode(*args, **kwargs) + + +def disjunct(*args, **kwargs): + return DisjunctNode(*args, **kwargs) + + +def disjunct_union(*args, **kwargs): + return DistjunctUnion(*args, **kwargs) + + +def querystring(*args, **kwargs): + return intersect(*args, **kwargs).to_string() diff --git a/redisearch/reducers.py b/redisearch/reducers.py new file mode 100644 index 0000000..4d2e290 --- /dev/null +++ b/redisearch/reducers.py @@ -0,0 +1,134 @@ +from .aggregation import Reducer, SortDirection + + +class FieldOnlyReducer(Reducer): + def __init__(self, field): + super(FieldOnlyReducer, self).__init__(field) + self._field = field + + +class count(Reducer): + """ + Counts the number of results in the group + """ + NAME = 'COUNT' + + def __init__(self): + super(count, self).__init__() + + +class sum(FieldOnlyReducer): + """ + Calculates the sum of all the values in the given fields within the group + """ + NAME = 'SUM' + + def __init__(self, field): + super(sum, self).__init__(field) + + +class min(FieldOnlyReducer): + """ + Calculates the smallest value in the given field within the group + """ + NAME = 'MIN' + + def __init__(self, field): + super(min, self).__init__(field) + + +class max(FieldOnlyReducer): + """ + Calculates the largest value in the given field within the group + """ + NAME = 'MAX' + + def __init__(self, field): + super(max, self).__init__(field) + + +class avg(FieldOnlyReducer): + """ + Calculates the mean value in the given field within the group + """ + NAME = 'AVG' + + def __init__(self, field): + super(avg, self).__init__(field) + + +class count_distinct(FieldOnlyReducer): + """ + Calculate the number of distinct values contained in all the results in + the group for the given field + """ + NAME = 'COUNT_DISTINCT' + + def __init__(self, field): + super(count_distinct, self).__init__(field) + + +class count_distinctish(FieldOnlyReducer): + """ + Calculate the number of distinct values contained in all the results in the + group for the given field. This uses a faster algorithm than + `count_distinct` but is less accurate + """ + name = 'COUNT_DISTINCTISH' + + +class quantile(Reducer): + """ + Return the value for the nth percentile within the range of values for the + field within the group. + """ + NAME = 'QUANTILE' + + def __init__(self, field, pct): + super(quantile, self).__init__(field, pct) + self._field = field + + +class stddev(FieldOnlyReducer): + """ + Return the standard deviation for the values within the group + """ + name = 'STDDEV' + + def __init__(self, field): + super(stddev, self).__init__(field) + + +class first_value(Reducer): + """ + Selects the first value within the group according to sorting parameters + """ + NAME = 'first_value' + + def __init__(self, field, *byfields): + """ + Selects the first value of the given field within the group. + + ### Parameter + + - **field**: Source field used for the value + - **byfields**: How to sort the results. This can be either the + *class* of `aggregation.Asc` or `aggregation.Desc` in which + case the field `field` is also used as the sort input. + + `byfields` can also be one or more *instances* of `Asc` or `Desc` + indicating the sort order for these fields + """ + fieldstrs = [] + if len(byfields) == 1 and isinstance(byfields[0], type) and \ + issubclass(byfields[0], SortDirection): + byfields = [byfields[0](field)] + + for f in byfields: + fieldstrs += [f.field, f.DIRSTRING] + + args = [field] + if fieldstrs: + args += ['BY'] + fieldstrs + super(first_value, self).__init__(*args) + self._field = field From a525feb101f859ff4863d9ce991266ae06ff9fe2 Mon Sep 17 00:00:00 2001 From: Mark Nunberg Date: Tue, 20 Feb 2018 15:53:24 -0500 Subject: [PATCH 2/5] Some tests for the builder --- redisearch/querystring.py | 22 ++++++++++++---------- test/test_builder.py | 34 ++++++++++++++++++++++++++++++++++ 2 files changed, 46 insertions(+), 10 deletions(-) create mode 100644 test/test_builder.py diff --git a/redisearch/querystring.py b/redisearch/querystring.py index 86b951e..4dffc03 100644 --- a/redisearch/querystring.py +++ b/redisearch/querystring.py @@ -8,7 +8,9 @@ def tags(*t): - **t**: Tags to search for """ - return '{' + ','.join(t) + '}' + if not t: + raise ValueError('At least one tag must be specified') + return TagValue(*t) def between(a, b, inclusive_min=True, inclusive_max=True): @@ -83,19 +85,19 @@ class RangeValue(Value): combinable = False def __init__(self, a, b, inclusive_min=False, inclusive_max=False): - self.range = [a,b] + if a is None: + a = '-inf' + if b is None: + b = 'inf' + self.range = [str(a), str(b)] self.inclusive_min = inclusive_min self.inclusive_max = inclusive_max def to_string(self): a, b = self.range - a = a if a is not None else '-inf' - b = b if b is not None else 'inf' - if not self.inclusive_min: - a = '(' + a - if not self.inclusive_max: - b = '(' + b - return '[{} {}]'.format(a, b) + return '[{1}{0[0]} {2}{0[1]}]'.format(self.range, + '(' if not self.inclusive_min else '', + '(' if not self.inclusive_max else '',) class ScalarValue(Value): @@ -115,7 +117,7 @@ def __init__(self, *tags): self.tags = tags def to_string(self): - return '{' + ','.join(self.tags) + '}' + return '{' + ' | '.join(str(t) for t in self.tags) + '}' class GeoValue(Value): diff --git a/test/test_builder.py b/test/test_builder.py new file mode 100644 index 0000000..90c2dab --- /dev/null +++ b/test/test_builder.py @@ -0,0 +1,34 @@ +from unittest import TestCase +import redisearch.aggregation as a +import redisearch.querystring as q +import redisearch.reducers as r + +class QueryBuilderTest(TestCase): + def testBetween(self): + b = q.between(1, 10) + self.assertEqual('[1 10]', b.to_string()) + b = q.between(None, 10) + self.assertEqual('[-inf 10]', b.to_string()) + b = q.between(1, 10, inclusive_min=False) + self.assertEqual('[(1 10]', b.to_string()) + + def testTags(self): + self.assertRaises(ValueError, q.tags) + self.assertEqual('{1 | 2 | 3}', q.tags(1,2,3).to_string()) + self.assertEqual('{foo}', q.tags('foo').to_string()) + + def testUnion(self): + u = q.union() + self.assertEqual('', u.to_string()) + u = q.union(foo='fooval', bar='barval') + self.assertEqual('(@foo:fooval|@bar:barval)', u.to_string()) + u = q.union(q.intersect(foo=1, bar=2), q.intersect(foo=3, bar=4)) + self.assertEqual('((@foo:1 @bar:2)|(@foo:3 @bar:4))', u.to_string()) + + def testSpecialNodes(self): + u = q.union(num=q.between(1, 10)) + self.assertEqual('@num:[1 10]', u.to_string()) + u = q.union(num=[q.between(1, 10), q.between(100, 200)]) + self.assertEqual('(@num:[1 10]|@num:[100 200])', u.to_string()) + u = q.union(num=[q.tags('t1', 't2', 't3'), q.tags('t100', 't200', 't300')]) + self.assertEqual('(@num:{t1 | t2 | t3}|@num:{t100 | t200 | t300})', u.to_string()) From a3a80a603e16db23295fe820d73b12e44d9ea881 Mon Sep 17 00:00:00 2001 From: Mark Nunberg Date: Thu, 15 Mar 2018 16:05:12 -0400 Subject: [PATCH 3/5] bomb --- redisearch/aggregation.py | 73 ++++++++++++++++++------------------ redisearch/querystring.py | 6 +++ redisearch/reducers.py | 2 +- test/test_builder.py | 78 +++++++++++++++++++++++++++++++++------ 4 files changed, 110 insertions(+), 49 deletions(-) diff --git a/redisearch/aggregation.py b/redisearch/aggregation.py index 10a8eae..b2cb01f 100644 --- a/redisearch/aggregation.py +++ b/redisearch/aggregation.py @@ -1,6 +1,18 @@ FIELDNAME = object() +class Limit(object): + def __init__(self, offset=0, count=0): + self.offset = offset + self.count = count + + def build_args(self): + if self.count: + return ['LIMIT', str(self.offset), str(self.count)] + else: + return [] + + class Reducer(object): """ Base reducer object for all reducers. @@ -71,22 +83,20 @@ class Group(object): """ This object automatically created in the `AggregateRequest.group_by()` """ - def __init__(self, *fields): - self.fields = fields - self.reducers = [] - self.limit = [0, 0] + def __init__(self, fields, reducers): + if not fields: + raise ValueError('need at least one field') + if not reducers: + raise ValueError('Need at least one reducer') - def add_reducer(self, reducer): - self.reducers.append(reducer) + fields = [fields] if isinstance(fields, basestring) else fields + reducers = [reducers] if isinstance(reducers, Reducer) else reducers - def validate(self): - if not self.reducers: - raise ValueError('Need at least one reducer') + self.fields = fields + self.reducers = reducers + self.limit = Limit() def build_args(self): - self.validate() - if not self.fields: - raise Exception('No fields to group by') ret = [str(len(self.fields))] ret.extend(self.fields) for reducer in self.reducers: @@ -118,7 +128,7 @@ def __init__(self, query='*'): self._groups = [] self._projections = [] self._loadfields = [] - self._limit = [] + self._limit = Limit() self._sortby = [] self._max = 0 @@ -146,16 +156,7 @@ def group_by(self, fields, *reducers): - **reducers**: One or more reducers. Reducers may be found in the `aggregation` module. """ - if isinstance(fields, basestring): - fields = [fields] - group = Group(*fields) - - if not reducers: - raise ValueError("Must pass at least one reducer") - - for reducer in reducers: - group.add_reducer(reducer) - + group = Group(fields, reducers) self._groups.append(group) return self @@ -217,13 +218,14 @@ def limit(self, offset, num): `sort_by()` instead. """ + limit = Limit(offset, num) if self._groups: - self._groups[-1].limit = [offset, num] + self._groups[-1].limit = limit else: - self._limit = [offset, num] + self._limit = limit return self - def sort_by(self, fields, max=0): + def sort_by(self, *fields, **kwargs): """ Indicate how the results should be sorted. This can also be used for *top-N* style queries @@ -251,7 +253,7 @@ def sort_by(self, fields, max=0): .sort_by(Desc('@paid'), max=10) ``` """ - self._max = max + self._max = kwargs.get('max', 0) if isinstance(fields, (basestring, SortDirection)): fields = [fields] for f in fields: @@ -261,12 +263,13 @@ def sort_by(self, fields, max=0): self._sortby.append(f) return self - def validate(self): - if not self._groups: - raise ValueError('Request requires at least one group') + def _limit_2_args(self, limit): + if limit[1]: + return ['LIMIT'] + [str(x) for x in limit] + else: + return [] def build_args(self): - self.validate() # @foo:bar ... ret = [self._query] if self._loadfields: @@ -274,10 +277,7 @@ def build_args(self): ret.append(str(len(self._loadfields))) ret.extend(self._loadfields) for group in self._groups: - ret += ['GROUPBY'] - ret.extend(group.build_args()) - if group.limit: - ret += ['LIMIT'] + [str(x) for x in group.limit] + ret += ['GROUPBY'] + group.build_args() + group.limit.build_args() for alias, projector in self._projections: ret += ['APPLY', projector] if alias: @@ -289,8 +289,7 @@ def build_args(self): if self._max: ret += ['MAX', str(self._max)] - if self._limit: - ret += ['LIMIT'] + [str(x) for x in self._limit] + ret += self._limit.build_args() return ret diff --git a/redisearch/querystring.py b/redisearch/querystring.py index 4dffc03..b0584cc 100644 --- a/redisearch/querystring.py +++ b/redisearch/querystring.py @@ -80,6 +80,12 @@ def make_value(v): return v return ScalarValue(v) + def to_string(self): + raise NotImplemented() + + def __str__(self): + return self.to_string() + class RangeValue(Value): combinable = False diff --git a/redisearch/reducers.py b/redisearch/reducers.py index 4d2e290..9e8ae6d 100644 --- a/redisearch/reducers.py +++ b/redisearch/reducers.py @@ -103,7 +103,7 @@ class first_value(Reducer): """ Selects the first value within the group according to sorting parameters """ - NAME = 'first_value' + NAME = 'FIRST_VALUE' def __init__(self, field, *byfields): """ diff --git a/test/test_builder.py b/test/test_builder.py index 90c2dab..b594ee0 100644 --- a/test/test_builder.py +++ b/test/test_builder.py @@ -6,29 +6,85 @@ class QueryBuilderTest(TestCase): def testBetween(self): b = q.between(1, 10) - self.assertEqual('[1 10]', b.to_string()) + self.assertEqual('[1 10]', str(b)) b = q.between(None, 10) - self.assertEqual('[-inf 10]', b.to_string()) + self.assertEqual('[-inf 10]', str(b)) b = q.between(1, 10, inclusive_min=False) - self.assertEqual('[(1 10]', b.to_string()) + self.assertEqual('[(1 10]', str(b)) def testTags(self): self.assertRaises(ValueError, q.tags) - self.assertEqual('{1 | 2 | 3}', q.tags(1,2,3).to_string()) - self.assertEqual('{foo}', q.tags('foo').to_string()) + self.assertEqual('{1 | 2 | 3}', str(q.tags(1,2,3))) + self.assertEqual('{foo}', str(q.tags('foo'))) def testUnion(self): u = q.union() - self.assertEqual('', u.to_string()) + self.assertEqual('', str(u)) u = q.union(foo='fooval', bar='barval') - self.assertEqual('(@foo:fooval|@bar:barval)', u.to_string()) + self.assertEqual('(@foo:fooval|@bar:barval)', str(u)) u = q.union(q.intersect(foo=1, bar=2), q.intersect(foo=3, bar=4)) - self.assertEqual('((@foo:1 @bar:2)|(@foo:3 @bar:4))', u.to_string()) + self.assertEqual('((@foo:1 @bar:2)|(@foo:3 @bar:4))', str(u)) def testSpecialNodes(self): u = q.union(num=q.between(1, 10)) - self.assertEqual('@num:[1 10]', u.to_string()) + self.assertEqual('@num:[1 10]', str(u)) u = q.union(num=[q.between(1, 10), q.between(100, 200)]) - self.assertEqual('(@num:[1 10]|@num:[100 200])', u.to_string()) + self.assertEqual('(@num:[1 10]|@num:[100 200])', str(u)) u = q.union(num=[q.tags('t1', 't2', 't3'), q.tags('t100', 't200', 't300')]) - self.assertEqual('(@num:{t1 | t2 | t3}|@num:{t100 | t200 | t300})', u.to_string()) + self.assertEqual('(@num:{t1 | t2 | t3}|@num:{t100 | t200 | t300})', str(u)) + + def testGroup(self): + # Check the group class on its own + self.assertRaises(ValueError, a.Group, [], []) + self.assertRaises(ValueError, a.Group, ['foo'], []) + self.assertRaises(ValueError, a.Group, [], r.count()) + + # Single field, single reducer + g = a.Group('foo', r.count()) + ret = g.build_args() + self.assertEqual(['1', 'foo', 'REDUCE', 'COUNT', '0'], ret) + + # Multiple fields, single reducer + g = a.Group(['foo', 'bar'], r.count()) + self.assertEqual(['2', 'foo', 'bar', 'REDUCE', 'COUNT', '0'], + g.build_args()) + + # Multiple fields, multiple reducers + g = a.Group(['foo', 'bar'], [r.count(), r.count_distinct('@fld1')]) + self.assertEqual(['2', 'foo', 'bar', 'REDUCE', 'COUNT', '0', 'REDUCE', 'COUNT_DISTINCT', '1', '@fld1'], + g.build_args()) + + def testAggRequest(self): + req = a.AggregateRequest() + self.assertEqual(['*'], req.build_args()) + + # Test with group_by + req = a.AggregateRequest().group_by('@foo', r.count()) + self.assertEqual(['*', 'GROUPBY', '1', '@foo', 'REDUCE', 'COUNT', '0'], req.build_args()) + + # Test with limit + req = a.AggregateRequest().\ + group_by('@foo', r.count()).\ + sort_by('@foo') + self.assertEqual(['*', 'GROUPBY', '1', '@foo', 'REDUCE', 'COUNT', '0', 'SORTBY', '1', + '@foo'], req.build_args()) + + # Test with sort_by + req = a.AggregateRequest().group_by('@foo', r.count()).sort_by('@date') + # print req.build_args() + self.assertEqual(['*', 'GROUPBY', '1', '@foo', 'REDUCE', 'COUNT', '0', 'SORTBY', '1', '@date'], + req.build_args()) + + req = a.AggregateRequest().group_by('@foo', r.count()).sort_by(a.Desc('@date')) + # print req.build_args() + self.assertEqual(['*', 'GROUPBY', '1', '@foo', 'REDUCE', 'COUNT', '0', 'SORTBY', '2', '@date', 'DESC'], + req.build_args()) + + req = a.AggregateRequest().group_by('@foo', r.count()).sort_by(a.Desc('@date'), a.Asc('@time')) + # print req.build_args() + self.assertEqual(['*', 'GROUPBY', '1', '@foo', 'REDUCE', 'COUNT', '0', 'SORTBY', '4', '@date', 'DESC', '@time', 'ASC'], + req.build_args()) + + req = a.AggregateRequest().group_by('@foo', r.count()).sort_by(a.Desc('@date'), a.Asc('@time'), max=10) + self.assertEqual(['*', 'GROUPBY', '1', '@foo', 'REDUCE', 'COUNT', '0', 'SORTBY', '4', '@date', 'DESC', '@time', 'ASC', 'MAX', '10'], + req.build_args()) From 11d0b05321597bdee76ddf65af43d181d161f626 Mon Sep 17 00:00:00 2001 From: Mark Nunberg Date: Wed, 25 Apr 2018 11:40:36 -0400 Subject: [PATCH 4/5] Add new reducers --- redisearch/reducers.py | 28 +++++++++++++++++++++++++++- test/test_builder.py | 18 ++++++++++++++++++ 2 files changed, 45 insertions(+), 1 deletion(-) diff --git a/redisearch/reducers.py b/redisearch/reducers.py index 9e8ae6d..0aa677b 100644 --- a/redisearch/reducers.py +++ b/redisearch/reducers.py @@ -56,6 +56,14 @@ class avg(FieldOnlyReducer): def __init__(self, field): super(avg, self).__init__(field) +class tolist(FieldOnlyReducer): + """ + Returns all the matched properties in a list + """ + NAME = 'TOLIST' + + def __init__(self, field): + super(tolist, self).__init__(field) class count_distinct(FieldOnlyReducer): """ @@ -85,7 +93,7 @@ class quantile(Reducer): NAME = 'QUANTILE' def __init__(self, field, pct): - super(quantile, self).__init__(field, pct) + super(quantile, self).__init__(field, str(pct)) self._field = field @@ -132,3 +140,21 @@ def __init__(self, field, *byfields): args += ['BY'] + fieldstrs super(first_value, self).__init__(*args) self._field = field + + +class random_sample(Reducer): + """ + Returns a random sample of items from the dataset, from the given property + """ + NAME = 'RANDOM_SAMPLE' + + def __init__(self, field, size): + """ + ### Parameter + + **field**: Field to sample from + **size**: Return this many items (can be less) + """ + args = [field, str(size)] + super(random_sample, self).__init__(*args) + self._field = field \ No newline at end of file diff --git a/test/test_builder.py b/test/test_builder.py index b594ee0..b621890 100644 --- a/test/test_builder.py +++ b/test/test_builder.py @@ -88,3 +88,21 @@ def testAggRequest(self): req = a.AggregateRequest().group_by('@foo', r.count()).sort_by(a.Desc('@date'), a.Asc('@time'), max=10) self.assertEqual(['*', 'GROUPBY', '1', '@foo', 'REDUCE', 'COUNT', '0', 'SORTBY', '4', '@date', 'DESC', '@time', 'ASC', 'MAX', '10'], req.build_args()) + + def test_reducers(self): + self.assertEqual((), r.count().args) + self.assertEqual(('f1',), r.sum('f1').args) + self.assertEqual(('f1',), r.min('f1').args) + self.assertEqual(('f1',), r.max('f1').args) + self.assertEqual(('f1',), r.avg('f1').args) + self.assertEqual(('f1',), r.tolist('f1').args) + self.assertEqual(('f1',), r.count_distinct('f1').args) + self.assertEqual(('f1',), r.count_distinctish('f1').args) + self.assertEqual(('f1', '0.95'), r.quantile('f1', 0.95).args) + self.assertEqual(('f1',), r.stddev('f1').args) + + self.assertEqual(('f1',), r.first_value('f1').args) + self.assertEqual(('f1', 'BY', 'f2', 'ASC'), r.first_value('f1', a.Asc('f2')).args) + self.assertEqual(('f1', 'BY', 'f1', 'ASC'), r.first_value('f1', a.Asc).args) + + self.assertEqual(('f1', '50'), r.random_sample('f1', 50).args) \ No newline at end of file From 36c2d2f6413e0eb19d47f0265400f5abc1bc5dd3 Mon Sep 17 00:00:00 2001 From: Mark Nunberg Date: Wed, 25 Apr 2018 12:49:24 -0400 Subject: [PATCH 5/5] Add WITHSCHEMA, WITHCURSOR APIs --- redisearch/aggregation.py | 61 +++++++++++++++++++++++++++++++++++++++ redisearch/client.py | 46 +++++++++++++++++++++++++++-- 2 files changed, 104 insertions(+), 3 deletions(-) diff --git a/redisearch/aggregation.py b/redisearch/aggregation.py index b2cb01f..fead332 100644 --- a/redisearch/aggregation.py +++ b/redisearch/aggregation.py @@ -131,6 +131,9 @@ def __init__(self, query='*'): self._limit = Limit() self._sortby = [] self._max = 0 + self._with_schema = False + self._verbatim = False + self._cursor = [] def load(self, *fields): """ @@ -263,6 +266,27 @@ def sort_by(self, *fields, **kwargs): self._sortby.append(f) return self + def with_schema(self): + """ + If set, the `schema` property will contain a list of `[field, type]` + entries in the result object. + """ + self._with_schema = True + return self + + def verbatim(self): + self._verbatim = True + return self + + def cursor(self, count=0, max_idle=0.0): + args = ['WITHCURSOR'] + if count: + args += ['COUNT', str(count)] + if max_idle: + args += ['MAXIDLE', str(max_idle * 1000)] + self._cursor = args + return self + def _limit_2_args(self, limit): if limit[1]: return ['LIMIT'] + [str(x) for x in limit] @@ -272,6 +296,16 @@ def _limit_2_args(self, limit): def build_args(self): # @foo:bar ... ret = [self._query] + + if self._with_schema: + ret.append('WITHSCHEMA') + + if self._verbatim: + ret.append('VERBATIM') + + if self._cursor: + ret += self._cursor + if self._loadfields: ret.append('LOAD') ret.append(str(len(self._loadfields))) @@ -294,3 +328,30 @@ def build_args(self): return ret +class Cursor(object): + def __init__(self, cid): + self.cid = cid + self.max_idle = 0 + self.count = 0 + + def build_args(self): + args = [str(self.cid)] + if self.max_idle: + args += ['MAXIDLE', str(self.max_idle)] + if self.count: + args += ['COUNT', str(self.count)] + return args + + +class AggregateResult(object): + def __init__(self, rows, cursor, schema): + self.rows = rows + self.cursor = cursor + self.schema = schema + + def __repr__(self): + return "<{} at 0x{:x} Rows={}, Cursor={}>".format( + self.__class__.__name__, + id(self), + len(self.rows), + self.cursor.cid if self.cursor else -1) \ No newline at end of file diff --git a/redisearch/client.py b/redisearch/client.py index acf0f04..ed57056 100644 --- a/redisearch/client.py +++ b/redisearch/client.py @@ -4,7 +4,7 @@ from .document import Document from .result import Result from .query import Query, Filter -from .aggregate_request import AggregateRequest +from .aggregation import AggregateRequest, AggregateResult, Cursor class Field(object): @@ -104,6 +104,7 @@ class Client(object): EXPLAIN_CMD = 'FT.EXPLAIN' DEL_CMD = 'FT.DEL' AGGREGATE_CMD = 'FT.AGGREGATE' + CURSOR_CMD = 'FT.CURSOR' NOOFFSETS = 'NOOFFSETS' @@ -318,5 +319,44 @@ def explain(self, query): return self.redis.execute_command(self.EXPLAIN_CMD, *args) def aggregate(self, query): - cmd = [self.AGGREGATE_CMD, self.index_name] + query.build_args() - return self.redis.execute_command(*cmd) \ No newline at end of file + """ + Issue an aggregation query + + ### Parameters + + **query**: This can be either an `AggeregateRequest`, or a `Cursor` + + An `AggregateResult` object is returned. You can access the rows from its + `rows` property, which will always yield the rows of the result + """ + if isinstance(query, AggregateRequest): + has_schema = query._with_schema + has_cursor = bool(query._cursor) + cmd = [self.AGGREGATE_CMD, self.index_name] + query.build_args() + elif isinstance(query, Cursor): + has_schema = False + has_cursor = True + cmd = [self.CURSOR_CMD, 'READ', self.index_name] + query.build_args() + else: + raise ValueError('Bad query', query) + + raw = self.redis.execute_command(*cmd) + if has_cursor: + if isinstance(query, Cursor): + query.cid = raw[1] + cursor = query + else: + cursor = Cursor(raw[1]) + raw = raw[0] + else: + cursor = None + + if query._with_schema: + schema = raw[0] + rows = raw[2:] + else: + schema = None + rows = raw[1:] + + res = AggregateResult(rows, cursor, schema) + return res \ No newline at end of file