Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion paimon_python_api/read_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def with_filter(self, predicate: Predicate):
"""

@abstractmethod
def with_projection(self, projection: List[List[int]]) -> 'ReadBuilder':
def with_projection(self, projection: List[str]) -> 'ReadBuilder':
"""Push nested projection."""

@abstractmethod
Expand Down
43 changes: 20 additions & 23 deletions paimon_python_java/pypaimon.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,37 +61,36 @@ class Table(table.Table):
def __init__(self, j_table, catalog_options: dict):
self._j_table = j_table
self._catalog_options = catalog_options
# init arrow schema
schema_bytes = get_gateway().jvm.SchemaUtil.getArrowSchema(j_table.rowType())
schema_reader = pa.RecordBatchStreamReader(pa.BufferReader(schema_bytes))
self._arrow_schema = schema_reader.schema
schema_reader.close()

def new_read_builder(self) -> 'ReadBuilder':
j_read_builder = get_gateway().jvm.InvocationUtil.getReadBuilder(self._j_table)
return ReadBuilder(
j_read_builder, self._j_table.rowType(), self._catalog_options, self._arrow_schema)
return ReadBuilder(j_read_builder, self._j_table.rowType(), self._catalog_options)

def new_batch_write_builder(self) -> 'BatchWriteBuilder':
java_utils.check_batch_write(self._j_table)
j_batch_write_builder = get_gateway().jvm.InvocationUtil.getBatchWriteBuilder(self._j_table)
return BatchWriteBuilder(j_batch_write_builder, self._j_table.rowType(), self._arrow_schema)
return BatchWriteBuilder(j_batch_write_builder)


class ReadBuilder(read_builder.ReadBuilder):

def __init__(self, j_read_builder, j_row_type, catalog_options: dict, arrow_schema: pa.Schema):
def __init__(self, j_read_builder, j_row_type, catalog_options: dict):
self._j_read_builder = j_read_builder
self._j_row_type = j_row_type
self._catalog_options = catalog_options
self._arrow_schema = arrow_schema

def with_filter(self, predicate: 'Predicate'):
self._j_read_builder.withFilter(predicate.to_j_predicate())
return self

def with_projection(self, projection: List[List[int]]) -> 'ReadBuilder':
self._j_read_builder.withProjection(projection)
def with_projection(self, projection: List[str]) -> 'ReadBuilder':
field_names = list(map(lambda field: field.name(), self._j_row_type.getFields()))
int_projection = list(map(lambda p: field_names.index(p), projection))
gateway = get_gateway()
int_projection_arr = gateway.new_array(gateway.jvm.int, len(projection))
for i in range(len(projection)):
int_projection_arr[i] = int_projection[i]
self._j_read_builder.withProjection(int_projection_arr)
return self

def with_limit(self, limit: int) -> 'ReadBuilder':
Expand All @@ -104,7 +103,7 @@ def new_scan(self) -> 'TableScan':

def new_read(self) -> 'TableRead':
j_table_read = self._j_read_builder.newRead().executeFilter()
return TableRead(j_table_read, self._j_row_type, self._catalog_options, self._arrow_schema)
return TableRead(j_table_read, self._j_read_builder.readType(), self._catalog_options)

def new_predicate_builder(self) -> 'PredicateBuilder':
return PredicateBuilder(self._j_row_type)
Expand Down Expand Up @@ -141,12 +140,12 @@ def to_j_split(self):

class TableRead(table_read.TableRead):

def __init__(self, j_table_read, j_row_type, catalog_options, arrow_schema):
def __init__(self, j_table_read, j_read_type, catalog_options):
self._j_table_read = j_table_read
self._j_row_type = j_row_type
self._j_read_type = j_read_type
self._catalog_options = catalog_options
self._j_bytes_reader = None
self._arrow_schema = arrow_schema
self._arrow_schema = java_utils.to_arrow_schema(j_read_type)

def to_arrow(self, splits):
record_batch_reader = self.to_arrow_batch_reader(splits)
Expand Down Expand Up @@ -174,7 +173,7 @@ def _init(self):
if max_workers <= 0:
raise ValueError("max_workers must be greater than 0")
self._j_bytes_reader = get_gateway().jvm.InvocationUtil.createParallelBytesReader(
self._j_table_read, self._j_row_type, max_workers)
self._j_table_read, self._j_read_type, max_workers)

def _batch_generator(self) -> Iterator[pa.RecordBatch]:
while True:
Expand All @@ -188,10 +187,8 @@ def _batch_generator(self) -> Iterator[pa.RecordBatch]:

class BatchWriteBuilder(write_builder.BatchWriteBuilder):

def __init__(self, j_batch_write_builder, j_row_type, arrow_schema: pa.Schema):
def __init__(self, j_batch_write_builder):
self._j_batch_write_builder = j_batch_write_builder
self._j_row_type = j_row_type
self._arrow_schema = arrow_schema

def overwrite(self, static_partition: Optional[dict] = None) -> 'BatchWriteBuilder':
if static_partition is None:
Expand All @@ -201,7 +198,7 @@ def overwrite(self, static_partition: Optional[dict] = None) -> 'BatchWriteBuild

def new_write(self) -> 'BatchTableWrite':
j_batch_table_write = self._j_batch_write_builder.newWrite()
return BatchTableWrite(j_batch_table_write, self._j_row_type, self._arrow_schema)
return BatchTableWrite(j_batch_table_write, self._j_batch_write_builder.rowType())

def new_commit(self) -> 'BatchTableCommit':
j_batch_table_commit = self._j_batch_write_builder.newCommit()
Expand All @@ -210,11 +207,11 @@ def new_commit(self) -> 'BatchTableCommit':

class BatchTableWrite(table_write.BatchTableWrite):

def __init__(self, j_batch_table_write, j_row_type, arrow_schema: pa.Schema):
def __init__(self, j_batch_table_write, j_row_type):
self._j_batch_table_write = j_batch_table_write
self._j_bytes_writer = get_gateway().jvm.InvocationUtil.createBytesWriter(
j_batch_table_write, j_row_type)
self._arrow_schema = arrow_schema
self._arrow_schema = java_utils.to_arrow_schema(j_row_type)

def write_arrow(self, table):
for record_batch in table.to_reader():
Expand Down
62 changes: 62 additions & 0 deletions paimon_python_java/tests/test_write_and_read.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,3 +445,65 @@ def _testIgnoreNullableImpl(self, table_name, table_schema, data_schema):
df['f0'] = df['f0'].astype('int32')
pd.testing.assert_frame_equal(
actual_df.reset_index(drop=True), df.reset_index(drop=True))

def testProjection(self):
pa_schema = pa.schema([
('f0', pa.int64()),
('f1', pa.string()),
('f2', pa.bool_()),
('f3', pa.string())
])
schema = Schema(pa_schema)
self.catalog.create_table('default.test_projection', schema, False)
table = self.catalog.get_table('default.test_projection')

# prepare data
data = {
'f0': [1, 2, 3],
'f1': ['a', 'b', 'c'],
'f2': [True, True, False],
'f3': ['A', 'B', 'C']
}
df = pd.DataFrame(data)

# write and commit data
write_builder = table.new_batch_write_builder()
table_write = write_builder.new_write()
table_commit = write_builder.new_commit()

table_write.write_pandas(df)
commit_messages = table_write.prepare_commit()
table_commit.commit(commit_messages)

table_write.close()
table_commit.close()

# case 1: read empty
read_builder = table.new_read_builder().with_projection([])
table_scan = read_builder.new_scan()
table_read = read_builder.new_read()
splits = table_scan.plan().splits()
result1 = table_read.to_pandas(splits)
self.assertTrue(result1.empty)

# case 2: read fully
read_builder = table.new_read_builder().with_projection(['f0', 'f1', 'f2', 'f3'])
table_scan = read_builder.new_scan()
table_read = read_builder.new_read()
splits = table_scan.plan().splits()
result2 = table_read.to_pandas(splits)
pd.testing.assert_frame_equal(
result2.reset_index(drop=True), df.reset_index(drop=True))

# case 3: read partially
read_builder = table.new_read_builder().with_projection(['f3', 'f2'])
table_scan = read_builder.new_scan()
table_read = read_builder.new_read()
splits = table_scan.plan().splits()
result3 = table_read.to_pandas(splits)
expected_df = pd.DataFrame({
'f3': ['A', 'B', 'C'],
'f2': [True, True, False]
})
pd.testing.assert_frame_equal(
result3.reset_index(drop=True), expected_df.reset_index(drop=True))
9 changes: 9 additions & 0 deletions paimon_python_java/util/java_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,3 +91,12 @@ def _to_j_type(name, pa_type):
return jvm.DataTypes.STRING()
else:
raise ValueError(f'Found unsupported data type {str(pa_type)} for field {name}.')


def to_arrow_schema(j_row_type):
# init arrow schema
schema_bytes = get_gateway().jvm.SchemaUtil.getArrowSchema(j_row_type)
schema_reader = pa.RecordBatchStreamReader(pa.BufferReader(schema_bytes))
arrow_schema = schema_reader.schema
schema_reader.close()
return arrow_schema
Loading