diff --git a/psqlextra/query.py b/psqlextra/query.py index d3454499..71198be4 100644 --- a/psqlextra/query.py +++ b/psqlextra/query.py @@ -40,12 +40,19 @@ def annotate(self, **annotations): the annotations are stored in an OrderedDict. Renaming only the conflicts will mess up the order. """ + fields = {field.name: field for field in self.model._meta.get_fields()} + new_annotations = OrderedDict() + renames = {} + for name, value in annotations.items(): - new_name = "%s_new" % name - new_annotations[new_name] = value - renames[new_name] = name + if name in fields: + new_name = "%s_new" % name + new_annotations[new_name] = value + renames[new_name] = name + else: + new_annotations[name] = value # run the base class's annotate function result = super().annotate(**new_annotations) diff --git a/psqlextra/sql.py b/psqlextra/sql.py index d27094ff..2bba9821 100644 --- a/psqlextra/sql.py +++ b/psqlextra/sql.py @@ -41,9 +41,10 @@ def rename_annotations(self, annotations) -> None: old name to the new name. """ + # safety check only, make sure there are no renames + # left that cannot be mapped back to the original name for old_name, new_name in annotations.items(): annotation = self.annotations.get(old_name) - if not annotation: raise SuspiciousOperation( ( @@ -52,13 +53,19 @@ def rename_annotations(self, annotations) -> None: ).format(old_name=old_name, new_name=new_name) ) - self.annotations[new_name] = annotation - del self.annotations[old_name] + # rebuild the annotations according to the original order + new_annotations = dict() + for old_name, annotation in self.annotations.items(): + new_name = annotations.get(old_name) + new_annotations[new_name or old_name] = annotation - if self.annotation_select_mask: + if new_name and self.annotation_select_mask: self.annotation_select_mask.discard(old_name) self.annotation_select_mask.add(new_name) + self.annotations.clear() + self.annotations.update(new_annotations) + def add_fields(self, field_names: List[str], *args, **kwargs) -> bool: """Adds the given (model) fields to the select set. diff --git a/tests/fake_model.py b/tests/fake_model.py index 21faa60a..36198944 100644 --- a/tests/fake_model.py +++ b/tests/fake_model.py @@ -89,6 +89,21 @@ def define_fake_partitioned_model( return model +def get_fake_partitioned_model( + fields=None, partitioning_options={}, meta_options={} +): + """Defines a fake partitioned model and creates it in the database.""" + + model = define_fake_partitioned_model( + fields, partitioning_options, meta_options + ) + + with connection.schema_editor() as schema_editor: + schema_editor.create_model(model) + + return model + + def get_fake_model(fields=None, model_base=PostgresModel, meta_options={}): """Defines a fake model and creates it in the database.""" diff --git a/tests/test_query.py b/tests/test_query.py index 35d0cfe7..e1496f51 100644 --- a/tests/test_query.py +++ b/tests/test_query.py @@ -1,5 +1,5 @@ from django.db import models -from django.db.models import F +from django.db.models import Case, F, Q, Value, When from psqlextra.expressions import HStoreRef from psqlextra.fields import HStoreField @@ -75,6 +75,26 @@ def test_query_annotate_rename_order(): assert list(qs.query.annotations.keys()) == ["value", "value_2"] +def test_query_annotate_in_expression(): + """Tests whether annotations can be used in expressions.""" + + model = get_fake_model({"name": models.CharField(max_length=10)}) + + model.objects.create(name="henk") + + result = model.objects.annotate( + real_name=F("name"), + is_he_henk=Case( + When(Q(real_name="henk"), then=Value("really henk")), + default=Value("definitely not henk"), + output_field=models.CharField(), + ), + ).first() + + assert result.real_name == "henk" + assert result.is_he_henk == "really henk" + + def test_query_hstore_value_update_f_ref(): """Tests whether F(..) expressions can be used in hstore values when performing update queries."""