diff --git a/psqlextra/compiler.py b/psqlextra/compiler.py index bf0f7e8b..52e40fbd 100644 --- a/psqlextra/compiler.py +++ b/psqlextra/compiler.py @@ -1,7 +1,7 @@ from collections.abc import Iterable from django.core.exceptions import SuspiciousOperation -from django.db.models import Model +from django.db.models import Expression, Model from django.db.models.fields.related import RelatedField from django.db.models.sql.compiler import SQLInsertCompiler, SQLUpdateCompiler from django.db.utils import ProgrammingError @@ -150,30 +150,31 @@ def _rewrite_insert_on_conflict( # for conflicts conflict_target = self._build_conflict_target() index_predicate = self.query.index_predicate + update_condition = self.query.conflict_update_condition - sql_template = ( - "{insert} ON CONFLICT {conflict_target} DO {conflict_action}" - ) + rewritten_sql = f"{sql} ON CONFLICT {conflict_target}" if index_predicate: - sql_template = "{insert} ON CONFLICT {conflict_target} WHERE {index_predicate} DO {conflict_action}" + if isinstance(index_predicate, Expression): + expr_sql, expr_params = self.compile(index_predicate) + rewritten_sql += f" WHERE {expr_sql}" + params += tuple(expr_params) + else: + rewritten_sql += f" WHERE {index_predicate}" + + rewritten_sql += f" DO {conflict_action}" if conflict_action == "UPDATE": - sql_template += " SET {update_columns}" - - sql_template += " RETURNING {returning}" - - return ( - sql_template.format( - insert=sql, - conflict_target=conflict_target, - conflict_action=conflict_action, - update_columns=update_columns, - returning=returning, - index_predicate=index_predicate, - ), - params, - ) + rewritten_sql += f" SET {update_columns}" + + if update_condition: + expr_sql, expr_params = self.compile(update_condition) + rewritten_sql += f" WHERE {expr_sql}" + params += tuple(expr_params) + + rewritten_sql += f" RETURNING {returning}" + + return (rewritten_sql, params) def _build_conflict_target(self): """Builds the `conflict_target` for the ON CONFLICT clause.""" diff --git a/psqlextra/query.py b/psqlextra/query.py index 71198be4..02eaadf5 100644 --- a/psqlextra/query.py +++ b/psqlextra/query.py @@ -4,6 +4,7 @@ from django.core.exceptions import SuspiciousOperation from django.db import models, router +from django.db.models import Expression from django.db.models.fields import NOT_PROVIDED from .sql import PostgresInsertQuery, PostgresQuery @@ -24,6 +25,7 @@ def __init__(self, model=None, query=None, using=None, hints=None): self.conflict_target = None self.conflict_action = None + self.conflict_update_condition = None self.index_predicate = None def annotate(self, **annotations): @@ -80,7 +82,8 @@ def on_conflict( self, fields: ConflictTarget, action: ConflictAction, - index_predicate: Optional[str] = None, + index_predicate: Optional[Union[Expression, str]] = None, + update_condition: Optional[Expression] = None, ): """Sets the action to take when conflicts arise when attempting to insert/create a new row. @@ -95,10 +98,14 @@ def on_conflict( index_predicate: The index predicate to satisfy an arbiter partial index (i.e. what partial index to use for checking conflicts) + + update_condition: + Only update if this SQL expression evaluates to true. """ self.conflict_target = fields self.conflict_action = action + self.conflict_update_condition = update_condition self.index_predicate = index_predicate return self @@ -250,8 +257,9 @@ def upsert( self, conflict_target: ConflictTarget, fields: dict, - index_predicate: Optional[str] = None, + index_predicate: Optional[Union[Expression, str]] = None, using: Optional[str] = None, + update_condition: Optional[Expression] = None, ) -> int: """Creates a new record or updates the existing one with the specified data. @@ -271,12 +279,18 @@ def upsert( The name of the database connection to use for this query. + update_condition: + Only update if this SQL expression evaluates to true. + Returns: The primary key of the row that was created/updated. """ self.on_conflict( - conflict_target, ConflictAction.UPDATE, index_predicate + conflict_target, + ConflictAction.UPDATE, + index_predicate=index_predicate, + update_condition=update_condition, ) return self.insert(**fields, using=using) @@ -284,8 +298,9 @@ def upsert_and_get( self, conflict_target: ConflictTarget, fields: dict, - index_predicate: Optional[str] = None, + index_predicate: Optional[Union[Expression, str]] = None, using: Optional[str] = None, + update_condition: Optional[Expression] = None, ): """Creates a new record or updates the existing one with the specified data and then gets the row. @@ -305,13 +320,19 @@ def upsert_and_get( The name of the database connection to use for this query. + update_condition: + Only update if this SQL expression evaluates to true. + Returns: The model instance representing the row that was created/updated. """ self.on_conflict( - conflict_target, ConflictAction.UPDATE, index_predicate + conflict_target, + ConflictAction.UPDATE, + index_predicate=index_predicate, + update_condition=update_condition, ) return self.insert_and_get(**fields, using=using) @@ -319,9 +340,10 @@ def bulk_upsert( self, conflict_target: ConflictTarget, rows: Iterable[Dict], - index_predicate: str = None, + index_predicate: Optional[Union[Expression, str]] = None, return_model: bool = False, using: Optional[str] = None, + update_condition: Optional[Expression] = None, ): """Creates a set of new records or updates the existing ones with the specified data. @@ -345,6 +367,9 @@ def bulk_upsert( The name of the database connection to use for this query. + update_condition: + Only update if this SQL expression evaluates to true. + Returns: A list of either the dicts of the rows upserted, including the pk or the models of the rows upserted @@ -357,7 +382,10 @@ def is_empty(r): return [] self.on_conflict( - conflict_target, ConflictAction.UPDATE, index_predicate + conflict_target, + ConflictAction.UPDATE, + index_predicate=index_predicate, + update_condition=update_condition, ) return self.bulk_insert(rows, return_model, using=using) @@ -425,6 +453,7 @@ def _build_insert_compiler( query = PostgresInsertQuery(self.model) query.conflict_action = self.conflict_action query.conflict_target = self.conflict_target + query.conflict_update_condition = self.conflict_update_condition query.index_predicate = self.index_predicate query.values(objs, insert_fields, update_fields) diff --git a/psqlextra/sql.py b/psqlextra/sql.py index 4e655ae9..9b11a558 100644 --- a/psqlextra/sql.py +++ b/psqlextra/sql.py @@ -145,6 +145,8 @@ def __init__(self, *args, **kwargs): self.conflict_target = [] self.conflict_action = ConflictAction.UPDATE + self.conflict_update_condition = None + self.index_predicate = None self.update_fields = [] diff --git a/pytest.ini b/pytest.ini index befe129e..b96a2c93 100644 --- a/pytest.ini +++ b/pytest.ini @@ -2,3 +2,5 @@ DJANGO_SETTINGS_MODULE=settings testpaths=tests addopts=-m "not benchmark" +filterwarnings = + ignore::UserWarning diff --git a/tests/test_upsert.py b/tests/test_upsert.py index 15e79be4..8cf94eb1 100644 --- a/tests/test_upsert.py +++ b/tests/test_upsert.py @@ -1,4 +1,5 @@ from django.db import models +from django.db.models.expressions import CombinedExpression, Value from psqlextra.fields import HStoreField @@ -76,6 +77,35 @@ def test_upsert_explicit_pk(): assert obj2.cookies == "second-boo" +def test_upsert_with_update_condition(): + """Tests that a custom expression can be passed as an update condition.""" + + model = get_fake_model( + { + "name": models.TextField(unique=True), + "priority": models.IntegerField(), + "active": models.BooleanField(), + } + ) + + obj1 = model.objects.create(name="joe", priority=1, active=False) + + model.objects.upsert( + conflict_target=["name"], + update_condition=CombinedExpression( + model._meta.get_field("active").get_col(model._meta.db_table), + "=", + Value(True), + ), + fields=dict(name="joe", priority=2, active=True), + ) + + obj1.refresh_from_db() + + assert obj1.priority == 1 + assert not obj1.active + + def test_upsert_bulk(): """Tests whether bulk_upsert works properly."""