Skip to content

Commit

Permalink
Restore renamed annotations in the original order without temporary r…
Browse files Browse the repository at this point in the history
…enaming

b70d239

Broke annotations. This was because even annotations that didn't
need renaming were renamed temporarily to `_new`. This made it
impossible to use those fields in an expression.

This fixes it by restoring the original behaviour, but fixing
the ordering problem by restoring the renamed annotations
in the original order.
  • Loading branch information
Photonios committed Oct 30, 2020
1 parent 9738c69 commit 96fd08b
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 8 deletions.
13 changes: 10 additions & 3 deletions psqlextra/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,19 @@ def annotate(self, **annotations):
the annotations are stored in an OrderedDict. Renaming only the
conflicts will mess up the order.
"""
fields = {field.name: field for field in self.model._meta.get_fields()}

new_annotations = OrderedDict()

renames = {}

for name, value in annotations.items():
new_name = "%s_new" % name
new_annotations[new_name] = value
renames[new_name] = name
if name in fields:
new_name = "%s_new" % name
new_annotations[new_name] = value
renames[new_name] = name
else:
new_annotations[name] = value

# run the base class's annotate function
result = super().annotate(**new_annotations)
Expand Down
15 changes: 11 additions & 4 deletions psqlextra/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,10 @@ def rename_annotations(self, annotations) -> None:
old name to the new name.
"""

# safety check only, make sure there are no renames
# left that cannot be mapped back to the original name
for old_name, new_name in annotations.items():
annotation = self.annotations.get(old_name)

if not annotation:
raise SuspiciousOperation(
(
Expand All @@ -52,13 +53,19 @@ def rename_annotations(self, annotations) -> None:
).format(old_name=old_name, new_name=new_name)
)

self.annotations[new_name] = annotation
del self.annotations[old_name]
# rebuild the annotations according to the original order
new_annotations = dict()
for old_name, annotation in self.annotations.items():
new_name = annotations.get(old_name)
new_annotations[new_name or old_name] = annotation

if self.annotation_select_mask:
if new_name and self.annotation_select_mask:
self.annotation_select_mask.discard(old_name)
self.annotation_select_mask.add(new_name)

self.annotations.clear()
self.annotations.update(new_annotations)

def add_fields(self, field_names: List[str], *args, **kwargs) -> bool:
"""Adds the given (model) fields to the select set.
Expand Down
15 changes: 15 additions & 0 deletions tests/fake_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,21 @@ def define_fake_partitioned_model(
return model


def get_fake_partitioned_model(
fields=None, partitioning_options={}, meta_options={}
):
"""Defines a fake partitioned model and creates it in the database."""

model = define_fake_partitioned_model(
fields, partitioning_options, meta_options
)

with connection.schema_editor() as schema_editor:
schema_editor.create_model(model)

return model


def get_fake_model(fields=None, model_base=PostgresModel, meta_options={}):
"""Defines a fake model and creates it in the database."""

Expand Down
22 changes: 21 additions & 1 deletion tests/test_query.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from django.db import models
from django.db.models import F
from django.db.models import Case, F, Q, Value, When

from psqlextra.expressions import HStoreRef
from psqlextra.fields import HStoreField
Expand Down Expand Up @@ -75,6 +75,26 @@ def test_query_annotate_rename_order():
assert list(qs.query.annotations.keys()) == ["value", "value_2"]


def test_query_annotate_in_expression():
"""Tests whether annotations can be used in expressions."""

model = get_fake_model({"name": models.CharField(max_length=10)})

model.objects.create(name="henk")

result = model.objects.annotate(
real_name=F("name"),
is_he_henk=Case(
When(Q(real_name="henk"), then=Value("really henk")),
default=Value("definitely not henk"),
output_field=models.CharField(),
),
).first()

assert result.real_name == "henk"
assert result.is_he_henk == "really henk"


def test_query_hstore_value_update_f_ref():
"""Tests whether F(..) expressions can be used in hstore values when
performing update queries."""
Expand Down

0 comments on commit 96fd08b

Please sign in to comment.