Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion crates/iceberg/src/arrow/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,4 +35,9 @@ mod value;

pub use reader::*;
pub use value::*;
pub(crate) mod record_batch_partition_splitter;
/// Partition value calculator for computing partition values
pub mod partition_value_calculator;
pub use partition_value_calculator::*;
/// Record batch partition splitter for partitioned tables
pub mod record_batch_partition_splitter;
pub use record_batch_partition_splitter::*;
254 changes: 254 additions & 0 deletions crates/iceberg/src/arrow/partition_value_calculator.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,254 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

//! Partition value calculation for Iceberg tables.
//!
//! This module provides utilities for calculating partition values from record batches
//! based on a partition specification.

use std::sync::Arc;

use arrow_array::{ArrayRef, RecordBatch, StructArray};
use arrow_schema::DataType;

use super::record_batch_projector::RecordBatchProjector;
use super::type_to_arrow_type;
use crate::spec::{PartitionSpec, Schema, StructType, Type};
use crate::transform::{BoxedTransformFunction, create_transform_function};
use crate::{Error, ErrorKind, Result};

/// Calculator for partition values in Iceberg tables.
///
/// This struct handles the projection of source columns and application of
/// partition transforms to compute partition values for a given record batch.
#[derive(Debug)]
pub struct PartitionValueCalculator {
projector: RecordBatchProjector,
transform_functions: Vec<BoxedTransformFunction>,
partition_type: StructType,
partition_arrow_type: DataType,
}

impl PartitionValueCalculator {
/// Create a new PartitionValueCalculator.
///
/// # Arguments
///
/// * `partition_spec` - The partition specification
/// * `table_schema` - The Iceberg table schema
///
/// # Returns
///
/// Returns a new `PartitionValueCalculator` instance or an error if initialization fails.
///
/// # Errors
///
/// Returns an error if:
/// - The partition spec is unpartitioned
/// - Transform function creation fails
/// - Projector initialization fails
pub fn try_new(partition_spec: &PartitionSpec, table_schema: &Schema) -> Result<Self> {
if partition_spec.is_unpartitioned() {
return Err(Error::new(
ErrorKind::DataInvalid,
"Cannot create partition calculator for unpartitioned table",
));
}

// Create transform functions for each partition field
let transform_functions: Vec<BoxedTransformFunction> = partition_spec
.fields()
.iter()
.map(|pf| create_transform_function(&pf.transform))
.collect::<Result<Vec<_>>>()?;

// Extract source field IDs for projection
let source_field_ids: Vec<i32> = partition_spec
.fields()
.iter()
.map(|pf| pf.source_id)
.collect();

// Create projector for extracting source columns
let projector = RecordBatchProjector::from_iceberg_schema(
Arc::new(table_schema.clone()),
&source_field_ids,
)?;

// Get partition type information
let partition_type = partition_spec.partition_type(table_schema)?;
let partition_arrow_type = type_to_arrow_type(&Type::Struct(partition_type.clone()))?;

Ok(Self {
projector,
transform_functions,
partition_type,
partition_arrow_type,
})
}

/// Get the partition type as an Iceberg StructType.
pub fn partition_type(&self) -> &StructType {
&self.partition_type
}

/// Get the partition type as an Arrow DataType.
pub fn partition_arrow_type(&self) -> &DataType {
&self.partition_arrow_type
}

/// Calculate partition values for a record batch.
///
/// This method:
/// 1. Projects the source columns from the batch
/// 2. Applies partition transforms to each source column
/// 3. Constructs a StructArray containing the partition values
///
/// # Arguments
///
/// * `batch` - The record batch to calculate partition values for
///
/// # Returns
///
/// Returns an ArrayRef containing a StructArray of partition values, or an error if calculation fails.
///
/// # Errors
///
/// Returns an error if:
/// - Column projection fails
/// - Transform application fails
/// - StructArray construction fails
pub fn calculate(&self, batch: &RecordBatch) -> Result<ArrayRef> {
// Project source columns from the batch
let source_columns = self.projector.project_column(batch.columns())?;

// Get expected struct fields for the result
let expected_struct_fields = match &self.partition_arrow_type {
DataType::Struct(fields) => fields.clone(),
_ => {
return Err(Error::new(
ErrorKind::DataInvalid,
"Expected partition type must be a struct",
));
}
};

// Apply transforms to each source column
let mut partition_values = Vec::with_capacity(self.transform_functions.len());
for (source_column, transform_fn) in source_columns.iter().zip(&self.transform_functions) {
let partition_value = transform_fn.transform(source_column.clone())?;
partition_values.push(partition_value);
}

// Construct the StructArray
let struct_array = StructArray::try_new(expected_struct_fields, partition_values, None)
.map_err(|e| {
Error::new(
ErrorKind::DataInvalid,
format!("Failed to create partition struct array: {}", e),
)
})?;

Ok(Arc::new(struct_array))
}
}

#[cfg(test)]
mod tests {
use std::sync::Arc;

use arrow_array::{Int32Array, RecordBatch, StringArray};
use arrow_schema::{Field, Schema as ArrowSchema};

use super::*;
use crate::spec::{NestedField, PartitionSpecBuilder, PrimitiveType, Transform};

#[test]
fn test_partition_calculator_identity_transform() {
let table_schema = Schema::builder()
.with_schema_id(0)
.with_fields(vec![
NestedField::required(1, "id", Type::Primitive(PrimitiveType::Int)).into(),
NestedField::required(2, "name", Type::Primitive(PrimitiveType::String)).into(),
])
.build()
.unwrap();

let partition_spec = PartitionSpecBuilder::new(Arc::new(table_schema.clone()))
.add_partition_field("id", "id_partition", Transform::Identity)
.unwrap()
.build()
.unwrap();

let calculator = PartitionValueCalculator::try_new(&partition_spec, &table_schema).unwrap();

// Verify partition type
assert_eq!(calculator.partition_type().fields().len(), 1);
assert_eq!(calculator.partition_type().fields()[0].name, "id_partition");

// Create test batch
let arrow_schema = Arc::new(ArrowSchema::new(vec![
Field::new("id", DataType::Int32, false),
Field::new("name", DataType::Utf8, false),
]));

let batch = RecordBatch::try_new(arrow_schema, vec![
Arc::new(Int32Array::from(vec![10, 20, 30])),
Arc::new(StringArray::from(vec!["a", "b", "c"])),
])
.unwrap();

// Calculate partition values
let result = calculator.calculate(&batch).unwrap();
let struct_array = result.as_any().downcast_ref::<StructArray>().unwrap();

let id_partition = struct_array
.column_by_name("id_partition")
.unwrap()
.as_any()
.downcast_ref::<Int32Array>()
.unwrap();

assert_eq!(id_partition.value(0), 10);
assert_eq!(id_partition.value(1), 20);
assert_eq!(id_partition.value(2), 30);
}

#[test]
fn test_partition_calculator_unpartitioned_error() {
let table_schema = Schema::builder()
.with_schema_id(0)
.with_fields(vec![
NestedField::required(1, "id", Type::Primitive(PrimitiveType::Int)).into(),
])
.build()
.unwrap();

let partition_spec = PartitionSpecBuilder::new(Arc::new(table_schema.clone()))
.build()
.unwrap();

let result = PartitionValueCalculator::try_new(&partition_spec, &table_schema);
assert!(result.is_err());
assert!(
result
.unwrap_err()
.to_string()
.contains("unpartitioned table")
);
}
}
Loading
Loading