Skip to content

Commit 1ac8ffb

Browse files
koenvoFokko
andauthored
Improve upsert memory pressure (#1995)
### Summary This PR updates the upsert logic to use batch processing. The main goal is to prevent out-of-memory (OOM) issues when updating large tables by avoiding loading all data at once. **Note:** This has only been tested against the unit tests—no real-world datasets have been evaluated yet. This PR partially depends on functionality introduced in [#1817](apache/iceberg#1817). --- ### Notes - Duplicate detection across multiple batches is **not** possible with this approach. - ~All data is read sequentially, which may be slower than the parallel read used by `to_arrow`.~ fixed using `concurrent_tasks` parameter --- ### Performance Comparison In setups with many small files, network and metadata overhead become the dominant factor. This impacts batch reading performance, as each file contributes relatively more overhead than payload. In the test setup used here, metadata access was the largest cost. #### Using `to_arrow_batch_reader` (sequential): - **Scan:** 9993.50 ms - **To list:** 19811.09 ms #### Using `to_arrow` (parallel): - **Scan:** 10607.88 ms --------- Co-authored-by: Fokko Driesprong <fokko@apache.org>
1 parent 84c91f0 commit 1ac8ffb

File tree

2 files changed

+75
-67
lines changed

2 files changed

+75
-67
lines changed

pyiceberg/io/pyarrow.py

Lines changed: 38 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525

2626
from __future__ import annotations
2727

28-
import concurrent.futures
2928
import fnmatch
3029
import functools
3130
import itertools
@@ -36,7 +35,6 @@
3635
import uuid
3736
import warnings
3837
from abc import ABC, abstractmethod
39-
from concurrent.futures import Future
4038
from copy import copy
4139
from dataclasses import dataclass
4240
from enum import Enum
@@ -70,7 +68,6 @@
7068
FileSystem,
7169
FileType,
7270
)
73-
from sortedcontainers import SortedList
7471

7572
from pyiceberg.conversions import to_bytes
7673
from pyiceberg.exceptions import ResolveError
@@ -1586,47 +1583,20 @@ def to_table(self, tasks: Iterable[FileScanTask]) -> pa.Table:
15861583
ResolveError: When a required field cannot be found in the file
15871584
ValueError: When a field type in the file cannot be projected to the schema type
15881585
"""
1589-
deletes_per_file = _read_all_delete_files(self._io, tasks)
1590-
executor = ExecutorFactory.get_or_create()
1591-
1592-
def _table_from_scan_task(task: FileScanTask) -> pa.Table:
1593-
batches = list(self._record_batches_from_scan_tasks_and_deletes([task], deletes_per_file))
1594-
if len(batches) > 0:
1595-
return pa.Table.from_batches(batches)
1596-
else:
1597-
return None
1598-
1599-
futures = [
1600-
executor.submit(
1601-
_table_from_scan_task,
1602-
task,
1603-
)
1604-
for task in tasks
1605-
]
1606-
total_row_count = 0
1607-
# for consistent ordering, we need to maintain future order
1608-
futures_index = {f: i for i, f in enumerate(futures)}
1609-
completed_futures: SortedList[Future[pa.Table]] = SortedList(iterable=[], key=lambda f: futures_index[f])
1610-
for future in concurrent.futures.as_completed(futures):
1611-
completed_futures.add(future)
1612-
if table_result := future.result():
1613-
total_row_count += len(table_result)
1614-
# stop early if limit is satisfied
1615-
if self._limit is not None and total_row_count >= self._limit:
1616-
break
1617-
1618-
# by now, we've either completed all tasks or satisfied the limit
1619-
if self._limit is not None:
1620-
_ = [f.cancel() for f in futures if not f.done()]
1621-
1622-
tables = [f.result() for f in completed_futures if f.result()]
1623-
16241586
arrow_schema = schema_to_pyarrow(self._projected_schema, include_field_ids=False)
16251587

1626-
if len(tables) < 1:
1627-
return pa.Table.from_batches([], schema=arrow_schema)
1628-
1629-
result = pa.concat_tables(tables, promote_options="permissive")
1588+
batches = self.to_record_batches(tasks)
1589+
try:
1590+
first_batch = next(batches)
1591+
except StopIteration:
1592+
# Empty
1593+
return arrow_schema.empty_table()
1594+
1595+
# Note: cannot use pa.Table.from_batches(itertools.chain([first_batch], batches)))
1596+
# as different batches can use different schema's (due to large_ types)
1597+
result = pa.concat_tables(
1598+
(pa.Table.from_batches([batch]) for batch in itertools.chain([first_batch], batches)), promote_options="permissive"
1599+
)
16301600

16311601
if property_as_bool(self._io.properties, PYARROW_USE_LARGE_TYPES_ON_READ, False):
16321602
deprecation_message(
@@ -1636,9 +1606,6 @@ def _table_from_scan_task(task: FileScanTask) -> pa.Table:
16361606
)
16371607
result = result.cast(arrow_schema)
16381608

1639-
if self._limit is not None:
1640-
return result.slice(0, self._limit)
1641-
16421609
return result
16431610

16441611
def to_record_batches(self, tasks: Iterable[FileScanTask]) -> Iterator[pa.RecordBatch]:
@@ -1660,7 +1627,32 @@ def to_record_batches(self, tasks: Iterable[FileScanTask]) -> Iterator[pa.Record
16601627
ValueError: When a field type in the file cannot be projected to the schema type
16611628
"""
16621629
deletes_per_file = _read_all_delete_files(self._io, tasks)
1663-
return self._record_batches_from_scan_tasks_and_deletes(tasks, deletes_per_file)
1630+
1631+
total_row_count = 0
1632+
executor = ExecutorFactory.get_or_create()
1633+
1634+
def batches_for_task(task: FileScanTask) -> List[pa.RecordBatch]:
1635+
# Materialize the iterator here to ensure execution happens within the executor.
1636+
# Otherwise, the iterator would be lazily consumed later (in the main thread),
1637+
# defeating the purpose of using executor.map.
1638+
return list(self._record_batches_from_scan_tasks_and_deletes([task], deletes_per_file))
1639+
1640+
limit_reached = False
1641+
for batches in executor.map(batches_for_task, tasks):
1642+
for batch in batches:
1643+
current_batch_size = len(batch)
1644+
if self._limit is not None and total_row_count + current_batch_size >= self._limit:
1645+
yield batch.slice(0, self._limit - total_row_count)
1646+
1647+
limit_reached = True
1648+
break
1649+
else:
1650+
yield batch
1651+
total_row_count += current_batch_size
1652+
1653+
if limit_reached:
1654+
# This break will also cancel all running tasks in the executor
1655+
break
16641656

