diff --git a/paimon-python/pypaimon/tests/write/table_write_test.py b/paimon-python/pypaimon/tests/write/table_write_test.py index 96abfb73aeec..b8c759ed6ba7 100644 --- a/paimon-python/pypaimon/tests/write/table_write_test.py +++ b/paimon-python/pypaimon/tests/write/table_write_test.py @@ -471,3 +471,47 @@ def test_rolling(self): self.assertGreater(len(writer.committed_files), 0) if writer.pending_data is not None: self.assertLessEqual(writer.pending_data.nbytes, target) + + def test_pk_partial_column_write(self): + schema = Schema.from_pyarrow_schema(self.pa_schema, partition_keys=['dt'], + primary_keys=['user_id', 'dt'], + options={'bucket': '2', 'merge-engine': 'partial-update'}) + self.catalog.create_table('default.test_pk_partial_write', schema, False) + table = self.catalog.get_table('default.test_pk_partial_write') + + write_builder = table.new_stream_write_builder() + table_write = write_builder.new_write() + table_commit = write_builder.new_commit() + + data1 = pa.Table.from_pydict({ + 'user_id': [1, 2, 3], + 'item_id': [1001, 1002, 1003], + 'behavior': ['a', 'b', 'c'], + 'dt': ['p1', 'p1', 'p2'], + }, schema=self.pk_pa_schema) + table_write.write_arrow(data1) + table_commit.commit(table_write.prepare_commit(0), 0) + + # User passes full-width data with null for missing columns + data2 = pa.Table.from_pydict({ + 'user_id': [1, 2], + 'item_id': [9001, 9002], + 'behavior': [None, None], + 'dt': ['p1', 'p1'], + }, schema=self.pk_pa_schema) + table_write.write_arrow(data2) + table_commit.commit(table_write.prepare_commit(1), 1) + table_write.close() + table_commit.close() + + read_builder = table.new_read_builder() + table_read = read_builder.new_read() + splits = read_builder.new_scan().plan().splits() + actual = table_read.to_arrow(splits).sort_by('user_id') + self.assertEqual(actual.num_rows, 3) + row1 = actual.filter(pa.compute.equal(actual['user_id'], 1)) + self.assertEqual(row1.column('item_id').to_pylist(), [9001]) + self.assertEqual(row1.column('behavior').to_pylist(), [None]) + row3 = actual.filter(pa.compute.equal(actual['user_id'], 3)) + self.assertEqual(row3.column('item_id').to_pylist(), [1003]) + self.assertEqual(row3.column('behavior').to_pylist(), ['c']) diff --git a/paimon-python/pypaimon/write/row_key_extractor.py b/paimon-python/pypaimon/write/row_key_extractor.py index dad93bf1eda7..2a3d73140bc1 100644 --- a/paimon-python/pypaimon/write/row_key_extractor.py +++ b/paimon-python/pypaimon/write/row_key_extractor.py @@ -83,24 +83,17 @@ class RowKeyExtractor(ABC): def __init__(self, table_schema: TableSchema): self.table_schema = table_schema - self.partition_indices = self._get_field_indices(table_schema.partition_keys) def extract_partition_bucket_batch(self, data: pa.RecordBatch) -> Tuple[List[Tuple], List[int]]: partitions = self._extract_partitions_batch(data) buckets = self._extract_buckets_batch(data) return partitions, buckets - def _get_field_indices(self, field_names: List[str]) -> List[int]: - if not field_names: - return [] - field_map = {field.name: i for i, field in enumerate(self.table_schema.fields)} - return [field_map[name] for name in field_names if name in field_map] - def _extract_partitions_batch(self, data: pa.RecordBatch) -> List[Tuple]: - if not self.partition_indices: + if not self.table_schema.partition_keys: return [() for _ in range(data.num_rows)] - partition_columns = [data.column(i) for i in self.partition_indices] + partition_columns = [data.column(name) for name in self.table_schema.partition_keys] partitions = [] for row_idx in range(data.num_rows): @@ -124,15 +117,11 @@ def __init__(self, table_schema: TableSchema): if self.num_buckets <= 0: raise ValueError(f"Fixed bucket mode requires bucket > 0, got {self.num_buckets}") - # Bucket-key resolution lives on TableSchema (mirrors Java - # ``TableSchema.bucketKeys()`` / ``logicalBucketKeyType()``); reuse - # it so any reader path that walks the same logic stays in sync. self.bucket_keys = table_schema.bucket_keys - self.bucket_key_indices = self._get_field_indices(self.bucket_keys) self._bucket_key_fields = table_schema.logical_bucket_key_fields def _extract_buckets_batch(self, data: pa.RecordBatch) -> List[int]: - columns = [data.column(i) for i in self.bucket_key_indices] + columns = [data.column(name) for name in self.bucket_keys] return [ _bucket_from_hash( self._binary_row_hash_code(tuple(col[row_idx].as_py() for col in columns)), @@ -287,7 +276,6 @@ def __init__(self, table_schema: 'TableSchema'): pk for pk in table_schema.primary_keys if pk not in table_schema.partition_keys ] - self.bucket_key_indices = self._get_field_indices(self.bucket_keys) field_map = {f.name: f for f in table_schema.fields} self._bucket_key_fields = [ field_map[name] for name in self.bucket_keys if name in field_map @@ -295,7 +283,7 @@ def __init__(self, table_schema: 'TableSchema'): def _extract_buckets_batch(self, data: pa.RecordBatch) -> List[int]: partitions = self._extract_partitions_batch(data) - columns = [data.column(i) for i in self.bucket_key_indices] + columns = [data.column(name) for name in self.bucket_keys] buckets = [] for row_idx in range(data.num_rows): key_hash = _hash_bytes_by_words( diff --git a/paimon-python/pypaimon/write/table_write.py b/paimon-python/pypaimon/write/table_write.py index 80ef5a3572bb..8092c4367fd1 100644 --- a/paimon-python/pypaimon/write/table_write.py +++ b/paimon-python/pypaimon/write/table_write.py @@ -63,6 +63,8 @@ def write_pandas(self, dataframe): return self.write_arrow_batch(record_batch) def with_write_type(self, write_cols: List[str]): + if self.table.is_primary_key_table: + raise NotImplementedError("with_write_type is not supported for primary key tables.") for col in write_cols: if col not in self.table_pyarrow_schema.names: raise ValueError(f"Column {col} is not in table schema.")