Skip to content

Commit

Permalink
Apply converters to results return in upsert_and_get
Browse files Browse the repository at this point in the history
  • Loading branch information
Photonios committed Feb 24, 2021
1 parent 480714f commit 39a9f86
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 4 deletions.
32 changes: 28 additions & 4 deletions psqlextra/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import Dict, Iterable, List, Optional, Tuple, Union

from django.core.exceptions import SuspiciousOperation
from django.db import models, router
from django.db import connections, models, router
from django.db.models import Expression
from django.db.models.fields import NOT_PROVIDED

Expand Down Expand Up @@ -389,14 +389,36 @@ def is_empty(r):
)
return self.bulk_insert(rows, return_model, using=using)

def _create_model_instance(self, field_values, using: Optional[str] = None):
def _create_model_instance(
self, field_values: dict, using: str, apply_converters: bool = True
):
"""Creates a new instance of the model with the specified field.
Use this after the row was inserted into the database. The new
instance will marked as "saved".
"""

instance = self.model(**field_values)
converted_field_values = field_values.copy()

if apply_converters:
connection = connections[using]

for field in self.model._meta.local_concrete_fields:
if field.attname not in converted_field_values:
continue

# converters can be defined on the field, or by
# the database back-end we're using
converters = field.get_db_converters(
connection
) + connection.ops.get_db_converters(field)

for converter in converters:
converted_field_values[field.attname] = converter(
converted_field_values[field.attname], field, connection
)

instance = self.model(**converted_field_values)
instance._state.db = using
instance._state.adding = False

Expand Down Expand Up @@ -444,7 +466,9 @@ def _build_insert_compiler(
).format(index)
)

objs.append(self._create_model_instance(row, using))
objs.append(
self._create_model_instance(row, using, apply_converters=False)
)

# get the fields to be used during update/insert
insert_fields, update_fields = self._get_upsert_fields(first_row)
Expand Down
16 changes: 16 additions & 0 deletions tests/test_upsert.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,22 @@ def test_upsert_with_update_condition():
assert not obj1.active


def test_upsert_and_get_applies_converters():
"""Tests that converters are properly applied when using upsert_and_get."""

class MyCustomField(models.TextField):
def from_db_value(self, value, expression, connection):
return value.replace("hello", "bye")

model = get_fake_model({"title": MyCustomField(unique=True)})

obj = model.objects.upsert_and_get(
conflict_target=["title"], fields=dict(title="hello")
)

assert obj.title == "bye"


def test_upsert_bulk():
"""Tests whether bulk_upsert works properly."""

Expand Down

0 comments on commit 39a9f86

Please sign in to comment.