Permalink
Browse files

Merge pull request #24 from alphagov/sort_and_limit

Sort and limit
  • Loading branch information...
2 parents 7680db6 + c513fb5 commit 694427e4b003f070a205f95739741245dc96a7b6 @pbadenski pbadenski committed Mar 28, 2013
@@ -26,31 +26,39 @@ def _period_group(self, doc):
'_count': doc['_count']
}
- def execute_weekly_group_query(self, key2, query):
- key1 = '_week_start_at'
+ def execute_weekly_group_query(self, group_by, query, sort=None,
+ limit=None):
+ period_key = '_week_start_at'
result = []
- cursor = self.repository.multi_group(key1, key2, query)
+ cursor = self.repository.multi_group(
+ group_by, period_key, query, sort=sort, limit=limit)
for doc in cursor:
- week_start_at = utc(doc.pop('_week_start_at'))
- doc['_start_at'] = week_start_at
- doc['_end_at'] = week_start_at + datetime.timedelta(days=7)
+ doc['values'] = doc.pop('_subgroup')
+
+ for item in doc['values']:
+ start_at = utc(item.pop("_week_start_at"))
+ item.update({
+ "_start_at": start_at,
+ "_end_at": start_at + datetime.timedelta(days=7)
+ })
+
result.append(doc)
return result
- def execute_grouped_query(self, group_by, query):
- cursor = self.repository.group(group_by, query)
+ def execute_grouped_query(self, group_by, query, sort=None, limit=None):
+ cursor = self.repository.group(group_by, query, sort, limit)
result = [{group_by: doc[group_by], '_count': doc['_count']} for doc
in cursor]
return result
- def execute_period_query(self, query):
- cursor = self.repository.group('_week_start_at', query)
+ def execute_period_query(self, query, limit=None):
+ cursor = self.repository.group('_week_start_at', query, limit=limit)
result = [self._period_group(doc) for doc in cursor]
return result
- def execute_query(self, query):
+ def execute_query(self, query, sort=None, limit=None):
result = []
- cursor = self.repository.find(query)
+ cursor = self.repository.find(query, sort=sort, limit=limit)
for doc in cursor:
# stringify the id
doc['_id'] = str(doc['_id'])
@@ -61,14 +69,19 @@ def execute_query(self, query):
def query(self, **params):
query = build_query(**params)
+ sort_by = params.get('sort_by')
+ group_by = params.get('group_by')
+ limit = params.get('limit')
- if 'group_by' in params and 'period' in params:
- result = self.execute_weekly_group_query(params['group_by'], query)
- elif 'group_by' in params:
- result = self.execute_grouped_query(params['group_by'], query)
+ if group_by and 'period' in params:
+ result = self.execute_weekly_group_query(
+ group_by, query, sort_by, limit)
+ elif group_by:
+ result = self.execute_grouped_query(
+ group_by, query, sort_by, limit)
elif 'period' in params:
- result = self.execute_period_query(query)
+ result = self.execute_period_query(query, limit)
else:
- result = self.execute_query(query)
+ result = self.execute_query(query, sort_by, limit)
return result
@@ -1,7 +1,5 @@
-from itertools import groupby
from bson import Code
-from pprint import pprint
-from pymongo import MongoClient
+import pymongo
def build_query(**params):
@@ -26,7 +24,7 @@ def ensure_has_timestamp(q):
class Database(object):
def __init__(self, host, port, name):
- self._mongo = MongoClient(host, port)
+ self._mongo = pymongo.MongoClient(host, port)
self.name = name
def alive(self):
@@ -48,34 +46,45 @@ def __init__(self, collection):
def name(self):
return self._collection.name
- def find(self, query):
- return self._collection.find(query).sort('_timestamp', -1)
-
- def group(self, group_by, query):
- return self._group([group_by], query)
+ def _validate_sort(self, sort):
+ if len(sort) != 2:
+ raise InvalidSortError("Expected a key and direction")
+
+ if sort[1] not in ["ascending", "descending"]:
+ raise InvalidSortError(sort[1])
+
+ def find(self, query, sort=None, limit=None):
+ cursor = self._collection.find(query)
+ if sort:
+ self._validate_sort(sort)
+ else:
+ sort = ["_timestamp", "ascending"]
+ sort_options = {
+ "ascending": pymongo.ASCENDING,
+ "descending": pymongo.DESCENDING
+ }
+ cursor.sort(sort[0], sort_options[sort[1]])
+ if limit:
+ cursor.limit(limit)
+
+ return cursor
+
+ def group(self, group_by, query, sort=None, limit=None):
+ if sort:
+ self._validate_sort(sort)
+ return self._group([group_by], query, sort, limit)
def save(self, obj):
self._collection.save(obj)
- def multi_group(self, key1, key2, query):
+ def multi_group(self, key1, key2, query, sort=None, limit=None):
if key1 == key2:
raise GroupingError("Cannot group on two equal keys")
- results = self._group([key1, key2], query)
-
- output = nested_merge([key1, key2], results)
-
- result = []
- for key1_value, value in sorted(output.items()):
- result.append({
- key1: key1_value,
- "_count": sum(doc['_count'] for doc in value.values()),
- "_group_count": len(value),
- key2: value
- })
+ results = self._group([key1, key2], query, sort, limit)
- return result
+ return results
- def _group(self, keys, query):
+ def _group(self, keys, query, sort=None, limit=None):
results = self._collection.group(
key=keys,
condition=query,
@@ -88,26 +97,87 @@ def _group(self, keys, query):
for key in keys:
if result[key] is None:
return []
+
+ results = nested_merge(keys, results)
+
+ if sort:
+ sorters = {
+ "ascending": lambda a, b: cmp(a, b),
+ "descending": lambda a, b: cmp(b, a)
+ }
+ sorter = sorters[sort[1]]
+ try:
+ results.sort(cmp=sorter, key=lambda a: a[sort[0]])
+ except KeyError:
+ raise InvalidSortError('Invalid sort key {0}'.format(sort[0]))
+ if limit:
+ results = results[:limit]
+
return results
class GroupingError(ValueError):
pass
+class InvalidSortError(ValueError):
+ pass
+
+
def nested_merge(keys, results):
- output = {}
+ groups = []
for result in results:
- output = _inner_merge(output, keys, result)
- return output
+ groups = _merge(groups, keys, result)
+ return groups
+
+
+def _merge(groups, keys, result):
+ keys = list(keys)
+ key = keys.pop(0)
+ is_leaf = (len(keys) == 0)
+ value = result.pop(key)
+
+ group = _find_group(group for group in groups if group[key] == value)
+ if not group:
+ if is_leaf:
+ group = _new_leaf_node(key, value, result)
+ else:
+ group = _new_branch_node(key, value)
+ groups.append(group)
+
+ if not is_leaf:
+ _merge_and_sort_subgroup(group, keys, result)
+ _add_branch_node_counts(group)
+ return groups
+
+
+def _find_group(items):
+ """Return the first item in an iterator or None"""
+ try:
+ return next(items)
+ except StopIteration:
+ return
+
+
+def _new_branch_node(key, value):
+ """Create a new node that has further sub-groups"""
+ return {
+ key: value,
+ "_subgroup": []
+ }
+
+
+def _new_leaf_node(key, value, result):
+ """Create a new node that has no further sub-groups"""
+ result[key] = value
+ return result
+
+def _merge_and_sort_subgroup(group, keys, result):
+ group['_subgroup'] = _merge(group['_subgroup'], keys, result)
+ group['_subgroup'].sort(key=lambda d: d[keys[0]])
-def _inner_merge(output, keys, value):
- if len(keys) == 0:
- return value
- key = value.pop(keys[0])
- if key not in output:
- output[key] = {}
- output[key].update(_inner_merge(output[key], keys[1:], value))
- return output
+def _add_branch_node_counts(group):
+ group['_count'] = sum(doc.get('_count', 0) for doc in group['_subgroup'])
+ group['_group_count'] = len(group['_subgroup'])
@@ -1,29 +1,29 @@
from unittest import TestCase
from hamcrest import *
from backdrop.core.database import build_query
-from tests.support.test_helpers import d
+from tests.support.test_helpers import d_tz
class TestBuild_query(TestCase):
def test_build_query_with_start_at(self):
- query = build_query(start_at = d(2013, 3, 18, 18, 10, 05))
+ query = build_query(start_at = d_tz(2013, 3, 18, 18, 10, 05))
assert_that(query, is_(
- {"_timestamp": {"$gte": d(2013, 03, 18, 18, 10, 05)}}))
+ {"_timestamp": {"$gte": d_tz(2013, 03, 18, 18, 10, 05)}}))
def test_build_query_with_end_at(self):
- query = build_query(end_at = d(2012, 3, 17, 17, 10, 6))
+ query = build_query(end_at = d_tz(2012, 3, 17, 17, 10, 6))
assert_that(query, is_(
- {"_timestamp": {"$lt": d(2012, 3, 17, 17, 10, 6)}}))
+ {"_timestamp": {"$lt": d_tz(2012, 3, 17, 17, 10, 6)}}))
def test_build_query_with_start_and_end_at(self):
query = build_query(
- start_at = d(2012, 3, 17, 17, 10, 6),
- end_at = d(2012, 3, 19, 17, 10, 6)
+ start_at = d_tz(2012, 3, 17, 17, 10, 6),
+ end_at = d_tz(2012, 3, 19, 17, 10, 6)
)
assert_that(query, is_({
"_timestamp": {
- "$gte": d(2012, 3, 17, 17, 10, 6),
- "$lt": d(2012, 3, 19, 17, 10, 6)
+ "$gte": d_tz(2012, 3, 17, 17, 10, 6),
+ "$lt": d_tz(2012, 3, 19, 17, 10, 6)
}
}))
@@ -30,6 +30,7 @@ def parse_request_args(request_args):
if 'start_at' in request_args:
args['start_at'] = parse_time_string(request_args['start_at'])
+
if 'end_at' in request_args:
args['end_at'] = parse_time_string(request_args['end_at'])
@@ -44,6 +45,11 @@ def parse_request_args(request_args):
if 'group_by' in request_args:
args['group_by'] = request_args['group_by']
+ if 'sort_by' in request_args:
+ args['sort_by'] = request_args['sort_by'].split(':', 1)
+
+ if 'limit' in request_args:
+ args['limit'] = int(request_args['limit'])
return args
Oops, something went wrong.

0 comments on commit 694427e

Please sign in to comment.