16651657
def _record_batches_from_scan_tasks_and_deletes(
16661658
self, tasks: Iterable[FileScanTask], deletes_per_file: Dict[str, List[ChunkedArray]]

pyiceberg/table/__init__.py

Lines changed: 37 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -773,39 +773,55 @@ def upsert(
773773
matched_predicate = upsert_util.create_match_filter(df, join_cols)
774774

775775
# We must use Transaction.table_metadata for the scan. This includes all uncommitted - but relevant - changes.
776-
matched_iceberg_table = DataScan(
776+
matched_iceberg_record_batches = DataScan(
777777
table_metadata=self.table_metadata,
778778
io=self._table.io,
779779
row_filter=matched_predicate,
780780
case_sensitive=case_sensitive,
781-
).to_arrow()
781+
).to_arrow_batch_reader()
782782

783-
update_row_cnt = 0
784-
insert_row_cnt = 0
783+
batches_to_overwrite = []
784+
overwrite_predicates = []
785+
rows_to_insert = df
785786

786-
if when_matched_update_all:
787-
# function get_rows_to_update is doing a check on non-key columns to see if any of the values have actually changed
788-
# we don't want to do just a blanket overwrite for matched rows if the actual non-key column data hasn't changed
789-
# this extra step avoids unnecessary IO and writes
790-
rows_to_update = upsert_util.get_rows_to_update(df, matched_iceberg_table, join_cols)
787+
for batch in matched_iceberg_record_batches:
788+
rows = pa.Table.from_batches([batch])
791789

792-
update_row_cnt = len(rows_to_update)
790+
if when_matched_update_all:
791+
# function get_rows_to_update is doing a check on non-key columns to see if any of the values have actually changed
792+
# we don't want to do just a blanket overwrite for matched rows if the actual non-key column data hasn't changed
793+
# this extra step avoids unnecessary IO and writes
794+
rows_to_update = upsert_util.get_rows_to_update(df, rows, join_cols)
793795

794-
if len(rows_to_update) > 0:
795-
# build the match predicate filter
796-
overwrite_mask_predicate = upsert_util.create_match_filter(rows_to_update, join_cols)
796+
if len(rows_to_update) > 0:
797+
# build the match predicate filter
798+
overwrite_mask_predicate = upsert_util.create_match_filter(rows_to_update, join_cols)
797799

798-
self.overwrite(rows_to_update, overwrite_filter=overwrite_mask_predicate)
800+
batches_to_overwrite.append(rows_to_update)
801+
overwrite_predicates.append(overwrite_mask_predicate)
799802

800-
if when_not_matched_insert_all:
801-
expr_match = upsert_util.create_match_filter(matched_iceberg_table, join_cols)
802-
expr_match_bound = bind(self.table_metadata.schema(), expr_match, case_sensitive=case_sensitive)
803-
expr_match_arrow = expression_to_pyarrow(expr_match_bound)
804-
rows_to_insert = df.filter(~expr_match_arrow)
803+
if when_not_matched_insert_all:
804+
expr_match = upsert_util.create_match_filter(rows, join_cols)
805+
expr_match_bound = bind(self.table_metadata.schema(), expr_match, case_sensitive=case_sensitive)
806+
expr_match_arrow = expression_to_pyarrow(expr_match_bound)
805807

806-
insert_row_cnt = len(rows_to_insert)
808+
# Filter rows per batch.
809+
rows_to_insert = rows_to_insert.filter(~expr_match_arrow)
807810

808-
if insert_row_cnt > 0:
811+
update_row_cnt = 0
812+
insert_row_cnt = 0
813+
814+
if batches_to_overwrite:
815+
rows_to_update = pa.concat_tables(batches_to_overwrite)
816+
update_row_cnt = len(rows_to_update)
817+
self.overwrite(
818+
rows_to_update,
819+
overwrite_filter=Or(*overwrite_predicates) if len(overwrite_predicates) > 1 else overwrite_predicates[0],
820+
)
821+
822+
if when_not_matched_insert_all:
823+
insert_row_cnt = len(rows_to_insert)
824+
if rows_to_insert:
809825
self.append(rows_to_insert)
810826

811827
return UpsertResult(rows_updated=update_row_cnt, rows_inserted=insert_row_cnt)

0 commit comments

Comments
 (0)