diff --git a/crates/iceberg/src/arrow/mod.rs b/crates/iceberg/src/arrow/mod.rs index 28116a4b5e..c091c45177 100644 --- a/crates/iceberg/src/arrow/mod.rs +++ b/crates/iceberg/src/arrow/mod.rs @@ -35,4 +35,9 @@ mod value; pub use reader::*; pub use value::*; -pub(crate) mod record_batch_partition_splitter; +/// Partition value calculator for computing partition values +pub mod partition_value_calculator; +pub use partition_value_calculator::*; +/// Record batch partition splitter for partitioned tables +pub mod record_batch_partition_splitter; +pub use record_batch_partition_splitter::*; diff --git a/crates/iceberg/src/arrow/partition_value_calculator.rs b/crates/iceberg/src/arrow/partition_value_calculator.rs new file mode 100644 index 0000000000..1409503451 --- /dev/null +++ b/crates/iceberg/src/arrow/partition_value_calculator.rs @@ -0,0 +1,254 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Partition value calculation for Iceberg tables. +//! +//! This module provides utilities for calculating partition values from record batches +//! based on a partition specification. + +use std::sync::Arc; + +use arrow_array::{ArrayRef, RecordBatch, StructArray}; +use arrow_schema::DataType; + +use super::record_batch_projector::RecordBatchProjector; +use super::type_to_arrow_type; +use crate::spec::{PartitionSpec, Schema, StructType, Type}; +use crate::transform::{BoxedTransformFunction, create_transform_function}; +use crate::{Error, ErrorKind, Result}; + +/// Calculator for partition values in Iceberg tables. +/// +/// This struct handles the projection of source columns and application of +/// partition transforms to compute partition values for a given record batch. +#[derive(Debug)] +pub struct PartitionValueCalculator { + projector: RecordBatchProjector, + transform_functions: Vec, + partition_type: StructType, + partition_arrow_type: DataType, +} + +impl PartitionValueCalculator { + /// Create a new PartitionValueCalculator. + /// + /// # Arguments + /// + /// * `partition_spec` - The partition specification + /// * `table_schema` - The Iceberg table schema + /// + /// # Returns + /// + /// Returns a new `PartitionValueCalculator` instance or an error if initialization fails. + /// + /// # Errors + /// + /// Returns an error if: + /// - The partition spec is unpartitioned + /// - Transform function creation fails + /// - Projector initialization fails + pub fn try_new(partition_spec: &PartitionSpec, table_schema: &Schema) -> Result { + if partition_spec.is_unpartitioned() { + return Err(Error::new( + ErrorKind::DataInvalid, + "Cannot create partition calculator for unpartitioned table", + )); + } + + // Create transform functions for each partition field + let transform_functions: Vec = partition_spec + .fields() + .iter() + .map(|pf| create_transform_function(&pf.transform)) + .collect::>>()?; + + // Extract source field IDs for projection + let source_field_ids: Vec = partition_spec + .fields() + .iter() + .map(|pf| pf.source_id) + .collect(); + + // Create projector for extracting source columns + let projector = RecordBatchProjector::from_iceberg_schema( + Arc::new(table_schema.clone()), + &source_field_ids, + )?; + + // Get partition type information + let partition_type = partition_spec.partition_type(table_schema)?; + let partition_arrow_type = type_to_arrow_type(&Type::Struct(partition_type.clone()))?; + + Ok(Self { + projector, + transform_functions, + partition_type, + partition_arrow_type, + }) + } + + /// Get the partition type as an Iceberg StructType. + pub fn partition_type(&self) -> &StructType { + &self.partition_type + } + + /// Get the partition type as an Arrow DataType. + pub fn partition_arrow_type(&self) -> &DataType { + &self.partition_arrow_type + } + + /// Calculate partition values for a record batch. + /// + /// This method: + /// 1. Projects the source columns from the batch + /// 2. Applies partition transforms to each source column + /// 3. Constructs a StructArray containing the partition values + /// + /// # Arguments + /// + /// * `batch` - The record batch to calculate partition values for + /// + /// # Returns + /// + /// Returns an ArrayRef containing a StructArray of partition values, or an error if calculation fails. + /// + /// # Errors + /// + /// Returns an error if: + /// - Column projection fails + /// - Transform application fails + /// - StructArray construction fails + pub fn calculate(&self, batch: &RecordBatch) -> Result { + // Project source columns from the batch + let source_columns = self.projector.project_column(batch.columns())?; + + // Get expected struct fields for the result + let expected_struct_fields = match &self.partition_arrow_type { + DataType::Struct(fields) => fields.clone(), + _ => { + return Err(Error::new( + ErrorKind::DataInvalid, + "Expected partition type must be a struct", + )); + } + }; + + // Apply transforms to each source column + let mut partition_values = Vec::with_capacity(self.transform_functions.len()); + for (source_column, transform_fn) in source_columns.iter().zip(&self.transform_functions) { + let partition_value = transform_fn.transform(source_column.clone())?; + partition_values.push(partition_value); + } + + // Construct the StructArray + let struct_array = StructArray::try_new(expected_struct_fields, partition_values, None) + .map_err(|e| { + Error::new( + ErrorKind::DataInvalid, + format!("Failed to create partition struct array: {}", e), + ) + })?; + + Ok(Arc::new(struct_array)) + } +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use arrow_array::{Int32Array, RecordBatch, StringArray}; + use arrow_schema::{Field, Schema as ArrowSchema}; + + use super::*; + use crate::spec::{NestedField, PartitionSpecBuilder, PrimitiveType, Transform}; + + #[test] + fn test_partition_calculator_identity_transform() { + let table_schema = Schema::builder() + .with_schema_id(0) + .with_fields(vec![ + NestedField::required(1, "id", Type::Primitive(PrimitiveType::Int)).into(), + NestedField::required(2, "name", Type::Primitive(PrimitiveType::String)).into(), + ]) + .build() + .unwrap(); + + let partition_spec = PartitionSpecBuilder::new(Arc::new(table_schema.clone())) + .add_partition_field("id", "id_partition", Transform::Identity) + .unwrap() + .build() + .unwrap(); + + let calculator = PartitionValueCalculator::try_new(&partition_spec, &table_schema).unwrap(); + + // Verify partition type + assert_eq!(calculator.partition_type().fields().len(), 1); + assert_eq!(calculator.partition_type().fields()[0].name, "id_partition"); + + // Create test batch + let arrow_schema = Arc::new(ArrowSchema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("name", DataType::Utf8, false), + ])); + + let batch = RecordBatch::try_new(arrow_schema, vec![ + Arc::new(Int32Array::from(vec![10, 20, 30])), + Arc::new(StringArray::from(vec!["a", "b", "c"])), + ]) + .unwrap(); + + // Calculate partition values + let result = calculator.calculate(&batch).unwrap(); + let struct_array = result.as_any().downcast_ref::().unwrap(); + + let id_partition = struct_array + .column_by_name("id_partition") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + + assert_eq!(id_partition.value(0), 10); + assert_eq!(id_partition.value(1), 20); + assert_eq!(id_partition.value(2), 30); + } + + #[test] + fn test_partition_calculator_unpartitioned_error() { + let table_schema = Schema::builder() + .with_schema_id(0) + .with_fields(vec![ + NestedField::required(1, "id", Type::Primitive(PrimitiveType::Int)).into(), + ]) + .build() + .unwrap(); + + let partition_spec = PartitionSpecBuilder::new(Arc::new(table_schema.clone())) + .build() + .unwrap(); + + let result = PartitionValueCalculator::try_new(&partition_spec, &table_schema); + assert!(result.is_err()); + assert!( + result + .unwrap_err() + .to_string() + .contains("unpartitioned table") + ); + } +} diff --git a/crates/iceberg/src/arrow/record_batch_partition_splitter.rs b/crates/iceberg/src/arrow/record_batch_partition_splitter.rs index 704a4e9c15..5b0af2d00e 100644 --- a/crates/iceberg/src/arrow/record_batch_partition_splitter.rs +++ b/crates/iceberg/src/arrow/record_batch_partition_splitter.rs @@ -19,18 +19,17 @@ use std::collections::HashMap; use std::sync::Arc; use arrow_array::{ArrayRef, BooleanArray, RecordBatch, StructArray}; -use arrow_schema::{DataType, SchemaRef as ArrowSchemaRef}; +use arrow_schema::SchemaRef as ArrowSchemaRef; use arrow_select::filter::filter_record_batch; -use itertools::Itertools; -use parquet::arrow::PARQUET_FIELD_ID_META_KEY; use super::arrow_struct_to_literal; -use super::record_batch_projector::RecordBatchProjector; -use crate::arrow::type_to_arrow_type; -use crate::spec::{Literal, PartitionSpecRef, SchemaRef, Struct, StructType, Type}; -use crate::transform::{BoxedTransformFunction, create_transform_function}; +use super::partition_value_calculator::PartitionValueCalculator; +use crate::spec::{Literal, PartitionKey, PartitionSpecRef, SchemaRef, StructType}; use crate::{Error, ErrorKind, Result}; +/// Column name for the projected partition values struct +pub const PROJECTED_PARTITION_VALUE_COLUMN: &str = "_partition"; + /// The splitter used to split the record batch into multiple record batches by the partition spec. /// 1. It will project and transform the input record batch based on the partition spec, get the partitioned record batch. /// 2. Split the input record batch into multiple record batches based on the partitioned record batch. @@ -40,116 +39,123 @@ use crate::{Error, ErrorKind, Result}; pub struct RecordBatchPartitionSplitter { schema: SchemaRef, partition_spec: PartitionSpecRef, - projector: RecordBatchProjector, - transform_functions: Vec, - + calculator: Option, partition_type: StructType, - partition_arrow_type: DataType, + use_projected_partition_column: bool, } // # TODO // Remove this after partition writer supported. #[allow(dead_code)] impl RecordBatchPartitionSplitter { + /// Create a new RecordBatchPartitionSplitter. + /// + /// # Arguments + /// + /// * `_input_schema` - The Arrow schema of the input record batches (unused when use_projected_partition_column is true) + /// * `iceberg_schema` - The Iceberg schema reference + /// * `partition_spec` - The partition specification reference + /// * `use_projected_partition_column` - If true, expects a pre-computed partition column in the input batch + /// + /// # Returns + /// + /// Returns a new `RecordBatchPartitionSplitter` instance or an error if initialization fails. pub fn new( - input_schema: ArrowSchemaRef, + _input_schema: ArrowSchemaRef, iceberg_schema: SchemaRef, partition_spec: PartitionSpecRef, + use_projected_partition_column: bool, ) -> Result { - let projector = RecordBatchProjector::new( - input_schema, - &partition_spec - .fields() - .iter() - .map(|field| field.source_id) - .collect::>(), - // The source columns, selected by ids, must be a primitive type and cannot be contained in a map or list, but may be nested in a struct. - // ref: https://iceberg.apache.org/spec/#partitioning - |field| { - if !field.data_type().is_primitive() { - return Ok(None); - } - field - .metadata() - .get(PARQUET_FIELD_ID_META_KEY) - .map(|s| { - s.parse::() - .map_err(|e| Error::new(ErrorKind::Unexpected, e.to_string())) - }) - .transpose() - }, - |_| true, - )?; - let transform_functions = partition_spec - .fields() - .iter() - .map(|field| create_transform_function(&field.transform)) - .collect::>>()?; - let partition_type = partition_spec.partition_type(&iceberg_schema)?; - let partition_arrow_type = type_to_arrow_type(&Type::Struct(partition_type.clone()))?; + + let calculator = if use_projected_partition_column { + // Skip calculator initialization when partition column is pre-computed + None + } else { + // Create calculator for computing partition values from source columns + Some(PartitionValueCalculator::try_new( + &partition_spec, + &iceberg_schema, + )?) + }; Ok(Self { schema: iceberg_schema, partition_spec, - projector, - transform_functions, + calculator, partition_type, - partition_arrow_type, + use_projected_partition_column, }) } - fn partition_columns_to_struct(&self, partition_columns: Vec) -> Result> { - let arrow_struct_array = { - let partition_arrow_fields = { - let DataType::Struct(fields) = &self.partition_arrow_type else { - return Err(Error::new( + /// Split the record batch into multiple record batches based on the partition spec. + pub fn split(&self, batch: &RecordBatch) -> Result> { + let partition_structs = if self.use_projected_partition_column { + // Extract partition values from pre-computed partition column + let partition_column = batch + .column_by_name(PROJECTED_PARTITION_VALUE_COLUMN) + .ok_or_else(|| { + Error::new( + ErrorKind::DataInvalid, + format!( + "Partition column '{}' not found in batch", + PROJECTED_PARTITION_VALUE_COLUMN + ), + ) + })?; + + let partition_struct_array = partition_column + .as_any() + .downcast_ref::() + .ok_or_else(|| { + Error::new( ErrorKind::DataInvalid, - "The partition arrow type is not a struct type", - )); - }; - fields.clone() - }; - Arc::new(StructArray::try_new( - partition_arrow_fields, - partition_columns, - None, - )?) as ArrayRef - }; - let struct_array = { + "Partition column is not a StructArray", + ) + })?; + + let arrow_struct_array = Arc::new(partition_struct_array.clone()) as ArrayRef; let struct_array = arrow_struct_to_literal(&arrow_struct_array, &self.partition_type)?; + struct_array .into_iter() .map(|s| { - if let Some(s) = s { - if let Literal::Struct(s) = s { - Ok(s) - } else { - Err(Error::new( - ErrorKind::DataInvalid, - "The struct is not a struct literal", - )) - } + if let Some(Literal::Struct(s)) = s { + Ok(s) } else { - Err(Error::new(ErrorKind::DataInvalid, "The struct is null")) + Err(Error::new( + ErrorKind::DataInvalid, + "Partition value is not a struct literal or is null", + )) } }) .collect::>>()? - }; + } else { + // Compute partition values from source columns using calculator + let calculator = self.calculator.as_ref().ok_or_else(|| { + Error::new( + ErrorKind::DataInvalid, + "Calculator not initialized for non-partition-column mode", + ) + })?; - Ok(struct_array) - } + let partition_array = calculator.calculate(batch)?; + let struct_array = arrow_struct_to_literal(&partition_array, &self.partition_type)?; - /// Split the record batch into multiple record batches based on the partition spec. - pub fn split(&self, batch: &RecordBatch) -> Result> { - let source_columns = self.projector.project_column(batch.columns())?; - let partition_columns = source_columns - .into_iter() - .zip_eq(self.transform_functions.iter()) - .map(|(source_column, transform_function)| transform_function.transform(source_column)) - .collect::>>()?; - - let partition_structs = self.partition_columns_to_struct(partition_columns)?; + struct_array + .into_iter() + .map(|s| { + if let Some(Literal::Struct(s)) = s { + Ok(s) + } else { + Err(Error::new( + ErrorKind::DataInvalid, + "Partition value is not a struct literal or is null", + )) + } + }) + .collect::>>()? + }; // Group the batch by row value. let mut group_ids = HashMap::new(); @@ -172,8 +178,15 @@ impl RecordBatchPartitionSplitter { filter.into() }; + // Create PartitionKey from the partition struct + let partition_key = PartitionKey::new( + self.partition_spec.as_ref().clone(), + self.schema.clone(), + row, + ); + // filter the RecordBatch - partition_batches.push((row, filter_record_batch(batch, &filter_array)?)); + partition_batches.push((partition_key, filter_record_batch(batch, &filter_array)?)); } Ok(partition_batches) @@ -185,11 +198,13 @@ mod tests { use std::sync::Arc; use arrow_array::{Int32Array, RecordBatch, StringArray}; + use arrow_schema::DataType; + use parquet::arrow::PARQUET_FIELD_ID_META_KEY; use super::*; use crate::arrow::schema_to_arrow_schema; use crate::spec::{ - NestedField, PartitionSpecBuilder, PrimitiveLiteral, Schema, Transform, + NestedField, PartitionSpecBuilder, PrimitiveLiteral, Schema, Struct, Transform, Type, UnboundPartitionField, }; @@ -228,9 +243,13 @@ mod tests { .unwrap(), ); let input_schema = Arc::new(schema_to_arrow_schema(&schema).unwrap()); - let partition_splitter = - RecordBatchPartitionSplitter::new(input_schema.clone(), schema.clone(), partition_spec) - .expect("Failed to create splitter"); + let partition_splitter = RecordBatchPartitionSplitter::new( + input_schema.clone(), + schema.clone(), + partition_spec, + false, + ) + .expect("Failed to create splitter"); let id_array = Int32Array::from(vec![1, 2, 1, 3, 2, 3, 1]); let data_array = StringArray::from(vec!["a", "b", "c", "d", "e", "f", "g"]); @@ -243,8 +262,8 @@ mod tests { let mut partitioned_batches = partition_splitter .split(&batch) .expect("Failed to split RecordBatch"); - partitioned_batches.sort_by_key(|(row, _)| { - if let PrimitiveLiteral::Int(i) = row.fields()[0] + partitioned_batches.sort_by_key(|(partition_key, _)| { + if let PrimitiveLiteral::Int(i) = partition_key.data().fields()[0] .as_ref() .unwrap() .as_primitive_literal() @@ -292,7 +311,7 @@ mod tests { let partition_values = partitioned_batches .iter() - .map(|(row, _)| row.clone()) + .map(|(partition_key, _)| partition_key.data().clone()) .collect::>(); // check partition value is struct(1), struct(2), struct(3) assert_eq!(partition_values, vec![ @@ -301,4 +320,119 @@ mod tests { Struct::from_iter(vec![Some(Literal::int(3))]), ]); } + + #[test] + fn test_record_batch_partition_split_with_partition_column() { + use arrow_array::StructArray; + use arrow_schema::{Field, Schema as ArrowSchema}; + + let schema = Arc::new( + Schema::builder() + .with_fields(vec![ + NestedField::required( + 1, + "id", + Type::Primitive(crate::spec::PrimitiveType::Int), + ) + .into(), + NestedField::required( + 2, + "name", + Type::Primitive(crate::spec::PrimitiveType::String), + ) + .into(), + ]) + .build() + .unwrap(), + ); + let partition_spec = Arc::new( + PartitionSpecBuilder::new(schema.clone()) + .with_spec_id(1) + .add_unbound_field(UnboundPartitionField { + source_id: 1, + field_id: None, + name: "id_bucket".to_string(), + transform: Transform::Identity, + }) + .unwrap() + .build() + .unwrap(), + ); + + // Create input schema with _partition column + // Note: partition field IDs start from 1000 by default + let partition_field = Field::new("id_bucket", DataType::Int32, false).with_metadata( + HashMap::from([(PARQUET_FIELD_ID_META_KEY.to_string(), "1000".to_string())]), + ); + let partition_struct_field = Field::new( + PROJECTED_PARTITION_VALUE_COLUMN, + DataType::Struct(vec![partition_field.clone()].into()), + false, + ); + + let input_schema = Arc::new(ArrowSchema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("name", DataType::Utf8, false), + partition_struct_field, + ])); + + // Create splitter with has_partition_column=true + let partition_splitter = RecordBatchPartitionSplitter::new( + input_schema.clone(), + schema.clone(), + partition_spec, + true, + ) + .expect("Failed to create splitter"); + + // Create test data with pre-computed partition column + let id_array = Int32Array::from(vec![1, 2, 1, 3, 2, 3, 1]); + let data_array = StringArray::from(vec!["a", "b", "c", "d", "e", "f", "g"]); + + // Create partition column (same values as id for Identity transform) + let partition_values = Int32Array::from(vec![1, 2, 1, 3, 2, 3, 1]); + let partition_struct = StructArray::from(vec![( + Arc::new(partition_field), + Arc::new(partition_values) as ArrayRef, + )]); + + let batch = RecordBatch::try_new(input_schema.clone(), vec![ + Arc::new(id_array), + Arc::new(data_array), + Arc::new(partition_struct), + ]) + .expect("Failed to create RecordBatch"); + + // Split using the pre-computed partition column + let mut partitioned_batches = partition_splitter + .split(&batch) + .expect("Failed to split RecordBatch"); + + partitioned_batches.sort_by_key(|(partition_key, _)| { + if let PrimitiveLiteral::Int(i) = partition_key.data().fields()[0] + .as_ref() + .unwrap() + .as_primitive_literal() + .unwrap() + { + i + } else { + panic!("The partition value is not a int"); + } + }); + + assert_eq!(partitioned_batches.len(), 3); + + // Verify partition values + let partition_values = partitioned_batches + .iter() + .map(|(partition_key, _)| partition_key.data().clone()) + .collect::>(); + + assert_eq!(partition_values, vec![ + Struct::from_iter(vec![Some(Literal::int(1))]), + Struct::from_iter(vec![Some(Literal::int(2))]), + Struct::from_iter(vec![Some(Literal::int(3))]), + ]); + } } diff --git a/crates/iceberg/src/writer/partitioning/mod.rs b/crates/iceberg/src/writer/partitioning/mod.rs index f63a9d0d26..c8106041ac 100644 --- a/crates/iceberg/src/writer/partitioning/mod.rs +++ b/crates/iceberg/src/writer/partitioning/mod.rs @@ -23,6 +23,7 @@ pub mod clustered_writer; pub mod fanout_writer; +pub mod unpartitioned_writer; use crate::Result; use crate::spec::PartitionKey; diff --git a/crates/iceberg/src/writer/partitioning/unpartitioned_writer.rs b/crates/iceberg/src/writer/partitioning/unpartitioned_writer.rs new file mode 100644 index 0000000000..702a69543d --- /dev/null +++ b/crates/iceberg/src/writer/partitioning/unpartitioned_writer.rs @@ -0,0 +1,412 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! This module provides the `UnpartitionedWriter` implementation. + +use std::marker::PhantomData; + +use crate::Result; +use crate::writer::{DefaultInput, DefaultOutput, IcebergWriter, IcebergWriterBuilder}; + +/// A simple wrapper around `IcebergWriterBuilder` for unpartitioned tables. +/// +/// This writer lazily creates the underlying writer on the first write operation +/// and writes all data to a single file (or set of files if rolling). +/// +/// # Type Parameters +/// +/// * `B` - The inner writer builder type +/// * `I` - Input type (defaults to `RecordBatch`) +/// * `O` - Output collection type (defaults to `Vec`) +pub struct UnpartitionedWriter +where + B: IcebergWriterBuilder, + O: IntoIterator + FromIterator<::Item>, + ::Item: Clone, +{ + inner_builder: B, + writer: Option, + output: Vec<::Item>, + _phantom: PhantomData, +} + +impl UnpartitionedWriter +where + B: IcebergWriterBuilder, + I: Send + 'static, + O: IntoIterator + FromIterator<::Item>, + ::Item: Send + Clone, +{ + /// Create a new `UnpartitionedWriter`. + pub fn new(inner_builder: B) -> Self { + Self { + inner_builder, + writer: None, + output: Vec::new(), + _phantom: PhantomData, + } + } + + /// Write data to the writer. + /// + /// The underlying writer is lazily created on the first write operation. + /// + /// # Parameters + /// + /// * `input` - The input data to write + /// + /// # Returns + /// + /// `Ok(())` on success, or an error if the write operation fails. + pub async fn write(&mut self, input: I) -> Result<()> { + // Lazily create writer on first write + if self.writer.is_none() { + self.writer = Some(self.inner_builder.clone().build(None).await?); + } + + // Write directly to inner writer + self.writer + .as_mut() + .expect("Writer should be initialized") + .write(input) + .await + } + + /// Close the writer and return all written data files. + /// + /// This method consumes the writer to prevent further use. + /// + /// # Returns + /// + /// The accumulated output from all write operations, or an empty collection + /// if no data was written. + pub async fn close(mut self) -> Result { + if let Some(mut writer) = self.writer.take() { + self.output.extend(writer.close().await?); + } + Ok(O::from_iter(self.output)) + } +} + +#[cfg(test)] +mod tests { + use std::collections::HashMap; + use std::sync::Arc; + + use arrow_array::{Int32Array, RecordBatch, StringArray}; + use arrow_schema::{DataType, Field, Schema}; + use parquet::arrow::PARQUET_FIELD_ID_META_KEY; + use parquet::file::properties::WriterProperties; + use tempfile::TempDir; + + use super::*; + use crate::Result; + use crate::io::FileIOBuilder; + use crate::spec::{DataFileFormat, NestedField, PrimitiveType, Struct, Type}; + use crate::writer::base_writer::data_file_writer::DataFileWriterBuilder; + use crate::writer::file_writer::ParquetWriterBuilder; + use crate::writer::file_writer::location_generator::{ + DefaultFileNameGenerator, DefaultLocationGenerator, + }; + use crate::writer::file_writer::rolling_writer::RollingFileWriterBuilder; + + /// Helper function to create a test writer setup with common configuration + fn create_test_writer_builder( + temp_dir: &TempDir, + schema: Arc, + ) -> Result { + let file_io = FileIOBuilder::new_fs_io().build()?; + let location_gen = DefaultLocationGenerator::with_data_location( + temp_dir.path().to_str().unwrap().to_string(), + ); + let file_name_gen = + DefaultFileNameGenerator::new("test".to_string(), None, DataFileFormat::Parquet); + + let parquet_writer_builder = + ParquetWriterBuilder::new(WriterProperties::builder().build(), schema); + let rolling_writer_builder = RollingFileWriterBuilder::new_with_default_file_size( + parquet_writer_builder, + file_io, + location_gen, + file_name_gen, + ); + + Ok(DataFileWriterBuilder::new(rolling_writer_builder)) + } + + /// Helper function to create a simple test schema + fn create_simple_schema() -> Result> { + Ok(Arc::new( + crate::spec::Schema::builder() + .with_schema_id(1) + .with_fields(vec![ + NestedField::required(1, "id", Type::Primitive(PrimitiveType::Int)).into(), + NestedField::required(2, "name", Type::Primitive(PrimitiveType::String)).into(), + ]) + .build()?, + )) + } + + /// Helper function to create a schema with a region partition field + fn create_schema_with_region() -> Result> { + Ok(Arc::new( + crate::spec::Schema::builder() + .with_schema_id(1) + .with_fields(vec![ + NestedField::required(1, "id", Type::Primitive(PrimitiveType::Int)).into(), + NestedField::required(2, "name", Type::Primitive(PrimitiveType::String)).into(), + NestedField::required(3, "region", Type::Primitive(PrimitiveType::String)) + .into(), + ]) + .build()?, + )) + } + + /// Helper function to create Arrow schema with field IDs for simple schema + fn create_arrow_schema_simple() -> Schema { + Schema::new(vec![ + Field::new("id", DataType::Int32, false).with_metadata(HashMap::from([( + PARQUET_FIELD_ID_META_KEY.to_string(), + 1.to_string(), + )])), + Field::new("name", DataType::Utf8, false).with_metadata(HashMap::from([( + PARQUET_FIELD_ID_META_KEY.to_string(), + 2.to_string(), + )])), + ]) + } + + /// Helper function to create Arrow schema with field IDs including region + fn create_arrow_schema_with_region() -> Schema { + Schema::new(vec![ + Field::new("id", DataType::Int32, false).with_metadata(HashMap::from([( + PARQUET_FIELD_ID_META_KEY.to_string(), + 1.to_string(), + )])), + Field::new("name", DataType::Utf8, false).with_metadata(HashMap::from([( + PARQUET_FIELD_ID_META_KEY.to_string(), + 2.to_string(), + )])), + Field::new("region", DataType::Utf8, false).with_metadata(HashMap::from([( + PARQUET_FIELD_ID_META_KEY.to_string(), + 3.to_string(), + )])), + ]) + } + + #[tokio::test] + async fn test_unpartitioned_writer_basic_functionality() -> Result<()> { + let temp_dir = TempDir::new()?; + let schema = create_simple_schema()?; + let data_file_writer_builder = create_test_writer_builder(&temp_dir, schema.clone())?; + + // Create unpartitioned writer + let mut writer = UnpartitionedWriter::new(data_file_writer_builder); + + // Create test data + let arrow_schema = create_arrow_schema_simple(); + let batch1 = RecordBatch::try_new(Arc::new(arrow_schema.clone()), vec![ + Arc::new(Int32Array::from(vec![1, 2])), + Arc::new(StringArray::from(vec!["Alice", "Bob"])), + ])?; + + let batch2 = RecordBatch::try_new(Arc::new(arrow_schema.clone()), vec![ + Arc::new(Int32Array::from(vec![3, 4])), + Arc::new(StringArray::from(vec!["Charlie", "Dave"])), + ])?; + + // Write data without partition keys + writer.write(batch1).await?; + writer.write(batch2).await?; + + // Close writer and get data files + let data_files = writer.close().await?; + + // Verify at least one file was created + assert!( + !data_files.is_empty(), + "Expected at least one data file to be created" + ); + + // Verify that all data files have empty partition value (unpartitioned) + let partition_value = Struct::empty(); + for data_file in &data_files { + assert_eq!(data_file.partition, partition_value); + } + + Ok(()) + } + + #[tokio::test] + async fn test_unpartitioned_writer_writes_all_data_together() -> Result<()> { + let temp_dir = TempDir::new()?; + let schema = create_schema_with_region()?; + let data_file_writer_builder = create_test_writer_builder(&temp_dir, schema.clone())?; + + // Create unpartitioned writer + let mut writer = UnpartitionedWriter::new(data_file_writer_builder); + + // Create test data with different regions + let arrow_schema = create_arrow_schema_with_region(); + let batch_us = RecordBatch::try_new(Arc::new(arrow_schema.clone()), vec![ + Arc::new(Int32Array::from(vec![1, 2])), + Arc::new(StringArray::from(vec!["Alice", "Bob"])), + Arc::new(StringArray::from(vec!["US", "US"])), + ])?; + + let batch_eu = RecordBatch::try_new(Arc::new(arrow_schema.clone()), vec![ + Arc::new(Int32Array::from(vec![3, 4])), + Arc::new(StringArray::from(vec!["Charlie", "Dave"])), + Arc::new(StringArray::from(vec!["EU", "EU"])), + ])?; + + // Write data from different regions - all goes to same file(s) + writer.write(batch_us).await?; + writer.write(batch_eu).await?; + + // Close writer and get data files + let data_files = writer.close().await?; + + // Verify at least one file was created + assert!( + !data_files.is_empty(), + "Expected at least one data file to be created" + ); + + // All data should be written to the same file(s) with empty partition + for data_file in &data_files { + assert_eq!( + data_file.partition, + Struct::empty(), + "Expected empty partition for unpartitioned writer" + ); + } + + Ok(()) + } + + #[tokio::test] + async fn test_unpartitioned_writer_lazy_initialization() -> Result<()> { + let temp_dir = TempDir::new()?; + let schema = create_simple_schema()?; + let data_file_writer_builder = create_test_writer_builder(&temp_dir, schema.clone())?; + + // Create unpartitioned writer - writer should not be initialized yet + let mut writer = UnpartitionedWriter::new(data_file_writer_builder); + + // Verify writer is None before first write + assert!( + writer.writer.is_none(), + "Writer should not be initialized before first write" + ); + + // Create test data + let arrow_schema = create_arrow_schema_simple(); + let batch = RecordBatch::try_new(Arc::new(arrow_schema.clone()), vec![ + Arc::new(Int32Array::from(vec![1, 2])), + Arc::new(StringArray::from(vec!["Alice", "Bob"])), + ])?; + + // Write data - this should trigger lazy initialization + writer.write(batch).await?; + + // Verify writer is now initialized + assert!( + writer.writer.is_some(), + "Writer should be initialized after first write" + ); + + // Close writer + let data_files = writer.close().await?; + + // Verify file was created + assert!(!data_files.is_empty(), "Expected at least one data file"); + + Ok(()) + } + + #[tokio::test] + async fn test_unpartitioned_writer_close_returns_correct_data_files() -> Result<()> { + let temp_dir = TempDir::new()?; + let schema = create_simple_schema()?; + let data_file_writer_builder = create_test_writer_builder(&temp_dir, schema.clone())?; + + // Create unpartitioned writer + let mut writer = UnpartitionedWriter::new(data_file_writer_builder); + + // Create test data + let arrow_schema = create_arrow_schema_simple(); + let batch = RecordBatch::try_new(Arc::new(arrow_schema.clone()), vec![ + Arc::new(Int32Array::from(vec![1, 2, 3])), + Arc::new(StringArray::from(vec!["Alice", "Bob", "Charlie"])), + ])?; + + // Write data + writer.write(batch).await?; + + // Close writer and get data files + let data_files = writer.close().await?; + + // Verify data files were returned + assert!(!data_files.is_empty(), "Expected at least one data file"); + + // Verify each data file has correct properties + let partition_value = Struct::empty(); + for data_file in &data_files { + // Check partition is empty (unpartitioned) + assert_eq!(data_file.partition, partition_value); + + // Check file format is Parquet + assert_eq!(data_file.file_format, DataFileFormat::Parquet); + + // Check file path is not empty + assert!( + !data_file.file_path.is_empty(), + "File path should not be empty" + ); + + // Check record count is positive + assert!( + data_file.record_count > 0, + "Record count should be positive" + ); + } + + Ok(()) + } + + #[tokio::test] + async fn test_unpartitioned_writer_close_without_writes() -> Result<()> { + let temp_dir = TempDir::new()?; + let schema = create_simple_schema()?; + let data_file_writer_builder = create_test_writer_builder(&temp_dir, schema.clone())?; + + // Create unpartitioned writer + let writer = UnpartitionedWriter::new(data_file_writer_builder); + + // Close writer without writing any data + let data_files = writer.close().await?; + + // Verify no data files were created + assert!( + data_files.is_empty(), + "Expected no data files when closing without writes" + ); + + Ok(()) + } +} diff --git a/crates/integrations/datafusion/src/lib.rs b/crates/integrations/datafusion/src/lib.rs index 09d1cac4ce..c5ea2b3ad2 100644 --- a/crates/integrations/datafusion/src/lib.rs +++ b/crates/integrations/datafusion/src/lib.rs @@ -26,3 +26,5 @@ mod schema; pub mod table; pub use table::table_provider_factory::IcebergTableProviderFactory; pub use table::*; + +pub mod writer; diff --git a/crates/integrations/datafusion/src/physical_plan/project.rs b/crates/integrations/datafusion/src/physical_plan/project.rs index 4bfe8192b0..d8775f6502 100644 --- a/crates/integrations/datafusion/src/physical_plan/project.rs +++ b/crates/integrations/datafusion/src/physical_plan/project.rs @@ -19,24 +19,23 @@ use std::sync::Arc; -use datafusion::arrow::array::{ArrayRef, RecordBatch, StructArray}; +use datafusion::arrow::array::RecordBatch; +#[cfg(test)] +use datafusion::arrow::array::{ArrayRef, StructArray}; use datafusion::arrow::datatypes::{DataType, Schema as ArrowSchema}; use datafusion::common::Result as DFResult; -use datafusion::error::DataFusionError; use datafusion::physical_expr::PhysicalExpr; use datafusion::physical_expr::expressions::Column; use datafusion::physical_plan::projection::ProjectionExec; use datafusion::physical_plan::{ColumnarValue, ExecutionPlan}; -use iceberg::arrow::record_batch_projector::RecordBatchProjector; -use iceberg::spec::{PartitionSpec, Schema}; +use iceberg::arrow::{PROJECTED_PARTITION_VALUE_COLUMN, PartitionValueCalculator}; +use iceberg::spec::PartitionSpec; +#[cfg(test)] +use iceberg::spec::Schema; use iceberg::table::Table; -use iceberg::transform::BoxedTransformFunction; use crate::to_datafusion_error; -/// Column name for the combined partition values struct -const PARTITION_VALUES_COLUMN: &str = "_partition"; - /// Extends an ExecutionPlan with partition value calculations for Iceberg tables. /// /// This function takes an input ExecutionPlan and extends it with an additional column @@ -65,12 +64,9 @@ pub fn project_with_partition( let input_schema = input.schema(); // TODO: Validate that input_schema matches the Iceberg table schema. // See: https://github.com/apache/iceberg-rust/issues/1752 - let partition_type = build_partition_type(partition_spec, table_schema.as_ref())?; - let calculator = PartitionValueCalculator::new( - partition_spec.as_ref().clone(), - table_schema.as_ref().clone(), - partition_type, - )?; + let calculator = + PartitionValueCalculator::try_new(partition_spec.as_ref(), table_schema.as_ref()) + .map_err(to_datafusion_error)?; let mut projection_exprs: Vec<(Arc, String)> = Vec::with_capacity(input_schema.fields().len() + 1); @@ -80,8 +76,8 @@ pub fn project_with_partition( projection_exprs.push((column_expr, field.name().clone())); } - let partition_expr = Arc::new(PartitionExpr::new(calculator)); - projection_exprs.push((partition_expr, PARTITION_VALUES_COLUMN.to_string())); + let partition_expr = Arc::new(PartitionExpr::new(calculator, partition_spec.clone())); + projection_exprs.push((partition_expr, PROJECTED_PARTITION_VALUE_COLUMN.to_string())); let projection = ProjectionExec::try_new(projection_exprs, input)?; Ok(Arc::new(projection)) @@ -91,21 +87,24 @@ pub fn project_with_partition( #[derive(Debug, Clone)] struct PartitionExpr { calculator: Arc, + partition_spec: Arc, } impl PartitionExpr { - fn new(calculator: PartitionValueCalculator) -> Self { + fn new(calculator: PartitionValueCalculator, partition_spec: Arc) -> Self { Self { calculator: Arc::new(calculator), + partition_spec, } } } // Manual PartialEq/Eq implementations for pointer-based equality -// (two PartitionExpr are equal if they share the same calculator instance) +// (two PartitionExpr are equal if they share the same calculator and partition_spec instances) impl PartialEq for PartitionExpr { fn eq(&self, other: &Self) -> bool { Arc::ptr_eq(&self.calculator, &other.calculator) + && Arc::ptr_eq(&self.partition_spec, &other.partition_spec) } } @@ -117,7 +116,7 @@ impl PhysicalExpr for PartitionExpr { } fn data_type(&self, _input_schema: &ArrowSchema) -> DFResult { - Ok(self.calculator.partition_type.clone()) + Ok(self.calculator.partition_arrow_type().clone()) } fn nullable(&self, _input_schema: &ArrowSchema) -> DFResult { @@ -125,7 +124,10 @@ impl PhysicalExpr for PartitionExpr { } fn evaluate(&self, batch: &RecordBatch) -> DFResult { - let array = self.calculator.calculate(batch)?; + let array = self + .calculator + .calculate(batch) + .map_err(to_datafusion_error)?; Ok(ColumnarValue::Array(array)) } @@ -142,7 +144,6 @@ impl PhysicalExpr for PartitionExpr { fn fmt_sql(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { let field_names: Vec = self - .calculator .partition_spec .fields() .iter() @@ -155,7 +156,6 @@ impl PhysicalExpr for PartitionExpr { impl std::fmt::Display for PartitionExpr { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { let field_names: Vec<&str> = self - .calculator .partition_spec .fields() .iter() @@ -167,104 +167,12 @@ impl std::fmt::Display for PartitionExpr { impl std::hash::Hash for PartitionExpr { fn hash(&self, state: &mut H) { - // Two PartitionExpr are equal if they share the same calculator Arc + // Two PartitionExpr are equal if they share the same calculator and partition_spec Arcs Arc::as_ptr(&self.calculator).hash(state); + Arc::as_ptr(&self.partition_spec).hash(state); } } -/// Calculator for partition values in Iceberg tables -#[derive(Debug)] -struct PartitionValueCalculator { - partition_spec: PartitionSpec, - partition_type: DataType, - projector: RecordBatchProjector, - transform_functions: Vec, -} - -impl PartitionValueCalculator { - fn new( - partition_spec: PartitionSpec, - table_schema: Schema, - partition_type: DataType, - ) -> DFResult { - if partition_spec.is_unpartitioned() { - return Err(DataFusionError::Internal( - "Cannot create partition calculator for unpartitioned table".to_string(), - )); - } - - let transform_functions: Result, _> = partition_spec - .fields() - .iter() - .map(|pf| iceberg::transform::create_transform_function(&pf.transform)) - .collect(); - - let transform_functions = transform_functions.map_err(to_datafusion_error)?; - - let source_field_ids: Vec = partition_spec - .fields() - .iter() - .map(|pf| pf.source_id) - .collect(); - - let projector = RecordBatchProjector::from_iceberg_schema( - Arc::new(table_schema.clone()), - &source_field_ids, - ) - .map_err(to_datafusion_error)?; - - Ok(Self { - partition_spec, - partition_type, - projector, - transform_functions, - }) - } - - fn calculate(&self, batch: &RecordBatch) -> DFResult { - let source_columns = self - .projector - .project_column(batch.columns()) - .map_err(to_datafusion_error)?; - - let expected_struct_fields = match &self.partition_type { - DataType::Struct(fields) => fields.clone(), - _ => { - return Err(DataFusionError::Internal( - "Expected partition type must be a struct".to_string(), - )); - } - }; - - let mut partition_values = Vec::with_capacity(self.partition_spec.fields().len()); - - for (source_column, transform_fn) in source_columns.iter().zip(&self.transform_functions) { - let partition_value = transform_fn - .transform(source_column.clone()) - .map_err(to_datafusion_error)?; - - partition_values.push(partition_value); - } - - let struct_array = StructArray::try_new(expected_struct_fields, partition_values, None) - .map_err(|e| DataFusionError::ArrowError(e, None))?; - - Ok(Arc::new(struct_array)) - } -} - -fn build_partition_type( - partition_spec: &PartitionSpec, - table_schema: &Schema, -) -> DFResult { - let partition_struct_type = partition_spec - .partition_type(table_schema) - .map_err(to_datafusion_error)?; - - iceberg::arrow::type_to_arrow_type(&iceberg::spec::Type::Struct(partition_struct_type)) - .map_err(to_datafusion_error) -} - #[cfg(test)] mod tests { use datafusion::arrow::array::Int32Array; @@ -291,20 +199,11 @@ mod tests { .build() .unwrap(); - let _arrow_schema = Arc::new(ArrowSchema::new(vec![ - Field::new("id", DataType::Int32, false), - Field::new("name", DataType::Utf8, false), - ])); - - let partition_type = build_partition_type(&partition_spec, &table_schema).unwrap(); - let calculator = PartitionValueCalculator::new( - partition_spec.clone(), - table_schema, - partition_type.clone(), - ) - .unwrap(); + let calculator = PartitionValueCalculator::try_new(&partition_spec, &table_schema).unwrap(); - assert_eq!(calculator.partition_type, partition_type); + // Verify partition type + assert_eq!(calculator.partition_type().fields().len(), 1); + assert_eq!(calculator.partition_type().fields()[0].name, "id_partition"); } #[test] @@ -318,11 +217,13 @@ mod tests { .build() .unwrap(); - let partition_spec = iceberg::spec::PartitionSpec::builder(Arc::new(table_schema.clone())) - .add_partition_field("id", "id_partition", Transform::Identity) - .unwrap() - .build() - .unwrap(); + let partition_spec = Arc::new( + iceberg::spec::PartitionSpec::builder(Arc::new(table_schema.clone())) + .add_partition_field("id", "id_partition", Transform::Identity) + .unwrap() + .build() + .unwrap(), + ); let arrow_schema = Arc::new(ArrowSchema::new(vec![ Field::new("id", DataType::Int32, false), @@ -331,9 +232,7 @@ mod tests { let input = Arc::new(EmptyExec::new(arrow_schema.clone())); - let partition_type = build_partition_type(&partition_spec, &table_schema).unwrap(); - let calculator = - PartitionValueCalculator::new(partition_spec, table_schema, partition_type).unwrap(); + let calculator = PartitionValueCalculator::try_new(&partition_spec, &table_schema).unwrap(); let mut projection_exprs: Vec<(Arc, String)> = Vec::with_capacity(arrow_schema.fields().len() + 1); @@ -342,8 +241,8 @@ mod tests { projection_exprs.push((column_expr, field.name().clone())); } - let partition_expr = Arc::new(PartitionExpr::new(calculator)); - projection_exprs.push((partition_expr, PARTITION_VALUES_COLUMN.to_string())); + let partition_expr = Arc::new(PartitionExpr::new(calculator, partition_spec)); + projection_exprs.push((partition_expr, PROJECTED_PARTITION_VALUE_COLUMN.to_string())); let projection = ProjectionExec::try_new(projection_exprs, input).unwrap(); let result = Arc::new(projection); @@ -384,11 +283,10 @@ mod tests { ]) .unwrap(); - let partition_type = build_partition_type(&partition_spec, &table_schema).unwrap(); - let calculator = - PartitionValueCalculator::new(partition_spec, table_schema, partition_type.clone()) - .unwrap(); - let expr = PartitionExpr::new(calculator); + let partition_spec = Arc::new(partition_spec); + let calculator = PartitionValueCalculator::try_new(&partition_spec, &table_schema).unwrap(); + let partition_type = calculator.partition_arrow_type().clone(); + let expr = PartitionExpr::new(calculator, partition_spec); assert_eq!(expr.data_type(&arrow_schema).unwrap(), partition_type); assert!(!expr.nullable(&arrow_schema).unwrap()); @@ -469,9 +367,7 @@ mod tests { ]) .unwrap(); - let partition_type = build_partition_type(&partition_spec, &table_schema).unwrap(); - let calculator = - PartitionValueCalculator::new(partition_spec, table_schema, partition_type).unwrap(); + let calculator = PartitionValueCalculator::try_new(&partition_spec, &table_schema).unwrap(); let array = calculator.calculate(&batch).unwrap(); let struct_array = array.as_any().downcast_ref::().unwrap(); diff --git a/crates/integrations/datafusion/src/physical_plan/write.rs b/crates/integrations/datafusion/src/physical_plan/write.rs index dff40a3c0d..47553081a8 100644 --- a/crates/integrations/datafusion/src/physical_plan/write.rs +++ b/crates/integrations/datafusion/src/physical_plan/write.rs @@ -44,13 +44,13 @@ use iceberg::writer::file_writer::location_generator::{ DefaultFileNameGenerator, DefaultLocationGenerator, }; use iceberg::writer::file_writer::rolling_writer::RollingFileWriterBuilder; -use iceberg::writer::{IcebergWriter, IcebergWriterBuilder}; use iceberg::{Error, ErrorKind}; use parquet::file::properties::WriterProperties; use uuid::Uuid; use crate::physical_plan::DATA_FILES_COL_NAME; use crate::to_datafusion_error; +use crate::writer::task::TaskWriter; /// An execution plan node that writes data to an Iceberg table. /// @@ -202,18 +202,6 @@ impl ExecutionPlan for IcebergWriteExec { partition: usize, context: Arc, ) -> DFResult { - if !self - .table - .metadata() - .default_partition_spec() - .is_unpartitioned() - { - // TODO add support for partitioned tables - return Err(DataFusionError::NotImplemented( - "IcebergWriteExec does not support partitioned tables yet".to_string(), - )); - } - let partition_type = self.table.metadata().default_partition_type().clone(); let format_version = self.table.metadata().format_version(); @@ -277,6 +265,10 @@ impl ExecutionPlan for IcebergWriteExec { ); let data_file_writer_builder = DataFileWriterBuilder::new(rolling_writer_builder); + // Get schema and partition spec for TaskWriter + let schema = self.table.metadata().current_schema().clone(); + let partition_spec = self.table.metadata().default_partition_spec().clone(); + // Get input data let data = execute_input_stream( Arc::clone(&self.input), @@ -290,18 +282,23 @@ impl ExecutionPlan for IcebergWriteExec { // Create write stream let stream = futures::stream::once(async move { - let mut writer = data_file_writer_builder - // todo specify partition key when partitioning writer is supported - .build(None) - .await - .map_err(to_datafusion_error)?; + // Create TaskWriter with fanout_enabled=false (use ClusteredWriter for partitioned tables) + let mut task_writer = TaskWriter::new( + data_file_writer_builder, + false, // todo should be configurable + schema, + partition_spec, + ); let mut input_stream = data; while let Some(batch) = input_stream.next().await { - writer.write(batch?).await.map_err(to_datafusion_error)?; + task_writer + .write(batch?) + .await + .map_err(to_datafusion_error)?; } - let data_files = writer.close().await.map_err(to_datafusion_error)?; + let data_files = task_writer.close().await.map_err(to_datafusion_error)?; // Convert builders to data files and then to JSON strings let data_files_strs: Vec = data_files @@ -475,6 +472,9 @@ mod tests { #[tokio::test] async fn test_iceberg_write_exec() -> Result<()> { + // This test verifies that IcebergWriteExec works correctly with the new TaskWriter + // implementation for unpartitioned tables. + // 1. Set up test environment let iceberg_catalog = get_iceberg_catalog().await; let namespace = NamespaceIdent::new("test_namespace".to_string()); @@ -629,4 +629,132 @@ mod tests { Ok(()) } + + // Note: Partitioned table tests are covered by integration tests in + // crates/integrations/datafusion/tests/integration_datafusion_test.rs + // The test_insert_into test validates end-to-end write functionality including + // the new TaskWriter implementation through the DataFusion SQL interface. + + #[tokio::test] + async fn test_iceberg_write_exec_multiple_batches() -> Result<()> { + // This test verifies that IcebergWriteExec correctly handles multiple input batches + // with the new TaskWriter implementation. + + // 1. Set up test environment + let iceberg_catalog = get_iceberg_catalog().await; + let namespace = NamespaceIdent::new("test_multi_batch_namespace".to_string()); + + // Create namespace + iceberg_catalog + .create_namespace(&namespace, HashMap::new()) + .await?; + + // Create schema + let schema = get_test_schema()?; + + // Create table + let table_name = "test_multi_batch_table"; + let table_location = temp_path(); + let creation = get_table_creation(table_location, table_name, schema); + let table = iceberg_catalog.create_table(&namespace, creation).await?; + + // 2. Create multiple test batches + let arrow_schema = Arc::new(ArrowSchema::new(vec![ + Field::new("id", DataType::Int32, false).with_metadata(HashMap::from([( + PARQUET_FIELD_ID_META_KEY.to_string(), + "1".to_string(), + )])), + Field::new("name", DataType::Utf8, false).with_metadata(HashMap::from([( + PARQUET_FIELD_ID_META_KEY.to_string(), + "2".to_string(), + )])), + ])); + + let batch1 = RecordBatch::try_new(arrow_schema.clone(), vec![ + Arc::new(Int32Array::from(vec![1, 2])) as ArrayRef, + Arc::new(StringArray::from(vec!["Alice", "Bob"])) as ArrayRef, + ]) + .map_err(|e| { + Error::new( + ErrorKind::Unexpected, + format!("Failed to create record batch: {}", e), + ) + })?; + + let batch2 = RecordBatch::try_new(arrow_schema.clone(), vec![ + Arc::new(Int32Array::from(vec![3, 4, 5])) as ArrayRef, + Arc::new(StringArray::from(vec!["Charlie", "David", "Eve"])) as ArrayRef, + ]) + .map_err(|e| { + Error::new( + ErrorKind::Unexpected, + format!("Failed to create record batch: {}", e), + ) + })?; + + // 3. Create mock input execution plan with multiple batches + let input_plan = Arc::new(MockExecutionPlan::new(arrow_schema.clone(), vec![ + batch1, batch2, + ])); + + // 4. Create IcebergWriteExec + let write_exec = IcebergWriteExec::new(table.clone(), input_plan, arrow_schema); + + // 5. Execute the plan + let task_ctx = Arc::new(TaskContext::default()); + let stream = write_exec.execute(0, task_ctx).map_err(|e| { + Error::new( + ErrorKind::Unexpected, + format!("Failed to execute plan: {}", e), + ) + })?; + + // Collect the results + let mut results = vec![]; + let mut stream = stream; + while let Some(batch) = stream.next().await { + results.push(batch.map_err(|e| { + Error::new(ErrorKind::Unexpected, format!("Failed to get batch: {}", e)) + })?); + } + + // 6. Verify the results + assert_eq!(results.len(), 1, "Expected one result batch"); + let result_batch = &results[0]; + + // Check data - should have at least one data file + assert!( + result_batch.num_rows() >= 1, + "Expected at least one data file" + ); + + // Get the data file JSON and verify total record count + let partition_type = table.metadata().default_partition_type(); + let spec_id = table.metadata().default_partition_spec_id(); + let schema = table.metadata().current_schema(); + + let mut total_records = 0; + for i in 0..result_batch.num_rows() { + let data_file_json = result_batch + .column(0) + .as_any() + .downcast_ref::() + .expect("Expected StringArray") + .value(i); + + let data_file = + deserialize_data_file_from_json(data_file_json, spec_id, partition_type, schema) + .expect("Failed to deserialize data file JSON"); + + total_records += data_file.record_count(); + } + + // Verify total record count matches input (2 + 3 = 5) + assert_eq!( + total_records, 5, + "Total records should be 5 (2 from batch1 + 3 from batch2)" + ); + + Ok(()) + } } diff --git a/crates/integrations/datafusion/src/writer/mod.rs b/crates/integrations/datafusion/src/writer/mod.rs new file mode 100644 index 0000000000..4ed0710db7 --- /dev/null +++ b/crates/integrations/datafusion/src/writer/mod.rs @@ -0,0 +1,20 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Writer module for DataFusion integration. + +pub mod task; diff --git a/crates/integrations/datafusion/src/writer/task.rs b/crates/integrations/datafusion/src/writer/task.rs new file mode 100644 index 0000000000..34d198b910 --- /dev/null +++ b/crates/integrations/datafusion/src/writer/task.rs @@ -0,0 +1,278 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! TaskWriter for DataFusion integration. +//! +//! This module provides a high-level writer that handles partitioning and routing +//! of RecordBatch data to Iceberg tables. + +use datafusion::arrow::array::RecordBatch; +use iceberg::Result; +use iceberg::arrow::RecordBatchPartitionSplitter; +use iceberg::spec::{DataFile, PartitionSpecRef, SchemaRef}; +use iceberg::writer::IcebergWriterBuilder; +use iceberg::writer::partitioning::PartitioningWriter; +use iceberg::writer::partitioning::clustered_writer::ClusteredWriter; +use iceberg::writer::partitioning::fanout_writer::FanoutWriter; +use iceberg::writer::partitioning::unpartitioned_writer::UnpartitionedWriter; + +/// High-level writer for DataFusion that handles partitioning and routing of RecordBatch data. +/// +/// TaskWriter coordinates writing data to Iceberg tables by: +/// - Selecting the appropriate partitioning strategy (unpartitioned, fanout, or clustered) +/// - Lazily initializing the partition splitter on first write +/// - Routing data to the underlying writer +/// - Collecting all written data files +/// +/// # Type Parameters +/// +/// * `B` - The IcebergWriterBuilder type used to create underlying writers +/// +/// # Example +/// +/// ```rust,ignore +/// use iceberg::spec::{PartitionSpec, Schema}; +/// use iceberg::writer::base_writer::data_file_writer::DataFileWriterBuilder; +/// use iceberg_datafusion::writer::task::TaskWriter; +/// +/// // Create a TaskWriter for an unpartitioned table +/// let task_writer = TaskWriter::new( +/// data_file_writer_builder, +/// false, // fanout_enabled +/// schema, +/// partition_spec, +/// ); +/// +/// // Write data +/// task_writer.write(record_batch).await?; +/// +/// // Close and get data files +/// let data_files = task_writer.close().await?; +/// ``` +pub struct TaskWriter { + /// The underlying writer (UnpartitionedWriter, FanoutWriter, or ClusteredWriter) + writer: SupportedWriter, + /// Lazily initialized partition splitter for partitioned tables + partition_splitter: Option, + /// Iceberg schema reference + schema: SchemaRef, + /// Partition specification reference + partition_spec: PartitionSpecRef, +} + +/// Internal enum to hold the different writer types. +/// +/// This enum allows TaskWriter to work with different partitioning strategies +/// while maintaining a unified interface. +enum SupportedWriter { + /// Writer for unpartitioned tables + Unpartitioned(UnpartitionedWriter), + /// Writer for partitioned tables with unsorted data (maintains multiple active writers) + Fanout(FanoutWriter), + /// Writer for partitioned tables with sorted data (maintains single active writer) + Clustered(ClusteredWriter), +} + +impl TaskWriter { + /// Create a new TaskWriter. + /// + /// # Parameters + /// + /// * `writer_builder` - The IcebergWriterBuilder to use for creating underlying writers + /// * `fanout_enabled` - If true, use FanoutWriter for partitioned tables; otherwise use ClusteredWriter + /// * `schema` - The Iceberg schema reference + /// * `partition_spec` - The partition specification reference + /// + /// # Returns + /// + /// Returns a new TaskWriter instance. + /// + /// # Writer Selection Logic + /// + /// - If partition_spec is unpartitioned: creates UnpartitionedWriter + /// - If partition_spec is partitioned AND fanout_enabled is true: creates FanoutWriter + /// - If partition_spec is partitioned AND fanout_enabled is false: creates ClusteredWriter + /// + /// # Example + /// + /// ```rust,ignore + /// use iceberg::spec::{PartitionSpec, Schema}; + /// use iceberg::writer::base_writer::data_file_writer::DataFileWriterBuilder; + /// use iceberg_datafusion::writer::task::TaskWriter; + /// + /// // Create a TaskWriter for an unpartitioned table + /// let task_writer = TaskWriter::new( + /// data_file_writer_builder, + /// false, // fanout_enabled + /// schema, + /// partition_spec, + /// ); + /// ``` + pub fn new( + writer_builder: B, + fanout_enabled: bool, + schema: SchemaRef, + partition_spec: PartitionSpecRef, + ) -> Self { + let writer = if partition_spec.is_unpartitioned() { + SupportedWriter::Unpartitioned(UnpartitionedWriter::new(writer_builder)) + } else if fanout_enabled { + SupportedWriter::Fanout(FanoutWriter::new(writer_builder)) + } else { + SupportedWriter::Clustered(ClusteredWriter::new(writer_builder)) + }; + + Self { + writer, + partition_splitter: None, + schema, + partition_spec, + } + } + + /// Write a RecordBatch to the TaskWriter. + /// + /// For the first write to a partitioned table, this method initializes the partition splitter. + /// For unpartitioned tables, data is written directly without splitting. + /// + /// # Parameters + /// + /// * `batch` - The RecordBatch to write + /// + /// # Returns + /// + /// Returns `Ok(())` on success, or an error if the write fails. + /// + /// # Errors + /// + /// This method will return an error if: + /// - Partition splitter initialization fails + /// - Splitting the batch by partition fails + /// - Writing to the underlying writer fails + /// + /// # Example + /// + /// ```rust,ignore + /// use arrow_array::RecordBatch; + /// use iceberg_datafusion::writer::task::TaskWriter; + /// + /// // Write a RecordBatch + /// task_writer.write(record_batch).await?; + /// ``` + pub async fn write(&mut self, batch: RecordBatch) -> Result<()> { + match &mut self.writer { + SupportedWriter::Unpartitioned(writer) => { + // Unpartitioned: write directly without splitting + writer.write(batch).await + } + SupportedWriter::Fanout(writer) => { + // Initialize splitter on first write if needed + if self.partition_splitter.is_none() { + let arrow_schema = batch.schema(); + self.partition_splitter = Some(RecordBatchPartitionSplitter::new( + arrow_schema, + self.schema.clone(), + self.partition_spec.clone(), + true, // use_projected_partition_value + )?); + } + + // Split and write partitioned data + Self::write_partitioned_batches(writer, &self.partition_splitter, &batch).await + } + SupportedWriter::Clustered(writer) => { + // Initialize splitter on first write if needed + if self.partition_splitter.is_none() { + let arrow_schema = batch.schema(); + self.partition_splitter = Some(RecordBatchPartitionSplitter::new( + arrow_schema, + self.schema.clone(), + self.partition_spec.clone(), + true, // use_projected_partition_value + )?); + } + + // Split and write partitioned data + Self::write_partitioned_batches(writer, &self.partition_splitter, &batch).await + } + } + } + + /// Helper method to split and write partitioned data. + /// + /// This method handles the common logic for both FanoutWriter and ClusteredWriter: + /// - Splits the batch by partition key using the provided splitter + /// - Writes each partition to the underlying writer + /// + /// # Parameters + /// + /// * `writer` - The underlying PartitioningWriter (FanoutWriter or ClusteredWriter) + /// * `partition_splitter` - The partition splitter (must be initialized) + /// * `batch` - The RecordBatch to write + /// + /// # Returns + /// + /// Returns `Ok(())` on success, or an error if the operation fails. + async fn write_partitioned_batches( + writer: &mut W, + partition_splitter: &Option, + batch: &RecordBatch, + ) -> Result<()> { + // Split batch by partition + let splitter = partition_splitter + .as_ref() + .expect("Partition splitter should be initialized"); + let partitioned_batches = splitter.split(batch)?; + + // Write each partition + for (partition_key, partition_batch) in partitioned_batches { + writer.write(partition_key, partition_batch).await?; + } + + Ok(()) + } + + /// Close the TaskWriter and return all written data files. + /// + /// This method consumes the TaskWriter to prevent further use. + /// + /// # Returns + /// + /// Returns a `Vec` containing all written files, or an error if closing fails. + /// + /// # Errors + /// + /// This method will return an error if: + /// - Closing the underlying writer fails + /// - Any I/O operation fails during the close process + /// + /// # Example + /// + /// ```rust,ignore + /// use iceberg_datafusion::writer::task::TaskWriter; + /// + /// // Close the writer and get all data files + /// let data_files = task_writer.close().await?; + /// ``` + pub async fn close(self) -> Result> { + match self.writer { + SupportedWriter::Unpartitioned(writer) => writer.close().await, + SupportedWriter::Fanout(writer) => writer.close().await, + SupportedWriter::Clustered(writer) => writer.close().await, + } + } +}