Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ Changes in 0.9.X - DEV
- Updated URL and Email Field regex validators, added schemes argument to URLField validation. #652
- Removed get_or_create() deprecated since 0.8.0. #300
- Capped collection multiple of 256. #1011
- Added `BaseQuerySet.aggregate_sum` and `BaseQuerySet.aggregate_average` methods.

Changes in 0.9.0
================
Expand Down
42 changes: 42 additions & 0 deletions mongoengine/queryset/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1248,6 +1248,27 @@ def sum(self, field):
else:
return 0

def aggregate_sum(self, field):
"""Sum over the values of the specified field.

:param field: the field to sum over; use dot-notation to refer to
embedded document fields

This method is more performant than the regular `sum`, because it uses
the aggregation framework instead of map-reduce.
"""
result = self._document._get_collection().aggregate([
{ '$match': self._query },
{ '$group': { '_id': 'sum', 'total': { '$sum': '$' + field } } }
])
if IS_PYMONGO_3:
result = list(result)
else:
result = result.get('result')
if result:
return result[0]['total']
return 0

def average(self, field):
"""Average over the values of the specified field.

Expand Down Expand Up @@ -1303,6 +1324,27 @@ def average(self, field):
else:
return 0

def aggregate_average(self, field):
"""Average over the values of the specified field.

:param field: the field to average over; use dot-notation to refer to
embedded document fields

This method is more performant than the regular `average`, because it
uses the aggregation framework instead of map-reduce.
"""
result = self._document._get_collection().aggregate([
{ '$match': self._query },
{ '$group': { '_id': 'avg', 'total': { '$avg': '$' + field } } }
])
if IS_PYMONGO_3:
result = list(result)
else:
result = result.get('result')
if result:
return result[0]['total']
return 0

def item_frequencies(self, field, normalize=False, map_reduce=True):
"""Returns a dictionary of all items present in a field across
the whole queried set of documents, and their corresponding frequency.
Expand Down
68 changes: 62 additions & 6 deletions tests/queryset/queryset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2706,26 +2706,58 @@ def test_average(self):

avg = float(sum(ages)) / (len(ages) + 1) # take into account the 0
self.assertAlmostEqual(int(self.Person.objects.average('age')), avg)
self.assertAlmostEqual(
int(self.Person.objects.aggregate_average('age')), avg
)

self.Person(name='ageless person').save()
self.assertEqual(int(self.Person.objects.average('age')), avg)
self.assertEqual(
int(self.Person.objects.aggregate_average('age')), avg
)

# dot notation
self.Person(
name='person meta', person_meta=self.PersonMeta(weight=0)).save()
self.assertAlmostEqual(
int(self.Person.objects.average('person_meta.weight')), 0)
self.assertAlmostEqual(
int(self.Person.objects.aggregate_average('person_meta.weight')),
0
)

for i, weight in enumerate(ages):
self.Person(
name='test meta%i', person_meta=self.PersonMeta(weight=weight)).save()

self.assertAlmostEqual(
int(self.Person.objects.average('person_meta.weight')), avg)
int(self.Person.objects.average('person_meta.weight')), avg
)
self.assertAlmostEqual(
int(self.Person.objects.aggregate_average('person_meta.weight')),
avg
)

self.Person(name='test meta none').save()
self.assertEqual(
int(self.Person.objects.average('person_meta.weight')), avg)
int(self.Person.objects.average('person_meta.weight')), avg
)
self.assertEqual(
int(self.Person.objects.aggregate_average('person_meta.weight')),
avg
)

# test summing over a filtered queryset
over_50 = [a for a in ages if a >= 50]
avg = float(sum(over_50)) / len(over_50)
self.assertEqual(
self.Person.objects.filter(age__gte=50).average('age'),
avg
)
self.assertEqual(
self.Person.objects.filter(age__gte=50).aggregate_average('age'),
avg
)

def test_sum(self):
"""Ensure that field can be summed over correctly.
Expand All @@ -2734,20 +2766,44 @@ def test_sum(self):
for i, age in enumerate(ages):
self.Person(name='test%s' % i, age=age).save()

self.assertEqual(int(self.Person.objects.sum('age')), sum(ages))
self.assertEqual(self.Person.objects.sum('age'), sum(ages))
self.assertEqual(
self.Person.objects.aggregate_sum('age'), sum(ages)
)

self.Person(name='ageless person').save()
self.assertEqual(int(self.Person.objects.sum('age')), sum(ages))
self.assertEqual(self.Person.objects.sum('age'), sum(ages))
self.assertEqual(
self.Person.objects.aggregate_sum('age'), sum(ages)
)

for i, age in enumerate(ages):
self.Person(name='test meta%s' %
i, person_meta=self.PersonMeta(weight=age)).save()

self.assertEqual(
int(self.Person.objects.sum('person_meta.weight')), sum(ages))
self.Person.objects.sum('person_meta.weight'), sum(ages)
)
self.assertEqual(
self.Person.objects.aggregate_sum('person_meta.weight'),
sum(ages)
)

self.Person(name='weightless person').save()
self.assertEqual(int(self.Person.objects.sum('age')), sum(ages))
self.assertEqual(self.Person.objects.sum('age'), sum(ages))
self.assertEqual(
self.Person.objects.aggregate_sum('age'), sum(ages)
)

# test summing over a filtered queryset
self.assertEqual(
self.Person.objects.filter(age__gte=50).sum('age'),
sum([a for a in ages if a >= 50])
)
self.assertEqual(
self.Person.objects.filter(age__gte=50).aggregate_sum('age'),
sum([a for a in ages if a >= 50])
)

def test_embedded_average(self):
class Pay(EmbeddedDocument):
Expand Down