Skip to content

Commit

Permalink
Added support for aggregating and annotating over a subset of GFK'd
Browse files Browse the repository at this point in the history
items (i.e. today's most commented on blog entries)
  • Loading branch information
coleifer committed May 31, 2010
1 parent a65e3ce commit b41d9be
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 18 deletions.
4 changes: 3 additions & 1 deletion generic_aggregation/tests/models.py
@@ -1,3 +1,5 @@
import datetime

from django.contrib.contenttypes.generic import GenericForeignKey
from django.contrib.contenttypes.models import ContentType
from django.db import models
Expand All @@ -11,7 +13,7 @@ def __unicode__(self):

class Rating(models.Model):
rating = models.IntegerField()
created = models.DateTimeField(auto_now_add=True)
created = models.DateTimeField(default=datetime.datetime.now)
object_id = models.IntegerField()
content_type = models.ForeignKey(ContentType)
content_object = GenericForeignKey(ct_field='content_type', fk_field='object_id')
Expand Down
37 changes: 29 additions & 8 deletions generic_aggregation/tests/tests.py
Expand Up @@ -17,7 +17,7 @@ def setUp(self):
Rating.objects.create(content_object=self.apple, rating=5)
Rating.objects.create(content_object=self.apple, rating=3)
Rating.objects.create(content_object=self.apple, rating=1, created=dt)
Rating.objects.create(content_object=self.apple, rating=7, created=dt)
Rating.objects.create(content_object=self.apple, rating=3, created=dt)

Rating.objects.create(content_object=self.orange, rating=4)
Rating.objects.create(content_object=self.orange, rating=3)
Expand All @@ -38,14 +38,14 @@ def test_annotation(self):
annotated_qs = generic_annotate(Food.objects.all(), Rating.content_object, 'rating', models.Sum)
self.assertEqual(annotated_qs.count(), 2)

food_a, food_b = annotated_qs

self.assertEqual(food_a.score, 16)
self.assertEqual(food_a.name, 'apple')
food_b, food_a = annotated_qs

self.assertEqual(food_b.score, 15)
self.assertEqual(food_b.name, 'orange')

self.assertEqual(food_a.score, 12)
self.assertEqual(food_a.name, 'apple')

annotated_qs = generic_annotate(Food.objects.all(), Rating.content_object, 'rating', models.Avg)
self.assertEqual(annotated_qs.count(), 2)

Expand All @@ -54,7 +54,7 @@ def test_annotation(self):
self.assertEqual(food_b.score, 5)
self.assertEqual(food_b.name, 'orange')

self.assertEqual(food_a.score, 4)
self.assertEqual(food_a.score, 3)
self.assertEqual(food_a.name, 'apple')

def test_aggregation(self):
Expand All @@ -64,7 +64,7 @@ def test_aggregation(self):

# total of ratings out there for all foods
aggregated = generic_aggregate(Food.objects.all(), Rating.content_object, 'rating', models.Sum)
self.assertEqual(aggregated, 31)
self.assertEqual(aggregated, 27)

# (showing the use of filters and inner query)

Expand All @@ -76,8 +76,29 @@ def test_aggregation(self):

# avg for apple
aggregated = generic_aggregate(Food.objects.filter(name='apple'), Rating.content_object, 'rating', models.Avg)
self.assertEqual(aggregated, 4)
self.assertEqual(aggregated, 3)

# avg for orange
aggregated = generic_aggregate(Food.objects.filter(name='orange'), Rating.content_object, 'rating', models.Avg)
self.assertEqual(aggregated, 5)

def test_subset_annotation(self):
todays_ratings = Rating.objects.filter(created__gte=datetime.date.today())
annotated_qs = generic_annotate(Food.objects.all(), Rating.content_object, 'rating', models.Sum, todays_ratings)
self.assertEqual(annotated_qs.count(), 2)

food_a, food_b = annotated_qs

self.assertEqual(food_a.score, 8)
self.assertEqual(food_a.name, 'apple')

self.assertEqual(food_b.score, 7)
self.assertEqual(food_b.name, 'orange')

def test_subset_aggregation(self):
todays_ratings = Rating.objects.filter(created__gte=datetime.date.today())
aggregated = generic_aggregate(Food.objects.all(), Rating.content_object, 'rating', models.Sum, todays_ratings)
self.assertEqual(aggregated, 15)

aggregated = generic_aggregate(Food.objects.all(), Rating.content_object, 'rating', models.Count, todays_ratings)
self.assertEqual(aggregated, 4)
46 changes: 37 additions & 9 deletions generic_aggregation/utils.py
@@ -1,7 +1,8 @@
from django.contrib.contenttypes.models import ContentType
from django.db import connection, models

def generic_annotate(queryset, gfk_field, aggregate_field, aggregator=models.Sum, desc=True):
def generic_annotate(queryset, gfk_field, aggregate_field, aggregator=models.Sum,
generic_queryset=None, desc=True):
ordering = desc and '-score' or 'score'
content_type = ContentType.objects.get_for_model(queryset.model)

Expand All @@ -27,19 +28,34 @@ def generic_annotate(queryset, gfk_field, aggregate_field, aggregator=models.Sum
%s=%s.%s
""" % params

queryset = queryset.extra(select={
'score': extra
},
order_by=[ordering])
if generic_queryset is not None:
inner_query, inner_query_params = generic_queryset.values_list('pk').query.as_sql()

inner_params = (
qn(generic_queryset.model._meta.db_table),
qn(generic_queryset.model._meta.pk.name),
)
inner_start = ' AND %s.%s IN (' % inner_params
inner_end = ')'
extra = extra + inner_start + inner_query + inner_end
else:
inner_query_params = []

queryset = queryset.extra(
select={'score': extra},
select_params=inner_query_params,
order_by=[ordering]
)

return queryset


def generic_aggregate(queryset, gfk_field, aggregate_field, aggregator=models.Sum):
def generic_aggregate(queryset, gfk_field, aggregate_field, aggregator=models.Sum,
generic_queryset=None):
content_type = ContentType.objects.get_for_model(queryset.model)

queryset = queryset.values_list('pk') # just the pks
inner_query, inner_params = queryset.query.as_nested_sql()
query, query_params = queryset.query.as_nested_sql()

qn = connection.ops.quote_name

Expand All @@ -63,12 +79,24 @@ def generic_aggregate(queryset, gfk_field, aggregate_field, aggregator=models.Su

query_end = ")"

if generic_queryset is not None:
inner_query, inner_query_params = generic_queryset.values_list('pk').query.as_sql()

query_params += inner_query_params

inner_params = (
qn(generic_queryset.model._meta.pk.name),
)
inner_start = ' AND %s IN (' % inner_params
inner_end = ')'
query_end = query_end + inner_start + inner_query + inner_end

# pass in the inner_query unmodified as we will use the cursor to handle
# quoting the inner parameters correctly
query = query_start + inner_query + query_end
query = query_start + query + query_end

cursor = connection.cursor()
cursor.execute(query, inner_params)
cursor.execute(query, query_params)
row = cursor.fetchone()

return row[0]

0 comments on commit b41d9be

Please sign in to comment.