Skip to content

Commit

Permalink
Add only() to RawQuerySet
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexHill committed Dec 13, 2013
1 parent ecd234b commit 3d2a2e2
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 9 deletions.
11 changes: 11 additions & 0 deletions django/db/models/query.py
Expand Up @@ -1628,6 +1628,17 @@ def model_fields(self):
self._model_fields[converter(column)] = field
return self._model_fields

def _clone(self):
return RawQuerySet(self.raw_query, model=self.model,
query=self.query.clone(self.query.using),
params=self.params, translations=self.translations,
using=self._db)

def only(self, *fields):
only_qs = self._clone()
only_qs.query.set_immediate_columns(fields)
return only_qs

def _prepare(self):
return self

Expand Down
69 changes: 60 additions & 9 deletions django/db/models/sql/query.py
Expand Up @@ -18,7 +18,7 @@
from django.db.models.aggregates import refs_aggregate
from django.db.models.expressions import ExpressionNode
from django.db.models.fields import FieldDoesNotExist
from django.db.models.query_utils import Q, InvalidQuery
from django.db.models.query_utils import Q
from django.db.models.related import PathInfo
from django.db.models.sql import aggregates as base_aggregates_module
from django.db.models.sql.constants import (QUERY_TERMS, ORDER_DIR, SINGLE,
Expand All @@ -37,21 +37,40 @@ class RawQuery(object):
A single raw SQL query
"""

SUBQUERY_ALIAS = 'sq'

def __init__(self, sql, using, params=None, pk_column=None):
self.params = params or ()
self.sql = sql
self._sql = sql
self.using = using
self.cursor = None
self.pk_column = pk_column
self.immediate_columns = set()

# Mirror some properties of a normal query so that
# the compiler can be used to process results.
self.low_mark, self.high_mark = 0, None # Used for offset/limit
self.extra_select = {}
self.aggregate_select = {}

@property
def sql(self):
"""
Wraps the raw SQL in an outer SELECT statement to select only
certain columns. Column resf requested by calling RawQuerySet.only().
"""
if not self.immediate_columns:
return self._sql
# Never defer the primary key - without it, it's impossible
# to construct model instances.
columns = list(self.immediate_columns | {self.pk_column})
return self._select_columns_sql(columns)

def clone(self, using):
return RawQuery(self.sql, using, params=self.params)
cloned_query = RawQuery(self._sql, using, params=self.params,
pk_column=self.pk_column)
cloned_query.set_immediate_columns(self.immediate_columns)
return cloned_query

def convert_values(self, value, field, connection):
"""Convert the database-returned value into a type that is consistent
Expand Down Expand Up @@ -88,13 +107,44 @@ def _execute_query(self):
self.cursor = connections[self.using].cursor()
self.cursor.execute(self.sql, self.params)

def set_immediate_columns(self, columns):
self.immediate_columns = set(columns)

def _select_columns_sql(self, columns, connection=None):
"""
Returns this query's raw SQL wrapped in another SELECT query which
selects only the columns specified in the parameter passed.
"""
connection = connection or connections[self.using]
select_format_str = ', '.join('%s' for _ in columns)
sql_query = 'SELECT %s FROM (%%s) %s' % (select_format_str,
self.SUBQUERY_ALIAS)
quoted_columns = [connection.ops.quote_name(col) for col in columns]
return sql_query % tuple(quoted_columns + [self._sql])

def as_nested_sql(self, connection):
if not self.pk_column:
return self.sql, self.params
quoted_column = connection.ops.quote_name(self.pk_column)
sql_params = (quoted_column, self.sql, quoted_column)
return ('SELECT sq.%s FROM (%s) sq WHERE %s IS NOT NULL' % sql_params,
self.params)
"""
Return this query as suitable for inclusion as a subquery. This
consists of the output of _select_columns_sql with an added
IS NOT NULL clause for each column selected.
"""
columns = []
if self.immediate_columns:
columns = list(self.immediate_columns)
elif self.pk_column:
columns = [self.pk_column]

if columns:
qn = connection.ops.quote_name

def not_null_clause(column):
return '%s.%s IS NOT NULL' % (self.SUBQUERY_ALIAS, qn(column))
not_nulls = ' AND '.join(not_null_clause(c) for c in columns)
sql = '%s WHERE %s' % (self._select_columns_sql(columns), not_nulls)
else:
sql = self._sql

return sql, self.params

def as_subquery_condition(self, alias, columns, qn):
connection = connections[self.using]
Expand All @@ -107,6 +157,7 @@ def as_subquery_condition(self, alias, columns, qn):
raise NotImplementedError(
"Can't use a raw query in a multi-column join")


class Query(object):
"""
A single SQL query.
Expand Down
14 changes: 14 additions & 0 deletions tests/raw_query/tests.py
Expand Up @@ -245,6 +245,20 @@ def test_subquery_count(self):
with self.assertNumQueries(2):
list(Book.objects.filter(author__in=list(raw_qs)))

def test_subquery_only(self):
raw_qs = Book.objects.raw('SELECT * FROM raw_query_book WHERE paperback')
self.assertQuerysetEqual(raw_qs.only('author_id'), raw_qs,
transform=lambda i: i, ordered=False)

with self.assertNumQueries(2):
raw_qs.only('author_id')[0].paperback

def test_subquery_only_in(self):
raw_qs = Book.objects.raw('SELECT * FROM raw_query_book WHERE paperback')
self.assertQuerysetEqual(Author.objects.filter(pk__in=raw_qs.only('author_id')),
Author.objects.filter(book__paperback=True).distinct(),
transform=lambda i: i, ordered=False)

def test_subquery_pk(self):
raw_qs = Author.objects.raw("SELECT * FROM raw_query_author WHERE id < 3")
qs = Author.objects.filter(id__lt=3)
Expand Down

0 comments on commit 3d2a2e2

Please sign in to comment.