From 96fd08bbaefc7a0afb73727eca997db586e6e2b5 Mon Sep 17 00:00:00 2001 From: Swen Kooij Date: Sat, 24 Oct 2020 13:53:13 +0300 Subject: [PATCH] Restore renamed annotations in the original order without temporary renaming https://github.com/SectorLabs/django-postgres-extra/commit/b70d2397dec4c2066c3de53a136f6f1db44675eb Broke annotations. This was because even annotations that didn't need renaming were renamed temporarily to `_new`. This made it impossible to use those fields in an expression. This fixes it by restoring the original behaviour, but fixing the ordering problem by restoring the renamed annotations in the original order. --- psqlextra/query.py | 13 ++++++++++--- psqlextra/sql.py | 15 +++++++++++---- tests/fake_model.py | 15 +++++++++++++++ tests/test_query.py | 22 +++++++++++++++++++++- 4 files changed, 57 insertions(+), 8 deletions(-) diff --git a/psqlextra/query.py b/psqlextra/query.py index d3454499..71198be4 100644 --- a/psqlextra/query.py +++ b/psqlextra/query.py @@ -40,12 +40,19 @@ def annotate(self, **annotations): the annotations are stored in an OrderedDict. Renaming only the conflicts will mess up the order. """ + fields = {field.name: field for field in self.model._meta.get_fields()} + new_annotations = OrderedDict() + renames = {} + for name, value in annotations.items(): - new_name = "%s_new" % name - new_annotations[new_name] = value - renames[new_name] = name + if name in fields: + new_name = "%s_new" % name + new_annotations[new_name] = value + renames[new_name] = name + else: + new_annotations[name] = value # run the base class's annotate function result = super().annotate(**new_annotations) diff --git a/psqlextra/sql.py b/psqlextra/sql.py index d27094ff..2bba9821 100644 --- a/psqlextra/sql.py +++ b/psqlextra/sql.py @@ -41,9 +41,10 @@ def rename_annotations(self, annotations) -> None: old name to the new name. """ + # safety check only, make sure there are no renames + # left that cannot be mapped back to the original name for old_name, new_name in annotations.items(): annotation = self.annotations.get(old_name) - if not annotation: raise SuspiciousOperation( ( @@ -52,13 +53,19 @@ def rename_annotations(self, annotations) -> None: ).format(old_name=old_name, new_name=new_name) ) - self.annotations[new_name] = annotation - del self.annotations[old_name] + # rebuild the annotations according to the original order + new_annotations = dict() + for old_name, annotation in self.annotations.items(): + new_name = annotations.get(old_name) + new_annotations[new_name or old_name] = annotation - if self.annotation_select_mask: + if new_name and self.annotation_select_mask: self.annotation_select_mask.discard(old_name) self.annotation_select_mask.add(new_name) + self.annotations.clear() + self.annotations.update(new_annotations) + def add_fields(self, field_names: List[str], *args, **kwargs) -> bool: """Adds the given (model) fields to the select set. diff --git a/tests/fake_model.py b/tests/fake_model.py index 21faa60a..36198944 100644 --- a/tests/fake_model.py +++ b/tests/fake_model.py @@ -89,6 +89,21 @@ def define_fake_partitioned_model( return model +def get_fake_partitioned_model( + fields=None, partitioning_options={}, meta_options={} +): + """Defines a fake partitioned model and creates it in the database.""" + + model = define_fake_partitioned_model( + fields, partitioning_options, meta_options + ) + + with connection.schema_editor() as schema_editor: + schema_editor.create_model(model) + + return model + + def get_fake_model(fields=None, model_base=PostgresModel, meta_options={}): """Defines a fake model and creates it in the database.""" diff --git a/tests/test_query.py b/tests/test_query.py index 35d0cfe7..e1496f51 100644 --- a/tests/test_query.py +++ b/tests/test_query.py @@ -1,5 +1,5 @@ from django.db import models -from django.db.models import F +from django.db.models import Case, F, Q, Value, When from psqlextra.expressions import HStoreRef from psqlextra.fields import HStoreField @@ -75,6 +75,26 @@ def test_query_annotate_rename_order(): assert list(qs.query.annotations.keys()) == ["value", "value_2"] +def test_query_annotate_in_expression(): + """Tests whether annotations can be used in expressions.""" + + model = get_fake_model({"name": models.CharField(max_length=10)}) + + model.objects.create(name="henk") + + result = model.objects.annotate( + real_name=F("name"), + is_he_henk=Case( + When(Q(real_name="henk"), then=Value("really henk")), + default=Value("definitely not henk"), + output_field=models.CharField(), + ), + ).first() + + assert result.real_name == "henk" + assert result.is_he_henk == "really henk" + + def test_query_hstore_value_update_f_ref(): """Tests whether F(..) expressions can be used in hstore values when performing update queries."""