From 8621453c57c2a77faec6d9028003211574cf7bad Mon Sep 17 00:00:00 2001 From: xiaohongbo Date: Tue, 19 May 2026 14:15:55 +0800 Subject: [PATCH 1/3] [python] Support partial column write for PK tables --- .../pypaimon/tests/write/table_write_test.py | 53 +++++++++++++++++++ .../pypaimon/write/file_store_write.py | 3 +- .../pypaimon/write/row_key_extractor.py | 20 ++----- paimon-python/pypaimon/write/table_write.py | 16 ++++++ .../write/writer/key_value_data_writer.py | 24 +++++++-- 5 files changed, 95 insertions(+), 21 deletions(-) diff --git a/paimon-python/pypaimon/tests/write/table_write_test.py b/paimon-python/pypaimon/tests/write/table_write_test.py index 96abfb73aeec..48d2f5037971 100644 --- a/paimon-python/pypaimon/tests/write/table_write_test.py +++ b/paimon-python/pypaimon/tests/write/table_write_test.py @@ -471,3 +471,56 @@ def test_rolling(self): self.assertGreater(len(writer.committed_files), 0) if writer.pending_data is not None: self.assertLessEqual(writer.pending_data.nbytes, target) + + def test_pk_partial_column_write(self): + schema = Schema.from_pyarrow_schema(self.pa_schema, partition_keys=['dt'], + primary_keys=['user_id', 'dt'], + options={'bucket': '2', 'merge-engine': 'partial-update'}) + self.catalog.create_table('default.test_pk_partial_write', schema, False) + table = self.catalog.get_table('default.test_pk_partial_write') + + write_builder = table.new_stream_write_builder() + table_write = write_builder.new_write() + table_commit = write_builder.new_commit() + + # First write: full columns + data1 = pa.Table.from_pydict({ + 'user_id': [1, 2, 3], + 'item_id': [1001, 1002, 1003], + 'behavior': ['a', 'b', 'c'], + 'dt': ['p1', 'p1', 'p2'], + }, schema=self.pk_pa_schema) + table_write.write_arrow(data1) + table_commit.commit(table_write.prepare_commit(0), 0) + + # Second write: partial columns (only update item_id) + table_write.with_write_type(['user_id', 'dt', 'item_id']) + partial_schema = pa.schema([ + pa.field('user_id', pa.int32(), nullable=False), + pa.field('dt', pa.string(), nullable=False), + ('item_id', pa.int64()), + ]) + data2 = pa.Table.from_pydict({ + 'user_id': [1, 2], + 'dt': ['p1', 'p1'], + 'item_id': [9001, 9002], + }, schema=partial_schema) + table_write.write_arrow(data2) + table_commit.commit(table_write.prepare_commit(1), 1) + table_write.close() + table_commit.close() + + # Read back — PK merge-on-read deduplicates by key (last-write-wins) + read_builder = table.new_read_builder() + table_read = read_builder.new_read() + splits = read_builder.new_scan().plan().splits() + actual = table_read.to_arrow(splits).sort_by('user_id') + self.assertEqual(actual.num_rows, 3) + # user_id=1,2 overwritten by partial write: item_id updated, behavior=null + row1 = actual.filter(pa.compute.equal(actual['user_id'], 1)) + self.assertEqual(row1.column('item_id').to_pylist(), [9001]) + self.assertEqual(row1.column('behavior').to_pylist(), [None]) + # user_id=3 unchanged + row3 = actual.filter(pa.compute.equal(actual['user_id'], 3)) + self.assertEqual(row3.column('item_id').to_pylist(), [1003]) + self.assertEqual(row3.column('behavior').to_pylist(), ['c']) diff --git a/paimon-python/pypaimon/write/file_store_write.py b/paimon-python/pypaimon/write/file_store_write.py index c33fca3792eb..8fc16a11f79e 100644 --- a/paimon-python/pypaimon/write/file_store_write.py +++ b/paimon-python/pypaimon/write/file_store_write.py @@ -78,7 +78,8 @@ def max_seq_number(): partition=partition, bucket=bucket, max_seq_number=max_seq_number(), - options=options) + options=options, + write_cols=self.write_cols) else: seq_number = 0 if self.table.bucket_mode() == BucketMode.BUCKET_UNAWARE else max_seq_number() return AppendOnlyDataWriter( diff --git a/paimon-python/pypaimon/write/row_key_extractor.py b/paimon-python/pypaimon/write/row_key_extractor.py index dad93bf1eda7..2a3d73140bc1 100644 --- a/paimon-python/pypaimon/write/row_key_extractor.py +++ b/paimon-python/pypaimon/write/row_key_extractor.py @@ -83,24 +83,17 @@ class RowKeyExtractor(ABC): def __init__(self, table_schema: TableSchema): self.table_schema = table_schema - self.partition_indices = self._get_field_indices(table_schema.partition_keys) def extract_partition_bucket_batch(self, data: pa.RecordBatch) -> Tuple[List[Tuple], List[int]]: partitions = self._extract_partitions_batch(data) buckets = self._extract_buckets_batch(data) return partitions, buckets - def _get_field_indices(self, field_names: List[str]) -> List[int]: - if not field_names: - return [] - field_map = {field.name: i for i, field in enumerate(self.table_schema.fields)} - return [field_map[name] for name in field_names if name in field_map] - def _extract_partitions_batch(self, data: pa.RecordBatch) -> List[Tuple]: - if not self.partition_indices: + if not self.table_schema.partition_keys: return [() for _ in range(data.num_rows)] - partition_columns = [data.column(i) for i in self.partition_indices] + partition_columns = [data.column(name) for name in self.table_schema.partition_keys] partitions = [] for row_idx in range(data.num_rows): @@ -124,15 +117,11 @@ def __init__(self, table_schema: TableSchema): if self.num_buckets <= 0: raise ValueError(f"Fixed bucket mode requires bucket > 0, got {self.num_buckets}") - # Bucket-key resolution lives on TableSchema (mirrors Java - # ``TableSchema.bucketKeys()`` / ``logicalBucketKeyType()``); reuse - # it so any reader path that walks the same logic stays in sync. self.bucket_keys = table_schema.bucket_keys - self.bucket_key_indices = self._get_field_indices(self.bucket_keys) self._bucket_key_fields = table_schema.logical_bucket_key_fields def _extract_buckets_batch(self, data: pa.RecordBatch) -> List[int]: - columns = [data.column(i) for i in self.bucket_key_indices] + columns = [data.column(name) for name in self.bucket_keys] return [ _bucket_from_hash( self._binary_row_hash_code(tuple(col[row_idx].as_py() for col in columns)), @@ -287,7 +276,6 @@ def __init__(self, table_schema: 'TableSchema'): pk for pk in table_schema.primary_keys if pk not in table_schema.partition_keys ] - self.bucket_key_indices = self._get_field_indices(self.bucket_keys) field_map = {f.name: f for f in table_schema.fields} self._bucket_key_fields = [ field_map[name] for name in self.bucket_keys if name in field_map @@ -295,7 +283,7 @@ def __init__(self, table_schema: 'TableSchema'): def _extract_buckets_batch(self, data: pa.RecordBatch) -> List[int]: partitions = self._extract_partitions_batch(data) - columns = [data.column(i) for i in self.bucket_key_indices] + columns = [data.column(name) for name in self.bucket_keys] buckets = [] for row_idx in range(data.num_rows): key_hash = _hash_bytes_by_words( diff --git a/paimon-python/pypaimon/write/table_write.py b/paimon-python/pypaimon/write/table_write.py index 80ef5a3572bb..1ce39ea29b7c 100644 --- a/paimon-python/pypaimon/write/table_write.py +++ b/paimon-python/pypaimon/write/table_write.py @@ -66,9 +66,25 @@ def with_write_type(self, write_cols: List[str]): for col in write_cols: if col not in self.table_pyarrow_schema.names: raise ValueError(f"Column {col} is not in table schema.") + # Partition keys are always needed for routing + missing_partitions = [pk for pk in self.table.partition_keys if pk not in write_cols] + if missing_partitions: + raise ValueError( + f"Partition key columns {missing_partitions} must be included in write_cols." + ) + if self.table.is_primary_key_table: + missing_pks = [pk for pk in self.table.trimmed_primary_keys if pk not in write_cols] + if missing_pks: + raise ValueError( + f"Primary key columns {missing_pks} must be included in write_cols " + f"for partial column write on PK tables." + ) if len(write_cols) == len(self.table_pyarrow_schema.names): write_cols = None self.file_store_write.write_cols = write_cols + # Propagate to existing writers so partial-column logic takes effect + for writer in self.file_store_write.data_writers.values(): + writer.write_cols = write_cols return self def write_ray( diff --git a/paimon-python/pypaimon/write/writer/key_value_data_writer.py b/paimon-python/pypaimon/write/writer/key_value_data_writer.py index 6c6f292f575c..90b850e26aef 100644 --- a/paimon-python/pypaimon/write/writer/key_value_data_writer.py +++ b/paimon-python/pypaimon/write/writer/key_value_data_writer.py @@ -18,6 +18,7 @@ import pyarrow as pa import pyarrow.compute as pc +from pypaimon.schema.data_types import PyarrowFieldParser from pypaimon.write.writer.data_writer import DataWriter @@ -33,7 +34,11 @@ def _merge_data(self, existing_data: pa.Table, new_data: pa.Table) -> pa.Table: return self._sort_by_primary_key(combined) def _add_system_fields(self, data: pa.RecordBatch) -> pa.RecordBatch: - """Add system fields: _KEY_{pk_key}, _SEQUENCE_NUMBER, _VALUE_KIND.""" + """Add system fields: _KEY_{pk_key}, _SEQUENCE_NUMBER, _VALUE_KIND. + + When write_cols is set (partial column write), missing value columns + are filled with null arrays so the output KV file has the full schema. + """ num_rows = data.num_rows new_arrays = [] @@ -55,9 +60,20 @@ def _add_system_fields(self, data: pa.RecordBatch) -> pa.RecordBatch: new_arrays.append(value_kind_column) new_fields.append(pa.field('_VALUE_KIND', pa.int8(), nullable=False)) - for i in range(data.num_columns): - new_arrays.append(data.column(i)) - new_fields.append(data.schema.field(i)) + if self.write_cols is not None: + data_col_names = set(data.schema.names) + for field in self.table.fields: + if field.name in data_col_names: + new_arrays.append(data.column(field.name)) + new_fields.append(data.schema.field(field.name)) + else: + pa_type = PyarrowFieldParser.from_paimon_type(field.type) + new_arrays.append(pa.nulls(num_rows, type=pa_type)) + new_fields.append(pa.field(field.name, pa_type, nullable=True)) + else: + for i in range(data.num_columns): + new_arrays.append(data.column(i)) + new_fields.append(data.schema.field(i)) return pa.RecordBatch.from_arrays(new_arrays, schema=pa.schema(new_fields)) From 6a63976cb0e1e72c89a82bc7d332cd142f430da9 Mon Sep 17 00:00:00 2001 From: xiaohongbo Date: Tue, 19 May 2026 14:22:33 +0800 Subject: [PATCH 2/3] clean up comments --- paimon-python/pypaimon/tests/write/table_write_test.py | 5 ----- paimon-python/pypaimon/write/table_write.py | 2 -- .../pypaimon/write/writer/key_value_data_writer.py | 6 +----- 3 files changed, 1 insertion(+), 12 deletions(-) diff --git a/paimon-python/pypaimon/tests/write/table_write_test.py b/paimon-python/pypaimon/tests/write/table_write_test.py index 48d2f5037971..61df4db4e427 100644 --- a/paimon-python/pypaimon/tests/write/table_write_test.py +++ b/paimon-python/pypaimon/tests/write/table_write_test.py @@ -483,7 +483,6 @@ def test_pk_partial_column_write(self): table_write = write_builder.new_write() table_commit = write_builder.new_commit() - # First write: full columns data1 = pa.Table.from_pydict({ 'user_id': [1, 2, 3], 'item_id': [1001, 1002, 1003], @@ -493,7 +492,6 @@ def test_pk_partial_column_write(self): table_write.write_arrow(data1) table_commit.commit(table_write.prepare_commit(0), 0) - # Second write: partial columns (only update item_id) table_write.with_write_type(['user_id', 'dt', 'item_id']) partial_schema = pa.schema([ pa.field('user_id', pa.int32(), nullable=False), @@ -510,17 +508,14 @@ def test_pk_partial_column_write(self): table_write.close() table_commit.close() - # Read back — PK merge-on-read deduplicates by key (last-write-wins) read_builder = table.new_read_builder() table_read = read_builder.new_read() splits = read_builder.new_scan().plan().splits() actual = table_read.to_arrow(splits).sort_by('user_id') self.assertEqual(actual.num_rows, 3) - # user_id=1,2 overwritten by partial write: item_id updated, behavior=null row1 = actual.filter(pa.compute.equal(actual['user_id'], 1)) self.assertEqual(row1.column('item_id').to_pylist(), [9001]) self.assertEqual(row1.column('behavior').to_pylist(), [None]) - # user_id=3 unchanged row3 = actual.filter(pa.compute.equal(actual['user_id'], 3)) self.assertEqual(row3.column('item_id').to_pylist(), [1003]) self.assertEqual(row3.column('behavior').to_pylist(), ['c']) diff --git a/paimon-python/pypaimon/write/table_write.py b/paimon-python/pypaimon/write/table_write.py index 1ce39ea29b7c..ada0ec985823 100644 --- a/paimon-python/pypaimon/write/table_write.py +++ b/paimon-python/pypaimon/write/table_write.py @@ -66,7 +66,6 @@ def with_write_type(self, write_cols: List[str]): for col in write_cols: if col not in self.table_pyarrow_schema.names: raise ValueError(f"Column {col} is not in table schema.") - # Partition keys are always needed for routing missing_partitions = [pk for pk in self.table.partition_keys if pk not in write_cols] if missing_partitions: raise ValueError( @@ -82,7 +81,6 @@ def with_write_type(self, write_cols: List[str]): if len(write_cols) == len(self.table_pyarrow_schema.names): write_cols = None self.file_store_write.write_cols = write_cols - # Propagate to existing writers so partial-column logic takes effect for writer in self.file_store_write.data_writers.values(): writer.write_cols = write_cols return self diff --git a/paimon-python/pypaimon/write/writer/key_value_data_writer.py b/paimon-python/pypaimon/write/writer/key_value_data_writer.py index 90b850e26aef..bc96782fd4fa 100644 --- a/paimon-python/pypaimon/write/writer/key_value_data_writer.py +++ b/paimon-python/pypaimon/write/writer/key_value_data_writer.py @@ -34,11 +34,7 @@ def _merge_data(self, existing_data: pa.Table, new_data: pa.Table) -> pa.Table: return self._sort_by_primary_key(combined) def _add_system_fields(self, data: pa.RecordBatch) -> pa.RecordBatch: - """Add system fields: _KEY_{pk_key}, _SEQUENCE_NUMBER, _VALUE_KIND. - - When write_cols is set (partial column write), missing value columns - are filled with null arrays so the output KV file has the full schema. - """ + """Add system fields: _KEY_{pk_key}, _SEQUENCE_NUMBER, _VALUE_KIND.""" num_rows = data.num_rows new_arrays = [] From e9aa9c3b7bcaf86df9b1656b53eab86fb0f77ce5 Mon Sep 17 00:00:00 2001 From: xiaohongbo Date: Tue, 19 May 2026 14:42:32 +0800 Subject: [PATCH 3/3] align with Java: disallow with_write_type on PK tables --- .../pypaimon/tests/write/table_write_test.py | 12 ++++-------- .../pypaimon/write/file_store_write.py | 3 +-- paimon-python/pypaimon/write/table_write.py | 16 ++-------------- .../write/writer/key_value_data_writer.py | 18 +++--------------- 4 files changed, 10 insertions(+), 39 deletions(-) diff --git a/paimon-python/pypaimon/tests/write/table_write_test.py b/paimon-python/pypaimon/tests/write/table_write_test.py index 61df4db4e427..b8c759ed6ba7 100644 --- a/paimon-python/pypaimon/tests/write/table_write_test.py +++ b/paimon-python/pypaimon/tests/write/table_write_test.py @@ -492,17 +492,13 @@ def test_pk_partial_column_write(self): table_write.write_arrow(data1) table_commit.commit(table_write.prepare_commit(0), 0) - table_write.with_write_type(['user_id', 'dt', 'item_id']) - partial_schema = pa.schema([ - pa.field('user_id', pa.int32(), nullable=False), - pa.field('dt', pa.string(), nullable=False), - ('item_id', pa.int64()), - ]) + # User passes full-width data with null for missing columns data2 = pa.Table.from_pydict({ 'user_id': [1, 2], - 'dt': ['p1', 'p1'], 'item_id': [9001, 9002], - }, schema=partial_schema) + 'behavior': [None, None], + 'dt': ['p1', 'p1'], + }, schema=self.pk_pa_schema) table_write.write_arrow(data2) table_commit.commit(table_write.prepare_commit(1), 1) table_write.close() diff --git a/paimon-python/pypaimon/write/file_store_write.py b/paimon-python/pypaimon/write/file_store_write.py index 8fc16a11f79e..c33fca3792eb 100644 --- a/paimon-python/pypaimon/write/file_store_write.py +++ b/paimon-python/pypaimon/write/file_store_write.py @@ -78,8 +78,7 @@ def max_seq_number(): partition=partition, bucket=bucket, max_seq_number=max_seq_number(), - options=options, - write_cols=self.write_cols) + options=options) else: seq_number = 0 if self.table.bucket_mode() == BucketMode.BUCKET_UNAWARE else max_seq_number() return AppendOnlyDataWriter( diff --git a/paimon-python/pypaimon/write/table_write.py b/paimon-python/pypaimon/write/table_write.py index ada0ec985823..8092c4367fd1 100644 --- a/paimon-python/pypaimon/write/table_write.py +++ b/paimon-python/pypaimon/write/table_write.py @@ -63,26 +63,14 @@ def write_pandas(self, dataframe): return self.write_arrow_batch(record_batch) def with_write_type(self, write_cols: List[str]): + if self.table.is_primary_key_table: + raise NotImplementedError("with_write_type is not supported for primary key tables.") for col in write_cols: if col not in self.table_pyarrow_schema.names: raise ValueError(f"Column {col} is not in table schema.") - missing_partitions = [pk for pk in self.table.partition_keys if pk not in write_cols] - if missing_partitions: - raise ValueError( - f"Partition key columns {missing_partitions} must be included in write_cols." - ) - if self.table.is_primary_key_table: - missing_pks = [pk for pk in self.table.trimmed_primary_keys if pk not in write_cols] - if missing_pks: - raise ValueError( - f"Primary key columns {missing_pks} must be included in write_cols " - f"for partial column write on PK tables." - ) if len(write_cols) == len(self.table_pyarrow_schema.names): write_cols = None self.file_store_write.write_cols = write_cols - for writer in self.file_store_write.data_writers.values(): - writer.write_cols = write_cols return self def write_ray( diff --git a/paimon-python/pypaimon/write/writer/key_value_data_writer.py b/paimon-python/pypaimon/write/writer/key_value_data_writer.py index bc96782fd4fa..6c6f292f575c 100644 --- a/paimon-python/pypaimon/write/writer/key_value_data_writer.py +++ b/paimon-python/pypaimon/write/writer/key_value_data_writer.py @@ -18,7 +18,6 @@ import pyarrow as pa import pyarrow.compute as pc -from pypaimon.schema.data_types import PyarrowFieldParser from pypaimon.write.writer.data_writer import DataWriter @@ -56,20 +55,9 @@ def _add_system_fields(self, data: pa.RecordBatch) -> pa.RecordBatch: new_arrays.append(value_kind_column) new_fields.append(pa.field('_VALUE_KIND', pa.int8(), nullable=False)) - if self.write_cols is not None: - data_col_names = set(data.schema.names) - for field in self.table.fields: - if field.name in data_col_names: - new_arrays.append(data.column(field.name)) - new_fields.append(data.schema.field(field.name)) - else: - pa_type = PyarrowFieldParser.from_paimon_type(field.type) - new_arrays.append(pa.nulls(num_rows, type=pa_type)) - new_fields.append(pa.field(field.name, pa_type, nullable=True)) - else: - for i in range(data.num_columns): - new_arrays.append(data.column(i)) - new_fields.append(data.schema.field(i)) + for i in range(data.num_columns): + new_arrays.append(data.column(i)) + new_fields.append(data.schema.field(i)) return pa.RecordBatch.from_arrays(new_arrays, schema=pa.schema(new_fields))