Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions paimon-python/pypaimon/tests/write/table_write_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,3 +471,47 @@ def test_rolling(self):
self.assertGreater(len(writer.committed_files), 0)
if writer.pending_data is not None:
self.assertLessEqual(writer.pending_data.nbytes, target)

def test_pk_partial_column_write(self):
schema = Schema.from_pyarrow_schema(self.pa_schema, partition_keys=['dt'],
primary_keys=['user_id', 'dt'],
options={'bucket': '2', 'merge-engine': 'partial-update'})
self.catalog.create_table('default.test_pk_partial_write', schema, False)
table = self.catalog.get_table('default.test_pk_partial_write')

write_builder = table.new_stream_write_builder()
table_write = write_builder.new_write()
table_commit = write_builder.new_commit()

data1 = pa.Table.from_pydict({
'user_id': [1, 2, 3],
'item_id': [1001, 1002, 1003],
'behavior': ['a', 'b', 'c'],
'dt': ['p1', 'p1', 'p2'],
}, schema=self.pk_pa_schema)
table_write.write_arrow(data1)
table_commit.commit(table_write.prepare_commit(0), 0)

# User passes full-width data with null for missing columns
data2 = pa.Table.from_pydict({
'user_id': [1, 2],
'item_id': [9001, 9002],
'behavior': [None, None],
'dt': ['p1', 'p1'],
}, schema=self.pk_pa_schema)
table_write.write_arrow(data2)
table_commit.commit(table_write.prepare_commit(1), 1)
table_write.close()
table_commit.close()

read_builder = table.new_read_builder()
table_read = read_builder.new_read()
splits = read_builder.new_scan().plan().splits()
actual = table_read.to_arrow(splits).sort_by('user_id')
self.assertEqual(actual.num_rows, 3)
row1 = actual.filter(pa.compute.equal(actual['user_id'], 1))
self.assertEqual(row1.column('item_id').to_pylist(), [9001])
self.assertEqual(row1.column('behavior').to_pylist(), [None])
row3 = actual.filter(pa.compute.equal(actual['user_id'], 3))
self.assertEqual(row3.column('item_id').to_pylist(), [1003])
self.assertEqual(row3.column('behavior').to_pylist(), ['c'])
20 changes: 4 additions & 16 deletions paimon-python/pypaimon/write/row_key_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,24 +83,17 @@ class RowKeyExtractor(ABC):

def __init__(self, table_schema: TableSchema):
self.table_schema = table_schema
self.partition_indices = self._get_field_indices(table_schema.partition_keys)

def extract_partition_bucket_batch(self, data: pa.RecordBatch) -> Tuple[List[Tuple], List[int]]:
partitions = self._extract_partitions_batch(data)
buckets = self._extract_buckets_batch(data)
return partitions, buckets

def _get_field_indices(self, field_names: List[str]) -> List[int]:
if not field_names:
return []
field_map = {field.name: i for i, field in enumerate(self.table_schema.fields)}
return [field_map[name] for name in field_names if name in field_map]

def _extract_partitions_batch(self, data: pa.RecordBatch) -> List[Tuple]:
if not self.partition_indices:
if not self.table_schema.partition_keys:
return [() for _ in range(data.num_rows)]

partition_columns = [data.column(i) for i in self.partition_indices]
partition_columns = [data.column(name) for name in self.table_schema.partition_keys]

partitions = []
for row_idx in range(data.num_rows):
Expand All @@ -124,15 +117,11 @@ def __init__(self, table_schema: TableSchema):
if self.num_buckets <= 0:
raise ValueError(f"Fixed bucket mode requires bucket > 0, got {self.num_buckets}")

# Bucket-key resolution lives on TableSchema (mirrors Java
# ``TableSchema.bucketKeys()`` / ``logicalBucketKeyType()``); reuse
# it so any reader path that walks the same logic stays in sync.
self.bucket_keys = table_schema.bucket_keys
self.bucket_key_indices = self._get_field_indices(self.bucket_keys)
self._bucket_key_fields = table_schema.logical_bucket_key_fields

def _extract_buckets_batch(self, data: pa.RecordBatch) -> List[int]:
columns = [data.column(i) for i in self.bucket_key_indices]
columns = [data.column(name) for name in self.bucket_keys]
return [
_bucket_from_hash(
self._binary_row_hash_code(tuple(col[row_idx].as_py() for col in columns)),
Expand Down Expand Up @@ -287,15 +276,14 @@ def __init__(self, table_schema: 'TableSchema'):
pk for pk in table_schema.primary_keys
if pk not in table_schema.partition_keys
]
self.bucket_key_indices = self._get_field_indices(self.bucket_keys)
field_map = {f.name: f for f in table_schema.fields}
self._bucket_key_fields = [
field_map[name] for name in self.bucket_keys if name in field_map
]

def _extract_buckets_batch(self, data: pa.RecordBatch) -> List[int]:
partitions = self._extract_partitions_batch(data)
columns = [data.column(i) for i in self.bucket_key_indices]
columns = [data.column(name) for name in self.bucket_keys]
buckets = []
for row_idx in range(data.num_rows):
key_hash = _hash_bytes_by_words(
Expand Down
2 changes: 2 additions & 0 deletions paimon-python/pypaimon/write/table_write.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ def write_pandas(self, dataframe):
return self.write_arrow_batch(record_batch)

def with_write_type(self, write_cols: List[str]):
if self.table.is_primary_key_table:
raise NotImplementedError("with_write_type is not supported for primary key tables.")
for col in write_cols:
if col not in self.table_pyarrow_schema.names:
raise ValueError(f"Column {col} is not in table schema.")
Expand Down
Loading