From cb76f4f931614f19cbaac610df89d7b584e9b323 Mon Sep 17 00:00:00 2001 From: seroy Date: Mon, 20 May 2024 13:06:22 +0400 Subject: [PATCH] Fix `StopIteration` in deduplication rows code when `conflict_action == ConflictAction.NOTHING` and rows parameter is iterator or generator --- psqlextra/query.py | 14 +++++++++++--- tests/test_on_conflict_nothing.py | 17 ++++++++++++----- 2 files changed, 23 insertions(+), 8 deletions(-) diff --git a/psqlextra/query.py b/psqlextra/query.py index 65a20c50..c75ce9c4 100644 --- a/psqlextra/query.py +++ b/psqlextra/query.py @@ -174,11 +174,19 @@ def bulk_insert( A list of either the dicts of the rows inserted, including the pk or the models of the rows inserted with defaults for any fields not specified """ + if rows is None: + return [] + + def peek(iterable): + try: + first = next(iterable) + except StopIteration: + return None + return list(chain([first], iterable)) - def is_empty(r): - return all([False for _ in r]) + rows = peek(iter(rows)) - if not rows or is_empty(rows): + if not rows: return [] if not self.conflict_target and not self.conflict_action: diff --git a/tests/test_on_conflict_nothing.py b/tests/test_on_conflict_nothing.py index 78c4c5f4..eb3b8a3c 100644 --- a/tests/test_on_conflict_nothing.py +++ b/tests/test_on_conflict_nothing.py @@ -179,8 +179,15 @@ def test_on_conflict_nothing_duplicate_rows(): rows = [dict(amount=1), dict(amount=1)] - ( - model.objects.on_conflict( - ["amount"], ConflictAction.NOTHING - ).bulk_insert(rows) - ) + inserted_rows = model.objects.on_conflict( + ["amount"], ConflictAction.NOTHING + ).bulk_insert(rows) + + assert len(inserted_rows) == 1 + + rows = iter([dict(amount=2), dict(amount=2)]) + inserted_rows = model.objects.on_conflict( + ["amount"], ConflictAction.NOTHING + ).bulk_insert(rows) + + assert len(inserted_rows) == 1