Skip to content

Commit

Permalink
Support for conditions in upserts
Browse files Browse the repository at this point in the history
  • Loading branch information
Photonios committed Feb 24, 2021
1 parent c11dbb5 commit 480714f
Show file tree
Hide file tree
Showing 5 changed files with 91 additions and 27 deletions.
41 changes: 21 additions & 20 deletions psqlextra/compiler.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from collections.abc import Iterable

from django.core.exceptions import SuspiciousOperation
from django.db.models import Model
from django.db.models import Expression, Model
from django.db.models.fields.related import RelatedField
from django.db.models.sql.compiler import SQLInsertCompiler, SQLUpdateCompiler
from django.db.utils import ProgrammingError
Expand Down Expand Up @@ -150,30 +150,31 @@ def _rewrite_insert_on_conflict(
# for conflicts
conflict_target = self._build_conflict_target()
index_predicate = self.query.index_predicate
update_condition = self.query.conflict_update_condition

sql_template = (
"{insert} ON CONFLICT {conflict_target} DO {conflict_action}"
)
rewritten_sql = f"{sql} ON CONFLICT {conflict_target}"

if index_predicate:
sql_template = "{insert} ON CONFLICT {conflict_target} WHERE {index_predicate} DO {conflict_action}"
if isinstance(index_predicate, Expression):
expr_sql, expr_params = self.compile(index_predicate)
rewritten_sql += f" WHERE {expr_sql}"
params += tuple(expr_params)
else:
rewritten_sql += f" WHERE {index_predicate}"

rewritten_sql += f" DO {conflict_action}"

if conflict_action == "UPDATE":
sql_template += " SET {update_columns}"

sql_template += " RETURNING {returning}"

return (
sql_template.format(
insert=sql,
conflict_target=conflict_target,
conflict_action=conflict_action,
update_columns=update_columns,
returning=returning,
index_predicate=index_predicate,
),
params,
)
rewritten_sql += f" SET {update_columns}"

if update_condition:
expr_sql, expr_params = self.compile(update_condition)
rewritten_sql += f" WHERE {expr_sql}"
params += tuple(expr_params)

rewritten_sql += f" RETURNING {returning}"

return (rewritten_sql, params)

def _build_conflict_target(self):
"""Builds the `conflict_target` for the ON CONFLICT clause."""
Expand Down
43 changes: 36 additions & 7 deletions psqlextra/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from django.core.exceptions import SuspiciousOperation
from django.db import models, router
from django.db.models import Expression
from django.db.models.fields import NOT_PROVIDED

from .sql import PostgresInsertQuery, PostgresQuery
Expand All @@ -24,6 +25,7 @@ def __init__(self, model=None, query=None, using=None, hints=None):

self.conflict_target = None
self.conflict_action = None
self.conflict_update_condition = None
self.index_predicate = None

def annotate(self, **annotations):
Expand Down Expand Up @@ -80,7 +82,8 @@ def on_conflict(
self,
fields: ConflictTarget,
action: ConflictAction,
index_predicate: Optional[str] = None,
index_predicate: Optional[Union[Expression, str]] = None,
update_condition: Optional[Expression] = None,
):
"""Sets the action to take when conflicts arise when attempting to
insert/create a new row.
Expand All @@ -95,10 +98,14 @@ def on_conflict(
index_predicate:
The index predicate to satisfy an arbiter partial index (i.e. what partial index to use for checking
conflicts)
update_condition:
Only update if this SQL expression evaluates to true.
"""

self.conflict_target = fields
self.conflict_action = action
self.conflict_update_condition = update_condition
self.index_predicate = index_predicate

return self
Expand Down Expand Up @@ -250,8 +257,9 @@ def upsert(
self,
conflict_target: ConflictTarget,
fields: dict,
index_predicate: Optional[str] = None,
index_predicate: Optional[Union[Expression, str]] = None,
using: Optional[str] = None,
update_condition: Optional[Expression] = None,
) -> int:
"""Creates a new record or updates the existing one with the specified
data.
Expand All @@ -271,21 +279,28 @@ def upsert(
The name of the database connection to
use for this query.
update_condition:
Only update if this SQL expression evaluates to true.
Returns:
The primary key of the row that was created/updated.
"""

self.on_conflict(
conflict_target, ConflictAction.UPDATE, index_predicate
conflict_target,
ConflictAction.UPDATE,
index_predicate=index_predicate,
update_condition=update_condition,
)
return self.insert(**fields, using=using)

def upsert_and_get(
self,
conflict_target: ConflictTarget,
fields: dict,
index_predicate: Optional[str] = None,
index_predicate: Optional[Union[Expression, str]] = None,
using: Optional[str] = None,
update_condition: Optional[Expression] = None,
):
"""Creates a new record or updates the existing one with the specified
data and then gets the row.
Expand All @@ -305,23 +320,30 @@ def upsert_and_get(
The name of the database connection to
use for this query.
update_condition:
Only update if this SQL expression evaluates to true.
Returns:
The model instance representing the row
that was created/updated.
"""

self.on_conflict(
conflict_target, ConflictAction.UPDATE, index_predicate
conflict_target,
ConflictAction.UPDATE,
index_predicate=index_predicate,
update_condition=update_condition,
)
return self.insert_and_get(**fields, using=using)

def bulk_upsert(
self,
conflict_target: ConflictTarget,
rows: Iterable[Dict],
index_predicate: str = None,
index_predicate: Optional[Union[Expression, str]] = None,
return_model: bool = False,
using: Optional[str] = None,
update_condition: Optional[Expression] = None,
):
"""Creates a set of new records or updates the existing ones with the
specified data.
Expand All @@ -345,6 +367,9 @@ def bulk_upsert(
The name of the database connection to use
for this query.
update_condition:
Only update if this SQL expression evaluates to true.
Returns:
A list of either the dicts of the rows upserted, including the pk or
the models of the rows upserted
Expand All @@ -357,7 +382,10 @@ def is_empty(r):
return []

self.on_conflict(
conflict_target, ConflictAction.UPDATE, index_predicate
conflict_target,
ConflictAction.UPDATE,
index_predicate=index_predicate,
update_condition=update_condition,
)
return self.bulk_insert(rows, return_model, using=using)

Expand Down Expand Up @@ -425,6 +453,7 @@ def _build_insert_compiler(
query = PostgresInsertQuery(self.model)
query.conflict_action = self.conflict_action
query.conflict_target = self.conflict_target
query.conflict_update_condition = self.conflict_update_condition
query.index_predicate = self.index_predicate
query.values(objs, insert_fields, update_fields)

Expand Down
2 changes: 2 additions & 0 deletions psqlextra/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,8 @@ def __init__(self, *args, **kwargs):

self.conflict_target = []
self.conflict_action = ConflictAction.UPDATE
self.conflict_update_condition = None
self.index_predicate = None

self.update_fields = []

Expand Down
2 changes: 2 additions & 0 deletions pytest.ini
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,5 @@
DJANGO_SETTINGS_MODULE=settings
testpaths=tests
addopts=-m "not benchmark"
filterwarnings =
ignore::UserWarning
30 changes: 30 additions & 0 deletions tests/test_upsert.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from django.db import models
from django.db.models.expressions import CombinedExpression, Value

from psqlextra.fields import HStoreField

Expand Down Expand Up @@ -76,6 +77,35 @@ def test_upsert_explicit_pk():
assert obj2.cookies == "second-boo"


def test_upsert_with_update_condition():
"""Tests that a custom expression can be passed as an update condition."""

model = get_fake_model(
{
"name": models.TextField(unique=True),
"priority": models.IntegerField(),
"active": models.BooleanField(),
}
)

obj1 = model.objects.create(name="joe", priority=1, active=False)

model.objects.upsert(
conflict_target=["name"],
update_condition=CombinedExpression(
model._meta.get_field("active").get_col(model._meta.db_table),
"=",
Value(True),
),
fields=dict(name="joe", priority=2, active=True),
)

obj1.refresh_from_db()

assert obj1.priority == 1
assert not obj1.active


def test_upsert_bulk():
"""Tests whether bulk_upsert works properly."""

Expand Down

0 comments on commit 480714f

Please sign in to comment.