From 03e8eb61039e67dc197f17aba91e19daa5501341 Mon Sep 17 00:00:00 2001 From: Swen Kooij Date: Tue, 18 Jun 2024 13:36:49 +0200 Subject: [PATCH] Add additional tests for ON CONFLICT DO NOTHING duplicate rows filtering --- psqlextra/query.py | 17 +++++++++-------- tests/test_on_conflict_nothing.py | 24 +++++++++++++----------- 2 files changed, 22 insertions(+), 19 deletions(-) diff --git a/psqlextra/query.py b/psqlextra/query.py index c75ce9c..6a86f18 100644 --- a/psqlextra/query.py +++ b/psqlextra/query.py @@ -41,6 +41,14 @@ QuerySetBase = QuerySet +def peek_iterator(iterable): + try: + first = next(iterable) + except StopIteration: + return None + return list(chain([first], iterable)) + + class PostgresQuerySet(QuerySetBase, Generic[TModel]): """Adds support for PostgreSQL specifics.""" @@ -177,14 +185,7 @@ def bulk_insert( if rows is None: return [] - def peek(iterable): - try: - first = next(iterable) - except StopIteration: - return None - return list(chain([first], iterable)) - - rows = peek(iter(rows)) + rows = peek_iterator(iter(rows)) if not rows: return [] diff --git a/tests/test_on_conflict_nothing.py b/tests/test_on_conflict_nothing.py index eb3b8a3..92e74df 100644 --- a/tests/test_on_conflict_nothing.py +++ b/tests/test_on_conflict_nothing.py @@ -170,24 +170,26 @@ def test_on_conflict_nothing_foreign_key_by_id(): assert obj1.data == "some data" -def test_on_conflict_nothing_duplicate_rows(): +@pytest.mark.parametrize( + "rows,expected_row_count", + [ + ([dict(amount=1), dict(amount=1)], 1), + (iter([dict(amount=1), dict(amount=1)]), 1), + ((row for row in [dict(amount=1), dict(amount=1)]), 1), + ([], 0), + (iter([]), 0), + ((row for row in []), 0), + ], +) +def test_on_conflict_nothing_duplicate_rows(rows, expected_row_count): """Tests whether duplicate rows are filtered out when doing a insert NOTHING and no error is raised when the list of rows contains duplicates.""" model = get_fake_model({"amount": models.IntegerField(unique=True)}) - rows = [dict(amount=1), dict(amount=1)] - - inserted_rows = model.objects.on_conflict( - ["amount"], ConflictAction.NOTHING - ).bulk_insert(rows) - - assert len(inserted_rows) == 1 - - rows = iter([dict(amount=2), dict(amount=2)]) inserted_rows = model.objects.on_conflict( ["amount"], ConflictAction.NOTHING ).bulk_insert(rows) - assert len(inserted_rows) == 1 + assert len(inserted_rows) == expected_row_count