Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions mongoengine/queryset/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1271,9 +1271,10 @@ def sum(self, field):
:param field: the field to sum over; use dot notation to refer to
embedded document fields
"""
db_field = self._fields_to_dbfields([field]).pop()
pipeline = [
{'$match': self._query},
{'$group': {'_id': 'sum', 'total': {'$sum': '$' + field}}}
{'$group': {'_id': 'sum', 'total': {'$sum': '$' + db_field}}}
]

# if we're performing a sum over a list field, we sum up all the
Expand All @@ -1300,9 +1301,10 @@ def average(self, field):
:param field: the field to average over; use dot notation to refer to
embedded document fields
"""
db_field = self._fields_to_dbfields([field]).pop()
pipeline = [
{'$match': self._query},
{'$group': {'_id': 'avg', 'total': {'$avg': '$' + field}}}
{'$group': {'_id': 'avg', 'total': {'$avg': '$' + db_field}}}
]

# if we're performing an average over a list field, we average out
Expand Down
28 changes: 28 additions & 0 deletions tests/queryset/queryset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2838,6 +2838,34 @@ def test_sum(self):
sum([a for a in ages if a >= 50])
)

def test_sum_over_db_field(self):
"""Ensure that a field mapped to a db field with a different name
can be summed over correctly.
"""
class UserVisit(Document):
num_visits = IntField(db_field='visits')

UserVisit.drop_collection()

UserVisit.objects.create(num_visits=10)
UserVisit.objects.create(num_visits=5)

self.assertEqual(UserVisit.objects.sum('num_visits'), 15)

def test_average_over_db_field(self):
"""Ensure that a field mapped to a db field with a different name
can have its average computed correctly.
"""
class UserVisit(Document):
num_visits = IntField(db_field='visits')

UserVisit.drop_collection()

UserVisit.objects.create(num_visits=20)
UserVisit.objects.create(num_visits=10)

self.assertEqual(UserVisit.objects.average('num_visits'), 15)

def test_embedded_average(self):
class Pay(EmbeddedDocument):
value = DecimalField()
Expand Down