From 39a9f8671e71b8198b676edd60d0eb6aae91a197 Mon Sep 17 00:00:00 2001 From: Swen Kooij Date: Wed, 24 Feb 2021 09:51:54 +0200 Subject: [PATCH] Apply converters to results return in upsert_and_get --- psqlextra/query.py | 32 ++++++++++++++++++++++++++++---- tests/test_upsert.py | 16 ++++++++++++++++ 2 files changed, 44 insertions(+), 4 deletions(-) diff --git a/psqlextra/query.py b/psqlextra/query.py index 02eaadf5..7369a0cc 100644 --- a/psqlextra/query.py +++ b/psqlextra/query.py @@ -3,7 +3,7 @@ from typing import Dict, Iterable, List, Optional, Tuple, Union from django.core.exceptions import SuspiciousOperation -from django.db import models, router +from django.db import connections, models, router from django.db.models import Expression from django.db.models.fields import NOT_PROVIDED @@ -389,14 +389,36 @@ def is_empty(r): ) return self.bulk_insert(rows, return_model, using=using) - def _create_model_instance(self, field_values, using: Optional[str] = None): + def _create_model_instance( + self, field_values: dict, using: str, apply_converters: bool = True + ): """Creates a new instance of the model with the specified field. Use this after the row was inserted into the database. The new instance will marked as "saved". """ - instance = self.model(**field_values) + converted_field_values = field_values.copy() + + if apply_converters: + connection = connections[using] + + for field in self.model._meta.local_concrete_fields: + if field.attname not in converted_field_values: + continue + + # converters can be defined on the field, or by + # the database back-end we're using + converters = field.get_db_converters( + connection + ) + connection.ops.get_db_converters(field) + + for converter in converters: + converted_field_values[field.attname] = converter( + converted_field_values[field.attname], field, connection + ) + + instance = self.model(**converted_field_values) instance._state.db = using instance._state.adding = False @@ -444,7 +466,9 @@ def _build_insert_compiler( ).format(index) ) - objs.append(self._create_model_instance(row, using)) + objs.append( + self._create_model_instance(row, using, apply_converters=False) + ) # get the fields to be used during update/insert insert_fields, update_fields = self._get_upsert_fields(first_row) diff --git a/tests/test_upsert.py b/tests/test_upsert.py index 8cf94eb1..b04e09ea 100644 --- a/tests/test_upsert.py +++ b/tests/test_upsert.py @@ -106,6 +106,22 @@ def test_upsert_with_update_condition(): assert not obj1.active +def test_upsert_and_get_applies_converters(): + """Tests that converters are properly applied when using upsert_and_get.""" + + class MyCustomField(models.TextField): + def from_db_value(self, value, expression, connection): + return value.replace("hello", "bye") + + model = get_fake_model({"title": MyCustomField(unique=True)}) + + obj = model.objects.upsert_and_get( + conflict_target=["title"], fields=dict(title="hello") + ) + + assert obj.title == "bye" + + def test_upsert_bulk(): """Tests whether bulk_upsert works properly."""